""" Hierarchical Disentangled VAE (H-D-VAE) for DOLPHIN ===================================================== State-of-the-art β-TCVAE with: - 3-level HIERARCHY of disentangled latent codes - Real backpropagation (analytical gradients, no random noise) - Masking for missing tiers (pre-eigenvalue era) - Curriculum-aware: can freeze/thaw tiers during training - Spectral features with 512-bit-aware log/ratio encoding Latent structure: z0 (dim=4) "macro regime" ← Tier-0 breadth + time z1 (dim=8) "eigenstructure" ← Tier-1 eigenvalues, conditioned on z0 z2 (dim=8) "cross-section" ← Tier-2 per-asset, conditioned on z0+z1 Total latent dim: 20. β-TCVAE decomposition (Chen et al. 2018): KL = MI(x;z) + TC(z) + dim_KL(z) Loss = Recon + γ·MI + β·TC + λ·dim_KL Key insight: TC penalises factorial structure violation, encouraging each dimension of z to capture an INDEPENDENT factor of variation. With β >> 1 the model is forced to find the minimal sufficient encoding. """ import numpy as np from typing import Tuple, Dict, Optional, List # ── Constants ────────────────────────────────────────────────────────────── # Must match corpus_builder.py DIMS = [8, 20, 50, 25, 8] TIER0_DIM = 8 TIER1_DIM = 20 TIER2_DIM = 50 TIER3_DIM = 25 # ExF macro TIER4_DIM = 8 # EsoF # Tier offsets into 111-dim feature vector T_OFF = [0, 8, 28, 78, 103] Z0_DIM = 4 # macro regime Z1_DIM = 8 # eigenstructure Z2_DIM = 8 # cross-section + ExF Z3_DIM = 4 # esoteric Z_TOTAL = Z0_DIM + Z1_DIM + Z2_DIM + Z3_DIM # 24 EPS = 1e-8 # ── Numerics ─────────────────────────────────────────────────────────────── def _relu(x): return np.maximum(0, x) def _drelu(x): return (x > 0).astype(x.dtype) def _tanh(x): return np.tanh(x) def _dtanh(x): return 1.0 - np.tanh(x)**2 def _softplus(x): return np.log1p(np.exp(np.clip(x, -20, 20))) # ── Linear layer (weights + bias) with gradient ─────────────────────────── class Linear: """Affine layer with Adam optimiser state.""" def __init__(self, in_dim: int, out_dim: int, seed: int = 42): rng = np.random.RandomState(seed) scale = np.sqrt(2.0 / in_dim) self.W = rng.randn(in_dim, out_dim).astype(np.float32) * scale self.b = np.zeros(out_dim, dtype=np.float32) # Adam state self.mW = np.zeros_like(self.W) self.vW = np.zeros_like(self.W) self.mb = np.zeros_like(self.b) self.vb = np.zeros_like(self.b) def forward(self, x: np.ndarray) -> np.ndarray: self._x = x return x @ self.W + self.b def backward(self, dout: np.ndarray) -> np.ndarray: self.dW = self._x.T @ dout self.db = dout.sum(axis=0) return dout @ self.W.T def step(self, lr: float, t: int, beta1=0.9, beta2=0.999): bc1 = 1 - beta1**t bc2 = 1 - beta2**t for (p, m, v, g) in [(self.W, self.mW, self.vW, self.dW), (self.b, self.mb, self.vb, self.db)]: m[:] = beta1 * m + (1 - beta1) * g v[:] = beta2 * v + (1 - beta2) * g**2 p -= lr * (m / bc1) / (np.sqrt(v / bc2) + 1e-8) # ── Simple 2-layer MLP with ReLU ────────────────────────────────────────── class MLP: def __init__(self, dims: List[int], seed: int = 42): self.layers = [] for i, (d_in, d_out) in enumerate(zip(dims[:-1], dims[1:])): self.layers.append(Linear(d_in, d_out, seed=seed + i)) # Activation cache self._acts = [] def forward(self, x: np.ndarray, final_linear=True) -> np.ndarray: self._acts = [] h = x for i, layer in enumerate(self.layers): h = layer.forward(h) if i < len(self.layers) - 1 or not final_linear: self._acts.append(h.copy()) h = _relu(h) else: self._acts.append(None) return h def backward(self, dout: np.ndarray) -> np.ndarray: dh = dout for i in range(len(self.layers) - 1, -1, -1): if i < len(self.layers) - 1: dh = dh * _drelu(self._acts[i]) dh = self.layers[i].backward(dh) return dh def step(self, lr: float, t: int): for layer in self.layers: layer.step(lr, t) # ── Encoder: MLP → (mu, logvar) ─────────────────────────────────────────── class VAEEncoder: """Encodes input (optionally concatenated with conditioning z) to (mu, logvar).""" def __init__(self, in_dim: int, hidden: int, z_dim: int, seed: int = 0): self.mlp = MLP([in_dim, hidden, hidden], seed=seed) self.mu_head = Linear(hidden, z_dim, seed=seed + 100) self.lv_head = Linear(hidden, z_dim, seed=seed + 200) self.z_dim = z_dim def forward(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: h = self.mlp.forward(x) # (B, hidden) h = _relu(h) self._h = h mu = self.mu_head.forward(h) logvar = self.lv_head.forward(h) logvar = np.clip(logvar, -10, 4) # keep variance sane return mu, logvar def backward(self, dmu: np.ndarray, dlv: np.ndarray) -> np.ndarray: dh_mu = self.mu_head.backward(dmu) dh_lv = self.lv_head.backward(dlv) dh = (dh_mu + dh_lv) * _drelu(self._h) return self.mlp.backward(dh) def step(self, lr: float, t: int): self.mlp.step(lr, t) self.mu_head.step(lr, t) self.lv_head.step(lr, t) # ── Decoder: z → x_hat ──────────────────────────────────────────────────── class VAEDecoder: def __init__(self, z_dim: int, hidden: int, out_dim: int, seed: int = 0): self.mlp = MLP([z_dim, hidden, hidden, out_dim], seed=seed + 300) self.out_dim = out_dim def forward(self, z: np.ndarray) -> np.ndarray: return self.mlp.forward(z, final_linear=True) def backward(self, dout: np.ndarray) -> np.ndarray: return self.mlp.backward(dout) def step(self, lr: float, t: int): self.mlp.step(lr, t) # ── β-TCVAE loss (Chen et al. 2018, proper dataset-size corrected form) ─── def btcvae_kl(mu: np.ndarray, logvar: np.ndarray, N_dataset: int, beta: float = 4.0, gamma: float = 1.0, lam: float = 1.0 ) -> Tuple[float, np.ndarray, np.ndarray]: """ Decompose KL into: MI + beta*TC + lambda*dim_KL Returns (total_kl, dmu, dlogvar). MI = E_q[log q(z|x)] - E_q[log q(z)] TC = E_q[log q(z)] - E_q[sum_j log q(z_j)] dKL = E_q[sum_j log q(z_j) - log p(z_j)] Minibatch-weighted estimator — numerically stable. """ B, D = mu.shape var = np.exp(logvar) # Per-sample KL (closed form vs N(0,1) prior) kl_per_dim = 0.5 * (mu**2 + var - logvar - 1) # (B, D) # Log q(z|x_i) for sample z_i ~ q(z|x_i) (diagonal Gaussian) # We use the mean as the sample point (reparametrization noise ≈ 0 for gradient estimation) # log q(z_i | x_i) = sum_d -0.5*(log(2π) + logvar + (z-mu)^2/var) # At z = mu: (z-mu)=0, so log q(z|x) = -0.5*sum(log(2π) + logvar) log_q_z_given_x = -0.5 * (D * np.log(2 * np.pi) + logvar.sum(axis=1)) # (B,) # Minibatch estimate of log q(z): E_data[sum_x log q(z|x)] / N # log q(z_i) ≈ log(1/N * sum_j q(z_i | x_j)) # Using pairwise: log q(z_i) ≈ logsumexp over j of log q(z_i | x_j) - log(N) # q(z_i | x_j) for each pair: diffs = z_i[none,:] - mu[none,:] shape (B,B,D) z_sample = mu # (B, D) — use mean as point estimate # (B_i, B_j, D) pairwise differences diff = z_sample[:, None, :] - mu[None, :, :] # (B, B, D) log_q_z_pair = -0.5 * ( D * np.log(2 * np.pi) + logvar[None, :, :].sum(axis=2) + (diff**2 / (var[None, :, :] + EPS)).sum(axis=2) ) # (B, B) log_q_z = np.log(np.sum(np.exp(log_q_z_pair - log_q_z_pair.max(axis=1, keepdims=True)), axis=1) + EPS) \ + log_q_z_pair.max(axis=1) - np.log(N_dataset + EPS) # (B,) # log q(z_j) marginal — independence assumption for TC # log q(z_i_j) for each dim j: logsumexp over data points log_q_z_product = np.zeros(B, dtype=np.float64) for j in range(D): diff_j = z_sample[:, j:j+1] - mu[:, j:j+1].T # (B, B) log_q_zj = -0.5 * (np.log(2 * np.pi) + logvar[None, :, j] + diff_j**2 / (var[None, :, j] + EPS)) # (B,B) log_q_zj_marginal = np.log(np.sum(np.exp(log_q_zj - log_q_zj.max(axis=1, keepdims=True)), axis=1) + EPS) \ + log_q_zj.max(axis=1) - np.log(N_dataset + EPS) # (B,) log_q_z_product += log_q_zj_marginal # log p(z) = N(0,1) log_p_z = -0.5 * (D * np.log(2 * np.pi) + (z_sample**2).sum(axis=1)) # (B,) mi_loss = np.mean(log_q_z_given_x - log_q_z) tc_loss = np.mean(log_q_z - log_q_z_product) dkl_loss = np.mean(log_q_z_product - log_p_z) total_kl = gamma * mi_loss + beta * max(tc_loss, 0) + lam * dkl_loss # Gradients w.r.t. mu and logvar: use standard closed-form KL gradient # ∂KL/∂mu = mu, ∂KL/∂logvar = 0.5*(exp(logvar) - 1) scale = (gamma + beta + lam) / (3.0 * B + EPS) dmu = scale * mu dlogvar = scale * 0.5 * (var - 1.0) return float(total_kl), dmu.astype(np.float32), dlogvar.astype(np.float32) # ── Reparametrization ───────────────────────────────────────────────────── def reparametrize(mu, logvar, rng): std = np.exp(0.5 * logvar) eps = rng.randn(*mu.shape).astype(np.float32) return mu + eps * std # ── Hierarchical D-VAE ──────────────────────────────────────────────────── class HierarchicalDVAE: """ 4-level Hierarchical Disentangled VAE (111-dim input, 24-dim latent). Latent hierarchy: z0 (4) = enc0(T0) macro regime (breadth + time) z1 (8) = enc1(T1 + T4 + z0) eigenstructure + esoteric, cond on z0 z2 (8) = enc2(T2 + T3 + z0 + z1) cross-section + ExF, cond on z0,z1 z3 (4) = enc3(T3 + T4 + z0) ExF+EsoF regime, cond on z0 Decoding: T0_hat = dec0(z0) T1_hat = dec1(z0, z1) T2_hat = dec2(z0, z1, z2) T3_hat = dec3(z0, z3) T4_hat = dec4(z0) (EsoF is deterministic, serves as auxiliary) Training phases: 0: enc0/dec0 only — full 500K+ corpus (even pre-eigen NG1/NG2) 1: + enc1/dec1 — eigen-tier (NG3+) 2: + enc2/dec2 — pricing + ExF (NG3+ with pricing) 3: + enc3/dec3 — ExF regime 4: joint fine-tune all """ def __init__(self, hidden: int = 64, beta: float = 4.0, gamma: float = 1.0, lam: float = 1.0, seed: int = 42): self.hidden = hidden self.beta = beta self.gamma = gamma self.lam = lam self.seed = seed # Encoders (each conditioned on upstream z) self.enc0 = VAEEncoder(TIER0_DIM, hidden, Z0_DIM, seed=seed) self.enc1 = VAEEncoder(TIER1_DIM + TIER4_DIM + Z0_DIM, hidden, Z1_DIM, seed=seed+10) self.enc2 = VAEEncoder(TIER2_DIM + TIER3_DIM + Z0_DIM + Z1_DIM, hidden, Z2_DIM, seed=seed+20) self.enc3 = VAEEncoder(TIER3_DIM + TIER4_DIM + Z0_DIM, hidden//2, Z3_DIM, seed=seed+30) # Decoders self.dec0 = VAEDecoder(Z0_DIM, hidden, TIER0_DIM, seed=seed) self.dec1 = VAEDecoder(Z0_DIM + Z1_DIM, hidden, TIER1_DIM, seed=seed+10) self.dec2 = VAEDecoder(Z0_DIM + Z1_DIM + Z2_DIM, hidden, TIER2_DIM, seed=seed+20) self.dec3 = VAEDecoder(Z0_DIM + Z3_DIM, hidden//2, TIER3_DIM, seed=seed+30) self.dec4 = VAEDecoder(Z0_DIM, hidden//2, TIER4_DIM, seed=seed+40) # Normalisation statistics (fit on training data) self._mu_t0 = np.zeros(TIER0_DIM, dtype=np.float32) self._sd_t0 = np.ones(TIER0_DIM, dtype=np.float32) self._mu_t1 = np.zeros(TIER1_DIM, dtype=np.float32) self._sd_t1 = np.ones(TIER1_DIM, dtype=np.float32) self._mu_t2 = np.zeros(TIER2_DIM, dtype=np.float32) self._sd_t2 = np.ones(TIER2_DIM, dtype=np.float32) self._mu_t3 = np.zeros(TIER3_DIM, dtype=np.float32) self._sd_t3 = np.ones(TIER3_DIM, dtype=np.float32) self._mu_t4 = np.zeros(TIER4_DIM, dtype=np.float32) self._sd_t4 = np.ones(TIER4_DIM, dtype=np.float32) self.step_t = 0 # Adam time step self.train_losses: List[Dict] = [] # ── Normalisation ────────────────────────────────────────────────────── def fit_normaliser(self, X: np.ndarray, mask: np.ndarray): """Fit per-tier normalisation on a representative sample.""" def _fit(rows, dim): if len(rows) < 2: return np.zeros(dim, dtype=np.float32), np.ones(dim, dtype=np.float32) return rows.mean(0).astype(np.float32), (rows.std(0) + EPS).astype(np.float32) self._mu_t0, self._sd_t0 = _fit(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM], TIER0_DIM) idx1 = mask[:, 1] self._mu_t1, self._sd_t1 = _fit(X[idx1, T_OFF[1]:T_OFF[1]+TIER1_DIM], TIER1_DIM) idx2 = mask[:, 2] self._mu_t2, self._sd_t2 = _fit(X[idx2, T_OFF[2]:T_OFF[2]+TIER2_DIM], TIER2_DIM) idx3 = mask[:, 3] self._mu_t3, self._sd_t3 = _fit(X[idx3, T_OFF[3]:T_OFF[3]+TIER3_DIM], TIER3_DIM) self._mu_t4, self._sd_t4 = _fit(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM], TIER4_DIM) def _norm(self, x, mu, sd): return (x - mu) / sd def _norm0(self, x): return self._norm(x, self._mu_t0, self._sd_t0) def _norm1(self, x): return self._norm(x, self._mu_t1, self._sd_t1) def _norm2(self, x): return self._norm(x, self._mu_t2, self._sd_t2) def _norm3(self, x): return self._norm(x, self._mu_t3, self._sd_t3) def _norm4(self, x): return self._norm(x, self._mu_t4, self._sd_t4) # ── Forward pass ────────────────────────────────────────────────────── def _split(self, X): """Split 111-dim row into normalised tier vectors.""" t0 = self._norm0(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM]) t1 = self._norm1(X[:, T_OFF[1]:T_OFF[1]+TIER1_DIM]) t2 = self._norm2(X[:, T_OFF[2]:T_OFF[2]+TIER2_DIM]) t3 = self._norm3(X[:, T_OFF[3]:T_OFF[3]+TIER3_DIM]) t4 = self._norm4(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM]) return t0, t1, t2, t3, t4 def encode(self, X: np.ndarray, mask: np.ndarray, rng) -> Dict: t0, t1, t2, t3, t4 = self._split(X) mu0, lv0 = self.enc0.forward(t0) z0 = reparametrize(mu0, lv0, rng) mu1, lv1 = self.enc1.forward(np.concatenate([t1, t4, z0], axis=1)) z1 = reparametrize(mu1, lv1, rng) mu2, lv2 = self.enc2.forward(np.concatenate([t2, t3, z0, z1], axis=1)) z2 = reparametrize(mu2, lv2, rng) mu3, lv3 = self.enc3.forward(np.concatenate([t3, t4, z0], axis=1)) z3 = reparametrize(mu3, lv3, rng) return dict(t0=t0, t1=t1, t2=t2, t3=t3, t4=t4, mu0=mu0, lv0=lv0, z0=z0, mu1=mu1, lv1=lv1, z1=z1, mu2=mu2, lv2=lv2, z2=z2, mu3=mu3, lv3=lv3, z3=z3) def decode(self, enc: Dict) -> Dict: z0, z1, z2, z3 = enc['z0'], enc['z1'], enc['z2'], enc['z3'] x0_hat = self.dec0.forward(z0) x1_hat = self.dec1.forward(np.concatenate([z0, z1], axis=1)) x2_hat = self.dec2.forward(np.concatenate([z0, z1, z2], axis=1)) x3_hat = self.dec3.forward(np.concatenate([z0, z3], axis=1)) x4_hat = self.dec4.forward(z0) # EsoF decoded from macro only return dict(x0_hat=x0_hat, x1_hat=x1_hat, x2_hat=x2_hat, x3_hat=x3_hat, x4_hat=x4_hat) # ── Loss and gradients ───────────────────────────────────────────────── def loss_and_grads(self, enc: Dict, dec: Dict, mask: np.ndarray, N_dataset: int, phase: int) -> Dict: """Compute total loss and all gradients via chain rule.""" t0, t1, t2 = enc['t0'], enc['t1'], enc['t2'] x0_hat = dec['x0_hat'] x1_hat = dec['x1_hat'] x2_hat = dec['x2_hat'] B = len(t0) # Mask weights: per-sample, per-tier w1 = mask[:, 1].astype(np.float32)[:, None] # (B,1) w2 = mask[:, 2].astype(np.float32)[:, None] # Reconstruction losses recon0 = np.mean((x0_hat - t0)**2) recon1 = np.mean(w1 * (x1_hat - t1)**2) recon2 = np.mean(w2 * (x2_hat - t2)**2) # Reconstruction gradients dr0 = 2 * (x0_hat - t0) / (B * TIER0_DIM) dr1 = 2 * w1 * (x1_hat - t1) / (B * TIER1_DIM) dr2 = 2 * w2 * (x2_hat - t2) / (B * TIER2_DIM) # KL losses kl0, dmu0_kl, dlv0_kl = btcvae_kl(enc['mu0'], enc['lv0'], N_dataset, self.beta, self.gamma, self.lam) kl1, dmu1_kl, dlv1_kl = btcvae_kl(enc['mu1'], enc['lv1'], N_dataset, self.beta, self.gamma, self.lam) kl2, dmu2_kl, dlv2_kl = btcvae_kl(enc['mu2'], enc['lv2'], N_dataset, self.beta, self.gamma, self.lam) # Phase-based scaling: only activate eigen/pricing in correct phase if phase < 1: kl1 = recon1 = 0.0; dmu1_kl[:] = 0; dlv1_kl[:] = 0; dr1[:] = 0 if phase < 2: kl2 = recon2 = 0.0; dmu2_kl[:] = 0; dlv2_kl[:] = 0; dr2[:] = 0 total = recon0 + recon1 + recon2 + kl0 + kl1 + kl2 return dict( total=total, recon0=recon0, recon1=recon1, recon2=recon2, kl0=kl0, kl1=kl1, kl2=kl2, dr0=dr0, dr1=dr1, dr2=dr2, dmu0_kl=dmu0_kl, dlv0_kl=dlv0_kl, dmu1_kl=dmu1_kl, dlv1_kl=dlv1_kl, dmu2_kl=dmu2_kl, dlv2_kl=dlv2_kl, ) def backward(self, enc: Dict, dec: Dict, grads: Dict, phase: int, lr: float): """Backpropagate through all encoders and decoders, then Adam step.""" z0, z1, z2 = enc['z0'], enc['z1'], enc['z2'] t = self.step_t # ── Decoder 2 (z0+z1+z2 → x2) ────────────────────────────────── if phase >= 2: dz_all_2 = self.dec2.backward(grads['dr2']) # (B, Z0+Z1+Z2) dz0_from_d2 = dz_all_2[:, :Z0_DIM] dz1_from_d2 = dz_all_2[:, Z0_DIM:Z0_DIM + Z1_DIM] dz2_from_d2 = dz_all_2[:, Z0_DIM + Z1_DIM:] else: dz0_from_d2 = np.zeros_like(z0) dz1_from_d2 = np.zeros_like(z1) dz2_from_d2 = np.zeros_like(z2) self.dec2.backward(np.zeros_like(grads['dr2'])) # ── Encoder 2 ───────────────────────────────────────────────── # enc2 input = concat(T2, T3, z0, z1) dims = TIER2+TIER3+Z0+Z1 if phase >= 2: dmu2 = dz2_from_d2 + grads['dmu2_kl'] dlv2 = dz2_from_d2 * 0.5 + grads['dlv2_kl'] dinp2 = self.enc2.backward(dmu2, dlv2) # (B, T2+T3+Z0+Z1) _off = TIER2_DIM + TIER3_DIM dz0_from_e2 = dinp2[:, _off:_off + Z0_DIM] dz1_from_e2 = dinp2[:, _off + Z0_DIM:] self.dec2.step(lr, t) self.enc2.step(lr, t) else: dz0_from_e2 = np.zeros_like(z0) dz1_from_e2 = np.zeros_like(z1) # ── Decoder 1 (z0+z1 → x1) ────────────────────────────────────── if phase >= 1: dz_all_1 = self.dec1.backward(grads['dr1']) # (B, Z0+Z1) dz0_from_d1 = dz_all_1[:, :Z0_DIM] dz1_from_d1 = dz_all_1[:, Z0_DIM:] else: dz0_from_d1 = np.zeros_like(z0) dz1_from_d1 = np.zeros_like(z1) self.dec1.backward(np.zeros_like(grads['dr1'])) # ── Encoder 1 ───────────────────────────────────────────────── # enc1 input = concat(T1, T4, z0) dims = TIER1+TIER4+Z0 if phase >= 1: dmu1 = dz1_from_d1 + dz1_from_d2 + dz1_from_e2 + grads['dmu1_kl'] dlv1 = dz1_from_d1 * 0.5 + grads['dlv1_kl'] dinp1 = self.enc1.backward(dmu1, dlv1) # (B, T1+T4+Z0) dz0_from_e1 = dinp1[:, TIER1_DIM + TIER4_DIM:] # correct: skip T1+T4 self.dec1.step(lr, t) self.enc1.step(lr, t) else: dz0_from_e1 = np.zeros_like(z0) # ── Decoder 0 (z0 → x0) ───────────────────────────────────────── dz0_from_d0 = self.dec0.backward(grads['dr0']) # (B, Z0) # ── Encoder 0 ──────────────────────────────────────────────────── dmu0 = dz0_from_d0 + dz0_from_d1 + dz0_from_d2 + dz0_from_e1 + dz0_from_e2 + grads['dmu0_kl'] dlv0 = dz0_from_d0 * 0.5 + grads['dlv0_kl'] self.enc0.backward(dmu0, dlv0) self.enc0.step(lr, t) self.dec0.step(lr, t) # ── Training ────────────────────────────────────────────────────────── def train_epoch(self, X: np.ndarray, mask: np.ndarray, lr: float, batch_size: int, phase: int, rng) -> Dict: N = len(X) idx = rng.permutation(N) epoch_stats = dict(total=0, recon0=0, recon1=0, recon2=0, kl0=0, kl1=0, kl2=0, n=0) for start in range(0, N, batch_size): bi = idx[start:start + batch_size] Xb = X[bi] mb = mask[bi] self.step_t += 1 # Adam global step (cumulative across all epochs) enc_out = self.encode(Xb, mb, rng) dec_out = self.decode(enc_out) grads = self.loss_and_grads(enc_out, dec_out, mb, N, phase) self.backward(enc_out, dec_out, grads, phase, lr) for k in ['total', 'recon0', 'recon1', 'recon2', 'kl0', 'kl1', 'kl2']: epoch_stats[k] += float(grads[k]) * len(bi) epoch_stats['n'] += len(bi) return {k: v / max(epoch_stats['n'], 1) for k, v in epoch_stats.items() if k != 'n'} # ── Inference ───────────────────────────────────────────────────────── def get_latents(self, X: np.ndarray, mask: np.ndarray) -> Dict[str, np.ndarray]: """Return disentangled latent codes for analysis.""" rng = np.random.RandomState(0) # deterministic for inference enc = self.encode(X, mask, rng) return { 'z0_macro': enc['mu0'], # (N, 4) deterministic mean 'z1_eigen': enc['mu1'], # (N, 8) 'z2_xsection': enc['mu2'], # (N, 8) 'z_all': np.concatenate([enc['mu0'], enc['mu1'], enc['mu2']], axis=1), # (N, 20) 'sigma0': np.exp(0.5 * enc['lv0']), # uncertainty 'sigma1': np.exp(0.5 * enc['lv1']), 'sigma2': np.exp(0.5 * enc['lv2']), } # ── Persistence ─────────────────────────────────────────────────────── def save(self, path: str): data = { 'norm': dict(mu_t0=self._mu_t0, sd_t0=self._sd_t0, mu_t1=self._mu_t1, sd_t1=self._sd_t1, mu_t2=self._mu_t2, sd_t2=self._sd_t2), 'step_t': self.step_t, 'losses': self.train_losses, } # Save all weight matrices for name, enc in [('enc0', self.enc0), ('enc1', self.enc1), ('enc2', self.enc2)]: for i, layer in enumerate(enc.mlp.layers): data[f'{name}_mlp{i}_W'] = layer.W data[f'{name}_mlp{i}_b'] = layer.b data[f'{name}_mu_W'] = enc.mu_head.W; data[f'{name}_mu_b'] = enc.mu_head.b data[f'{name}_lv_W'] = enc.lv_head.W; data[f'{name}_lv_b'] = enc.lv_head.b for name, dec in [('dec0', self.dec0), ('dec1', self.dec1), ('dec2', self.dec2)]: for i, layer in enumerate(dec.mlp.layers): data[f'{name}_mlp{i}_W'] = layer.W data[f'{name}_mlp{i}_b'] = layer.b np.savez_compressed(path, **data) print(f"Model saved: {path}.npz")