142 lines
7.2 KiB
Python
142 lines
7.2 KiB
Python
|
|
"""
|
|||
|
|
convnext_5s_query.py — inference query against trained convnext_model_5s.json
|
|||
|
|
|
|||
|
|
Reports:
|
|||
|
|
1. Per-channel reconstruction correlation (orig vs recon)
|
|||
|
|
2. z-dim activity and spread
|
|||
|
|
3. Top z-dims correlated with proxy_B (ch7)
|
|||
|
|
|
|||
|
|
Uses vbt_cache/*.parquet (5s scan corpus, C_in=8, no ExF).
|
|||
|
|
"""
|
|||
|
|
import os, sys, json, io
|
|||
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
|||
|
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
|
|||
|
|
import numpy as np
|
|||
|
|
import glob
|
|||
|
|
import pandas as pd
|
|||
|
|
|
|||
|
|
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
|
DVAE_DIR = os.path.join(ROOT, 'nautilus_dolphin', 'dvae')
|
|||
|
|
sys.path.insert(0, DVAE_DIR)
|
|||
|
|
|
|||
|
|
MODEL_PATH = os.path.join(DVAE_DIR, 'convnext_model_5s.json')
|
|||
|
|
SCANS_DIR = os.path.join(ROOT, 'vbt_cache')
|
|||
|
|
|
|||
|
|
FEATURE_COLS = [
|
|||
|
|
'v50_lambda_max_velocity', 'v150_lambda_max_velocity',
|
|||
|
|
'v300_lambda_max_velocity', 'v750_lambda_max_velocity',
|
|||
|
|
'vel_div', 'instability_50', 'instability_150',
|
|||
|
|
]
|
|||
|
|
CH_NAMES = FEATURE_COLS + ['proxy_B'] # 8 channels
|
|||
|
|
|
|||
|
|
T_WIN = 32
|
|||
|
|
N_PROBES = 200 # more probes — 56 files, sample ~3-4 per file
|
|||
|
|
|
|||
|
|
# ── load model ──────────────────────────────────────────────────────────────
|
|||
|
|
from convnext_dvae import ConvNeXtVAE
|
|||
|
|
|
|||
|
|
with open(MODEL_PATH) as f:
|
|||
|
|
meta = json.load(f)
|
|||
|
|
|
|||
|
|
arch = meta.get('architecture', {})
|
|||
|
|
model = ConvNeXtVAE(
|
|||
|
|
C_in = arch.get('C_in', 8),
|
|||
|
|
T_in = arch.get('T_in', 32),
|
|||
|
|
z_dim = arch.get('z_dim', 32),
|
|||
|
|
base_ch = arch.get('base_ch', 32),
|
|||
|
|
n_blocks = arch.get('n_blocks', 3),
|
|||
|
|
seed = 42,
|
|||
|
|
)
|
|||
|
|
model.load(MODEL_PATH)
|
|||
|
|
norm_mean = np.array(meta['norm_mean']) if 'norm_mean' in meta else None
|
|||
|
|
norm_std = np.array(meta['norm_std']) if 'norm_std' in meta else None
|
|||
|
|
|
|||
|
|
print(f"Model: epoch={meta.get('epoch')} val_loss={meta.get('val_loss', float('nan')):.5f}")
|
|||
|
|
print(f" C_in={arch.get('C_in')} z_dim={arch.get('z_dim')} base_ch={arch.get('base_ch')}\n")
|
|||
|
|
|
|||
|
|
# ── build probe set ──────────────────────────────────────────────────────────
|
|||
|
|
files = sorted(f for f in glob.glob(os.path.join(SCANS_DIR, '*.parquet'))
|
|||
|
|
if 'catalog' not in f)
|
|||
|
|
step = max(1, len(files) // (N_PROBES // 4)) # ~4 probes per file
|
|||
|
|
probes_raw, proxy_B_vals = [], []
|
|||
|
|
|
|||
|
|
rng = np.random.default_rng(42)
|
|||
|
|
for f in files[::step]:
|
|||
|
|
try:
|
|||
|
|
df = pd.read_parquet(f, columns=FEATURE_COLS).dropna()
|
|||
|
|
if len(df) < T_WIN + 4: continue
|
|||
|
|
# sample multiple starting positions per file
|
|||
|
|
positions = rng.integers(0, len(df) - T_WIN, size=4)
|
|||
|
|
for pos in positions:
|
|||
|
|
arr = df[FEATURE_COLS].values[pos:pos+T_WIN].astype(np.float64) # (T, 7)
|
|||
|
|
proxy_B = (arr[:, 5] - arr[:, 3]).reshape(-1, 1) # instability_50 - v750
|
|||
|
|
arr8 = np.concatenate([arr, proxy_B], axis=1) # (T, 8)
|
|||
|
|
if not np.isfinite(arr8).all(): continue
|
|||
|
|
probes_raw.append(arr8.T) # (8, T)
|
|||
|
|
proxy_B_vals.append(float(proxy_B.mean()))
|
|||
|
|
if len(probes_raw) >= N_PROBES: break
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
if len(probes_raw) >= N_PROBES: break
|
|||
|
|
|
|||
|
|
probes_raw = np.stack(probes_raw) # (N, 8, T)
|
|||
|
|
proxy_B_arr = np.array(proxy_B_vals) # (N,)
|
|||
|
|
print(f"Probe set: {probes_raw.shape} ({len(probes_raw)} windows × {probes_raw.shape[1]} ch × {T_WIN} steps)\n")
|
|||
|
|
|
|||
|
|
# ── normalise ────────────────────────────────────────────────────────────────
|
|||
|
|
probes = probes_raw.copy()
|
|||
|
|
if norm_mean is not None:
|
|||
|
|
probes = (probes - norm_mean[None, :, None]) / norm_std[None, :, None]
|
|||
|
|
np.clip(probes, -6.0, 6.0, out=probes)
|
|||
|
|
|
|||
|
|
# ── encode / decode ──────────────────────────────────────────────────────────
|
|||
|
|
z_mu, z_logvar = model.encode(probes)
|
|||
|
|
x_recon = model.decode(z_mu)
|
|||
|
|
|
|||
|
|
# ── 1. Per-channel reconstruction correlation ────────────────────────────────
|
|||
|
|
print("── Per-channel reconstruction r (orig vs recon) ──────────────────")
|
|||
|
|
for c, name in enumerate(CH_NAMES):
|
|||
|
|
rs = []
|
|||
|
|
for b in range(len(probes)):
|
|||
|
|
o, r = probes[b, c], x_recon[b, c]
|
|||
|
|
if o.std() > 1e-6 and r.std() > 1e-6:
|
|||
|
|
rv = float(np.corrcoef(o, r)[0, 1])
|
|||
|
|
if np.isfinite(rv): rs.append(rv)
|
|||
|
|
mean_r = np.mean(rs) if rs else float('nan')
|
|||
|
|
bar = '█' * int(max(0, mean_r) * 20)
|
|||
|
|
print(f" ch{c:2d} {name:<30s} r={mean_r:+.3f} {bar}")
|
|||
|
|
|
|||
|
|
# ── 2. z-dim activity ────────────────────────────────────────────────────────
|
|||
|
|
z_std_per_dim = z_mu.std(0) # (D,)
|
|||
|
|
z_active = int((z_std_per_dim > 0.01).sum())
|
|||
|
|
z_post_std = float(np.exp(0.5 * z_logvar).mean())
|
|||
|
|
|
|||
|
|
print(f"\n── Latent space ──────────────────────────────────────────────────")
|
|||
|
|
print(f" z_active_dims : {z_active} / {z_mu.shape[1]}")
|
|||
|
|
print(f" z_post_std : {z_post_std:.4f} (>1 = posterior wider than prior)")
|
|||
|
|
z_stds_sorted = sorted(enumerate(z_std_per_dim), key=lambda x: -x[1])
|
|||
|
|
print(f" Top z-dim stds: " + " ".join(f"z[{i}]={s:.3f}" for i, s in z_stds_sorted[:8]))
|
|||
|
|
|
|||
|
|
# ── 3. z-dim × proxy_B correlation ──────────────────────────────────────────
|
|||
|
|
print(f"\n── z-dim correlation with proxy_B (all active dims) ─────────────")
|
|||
|
|
corrs = []
|
|||
|
|
for d in range(z_mu.shape[1]):
|
|||
|
|
if z_std_per_dim[d] > 0.01:
|
|||
|
|
r = float(np.corrcoef(z_mu[:, d], proxy_B_arr)[0, 1])
|
|||
|
|
if np.isfinite(r): corrs.append((abs(r), r, d))
|
|||
|
|
corrs.sort(reverse=True)
|
|||
|
|
print(f" (proxy_B mean={proxy_B_arr.mean():+.4f} std={proxy_B_arr.std():.4f})")
|
|||
|
|
for _, r, d in corrs[:15]:
|
|||
|
|
bar = '█' * int(abs(r) * 30)
|
|||
|
|
print(f" z[{d:2d}] r={r:+.4f} {bar}")
|
|||
|
|
|
|||
|
|
# ── 4. z-dim statistics ──────────────────────────────────────────────────────
|
|||
|
|
print(f"\n── z-dim statistics (z_mu) ──────────────────────────────────────")
|
|||
|
|
print(f" {'dim':>4} {'mean':>8} {'std':>8} {'min':>8} {'max':>8} {'r_proxyB':>10}")
|
|||
|
|
for i, s in z_stds_sorted[:16]:
|
|||
|
|
r_pb = float(np.corrcoef(z_mu[:, i], proxy_B_arr)[0, 1]) if s > 0.01 else float('nan')
|
|||
|
|
print(f" z[{i:2d}] {z_mu[:, i].mean():>+8.4f} {s:>8.4f} "
|
|||
|
|
f"{z_mu[:, i].min():>+8.4f} {z_mu[:, i].max():>+8.4f} {r_pb:>+10.4f}")
|
|||
|
|
|
|||
|
|
print(f"\nDone.")
|