111 lines
4.5 KiB
Python
111 lines
4.5 KiB
Python
|
|
"""
|
||
|
|
Unit test for FlintHDVAE inverse projection decoder.
|
||
|
|
Criteria:
|
||
|
|
1. No NaN during training
|
||
|
|
2. z_var per dim > 0.01 after 20 epochs (no posterior collapse)
|
||
|
|
3. Reconstruction MSE < 1.0 on held-out 20%
|
||
|
|
"""
|
||
|
|
import sys, os
|
||
|
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
||
|
|
sys.path.insert(0, r"C:\Users\Lenovo\Documents\- DOLPHIN NG HD HCM TSF Predict")
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
HERE = Path(__file__).parent
|
||
|
|
|
||
|
|
print("=" * 65)
|
||
|
|
print("UNIT TEST: FlintHDVAE Inverse Projection Decoder")
|
||
|
|
print("=" * 65)
|
||
|
|
|
||
|
|
# ── Load T1 corpus ──────────────────────────────────────────────
|
||
|
|
print("\nLoading 16K eigen corpus...")
|
||
|
|
from corpus_builder import DolphinCorpus, OFF, T1 as T1_DIM
|
||
|
|
corpus = DolphinCorpus.load(str(HERE / 'corpus_cache.npz'))
|
||
|
|
idx_mask = corpus.mask[:, 1]
|
||
|
|
X_e = corpus.X[idx_mask]
|
||
|
|
T1 = X_e[:, OFF[1]:OFF[1] + T1_DIM].copy() # (16607, 20)
|
||
|
|
N = len(T1)
|
||
|
|
print(f" T1 shape: {T1.shape} dtype: {T1.dtype}")
|
||
|
|
print(f" Any NaN in T1: {np.isnan(T1).any()}")
|
||
|
|
print(f" T1 stats: min={T1.min():.4f} max={T1.max():.4f} std={T1.std():.4f}")
|
||
|
|
|
||
|
|
# ── Train/val split (random 80/20 — avoids regime distribution shift) ────
|
||
|
|
rng_split = np.random.RandomState(42)
|
||
|
|
idx_all = rng_split.permutation(N)
|
||
|
|
n_train = int(N * 0.8)
|
||
|
|
X_train = T1[idx_all[:n_train]]
|
||
|
|
X_val = T1[idx_all[n_train:]]
|
||
|
|
print(f"\nTrain: {len(X_train)} Val: {len(X_val)}")
|
||
|
|
|
||
|
|
# ── Instantiate model ────────────────────────────────────────────
|
||
|
|
from flint_hd_vae import FlintHDVAE
|
||
|
|
|
||
|
|
print("\nInstantiating FlintHDVAE(beta=0.1, use_flint_norm=False — plain z-score for decoder test)...")
|
||
|
|
model = FlintHDVAE(input_dim=20, hd_dim=512, latent_dim=8, beta=0.1, seed=42,
|
||
|
|
use_flint_norm=False)
|
||
|
|
print(" OK")
|
||
|
|
|
||
|
|
# ── Train 40 epochs ──────────────────────────────────────────────
|
||
|
|
print("\nTraining 40 epochs on 80% T1 data (warmup first 30%)...")
|
||
|
|
PASS = True
|
||
|
|
|
||
|
|
try:
|
||
|
|
model.fit(X_train, epochs=40, lr=1e-3, batch_size=256, verbose=True, warmup_frac=0.3)
|
||
|
|
except Exception as ex:
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
PASS = False
|
||
|
|
|
||
|
|
if PASS:
|
||
|
|
# Check 1: No NaN in losses
|
||
|
|
losses_arr = np.array(model.train_losses)
|
||
|
|
nan_losses = np.isnan(losses_arr).sum()
|
||
|
|
print(f"\n[CHECK 1] NaN in training losses: {nan_losses}")
|
||
|
|
if nan_losses == 0:
|
||
|
|
print(" PASS: No NaN detected")
|
||
|
|
else:
|
||
|
|
print(" FAIL: NaN losses detected!")
|
||
|
|
PASS = False
|
||
|
|
|
||
|
|
# Check 2: z_var per dim > 0.01 (use encode() which applies global MCDAIN)
|
||
|
|
mu_s = model.encode(X_train[:1000]) # (1000, 8)
|
||
|
|
var_per_dim = mu_s.var(axis=0) # (8,)
|
||
|
|
min_var = var_per_dim.min()
|
||
|
|
active_dims = (var_per_dim > 0.01).sum()
|
||
|
|
print(f"\n[CHECK 2] z_var per dim: {var_per_dim.round(4)}")
|
||
|
|
print(f" Min var: {min_var:.6f} Active dims (>0.01): {active_dims}/8")
|
||
|
|
if active_dims >= 4: # at least half active
|
||
|
|
print(f" PASS: {active_dims}/8 dims active (no posterior collapse)")
|
||
|
|
elif active_dims >= 1:
|
||
|
|
print(f" PARTIAL: Only {active_dims}/8 dims active — weak but not fully collapsed")
|
||
|
|
else:
|
||
|
|
print(f" FAIL: All dims collapsed (posterior collapse)")
|
||
|
|
PASS = False
|
||
|
|
|
||
|
|
# Check 3: Reconstruction MSE < 1.0 on val
|
||
|
|
# reconstruct() returns (T1_hat, X_val_norm) — both in same normalised space
|
||
|
|
T1_hat, X_val_norm = model.reconstruct(X_val)
|
||
|
|
recon_mse = float(np.mean((T1_hat - X_val_norm) ** 2))
|
||
|
|
print(f"\n[CHECK 3] Val reconstruction MSE: {recon_mse:.4f}")
|
||
|
|
if recon_mse < 1.0:
|
||
|
|
print(f" PASS: MSE={recon_mse:.4f} < 1.0")
|
||
|
|
else:
|
||
|
|
print(f" FAIL: MSE={recon_mse:.4f} >= 1.0 (decoder not learning)")
|
||
|
|
PASS = False
|
||
|
|
|
||
|
|
# Bonus: check encode output shape and values
|
||
|
|
z_all = model.encode(X_val[:100])
|
||
|
|
print(f"\n[BONUS] encode() output shape: {z_all.shape}")
|
||
|
|
print(f" z range: [{z_all.min():.4f}, {z_all.max():.4f}] std: {z_all.std():.4f}")
|
||
|
|
if np.isnan(z_all).any():
|
||
|
|
print(" WARNING: NaN in encode() output!")
|
||
|
|
|
||
|
|
print("\n" + "=" * 65)
|
||
|
|
if PASS:
|
||
|
|
print("OVERALL: PASS — FlintHDVAE inverse projection decoder functional")
|
||
|
|
else:
|
||
|
|
print("OVERALL: FAIL — see issues above")
|
||
|
|
print("=" * 65)
|