Files
DOLPHIN/adaptive_exit/continuation_model.py

306 lines
12 KiB
Python
Raw Permalink Normal View History

"""
Per-bucket continuation probability model.
Architecture:
- One LogisticRegression per bucket (warm_start=True for online updates)
- Global fallback model trained on all buckets
- Online update: accumulate buffer partial_fit periodically
Anti-degradation (basin guard):
Shadow-only exits create a feedback loop: model says EXIT only early-exit
outcomes are observed model learns from biased short-horizon data drifts
to "always EXIT". Three safeguards prevent this:
1. NATURAL_EXITS_ONLY online updates only from FIXED_TP / MAX_HOLD exits.
Forced exits (HIBERNATE_HALT, SUBDAY_ACB_NORMALIZATION) are excluded because
they are regime artifacts, not continuation-relevant outcomes.
2. Rolling accuracy monitor tracks whether the model's continuation predictions
match actual outcomes over a sliding window. If accuracy drops below
DEGRADATION_THRESHOLD, online updates are paused until it recovers.
3. Label balance guard if the online update buffer is >80% one class,
the flush is skipped (insufficient signal diversity).
Features: [mae_norm, mfe_norm, tau_norm, ret_1, ret_3, spread_bps, depth_usd, fill_prob]
Target: continuation (1 = still favorable, 0 = adverse)
Usage:
model = ContinuationModelBank.load() # or .train(df)
p = model.predict(asset="BTCUSDT", mae_norm=0.5, mfe_norm=0.2, tau_norm=0.3,
ret_1=-0.001, ret_3=-0.003, bucket_id=4)
"""
import os
import pickle
import threading
from collections import defaultdict, deque
from typing import Optional
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "continuation_models.pkl")
FEATURE_COLS = [
# trade state
"mae_norm", "mfe_norm", "tau_norm", "ret_1", "ret_3",
# eigenvalue signal — entry quality and current divergence
"vel_div_entry", # vel_div at entry bar; always <-0.02 at BLUE inference
"vel_div_now", # vel_div at current bar k; live signal during excursion
# OBF (static median; zero when unavailable)
"spread_bps", "depth_usd", "fill_prob",
# ExF — macro/sentiment (daily NPZ backfill; zero-filled when unavailable)
"exf_fng", # Fear & Greed / 100 (01)
"exf_fng_delta", # (fng - fng_prev) / 100
"exf_funding_btc", # BTC perpetual funding rate
"exf_dvol_btc", # BTC implied vol / 100
"exf_chg24_btc", # BTC 24h return / 100
]
# Online update config
ONLINE_BUFFER_SIZE = 200 # samples before triggering partial retrain
ONLINE_MIN_SAMPLES = 50 # min samples per bucket to attempt fit
# Anti-degradation config
NATURAL_EXIT_REASONS = frozenset({"FIXED_TP", "MAX_HOLD", "V7_MAE_SL_VOL_NORM",
"V7_COMPOSITE_PRESSURE", "AE_MAE_STOP",
"AE_GIVEBACK_LOW_CONT", "AE_TIME"})
ACCURACY_WINDOW = 50 # rolling window for accuracy monitoring
DEGRADATION_THRESHOLD = 0.40 # pause updates if accuracy drops below this
LABEL_BALANCE_MIN = 0.15 # skip flush if minority class < 15% of buffer
class DegradationGuard:
"""Rolling accuracy monitor. Pauses online updates when model degrades."""
def __init__(self):
self._preds: deque = deque(maxlen=ACCURACY_WINDOW) # (p_cont, actual_cont)
self._paused = False
def record(self, p_cont: float, actual_continuation: int) -> None:
correct = int((p_cont >= 0.5) == bool(actual_continuation))
self._preds.append(correct)
if len(self._preds) >= ACCURACY_WINDOW // 2:
acc = sum(self._preds) / len(self._preds)
self._paused = acc < DEGRADATION_THRESHOLD
@property
def updates_allowed(self) -> bool:
return not self._paused
@property
def accuracy(self) -> float:
return sum(self._preds) / len(self._preds) if self._preds else 0.5
class _BucketModel:
"""Single-bucket LR with online update support."""
def __init__(self, bucket_id: int):
self.bucket_id = bucket_id
self.scaler = StandardScaler()
self.lr = LogisticRegression(
C=0.01,
max_iter=500,
warm_start=True,
solver="lbfgs",
class_weight="balanced",
)
self._fitted = False
self._n_train = 0
self._online_buf_X: list = []
self._online_buf_y: list = []
self._guard = DegradationGuard()
def fit(self, X: np.ndarray, y: np.ndarray) -> None:
if len(X) < ONLINE_MIN_SAMPLES:
return
Xs = self.scaler.fit_transform(X)
self.lr.fit(Xs, y)
self._fitted = True
self._n_train = len(X)
def predict_proba(self, x: np.ndarray) -> float:
"""Return P(continuation=1) for a single sample."""
if not self._fitted:
return 0.5
xs = self.scaler.transform(x.reshape(1, -1))
return float(self.lr.predict_proba(xs)[0, 1])
def online_update(self, x: np.ndarray, y: int, p_pred: float = 0.5) -> None:
# Anti-degradation: record prediction accuracy
self._guard.record(p_pred, y)
if not self._guard.updates_allowed:
return # model degraded — pause updates until accuracy recovers
self._online_buf_X.append(x.copy())
self._online_buf_y.append(y)
if len(self._online_buf_X) >= ONLINE_BUFFER_SIZE:
self._flush_online_buffer()
def _flush_online_buffer(self) -> None:
if not self._online_buf_X:
return
X_new = np.array(self._online_buf_X)
y_new = np.array(self._online_buf_y)
# Label balance guard: skip if minority class < LABEL_BALANCE_MIN
pos_rate = y_new.mean()
if pos_rate < LABEL_BALANCE_MIN or pos_rate > (1.0 - LABEL_BALANCE_MIN):
self._online_buf_X.clear()
self._online_buf_y.clear()
return
if not self._fitted:
self.fit(X_new, y_new)
else:
Xs = self.scaler.transform(X_new)
if len(np.unique(y_new)) > 1:
self.lr.fit(Xs, y_new)
self._online_buf_X.clear()
self._online_buf_y.clear()
class ContinuationModelBank:
"""Registry of per-bucket models with a global fallback."""
def __init__(self):
self._models: dict[int, _BucketModel] = {}
self._global: Optional[_BucketModel] = None
self._lock = threading.Lock()
# ── Training ──────────────────────────────────────────────────────────────
def train(self, df: pd.DataFrame) -> None:
"""Fit all per-bucket models from training DataFrame."""
print(f"[ContinuationModelBank] Training on {len(df)} samples, "
f"{df['bucket_id'].nunique()} buckets")
df = df.dropna(subset=FEATURE_COLS + ["bucket_id", "continuation"])
# Global fallback
X_all = df[FEATURE_COLS].values
y_all = df["continuation"].values.astype(int)
self._global = _BucketModel(bucket_id=-1)
self._global.fit(X_all, y_all)
print(f" Global model: n={len(X_all)}, "
f"pos_rate={y_all.mean():.2f}")
# Per-bucket
for bid, grp in df.groupby("bucket_id"):
X = grp[FEATURE_COLS].values
y = grp["continuation"].values.astype(int)
m = _BucketModel(bucket_id=int(bid))
m.fit(X, y)
self._models[int(bid)] = m
print(f" Bucket {bid:2d}: n={len(X):6d}, pos_rate={y.mean():.2f}, "
f"fitted={m._fitted}")
print(f"[ContinuationModelBank] Training complete: "
f"{sum(m._fitted for m in self._models.values())}/{len(self._models)} buckets fitted")
# ── Inference ─────────────────────────────────────────────────────────────
def predict(
self,
mae_norm: float,
mfe_norm: float,
tau_norm: float,
ret_1: float = 0.0,
ret_3: float = 0.0,
vel_div_entry: float = 0.0,
vel_div_now: float = 0.0,
spread_bps: float = 0.0,
depth_usd: float = 0.0,
fill_prob: float = 0.9,
exf_fng: float = 0.0,
exf_fng_delta: float = 0.0,
exf_funding_btc: float = 0.0,
exf_dvol_btc: float = 0.0,
exf_chg24_btc: float = 0.0,
bucket_id: int = 0,
) -> float:
"""Return P(continuation | state). Fallback to global if bucket missing."""
x = np.array([mae_norm, mfe_norm, tau_norm, ret_1, ret_3,
vel_div_entry, vel_div_now,
spread_bps, depth_usd, fill_prob,
exf_fng, exf_fng_delta, exf_funding_btc, exf_dvol_btc, exf_chg24_btc],
dtype=float)
with self._lock:
m = self._models.get(bucket_id)
if m is not None and m._fitted:
return m.predict_proba(x)
if self._global is not None and self._global._fitted:
return self._global.predict_proba(x)
return 0.5
# ── Online update ─────────────────────────────────────────────────────────
def online_update(
self,
bucket_id: int,
mae_norm: float,
mfe_norm: float,
tau_norm: float,
ret_1: float,
ret_3: float,
vel_div_entry: float = 0.0,
vel_div_now: float = 0.0,
spread_bps: float = 0.0,
depth_usd: float = 0.0,
fill_prob: float = 0.9,
exf_fng: float = 0.0,
exf_fng_delta: float = 0.0,
exf_funding_btc: float = 0.0,
exf_dvol_btc: float = 0.0,
exf_chg24_btc: float = 0.0,
continuation: int = 0,
exit_reason: str = "",
p_pred: float = 0.5,
) -> None:
# Natural-exits-only guard: skip forced/regime exits (HIBERNATE_HALT, etc.)
# These don't reflect continuation dynamics and would bias the model.
if exit_reason and exit_reason not in NATURAL_EXIT_REASONS:
return
x = np.array([mae_norm, mfe_norm, tau_norm, ret_1, ret_3,
vel_div_entry, vel_div_now,
spread_bps, depth_usd, fill_prob,
exf_fng, exf_fng_delta, exf_funding_btc, exf_dvol_btc, exf_chg24_btc],
dtype=float)
with self._lock:
if bucket_id not in self._models:
self._models[bucket_id] = _BucketModel(bucket_id)
self._models[bucket_id].online_update(x, continuation, p_pred)
if self._global is not None:
self._global.online_update(x, continuation, p_pred)
# ── Persistence ───────────────────────────────────────────────────────────
def __getstate__(self):
state = self.__dict__.copy()
del state["_lock"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._lock = threading.Lock()
def save(self, path: str = _MODEL_PATH) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump(self, f)
print(f"[ContinuationModelBank] Saved → {path}")
@classmethod
def load(cls, path: str = _MODEL_PATH) -> "ContinuationModelBank":
if not os.path.exists(path):
raise FileNotFoundError(f"No trained model at {path} — run train.py first")
with open(path, "rb") as f:
return pickle.load(f)
def summary(self) -> dict:
return {
"n_buckets": len(self._models),
"fitted_buckets": sum(m._fitted for m in self._models.values()),
"global_fitted": self._global._fitted if self._global else False,
"n_train_global": self._global._n_train if self._global else 0,
}