import sys import os import unittest import json import time from pathlib import Path # Add correctly mapped paths for the ND system ROOT_DIR = Path(__file__).parent.parent.parent sys.path.insert(0, str(ROOT_DIR / "nautilus_dolphin")) sys.path.insert(0, str(ROOT_DIR)) import numpy as np import logging from unittest.mock import MagicMock, patch from collections import deque from datetime import datetime, timezone from nautilus_dolphin.nautilus.ob_features import ( OBFeatureEngine, OBPlacementFeatures, OBSignalFeatures, OBMacroFeatures, NEUTRAL_PLACEMENT, NEUTRAL_SIGNAL, NEUTRAL_MACRO ) from nautilus_dolphin.nautilus.ob_provider import OBSnapshot from nautilus_dolphin.nautilus.hz_ob_provider import HZOBProvider class TestHZOBProviderLive(unittest.TestCase): def setUp(self): self.mock_provider = MagicMock(spec=HZOBProvider) self.engine = OBFeatureEngine(self.mock_provider) def test_step_live_fetches_snapshots(self): """Test that step_live calls provider.get_snapshot for all assets.""" assets = ["BTCUSDT", "ETHUSDT"] self.mock_provider.get_snapshot.return_value = None self.engine.step_live(assets, bar_idx=100) self.assertEqual(self.mock_provider.get_snapshot.call_count, 2) self.assertTrue(self.engine._live_mode) self.assertEqual(self.engine._live_bar_idx, 100) def test_step_live_populates_placement_cache(self): """Test that placement features are correctly computed and cached in live mode.""" asset = "BTCUSDT" snap = OBSnapshot( timestamp=time.time(), asset=asset, bid_notional=np.array([1000.0, 2000.0, 3000.0, 4000.0, 5000.0]), ask_notional=np.array([1100.0, 2100.0, 3100.0, 4100.0, 5100.0]), bid_depth=np.array([1.0, 2.0, 3.0, 4.0, 5.0]), ask_depth=np.array([1.1, 2.1, 3.1, 4.1, 5.1]) ) self.mock_provider.get_snapshot.return_value = snap self.engine.step_live([asset], bar_idx=5) placement = self.engine.get_placement(asset, 5) self.assertAlmostEqual(placement.depth_1pct_usd, 2100.0) # 1000 + 1100 self.assertGreater(placement.fill_probability, 0.5) def test_step_live_populates_signal_cache(self): """Test that signal features (imbalance, persistence) are computed in live mode.""" asset = "BTCUSDT" # Snapshot with heavy bid exposure (imbalance > 0) snap = OBSnapshot( timestamp=time.time(), asset=asset, bid_notional=np.array([5000.0, 0, 0, 0, 0]), ask_notional=np.array([1000.0, 0, 0, 0, 0]), bid_depth=np.ones(5), ask_depth=np.ones(5) ) self.mock_provider.get_snapshot.return_value = snap # Step twice to check histories self.engine.step_live([asset], bar_idx=10) self.engine.step_live([asset], bar_idx=11) signal = self.engine.get_signal(asset, 11) self.assertAlmostEqual(signal.imbalance, (5000-1000)/(5000+1000)) self.assertEqual(signal.imbalance_persistence, 1.0) # both positive def test_step_live_market_features(self): """Test cross-asset agreement and cascade signal.""" assets = ["BTCUSDT", "ETHUSDT"] # BTC withdrawing (vel < -0.1), ETH building (vel > 0) snaps = { "BTCUSDT": [ OBSnapshot(time.time(), "BTCUSDT", np.array([2000.0]*5), np.array([2000.0]*5), np.ones(5), np.ones(5)), OBSnapshot(time.time(), "BTCUSDT", np.array([1000.0]*5), np.array([1000.0]*5), np.ones(5), np.ones(5)) ], "ETHUSDT": [ OBSnapshot(time.time(), "ETHUSDT", np.array([1000.0]*5), np.array([1000.0]*5), np.ones(5), np.ones(5)), OBSnapshot(time.time(), "ETHUSDT", np.array([1200.0]*5), np.array([1200.0]*5), np.ones(5), np.ones(5)) ] } self._snap_idx = 0 def side_effect(asset, ts): return snaps[asset][self._snap_idx] self.mock_provider.get_snapshot.side_effect = side_effect self._snap_idx = 0 self.engine.step_live(assets, bar_idx=0) self._snap_idx = 1 self.engine.step_live(assets, bar_idx=1) macro = self.engine.get_macro(1) # BTC vel = (2000-4000)/4000 = -0.5 # ETH vel = (2400-2000)/2000 = +0.2 # cascade count should be 1 if threshold is -0.1 self.assertEqual(macro.cascade_count, 1) def test_step_live_none_snapshot_skipped(self): """Test that None snapshots are skipped without error.""" self.mock_provider.get_snapshot.return_value = None self.engine.step_live(["BTCUSDT"], bar_idx=20) self.assertEqual(self.engine._live_stale_count, 1) def test_step_live_stale_warning(self): """Test that stale count increments correctly.""" self.mock_provider.get_snapshot.return_value = None for i in range(3): self.engine.step_live(["BTCUSDT"], bar_idx=i) self.assertEqual(self.engine._live_stale_count, 3) def test_step_live_cache_eviction(self): """Test that live caches are evicted after MAX_LIVE_CACHE entries.""" asset = "BTCUSDT" snap = OBSnapshot(time.time(), asset, np.array([1000.0]*5), np.array([1000.0]*5), np.ones(5), np.ones(5)) self.mock_provider.get_snapshot.return_value = snap for i in range(505): self.engine.step_live([asset], bar_idx=i) self.assertEqual(len(self.engine._live_placement[asset]), 500) self.assertNotIn(0, self.engine._live_placement[asset]) self.assertIn(504, self.engine._live_placement[asset]) def test_resolve_idx_live_mode(self): """Test index resolution in live mode.""" self.engine._live_mode = True self.engine._live_placement["BTCUSDT"] = {10: MagicMock()} idx = self.engine._resolve_idx("BTCUSDT", 10.0) self.assertEqual(idx, 10) def test_resolve_idx_live_fallback(self): """Test fallback to latest bar in live mode.""" self.engine._live_mode = True self.engine._live_placement["BTCUSDT"] = {10: MagicMock(), 15: MagicMock()} idx = self.engine._resolve_idx("BTCUSDT", 20.0) # unknown bar self.assertEqual(idx, 15) def test_median_depth_ema(self): """Test that _median_depth_ref converges via EMA.""" asset = "BTCUSDT" # Init with 2000 snap1 = OBSnapshot(time.time(), asset, np.array([1000.0]*5), np.array([1000.0]*5), np.ones(5), np.ones(5)) self.mock_provider.get_snapshot.return_value = snap1 self.engine.step_live([asset], bar_idx=0) self.assertEqual(self.engine._median_depth_ref[asset], 2000.0) # Next value 4000 snap2 = OBSnapshot(time.time(), asset, np.array([2000.0]*5), np.array([2000.0]*5), np.ones(5), np.ones(5)) self.mock_provider.get_snapshot.return_value = snap2 self.engine.step_live([asset], bar_idx=1) # 0.99 * 2000 + 0.01 * 4000 = 1980 + 40 = 2020 self.assertAlmostEqual(self.engine._median_depth_ref[asset], 2020.0) def test_hz_ob_provider_timestamp_iso(self): """Test ISO string normalization in HZOBProvider.""" provider = HZOBProvider() mock_imap = MagicMock() provider._imap = mock_imap iso_ts = "2026-03-26T12:00:00+00:00" expected_ts = datetime.fromisoformat(iso_ts).replace(tzinfo=timezone.utc).timestamp() payload = json.dumps({ "timestamp": iso_ts, "bid_notional": [1.0]*5, "ask_notional": [1.0]*5, "bid_depth": [1.0]*5, "ask_depth": [1.0]*5 }) mock_imap.get.return_value = payload snap = provider.get_snapshot("BTCUSDT", time.time()) self.assertEqual(snap.timestamp, expected_ts) def test_hz_ob_provider_timestamp_float(self): """Test float timestamp pass-through in HZOBProvider.""" provider = HZOBProvider() mock_imap = MagicMock() provider._imap = mock_imap float_ts = 1711454400.0 payload = json.dumps({ "timestamp": float_ts, "bid_notional": [1.0]*5, "ask_notional": [1.0]*5, "bid_depth": [1.0]*5, "ask_depth": [1.0]*5 }) mock_imap.get.return_value = payload snap = provider.get_snapshot("BTCUSDT", time.time()) self.assertEqual(snap.timestamp, float_ts) if __name__ == "__main__": unittest.main()