Includes core prod + GREEN/BLUE subsystems: - prod/ (BLUE harness, configs, scripts, docs) - nautilus_dolphin/ (GREEN Nautilus-native impl + dvae/ preserved) - adaptive_exit/ (AEM engine + models/bucket_assignments.pkl) - Observability/ (EsoF advisor, TUI, dashboards) - external_factors/ (EsoF producer) - mc_forewarning_qlabs_fork/ (MC regime/envelope) Excludes runtime caches, logs, backups, and reproducible artifacts per .gitignore.
560 lines
26 KiB
Python
Executable File
560 lines
26 KiB
Python
Executable File
"""
|
||
Hierarchical Disentangled VAE (H-D-VAE) for DOLPHIN
|
||
=====================================================
|
||
State-of-the-art β-TCVAE with:
|
||
- 3-level HIERARCHY of disentangled latent codes
|
||
- Real backpropagation (analytical gradients, no random noise)
|
||
- Masking for missing tiers (pre-eigenvalue era)
|
||
- Curriculum-aware: can freeze/thaw tiers during training
|
||
- Spectral features with 512-bit-aware log/ratio encoding
|
||
|
||
Latent structure:
|
||
z0 (dim=4) "macro regime" ← Tier-0 breadth + time
|
||
z1 (dim=8) "eigenstructure" ← Tier-1 eigenvalues, conditioned on z0
|
||
z2 (dim=8) "cross-section" ← Tier-2 per-asset, conditioned on z0+z1
|
||
|
||
Total latent dim: 20.
|
||
|
||
β-TCVAE decomposition (Chen et al. 2018):
|
||
KL = MI(x;z) + TC(z) + dim_KL(z)
|
||
Loss = Recon + γ·MI + β·TC + λ·dim_KL
|
||
|
||
Key insight: TC penalises factorial structure violation, encouraging
|
||
each dimension of z to capture an INDEPENDENT factor of variation.
|
||
With β >> 1 the model is forced to find the minimal sufficient encoding.
|
||
"""
|
||
|
||
import numpy as np
|
||
from typing import Tuple, Dict, Optional, List
|
||
|
||
# ── Constants ──────────────────────────────────────────────────────────────
|
||
# Must match corpus_builder.py DIMS = [8, 20, 50, 25, 8]
|
||
TIER0_DIM = 8
|
||
TIER1_DIM = 20
|
||
TIER2_DIM = 50
|
||
TIER3_DIM = 25 # ExF macro
|
||
TIER4_DIM = 8 # EsoF
|
||
|
||
# Tier offsets into 111-dim feature vector
|
||
T_OFF = [0, 8, 28, 78, 103]
|
||
|
||
Z0_DIM = 4 # macro regime
|
||
Z1_DIM = 8 # eigenstructure
|
||
Z2_DIM = 8 # cross-section + ExF
|
||
Z3_DIM = 4 # esoteric
|
||
Z_TOTAL = Z0_DIM + Z1_DIM + Z2_DIM + Z3_DIM # 24
|
||
|
||
EPS = 1e-8
|
||
|
||
|
||
# ── Numerics ───────────────────────────────────────────────────────────────
|
||
|
||
def _relu(x): return np.maximum(0, x)
|
||
def _drelu(x): return (x > 0).astype(x.dtype)
|
||
def _tanh(x): return np.tanh(x)
|
||
def _dtanh(x): return 1.0 - np.tanh(x)**2
|
||
def _softplus(x): return np.log1p(np.exp(np.clip(x, -20, 20)))
|
||
|
||
|
||
# ── Linear layer (weights + bias) with gradient ───────────────────────────
|
||
|
||
class Linear:
|
||
"""Affine layer with Adam optimiser state."""
|
||
def __init__(self, in_dim: int, out_dim: int, seed: int = 42):
|
||
rng = np.random.RandomState(seed)
|
||
scale = np.sqrt(2.0 / in_dim)
|
||
self.W = rng.randn(in_dim, out_dim).astype(np.float32) * scale
|
||
self.b = np.zeros(out_dim, dtype=np.float32)
|
||
# Adam state
|
||
self.mW = np.zeros_like(self.W)
|
||
self.vW = np.zeros_like(self.W)
|
||
self.mb = np.zeros_like(self.b)
|
||
self.vb = np.zeros_like(self.b)
|
||
|
||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||
self._x = x
|
||
return x @ self.W + self.b
|
||
|
||
def backward(self, dout: np.ndarray) -> np.ndarray:
|
||
self.dW = self._x.T @ dout
|
||
self.db = dout.sum(axis=0)
|
||
return dout @ self.W.T
|
||
|
||
def step(self, lr: float, t: int, beta1=0.9, beta2=0.999):
|
||
bc1 = 1 - beta1**t
|
||
bc2 = 1 - beta2**t
|
||
for (p, m, v, g) in [(self.W, self.mW, self.vW, self.dW),
|
||
(self.b, self.mb, self.vb, self.db)]:
|
||
m[:] = beta1 * m + (1 - beta1) * g
|
||
v[:] = beta2 * v + (1 - beta2) * g**2
|
||
p -= lr * (m / bc1) / (np.sqrt(v / bc2) + 1e-8)
|
||
|
||
|
||
# ── Simple 2-layer MLP with ReLU ──────────────────────────────────────────
|
||
|
||
class MLP:
|
||
def __init__(self, dims: List[int], seed: int = 42):
|
||
self.layers = []
|
||
for i, (d_in, d_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||
self.layers.append(Linear(d_in, d_out, seed=seed + i))
|
||
# Activation cache
|
||
self._acts = []
|
||
|
||
def forward(self, x: np.ndarray, final_linear=True) -> np.ndarray:
|
||
self._acts = []
|
||
h = x
|
||
for i, layer in enumerate(self.layers):
|
||
h = layer.forward(h)
|
||
if i < len(self.layers) - 1 or not final_linear:
|
||
self._acts.append(h.copy())
|
||
h = _relu(h)
|
||
else:
|
||
self._acts.append(None)
|
||
return h
|
||
|
||
def backward(self, dout: np.ndarray) -> np.ndarray:
|
||
dh = dout
|
||
for i in range(len(self.layers) - 1, -1, -1):
|
||
if i < len(self.layers) - 1:
|
||
dh = dh * _drelu(self._acts[i])
|
||
dh = self.layers[i].backward(dh)
|
||
return dh
|
||
|
||
def step(self, lr: float, t: int):
|
||
for layer in self.layers:
|
||
layer.step(lr, t)
|
||
|
||
|
||
# ── Encoder: MLP → (mu, logvar) ───────────────────────────────────────────
|
||
|
||
class VAEEncoder:
|
||
"""Encodes input (optionally concatenated with conditioning z) to (mu, logvar)."""
|
||
def __init__(self, in_dim: int, hidden: int, z_dim: int, seed: int = 0):
|
||
self.mlp = MLP([in_dim, hidden, hidden], seed=seed)
|
||
self.mu_head = Linear(hidden, z_dim, seed=seed + 100)
|
||
self.lv_head = Linear(hidden, z_dim, seed=seed + 200)
|
||
self.z_dim = z_dim
|
||
|
||
def forward(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||
h = self.mlp.forward(x) # (B, hidden)
|
||
h = _relu(h)
|
||
self._h = h
|
||
mu = self.mu_head.forward(h)
|
||
logvar = self.lv_head.forward(h)
|
||
logvar = np.clip(logvar, -10, 4) # keep variance sane
|
||
return mu, logvar
|
||
|
||
def backward(self, dmu: np.ndarray, dlv: np.ndarray) -> np.ndarray:
|
||
dh_mu = self.mu_head.backward(dmu)
|
||
dh_lv = self.lv_head.backward(dlv)
|
||
dh = (dh_mu + dh_lv) * _drelu(self._h)
|
||
return self.mlp.backward(dh)
|
||
|
||
def step(self, lr: float, t: int):
|
||
self.mlp.step(lr, t)
|
||
self.mu_head.step(lr, t)
|
||
self.lv_head.step(lr, t)
|
||
|
||
|
||
# ── Decoder: z → x_hat ────────────────────────────────────────────────────
|
||
|
||
class VAEDecoder:
|
||
def __init__(self, z_dim: int, hidden: int, out_dim: int, seed: int = 0):
|
||
self.mlp = MLP([z_dim, hidden, hidden, out_dim], seed=seed + 300)
|
||
self.out_dim = out_dim
|
||
|
||
def forward(self, z: np.ndarray) -> np.ndarray:
|
||
return self.mlp.forward(z, final_linear=True)
|
||
|
||
def backward(self, dout: np.ndarray) -> np.ndarray:
|
||
return self.mlp.backward(dout)
|
||
|
||
def step(self, lr: float, t: int):
|
||
self.mlp.step(lr, t)
|
||
|
||
|
||
# ── β-TCVAE loss (Chen et al. 2018, proper dataset-size corrected form) ───
|
||
|
||
def btcvae_kl(mu: np.ndarray, logvar: np.ndarray,
|
||
N_dataset: int, beta: float = 4.0,
|
||
gamma: float = 1.0, lam: float = 1.0
|
||
) -> Tuple[float, np.ndarray, np.ndarray]:
|
||
"""
|
||
Decompose KL into: MI + beta*TC + lambda*dim_KL
|
||
Returns (total_kl, dmu, dlogvar).
|
||
|
||
MI = E_q[log q(z|x)] - E_q[log q(z)]
|
||
TC = E_q[log q(z)] - E_q[sum_j log q(z_j)]
|
||
dKL = E_q[sum_j log q(z_j) - log p(z_j)]
|
||
|
||
Minibatch-weighted estimator — numerically stable.
|
||
"""
|
||
B, D = mu.shape
|
||
var = np.exp(logvar)
|
||
|
||
# Per-sample KL (closed form vs N(0,1) prior)
|
||
kl_per_dim = 0.5 * (mu**2 + var - logvar - 1) # (B, D)
|
||
|
||
# Log q(z|x_i) for sample z_i ~ q(z|x_i) (diagonal Gaussian)
|
||
# We use the mean as the sample point (reparametrization noise ≈ 0 for gradient estimation)
|
||
# log q(z_i | x_i) = sum_d -0.5*(log(2π) + logvar + (z-mu)^2/var)
|
||
# At z = mu: (z-mu)=0, so log q(z|x) = -0.5*sum(log(2π) + logvar)
|
||
log_q_z_given_x = -0.5 * (D * np.log(2 * np.pi) + logvar.sum(axis=1)) # (B,)
|
||
|
||
# Minibatch estimate of log q(z): E_data[sum_x log q(z|x)] / N
|
||
# log q(z_i) ≈ log(1/N * sum_j q(z_i | x_j))
|
||
# Using pairwise: log q(z_i) ≈ logsumexp over j of log q(z_i | x_j) - log(N)
|
||
# q(z_i | x_j) for each pair: diffs = z_i[none,:] - mu[none,:] shape (B,B,D)
|
||
z_sample = mu # (B, D) — use mean as point estimate
|
||
# (B_i, B_j, D) pairwise differences
|
||
diff = z_sample[:, None, :] - mu[None, :, :] # (B, B, D)
|
||
log_q_z_pair = -0.5 * (
|
||
D * np.log(2 * np.pi)
|
||
+ logvar[None, :, :].sum(axis=2)
|
||
+ (diff**2 / (var[None, :, :] + EPS)).sum(axis=2)
|
||
) # (B, B)
|
||
log_q_z = np.log(np.sum(np.exp(log_q_z_pair - log_q_z_pair.max(axis=1, keepdims=True)), axis=1) + EPS) \
|
||
+ log_q_z_pair.max(axis=1) - np.log(N_dataset + EPS) # (B,)
|
||
|
||
# log q(z_j) marginal — independence assumption for TC
|
||
# log q(z_i_j) for each dim j: logsumexp over data points
|
||
log_q_z_product = np.zeros(B, dtype=np.float64)
|
||
for j in range(D):
|
||
diff_j = z_sample[:, j:j+1] - mu[:, j:j+1].T # (B, B)
|
||
log_q_zj = -0.5 * (np.log(2 * np.pi) + logvar[None, :, j] + diff_j**2 / (var[None, :, j] + EPS)) # (B,B)
|
||
log_q_zj_marginal = np.log(np.sum(np.exp(log_q_zj - log_q_zj.max(axis=1, keepdims=True)), axis=1) + EPS) \
|
||
+ log_q_zj.max(axis=1) - np.log(N_dataset + EPS) # (B,)
|
||
log_q_z_product += log_q_zj_marginal
|
||
|
||
# log p(z) = N(0,1)
|
||
log_p_z = -0.5 * (D * np.log(2 * np.pi) + (z_sample**2).sum(axis=1)) # (B,)
|
||
|
||
mi_loss = np.mean(log_q_z_given_x - log_q_z)
|
||
tc_loss = np.mean(log_q_z - log_q_z_product)
|
||
dkl_loss = np.mean(log_q_z_product - log_p_z)
|
||
|
||
total_kl = gamma * mi_loss + beta * max(tc_loss, 0) + lam * dkl_loss
|
||
|
||
# Gradients w.r.t. mu and logvar: use standard closed-form KL gradient
|
||
# ∂KL/∂mu = mu, ∂KL/∂logvar = 0.5*(exp(logvar) - 1)
|
||
scale = (gamma + beta + lam) / (3.0 * B + EPS)
|
||
dmu = scale * mu
|
||
dlogvar = scale * 0.5 * (var - 1.0)
|
||
|
||
return float(total_kl), dmu.astype(np.float32), dlogvar.astype(np.float32)
|
||
|
||
|
||
# ── Reparametrization ─────────────────────────────────────────────────────
|
||
|
||
def reparametrize(mu, logvar, rng):
|
||
std = np.exp(0.5 * logvar)
|
||
eps = rng.randn(*mu.shape).astype(np.float32)
|
||
return mu + eps * std
|
||
|
||
|
||
# ── Hierarchical D-VAE ────────────────────────────────────────────────────
|
||
|
||
class HierarchicalDVAE:
|
||
"""
|
||
4-level Hierarchical Disentangled VAE (111-dim input, 24-dim latent).
|
||
|
||
Latent hierarchy:
|
||
z0 (4) = enc0(T0) macro regime (breadth + time)
|
||
z1 (8) = enc1(T1 + T4 + z0) eigenstructure + esoteric, cond on z0
|
||
z2 (8) = enc2(T2 + T3 + z0 + z1) cross-section + ExF, cond on z0,z1
|
||
z3 (4) = enc3(T3 + T4 + z0) ExF+EsoF regime, cond on z0
|
||
|
||
Decoding:
|
||
T0_hat = dec0(z0)
|
||
T1_hat = dec1(z0, z1)
|
||
T2_hat = dec2(z0, z1, z2)
|
||
T3_hat = dec3(z0, z3)
|
||
T4_hat = dec4(z0) (EsoF is deterministic, serves as auxiliary)
|
||
|
||
Training phases:
|
||
0: enc0/dec0 only — full 500K+ corpus (even pre-eigen NG1/NG2)
|
||
1: + enc1/dec1 — eigen-tier (NG3+)
|
||
2: + enc2/dec2 — pricing + ExF (NG3+ with pricing)
|
||
3: + enc3/dec3 — ExF regime
|
||
4: joint fine-tune all
|
||
"""
|
||
|
||
def __init__(self, hidden: int = 64, beta: float = 4.0,
|
||
gamma: float = 1.0, lam: float = 1.0, seed: int = 42):
|
||
self.hidden = hidden
|
||
self.beta = beta
|
||
self.gamma = gamma
|
||
self.lam = lam
|
||
self.seed = seed
|
||
|
||
# Encoders (each conditioned on upstream z)
|
||
self.enc0 = VAEEncoder(TIER0_DIM, hidden, Z0_DIM, seed=seed)
|
||
self.enc1 = VAEEncoder(TIER1_DIM + TIER4_DIM + Z0_DIM, hidden, Z1_DIM, seed=seed+10)
|
||
self.enc2 = VAEEncoder(TIER2_DIM + TIER3_DIM + Z0_DIM + Z1_DIM, hidden, Z2_DIM, seed=seed+20)
|
||
self.enc3 = VAEEncoder(TIER3_DIM + TIER4_DIM + Z0_DIM, hidden//2, Z3_DIM, seed=seed+30)
|
||
|
||
# Decoders
|
||
self.dec0 = VAEDecoder(Z0_DIM, hidden, TIER0_DIM, seed=seed)
|
||
self.dec1 = VAEDecoder(Z0_DIM + Z1_DIM, hidden, TIER1_DIM, seed=seed+10)
|
||
self.dec2 = VAEDecoder(Z0_DIM + Z1_DIM + Z2_DIM, hidden, TIER2_DIM, seed=seed+20)
|
||
self.dec3 = VAEDecoder(Z0_DIM + Z3_DIM, hidden//2, TIER3_DIM, seed=seed+30)
|
||
self.dec4 = VAEDecoder(Z0_DIM, hidden//2, TIER4_DIM, seed=seed+40)
|
||
|
||
# Normalisation statistics (fit on training data)
|
||
self._mu_t0 = np.zeros(TIER0_DIM, dtype=np.float32)
|
||
self._sd_t0 = np.ones(TIER0_DIM, dtype=np.float32)
|
||
self._mu_t1 = np.zeros(TIER1_DIM, dtype=np.float32)
|
||
self._sd_t1 = np.ones(TIER1_DIM, dtype=np.float32)
|
||
self._mu_t2 = np.zeros(TIER2_DIM, dtype=np.float32)
|
||
self._sd_t2 = np.ones(TIER2_DIM, dtype=np.float32)
|
||
self._mu_t3 = np.zeros(TIER3_DIM, dtype=np.float32)
|
||
self._sd_t3 = np.ones(TIER3_DIM, dtype=np.float32)
|
||
self._mu_t4 = np.zeros(TIER4_DIM, dtype=np.float32)
|
||
self._sd_t4 = np.ones(TIER4_DIM, dtype=np.float32)
|
||
|
||
self.step_t = 0 # Adam time step
|
||
self.train_losses: List[Dict] = []
|
||
|
||
# ── Normalisation ──────────────────────────────────────────────────────
|
||
|
||
def fit_normaliser(self, X: np.ndarray, mask: np.ndarray):
|
||
"""Fit per-tier normalisation on a representative sample."""
|
||
def _fit(rows, dim):
|
||
if len(rows) < 2:
|
||
return np.zeros(dim, dtype=np.float32), np.ones(dim, dtype=np.float32)
|
||
return rows.mean(0).astype(np.float32), (rows.std(0) + EPS).astype(np.float32)
|
||
|
||
self._mu_t0, self._sd_t0 = _fit(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM], TIER0_DIM)
|
||
idx1 = mask[:, 1]
|
||
self._mu_t1, self._sd_t1 = _fit(X[idx1, T_OFF[1]:T_OFF[1]+TIER1_DIM], TIER1_DIM)
|
||
idx2 = mask[:, 2]
|
||
self._mu_t2, self._sd_t2 = _fit(X[idx2, T_OFF[2]:T_OFF[2]+TIER2_DIM], TIER2_DIM)
|
||
idx3 = mask[:, 3]
|
||
self._mu_t3, self._sd_t3 = _fit(X[idx3, T_OFF[3]:T_OFF[3]+TIER3_DIM], TIER3_DIM)
|
||
self._mu_t4, self._sd_t4 = _fit(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM], TIER4_DIM)
|
||
|
||
def _norm(self, x, mu, sd): return (x - mu) / sd
|
||
def _norm0(self, x): return self._norm(x, self._mu_t0, self._sd_t0)
|
||
def _norm1(self, x): return self._norm(x, self._mu_t1, self._sd_t1)
|
||
def _norm2(self, x): return self._norm(x, self._mu_t2, self._sd_t2)
|
||
def _norm3(self, x): return self._norm(x, self._mu_t3, self._sd_t3)
|
||
def _norm4(self, x): return self._norm(x, self._mu_t4, self._sd_t4)
|
||
|
||
# ── Forward pass ──────────────────────────────────────────────────────
|
||
|
||
def _split(self, X):
|
||
"""Split 111-dim row into normalised tier vectors."""
|
||
t0 = self._norm0(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM])
|
||
t1 = self._norm1(X[:, T_OFF[1]:T_OFF[1]+TIER1_DIM])
|
||
t2 = self._norm2(X[:, T_OFF[2]:T_OFF[2]+TIER2_DIM])
|
||
t3 = self._norm3(X[:, T_OFF[3]:T_OFF[3]+TIER3_DIM])
|
||
t4 = self._norm4(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM])
|
||
return t0, t1, t2, t3, t4
|
||
|
||
def encode(self, X: np.ndarray, mask: np.ndarray, rng) -> Dict:
|
||
t0, t1, t2, t3, t4 = self._split(X)
|
||
|
||
mu0, lv0 = self.enc0.forward(t0)
|
||
z0 = reparametrize(mu0, lv0, rng)
|
||
|
||
mu1, lv1 = self.enc1.forward(np.concatenate([t1, t4, z0], axis=1))
|
||
z1 = reparametrize(mu1, lv1, rng)
|
||
|
||
mu2, lv2 = self.enc2.forward(np.concatenate([t2, t3, z0, z1], axis=1))
|
||
z2 = reparametrize(mu2, lv2, rng)
|
||
|
||
mu3, lv3 = self.enc3.forward(np.concatenate([t3, t4, z0], axis=1))
|
||
z3 = reparametrize(mu3, lv3, rng)
|
||
|
||
return dict(t0=t0, t1=t1, t2=t2, t3=t3, t4=t4,
|
||
mu0=mu0, lv0=lv0, z0=z0,
|
||
mu1=mu1, lv1=lv1, z1=z1,
|
||
mu2=mu2, lv2=lv2, z2=z2,
|
||
mu3=mu3, lv3=lv3, z3=z3)
|
||
|
||
def decode(self, enc: Dict) -> Dict:
|
||
z0, z1, z2, z3 = enc['z0'], enc['z1'], enc['z2'], enc['z3']
|
||
x0_hat = self.dec0.forward(z0)
|
||
x1_hat = self.dec1.forward(np.concatenate([z0, z1], axis=1))
|
||
x2_hat = self.dec2.forward(np.concatenate([z0, z1, z2], axis=1))
|
||
x3_hat = self.dec3.forward(np.concatenate([z0, z3], axis=1))
|
||
x4_hat = self.dec4.forward(z0) # EsoF decoded from macro only
|
||
return dict(x0_hat=x0_hat, x1_hat=x1_hat, x2_hat=x2_hat,
|
||
x3_hat=x3_hat, x4_hat=x4_hat)
|
||
|
||
# ── Loss and gradients ─────────────────────────────────────────────────
|
||
|
||
def loss_and_grads(self, enc: Dict, dec: Dict, mask: np.ndarray,
|
||
N_dataset: int, phase: int) -> Dict:
|
||
"""Compute total loss and all gradients via chain rule."""
|
||
t0, t1, t2 = enc['t0'], enc['t1'], enc['t2']
|
||
x0_hat = dec['x0_hat']
|
||
x1_hat = dec['x1_hat']
|
||
x2_hat = dec['x2_hat']
|
||
B = len(t0)
|
||
|
||
# Mask weights: per-sample, per-tier
|
||
w1 = mask[:, 1].astype(np.float32)[:, None] # (B,1)
|
||
w2 = mask[:, 2].astype(np.float32)[:, None]
|
||
|
||
# Reconstruction losses
|
||
recon0 = np.mean((x0_hat - t0)**2)
|
||
recon1 = np.mean(w1 * (x1_hat - t1)**2)
|
||
recon2 = np.mean(w2 * (x2_hat - t2)**2)
|
||
|
||
# Reconstruction gradients
|
||
dr0 = 2 * (x0_hat - t0) / (B * TIER0_DIM)
|
||
dr1 = 2 * w1 * (x1_hat - t1) / (B * TIER1_DIM)
|
||
dr2 = 2 * w2 * (x2_hat - t2) / (B * TIER2_DIM)
|
||
|
||
# KL losses
|
||
kl0, dmu0_kl, dlv0_kl = btcvae_kl(enc['mu0'], enc['lv0'], N_dataset, self.beta, self.gamma, self.lam)
|
||
kl1, dmu1_kl, dlv1_kl = btcvae_kl(enc['mu1'], enc['lv1'], N_dataset, self.beta, self.gamma, self.lam)
|
||
kl2, dmu2_kl, dlv2_kl = btcvae_kl(enc['mu2'], enc['lv2'], N_dataset, self.beta, self.gamma, self.lam)
|
||
|
||
# Phase-based scaling: only activate eigen/pricing in correct phase
|
||
if phase < 1:
|
||
kl1 = recon1 = 0.0; dmu1_kl[:] = 0; dlv1_kl[:] = 0; dr1[:] = 0
|
||
if phase < 2:
|
||
kl2 = recon2 = 0.0; dmu2_kl[:] = 0; dlv2_kl[:] = 0; dr2[:] = 0
|
||
|
||
total = recon0 + recon1 + recon2 + kl0 + kl1 + kl2
|
||
|
||
return dict(
|
||
total=total, recon0=recon0, recon1=recon1, recon2=recon2,
|
||
kl0=kl0, kl1=kl1, kl2=kl2,
|
||
dr0=dr0, dr1=dr1, dr2=dr2,
|
||
dmu0_kl=dmu0_kl, dlv0_kl=dlv0_kl,
|
||
dmu1_kl=dmu1_kl, dlv1_kl=dlv1_kl,
|
||
dmu2_kl=dmu2_kl, dlv2_kl=dlv2_kl,
|
||
)
|
||
|
||
def backward(self, enc: Dict, dec: Dict, grads: Dict, phase: int, lr: float):
|
||
"""Backpropagate through all encoders and decoders, then Adam step."""
|
||
z0, z1, z2 = enc['z0'], enc['z1'], enc['z2']
|
||
t = self.step_t
|
||
|
||
# ── Decoder 2 (z0+z1+z2 → x2) ──────────────────────────────────
|
||
if phase >= 2:
|
||
dz_all_2 = self.dec2.backward(grads['dr2']) # (B, Z0+Z1+Z2)
|
||
dz0_from_d2 = dz_all_2[:, :Z0_DIM]
|
||
dz1_from_d2 = dz_all_2[:, Z0_DIM:Z0_DIM + Z1_DIM]
|
||
dz2_from_d2 = dz_all_2[:, Z0_DIM + Z1_DIM:]
|
||
else:
|
||
dz0_from_d2 = np.zeros_like(z0)
|
||
dz1_from_d2 = np.zeros_like(z1)
|
||
dz2_from_d2 = np.zeros_like(z2)
|
||
self.dec2.backward(np.zeros_like(grads['dr2']))
|
||
|
||
# ── Encoder 2 ─────────────────────────────────────────────────
|
||
# enc2 input = concat(T2, T3, z0, z1) dims = TIER2+TIER3+Z0+Z1
|
||
if phase >= 2:
|
||
dmu2 = dz2_from_d2 + grads['dmu2_kl']
|
||
dlv2 = dz2_from_d2 * 0.5 + grads['dlv2_kl']
|
||
dinp2 = self.enc2.backward(dmu2, dlv2) # (B, T2+T3+Z0+Z1)
|
||
_off = TIER2_DIM + TIER3_DIM
|
||
dz0_from_e2 = dinp2[:, _off:_off + Z0_DIM]
|
||
dz1_from_e2 = dinp2[:, _off + Z0_DIM:]
|
||
self.dec2.step(lr, t)
|
||
self.enc2.step(lr, t)
|
||
else:
|
||
dz0_from_e2 = np.zeros_like(z0)
|
||
dz1_from_e2 = np.zeros_like(z1)
|
||
|
||
# ── Decoder 1 (z0+z1 → x1) ──────────────────────────────────────
|
||
if phase >= 1:
|
||
dz_all_1 = self.dec1.backward(grads['dr1']) # (B, Z0+Z1)
|
||
dz0_from_d1 = dz_all_1[:, :Z0_DIM]
|
||
dz1_from_d1 = dz_all_1[:, Z0_DIM:]
|
||
else:
|
||
dz0_from_d1 = np.zeros_like(z0)
|
||
dz1_from_d1 = np.zeros_like(z1)
|
||
self.dec1.backward(np.zeros_like(grads['dr1']))
|
||
|
||
# ── Encoder 1 ─────────────────────────────────────────────────
|
||
# enc1 input = concat(T1, T4, z0) dims = TIER1+TIER4+Z0
|
||
if phase >= 1:
|
||
dmu1 = dz1_from_d1 + dz1_from_d2 + dz1_from_e2 + grads['dmu1_kl']
|
||
dlv1 = dz1_from_d1 * 0.5 + grads['dlv1_kl']
|
||
dinp1 = self.enc1.backward(dmu1, dlv1) # (B, T1+T4+Z0)
|
||
dz0_from_e1 = dinp1[:, TIER1_DIM + TIER4_DIM:] # correct: skip T1+T4
|
||
self.dec1.step(lr, t)
|
||
self.enc1.step(lr, t)
|
||
else:
|
||
dz0_from_e1 = np.zeros_like(z0)
|
||
|
||
# ── Decoder 0 (z0 → x0) ─────────────────────────────────────────
|
||
dz0_from_d0 = self.dec0.backward(grads['dr0']) # (B, Z0)
|
||
|
||
# ── Encoder 0 ────────────────────────────────────────────────────
|
||
dmu0 = dz0_from_d0 + dz0_from_d1 + dz0_from_d2 + dz0_from_e1 + dz0_from_e2 + grads['dmu0_kl']
|
||
dlv0 = dz0_from_d0 * 0.5 + grads['dlv0_kl']
|
||
self.enc0.backward(dmu0, dlv0)
|
||
self.enc0.step(lr, t)
|
||
self.dec0.step(lr, t)
|
||
|
||
# ── Training ──────────────────────────────────────────────────────────
|
||
|
||
def train_epoch(self, X: np.ndarray, mask: np.ndarray,
|
||
lr: float, batch_size: int, phase: int, rng) -> Dict:
|
||
N = len(X)
|
||
idx = rng.permutation(N)
|
||
epoch_stats = dict(total=0, recon0=0, recon1=0, recon2=0, kl0=0, kl1=0, kl2=0, n=0)
|
||
|
||
for start in range(0, N, batch_size):
|
||
bi = idx[start:start + batch_size]
|
||
Xb = X[bi]
|
||
mb = mask[bi]
|
||
self.step_t += 1 # Adam global step (cumulative across all epochs)
|
||
|
||
enc_out = self.encode(Xb, mb, rng)
|
||
dec_out = self.decode(enc_out)
|
||
grads = self.loss_and_grads(enc_out, dec_out, mb, N, phase)
|
||
self.backward(enc_out, dec_out, grads, phase, lr)
|
||
|
||
for k in ['total', 'recon0', 'recon1', 'recon2', 'kl0', 'kl1', 'kl2']:
|
||
epoch_stats[k] += float(grads[k]) * len(bi)
|
||
epoch_stats['n'] += len(bi)
|
||
|
||
return {k: v / max(epoch_stats['n'], 1) for k, v in epoch_stats.items() if k != 'n'}
|
||
|
||
# ── Inference ─────────────────────────────────────────────────────────
|
||
|
||
def get_latents(self, X: np.ndarray, mask: np.ndarray) -> Dict[str, np.ndarray]:
|
||
"""Return disentangled latent codes for analysis."""
|
||
rng = np.random.RandomState(0) # deterministic for inference
|
||
enc = self.encode(X, mask, rng)
|
||
return {
|
||
'z0_macro': enc['mu0'], # (N, 4) deterministic mean
|
||
'z1_eigen': enc['mu1'], # (N, 8)
|
||
'z2_xsection': enc['mu2'], # (N, 8)
|
||
'z_all': np.concatenate([enc['mu0'], enc['mu1'], enc['mu2']], axis=1), # (N, 20)
|
||
'sigma0': np.exp(0.5 * enc['lv0']), # uncertainty
|
||
'sigma1': np.exp(0.5 * enc['lv1']),
|
||
'sigma2': np.exp(0.5 * enc['lv2']),
|
||
}
|
||
|
||
# ── Persistence ───────────────────────────────────────────────────────
|
||
|
||
def save(self, path: str):
|
||
data = {
|
||
'norm': dict(mu_t0=self._mu_t0, sd_t0=self._sd_t0,
|
||
mu_t1=self._mu_t1, sd_t1=self._sd_t1,
|
||
mu_t2=self._mu_t2, sd_t2=self._sd_t2),
|
||
'step_t': self.step_t,
|
||
'losses': self.train_losses,
|
||
}
|
||
# Save all weight matrices
|
||
for name, enc in [('enc0', self.enc0), ('enc1', self.enc1), ('enc2', self.enc2)]:
|
||
for i, layer in enumerate(enc.mlp.layers):
|
||
data[f'{name}_mlp{i}_W'] = layer.W
|
||
data[f'{name}_mlp{i}_b'] = layer.b
|
||
data[f'{name}_mu_W'] = enc.mu_head.W; data[f'{name}_mu_b'] = enc.mu_head.b
|
||
data[f'{name}_lv_W'] = enc.lv_head.W; data[f'{name}_lv_b'] = enc.lv_head.b
|
||
for name, dec in [('dec0', self.dec0), ('dec1', self.dec1), ('dec2', self.dec2)]:
|
||
for i, layer in enumerate(dec.mlp.layers):
|
||
data[f'{name}_mlp{i}_W'] = layer.W
|
||
data[f'{name}_mlp{i}_b'] = layer.b
|
||
np.savez_compressed(path, **data)
|
||
print(f"Model saved: {path}.npz")
|