Files
DOLPHIN/nautilus_dolphin/dvae/hierarchical_dvae.py

560 lines
26 KiB
Python
Raw Normal View History

"""
Hierarchical Disentangled VAE (H-D-VAE) for DOLPHIN
=====================================================
State-of-the-art β-TCVAE with:
- 3-level HIERARCHY of disentangled latent codes
- Real backpropagation (analytical gradients, no random noise)
- Masking for missing tiers (pre-eigenvalue era)
- Curriculum-aware: can freeze/thaw tiers during training
- Spectral features with 512-bit-aware log/ratio encoding
Latent structure:
z0 (dim=4) "macro regime" Tier-0 breadth + time
z1 (dim=8) "eigenstructure" Tier-1 eigenvalues, conditioned on z0
z2 (dim=8) "cross-section" Tier-2 per-asset, conditioned on z0+z1
Total latent dim: 20.
β-TCVAE decomposition (Chen et al. 2018):
KL = MI(x;z) + TC(z) + dim_KL(z)
Loss = Recon + γ·MI + β·TC + λ·dim_KL
Key insight: TC penalises factorial structure violation, encouraging
each dimension of z to capture an INDEPENDENT factor of variation.
With β >> 1 the model is forced to find the minimal sufficient encoding.
"""
import numpy as np
from typing import Tuple, Dict, Optional, List
# ── Constants ──────────────────────────────────────────────────────────────
# Must match corpus_builder.py DIMS = [8, 20, 50, 25, 8]
TIER0_DIM = 8
TIER1_DIM = 20
TIER2_DIM = 50
TIER3_DIM = 25 # ExF macro
TIER4_DIM = 8 # EsoF
# Tier offsets into 111-dim feature vector
T_OFF = [0, 8, 28, 78, 103]
Z0_DIM = 4 # macro regime
Z1_DIM = 8 # eigenstructure
Z2_DIM = 8 # cross-section + ExF
Z3_DIM = 4 # esoteric
Z_TOTAL = Z0_DIM + Z1_DIM + Z2_DIM + Z3_DIM # 24
EPS = 1e-8
# ── Numerics ───────────────────────────────────────────────────────────────
def _relu(x): return np.maximum(0, x)
def _drelu(x): return (x > 0).astype(x.dtype)
def _tanh(x): return np.tanh(x)
def _dtanh(x): return 1.0 - np.tanh(x)**2
def _softplus(x): return np.log1p(np.exp(np.clip(x, -20, 20)))
# ── Linear layer (weights + bias) with gradient ───────────────────────────
class Linear:
"""Affine layer with Adam optimiser state."""
def __init__(self, in_dim: int, out_dim: int, seed: int = 42):
rng = np.random.RandomState(seed)
scale = np.sqrt(2.0 / in_dim)
self.W = rng.randn(in_dim, out_dim).astype(np.float32) * scale
self.b = np.zeros(out_dim, dtype=np.float32)
# Adam state
self.mW = np.zeros_like(self.W)
self.vW = np.zeros_like(self.W)
self.mb = np.zeros_like(self.b)
self.vb = np.zeros_like(self.b)
def forward(self, x: np.ndarray) -> np.ndarray:
self._x = x
return x @ self.W + self.b
def backward(self, dout: np.ndarray) -> np.ndarray:
self.dW = self._x.T @ dout
self.db = dout.sum(axis=0)
return dout @ self.W.T
def step(self, lr: float, t: int, beta1=0.9, beta2=0.999):
bc1 = 1 - beta1**t
bc2 = 1 - beta2**t
for (p, m, v, g) in [(self.W, self.mW, self.vW, self.dW),
(self.b, self.mb, self.vb, self.db)]:
m[:] = beta1 * m + (1 - beta1) * g
v[:] = beta2 * v + (1 - beta2) * g**2
p -= lr * (m / bc1) / (np.sqrt(v / bc2) + 1e-8)
# ── Simple 2-layer MLP with ReLU ──────────────────────────────────────────
class MLP:
def __init__(self, dims: List[int], seed: int = 42):
self.layers = []
for i, (d_in, d_out) in enumerate(zip(dims[:-1], dims[1:])):
self.layers.append(Linear(d_in, d_out, seed=seed + i))
# Activation cache
self._acts = []
def forward(self, x: np.ndarray, final_linear=True) -> np.ndarray:
self._acts = []
h = x
for i, layer in enumerate(self.layers):
h = layer.forward(h)
if i < len(self.layers) - 1 or not final_linear:
self._acts.append(h.copy())
h = _relu(h)
else:
self._acts.append(None)
return h
def backward(self, dout: np.ndarray) -> np.ndarray:
dh = dout
for i in range(len(self.layers) - 1, -1, -1):
if i < len(self.layers) - 1:
dh = dh * _drelu(self._acts[i])
dh = self.layers[i].backward(dh)
return dh
def step(self, lr: float, t: int):
for layer in self.layers:
layer.step(lr, t)
# ── Encoder: MLP → (mu, logvar) ───────────────────────────────────────────
class VAEEncoder:
"""Encodes input (optionally concatenated with conditioning z) to (mu, logvar)."""
def __init__(self, in_dim: int, hidden: int, z_dim: int, seed: int = 0):
self.mlp = MLP([in_dim, hidden, hidden], seed=seed)
self.mu_head = Linear(hidden, z_dim, seed=seed + 100)
self.lv_head = Linear(hidden, z_dim, seed=seed + 200)
self.z_dim = z_dim
def forward(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
h = self.mlp.forward(x) # (B, hidden)
h = _relu(h)
self._h = h
mu = self.mu_head.forward(h)
logvar = self.lv_head.forward(h)
logvar = np.clip(logvar, -10, 4) # keep variance sane
return mu, logvar
def backward(self, dmu: np.ndarray, dlv: np.ndarray) -> np.ndarray:
dh_mu = self.mu_head.backward(dmu)
dh_lv = self.lv_head.backward(dlv)
dh = (dh_mu + dh_lv) * _drelu(self._h)
return self.mlp.backward(dh)
def step(self, lr: float, t: int):
self.mlp.step(lr, t)
self.mu_head.step(lr, t)
self.lv_head.step(lr, t)
# ── Decoder: z → x_hat ────────────────────────────────────────────────────
class VAEDecoder:
def __init__(self, z_dim: int, hidden: int, out_dim: int, seed: int = 0):
self.mlp = MLP([z_dim, hidden, hidden, out_dim], seed=seed + 300)
self.out_dim = out_dim
def forward(self, z: np.ndarray) -> np.ndarray:
return self.mlp.forward(z, final_linear=True)
def backward(self, dout: np.ndarray) -> np.ndarray:
return self.mlp.backward(dout)
def step(self, lr: float, t: int):
self.mlp.step(lr, t)
# ── β-TCVAE loss (Chen et al. 2018, proper dataset-size corrected form) ───
def btcvae_kl(mu: np.ndarray, logvar: np.ndarray,
N_dataset: int, beta: float = 4.0,
gamma: float = 1.0, lam: float = 1.0
) -> Tuple[float, np.ndarray, np.ndarray]:
"""
Decompose KL into: MI + beta*TC + lambda*dim_KL
Returns (total_kl, dmu, dlogvar).
MI = E_q[log q(z|x)] - E_q[log q(z)]
TC = E_q[log q(z)] - E_q[sum_j log q(z_j)]
dKL = E_q[sum_j log q(z_j) - log p(z_j)]
Minibatch-weighted estimator numerically stable.
"""
B, D = mu.shape
var = np.exp(logvar)
# Per-sample KL (closed form vs N(0,1) prior)
kl_per_dim = 0.5 * (mu**2 + var - logvar - 1) # (B, D)
# Log q(z|x_i) for sample z_i ~ q(z|x_i) (diagonal Gaussian)
# We use the mean as the sample point (reparametrization noise ≈ 0 for gradient estimation)
# log q(z_i | x_i) = sum_d -0.5*(log(2π) + logvar + (z-mu)^2/var)
# At z = mu: (z-mu)=0, so log q(z|x) = -0.5*sum(log(2π) + logvar)
log_q_z_given_x = -0.5 * (D * np.log(2 * np.pi) + logvar.sum(axis=1)) # (B,)
# Minibatch estimate of log q(z): E_data[sum_x log q(z|x)] / N
# log q(z_i) ≈ log(1/N * sum_j q(z_i | x_j))
# Using pairwise: log q(z_i) ≈ logsumexp over j of log q(z_i | x_j) - log(N)
# q(z_i | x_j) for each pair: diffs = z_i[none,:] - mu[none,:] shape (B,B,D)
z_sample = mu # (B, D) — use mean as point estimate
# (B_i, B_j, D) pairwise differences
diff = z_sample[:, None, :] - mu[None, :, :] # (B, B, D)
log_q_z_pair = -0.5 * (
D * np.log(2 * np.pi)
+ logvar[None, :, :].sum(axis=2)
+ (diff**2 / (var[None, :, :] + EPS)).sum(axis=2)
) # (B, B)
log_q_z = np.log(np.sum(np.exp(log_q_z_pair - log_q_z_pair.max(axis=1, keepdims=True)), axis=1) + EPS) \
+ log_q_z_pair.max(axis=1) - np.log(N_dataset + EPS) # (B,)
# log q(z_j) marginal — independence assumption for TC
# log q(z_i_j) for each dim j: logsumexp over data points
log_q_z_product = np.zeros(B, dtype=np.float64)
for j in range(D):
diff_j = z_sample[:, j:j+1] - mu[:, j:j+1].T # (B, B)
log_q_zj = -0.5 * (np.log(2 * np.pi) + logvar[None, :, j] + diff_j**2 / (var[None, :, j] + EPS)) # (B,B)
log_q_zj_marginal = np.log(np.sum(np.exp(log_q_zj - log_q_zj.max(axis=1, keepdims=True)), axis=1) + EPS) \
+ log_q_zj.max(axis=1) - np.log(N_dataset + EPS) # (B,)
log_q_z_product += log_q_zj_marginal
# log p(z) = N(0,1)
log_p_z = -0.5 * (D * np.log(2 * np.pi) + (z_sample**2).sum(axis=1)) # (B,)
mi_loss = np.mean(log_q_z_given_x - log_q_z)
tc_loss = np.mean(log_q_z - log_q_z_product)
dkl_loss = np.mean(log_q_z_product - log_p_z)
total_kl = gamma * mi_loss + beta * max(tc_loss, 0) + lam * dkl_loss
# Gradients w.r.t. mu and logvar: use standard closed-form KL gradient
# ∂KL/∂mu = mu, ∂KL/∂logvar = 0.5*(exp(logvar) - 1)
scale = (gamma + beta + lam) / (3.0 * B + EPS)
dmu = scale * mu
dlogvar = scale * 0.5 * (var - 1.0)
return float(total_kl), dmu.astype(np.float32), dlogvar.astype(np.float32)
# ── Reparametrization ─────────────────────────────────────────────────────
def reparametrize(mu, logvar, rng):
std = np.exp(0.5 * logvar)
eps = rng.randn(*mu.shape).astype(np.float32)
return mu + eps * std
# ── Hierarchical D-VAE ────────────────────────────────────────────────────
class HierarchicalDVAE:
"""
4-level Hierarchical Disentangled VAE (111-dim input, 24-dim latent).
Latent hierarchy:
z0 (4) = enc0(T0) macro regime (breadth + time)
z1 (8) = enc1(T1 + T4 + z0) eigenstructure + esoteric, cond on z0
z2 (8) = enc2(T2 + T3 + z0 + z1) cross-section + ExF, cond on z0,z1
z3 (4) = enc3(T3 + T4 + z0) ExF+EsoF regime, cond on z0
Decoding:
T0_hat = dec0(z0)
T1_hat = dec1(z0, z1)
T2_hat = dec2(z0, z1, z2)
T3_hat = dec3(z0, z3)
T4_hat = dec4(z0) (EsoF is deterministic, serves as auxiliary)
Training phases:
0: enc0/dec0 only full 500K+ corpus (even pre-eigen NG1/NG2)
1: + enc1/dec1 eigen-tier (NG3+)
2: + enc2/dec2 pricing + ExF (NG3+ with pricing)
3: + enc3/dec3 ExF regime
4: joint fine-tune all
"""
def __init__(self, hidden: int = 64, beta: float = 4.0,
gamma: float = 1.0, lam: float = 1.0, seed: int = 42):
self.hidden = hidden
self.beta = beta
self.gamma = gamma
self.lam = lam
self.seed = seed
# Encoders (each conditioned on upstream z)
self.enc0 = VAEEncoder(TIER0_DIM, hidden, Z0_DIM, seed=seed)
self.enc1 = VAEEncoder(TIER1_DIM + TIER4_DIM + Z0_DIM, hidden, Z1_DIM, seed=seed+10)
self.enc2 = VAEEncoder(TIER2_DIM + TIER3_DIM + Z0_DIM + Z1_DIM, hidden, Z2_DIM, seed=seed+20)
self.enc3 = VAEEncoder(TIER3_DIM + TIER4_DIM + Z0_DIM, hidden//2, Z3_DIM, seed=seed+30)
# Decoders
self.dec0 = VAEDecoder(Z0_DIM, hidden, TIER0_DIM, seed=seed)
self.dec1 = VAEDecoder(Z0_DIM + Z1_DIM, hidden, TIER1_DIM, seed=seed+10)
self.dec2 = VAEDecoder(Z0_DIM + Z1_DIM + Z2_DIM, hidden, TIER2_DIM, seed=seed+20)
self.dec3 = VAEDecoder(Z0_DIM + Z3_DIM, hidden//2, TIER3_DIM, seed=seed+30)
self.dec4 = VAEDecoder(Z0_DIM, hidden//2, TIER4_DIM, seed=seed+40)
# Normalisation statistics (fit on training data)
self._mu_t0 = np.zeros(TIER0_DIM, dtype=np.float32)
self._sd_t0 = np.ones(TIER0_DIM, dtype=np.float32)
self._mu_t1 = np.zeros(TIER1_DIM, dtype=np.float32)
self._sd_t1 = np.ones(TIER1_DIM, dtype=np.float32)
self._mu_t2 = np.zeros(TIER2_DIM, dtype=np.float32)
self._sd_t2 = np.ones(TIER2_DIM, dtype=np.float32)
self._mu_t3 = np.zeros(TIER3_DIM, dtype=np.float32)
self._sd_t3 = np.ones(TIER3_DIM, dtype=np.float32)
self._mu_t4 = np.zeros(TIER4_DIM, dtype=np.float32)
self._sd_t4 = np.ones(TIER4_DIM, dtype=np.float32)
self.step_t = 0 # Adam time step
self.train_losses: List[Dict] = []
# ── Normalisation ──────────────────────────────────────────────────────
def fit_normaliser(self, X: np.ndarray, mask: np.ndarray):
"""Fit per-tier normalisation on a representative sample."""
def _fit(rows, dim):
if len(rows) < 2:
return np.zeros(dim, dtype=np.float32), np.ones(dim, dtype=np.float32)
return rows.mean(0).astype(np.float32), (rows.std(0) + EPS).astype(np.float32)
self._mu_t0, self._sd_t0 = _fit(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM], TIER0_DIM)
idx1 = mask[:, 1]
self._mu_t1, self._sd_t1 = _fit(X[idx1, T_OFF[1]:T_OFF[1]+TIER1_DIM], TIER1_DIM)
idx2 = mask[:, 2]
self._mu_t2, self._sd_t2 = _fit(X[idx2, T_OFF[2]:T_OFF[2]+TIER2_DIM], TIER2_DIM)
idx3 = mask[:, 3]
self._mu_t3, self._sd_t3 = _fit(X[idx3, T_OFF[3]:T_OFF[3]+TIER3_DIM], TIER3_DIM)
self._mu_t4, self._sd_t4 = _fit(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM], TIER4_DIM)
def _norm(self, x, mu, sd): return (x - mu) / sd
def _norm0(self, x): return self._norm(x, self._mu_t0, self._sd_t0)
def _norm1(self, x): return self._norm(x, self._mu_t1, self._sd_t1)
def _norm2(self, x): return self._norm(x, self._mu_t2, self._sd_t2)
def _norm3(self, x): return self._norm(x, self._mu_t3, self._sd_t3)
def _norm4(self, x): return self._norm(x, self._mu_t4, self._sd_t4)
# ── Forward pass ──────────────────────────────────────────────────────
def _split(self, X):
"""Split 111-dim row into normalised tier vectors."""
t0 = self._norm0(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM])
t1 = self._norm1(X[:, T_OFF[1]:T_OFF[1]+TIER1_DIM])
t2 = self._norm2(X[:, T_OFF[2]:T_OFF[2]+TIER2_DIM])
t3 = self._norm3(X[:, T_OFF[3]:T_OFF[3]+TIER3_DIM])
t4 = self._norm4(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM])
return t0, t1, t2, t3, t4
def encode(self, X: np.ndarray, mask: np.ndarray, rng) -> Dict:
t0, t1, t2, t3, t4 = self._split(X)
mu0, lv0 = self.enc0.forward(t0)
z0 = reparametrize(mu0, lv0, rng)
mu1, lv1 = self.enc1.forward(np.concatenate([t1, t4, z0], axis=1))
z1 = reparametrize(mu1, lv1, rng)
mu2, lv2 = self.enc2.forward(np.concatenate([t2, t3, z0, z1], axis=1))
z2 = reparametrize(mu2, lv2, rng)
mu3, lv3 = self.enc3.forward(np.concatenate([t3, t4, z0], axis=1))
z3 = reparametrize(mu3, lv3, rng)
return dict(t0=t0, t1=t1, t2=t2, t3=t3, t4=t4,
mu0=mu0, lv0=lv0, z0=z0,
mu1=mu1, lv1=lv1, z1=z1,
mu2=mu2, lv2=lv2, z2=z2,
mu3=mu3, lv3=lv3, z3=z3)
def decode(self, enc: Dict) -> Dict:
z0, z1, z2, z3 = enc['z0'], enc['z1'], enc['z2'], enc['z3']
x0_hat = self.dec0.forward(z0)
x1_hat = self.dec1.forward(np.concatenate([z0, z1], axis=1))
x2_hat = self.dec2.forward(np.concatenate([z0, z1, z2], axis=1))
x3_hat = self.dec3.forward(np.concatenate([z0, z3], axis=1))
x4_hat = self.dec4.forward(z0) # EsoF decoded from macro only
return dict(x0_hat=x0_hat, x1_hat=x1_hat, x2_hat=x2_hat,
x3_hat=x3_hat, x4_hat=x4_hat)
# ── Loss and gradients ─────────────────────────────────────────────────
def loss_and_grads(self, enc: Dict, dec: Dict, mask: np.ndarray,
N_dataset: int, phase: int) -> Dict:
"""Compute total loss and all gradients via chain rule."""
t0, t1, t2 = enc['t0'], enc['t1'], enc['t2']
x0_hat = dec['x0_hat']
x1_hat = dec['x1_hat']
x2_hat = dec['x2_hat']
B = len(t0)
# Mask weights: per-sample, per-tier
w1 = mask[:, 1].astype(np.float32)[:, None] # (B,1)
w2 = mask[:, 2].astype(np.float32)[:, None]
# Reconstruction losses
recon0 = np.mean((x0_hat - t0)**2)
recon1 = np.mean(w1 * (x1_hat - t1)**2)
recon2 = np.mean(w2 * (x2_hat - t2)**2)
# Reconstruction gradients
dr0 = 2 * (x0_hat - t0) / (B * TIER0_DIM)
dr1 = 2 * w1 * (x1_hat - t1) / (B * TIER1_DIM)
dr2 = 2 * w2 * (x2_hat - t2) / (B * TIER2_DIM)
# KL losses
kl0, dmu0_kl, dlv0_kl = btcvae_kl(enc['mu0'], enc['lv0'], N_dataset, self.beta, self.gamma, self.lam)
kl1, dmu1_kl, dlv1_kl = btcvae_kl(enc['mu1'], enc['lv1'], N_dataset, self.beta, self.gamma, self.lam)
kl2, dmu2_kl, dlv2_kl = btcvae_kl(enc['mu2'], enc['lv2'], N_dataset, self.beta, self.gamma, self.lam)
# Phase-based scaling: only activate eigen/pricing in correct phase
if phase < 1:
kl1 = recon1 = 0.0; dmu1_kl[:] = 0; dlv1_kl[:] = 0; dr1[:] = 0
if phase < 2:
kl2 = recon2 = 0.0; dmu2_kl[:] = 0; dlv2_kl[:] = 0; dr2[:] = 0
total = recon0 + recon1 + recon2 + kl0 + kl1 + kl2
return dict(
total=total, recon0=recon0, recon1=recon1, recon2=recon2,
kl0=kl0, kl1=kl1, kl2=kl2,
dr0=dr0, dr1=dr1, dr2=dr2,
dmu0_kl=dmu0_kl, dlv0_kl=dlv0_kl,
dmu1_kl=dmu1_kl, dlv1_kl=dlv1_kl,
dmu2_kl=dmu2_kl, dlv2_kl=dlv2_kl,
)
def backward(self, enc: Dict, dec: Dict, grads: Dict, phase: int, lr: float):
"""Backpropagate through all encoders and decoders, then Adam step."""
z0, z1, z2 = enc['z0'], enc['z1'], enc['z2']
t = self.step_t
# ── Decoder 2 (z0+z1+z2 → x2) ──────────────────────────────────
if phase >= 2:
dz_all_2 = self.dec2.backward(grads['dr2']) # (B, Z0+Z1+Z2)
dz0_from_d2 = dz_all_2[:, :Z0_DIM]
dz1_from_d2 = dz_all_2[:, Z0_DIM:Z0_DIM + Z1_DIM]
dz2_from_d2 = dz_all_2[:, Z0_DIM + Z1_DIM:]
else:
dz0_from_d2 = np.zeros_like(z0)
dz1_from_d2 = np.zeros_like(z1)
dz2_from_d2 = np.zeros_like(z2)
self.dec2.backward(np.zeros_like(grads['dr2']))
# ── Encoder 2 ─────────────────────────────────────────────────
# enc2 input = concat(T2, T3, z0, z1) dims = TIER2+TIER3+Z0+Z1
if phase >= 2:
dmu2 = dz2_from_d2 + grads['dmu2_kl']
dlv2 = dz2_from_d2 * 0.5 + grads['dlv2_kl']
dinp2 = self.enc2.backward(dmu2, dlv2) # (B, T2+T3+Z0+Z1)
_off = TIER2_DIM + TIER3_DIM
dz0_from_e2 = dinp2[:, _off:_off + Z0_DIM]
dz1_from_e2 = dinp2[:, _off + Z0_DIM:]
self.dec2.step(lr, t)
self.enc2.step(lr, t)
else:
dz0_from_e2 = np.zeros_like(z0)
dz1_from_e2 = np.zeros_like(z1)
# ── Decoder 1 (z0+z1 → x1) ──────────────────────────────────────
if phase >= 1:
dz_all_1 = self.dec1.backward(grads['dr1']) # (B, Z0+Z1)
dz0_from_d1 = dz_all_1[:, :Z0_DIM]
dz1_from_d1 = dz_all_1[:, Z0_DIM:]
else:
dz0_from_d1 = np.zeros_like(z0)
dz1_from_d1 = np.zeros_like(z1)
self.dec1.backward(np.zeros_like(grads['dr1']))
# ── Encoder 1 ─────────────────────────────────────────────────
# enc1 input = concat(T1, T4, z0) dims = TIER1+TIER4+Z0
if phase >= 1:
dmu1 = dz1_from_d1 + dz1_from_d2 + dz1_from_e2 + grads['dmu1_kl']
dlv1 = dz1_from_d1 * 0.5 + grads['dlv1_kl']
dinp1 = self.enc1.backward(dmu1, dlv1) # (B, T1+T4+Z0)
dz0_from_e1 = dinp1[:, TIER1_DIM + TIER4_DIM:] # correct: skip T1+T4
self.dec1.step(lr, t)
self.enc1.step(lr, t)
else:
dz0_from_e1 = np.zeros_like(z0)
# ── Decoder 0 (z0 → x0) ─────────────────────────────────────────
dz0_from_d0 = self.dec0.backward(grads['dr0']) # (B, Z0)
# ── Encoder 0 ────────────────────────────────────────────────────
dmu0 = dz0_from_d0 + dz0_from_d1 + dz0_from_d2 + dz0_from_e1 + dz0_from_e2 + grads['dmu0_kl']
dlv0 = dz0_from_d0 * 0.5 + grads['dlv0_kl']
self.enc0.backward(dmu0, dlv0)
self.enc0.step(lr, t)
self.dec0.step(lr, t)
# ── Training ──────────────────────────────────────────────────────────
def train_epoch(self, X: np.ndarray, mask: np.ndarray,
lr: float, batch_size: int, phase: int, rng) -> Dict:
N = len(X)
idx = rng.permutation(N)
epoch_stats = dict(total=0, recon0=0, recon1=0, recon2=0, kl0=0, kl1=0, kl2=0, n=0)
for start in range(0, N, batch_size):
bi = idx[start:start + batch_size]
Xb = X[bi]
mb = mask[bi]
self.step_t += 1 # Adam global step (cumulative across all epochs)
enc_out = self.encode(Xb, mb, rng)
dec_out = self.decode(enc_out)
grads = self.loss_and_grads(enc_out, dec_out, mb, N, phase)
self.backward(enc_out, dec_out, grads, phase, lr)
for k in ['total', 'recon0', 'recon1', 'recon2', 'kl0', 'kl1', 'kl2']:
epoch_stats[k] += float(grads[k]) * len(bi)
epoch_stats['n'] += len(bi)
return {k: v / max(epoch_stats['n'], 1) for k, v in epoch_stats.items() if k != 'n'}
# ── Inference ─────────────────────────────────────────────────────────
def get_latents(self, X: np.ndarray, mask: np.ndarray) -> Dict[str, np.ndarray]:
"""Return disentangled latent codes for analysis."""
rng = np.random.RandomState(0) # deterministic for inference
enc = self.encode(X, mask, rng)
return {
'z0_macro': enc['mu0'], # (N, 4) deterministic mean
'z1_eigen': enc['mu1'], # (N, 8)
'z2_xsection': enc['mu2'], # (N, 8)
'z_all': np.concatenate([enc['mu0'], enc['mu1'], enc['mu2']], axis=1), # (N, 20)
'sigma0': np.exp(0.5 * enc['lv0']), # uncertainty
'sigma1': np.exp(0.5 * enc['lv1']),
'sigma2': np.exp(0.5 * enc['lv2']),
}
# ── Persistence ───────────────────────────────────────────────────────
def save(self, path: str):
data = {
'norm': dict(mu_t0=self._mu_t0, sd_t0=self._sd_t0,
mu_t1=self._mu_t1, sd_t1=self._sd_t1,
mu_t2=self._mu_t2, sd_t2=self._sd_t2),
'step_t': self.step_t,
'losses': self.train_losses,
}
# Save all weight matrices
for name, enc in [('enc0', self.enc0), ('enc1', self.enc1), ('enc2', self.enc2)]:
for i, layer in enumerate(enc.mlp.layers):
data[f'{name}_mlp{i}_W'] = layer.W
data[f'{name}_mlp{i}_b'] = layer.b
data[f'{name}_mu_W'] = enc.mu_head.W; data[f'{name}_mu_b'] = enc.mu_head.b
data[f'{name}_lv_W'] = enc.lv_head.W; data[f'{name}_lv_b'] = enc.lv_head.b
for name, dec in [('dec0', self.dec0), ('dec1', self.dec1), ('dec2', self.dec2)]:
for i, layer in enumerate(dec.mlp.layers):
data[f'{name}_mlp{i}_W'] = layer.W
data[f'{name}_mlp{i}_b'] = layer.b
np.savez_compressed(path, **data)
print(f"Model saved: {path}.npz")