275 lines
13 KiB
Python
275 lines
13 KiB
Python
|
|
"""
|
|||
|
|
flint_hd_vae.py
|
|||
|
|
===============
|
|||
|
|
SILOQY-compatible HD-VAE with inverse projection decoder.
|
|||
|
|
|
|||
|
|
Architecture:
|
|||
|
|
Encoder:
|
|||
|
|
T1 (20-dim)
|
|||
|
|
→ MCDAIN 550-bit normalisation (no upstream modification — read-only call)
|
|||
|
|
→ HD random projection W_enc (20×512), ReLU → h (512)
|
|||
|
|
→ Linear bottleneck: W_mu (512×8), W_lv (512×8) → mu, logvar (8)
|
|||
|
|
→ reparameterisation → z (8)
|
|||
|
|
|
|||
|
|
Decoder (inverse projection — THE NEW PIECE):
|
|||
|
|
z (8)
|
|||
|
|
→ Linear W_dec (8×512), ReLU → h_hat (512) *inverse of bottleneck*
|
|||
|
|
→ Linear W_out (512×20) → T1_hat (20) *pseudo-inverse of HD proj*
|
|||
|
|
|
|||
|
|
Loss:
|
|||
|
|
recon = MSE(T1_hat, T1_norm)
|
|||
|
|
KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar)) [standard VAE KL]
|
|||
|
|
total = recon + beta * KL
|
|||
|
|
|
|||
|
|
No upstream files are modified. All SILOQY calls are read-only.
|
|||
|
|
"""
|
|||
|
|
import sys, os
|
|||
|
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|||
|
|
sys.path.insert(0, r"C:\Users\Lenovo\Documents\- DOLPHIN NG HD HCM TSF Predict")
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
from pathlib import Path
|
|||
|
|
from SILOQY_NN_Kernel_COMPLETE6 import arb, safe_float, FLINT_AVAILABLE, with_precision
|
|||
|
|
|
|||
|
|
EPS = 1e-8
|
|||
|
|
|
|||
|
|
# ── MCDAIN 550-bit normalisation (read-only logic, no upstream changes) ────
|
|||
|
|
def mcdain_550bit(X_raw: np.ndarray) -> np.ndarray:
|
|||
|
|
"""Apply MCDAIN analytical normalisation at 550-bit precision."""
|
|||
|
|
rows, cols = X_raw.shape
|
|||
|
|
X_norm = np.zeros_like(X_raw, dtype=np.float64)
|
|||
|
|
with with_precision(550):
|
|||
|
|
for j in range(cols):
|
|||
|
|
col = X_raw[:, j]
|
|||
|
|
col_abs = np.abs(col[np.isfinite(col)])
|
|||
|
|
if len(col_abs) == 0 or col_abs.mean() < 1e-12:
|
|||
|
|
continue
|
|||
|
|
magnitude = arb(str(float(col_abs.mean())))
|
|||
|
|
log_mag = magnitude.log()
|
|||
|
|
mean_val = magnitude * arb("0.1")
|
|||
|
|
scale_val = arb("1.0") / (log_mag + arb("1e-8"))
|
|||
|
|
gate_val = arb("1.0") / (arb("1.0") + (-log_mag).exp())
|
|||
|
|
m = safe_float(mean_val)
|
|||
|
|
s = safe_float(scale_val)
|
|||
|
|
g = safe_float(gate_val)
|
|||
|
|
X_norm[:, j] = np.clip((X_raw[:, j] - m) * s * g, -10, 10)
|
|||
|
|
return np.nan_to_num(X_norm, nan=0.0, posinf=5.0, neginf=-5.0)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── Adam optimiser state ───────────────────────────────────────────────────
|
|||
|
|
class AdamParam:
|
|||
|
|
def __init__(self, shape, seed=0):
|
|||
|
|
rng = np.random.RandomState(seed)
|
|||
|
|
scale = np.sqrt(2.0 / shape[0])
|
|||
|
|
self.W = rng.randn(*shape).astype(np.float64) * scale
|
|||
|
|
self.m = np.zeros_like(self.W)
|
|||
|
|
self.v = np.zeros_like(self.W)
|
|||
|
|
self.t = 0
|
|||
|
|
|
|||
|
|
def step(self, grad, lr=1e-3, b1=0.9, b2=0.999):
|
|||
|
|
self.t += 1
|
|||
|
|
self.m = b1 * self.m + (1 - b1) * grad
|
|||
|
|
self.v = b2 * self.v + (1 - b2) * grad**2
|
|||
|
|
m_hat = self.m / (1 - b1**self.t)
|
|||
|
|
v_hat = self.v / (1 - b2**self.t)
|
|||
|
|
self.W -= lr * m_hat / (np.sqrt(v_hat) + EPS)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ── FlintHDVAE ────────────────────────────────────────────────────────────
|
|||
|
|
class FlintHDVAE:
|
|||
|
|
"""
|
|||
|
|
HD-VAE with 550-bit MCDAIN encoder normalisation.
|
|||
|
|
Inverse projection decoder: z(8) → Linear+ReLU(512) → Linear(20).
|
|||
|
|
"""
|
|||
|
|
def __init__(self, input_dim=20, hd_dim=512, latent_dim=8,
|
|||
|
|
beta=0.5, seed=42, use_flint_norm=True):
|
|||
|
|
self.input_dim = input_dim
|
|||
|
|
self.hd_dim = hd_dim
|
|||
|
|
self.latent_dim = latent_dim
|
|||
|
|
self.beta = beta
|
|||
|
|
self.use_flint = use_flint_norm and FLINT_AVAILABLE
|
|||
|
|
|
|||
|
|
rng = np.random.RandomState(seed)
|
|||
|
|
# Fixed random HD projection (encoder side, non-trainable)
|
|||
|
|
self.W_hd = rng.randn(input_dim, hd_dim).astype(np.float64) * np.sqrt(2.0/input_dim)
|
|||
|
|
|
|||
|
|
# Trainable parameters — encoder bottleneck
|
|||
|
|
self.P_mu = AdamParam((hd_dim, latent_dim), seed=seed+1)
|
|||
|
|
self.P_lv = AdamParam((hd_dim, latent_dim), seed=seed+2)
|
|||
|
|
|
|||
|
|
# Trainable parameters — DECODER (inverse projection, THE NEW PIECE)
|
|||
|
|
self.P_dec = AdamParam((latent_dim, hd_dim), seed=seed+3) # z→h_hat
|
|||
|
|
self.P_out = AdamParam((hd_dim, input_dim), seed=seed+4) # h_hat→T1_hat
|
|||
|
|
|
|||
|
|
# Normaliser stats (fitted once)
|
|||
|
|
self._norm_fitted = False
|
|||
|
|
self._norm_mu = np.zeros(input_dim)
|
|||
|
|
self._norm_sd = np.ones(input_dim)
|
|||
|
|
|
|||
|
|
self.train_losses = []
|
|||
|
|
|
|||
|
|
# ── Normalisation ──────────────────────────────────────────────────────
|
|||
|
|
def fit_normaliser(self, X: np.ndarray):
|
|||
|
|
"""Fit normaliser stats from the FULL training set (called once).
|
|||
|
|
For MCDAIN: computes global per-column m/s/g and stores them so that
|
|||
|
|
all subsequent _normalise() calls are deterministic (no batch-dependency).
|
|||
|
|
Falls back to z-score if FLINT unavailable."""
|
|||
|
|
self._norm_mu = X.mean(0)
|
|||
|
|
self._norm_sd = X.std(0) + EPS
|
|||
|
|
if self.use_flint:
|
|||
|
|
# Compute MCDAIN params column-wise on full X, store as fixed stats
|
|||
|
|
X_norm_full = mcdain_550bit(X)
|
|||
|
|
# Store the effective per-column shift/scale as z-score of the MCDAIN output
|
|||
|
|
self._mcdain_mu = X_norm_full.mean(0)
|
|||
|
|
self._mcdain_sd = X_norm_full.std(0) + EPS
|
|||
|
|
# Also store the raw MCDAIN params by fitting a passthrough
|
|||
|
|
self._mcdain_fitted = True
|
|||
|
|
self._X_norm_ref = X_norm_full # kept for diagnostics only (not used in loops)
|
|||
|
|
self._norm_fitted = True
|
|||
|
|
|
|||
|
|
def _normalise(self, X: np.ndarray) -> np.ndarray:
|
|||
|
|
if self.use_flint and self._norm_fitted and hasattr(self, '_mcdain_fitted'):
|
|||
|
|
# Apply MCDAIN then standardise using TRAINING statistics
|
|||
|
|
# This makes normalisation deterministic regardless of batch size
|
|||
|
|
raw = mcdain_550bit(X)
|
|||
|
|
return (raw - self._mcdain_mu) / self._mcdain_sd
|
|||
|
|
return (X - self._norm_mu) / self._norm_sd
|
|||
|
|
|
|||
|
|
# ── Forward pass ──────────────────────────────────────────────────────
|
|||
|
|
def _encode(self, X_norm, rng):
|
|||
|
|
"""X_norm (B,20) → h (B,512) → mu,logvar (B,8) → z (B,8)"""
|
|||
|
|
h = np.maximum(0, X_norm @ self.W_hd) # (B, 512) ReLU
|
|||
|
|
mu = h @ self.P_mu.W # (B, 8)
|
|||
|
|
lv = np.clip(h @ self.P_lv.W, -4, 4) # (B, 8)
|
|||
|
|
eps = rng.randn(*mu.shape)
|
|||
|
|
z = mu + np.exp(0.5 * lv) * eps # reparam
|
|||
|
|
return h, mu, lv, z
|
|||
|
|
|
|||
|
|
def _decode(self, z):
|
|||
|
|
"""z (B,8) → h_hat (B,512) → T1_hat (B,20) — INVERSE PROJECTION"""
|
|||
|
|
h_hat = np.maximum(0, z @ self.P_dec.W) # (B, 512) ReLU
|
|||
|
|
T1_hat = h_hat @ self.P_out.W # (B, 20) linear
|
|||
|
|
return h_hat, T1_hat
|
|||
|
|
|
|||
|
|
# ── Loss ──────────────────────────────────────────────────────────────
|
|||
|
|
def _loss(self, T1_norm, T1_hat, mu, lv):
|
|||
|
|
B = len(T1_norm)
|
|||
|
|
recon = np.mean((T1_hat - T1_norm)**2)
|
|||
|
|
kl = -0.5 * np.mean(1 + lv - mu**2 - np.exp(lv))
|
|||
|
|
total = recon + self.beta * kl
|
|||
|
|
return total, recon, kl
|
|||
|
|
|
|||
|
|
# ── Backward (analytical gradients) ───────────────────────────────────
|
|||
|
|
def _backward(self, T1_norm, T1_hat, h, h_hat, mu, lv, z, lr):
|
|||
|
|
B = len(T1_norm)
|
|||
|
|
|
|||
|
|
# ── Decoder gradients ────────────────────────────────────────────
|
|||
|
|
# dL/dT1_hat = 2*(T1_hat - T1_norm) / (B*D)
|
|||
|
|
dT1 = 2.0 * (T1_hat - T1_norm) / (B * self.input_dim)
|
|||
|
|
|
|||
|
|
# W_out: h_hat.T @ dT1
|
|||
|
|
dW_out = h_hat.T @ dT1 # (512, 20)
|
|||
|
|
self.P_out.step(dW_out, lr)
|
|||
|
|
|
|||
|
|
# Back through ReLU of h_hat
|
|||
|
|
dh_hat = (dT1 @ self.P_out.W.T) * (h_hat > 0) # (B, 512)
|
|||
|
|
|
|||
|
|
# W_dec: z.T @ dh_hat
|
|||
|
|
dW_dec = z.T @ dh_hat # (8, 512)
|
|||
|
|
self.P_dec.step(dW_dec, lr)
|
|||
|
|
|
|||
|
|
# dz from decoder
|
|||
|
|
dz_dec = dh_hat @ self.P_dec.W.T # (B, 8)
|
|||
|
|
|
|||
|
|
# ── KL gradients (standard VAE) ──────────────────────────────────
|
|||
|
|
# dKL/dmu = mu/B; dKL/dlv = 0.5*(exp(lv)-1)/B
|
|||
|
|
dmu_kl = self.beta * mu / B
|
|||
|
|
dlv_kl = self.beta * 0.5 * (np.exp(lv) - 1) / B
|
|||
|
|
|
|||
|
|
# ── Reparameterisation: dz flows back to mu and lv ───────────────
|
|||
|
|
# z = mu + exp(0.5*lv)*eps → dmu = dz, dlv = dz*0.5*z (approx)
|
|||
|
|
dmu = dz_dec + dmu_kl
|
|||
|
|
dlv = dz_dec * 0.5 * (z - mu) + dlv_kl # chain rule
|
|||
|
|
|
|||
|
|
# ── Encoder bottleneck gradients ─────────────────────────────────
|
|||
|
|
dW_mu = h.T @ dmu # (512, 8)
|
|||
|
|
dW_lv = h.T @ dlv
|
|||
|
|
self.P_mu.step(dW_mu, lr)
|
|||
|
|
self.P_lv.step(dW_lv, lr)
|
|||
|
|
# (W_hd is fixed, no gradient needed for it)
|
|||
|
|
|
|||
|
|
# ── Training ──────────────────────────────────────────────────────────
|
|||
|
|
def fit(self, X: np.ndarray, epochs=30, lr=1e-3,
|
|||
|
|
batch_size=256, verbose=True, warmup_frac=0.3):
|
|||
|
|
"""
|
|||
|
|
warmup_frac: fraction of epochs over which beta ramps 0 → self.beta.
|
|||
|
|
Prevents KL from dominating before the decoder learns to reconstruct.
|
|||
|
|
"""
|
|||
|
|
rng = np.random.RandomState(42)
|
|||
|
|
self.fit_normaliser(X) # computes global MCDAIN stats once
|
|||
|
|
X_norm = self._normalise(X) # normalise full dataset once; stable across batches
|
|||
|
|
N = len(X_norm)
|
|||
|
|
target_beta = self.beta
|
|||
|
|
warmup_epochs = max(1, int(epochs * warmup_frac))
|
|||
|
|
|
|||
|
|
for epoch in range(1, epochs + 1):
|
|||
|
|
# KL warmup: ramp beta from 0 to target over first warmup_epochs
|
|||
|
|
if epoch <= warmup_epochs:
|
|||
|
|
self.beta = target_beta * (epoch / warmup_epochs)
|
|||
|
|
else:
|
|||
|
|
self.beta = target_beta
|
|||
|
|
|
|||
|
|
idx = rng.permutation(N)
|
|||
|
|
ep_loss = ep_recon = ep_kl = 0.0
|
|||
|
|
n_batches = 0
|
|||
|
|
for start in range(0, N, batch_size):
|
|||
|
|
bi = idx[start:start + batch_size]
|
|||
|
|
Xb = X_norm[bi] # already normalised with global stats
|
|||
|
|
h, mu, lv, z = self._encode(Xb, rng)
|
|||
|
|
h_hat, T1_hat = self._decode(z)
|
|||
|
|
loss, recon, kl = self._loss(Xb, T1_hat, mu, lv)
|
|||
|
|
self._backward(Xb, T1_hat, h, h_hat, mu, lv, z, lr)
|
|||
|
|
ep_loss += loss; ep_recon += recon; ep_kl += kl
|
|||
|
|
n_batches += 1
|
|||
|
|
|
|||
|
|
ep_loss /= n_batches; ep_recon /= n_batches; ep_kl /= n_batches
|
|||
|
|
self.train_losses.append(ep_loss)
|
|||
|
|
if verbose and (epoch % 5 == 0 or epoch == 1):
|
|||
|
|
# Anti-collapse diagnostic: encode a fixed held-out sample
|
|||
|
|
sample_norm = X_norm[:min(1000, N)]
|
|||
|
|
_, mu_s, _, _ = self._encode(sample_norm, rng)
|
|||
|
|
var_per_dim = mu_s.var(0)
|
|||
|
|
print(f" ep{epoch:3d}/{epochs} beta={self.beta:.3f} "
|
|||
|
|
f"loss={ep_loss:.4f} recon={ep_recon:.4f} kl={ep_kl:.4f} "
|
|||
|
|
f"z_var=[{' '.join(f'{v:.3f}' for v in var_per_dim)}]")
|
|||
|
|
|
|||
|
|
self.beta = target_beta # restore after training
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
# ── Encode for downstream use ─────────────────────────────────────────
|
|||
|
|
def encode(self, X: np.ndarray) -> np.ndarray:
|
|||
|
|
"""Return deterministic mu (B, latent_dim) for all samples.
|
|||
|
|
Normalisation is deterministic (global MCDAIN stats from fit_normaliser)."""
|
|||
|
|
rng = np.random.RandomState(0)
|
|||
|
|
STEP = 512
|
|||
|
|
mus = []
|
|||
|
|
for s in range(0, len(X), STEP):
|
|||
|
|
Xb = self._normalise(X[s:s+STEP])
|
|||
|
|
_, mu, _, _ = self._encode(Xb, rng)
|
|||
|
|
mus.append(mu)
|
|||
|
|
return np.concatenate(mus)
|
|||
|
|
|
|||
|
|
def reconstruct(self, X: np.ndarray) -> np.ndarray:
|
|||
|
|
"""Returns (T1_hat, X_norm) both in the same normalised space.
|
|||
|
|
Normalisation is deterministic (global MCDAIN stats from fit_normaliser)."""
|
|||
|
|
rng = np.random.RandomState(0)
|
|||
|
|
Xn = self._normalise(X)
|
|||
|
|
STEP = 512
|
|||
|
|
hats = []
|
|||
|
|
for s in range(0, len(Xn), STEP):
|
|||
|
|
_, mu, _, _ = self._encode(Xn[s:s+STEP], rng)
|
|||
|
|
_, T1_hat = self._decode(mu)
|
|||
|
|
hats.append(T1_hat)
|
|||
|
|
return np.concatenate(hats), Xn
|