Files
DOLPHIN/nautilus_dolphin/dvae/convnext_dvae.py
hjnormey 01c19662cb initial: import DOLPHIN baseline 2026-04-21 from dolphinng5_predict working tree
Includes core prod + GREEN/BLUE subsystems:
- prod/ (BLUE harness, configs, scripts, docs)
- nautilus_dolphin/ (GREEN Nautilus-native impl + dvae/ preserved)
- adaptive_exit/ (AEM engine + models/bucket_assignments.pkl)
- Observability/ (EsoF advisor, TUI, dashboards)
- external_factors/ (EsoF producer)
- mc_forewarning_qlabs_fork/ (MC regime/envelope)

Excludes runtime caches, logs, backups, and reproducible artifacts per .gitignore.
2026-04-21 16:58:38 +02:00

1049 lines
39 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ConvNeXt-1D β-TCVAE — pure numpy (no PyTorch required).
========================================================
Architecture
------------
Input : (B, C_in=8, T=32) — 8 eigenvalue channels × 32 timestep window
Stem : PointwiseProj(8→32) + LayerNorm
Stage 0: 3 × ConvNeXtBlock1D(32, dw_k=7)
Pool : AdaptiveAvgPool(16) + PointwiseProj(32→64)
Stage 1: 3 × ConvNeXtBlock1D(64, dw_k=7)
Head : GlobalAvgPool → Linear(64→32) for z_mu / z_logvar
Decoder (mirrored):
Expand : Linear(32→64) → repeat 16 times along T axis
Stage 1D: 3 × ConvNeXtBlock1D(64, dw_k=7)
Upsample: RepeatInterleave(×2) + PointwiseProj(64→32)
Stage 0D: 3 × ConvNeXtBlock1D(32, dw_k=7)
Out : PointwiseProj(32→8)
Objective: β-TCVAE (Chen et al. 2018, "Isolating Sources of Disentanglement")
β_tc=4.0, α_mi=1.0. Minibatch TC estimator.
Gradient implementation notes
------------------------------
All depthwise convolutions use the im2col trick so that gradients reduce to
standard matmul / einsum — no custom numerical differentiation required.
DWConv1d forward : windows[b,c,t,k] = x_pad[b,c,t+k]
out[b,c,t] = sum_k windows * w[c,k] + b[c]
DWConv1d backward: grad_w[c] = einsum('btk,bt->k', windows_c, grad_out_c) over B
col2im for grad_x (accumulate k-shifted copies of grad_out * w)
β-TCVAE TC gradient: computed via minibatch weight matrix w[i,j] ∝ q(z_i|x_j)
"""
from __future__ import annotations
import json
import numpy as np
from scipy.special import erf
try:
from numba import njit as _njit
_NUMBA = True
except ImportError:
_NUMBA = False
def _njit(fn): # no-op decorator
return fn
__all__ = ['ConvNeXtVAE', 'btcvae_loss', 'btcvae_loss_backward']
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
def _kaiming(fan_in: int, fan_out: int, rng: np.random.RandomState) -> np.ndarray:
std = np.sqrt(2.0 / fan_in)
return rng.randn(fan_in, fan_out) * std
def _gelu(x: np.ndarray) -> np.ndarray:
return 0.5 * x * (1.0 + erf(x / np.sqrt(2.0)))
def _gelu_grad(x: np.ndarray) -> np.ndarray:
# Two allocations instead of three: out= avoids the final x*pdf temporary
cdf = 0.5 * (1.0 + erf(x * (1.0 / np.sqrt(2.0))))
pdf = np.exp(-0.5 * x * x)
pdf *= 1.0 / np.sqrt(2.0 * np.pi)
np.multiply(x, pdf, out=pdf) # pdf now holds x * pdf, in-place
pdf += cdf # pdf now holds cdf + x*pdf
return pdf
def _logsumexp(a: np.ndarray, axis: int) -> np.ndarray:
a_max = a.max(axis=axis, keepdims=True)
out = np.log(np.exp(a - a_max).sum(axis=axis, keepdims=True)) + a_max
return out.squeeze(axis)
# ---------------------------------------------------------------------------
# Parameter container — stores weight, gradient, Adam state
# ---------------------------------------------------------------------------
class Param:
__slots__ = ('data', 'grad', '_m', '_v', 't')
def __init__(self, data: np.ndarray):
self.data = data.astype(np.float64)
self.grad = np.zeros_like(self.data)
self._m = np.zeros_like(self.data)
self._v = np.zeros_like(self.data)
self.t = 0
def zero_grad(self):
self.grad.fill(0.0)
def adam_step(self, lr: float, β1: float = 0.9, β2: float = 0.999,
eps: float = 1e-8, wd: float = 0.0):
self.t += 1
g = self.grad
if wd > 0:
g = g + wd * self.data
self._m = β1 * self._m + (1 - β1) * g
self._v = β2 * self._v + (1 - β2) * g * g
m_hat = self._m / (1 - β1 ** self.t)
v_hat = self._v / (1 - β2 ** self.t)
self.data -= lr * m_hat / (np.sqrt(v_hat) + eps)
# ---------------------------------------------------------------------------
# im2col helpers
# ---------------------------------------------------------------------------
def _im2col1d(x: np.ndarray, k: int, pad: int) -> np.ndarray:
"""
x : (B, C, T)
returns: (B, C, T, k) — zero-padded windows
Uses np.lib.stride_tricks.as_strided for zero-copy window view,
then immediately makes a contiguous copy so writes to the returned
array never corrupt x_pad.
"""
B, C, T = x.shape
x_pad = np.pad(x, ((0, 0), (0, 0), (pad, pad)))
# stride-trick view: step one element along T axis per k-position
shape = (B, C, T, k)
s0, s1, s2 = x_pad.strides
strides = (s0, s1, s2, s2)
windows = np.lib.stride_tricks.as_strided(x_pad, shape=shape, strides=strides)
return windows.copy() # contiguous copy — safe for in-place ops downstream
@_njit(cache=True)
def _col2im1d_kernel(g_pad, grad_win, T, k):
"""Numba-JIT accumulation loop for col2im."""
for i in range(k):
g_pad[:, :, i:i + T] += grad_win[:, :, :, i]
def _col2im1d(grad_win: np.ndarray, T: int, k: int, pad: int) -> np.ndarray:
"""
grad_win: (B, C, T, k)
returns : (B, C, T)
"""
B, C = grad_win.shape[0], grad_win.shape[1]
g_pad = np.zeros((B, C, T + 2 * pad))
_col2im1d_kernel(g_pad, grad_win, T, k)
return g_pad[:, :, pad:pad + T]
# ---------------------------------------------------------------------------
# Layer: DWConv1d (depthwise conv via im2col)
# ---------------------------------------------------------------------------
class DWConv1d:
def __init__(self, C: int, k: int, rng: np.random.RandomState):
self.C, self.k, self.pad = C, k, k // 2
self.w = Param(rng.randn(C, k) * 0.02)
self.b = Param(np.zeros(C))
self._cache: tuple | None = None
def params(self) -> list[Param]:
return [self.w, self.b]
def forward(self, x: np.ndarray) -> np.ndarray:
"""x: (B, C, T) → (B, C, T)"""
wins = _im2col1d(x, self.k, self.pad) # (B, C, T, k)
self._cache = (wins,)
# einsum with optimize: uses BLAS when possible
out = np.einsum('bctk,ck->bct', wins, self.w.data, optimize=True)
out += self.b.data[None, :, None]
return out
def backward(self, grad_out: np.ndarray) -> np.ndarray:
"""grad_out: (B, C, T) → grad_x: (B, C, T)"""
wins, = self._cache
B, C, T = grad_out.shape
# grad w: einsum over B and T
self.w.grad += np.einsum('bct,bctk->ck', grad_out, wins, optimize=True)
# grad b
self.b.grad += grad_out.sum(axis=(0, 2)) # (C,)
# grad x via col2im
grad_win = np.einsum('bct,ck->bctk', grad_out, self.w.data, optimize=True)
return _col2im1d(grad_win, T, self.k, self.pad)
# ---------------------------------------------------------------------------
# Layer: LayerNorm (over last axis, applied to (B, T, C) or (B, C))
# ---------------------------------------------------------------------------
class LayerNorm:
EPS = 1e-6
def __init__(self, D: int, rng: np.random.RandomState | None = None):
self.D = D
self.gamma = Param(np.ones(D))
self.beta = Param(np.zeros(D))
self._cache: tuple | None = None
def params(self) -> list[Param]:
return [self.gamma, self.beta]
def forward(self, x: np.ndarray) -> np.ndarray:
"""Normalize last axis."""
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x_hat = (x - mean) / np.sqrt(var + self.EPS)
self._cache = (x_hat, var)
return self.gamma.data * x_hat + self.beta.data
def backward(self, dy: np.ndarray) -> np.ndarray:
x_hat, var = self._cache
D = x_hat.shape[-1]
orig_shape = dy.shape
dy_2d = dy.reshape(-1, D)
xhat_2d = x_hat.reshape(-1, D)
var_2d = var.reshape(-1, 1)
self.gamma.grad += (dy_2d * xhat_2d).sum(0)
self.beta.grad += dy_2d.sum(0)
dx_hat = dy_2d * self.gamma.data # (N, D)
std_inv = 1.0 / np.sqrt(var_2d + self.EPS) # (N, 1)
dx = std_inv * (dx_hat
- dx_hat.mean(-1, keepdims=True)
- xhat_2d * (dx_hat * xhat_2d).mean(-1, keepdims=True))
return dx.reshape(orig_shape)
# ---------------------------------------------------------------------------
# Layer: Linear
# ---------------------------------------------------------------------------
class Linear:
def __init__(self, in_f: int, out_f: int, rng: np.random.RandomState,
bias: bool = True):
self.W = Param(_kaiming(in_f, out_f, rng))
self.b = Param(np.zeros(out_f)) if bias else None
self._x_cache: np.ndarray | None = None
def params(self) -> list[Param]:
return [self.W] + ([self.b] if self.b is not None else [])
def forward(self, x: np.ndarray) -> np.ndarray:
self._x_cache = x
out = x @ self.W.data
if self.b is not None:
out = out + self.b.data
return out
def backward(self, dy: np.ndarray) -> np.ndarray:
x = self._x_cache
x2 = x.reshape(-1, self.W.data.shape[0])
dy2 = dy.reshape(-1, self.W.data.shape[1])
self.W.grad += x2.T @ dy2
if self.b is not None:
self.b.grad += dy2.sum(0)
return (dy2 @ self.W.data.T).reshape(x.shape)
# ---------------------------------------------------------------------------
# Block: ConvNeXtBlock1D
# ---------------------------------------------------------------------------
# Input/output: (B, C, T)
# 1. DWConv1d(C, k)
# 2. Permute → (B, T, C)
# 3. LayerNorm(C)
# 4. Linear(C → 4C) + GELU
# 5. Linear(4C → C)
# 6. Permute → (B, C, T)
# 7. Skip
class ConvNeXtBlock1D:
def __init__(self, C: int, k: int, rng: np.random.RandomState):
self.dwconv = DWConv1d(C, k, rng)
self.ln = LayerNorm(C)
self.fc1 = Linear(C, 4 * C, rng)
self.fc2 = Linear(4 * C, C, rng)
self._cache: dict | None = None
def params(self) -> list[Param]:
return (self.dwconv.params() + self.ln.params() +
self.fc1.params() + self.fc2.params())
def forward(self, x: np.ndarray) -> np.ndarray:
"""x: (B, C, T) → (B, C, T)"""
h = self.dwconv.forward(x) # (B, C, T)
h = h.transpose(0, 2, 1) # (B, T, C)
h = self.ln.forward(h) # (B, T, C)
h_fc1 = self.fc1.forward(h) # (B, T, 4C) — pre-GELU
act = _gelu(h_fc1)
h = self.fc2.forward(act) # (B, T, C)
h = h.transpose(0, 2, 1) # (B, C, T)
self._pre_gelu = h_fc1 # cache for backward
return x + h # skip
def backward(self, grad_out: np.ndarray) -> np.ndarray:
"""grad_out: (B, C, T) → grad_x: (B, C, T)"""
# skip: grad flows unchanged to x, AND through the block
grad_block = grad_out.copy()
# 6. permute (B, C, T) → (B, T, C)
grad_block = grad_block.transpose(0, 2, 1)
# 5. fc2 backward → grad w.r.t. GELU output (act), shape (B, T, 4C)
grad_block = self.fc2.backward(grad_block)
# 4. GELU backward (pre_gelu cached during forward)
grad_block = grad_block * _gelu_grad(self._pre_gelu)
# 4. fc1 backward
grad_block = self.fc1.backward(grad_block)
# 3. LN backward
grad_block = self.ln.backward(grad_block)
# 2. permute (B, T, C) → (B, C, T)
grad_block = grad_block.transpose(0, 2, 1)
# 1. DWConv backward
grad_block = self.dwconv.backward(grad_block)
return grad_out + grad_block # skip: grad_x = grad_out + grad_block
# ---------------------------------------------------------------------------
# Pool: AdaptiveAvgPool1D (fixed stride, e.g. T=32→T=16 stride=2)
# ---------------------------------------------------------------------------
class AvgPool1D:
"""Stride-2 average pooling over temporal axis."""
def __init__(self, stride: int = 2):
self.stride = stride
self._T_in: int | None = None
def params(self) -> list[Param]:
return []
def forward(self, x: np.ndarray) -> np.ndarray:
"""x: (B, C, T) → (B, C, T//stride)"""
B, C, T = x.shape
self._T_in = T
T_out = T // self.stride
# reshape and mean
return x[:, :, :T_out * self.stride].reshape(B, C, T_out, self.stride).mean(-1)
def backward(self, grad_out: np.ndarray) -> np.ndarray:
B, C, T_out = grad_out.shape
T_in = self._T_in
# each output contributes 1/stride to each input in the pool
g = grad_out[:, :, :, None] * np.ones((1, 1, 1, self.stride)) / self.stride
g = g.reshape(B, C, T_out * self.stride)
if T_in > T_out * self.stride:
g = np.pad(g, ((0, 0), (0, 0), (0, T_in - T_out * self.stride)))
return g
# ---------------------------------------------------------------------------
# Upsample: nearest-neighbor ×2 along temporal axis
# ---------------------------------------------------------------------------
class Upsample1D:
"""Nearest-neighbor upsampling ×factor along temporal axis."""
def __init__(self, factor: int = 2):
self.factor = factor
def params(self) -> list[Param]:
return []
def forward(self, x: np.ndarray) -> np.ndarray:
return np.repeat(x, self.factor, axis=-1)
def backward(self, grad_out: np.ndarray) -> np.ndarray:
f = self.factor
B, C, T = grad_out.shape
T_in = T // f
return grad_out.reshape(B, C, T_in, f).sum(-1)
# ---------------------------------------------------------------------------
# PointwiseProj: 1×1 conv (= Linear applied per timestep)
# ---------------------------------------------------------------------------
class PointwiseProj:
"""(B, C_in, T) → (B, C_out, T) via shared Linear(C_in, C_out)."""
def __init__(self, C_in: int, C_out: int, rng: np.random.RandomState):
self.lin = Linear(C_in, C_out, rng)
def params(self) -> list[Param]:
return self.lin.params()
def forward(self, x: np.ndarray) -> np.ndarray:
"""x: (B, C_in, T)"""
B, C, T = x.shape
x_t = x.transpose(0, 2, 1) # (B, T, C_in)
out = self.lin.forward(x_t) # (B, T, C_out)
return out.transpose(0, 2, 1) # (B, C_out, T)
def backward(self, grad_out: np.ndarray) -> np.ndarray:
g_t = grad_out.transpose(0, 2, 1) # (B, T, C_out)
g_t = self.lin.backward(g_t) # (B, T, C_in)
return g_t.transpose(0, 2, 1) # (B, C_in, T)
# ---------------------------------------------------------------------------
# Main model: ConvNeXtVAE
# ---------------------------------------------------------------------------
class ConvNeXtVAE:
"""
ConvNeXt-1D β-TCVAE.
Parameters
----------
C_in : input channels (default 8)
T_in : input timesteps (default 32)
z_dim : latent dimensionality (default 32)
base_ch : stem channel width (default 32)
dw_k : depthwise kernel size (default 7)
n_blocks : ConvNeXt blocks per stage (default 3)
seed : RNG seed
"""
def __init__(
self,
C_in: int = 8,
T_in: int = 32,
z_dim: int = 32,
base_ch: int = 32,
dw_k: int = 7,
n_blocks: int = 3,
seed: int = 42,
):
rng = np.random.RandomState(seed)
self.C_in = C_in
self.T_in = T_in
self.z_dim = z_dim
self.base_ch = base_ch
ch1 = base_ch * 2 # 64
# ---- ENCODER ----
# Stem
self.stem_proj = PointwiseProj(C_in, base_ch, rng)
self.stem_ln = LayerNorm(base_ch)
# Stage 0
self.stage0 = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)]
# Pool + channel proj
self.pool = AvgPool1D(stride=2)
self.pool_ln = LayerNorm(base_ch)
self.pool_proj = PointwiseProj(base_ch, ch1, rng)
# Stage 1
self.stage1 = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)]
# Global avg pool → z_mu, z_logvar
self.head_mu = Linear(ch1, z_dim, rng)
self.head_logvar = Linear(ch1, z_dim, rng)
# ---- DECODER ----
T_bot = T_in // 2 # bottleneck temporal size (16 for T_in=32)
self.dec_linear = Linear(z_dim, ch1, rng)
# Stage 1D
self.stage1d = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)]
# Upsample + channel proj
self.up = Upsample1D(factor=2)
self.up_proj = PointwiseProj(ch1, base_ch, rng)
self.up_ln = LayerNorm(base_ch)
# Stage 0D
self.stage0d = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)]
# Output projection
self.out_proj = PointwiseProj(base_ch, C_in, rng)
# Cache for backward
self._enc_cache: dict = {}
self._dec_cache: dict = {}
self._T_bot = T_bot
# ------------------------------------------------------------------
def all_params(self) -> list[Param]:
ps = []
ps += self.stem_proj.params() + self.stem_ln.params()
for blk in self.stage0:
ps += blk.params()
ps += (self.pool_ln.params() + self.pool_proj.params())
for blk in self.stage1:
ps += blk.params()
ps += self.head_mu.params() + self.head_logvar.params()
ps += self.dec_linear.params()
for blk in self.stage1d:
ps += blk.params()
ps += self.up_proj.params() + self.up_ln.params()
for blk in self.stage0d:
ps += blk.params()
ps += self.out_proj.params()
return ps
def zero_grad(self):
for p in self.all_params():
p.zero_grad()
def adam_step(self, lr: float, wd: float = 1e-4):
for p in self.all_params():
p.adam_step(lr, wd=wd)
def n_params(self) -> int:
return sum(p.data.size for p in self.all_params())
# ------------------------------------------------------------------
# ENCODER
# ------------------------------------------------------------------
def encode(self, x: np.ndarray):
"""
x: (B, C_in, T_in)
Returns z_mu (B,z), z_logvar (B,z)
"""
c = self._enc_cache
c['x'] = x
# Stem
h = self.stem_proj.forward(x) # (B, base_ch, T)
h = h.transpose(0, 2, 1) # (B, T, base_ch)
h = self.stem_ln.forward(h) # (B, T, base_ch)
h = h.transpose(0, 2, 1) # (B, base_ch, T)
c['after_stem'] = h
# Stage 0
for blk in self.stage0:
h = blk.forward(h)
c['after_stage0'] = h
# Pool + proj
h = self.pool.forward(h) # (B, base_ch, T//2)
h = h.transpose(0, 2, 1) # (B, T//2, base_ch)
h = self.pool_ln.forward(h) # (B, T//2, base_ch)
h = h.transpose(0, 2, 1) # (B, base_ch, T//2)
h = self.pool_proj.forward(h) # (B, ch1, T//2)
c['after_pool'] = h
# Stage 1
for blk in self.stage1:
h = blk.forward(h)
c['after_stage1'] = h
# Global avg pool
h_gap = h.mean(axis=-1) # (B, ch1)
c['gap'] = h_gap
z_mu = self.head_mu.forward(h_gap) # (B, z_dim)
z_logvar = self.head_logvar.forward(h_gap) # (B, z_dim)
return z_mu, z_logvar
# ------------------------------------------------------------------
# DECODER
# ------------------------------------------------------------------
def decode(self, z: np.ndarray) -> np.ndarray:
"""
z: (B, z_dim)
Returns x_recon: (B, C_in, T_in)
"""
c = self._dec_cache
T_bot = self._T_bot
h = self.dec_linear.forward(z) # (B, ch1)
h = h[:, :, None] * np.ones((1, 1, T_bot)) # broadcast (B, ch1, T_bot)
c['dec_expand'] = h
for blk in self.stage1d:
h = blk.forward(h)
c['after_stage1d'] = h
# Upsample + proj
h = self.up.forward(h) # (B, ch1, T_in)
h = self.up_proj.forward(h) # (B, base_ch, T_in)
h = h.transpose(0, 2, 1) # (B, T_in, base_ch)
h = self.up_ln.forward(h) # (B, T_in, base_ch)
h = h.transpose(0, 2, 1) # (B, base_ch, T_in)
c['after_up'] = h
for blk in self.stage0d:
h = blk.forward(h)
x_recon = self.out_proj.forward(h) # (B, C_in, T_in)
return x_recon
# ------------------------------------------------------------------
# BACKWARD (encoder)
# ------------------------------------------------------------------
def backward_encode(self, dz_mu: np.ndarray, dz_logvar: np.ndarray):
"""
Given grad w.r.t. z_mu and z_logvar, backprop through encoder.
Returns grad w.r.t. x (not used further, just for completeness).
"""
c = self._enc_cache
# head
d_gap = self.head_mu.backward(dz_mu) + self.head_logvar.backward(dz_logvar)
# global avg pool backward: distribute evenly over T axis
h_after = c['after_stage1']
T_bot = h_after.shape[-1]
d_h = d_gap[:, :, None] / T_bot * np.ones((1, 1, T_bot)) # (B, ch1, T_bot)
# stage 1 backward
for blk in reversed(self.stage1):
d_h = blk.backward(d_h)
# pool_proj backward
d_h = self.pool_proj.backward(d_h) # (B, base_ch, T//2)
# pool_ln backward
d_h = d_h.transpose(0, 2, 1)
d_h = self.pool_ln.backward(d_h)
d_h = d_h.transpose(0, 2, 1)
# pool backward
d_h = self.pool.backward(d_h) # (B, base_ch, T)
# stage 0 backward
for blk in reversed(self.stage0):
d_h = blk.backward(d_h)
# stem backward
d_h = d_h.transpose(0, 2, 1)
d_h = self.stem_ln.backward(d_h)
d_h = d_h.transpose(0, 2, 1)
d_h = self.stem_proj.backward(d_h)
return d_h # grad w.r.t. input x
# ------------------------------------------------------------------
# BACKWARD (decoder)
# ------------------------------------------------------------------
def backward_decode(self, d_recon: np.ndarray):
"""
d_recon: (B, C_in, T_in) — grad w.r.t. x_recon
Returns grad w.r.t. z.
"""
c = self._dec_cache
T_bot = self._T_bot
# out_proj backward
d_h = self.out_proj.backward(d_recon) # (B, base_ch, T_in)
# stage 0d backward
for blk in reversed(self.stage0d):
d_h = blk.backward(d_h)
# up_ln backward
d_h = d_h.transpose(0, 2, 1)
d_h = self.up_ln.backward(d_h)
d_h = d_h.transpose(0, 2, 1)
# up_proj backward
d_h = self.up_proj.backward(d_h) # (B, ch1, T_in)
# upsample backward
d_h = self.up.backward(d_h) # (B, ch1, T_bot)
# stage 1d backward
for blk in reversed(self.stage1d):
d_h = blk.backward(d_h)
# expand backward: sum over temporal axis (it was broadcast)
d_h_expand = d_h.sum(axis=-1) # (B, ch1)
# dec_linear backward
dz = self.dec_linear.backward(d_h_expand) # (B, z_dim)
return dz
# ------------------------------------------------------------------
# FORWARD convenience
# ------------------------------------------------------------------
def forward(self, x: np.ndarray):
"""Full forward: encode + reparameterize + decode."""
z_mu, z_logvar = self.encode(x)
eps = np.random.randn(*z_mu.shape)
z = z_mu + eps * np.exp(0.5 * z_logvar)
x_recon = self.decode(z)
return x_recon, z_mu, z_logvar, z
# ------------------------------------------------------------------
# SAVE / LOAD
# ------------------------------------------------------------------
def save(self, path: str, norm_mean=None, norm_std=None,
extra: dict | None = None, save_adam: bool = False):
data: dict = {}
params = self.all_params()
names = self._param_names()
for name, p in zip(names, params):
data[name] = p.data.tolist()
if save_adam:
data[f'{name}__m'] = p._m.tolist()
data[f'{name}__v'] = p._v.tolist()
data[f'{name}__t'] = int(p.t)
if norm_mean is not None:
data['norm_mean'] = norm_mean.tolist()
data['norm_std'] = norm_std.tolist()
data['architecture'] = {
'C_in': self.C_in, 'T_in': self.T_in, 'z_dim': self.z_dim,
'base_ch': self.base_ch,
}
if extra:
data.update(extra)
with open(path, 'w') as f:
json.dump(data, f)
def load(self, path: str) -> dict:
with open(path) as f:
data = json.load(f)
params = self.all_params()
names = self._param_names()
reserved = set(names)
for name, p in zip(names, params):
if name in data:
p.data[:] = np.array(data[name], dtype=np.float64)
if f'{name}__m' in data:
p._m[:] = np.array(data[f'{name}__m'], dtype=np.float64)
p._v[:] = np.array(data[f'{name}__v'], dtype=np.float64)
p.t = int(data[f'{name}__t'])
reserved.update({f'{name}__m', f'{name}__v', f'{name}__t'})
extras = {k: v for k, v in data.items()
if k not in reserved and k != 'architecture'}
return extras
def _param_names(self) -> list[str]:
"""Return stable flat list of parameter name strings matching all_params()."""
names = []
# stem
for i, _ in enumerate(self.stem_proj.params()):
names.append(f'stem_proj_p{i}')
for i, _ in enumerate(self.stem_ln.params()):
names.append(f'stem_ln_p{i}')
# stage0
for bi, blk in enumerate(self.stage0):
for i, _ in enumerate(blk.params()):
names.append(f'stage0_b{bi}_p{i}')
# pool
for i, _ in enumerate(self.pool_ln.params()):
names.append(f'pool_ln_p{i}')
for i, _ in enumerate(self.pool_proj.params()):
names.append(f'pool_proj_p{i}')
# stage1
for bi, blk in enumerate(self.stage1):
for i, _ in enumerate(blk.params()):
names.append(f'stage1_b{bi}_p{i}')
# head
for i, _ in enumerate(self.head_mu.params()):
names.append(f'head_mu_p{i}')
for i, _ in enumerate(self.head_logvar.params()):
names.append(f'head_logvar_p{i}')
# dec_linear
for i, _ in enumerate(self.dec_linear.params()):
names.append(f'dec_linear_p{i}')
# stage1d
for bi, blk in enumerate(self.stage1d):
for i, _ in enumerate(blk.params()):
names.append(f'stage1d_b{bi}_p{i}')
# up
for i, _ in enumerate(self.up_proj.params()):
names.append(f'up_proj_p{i}')
for i, _ in enumerate(self.up_ln.params()):
names.append(f'up_ln_p{i}')
# stage0d
for bi, blk in enumerate(self.stage0d):
for i, _ in enumerate(blk.params()):
names.append(f'stage0d_b{bi}_p{i}')
# out_proj
for i, _ in enumerate(self.out_proj.params()):
names.append(f'out_proj_p{i}')
return names
# ---------------------------------------------------------------------------
# β-TCVAE loss
# ---------------------------------------------------------------------------
def btcvae_loss(
x: np.ndarray,
x_recon: np.ndarray,
z_mu: np.ndarray,
z_logvar: np.ndarray,
z: np.ndarray,
beta_tc: float = 4.0,
alpha_mi: float = 1.0,
free_bits: float = 0.0,
) -> tuple[float, dict]:
"""
β-TCVAE loss with minibatch TC estimator.
Parameters
----------
x : (B, C, T) — input
x_recon : (B, C, T) — reconstruction
z_mu : (B, D)
z_logvar : (B, D)
z : (B, D) — reparameterized sample
beta_tc : weight on Total Correlation term
alpha_mi : weight on Mutual Information term
free_bits : minimum KL per latent dimension (nats). Replaces the minibatch
dimKL term with analytical per-dim KL clamped from below at
free_bits. Prevents posterior collapse. 0 = standard behaviour.
Returns
-------
loss : scalar
info : dict with breakdown (recon, MI, TC, dimKL)
"""
B, D = z.shape
# Reconstruction loss (per-sample mean, then mean over batch)
recon = ((x - x_recon) ** 2).mean(axis=(-1, -2)).mean()
# log q(z|x): (B,)
log_qz_x = -0.5 * (
np.log(2 * np.pi) + z_logvar + (z - z_mu) ** 2 * np.exp(-z_logvar)
).sum(axis=1)
# log p(z): (B,) — standard Gaussian
log_pz = -0.5 * (np.log(2 * np.pi) + z ** 2).sum(axis=1)
# Minibatch estimator for log q(z) and log prod_d q(z_d)
z_i = z[:, None, :] # (B, 1, D)
mu_j = z_mu[None, :, :] # (1, B, D)
lv_j = z_logvar[None, :, :] # (1, B, D)
# (B, B, D)
log_q_ij_d = -0.5 * (
np.log(2 * np.pi) + lv_j + (z_i - mu_j) ** 2 * np.exp(-lv_j)
)
# log q(z_i) = logsumexp_j[sum_d log_q_ij_d] - log B
log_qz = _logsumexp(log_q_ij_d.sum(-1), axis=1) - np.log(B) # (B,)
# log prod_d q(z_d[i]) = sum_d logsumexp_j[log_q_ij_d] - D*log(B)
log_qz_prod = (_logsumexp(log_q_ij_d, axis=1) - np.log(B)).sum(-1) # (B,)
MI = (log_qz_x - log_qz).mean()
TC = (log_qz - log_qz_prod).mean()
if free_bits > 0.0:
# Analytical per-dim KL, clamped: sum_d max(free_bits, KL_d)
kl_analytic = 0.5 * (z_mu ** 2 + np.exp(z_logvar) - 1.0 - z_logvar)
kl_per_dim = kl_analytic.mean(0) # (D,)
dimKL = float(np.maximum(free_bits, kl_per_dim).sum())
else:
dimKL = float((log_qz_prod - log_pz).mean())
loss = recon + alpha_mi * MI + beta_tc * TC + dimKL
return float(loss), {
'recon': float(recon),
'MI': float(MI),
'TC': float(TC),
'dimKL': dimKL,
}
# ---------------------------------------------------------------------------
# β-TCVAE backward
# ---------------------------------------------------------------------------
def btcvae_loss_backward(
x: np.ndarray,
x_recon: np.ndarray,
z_mu: np.ndarray,
z_logvar: np.ndarray,
z: np.ndarray,
eps: np.ndarray,
beta_tc: float = 4.0,
alpha_mi: float = 1.0,
free_bits: float = 0.0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute gradients of β-TCVAE loss w.r.t. x_recon, z_mu, z_logvar.
Parameters
----------
eps: (B, D) — noise used in reparameterisation: z = z_mu + eps*exp(0.5*z_logvar)
Returns
-------
d_recon : (B, C, T) grad w.r.t. x_recon
d_z_mu : (B, D) grad w.r.t. z_mu
d_z_logvar: (B, D) grad w.r.t. z_logvar
"""
B, D = z.shape
var = np.exp(z_logvar) # (B, D)
std = np.exp(0.5 * z_logvar) # (B, D)
# ---- grad w.r.t. x_recon (reconstruction MSE) ----
# L_recon = mean_{b,c,t} (x-x_recon)^2 = (1/(B*C*T)) * ||x - x_recon||^2
# dL/dx_recon = -2*(x - x_recon) / (B*C*T)
n_elements = x.size
d_recon = -2.0 * (x - x_recon) / n_elements # shape (B, C, T)
# ---- KL terms via minibatch estimator ----
# log q(z_i | x_i) (B,D) terms:
err = (z - z_mu) # (B, D)
log_qz_x_d = -0.5 * (np.log(2 * np.pi) + z_logvar + err ** 2 / var)
# Minibatch terms
z_i = z[:, None, :]
mu_j = z_mu[None, :, :]
lv_j = z_logvar[None, :, :]
diff_ij = z_i - mu_j # (B, B, D)
var_j = np.exp(lv_j) # (B, B, D)
log_q_ij_d = -0.5 * (np.log(2 * np.pi) + lv_j + diff_ij ** 2 / var_j) # (B,B,D)
# Weight matrices for log_qz and log_qz_prod
# w_qz[i, j] = softmax_j[ sum_d log_q_ij_d[i, j, :] ]
lq_sum = log_q_ij_d.sum(-1) # (B, B)
lq_sum_max = lq_sum.max(1, keepdims=True)
w_qz = np.exp(lq_sum - lq_sum_max)
w_qz /= w_qz.sum(1, keepdims=True) + 1e-30 # (B, B) softmax over j
# w_prod[i, j, d] = softmax_j[ log_q_ij_d[i, j, d] ] for each d separately
lq_max_d = log_q_ij_d.max(1, keepdims=True) # (B, 1, D)
w_prod = np.exp(log_q_ij_d - lq_max_d)
w_prod /= w_prod.sum(1, keepdims=True) + 1e-30 # (B, B, D)
# ---- d(MI + TC + dimKL) / d(z_logvar) and d(z_mu) ----
# MI = mean_i [ log_qz_x[i] - log_qz[i] ]
# TC = mean_i [ log_qz[i] - log_qz_prod[i] ]
# dimKL = mean_i [ log_qz_prod[i] - log_pz[i] ]
#
# Combined coefficient of each log term:
# d/dtheta = (1/B) * [ alpha * d_log_qz_x - alpha * d_log_qz
# + beta * d_log_qz - beta * d_log_qz_prod
# + d_log_qz_prod - d_log_pz ]
# d log_qz_x[i] / d z_mu[i, d] = err[i,d] / var[i,d] (positive)
# d log_qz_x[i] / d z_logvar[i,d] = -0.5 + 0.5 * err[i,d]^2 / var[i,d]
# d log_qz[i] / d z_mu[j, d] = -w_qz[i, j] * diff_ij[i, j, d] / var_j[i,j,d]
# (gradient w.r.t. z_mu is aggregated over i by summing)
# (note: this is through the *denominator* of q(z_i|x_j), not z itself)
# d log_qz_prod[i,d] / d z_mu[j,d] = -w_prod[i,j,d] * diff_ij[i,j,d]/var_j[i,j,d]
# Additionally, z = mu + eps*std, so d_log_qz/d_z_mu has a DIRECT path via z_i:
# d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (-(z[i,d] - mu_j[d]) / var_j[i,j,d])
# d z[i,d] / d mu[i,d] = 1 → direct path
# For cleaner code, compute:
# G_mu[j, d] = sum_i (alpha_mi + beta_tc) * dlog_qz/dmu_j[d] (through z_mu[j])
# + sum_i (1+beta_tc) * dlog_qz_prod/dmu_j[d] (through z_mu[j])
# Plus direct terms from d_log_qz_x
# Direct terms from log q(z|x):
# d log_qz_x[i,d] / d mu[i,d] = (z[i,d] - mu[i,d]) / var[i,d]
# With reparameterization: d z/d mu = 1, so also flows through z
# d log_qz_x[i,d] / d z[i,d] = -err[i,d] / var[i,d] → d mu[i,d]
# So total: +err/var + (-err/var) from z-path = 0 for the z-direct path
# Actually: d loss / d mu = d loss / d z * dz/dmu + d loss / d mu directly
# We'll compute d_loss/d_z separately and use reparameterization to split.
# ---- Gradient through reparameterization ----
# L is a function of z_mu, z_logvar, and z = z_mu + eps*std.
# dL/d_mu[i,d] = dL/d_z[i,d] * 1 + (dL/d_mu directly)
# dL/d_logvar[i,d]= dL/d_z[i,d] * 0.5*eps[i,d]*std[i,d] + (dL/d_logvar directly)
# Compute dL/d_z[i,d]:
# 1. Reconstruction path (decoder): computed elsewhere and passed in as dz_from_recon
# (we return d_recon so caller can backprop decoder separately)
# Here we only handle KL/TC terms which depend on z.
# 2. From log_qz_x[i] = sum_d log_qz_x_d[i,d]:
# d log_qz_x[i] / d z[i,d] = -(z[i,d] - mu[i,d]) / var[i,d] = -err/var
d_logqzx_d_z = -err / var # (B, D)
# 3. From log_qz[i] via z: (z[i,d] enters log_qz[i] through diff_ij[i,:,d])
# d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (- diff_ij[i,j,d] / var_j[i,j,d])
# But wait: diff_ij[i,j,d] = z[i,d] - mu_j[d], so d(diff_ij)/d(z[i,d]) = 1
d_logqz_d_z = -(w_qz[:, :, None] * diff_ij / var_j).sum(1) # (B, D)
# 4. From log_qz_prod[i] via z:
# d log_qz_prod[i] / d z[i,d] = sum_j w_prod[i,j,d] * (-diff_ij[i,j,d]/var_j[i,j,d])
d_logqzprod_d_z = -(w_prod * diff_ij / var_j).sum(1) # (B, D)
# 5. From log_pz: d log_pz / d z = -z
d_logpz_d_z = -z # (B, D)
# Combine with coefficients — MI + TC terms only; dimKL handled separately below
# when free_bits > 0: skip ALL MI/TC aggregate-posterior terms to prevent
# collapse. Only the analytical per-dim KL (masked) + direct encoder term
# (alpha_mi * err/var below) remain. This is a plain VAE-with-free-bits mode.
# when free_bits == 0: full beta-TCVAE with minibatch TC/MI decomposition.
if free_bits > 0.0:
dL_dz = np.zeros((B, D), dtype=np.float64) # no MI/TC reparameterization gradient
else:
dL_dz = (1.0 / B) * (
alpha_mi * d_logqzx_d_z
- alpha_mi * d_logqz_d_z
+ beta_tc * d_logqz_d_z
- beta_tc * d_logqzprod_d_z
+ d_logqzprod_d_z
- d_logpz_d_z
) # (B, D)
# Gradient through z_mu and z_logvar VIA z (reparameterization):
d_z_mu = dL_dz.copy()
d_z_logvar = dL_dz * 0.5 * eps * std
# Direct gradient w.r.t. z_mu from log q(z|x):
d_z_mu += (alpha_mi / B) * err / var
# Direct gradient w.r.t. z_logvar from log q(z|x):
d_z_logvar += (alpha_mi / B) * (-0.5 + 0.5 * err ** 2 / var)
# Gradient of log_qz and log_qz_prod w.r.t. z_mu[j] directly
d_logqz_d_mu_direct = (
(w_qz[:, :, None] * diff_ij / var_j).sum(0) # (B, D)
)
d_logqzprod_d_mu_direct = (
(w_prod * diff_ij / var_j).sum(0) # (B, D)
)
if free_bits > 0.0:
pass # no MI/TC direct-to-mu gradient in plain-VAE-with-free-bits mode
else:
d_z_mu += (1.0 / B) * (
-alpha_mi * d_logqz_d_mu_direct
+ beta_tc * d_logqz_d_mu_direct
- beta_tc * d_logqzprod_d_mu_direct
+ d_logqzprod_d_mu_direct
)
d_logqz_d_lv_direct = (
w_qz[:, :, None] * (-0.5 + 0.5 * diff_ij ** 2 / var_j)
).sum(0) # (B, D)
d_logqzprod_d_lv_direct = (
w_prod * (-0.5 + 0.5 * diff_ij ** 2 / var_j)
).sum(0) # (B, D)
if free_bits > 0.0:
# no MI/TC direct-to-logvar gradient in plain-VAE-with-free-bits mode
# Analytical free-bits dimKL gradient: only for dims above the floor
kl_analytic = 0.5 * (z_mu ** 2 + var - 1.0 - z_logvar) # (B, D)
kl_per_dim = kl_analytic.mean(0) # (D,)
mask = (kl_per_dim >= free_bits).astype(np.float64) # (D,) — 1 = penalised
d_z_mu += mask[None, :] * z_mu / B
d_z_logvar += mask[None, :] * 0.5 * (var - 1.0) / B
else:
d_z_logvar += (1.0 / B) * (
-alpha_mi * d_logqz_d_lv_direct
+ beta_tc * d_logqz_d_lv_direct
- beta_tc * d_logqzprod_d_lv_direct
+ d_logqzprod_d_lv_direct
)
return d_recon, d_z_mu, d_z_logvar