Files
DOLPHIN/nautilus_dolphin/dvae/convnext_query.py
hjnormey 01c19662cb initial: import DOLPHIN baseline 2026-04-21 from dolphinng5_predict working tree
Includes core prod + GREEN/BLUE subsystems:
- prod/ (BLUE harness, configs, scripts, docs)
- nautilus_dolphin/ (GREEN Nautilus-native impl + dvae/ preserved)
- adaptive_exit/ (AEM engine + models/bucket_assignments.pkl)
- Observability/ (EsoF advisor, TUI, dashboards)
- external_factors/ (EsoF producer)
- mc_forewarning_qlabs_fork/ (MC regime/envelope)

Excludes runtime caches, logs, backups, and reproducible artifacts per .gitignore.
2026-04-21 16:58:38 +02:00

146 lines
6.7 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
convnext_query.py — inference query against trained convnext_model.json
Reports:
1. Per-channel reconstruction correlation (orig vs recon)
2. z-dim activity and spread
3. Top z-dims correlated with proxy_B (ch7)
"""
import os, sys, json
import numpy as np
import glob
import pandas as pd
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
DVAE_DIR = os.path.join(ROOT, 'nautilus_dolphin', 'dvae')
sys.path.insert(0, DVAE_DIR)
MODEL_PATH = os.path.join(DVAE_DIR, 'convnext_model.json')
KLINES_DIR = os.path.join(ROOT, 'vbt_cache_klines')
EIGENVALUES_PATH = r"C:\Users\Lenovo\Documents\- Dolphin NG HD (NG3)\correlation_arb512\eigenvalues"
EXF_NPZ_NAME = "scan_000001__Indicators.npz"
FEATURE_COLS = [
'v50_lambda_max_velocity', 'v150_lambda_max_velocity',
'v300_lambda_max_velocity', 'v750_lambda_max_velocity',
'vel_div', 'instability_50', 'instability_150',
]
EXF_COLS = ['dvol_btc', 'fng', 'funding_btc']
CH_NAMES = FEATURE_COLS + ['proxy_B'] + EXF_COLS # 11 channels
T_WIN = 32
N_PROBES = 100
# ── load model ──────────────────────────────────────────────────────────────
from convnext_dvae import ConvNeXtVAE
with open(MODEL_PATH) as f:
meta = json.load(f)
arch = meta.get('architecture', {})
model = ConvNeXtVAE(C_in=arch.get('C_in', 11), T_in=arch.get('T_in', 32),
z_dim=arch.get('z_dim', 32), base_ch=arch.get('base_ch', 32),
n_blocks=3, seed=42)
model.load(MODEL_PATH)
norm_mean = np.array(meta['norm_mean']) if 'norm_mean' in meta else None
norm_std = np.array(meta['norm_std']) if 'norm_std' in meta else None
print(f"Model: epoch={meta.get('epoch')} val_loss={meta.get('val_loss', float('nan')):.4f}")
print(f" C_in={arch.get('C_in')} z_dim={arch.get('z_dim')} base_ch={arch.get('base_ch')}\n")
# ── build probe set ──────────────────────────────────────────────────────────
_exf_idx = None
def get_exf_indices():
global _exf_idx
if _exf_idx is not None: return _exf_idx
for ds in sorted(os.listdir(EIGENVALUES_PATH)):
p = os.path.join(EIGENVALUES_PATH, ds, EXF_NPZ_NAME)
if os.path.exists(p):
try:
d = np.load(p, allow_pickle=True)
_exf_idx = {n: i for i, n in enumerate(d['api_names'])}
return _exf_idx
except Exception: continue
return {}
files = sorted(glob.glob(os.path.join(KLINES_DIR, '*.parquet')))
step = max(1, len(files) // N_PROBES)
idx_map = get_exf_indices()
probes_raw, proxy_B_vals = [], []
for f in files[::step]:
try:
df = pd.read_parquet(f, columns=FEATURE_COLS).dropna()
if len(df) < T_WIN + 10: continue
pos = len(df) // 2
arr = df[FEATURE_COLS].values[pos:pos+T_WIN].astype(np.float64)
proxy_B = (arr[:, 5] - arr[:, 3]).reshape(-1, 1)
arr = np.concatenate([arr, proxy_B], axis=1) # (T, 8)
exf = np.zeros((T_WIN, len(EXF_COLS)), dtype=np.float64)
date_str = os.path.basename(f).replace('.parquet', '')
npz_p = os.path.join(EIGENVALUES_PATH, date_str, EXF_NPZ_NAME)
if os.path.exists(npz_p) and idx_map:
d = np.load(npz_p, allow_pickle=True)
for ci, col in enumerate(EXF_COLS):
fi = idx_map.get(col, -1)
if fi >= 0 and bool(d['api_success'][fi]):
exf[:, ci] = float(d['api_indicators'][fi])
arr = np.concatenate([arr, exf], axis=1).T # (11, T)
probes_raw.append(arr)
proxy_B_vals.append(float(proxy_B.mean()))
except Exception:
pass
if len(probes_raw) >= N_PROBES: break
probes_raw = np.stack(probes_raw) # (N, 11, T)
proxy_B_arr = np.array(proxy_B_vals) # (N,)
print(f"Probe set: {probes_raw.shape} ({len(probes_raw)} windows × {probes_raw.shape[1]} ch × {T_WIN} steps)\n")
# ── normalise ────────────────────────────────────────────────────────────────
probes = probes_raw.copy()
if norm_mean is not None:
probes = (probes - norm_mean[None, :, None]) / norm_std[None, :, None]
np.clip(probes, -6.0, 6.0, out=probes)
# ── encode / decode ──────────────────────────────────────────────────────────
z_mu, z_logvar = model.encode(probes)
x_recon = model.decode(z_mu)
# ── 1. Per-channel reconstruction correlation ────────────────────────────────
print("── Per-channel reconstruction r (orig vs recon) ──────────────────")
for c, name in enumerate(CH_NAMES):
rs = []
for b in range(len(probes)):
o, r = probes[b, c], x_recon[b, c]
if o.std() > 1e-6 and r.std() > 1e-6:
rv = float(np.corrcoef(o, r)[0, 1])
if np.isfinite(rv): rs.append(rv)
mean_r = np.mean(rs) if rs else float('nan')
bar = '' * int(max(0, mean_r) * 20)
print(f" ch{c:2d} {name:<30s} r={mean_r:+.3f} {bar}")
# ── 2. z-dim activity ────────────────────────────────────────────────────────
z_std_per_dim = z_mu.std(0) # (D,)
z_active = int((z_std_per_dim > 0.01).sum())
z_post_std = float(np.exp(0.5 * z_logvar).mean())
print(f"\n── Latent space ──────────────────────────────────────────────────")
print(f" z_active_dims : {z_active} / {z_mu.shape[1]}")
print(f" z_post_std : {z_post_std:.4f} (>1 = posterior wider than prior)")
# ── 3. z-dim × proxy_B correlation ──────────────────────────────────────────
print(f"\n── z-dim correlation with proxy_B (top 10) ──────────────────────")
corrs = []
for d in range(z_mu.shape[1]):
if z_std_per_dim[d] > 0.01:
r = float(np.corrcoef(z_mu[:, d], proxy_B_arr)[0, 1])
if np.isfinite(r): corrs.append((abs(r), r, d))
corrs.sort(reverse=True)
for _, r, d in corrs[:10]:
bar = '' * int(abs(r) * 20)
print(f" z[{d:2d}] r={r:+.3f} {bar}")
print(f"\nDone.")