""" Per-bucket continuation probability model. Architecture: - One LogisticRegression per bucket (warm_start=True for online updates) - Global fallback model trained on all buckets - Online update: accumulate buffer → partial_fit periodically Anti-degradation (basin guard): Shadow-only exits create a feedback loop: model says EXIT → only early-exit outcomes are observed → model learns from biased short-horizon data → drifts to "always EXIT". Three safeguards prevent this: 1. NATURAL_EXITS_ONLY — online updates only from FIXED_TP / MAX_HOLD exits. Forced exits (HIBERNATE_HALT, SUBDAY_ACB_NORMALIZATION) are excluded because they are regime artifacts, not continuation-relevant outcomes. 2. Rolling accuracy monitor — tracks whether the model's continuation predictions match actual outcomes over a sliding window. If accuracy drops below DEGRADATION_THRESHOLD, online updates are paused until it recovers. 3. Label balance guard — if the online update buffer is >80% one class, the flush is skipped (insufficient signal diversity). Features: [mae_norm, mfe_norm, tau_norm, ret_1, ret_3, spread_bps, depth_usd, fill_prob] Target: continuation (1 = still favorable, 0 = adverse) Usage: model = ContinuationModelBank.load() # or .train(df) p = model.predict(asset="BTCUSDT", mae_norm=0.5, mfe_norm=0.2, tau_norm=0.3, ret_1=-0.001, ret_3=-0.003, bucket_id=4) """ import os import pickle import threading from collections import defaultdict, deque from typing import Optional import numpy as np import pandas as pd from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler _MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "continuation_models.pkl") FEATURE_COLS = [ # trade state "mae_norm", "mfe_norm", "tau_norm", "ret_1", "ret_3", # eigenvalue signal — entry quality and current divergence "vel_div_entry", # vel_div at entry bar; always <-0.02 at BLUE inference "vel_div_now", # vel_div at current bar k; live signal during excursion # OBF (static median; zero when unavailable) "spread_bps", "depth_usd", "fill_prob", # ExF — macro/sentiment (daily NPZ backfill; zero-filled when unavailable) "exf_fng", # Fear & Greed / 100 (0–1) "exf_fng_delta", # (fng - fng_prev) / 100 "exf_funding_btc", # BTC perpetual funding rate "exf_dvol_btc", # BTC implied vol / 100 "exf_chg24_btc", # BTC 24h return / 100 ] # Online update config ONLINE_BUFFER_SIZE = 200 # samples before triggering partial retrain ONLINE_MIN_SAMPLES = 50 # min samples per bucket to attempt fit # Anti-degradation config NATURAL_EXIT_REASONS = frozenset({"FIXED_TP", "MAX_HOLD", "V7_MAE_SL_VOL_NORM", "V7_COMPOSITE_PRESSURE", "AE_MAE_STOP", "AE_GIVEBACK_LOW_CONT", "AE_TIME"}) ACCURACY_WINDOW = 50 # rolling window for accuracy monitoring DEGRADATION_THRESHOLD = 0.40 # pause updates if accuracy drops below this LABEL_BALANCE_MIN = 0.15 # skip flush if minority class < 15% of buffer class DegradationGuard: """Rolling accuracy monitor. Pauses online updates when model degrades.""" def __init__(self): self._preds: deque = deque(maxlen=ACCURACY_WINDOW) # (p_cont, actual_cont) self._paused = False def record(self, p_cont: float, actual_continuation: int) -> None: correct = int((p_cont >= 0.5) == bool(actual_continuation)) self._preds.append(correct) if len(self._preds) >= ACCURACY_WINDOW // 2: acc = sum(self._preds) / len(self._preds) self._paused = acc < DEGRADATION_THRESHOLD @property def updates_allowed(self) -> bool: return not self._paused @property def accuracy(self) -> float: return sum(self._preds) / len(self._preds) if self._preds else 0.5 class _BucketModel: """Single-bucket LR with online update support.""" def __init__(self, bucket_id: int): self.bucket_id = bucket_id self.scaler = StandardScaler() self.lr = LogisticRegression( C=0.01, max_iter=500, warm_start=True, solver="lbfgs", class_weight="balanced", ) self._fitted = False self._n_train = 0 self._online_buf_X: list = [] self._online_buf_y: list = [] self._guard = DegradationGuard() def fit(self, X: np.ndarray, y: np.ndarray) -> None: if len(X) < ONLINE_MIN_SAMPLES: return Xs = self.scaler.fit_transform(X) self.lr.fit(Xs, y) self._fitted = True self._n_train = len(X) def predict_proba(self, x: np.ndarray) -> float: """Return P(continuation=1) for a single sample.""" if not self._fitted: return 0.5 xs = self.scaler.transform(x.reshape(1, -1)) return float(self.lr.predict_proba(xs)[0, 1]) def online_update(self, x: np.ndarray, y: int, p_pred: float = 0.5) -> None: # Anti-degradation: record prediction accuracy self._guard.record(p_pred, y) if not self._guard.updates_allowed: return # model degraded — pause updates until accuracy recovers self._online_buf_X.append(x.copy()) self._online_buf_y.append(y) if len(self._online_buf_X) >= ONLINE_BUFFER_SIZE: self._flush_online_buffer() def _flush_online_buffer(self) -> None: if not self._online_buf_X: return X_new = np.array(self._online_buf_X) y_new = np.array(self._online_buf_y) # Label balance guard: skip if minority class < LABEL_BALANCE_MIN pos_rate = y_new.mean() if pos_rate < LABEL_BALANCE_MIN or pos_rate > (1.0 - LABEL_BALANCE_MIN): self._online_buf_X.clear() self._online_buf_y.clear() return if not self._fitted: self.fit(X_new, y_new) else: Xs = self.scaler.transform(X_new) if len(np.unique(y_new)) > 1: self.lr.fit(Xs, y_new) self._online_buf_X.clear() self._online_buf_y.clear() class ContinuationModelBank: """Registry of per-bucket models with a global fallback.""" def __init__(self): self._models: dict[int, _BucketModel] = {} self._global: Optional[_BucketModel] = None self._lock = threading.Lock() # ── Training ────────────────────────────────────────────────────────────── def train(self, df: pd.DataFrame) -> None: """Fit all per-bucket models from training DataFrame.""" print(f"[ContinuationModelBank] Training on {len(df)} samples, " f"{df['bucket_id'].nunique()} buckets") df = df.dropna(subset=FEATURE_COLS + ["bucket_id", "continuation"]) # Global fallback X_all = df[FEATURE_COLS].values y_all = df["continuation"].values.astype(int) self._global = _BucketModel(bucket_id=-1) self._global.fit(X_all, y_all) print(f" Global model: n={len(X_all)}, " f"pos_rate={y_all.mean():.2f}") # Per-bucket for bid, grp in df.groupby("bucket_id"): X = grp[FEATURE_COLS].values y = grp["continuation"].values.astype(int) m = _BucketModel(bucket_id=int(bid)) m.fit(X, y) self._models[int(bid)] = m print(f" Bucket {bid:2d}: n={len(X):6d}, pos_rate={y.mean():.2f}, " f"fitted={m._fitted}") print(f"[ContinuationModelBank] Training complete: " f"{sum(m._fitted for m in self._models.values())}/{len(self._models)} buckets fitted") # ── Inference ───────────────────────────────────────────────────────────── def predict( self, mae_norm: float, mfe_norm: float, tau_norm: float, ret_1: float = 0.0, ret_3: float = 0.0, vel_div_entry: float = 0.0, vel_div_now: float = 0.0, spread_bps: float = 0.0, depth_usd: float = 0.0, fill_prob: float = 0.9, exf_fng: float = 0.0, exf_fng_delta: float = 0.0, exf_funding_btc: float = 0.0, exf_dvol_btc: float = 0.0, exf_chg24_btc: float = 0.0, bucket_id: int = 0, ) -> float: """Return P(continuation | state). Fallback to global if bucket missing.""" x = np.array([mae_norm, mfe_norm, tau_norm, ret_1, ret_3, vel_div_entry, vel_div_now, spread_bps, depth_usd, fill_prob, exf_fng, exf_fng_delta, exf_funding_btc, exf_dvol_btc, exf_chg24_btc], dtype=float) with self._lock: m = self._models.get(bucket_id) if m is not None and m._fitted: return m.predict_proba(x) if self._global is not None and self._global._fitted: return self._global.predict_proba(x) return 0.5 # ── Online update ───────────────────────────────────────────────────────── def online_update( self, bucket_id: int, mae_norm: float, mfe_norm: float, tau_norm: float, ret_1: float, ret_3: float, vel_div_entry: float = 0.0, vel_div_now: float = 0.0, spread_bps: float = 0.0, depth_usd: float = 0.0, fill_prob: float = 0.9, exf_fng: float = 0.0, exf_fng_delta: float = 0.0, exf_funding_btc: float = 0.0, exf_dvol_btc: float = 0.0, exf_chg24_btc: float = 0.0, continuation: int = 0, exit_reason: str = "", p_pred: float = 0.5, ) -> None: # Natural-exits-only guard: skip forced/regime exits (HIBERNATE_HALT, etc.) # These don't reflect continuation dynamics and would bias the model. if exit_reason and exit_reason not in NATURAL_EXIT_REASONS: return x = np.array([mae_norm, mfe_norm, tau_norm, ret_1, ret_3, vel_div_entry, vel_div_now, spread_bps, depth_usd, fill_prob, exf_fng, exf_fng_delta, exf_funding_btc, exf_dvol_btc, exf_chg24_btc], dtype=float) with self._lock: if bucket_id not in self._models: self._models[bucket_id] = _BucketModel(bucket_id) self._models[bucket_id].online_update(x, continuation, p_pred) if self._global is not None: self._global.online_update(x, continuation, p_pred) # ── Persistence ─────────────────────────────────────────────────────────── def __getstate__(self): state = self.__dict__.copy() del state["_lock"] return state def __setstate__(self, state): self.__dict__.update(state) self._lock = threading.Lock() def save(self, path: str = _MODEL_PATH) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as f: pickle.dump(self, f) print(f"[ContinuationModelBank] Saved → {path}") @classmethod def load(cls, path: str = _MODEL_PATH) -> "ContinuationModelBank": if not os.path.exists(path): raise FileNotFoundError(f"No trained model at {path} — run train.py first") with open(path, "rb") as f: return pickle.load(f) def summary(self) -> dict: return { "n_buckets": len(self._models), "fitted_buckets": sum(m._fitted for m in self._models.values()), "global_fitted": self._global._fitted if self._global else False, "n_train_global": self._global._n_train if self._global else 0, }