264 lines
12 KiB
Python
264 lines
12 KiB
Python
|
|
"""Real Zinc-backed hot-path plane for DITAv2.
|
||
|
|
|
||
|
|
This wrapper uses the Zinc Python adapter directly. The kernel still talks to
|
||
|
|
the narrow ``ZincPlane`` interface; this module just makes that interface real.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from dataclasses import asdict
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Dict, List, Optional
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import struct
|
||
|
|
import sys
|
||
|
|
import threading
|
||
|
|
|
||
|
|
from .contracts import KernelIntent, TradeSide, TradeSlot, TradeStage, VenueOrder, VenueOrderStatus
|
||
|
|
from .control import KernelControlSnapshot
|
||
|
|
|
||
|
|
_ZINC_ADAPTER_PATH = Path(__file__).resolve().parents[3] / "zinc" / "adapters" / "python"
|
||
|
|
if _ZINC_ADAPTER_PATH.exists() and str(_ZINC_ADAPTER_PATH) not in sys.path:
|
||
|
|
sys.path.insert(0, str(_ZINC_ADAPTER_PATH))
|
||
|
|
|
||
|
|
try: # pragma: no cover - exercised in integration tests
|
||
|
|
from zinc import SharedRegion
|
||
|
|
except Exception as exc: # pragma: no cover
|
||
|
|
SharedRegion = None # type: ignore[assignment]
|
||
|
|
_ZINC_IMPORT_ERROR = exc
|
||
|
|
else:
|
||
|
|
_ZINC_IMPORT_ERROR = None
|
||
|
|
|
||
|
|
|
||
|
|
class RealZincUnavailable(RuntimeError):
|
||
|
|
"""Raised when the Zinc Python adapter cannot be loaded."""
|
||
|
|
|
||
|
|
|
||
|
|
def require_real_zinc() -> None:
|
||
|
|
if SharedRegion is None:
|
||
|
|
raise RealZincUnavailable(str(_ZINC_IMPORT_ERROR))
|
||
|
|
|
||
|
|
|
||
|
|
def _json_default(value: Any) -> Any:
|
||
|
|
if hasattr(value, "value"):
|
||
|
|
return value.value
|
||
|
|
if hasattr(value, "isoformat"):
|
||
|
|
try:
|
||
|
|
return value.isoformat()
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
if hasattr(value, "__dict__"):
|
||
|
|
return dict(vars(value))
|
||
|
|
raise TypeError(f"Unsupported value: {type(value)!r}")
|
||
|
|
|
||
|
|
|
||
|
|
def _slot_to_payload(slot: TradeSlot) -> Dict[str, Any]:
|
||
|
|
data = slot.to_dict()
|
||
|
|
return data
|
||
|
|
|
||
|
|
|
||
|
|
def _slot_from_payload(payload: Dict[str, Any]) -> TradeSlot:
|
||
|
|
active_entry_order = None
|
||
|
|
active_exit_order = None
|
||
|
|
if isinstance(payload.get("active_entry_order"), dict):
|
||
|
|
active_entry_order = VenueOrder(
|
||
|
|
internal_trade_id=str(payload.get("trade_id", "")),
|
||
|
|
venue_order_id=str(payload["active_entry_order"].get("venue_order_id", "")),
|
||
|
|
venue_client_id=str(payload["active_entry_order"].get("venue_client_id", "")),
|
||
|
|
side=TradeSide(str(payload["active_entry_order"].get("side", TradeSide.FLAT.value))),
|
||
|
|
intended_size=float(payload["active_entry_order"].get("intended_size", payload.get("size", 0.0))),
|
||
|
|
filled_size=float(payload["active_entry_order"].get("filled_size", 0.0)),
|
||
|
|
average_fill_price=float(payload["active_entry_order"].get("average_fill_price", 0.0)),
|
||
|
|
status=VenueOrderStatus(str(payload["active_entry_order"].get("status", VenueOrderStatus.NEW.value))),
|
||
|
|
metadata=dict(payload["active_entry_order"].get("metadata", {})),
|
||
|
|
)
|
||
|
|
if isinstance(payload.get("active_exit_order"), dict):
|
||
|
|
active_exit_order = VenueOrder(
|
||
|
|
internal_trade_id=str(payload.get("trade_id", "")),
|
||
|
|
venue_order_id=str(payload["active_exit_order"].get("venue_order_id", "")),
|
||
|
|
venue_client_id=str(payload["active_exit_order"].get("venue_client_id", "")),
|
||
|
|
side=TradeSide(str(payload["active_exit_order"].get("side", TradeSide.FLAT.value))),
|
||
|
|
intended_size=float(payload["active_exit_order"].get("intended_size", payload.get("size", 0.0))),
|
||
|
|
filled_size=float(payload["active_exit_order"].get("filled_size", 0.0)),
|
||
|
|
average_fill_price=float(payload["active_exit_order"].get("average_fill_price", 0.0)),
|
||
|
|
status=VenueOrderStatus(str(payload["active_exit_order"].get("status", VenueOrderStatus.NEW.value))),
|
||
|
|
metadata=dict(payload["active_exit_order"].get("metadata", {})),
|
||
|
|
)
|
||
|
|
slot = TradeSlot(
|
||
|
|
slot_id=int(payload.get("slot_id", 0)),
|
||
|
|
trade_id=str(payload.get("trade_id", "")),
|
||
|
|
asset=str(payload.get("asset", "")),
|
||
|
|
side=TradeSide(str(payload.get("side", TradeSide.FLAT.value))),
|
||
|
|
entry_price=float(payload.get("entry_price", 0.0)),
|
||
|
|
size=float(payload.get("size", 0.0)),
|
||
|
|
initial_size=float(payload.get("initial_size", 0.0)),
|
||
|
|
leverage=float(payload.get("leverage", 0.0)),
|
||
|
|
entry_time=datetime.fromisoformat(payload["entry_time"]) if payload.get("entry_time") else None,
|
||
|
|
unrealized_pnl=float(payload.get("unrealized_pnl", 0.0)),
|
||
|
|
realized_pnl=float(payload.get("realized_pnl", 0.0)),
|
||
|
|
closed=bool(payload.get("closed", False)),
|
||
|
|
exit_leg_ratios=tuple(float(r) for r in payload.get("exit_leg_ratios", (1.0,))),
|
||
|
|
active_leg_index=int(payload.get("active_leg_index", 0)),
|
||
|
|
active_exit_order=active_exit_order,
|
||
|
|
active_entry_order=active_entry_order,
|
||
|
|
fsm_state=TradeStage(str(payload.get("fsm_state", TradeStage.IDLE.value))),
|
||
|
|
close_reason=str(payload.get("close_reason", "")),
|
||
|
|
last_event_time=datetime.fromisoformat(payload["last_event_time"]) if payload.get("last_event_time") else None,
|
||
|
|
seen_event_ids=tuple(str(event_id) for event_id in payload.get("seen_event_ids", ())),
|
||
|
|
metadata=dict(payload.get("metadata", {})),
|
||
|
|
)
|
||
|
|
return slot
|
||
|
|
|
||
|
|
|
||
|
|
def _encode_packet(seq: int, payload: Dict[str, Any]) -> bytes:
|
||
|
|
text = json.dumps(payload, sort_keys=True, ensure_ascii=False, default=_json_default, separators=(",", ":")).encode("utf-8")
|
||
|
|
return struct.pack("!QQ", int(seq), len(text)) + text
|
||
|
|
|
||
|
|
|
||
|
|
def _decode_packet(buf: memoryview) -> Dict[str, Any]:
|
||
|
|
if len(buf) < 16:
|
||
|
|
return {}
|
||
|
|
seq, size = struct.unpack_from("!QQ", buf, 0)
|
||
|
|
if size <= 0 or size > len(buf) - 16:
|
||
|
|
return {}
|
||
|
|
payload = bytes(buf[16 : 16 + size]).decode("utf-8")
|
||
|
|
out = json.loads(payload)
|
||
|
|
if isinstance(out, dict):
|
||
|
|
out["_seq"] = seq
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
class RealZincPlane:
|
||
|
|
"""Shared-memory Zinc plane used by the Python prototype."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
prefix: str,
|
||
|
|
slot_count: int = 10,
|
||
|
|
intent_capacity: int = 1 << 20,
|
||
|
|
state_capacity: int = 1 << 20,
|
||
|
|
control_capacity: int = 1 << 20,
|
||
|
|
create: bool = True,
|
||
|
|
) -> None:
|
||
|
|
require_real_zinc()
|
||
|
|
base = prefix.strip("/").replace("/", "_")
|
||
|
|
self.intent_name = f"{base}_intent"
|
||
|
|
self.state_name = f"{base}_state"
|
||
|
|
self.control_name = f"{base}_control"
|
||
|
|
self._intent_seq = 0
|
||
|
|
self._state_seq = 0
|
||
|
|
self._control_seq = 0
|
||
|
|
self._lock = threading.Lock()
|
||
|
|
self._slot_cache: Dict[int, TradeSlot] = {i: TradeSlot(slot_id=i) for i in range(int(slot_count))}
|
||
|
|
self._slot_count = int(slot_count)
|
||
|
|
self._intent_cache: List[Dict[str, Any]] = []
|
||
|
|
self._control_cache = KernelControlSnapshot()
|
||
|
|
if create:
|
||
|
|
self.intent_region = SharedRegion.create(self.intent_name, intent_capacity)
|
||
|
|
self.state_region = SharedRegion.create(self.state_name, state_capacity)
|
||
|
|
self.control_region = SharedRegion.create(self.control_name, control_capacity)
|
||
|
|
self._write_region(self.control_region, self._control_seq, {"control": self._control_cache.as_dict()})
|
||
|
|
self._write_region(
|
||
|
|
self.state_region,
|
||
|
|
self._state_seq,
|
||
|
|
{"slots": [self._slot_cache[key].to_dict() for key in range(self._slot_count)]},
|
||
|
|
)
|
||
|
|
self._write_region(self.intent_region, self._intent_seq, {"items": []})
|
||
|
|
else:
|
||
|
|
self.intent_region = SharedRegion.open(self.intent_name)
|
||
|
|
self.state_region = SharedRegion.open(self.state_name)
|
||
|
|
self.control_region = SharedRegion.open(self.control_name)
|
||
|
|
control_payload = _decode_packet(self.control_region.as_buffer())
|
||
|
|
state_payload = _decode_packet(self.state_region.as_buffer())
|
||
|
|
intent_payload = _decode_packet(self.intent_region.as_buffer())
|
||
|
|
if isinstance(control_payload.get("control"), dict):
|
||
|
|
self._control_cache = KernelControlSnapshot(**control_payload["control"])
|
||
|
|
if isinstance(state_payload.get("slots"), list):
|
||
|
|
for slot_payload in state_payload["slots"]:
|
||
|
|
if isinstance(slot_payload, dict):
|
||
|
|
slot = _slot_from_payload(slot_payload)
|
||
|
|
self._slot_cache[int(slot.slot_id)] = slot
|
||
|
|
if isinstance(intent_payload.get("items"), list):
|
||
|
|
self._intent_cache = list(intent_payload["items"])
|
||
|
|
|
||
|
|
def close(self) -> None:
|
||
|
|
self.intent_region.close()
|
||
|
|
self.state_region.close()
|
||
|
|
self.control_region.close()
|
||
|
|
|
||
|
|
def publish_intent(self, intent: KernelIntent) -> None:
|
||
|
|
with self._lock:
|
||
|
|
self._intent_seq += 1
|
||
|
|
row = intent.__dict__.copy()
|
||
|
|
row["timestamp"] = intent.timestamp.isoformat()
|
||
|
|
row["side"] = intent.side.value
|
||
|
|
row["action"] = intent.action.value
|
||
|
|
row["stage"] = intent.stage.value
|
||
|
|
row["exit_leg_ratios"] = list(intent.exit_leg_ratios)
|
||
|
|
row["metadata"] = json.loads(json.dumps(intent.metadata, default=_json_default))
|
||
|
|
self._intent_cache.append(row)
|
||
|
|
self._write_region(self.intent_region, self._intent_seq, {"items": self._intent_cache[-512:]})
|
||
|
|
|
||
|
|
def write_slot(self, slot: TradeSlot) -> None:
|
||
|
|
with self._lock:
|
||
|
|
self._state_seq += 1
|
||
|
|
self._slot_cache[int(slot.slot_id)] = slot
|
||
|
|
payload = {
|
||
|
|
"slots": [self._slot_cache[key].to_dict() for key in range(self._slot_count)],
|
||
|
|
}
|
||
|
|
self._write_region(self.state_region, self._state_seq, payload)
|
||
|
|
|
||
|
|
def read_slots(self) -> List[TradeSlot]:
|
||
|
|
payload = _decode_packet(self.state_region.as_buffer())
|
||
|
|
slots = payload.get("slots", []) if isinstance(payload, dict) else []
|
||
|
|
return [_slot_from_payload(slot) for slot in sorted(slots, key=lambda row: int(row.get("slot_id", 0)))]
|
||
|
|
|
||
|
|
def read_intents(self) -> List[Dict[str, Any]]:
|
||
|
|
payload = _decode_packet(self.intent_region.as_buffer())
|
||
|
|
items = payload.get("items", []) if isinstance(payload, dict) else []
|
||
|
|
return list(items)
|
||
|
|
|
||
|
|
def update_control(self, control: KernelControlSnapshot) -> None:
|
||
|
|
with self._lock:
|
||
|
|
self._control_seq += 1
|
||
|
|
self._control_cache = control
|
||
|
|
self._write_region(self.control_region, self._control_seq, {"control": control.as_dict()})
|
||
|
|
|
||
|
|
def read_control(self) -> KernelControlSnapshot:
|
||
|
|
payload = _decode_packet(self.control_region.as_buffer())
|
||
|
|
control = payload.get("control") if isinstance(payload, dict) else None
|
||
|
|
if not isinstance(control, dict):
|
||
|
|
return self._control_cache
|
||
|
|
return KernelControlSnapshot(**control)
|
||
|
|
|
||
|
|
def wait_on_state(self, timeout_ms: int = 1000) -> bool:
|
||
|
|
return bool(self.state_region.wait(timeout_ms))
|
||
|
|
|
||
|
|
def notify_state(self) -> None:
|
||
|
|
self.state_region.notify()
|
||
|
|
|
||
|
|
def wait_on_control(self, timeout_ms: int = 1000) -> bool:
|
||
|
|
return bool(self.control_region.wait(timeout_ms))
|
||
|
|
|
||
|
|
def notify_control(self) -> None:
|
||
|
|
self.control_region.notify()
|
||
|
|
|
||
|
|
def wait_on_intent(self, timeout_ms: int = 1000) -> bool:
|
||
|
|
return bool(self.intent_region.wait(timeout_ms))
|
||
|
|
|
||
|
|
def notify_intent(self) -> None:
|
||
|
|
self.intent_region.notify()
|
||
|
|
|
||
|
|
def _write_region(self, region: Any, seq: int, payload: Dict[str, Any]) -> None:
|
||
|
|
packet = _encode_packet(seq, payload)
|
||
|
|
buf = region.as_buffer()
|
||
|
|
if len(packet) > len(buf):
|
||
|
|
raise ValueError(f"payload too large for Zinc region: {len(packet)} > {len(buf)}")
|
||
|
|
view = memoryview(buf)
|
||
|
|
view[:] = b"\x00" * len(view)
|
||
|
|
view[: len(packet)] = packet
|
||
|
|
region.notify()
|