Files
DOLPHIN/nautilus_dolphin/dvae/test_flint_hd_vae.py

111 lines
4.5 KiB
Python
Raw Normal View History

"""
Unit test for FlintHDVAE inverse projection decoder.
Criteria:
1. No NaN during training
2. z_var per dim > 0.01 after 20 epochs (no posterior collapse)
3. Reconstruction MSE < 1.0 on held-out 20%
"""
import sys, os
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, r"C:\Users\Lenovo\Documents\- DOLPHIN NG HD HCM TSF Predict")
import numpy as np
from pathlib import Path
HERE = Path(__file__).parent
print("=" * 65)
print("UNIT TEST: FlintHDVAE Inverse Projection Decoder")
print("=" * 65)
# ── Load T1 corpus ──────────────────────────────────────────────
print("\nLoading 16K eigen corpus...")
from corpus_builder import DolphinCorpus, OFF, T1 as T1_DIM
corpus = DolphinCorpus.load(str(HERE / 'corpus_cache.npz'))
idx_mask = corpus.mask[:, 1]
X_e = corpus.X[idx_mask]
T1 = X_e[:, OFF[1]:OFF[1] + T1_DIM].copy() # (16607, 20)
N = len(T1)
print(f" T1 shape: {T1.shape} dtype: {T1.dtype}")
print(f" Any NaN in T1: {np.isnan(T1).any()}")
print(f" T1 stats: min={T1.min():.4f} max={T1.max():.4f} std={T1.std():.4f}")
# ── Train/val split (random 80/20 — avoids regime distribution shift) ────
rng_split = np.random.RandomState(42)
idx_all = rng_split.permutation(N)
n_train = int(N * 0.8)
X_train = T1[idx_all[:n_train]]
X_val = T1[idx_all[n_train:]]
print(f"\nTrain: {len(X_train)} Val: {len(X_val)}")
# ── Instantiate model ────────────────────────────────────────────
from flint_hd_vae import FlintHDVAE
print("\nInstantiating FlintHDVAE(beta=0.1, use_flint_norm=False — plain z-score for decoder test)...")
model = FlintHDVAE(input_dim=20, hd_dim=512, latent_dim=8, beta=0.1, seed=42,
use_flint_norm=False)
print(" OK")
# ── Train 40 epochs ──────────────────────────────────────────────
print("\nTraining 40 epochs on 80% T1 data (warmup first 30%)...")
PASS = True
try:
model.fit(X_train, epochs=40, lr=1e-3, batch_size=256, verbose=True, warmup_frac=0.3)
except Exception as ex:
import traceback
traceback.print_exc()
PASS = False
if PASS:
# Check 1: No NaN in losses
losses_arr = np.array(model.train_losses)
nan_losses = np.isnan(losses_arr).sum()
print(f"\n[CHECK 1] NaN in training losses: {nan_losses}")
if nan_losses == 0:
print(" PASS: No NaN detected")
else:
print(" FAIL: NaN losses detected!")
PASS = False
# Check 2: z_var per dim > 0.01 (use encode() which applies global MCDAIN)
mu_s = model.encode(X_train[:1000]) # (1000, 8)
var_per_dim = mu_s.var(axis=0) # (8,)
min_var = var_per_dim.min()
active_dims = (var_per_dim > 0.01).sum()
print(f"\n[CHECK 2] z_var per dim: {var_per_dim.round(4)}")
print(f" Min var: {min_var:.6f} Active dims (>0.01): {active_dims}/8")
if active_dims >= 4: # at least half active
print(f" PASS: {active_dims}/8 dims active (no posterior collapse)")
elif active_dims >= 1:
print(f" PARTIAL: Only {active_dims}/8 dims active — weak but not fully collapsed")
else:
print(f" FAIL: All dims collapsed (posterior collapse)")
PASS = False
# Check 3: Reconstruction MSE < 1.0 on val
# reconstruct() returns (T1_hat, X_val_norm) — both in same normalised space
T1_hat, X_val_norm = model.reconstruct(X_val)
recon_mse = float(np.mean((T1_hat - X_val_norm) ** 2))
print(f"\n[CHECK 3] Val reconstruction MSE: {recon_mse:.4f}")
if recon_mse < 1.0:
print(f" PASS: MSE={recon_mse:.4f} < 1.0")
else:
print(f" FAIL: MSE={recon_mse:.4f} >= 1.0 (decoder not learning)")
PASS = False
# Bonus: check encode output shape and values
z_all = model.encode(X_val[:100])
print(f"\n[BONUS] encode() output shape: {z_all.shape}")
print(f" z range: [{z_all.min():.4f}, {z_all.max():.4f}] std: {z_all.std():.4f}")
if np.isnan(z_all).any():
print(" WARNING: NaN in encode() output!")
print("\n" + "=" * 65)
if PASS:
print("OVERALL: PASS — FlintHDVAE inverse projection decoder functional")
else:
print("OVERALL: FAIL — see issues above")
print("=" * 65)