Files
DOLPHIN/nautilus_dolphin/dvae/hierarchical_dvae.py
hjnormey 01c19662cb initial: import DOLPHIN baseline 2026-04-21 from dolphinng5_predict working tree
Includes core prod + GREEN/BLUE subsystems:
- prod/ (BLUE harness, configs, scripts, docs)
- nautilus_dolphin/ (GREEN Nautilus-native impl + dvae/ preserved)
- adaptive_exit/ (AEM engine + models/bucket_assignments.pkl)
- Observability/ (EsoF advisor, TUI, dashboards)
- external_factors/ (EsoF producer)
- mc_forewarning_qlabs_fork/ (MC regime/envelope)

Excludes runtime caches, logs, backups, and reproducible artifacts per .gitignore.
2026-04-21 16:58:38 +02:00

560 lines
26 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Hierarchical Disentangled VAE (H-D-VAE) for DOLPHIN
=====================================================
State-of-the-art β-TCVAE with:
- 3-level HIERARCHY of disentangled latent codes
- Real backpropagation (analytical gradients, no random noise)
- Masking for missing tiers (pre-eigenvalue era)
- Curriculum-aware: can freeze/thaw tiers during training
- Spectral features with 512-bit-aware log/ratio encoding
Latent structure:
z0 (dim=4) "macro regime" ← Tier-0 breadth + time
z1 (dim=8) "eigenstructure" ← Tier-1 eigenvalues, conditioned on z0
z2 (dim=8) "cross-section" ← Tier-2 per-asset, conditioned on z0+z1
Total latent dim: 20.
β-TCVAE decomposition (Chen et al. 2018):
KL = MI(x;z) + TC(z) + dim_KL(z)
Loss = Recon + γ·MI + β·TC + λ·dim_KL
Key insight: TC penalises factorial structure violation, encouraging
each dimension of z to capture an INDEPENDENT factor of variation.
With β >> 1 the model is forced to find the minimal sufficient encoding.
"""
import numpy as np
from typing import Tuple, Dict, Optional, List
# ── Constants ──────────────────────────────────────────────────────────────
# Must match corpus_builder.py DIMS = [8, 20, 50, 25, 8]
TIER0_DIM = 8
TIER1_DIM = 20
TIER2_DIM = 50
TIER3_DIM = 25 # ExF macro
TIER4_DIM = 8 # EsoF
# Tier offsets into 111-dim feature vector
T_OFF = [0, 8, 28, 78, 103]
Z0_DIM = 4 # macro regime
Z1_DIM = 8 # eigenstructure
Z2_DIM = 8 # cross-section + ExF
Z3_DIM = 4 # esoteric
Z_TOTAL = Z0_DIM + Z1_DIM + Z2_DIM + Z3_DIM # 24
EPS = 1e-8
# ── Numerics ───────────────────────────────────────────────────────────────
def _relu(x): return np.maximum(0, x)
def _drelu(x): return (x > 0).astype(x.dtype)
def _tanh(x): return np.tanh(x)
def _dtanh(x): return 1.0 - np.tanh(x)**2
def _softplus(x): return np.log1p(np.exp(np.clip(x, -20, 20)))
# ── Linear layer (weights + bias) with gradient ───────────────────────────
class Linear:
"""Affine layer with Adam optimiser state."""
def __init__(self, in_dim: int, out_dim: int, seed: int = 42):
rng = np.random.RandomState(seed)
scale = np.sqrt(2.0 / in_dim)
self.W = rng.randn(in_dim, out_dim).astype(np.float32) * scale
self.b = np.zeros(out_dim, dtype=np.float32)
# Adam state
self.mW = np.zeros_like(self.W)
self.vW = np.zeros_like(self.W)
self.mb = np.zeros_like(self.b)
self.vb = np.zeros_like(self.b)
def forward(self, x: np.ndarray) -> np.ndarray:
self._x = x
return x @ self.W + self.b
def backward(self, dout: np.ndarray) -> np.ndarray:
self.dW = self._x.T @ dout
self.db = dout.sum(axis=0)
return dout @ self.W.T
def step(self, lr: float, t: int, beta1=0.9, beta2=0.999):
bc1 = 1 - beta1**t
bc2 = 1 - beta2**t
for (p, m, v, g) in [(self.W, self.mW, self.vW, self.dW),
(self.b, self.mb, self.vb, self.db)]:
m[:] = beta1 * m + (1 - beta1) * g
v[:] = beta2 * v + (1 - beta2) * g**2
p -= lr * (m / bc1) / (np.sqrt(v / bc2) + 1e-8)
# ── Simple 2-layer MLP with ReLU ──────────────────────────────────────────
class MLP:
def __init__(self, dims: List[int], seed: int = 42):
self.layers = []
for i, (d_in, d_out) in enumerate(zip(dims[:-1], dims[1:])):
self.layers.append(Linear(d_in, d_out, seed=seed + i))
# Activation cache
self._acts = []
def forward(self, x: np.ndarray, final_linear=True) -> np.ndarray:
self._acts = []
h = x
for i, layer in enumerate(self.layers):
h = layer.forward(h)
if i < len(self.layers) - 1 or not final_linear:
self._acts.append(h.copy())
h = _relu(h)
else:
self._acts.append(None)
return h
def backward(self, dout: np.ndarray) -> np.ndarray:
dh = dout
for i in range(len(self.layers) - 1, -1, -1):
if i < len(self.layers) - 1:
dh = dh * _drelu(self._acts[i])
dh = self.layers[i].backward(dh)
return dh
def step(self, lr: float, t: int):
for layer in self.layers:
layer.step(lr, t)
# ── Encoder: MLP → (mu, logvar) ───────────────────────────────────────────
class VAEEncoder:
"""Encodes input (optionally concatenated with conditioning z) to (mu, logvar)."""
def __init__(self, in_dim: int, hidden: int, z_dim: int, seed: int = 0):
self.mlp = MLP([in_dim, hidden, hidden], seed=seed)
self.mu_head = Linear(hidden, z_dim, seed=seed + 100)
self.lv_head = Linear(hidden, z_dim, seed=seed + 200)
self.z_dim = z_dim
def forward(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
h = self.mlp.forward(x) # (B, hidden)
h = _relu(h)
self._h = h
mu = self.mu_head.forward(h)
logvar = self.lv_head.forward(h)
logvar = np.clip(logvar, -10, 4) # keep variance sane
return mu, logvar
def backward(self, dmu: np.ndarray, dlv: np.ndarray) -> np.ndarray:
dh_mu = self.mu_head.backward(dmu)
dh_lv = self.lv_head.backward(dlv)
dh = (dh_mu + dh_lv) * _drelu(self._h)
return self.mlp.backward(dh)
def step(self, lr: float, t: int):
self.mlp.step(lr, t)
self.mu_head.step(lr, t)
self.lv_head.step(lr, t)
# ── Decoder: z → x_hat ────────────────────────────────────────────────────
class VAEDecoder:
def __init__(self, z_dim: int, hidden: int, out_dim: int, seed: int = 0):
self.mlp = MLP([z_dim, hidden, hidden, out_dim], seed=seed + 300)
self.out_dim = out_dim
def forward(self, z: np.ndarray) -> np.ndarray:
return self.mlp.forward(z, final_linear=True)
def backward(self, dout: np.ndarray) -> np.ndarray:
return self.mlp.backward(dout)
def step(self, lr: float, t: int):
self.mlp.step(lr, t)
# ── β-TCVAE loss (Chen et al. 2018, proper dataset-size corrected form) ───
def btcvae_kl(mu: np.ndarray, logvar: np.ndarray,
N_dataset: int, beta: float = 4.0,
gamma: float = 1.0, lam: float = 1.0
) -> Tuple[float, np.ndarray, np.ndarray]:
"""
Decompose KL into: MI + beta*TC + lambda*dim_KL
Returns (total_kl, dmu, dlogvar).
MI = E_q[log q(z|x)] - E_q[log q(z)]
TC = E_q[log q(z)] - E_q[sum_j log q(z_j)]
dKL = E_q[sum_j log q(z_j) - log p(z_j)]
Minibatch-weighted estimator — numerically stable.
"""
B, D = mu.shape
var = np.exp(logvar)
# Per-sample KL (closed form vs N(0,1) prior)
kl_per_dim = 0.5 * (mu**2 + var - logvar - 1) # (B, D)
# Log q(z|x_i) for sample z_i ~ q(z|x_i) (diagonal Gaussian)
# We use the mean as the sample point (reparametrization noise ≈ 0 for gradient estimation)
# log q(z_i | x_i) = sum_d -0.5*(log(2π) + logvar + (z-mu)^2/var)
# At z = mu: (z-mu)=0, so log q(z|x) = -0.5*sum(log(2π) + logvar)
log_q_z_given_x = -0.5 * (D * np.log(2 * np.pi) + logvar.sum(axis=1)) # (B,)
# Minibatch estimate of log q(z): E_data[sum_x log q(z|x)] / N
# log q(z_i) ≈ log(1/N * sum_j q(z_i | x_j))
# Using pairwise: log q(z_i) ≈ logsumexp over j of log q(z_i | x_j) - log(N)
# q(z_i | x_j) for each pair: diffs = z_i[none,:] - mu[none,:] shape (B,B,D)
z_sample = mu # (B, D) — use mean as point estimate
# (B_i, B_j, D) pairwise differences
diff = z_sample[:, None, :] - mu[None, :, :] # (B, B, D)
log_q_z_pair = -0.5 * (
D * np.log(2 * np.pi)
+ logvar[None, :, :].sum(axis=2)
+ (diff**2 / (var[None, :, :] + EPS)).sum(axis=2)
) # (B, B)
log_q_z = np.log(np.sum(np.exp(log_q_z_pair - log_q_z_pair.max(axis=1, keepdims=True)), axis=1) + EPS) \
+ log_q_z_pair.max(axis=1) - np.log(N_dataset + EPS) # (B,)
# log q(z_j) marginal — independence assumption for TC
# log q(z_i_j) for each dim j: logsumexp over data points
log_q_z_product = np.zeros(B, dtype=np.float64)
for j in range(D):
diff_j = z_sample[:, j:j+1] - mu[:, j:j+1].T # (B, B)
log_q_zj = -0.5 * (np.log(2 * np.pi) + logvar[None, :, j] + diff_j**2 / (var[None, :, j] + EPS)) # (B,B)
log_q_zj_marginal = np.log(np.sum(np.exp(log_q_zj - log_q_zj.max(axis=1, keepdims=True)), axis=1) + EPS) \
+ log_q_zj.max(axis=1) - np.log(N_dataset + EPS) # (B,)
log_q_z_product += log_q_zj_marginal
# log p(z) = N(0,1)
log_p_z = -0.5 * (D * np.log(2 * np.pi) + (z_sample**2).sum(axis=1)) # (B,)
mi_loss = np.mean(log_q_z_given_x - log_q_z)
tc_loss = np.mean(log_q_z - log_q_z_product)
dkl_loss = np.mean(log_q_z_product - log_p_z)
total_kl = gamma * mi_loss + beta * max(tc_loss, 0) + lam * dkl_loss
# Gradients w.r.t. mu and logvar: use standard closed-form KL gradient
# ∂KL/∂mu = mu, ∂KL/∂logvar = 0.5*(exp(logvar) - 1)
scale = (gamma + beta + lam) / (3.0 * B + EPS)
dmu = scale * mu
dlogvar = scale * 0.5 * (var - 1.0)
return float(total_kl), dmu.astype(np.float32), dlogvar.astype(np.float32)
# ── Reparametrization ─────────────────────────────────────────────────────
def reparametrize(mu, logvar, rng):
std = np.exp(0.5 * logvar)
eps = rng.randn(*mu.shape).astype(np.float32)
return mu + eps * std
# ── Hierarchical D-VAE ────────────────────────────────────────────────────
class HierarchicalDVAE:
"""
4-level Hierarchical Disentangled VAE (111-dim input, 24-dim latent).
Latent hierarchy:
z0 (4) = enc0(T0) macro regime (breadth + time)
z1 (8) = enc1(T1 + T4 + z0) eigenstructure + esoteric, cond on z0
z2 (8) = enc2(T2 + T3 + z0 + z1) cross-section + ExF, cond on z0,z1
z3 (4) = enc3(T3 + T4 + z0) ExF+EsoF regime, cond on z0
Decoding:
T0_hat = dec0(z0)
T1_hat = dec1(z0, z1)
T2_hat = dec2(z0, z1, z2)
T3_hat = dec3(z0, z3)
T4_hat = dec4(z0) (EsoF is deterministic, serves as auxiliary)
Training phases:
0: enc0/dec0 only — full 500K+ corpus (even pre-eigen NG1/NG2)
1: + enc1/dec1 — eigen-tier (NG3+)
2: + enc2/dec2 — pricing + ExF (NG3+ with pricing)
3: + enc3/dec3 — ExF regime
4: joint fine-tune all
"""
def __init__(self, hidden: int = 64, beta: float = 4.0,
gamma: float = 1.0, lam: float = 1.0, seed: int = 42):
self.hidden = hidden
self.beta = beta
self.gamma = gamma
self.lam = lam
self.seed = seed
# Encoders (each conditioned on upstream z)
self.enc0 = VAEEncoder(TIER0_DIM, hidden, Z0_DIM, seed=seed)
self.enc1 = VAEEncoder(TIER1_DIM + TIER4_DIM + Z0_DIM, hidden, Z1_DIM, seed=seed+10)
self.enc2 = VAEEncoder(TIER2_DIM + TIER3_DIM + Z0_DIM + Z1_DIM, hidden, Z2_DIM, seed=seed+20)
self.enc3 = VAEEncoder(TIER3_DIM + TIER4_DIM + Z0_DIM, hidden//2, Z3_DIM, seed=seed+30)
# Decoders
self.dec0 = VAEDecoder(Z0_DIM, hidden, TIER0_DIM, seed=seed)
self.dec1 = VAEDecoder(Z0_DIM + Z1_DIM, hidden, TIER1_DIM, seed=seed+10)
self.dec2 = VAEDecoder(Z0_DIM + Z1_DIM + Z2_DIM, hidden, TIER2_DIM, seed=seed+20)
self.dec3 = VAEDecoder(Z0_DIM + Z3_DIM, hidden//2, TIER3_DIM, seed=seed+30)
self.dec4 = VAEDecoder(Z0_DIM, hidden//2, TIER4_DIM, seed=seed+40)
# Normalisation statistics (fit on training data)
self._mu_t0 = np.zeros(TIER0_DIM, dtype=np.float32)
self._sd_t0 = np.ones(TIER0_DIM, dtype=np.float32)
self._mu_t1 = np.zeros(TIER1_DIM, dtype=np.float32)
self._sd_t1 = np.ones(TIER1_DIM, dtype=np.float32)
self._mu_t2 = np.zeros(TIER2_DIM, dtype=np.float32)
self._sd_t2 = np.ones(TIER2_DIM, dtype=np.float32)
self._mu_t3 = np.zeros(TIER3_DIM, dtype=np.float32)
self._sd_t3 = np.ones(TIER3_DIM, dtype=np.float32)
self._mu_t4 = np.zeros(TIER4_DIM, dtype=np.float32)
self._sd_t4 = np.ones(TIER4_DIM, dtype=np.float32)
self.step_t = 0 # Adam time step
self.train_losses: List[Dict] = []
# ── Normalisation ──────────────────────────────────────────────────────
def fit_normaliser(self, X: np.ndarray, mask: np.ndarray):
"""Fit per-tier normalisation on a representative sample."""
def _fit(rows, dim):
if len(rows) < 2:
return np.zeros(dim, dtype=np.float32), np.ones(dim, dtype=np.float32)
return rows.mean(0).astype(np.float32), (rows.std(0) + EPS).astype(np.float32)
self._mu_t0, self._sd_t0 = _fit(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM], TIER0_DIM)
idx1 = mask[:, 1]
self._mu_t1, self._sd_t1 = _fit(X[idx1, T_OFF[1]:T_OFF[1]+TIER1_DIM], TIER1_DIM)
idx2 = mask[:, 2]
self._mu_t2, self._sd_t2 = _fit(X[idx2, T_OFF[2]:T_OFF[2]+TIER2_DIM], TIER2_DIM)
idx3 = mask[:, 3]
self._mu_t3, self._sd_t3 = _fit(X[idx3, T_OFF[3]:T_OFF[3]+TIER3_DIM], TIER3_DIM)
self._mu_t4, self._sd_t4 = _fit(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM], TIER4_DIM)
def _norm(self, x, mu, sd): return (x - mu) / sd
def _norm0(self, x): return self._norm(x, self._mu_t0, self._sd_t0)
def _norm1(self, x): return self._norm(x, self._mu_t1, self._sd_t1)
def _norm2(self, x): return self._norm(x, self._mu_t2, self._sd_t2)
def _norm3(self, x): return self._norm(x, self._mu_t3, self._sd_t3)
def _norm4(self, x): return self._norm(x, self._mu_t4, self._sd_t4)
# ── Forward pass ──────────────────────────────────────────────────────
def _split(self, X):
"""Split 111-dim row into normalised tier vectors."""
t0 = self._norm0(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM])
t1 = self._norm1(X[:, T_OFF[1]:T_OFF[1]+TIER1_DIM])
t2 = self._norm2(X[:, T_OFF[2]:T_OFF[2]+TIER2_DIM])
t3 = self._norm3(X[:, T_OFF[3]:T_OFF[3]+TIER3_DIM])
t4 = self._norm4(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM])
return t0, t1, t2, t3, t4
def encode(self, X: np.ndarray, mask: np.ndarray, rng) -> Dict:
t0, t1, t2, t3, t4 = self._split(X)
mu0, lv0 = self.enc0.forward(t0)
z0 = reparametrize(mu0, lv0, rng)
mu1, lv1 = self.enc1.forward(np.concatenate([t1, t4, z0], axis=1))
z1 = reparametrize(mu1, lv1, rng)
mu2, lv2 = self.enc2.forward(np.concatenate([t2, t3, z0, z1], axis=1))
z2 = reparametrize(mu2, lv2, rng)
mu3, lv3 = self.enc3.forward(np.concatenate([t3, t4, z0], axis=1))
z3 = reparametrize(mu3, lv3, rng)
return dict(t0=t0, t1=t1, t2=t2, t3=t3, t4=t4,
mu0=mu0, lv0=lv0, z0=z0,
mu1=mu1, lv1=lv1, z1=z1,
mu2=mu2, lv2=lv2, z2=z2,
mu3=mu3, lv3=lv3, z3=z3)
def decode(self, enc: Dict) -> Dict:
z0, z1, z2, z3 = enc['z0'], enc['z1'], enc['z2'], enc['z3']
x0_hat = self.dec0.forward(z0)
x1_hat = self.dec1.forward(np.concatenate([z0, z1], axis=1))
x2_hat = self.dec2.forward(np.concatenate([z0, z1, z2], axis=1))
x3_hat = self.dec3.forward(np.concatenate([z0, z3], axis=1))
x4_hat = self.dec4.forward(z0) # EsoF decoded from macro only
return dict(x0_hat=x0_hat, x1_hat=x1_hat, x2_hat=x2_hat,
x3_hat=x3_hat, x4_hat=x4_hat)
# ── Loss and gradients ─────────────────────────────────────────────────
def loss_and_grads(self, enc: Dict, dec: Dict, mask: np.ndarray,
N_dataset: int, phase: int) -> Dict:
"""Compute total loss and all gradients via chain rule."""
t0, t1, t2 = enc['t0'], enc['t1'], enc['t2']
x0_hat = dec['x0_hat']
x1_hat = dec['x1_hat']
x2_hat = dec['x2_hat']
B = len(t0)
# Mask weights: per-sample, per-tier
w1 = mask[:, 1].astype(np.float32)[:, None] # (B,1)
w2 = mask[:, 2].astype(np.float32)[:, None]
# Reconstruction losses
recon0 = np.mean((x0_hat - t0)**2)
recon1 = np.mean(w1 * (x1_hat - t1)**2)
recon2 = np.mean(w2 * (x2_hat - t2)**2)
# Reconstruction gradients
dr0 = 2 * (x0_hat - t0) / (B * TIER0_DIM)
dr1 = 2 * w1 * (x1_hat - t1) / (B * TIER1_DIM)
dr2 = 2 * w2 * (x2_hat - t2) / (B * TIER2_DIM)
# KL losses
kl0, dmu0_kl, dlv0_kl = btcvae_kl(enc['mu0'], enc['lv0'], N_dataset, self.beta, self.gamma, self.lam)
kl1, dmu1_kl, dlv1_kl = btcvae_kl(enc['mu1'], enc['lv1'], N_dataset, self.beta, self.gamma, self.lam)
kl2, dmu2_kl, dlv2_kl = btcvae_kl(enc['mu2'], enc['lv2'], N_dataset, self.beta, self.gamma, self.lam)
# Phase-based scaling: only activate eigen/pricing in correct phase
if phase < 1:
kl1 = recon1 = 0.0; dmu1_kl[:] = 0; dlv1_kl[:] = 0; dr1[:] = 0
if phase < 2:
kl2 = recon2 = 0.0; dmu2_kl[:] = 0; dlv2_kl[:] = 0; dr2[:] = 0
total = recon0 + recon1 + recon2 + kl0 + kl1 + kl2
return dict(
total=total, recon0=recon0, recon1=recon1, recon2=recon2,
kl0=kl0, kl1=kl1, kl2=kl2,
dr0=dr0, dr1=dr1, dr2=dr2,
dmu0_kl=dmu0_kl, dlv0_kl=dlv0_kl,
dmu1_kl=dmu1_kl, dlv1_kl=dlv1_kl,
dmu2_kl=dmu2_kl, dlv2_kl=dlv2_kl,
)
def backward(self, enc: Dict, dec: Dict, grads: Dict, phase: int, lr: float):
"""Backpropagate through all encoders and decoders, then Adam step."""
z0, z1, z2 = enc['z0'], enc['z1'], enc['z2']
t = self.step_t
# ── Decoder 2 (z0+z1+z2 → x2) ──────────────────────────────────
if phase >= 2:
dz_all_2 = self.dec2.backward(grads['dr2']) # (B, Z0+Z1+Z2)
dz0_from_d2 = dz_all_2[:, :Z0_DIM]
dz1_from_d2 = dz_all_2[:, Z0_DIM:Z0_DIM + Z1_DIM]
dz2_from_d2 = dz_all_2[:, Z0_DIM + Z1_DIM:]
else:
dz0_from_d2 = np.zeros_like(z0)
dz1_from_d2 = np.zeros_like(z1)
dz2_from_d2 = np.zeros_like(z2)
self.dec2.backward(np.zeros_like(grads['dr2']))
# ── Encoder 2 ─────────────────────────────────────────────────
# enc2 input = concat(T2, T3, z0, z1) dims = TIER2+TIER3+Z0+Z1
if phase >= 2:
dmu2 = dz2_from_d2 + grads['dmu2_kl']
dlv2 = dz2_from_d2 * 0.5 + grads['dlv2_kl']
dinp2 = self.enc2.backward(dmu2, dlv2) # (B, T2+T3+Z0+Z1)
_off = TIER2_DIM + TIER3_DIM
dz0_from_e2 = dinp2[:, _off:_off + Z0_DIM]
dz1_from_e2 = dinp2[:, _off + Z0_DIM:]
self.dec2.step(lr, t)
self.enc2.step(lr, t)
else:
dz0_from_e2 = np.zeros_like(z0)
dz1_from_e2 = np.zeros_like(z1)
# ── Decoder 1 (z0+z1 → x1) ──────────────────────────────────────
if phase >= 1:
dz_all_1 = self.dec1.backward(grads['dr1']) # (B, Z0+Z1)
dz0_from_d1 = dz_all_1[:, :Z0_DIM]
dz1_from_d1 = dz_all_1[:, Z0_DIM:]
else:
dz0_from_d1 = np.zeros_like(z0)
dz1_from_d1 = np.zeros_like(z1)
self.dec1.backward(np.zeros_like(grads['dr1']))
# ── Encoder 1 ─────────────────────────────────────────────────
# enc1 input = concat(T1, T4, z0) dims = TIER1+TIER4+Z0
if phase >= 1:
dmu1 = dz1_from_d1 + dz1_from_d2 + dz1_from_e2 + grads['dmu1_kl']
dlv1 = dz1_from_d1 * 0.5 + grads['dlv1_kl']
dinp1 = self.enc1.backward(dmu1, dlv1) # (B, T1+T4+Z0)
dz0_from_e1 = dinp1[:, TIER1_DIM + TIER4_DIM:] # correct: skip T1+T4
self.dec1.step(lr, t)
self.enc1.step(lr, t)
else:
dz0_from_e1 = np.zeros_like(z0)
# ── Decoder 0 (z0 → x0) ─────────────────────────────────────────
dz0_from_d0 = self.dec0.backward(grads['dr0']) # (B, Z0)
# ── Encoder 0 ────────────────────────────────────────────────────
dmu0 = dz0_from_d0 + dz0_from_d1 + dz0_from_d2 + dz0_from_e1 + dz0_from_e2 + grads['dmu0_kl']
dlv0 = dz0_from_d0 * 0.5 + grads['dlv0_kl']
self.enc0.backward(dmu0, dlv0)
self.enc0.step(lr, t)
self.dec0.step(lr, t)
# ── Training ──────────────────────────────────────────────────────────
def train_epoch(self, X: np.ndarray, mask: np.ndarray,
lr: float, batch_size: int, phase: int, rng) -> Dict:
N = len(X)
idx = rng.permutation(N)
epoch_stats = dict(total=0, recon0=0, recon1=0, recon2=0, kl0=0, kl1=0, kl2=0, n=0)
for start in range(0, N, batch_size):
bi = idx[start:start + batch_size]
Xb = X[bi]
mb = mask[bi]
self.step_t += 1 # Adam global step (cumulative across all epochs)
enc_out = self.encode(Xb, mb, rng)
dec_out = self.decode(enc_out)
grads = self.loss_and_grads(enc_out, dec_out, mb, N, phase)
self.backward(enc_out, dec_out, grads, phase, lr)
for k in ['total', 'recon0', 'recon1', 'recon2', 'kl0', 'kl1', 'kl2']:
epoch_stats[k] += float(grads[k]) * len(bi)
epoch_stats['n'] += len(bi)
return {k: v / max(epoch_stats['n'], 1) for k, v in epoch_stats.items() if k != 'n'}
# ── Inference ─────────────────────────────────────────────────────────
def get_latents(self, X: np.ndarray, mask: np.ndarray) -> Dict[str, np.ndarray]:
"""Return disentangled latent codes for analysis."""
rng = np.random.RandomState(0) # deterministic for inference
enc = self.encode(X, mask, rng)
return {
'z0_macro': enc['mu0'], # (N, 4) deterministic mean
'z1_eigen': enc['mu1'], # (N, 8)
'z2_xsection': enc['mu2'], # (N, 8)
'z_all': np.concatenate([enc['mu0'], enc['mu1'], enc['mu2']], axis=1), # (N, 20)
'sigma0': np.exp(0.5 * enc['lv0']), # uncertainty
'sigma1': np.exp(0.5 * enc['lv1']),
'sigma2': np.exp(0.5 * enc['lv2']),
}
# ── Persistence ───────────────────────────────────────────────────────
def save(self, path: str):
data = {
'norm': dict(mu_t0=self._mu_t0, sd_t0=self._sd_t0,
mu_t1=self._mu_t1, sd_t1=self._sd_t1,
mu_t2=self._mu_t2, sd_t2=self._sd_t2),
'step_t': self.step_t,
'losses': self.train_losses,
}
# Save all weight matrices
for name, enc in [('enc0', self.enc0), ('enc1', self.enc1), ('enc2', self.enc2)]:
for i, layer in enumerate(enc.mlp.layers):
data[f'{name}_mlp{i}_W'] = layer.W
data[f'{name}_mlp{i}_b'] = layer.b
data[f'{name}_mu_W'] = enc.mu_head.W; data[f'{name}_mu_b'] = enc.mu_head.b
data[f'{name}_lv_W'] = enc.lv_head.W; data[f'{name}_lv_b'] = enc.lv_head.b
for name, dec in [('dec0', self.dec0), ('dec1', self.dec1), ('dec2', self.dec2)]:
for i, layer in enumerate(dec.mlp.layers):
data[f'{name}_mlp{i}_W'] = layer.W
data[f'{name}_mlp{i}_b'] = layer.b
np.savez_compressed(path, **data)
print(f"Model saved: {path}.npz")