initial: import DOLPHIN baseline 2026-04-21 from dolphinng5_predict working tree
Includes core prod + GREEN/BLUE subsystems: - prod/ (BLUE harness, configs, scripts, docs) - nautilus_dolphin/ (GREEN Nautilus-native impl + dvae/ preserved) - adaptive_exit/ (AEM engine + models/bucket_assignments.pkl) - Observability/ (EsoF advisor, TUI, dashboards) - external_factors/ (EsoF producer) - mc_forewarning_qlabs_fork/ (MC regime/envelope) Excludes runtime caches, logs, backups, and reproducible artifacts per .gitignore.
This commit is contained in:
559
nautilus_dolphin/dvae/hierarchical_dvae.py
Executable file
559
nautilus_dolphin/dvae/hierarchical_dvae.py
Executable file
@@ -0,0 +1,559 @@
|
||||
"""
|
||||
Hierarchical Disentangled VAE (H-D-VAE) for DOLPHIN
|
||||
=====================================================
|
||||
State-of-the-art β-TCVAE with:
|
||||
- 3-level HIERARCHY of disentangled latent codes
|
||||
- Real backpropagation (analytical gradients, no random noise)
|
||||
- Masking for missing tiers (pre-eigenvalue era)
|
||||
- Curriculum-aware: can freeze/thaw tiers during training
|
||||
- Spectral features with 512-bit-aware log/ratio encoding
|
||||
|
||||
Latent structure:
|
||||
z0 (dim=4) "macro regime" ← Tier-0 breadth + time
|
||||
z1 (dim=8) "eigenstructure" ← Tier-1 eigenvalues, conditioned on z0
|
||||
z2 (dim=8) "cross-section" ← Tier-2 per-asset, conditioned on z0+z1
|
||||
|
||||
Total latent dim: 20.
|
||||
|
||||
β-TCVAE decomposition (Chen et al. 2018):
|
||||
KL = MI(x;z) + TC(z) + dim_KL(z)
|
||||
Loss = Recon + γ·MI + β·TC + λ·dim_KL
|
||||
|
||||
Key insight: TC penalises factorial structure violation, encouraging
|
||||
each dimension of z to capture an INDEPENDENT factor of variation.
|
||||
With β >> 1 the model is forced to find the minimal sufficient encoding.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Tuple, Dict, Optional, List
|
||||
|
||||
# ── Constants ──────────────────────────────────────────────────────────────
|
||||
# Must match corpus_builder.py DIMS = [8, 20, 50, 25, 8]
|
||||
TIER0_DIM = 8
|
||||
TIER1_DIM = 20
|
||||
TIER2_DIM = 50
|
||||
TIER3_DIM = 25 # ExF macro
|
||||
TIER4_DIM = 8 # EsoF
|
||||
|
||||
# Tier offsets into 111-dim feature vector
|
||||
T_OFF = [0, 8, 28, 78, 103]
|
||||
|
||||
Z0_DIM = 4 # macro regime
|
||||
Z1_DIM = 8 # eigenstructure
|
||||
Z2_DIM = 8 # cross-section + ExF
|
||||
Z3_DIM = 4 # esoteric
|
||||
Z_TOTAL = Z0_DIM + Z1_DIM + Z2_DIM + Z3_DIM # 24
|
||||
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
# ── Numerics ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _relu(x): return np.maximum(0, x)
|
||||
def _drelu(x): return (x > 0).astype(x.dtype)
|
||||
def _tanh(x): return np.tanh(x)
|
||||
def _dtanh(x): return 1.0 - np.tanh(x)**2
|
||||
def _softplus(x): return np.log1p(np.exp(np.clip(x, -20, 20)))
|
||||
|
||||
|
||||
# ── Linear layer (weights + bias) with gradient ───────────────────────────
|
||||
|
||||
class Linear:
|
||||
"""Affine layer with Adam optimiser state."""
|
||||
def __init__(self, in_dim: int, out_dim: int, seed: int = 42):
|
||||
rng = np.random.RandomState(seed)
|
||||
scale = np.sqrt(2.0 / in_dim)
|
||||
self.W = rng.randn(in_dim, out_dim).astype(np.float32) * scale
|
||||
self.b = np.zeros(out_dim, dtype=np.float32)
|
||||
# Adam state
|
||||
self.mW = np.zeros_like(self.W)
|
||||
self.vW = np.zeros_like(self.W)
|
||||
self.mb = np.zeros_like(self.b)
|
||||
self.vb = np.zeros_like(self.b)
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
self._x = x
|
||||
return x @ self.W + self.b
|
||||
|
||||
def backward(self, dout: np.ndarray) -> np.ndarray:
|
||||
self.dW = self._x.T @ dout
|
||||
self.db = dout.sum(axis=0)
|
||||
return dout @ self.W.T
|
||||
|
||||
def step(self, lr: float, t: int, beta1=0.9, beta2=0.999):
|
||||
bc1 = 1 - beta1**t
|
||||
bc2 = 1 - beta2**t
|
||||
for (p, m, v, g) in [(self.W, self.mW, self.vW, self.dW),
|
||||
(self.b, self.mb, self.vb, self.db)]:
|
||||
m[:] = beta1 * m + (1 - beta1) * g
|
||||
v[:] = beta2 * v + (1 - beta2) * g**2
|
||||
p -= lr * (m / bc1) / (np.sqrt(v / bc2) + 1e-8)
|
||||
|
||||
|
||||
# ── Simple 2-layer MLP with ReLU ──────────────────────────────────────────
|
||||
|
||||
class MLP:
|
||||
def __init__(self, dims: List[int], seed: int = 42):
|
||||
self.layers = []
|
||||
for i, (d_in, d_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
self.layers.append(Linear(d_in, d_out, seed=seed + i))
|
||||
# Activation cache
|
||||
self._acts = []
|
||||
|
||||
def forward(self, x: np.ndarray, final_linear=True) -> np.ndarray:
|
||||
self._acts = []
|
||||
h = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
h = layer.forward(h)
|
||||
if i < len(self.layers) - 1 or not final_linear:
|
||||
self._acts.append(h.copy())
|
||||
h = _relu(h)
|
||||
else:
|
||||
self._acts.append(None)
|
||||
return h
|
||||
|
||||
def backward(self, dout: np.ndarray) -> np.ndarray:
|
||||
dh = dout
|
||||
for i in range(len(self.layers) - 1, -1, -1):
|
||||
if i < len(self.layers) - 1:
|
||||
dh = dh * _drelu(self._acts[i])
|
||||
dh = self.layers[i].backward(dh)
|
||||
return dh
|
||||
|
||||
def step(self, lr: float, t: int):
|
||||
for layer in self.layers:
|
||||
layer.step(lr, t)
|
||||
|
||||
|
||||
# ── Encoder: MLP → (mu, logvar) ───────────────────────────────────────────
|
||||
|
||||
class VAEEncoder:
|
||||
"""Encodes input (optionally concatenated with conditioning z) to (mu, logvar)."""
|
||||
def __init__(self, in_dim: int, hidden: int, z_dim: int, seed: int = 0):
|
||||
self.mlp = MLP([in_dim, hidden, hidden], seed=seed)
|
||||
self.mu_head = Linear(hidden, z_dim, seed=seed + 100)
|
||||
self.lv_head = Linear(hidden, z_dim, seed=seed + 200)
|
||||
self.z_dim = z_dim
|
||||
|
||||
def forward(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
h = self.mlp.forward(x) # (B, hidden)
|
||||
h = _relu(h)
|
||||
self._h = h
|
||||
mu = self.mu_head.forward(h)
|
||||
logvar = self.lv_head.forward(h)
|
||||
logvar = np.clip(logvar, -10, 4) # keep variance sane
|
||||
return mu, logvar
|
||||
|
||||
def backward(self, dmu: np.ndarray, dlv: np.ndarray) -> np.ndarray:
|
||||
dh_mu = self.mu_head.backward(dmu)
|
||||
dh_lv = self.lv_head.backward(dlv)
|
||||
dh = (dh_mu + dh_lv) * _drelu(self._h)
|
||||
return self.mlp.backward(dh)
|
||||
|
||||
def step(self, lr: float, t: int):
|
||||
self.mlp.step(lr, t)
|
||||
self.mu_head.step(lr, t)
|
||||
self.lv_head.step(lr, t)
|
||||
|
||||
|
||||
# ── Decoder: z → x_hat ────────────────────────────────────────────────────
|
||||
|
||||
class VAEDecoder:
|
||||
def __init__(self, z_dim: int, hidden: int, out_dim: int, seed: int = 0):
|
||||
self.mlp = MLP([z_dim, hidden, hidden, out_dim], seed=seed + 300)
|
||||
self.out_dim = out_dim
|
||||
|
||||
def forward(self, z: np.ndarray) -> np.ndarray:
|
||||
return self.mlp.forward(z, final_linear=True)
|
||||
|
||||
def backward(self, dout: np.ndarray) -> np.ndarray:
|
||||
return self.mlp.backward(dout)
|
||||
|
||||
def step(self, lr: float, t: int):
|
||||
self.mlp.step(lr, t)
|
||||
|
||||
|
||||
# ── β-TCVAE loss (Chen et al. 2018, proper dataset-size corrected form) ───
|
||||
|
||||
def btcvae_kl(mu: np.ndarray, logvar: np.ndarray,
|
||||
N_dataset: int, beta: float = 4.0,
|
||||
gamma: float = 1.0, lam: float = 1.0
|
||||
) -> Tuple[float, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Decompose KL into: MI + beta*TC + lambda*dim_KL
|
||||
Returns (total_kl, dmu, dlogvar).
|
||||
|
||||
MI = E_q[log q(z|x)] - E_q[log q(z)]
|
||||
TC = E_q[log q(z)] - E_q[sum_j log q(z_j)]
|
||||
dKL = E_q[sum_j log q(z_j) - log p(z_j)]
|
||||
|
||||
Minibatch-weighted estimator — numerically stable.
|
||||
"""
|
||||
B, D = mu.shape
|
||||
var = np.exp(logvar)
|
||||
|
||||
# Per-sample KL (closed form vs N(0,1) prior)
|
||||
kl_per_dim = 0.5 * (mu**2 + var - logvar - 1) # (B, D)
|
||||
|
||||
# Log q(z|x_i) for sample z_i ~ q(z|x_i) (diagonal Gaussian)
|
||||
# We use the mean as the sample point (reparametrization noise ≈ 0 for gradient estimation)
|
||||
# log q(z_i | x_i) = sum_d -0.5*(log(2π) + logvar + (z-mu)^2/var)
|
||||
# At z = mu: (z-mu)=0, so log q(z|x) = -0.5*sum(log(2π) + logvar)
|
||||
log_q_z_given_x = -0.5 * (D * np.log(2 * np.pi) + logvar.sum(axis=1)) # (B,)
|
||||
|
||||
# Minibatch estimate of log q(z): E_data[sum_x log q(z|x)] / N
|
||||
# log q(z_i) ≈ log(1/N * sum_j q(z_i | x_j))
|
||||
# Using pairwise: log q(z_i) ≈ logsumexp over j of log q(z_i | x_j) - log(N)
|
||||
# q(z_i | x_j) for each pair: diffs = z_i[none,:] - mu[none,:] shape (B,B,D)
|
||||
z_sample = mu # (B, D) — use mean as point estimate
|
||||
# (B_i, B_j, D) pairwise differences
|
||||
diff = z_sample[:, None, :] - mu[None, :, :] # (B, B, D)
|
||||
log_q_z_pair = -0.5 * (
|
||||
D * np.log(2 * np.pi)
|
||||
+ logvar[None, :, :].sum(axis=2)
|
||||
+ (diff**2 / (var[None, :, :] + EPS)).sum(axis=2)
|
||||
) # (B, B)
|
||||
log_q_z = np.log(np.sum(np.exp(log_q_z_pair - log_q_z_pair.max(axis=1, keepdims=True)), axis=1) + EPS) \
|
||||
+ log_q_z_pair.max(axis=1) - np.log(N_dataset + EPS) # (B,)
|
||||
|
||||
# log q(z_j) marginal — independence assumption for TC
|
||||
# log q(z_i_j) for each dim j: logsumexp over data points
|
||||
log_q_z_product = np.zeros(B, dtype=np.float64)
|
||||
for j in range(D):
|
||||
diff_j = z_sample[:, j:j+1] - mu[:, j:j+1].T # (B, B)
|
||||
log_q_zj = -0.5 * (np.log(2 * np.pi) + logvar[None, :, j] + diff_j**2 / (var[None, :, j] + EPS)) # (B,B)
|
||||
log_q_zj_marginal = np.log(np.sum(np.exp(log_q_zj - log_q_zj.max(axis=1, keepdims=True)), axis=1) + EPS) \
|
||||
+ log_q_zj.max(axis=1) - np.log(N_dataset + EPS) # (B,)
|
||||
log_q_z_product += log_q_zj_marginal
|
||||
|
||||
# log p(z) = N(0,1)
|
||||
log_p_z = -0.5 * (D * np.log(2 * np.pi) + (z_sample**2).sum(axis=1)) # (B,)
|
||||
|
||||
mi_loss = np.mean(log_q_z_given_x - log_q_z)
|
||||
tc_loss = np.mean(log_q_z - log_q_z_product)
|
||||
dkl_loss = np.mean(log_q_z_product - log_p_z)
|
||||
|
||||
total_kl = gamma * mi_loss + beta * max(tc_loss, 0) + lam * dkl_loss
|
||||
|
||||
# Gradients w.r.t. mu and logvar: use standard closed-form KL gradient
|
||||
# ∂KL/∂mu = mu, ∂KL/∂logvar = 0.5*(exp(logvar) - 1)
|
||||
scale = (gamma + beta + lam) / (3.0 * B + EPS)
|
||||
dmu = scale * mu
|
||||
dlogvar = scale * 0.5 * (var - 1.0)
|
||||
|
||||
return float(total_kl), dmu.astype(np.float32), dlogvar.astype(np.float32)
|
||||
|
||||
|
||||
# ── Reparametrization ─────────────────────────────────────────────────────
|
||||
|
||||
def reparametrize(mu, logvar, rng):
|
||||
std = np.exp(0.5 * logvar)
|
||||
eps = rng.randn(*mu.shape).astype(np.float32)
|
||||
return mu + eps * std
|
||||
|
||||
|
||||
# ── Hierarchical D-VAE ────────────────────────────────────────────────────
|
||||
|
||||
class HierarchicalDVAE:
|
||||
"""
|
||||
4-level Hierarchical Disentangled VAE (111-dim input, 24-dim latent).
|
||||
|
||||
Latent hierarchy:
|
||||
z0 (4) = enc0(T0) macro regime (breadth + time)
|
||||
z1 (8) = enc1(T1 + T4 + z0) eigenstructure + esoteric, cond on z0
|
||||
z2 (8) = enc2(T2 + T3 + z0 + z1) cross-section + ExF, cond on z0,z1
|
||||
z3 (4) = enc3(T3 + T4 + z0) ExF+EsoF regime, cond on z0
|
||||
|
||||
Decoding:
|
||||
T0_hat = dec0(z0)
|
||||
T1_hat = dec1(z0, z1)
|
||||
T2_hat = dec2(z0, z1, z2)
|
||||
T3_hat = dec3(z0, z3)
|
||||
T4_hat = dec4(z0) (EsoF is deterministic, serves as auxiliary)
|
||||
|
||||
Training phases:
|
||||
0: enc0/dec0 only — full 500K+ corpus (even pre-eigen NG1/NG2)
|
||||
1: + enc1/dec1 — eigen-tier (NG3+)
|
||||
2: + enc2/dec2 — pricing + ExF (NG3+ with pricing)
|
||||
3: + enc3/dec3 — ExF regime
|
||||
4: joint fine-tune all
|
||||
"""
|
||||
|
||||
def __init__(self, hidden: int = 64, beta: float = 4.0,
|
||||
gamma: float = 1.0, lam: float = 1.0, seed: int = 42):
|
||||
self.hidden = hidden
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.lam = lam
|
||||
self.seed = seed
|
||||
|
||||
# Encoders (each conditioned on upstream z)
|
||||
self.enc0 = VAEEncoder(TIER0_DIM, hidden, Z0_DIM, seed=seed)
|
||||
self.enc1 = VAEEncoder(TIER1_DIM + TIER4_DIM + Z0_DIM, hidden, Z1_DIM, seed=seed+10)
|
||||
self.enc2 = VAEEncoder(TIER2_DIM + TIER3_DIM + Z0_DIM + Z1_DIM, hidden, Z2_DIM, seed=seed+20)
|
||||
self.enc3 = VAEEncoder(TIER3_DIM + TIER4_DIM + Z0_DIM, hidden//2, Z3_DIM, seed=seed+30)
|
||||
|
||||
# Decoders
|
||||
self.dec0 = VAEDecoder(Z0_DIM, hidden, TIER0_DIM, seed=seed)
|
||||
self.dec1 = VAEDecoder(Z0_DIM + Z1_DIM, hidden, TIER1_DIM, seed=seed+10)
|
||||
self.dec2 = VAEDecoder(Z0_DIM + Z1_DIM + Z2_DIM, hidden, TIER2_DIM, seed=seed+20)
|
||||
self.dec3 = VAEDecoder(Z0_DIM + Z3_DIM, hidden//2, TIER3_DIM, seed=seed+30)
|
||||
self.dec4 = VAEDecoder(Z0_DIM, hidden//2, TIER4_DIM, seed=seed+40)
|
||||
|
||||
# Normalisation statistics (fit on training data)
|
||||
self._mu_t0 = np.zeros(TIER0_DIM, dtype=np.float32)
|
||||
self._sd_t0 = np.ones(TIER0_DIM, dtype=np.float32)
|
||||
self._mu_t1 = np.zeros(TIER1_DIM, dtype=np.float32)
|
||||
self._sd_t1 = np.ones(TIER1_DIM, dtype=np.float32)
|
||||
self._mu_t2 = np.zeros(TIER2_DIM, dtype=np.float32)
|
||||
self._sd_t2 = np.ones(TIER2_DIM, dtype=np.float32)
|
||||
self._mu_t3 = np.zeros(TIER3_DIM, dtype=np.float32)
|
||||
self._sd_t3 = np.ones(TIER3_DIM, dtype=np.float32)
|
||||
self._mu_t4 = np.zeros(TIER4_DIM, dtype=np.float32)
|
||||
self._sd_t4 = np.ones(TIER4_DIM, dtype=np.float32)
|
||||
|
||||
self.step_t = 0 # Adam time step
|
||||
self.train_losses: List[Dict] = []
|
||||
|
||||
# ── Normalisation ──────────────────────────────────────────────────────
|
||||
|
||||
def fit_normaliser(self, X: np.ndarray, mask: np.ndarray):
|
||||
"""Fit per-tier normalisation on a representative sample."""
|
||||
def _fit(rows, dim):
|
||||
if len(rows) < 2:
|
||||
return np.zeros(dim, dtype=np.float32), np.ones(dim, dtype=np.float32)
|
||||
return rows.mean(0).astype(np.float32), (rows.std(0) + EPS).astype(np.float32)
|
||||
|
||||
self._mu_t0, self._sd_t0 = _fit(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM], TIER0_DIM)
|
||||
idx1 = mask[:, 1]
|
||||
self._mu_t1, self._sd_t1 = _fit(X[idx1, T_OFF[1]:T_OFF[1]+TIER1_DIM], TIER1_DIM)
|
||||
idx2 = mask[:, 2]
|
||||
self._mu_t2, self._sd_t2 = _fit(X[idx2, T_OFF[2]:T_OFF[2]+TIER2_DIM], TIER2_DIM)
|
||||
idx3 = mask[:, 3]
|
||||
self._mu_t3, self._sd_t3 = _fit(X[idx3, T_OFF[3]:T_OFF[3]+TIER3_DIM], TIER3_DIM)
|
||||
self._mu_t4, self._sd_t4 = _fit(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM], TIER4_DIM)
|
||||
|
||||
def _norm(self, x, mu, sd): return (x - mu) / sd
|
||||
def _norm0(self, x): return self._norm(x, self._mu_t0, self._sd_t0)
|
||||
def _norm1(self, x): return self._norm(x, self._mu_t1, self._sd_t1)
|
||||
def _norm2(self, x): return self._norm(x, self._mu_t2, self._sd_t2)
|
||||
def _norm3(self, x): return self._norm(x, self._mu_t3, self._sd_t3)
|
||||
def _norm4(self, x): return self._norm(x, self._mu_t4, self._sd_t4)
|
||||
|
||||
# ── Forward pass ──────────────────────────────────────────────────────
|
||||
|
||||
def _split(self, X):
|
||||
"""Split 111-dim row into normalised tier vectors."""
|
||||
t0 = self._norm0(X[:, T_OFF[0]:T_OFF[0]+TIER0_DIM])
|
||||
t1 = self._norm1(X[:, T_OFF[1]:T_OFF[1]+TIER1_DIM])
|
||||
t2 = self._norm2(X[:, T_OFF[2]:T_OFF[2]+TIER2_DIM])
|
||||
t3 = self._norm3(X[:, T_OFF[3]:T_OFF[3]+TIER3_DIM])
|
||||
t4 = self._norm4(X[:, T_OFF[4]:T_OFF[4]+TIER4_DIM])
|
||||
return t0, t1, t2, t3, t4
|
||||
|
||||
def encode(self, X: np.ndarray, mask: np.ndarray, rng) -> Dict:
|
||||
t0, t1, t2, t3, t4 = self._split(X)
|
||||
|
||||
mu0, lv0 = self.enc0.forward(t0)
|
||||
z0 = reparametrize(mu0, lv0, rng)
|
||||
|
||||
mu1, lv1 = self.enc1.forward(np.concatenate([t1, t4, z0], axis=1))
|
||||
z1 = reparametrize(mu1, lv1, rng)
|
||||
|
||||
mu2, lv2 = self.enc2.forward(np.concatenate([t2, t3, z0, z1], axis=1))
|
||||
z2 = reparametrize(mu2, lv2, rng)
|
||||
|
||||
mu3, lv3 = self.enc3.forward(np.concatenate([t3, t4, z0], axis=1))
|
||||
z3 = reparametrize(mu3, lv3, rng)
|
||||
|
||||
return dict(t0=t0, t1=t1, t2=t2, t3=t3, t4=t4,
|
||||
mu0=mu0, lv0=lv0, z0=z0,
|
||||
mu1=mu1, lv1=lv1, z1=z1,
|
||||
mu2=mu2, lv2=lv2, z2=z2,
|
||||
mu3=mu3, lv3=lv3, z3=z3)
|
||||
|
||||
def decode(self, enc: Dict) -> Dict:
|
||||
z0, z1, z2, z3 = enc['z0'], enc['z1'], enc['z2'], enc['z3']
|
||||
x0_hat = self.dec0.forward(z0)
|
||||
x1_hat = self.dec1.forward(np.concatenate([z0, z1], axis=1))
|
||||
x2_hat = self.dec2.forward(np.concatenate([z0, z1, z2], axis=1))
|
||||
x3_hat = self.dec3.forward(np.concatenate([z0, z3], axis=1))
|
||||
x4_hat = self.dec4.forward(z0) # EsoF decoded from macro only
|
||||
return dict(x0_hat=x0_hat, x1_hat=x1_hat, x2_hat=x2_hat,
|
||||
x3_hat=x3_hat, x4_hat=x4_hat)
|
||||
|
||||
# ── Loss and gradients ─────────────────────────────────────────────────
|
||||
|
||||
def loss_and_grads(self, enc: Dict, dec: Dict, mask: np.ndarray,
|
||||
N_dataset: int, phase: int) -> Dict:
|
||||
"""Compute total loss and all gradients via chain rule."""
|
||||
t0, t1, t2 = enc['t0'], enc['t1'], enc['t2']
|
||||
x0_hat = dec['x0_hat']
|
||||
x1_hat = dec['x1_hat']
|
||||
x2_hat = dec['x2_hat']
|
||||
B = len(t0)
|
||||
|
||||
# Mask weights: per-sample, per-tier
|
||||
w1 = mask[:, 1].astype(np.float32)[:, None] # (B,1)
|
||||
w2 = mask[:, 2].astype(np.float32)[:, None]
|
||||
|
||||
# Reconstruction losses
|
||||
recon0 = np.mean((x0_hat - t0)**2)
|
||||
recon1 = np.mean(w1 * (x1_hat - t1)**2)
|
||||
recon2 = np.mean(w2 * (x2_hat - t2)**2)
|
||||
|
||||
# Reconstruction gradients
|
||||
dr0 = 2 * (x0_hat - t0) / (B * TIER0_DIM)
|
||||
dr1 = 2 * w1 * (x1_hat - t1) / (B * TIER1_DIM)
|
||||
dr2 = 2 * w2 * (x2_hat - t2) / (B * TIER2_DIM)
|
||||
|
||||
# KL losses
|
||||
kl0, dmu0_kl, dlv0_kl = btcvae_kl(enc['mu0'], enc['lv0'], N_dataset, self.beta, self.gamma, self.lam)
|
||||
kl1, dmu1_kl, dlv1_kl = btcvae_kl(enc['mu1'], enc['lv1'], N_dataset, self.beta, self.gamma, self.lam)
|
||||
kl2, dmu2_kl, dlv2_kl = btcvae_kl(enc['mu2'], enc['lv2'], N_dataset, self.beta, self.gamma, self.lam)
|
||||
|
||||
# Phase-based scaling: only activate eigen/pricing in correct phase
|
||||
if phase < 1:
|
||||
kl1 = recon1 = 0.0; dmu1_kl[:] = 0; dlv1_kl[:] = 0; dr1[:] = 0
|
||||
if phase < 2:
|
||||
kl2 = recon2 = 0.0; dmu2_kl[:] = 0; dlv2_kl[:] = 0; dr2[:] = 0
|
||||
|
||||
total = recon0 + recon1 + recon2 + kl0 + kl1 + kl2
|
||||
|
||||
return dict(
|
||||
total=total, recon0=recon0, recon1=recon1, recon2=recon2,
|
||||
kl0=kl0, kl1=kl1, kl2=kl2,
|
||||
dr0=dr0, dr1=dr1, dr2=dr2,
|
||||
dmu0_kl=dmu0_kl, dlv0_kl=dlv0_kl,
|
||||
dmu1_kl=dmu1_kl, dlv1_kl=dlv1_kl,
|
||||
dmu2_kl=dmu2_kl, dlv2_kl=dlv2_kl,
|
||||
)
|
||||
|
||||
def backward(self, enc: Dict, dec: Dict, grads: Dict, phase: int, lr: float):
|
||||
"""Backpropagate through all encoders and decoders, then Adam step."""
|
||||
z0, z1, z2 = enc['z0'], enc['z1'], enc['z2']
|
||||
t = self.step_t
|
||||
|
||||
# ── Decoder 2 (z0+z1+z2 → x2) ──────────────────────────────────
|
||||
if phase >= 2:
|
||||
dz_all_2 = self.dec2.backward(grads['dr2']) # (B, Z0+Z1+Z2)
|
||||
dz0_from_d2 = dz_all_2[:, :Z0_DIM]
|
||||
dz1_from_d2 = dz_all_2[:, Z0_DIM:Z0_DIM + Z1_DIM]
|
||||
dz2_from_d2 = dz_all_2[:, Z0_DIM + Z1_DIM:]
|
||||
else:
|
||||
dz0_from_d2 = np.zeros_like(z0)
|
||||
dz1_from_d2 = np.zeros_like(z1)
|
||||
dz2_from_d2 = np.zeros_like(z2)
|
||||
self.dec2.backward(np.zeros_like(grads['dr2']))
|
||||
|
||||
# ── Encoder 2 ─────────────────────────────────────────────────
|
||||
# enc2 input = concat(T2, T3, z0, z1) dims = TIER2+TIER3+Z0+Z1
|
||||
if phase >= 2:
|
||||
dmu2 = dz2_from_d2 + grads['dmu2_kl']
|
||||
dlv2 = dz2_from_d2 * 0.5 + grads['dlv2_kl']
|
||||
dinp2 = self.enc2.backward(dmu2, dlv2) # (B, T2+T3+Z0+Z1)
|
||||
_off = TIER2_DIM + TIER3_DIM
|
||||
dz0_from_e2 = dinp2[:, _off:_off + Z0_DIM]
|
||||
dz1_from_e2 = dinp2[:, _off + Z0_DIM:]
|
||||
self.dec2.step(lr, t)
|
||||
self.enc2.step(lr, t)
|
||||
else:
|
||||
dz0_from_e2 = np.zeros_like(z0)
|
||||
dz1_from_e2 = np.zeros_like(z1)
|
||||
|
||||
# ── Decoder 1 (z0+z1 → x1) ──────────────────────────────────────
|
||||
if phase >= 1:
|
||||
dz_all_1 = self.dec1.backward(grads['dr1']) # (B, Z0+Z1)
|
||||
dz0_from_d1 = dz_all_1[:, :Z0_DIM]
|
||||
dz1_from_d1 = dz_all_1[:, Z0_DIM:]
|
||||
else:
|
||||
dz0_from_d1 = np.zeros_like(z0)
|
||||
dz1_from_d1 = np.zeros_like(z1)
|
||||
self.dec1.backward(np.zeros_like(grads['dr1']))
|
||||
|
||||
# ── Encoder 1 ─────────────────────────────────────────────────
|
||||
# enc1 input = concat(T1, T4, z0) dims = TIER1+TIER4+Z0
|
||||
if phase >= 1:
|
||||
dmu1 = dz1_from_d1 + dz1_from_d2 + dz1_from_e2 + grads['dmu1_kl']
|
||||
dlv1 = dz1_from_d1 * 0.5 + grads['dlv1_kl']
|
||||
dinp1 = self.enc1.backward(dmu1, dlv1) # (B, T1+T4+Z0)
|
||||
dz0_from_e1 = dinp1[:, TIER1_DIM + TIER4_DIM:] # correct: skip T1+T4
|
||||
self.dec1.step(lr, t)
|
||||
self.enc1.step(lr, t)
|
||||
else:
|
||||
dz0_from_e1 = np.zeros_like(z0)
|
||||
|
||||
# ── Decoder 0 (z0 → x0) ─────────────────────────────────────────
|
||||
dz0_from_d0 = self.dec0.backward(grads['dr0']) # (B, Z0)
|
||||
|
||||
# ── Encoder 0 ────────────────────────────────────────────────────
|
||||
dmu0 = dz0_from_d0 + dz0_from_d1 + dz0_from_d2 + dz0_from_e1 + dz0_from_e2 + grads['dmu0_kl']
|
||||
dlv0 = dz0_from_d0 * 0.5 + grads['dlv0_kl']
|
||||
self.enc0.backward(dmu0, dlv0)
|
||||
self.enc0.step(lr, t)
|
||||
self.dec0.step(lr, t)
|
||||
|
||||
# ── Training ──────────────────────────────────────────────────────────
|
||||
|
||||
def train_epoch(self, X: np.ndarray, mask: np.ndarray,
|
||||
lr: float, batch_size: int, phase: int, rng) -> Dict:
|
||||
N = len(X)
|
||||
idx = rng.permutation(N)
|
||||
epoch_stats = dict(total=0, recon0=0, recon1=0, recon2=0, kl0=0, kl1=0, kl2=0, n=0)
|
||||
|
||||
for start in range(0, N, batch_size):
|
||||
bi = idx[start:start + batch_size]
|
||||
Xb = X[bi]
|
||||
mb = mask[bi]
|
||||
self.step_t += 1 # Adam global step (cumulative across all epochs)
|
||||
|
||||
enc_out = self.encode(Xb, mb, rng)
|
||||
dec_out = self.decode(enc_out)
|
||||
grads = self.loss_and_grads(enc_out, dec_out, mb, N, phase)
|
||||
self.backward(enc_out, dec_out, grads, phase, lr)
|
||||
|
||||
for k in ['total', 'recon0', 'recon1', 'recon2', 'kl0', 'kl1', 'kl2']:
|
||||
epoch_stats[k] += float(grads[k]) * len(bi)
|
||||
epoch_stats['n'] += len(bi)
|
||||
|
||||
return {k: v / max(epoch_stats['n'], 1) for k, v in epoch_stats.items() if k != 'n'}
|
||||
|
||||
# ── Inference ─────────────────────────────────────────────────────────
|
||||
|
||||
def get_latents(self, X: np.ndarray, mask: np.ndarray) -> Dict[str, np.ndarray]:
|
||||
"""Return disentangled latent codes for analysis."""
|
||||
rng = np.random.RandomState(0) # deterministic for inference
|
||||
enc = self.encode(X, mask, rng)
|
||||
return {
|
||||
'z0_macro': enc['mu0'], # (N, 4) deterministic mean
|
||||
'z1_eigen': enc['mu1'], # (N, 8)
|
||||
'z2_xsection': enc['mu2'], # (N, 8)
|
||||
'z_all': np.concatenate([enc['mu0'], enc['mu1'], enc['mu2']], axis=1), # (N, 20)
|
||||
'sigma0': np.exp(0.5 * enc['lv0']), # uncertainty
|
||||
'sigma1': np.exp(0.5 * enc['lv1']),
|
||||
'sigma2': np.exp(0.5 * enc['lv2']),
|
||||
}
|
||||
|
||||
# ── Persistence ───────────────────────────────────────────────────────
|
||||
|
||||
def save(self, path: str):
|
||||
data = {
|
||||
'norm': dict(mu_t0=self._mu_t0, sd_t0=self._sd_t0,
|
||||
mu_t1=self._mu_t1, sd_t1=self._sd_t1,
|
||||
mu_t2=self._mu_t2, sd_t2=self._sd_t2),
|
||||
'step_t': self.step_t,
|
||||
'losses': self.train_losses,
|
||||
}
|
||||
# Save all weight matrices
|
||||
for name, enc in [('enc0', self.enc0), ('enc1', self.enc1), ('enc2', self.enc2)]:
|
||||
for i, layer in enumerate(enc.mlp.layers):
|
||||
data[f'{name}_mlp{i}_W'] = layer.W
|
||||
data[f'{name}_mlp{i}_b'] = layer.b
|
||||
data[f'{name}_mu_W'] = enc.mu_head.W; data[f'{name}_mu_b'] = enc.mu_head.b
|
||||
data[f'{name}_lv_W'] = enc.lv_head.W; data[f'{name}_lv_b'] = enc.lv_head.b
|
||||
for name, dec in [('dec0', self.dec0), ('dec1', self.dec1), ('dec2', self.dec2)]:
|
||||
for i, layer in enumerate(dec.mlp.layers):
|
||||
data[f'{name}_mlp{i}_W'] = layer.W
|
||||
data[f'{name}_mlp{i}_b'] = layer.b
|
||||
np.savez_compressed(path, **data)
|
||||
print(f"Model saved: {path}.npz")
|
||||
Reference in New Issue
Block a user