""" ConvNeXt-1D β-TCVAE — pure numpy (no PyTorch required). ======================================================== Architecture ------------ Input : (B, C_in=8, T=32) — 8 eigenvalue channels × 32 timestep window Stem : PointwiseProj(8→32) + LayerNorm Stage 0: 3 × ConvNeXtBlock1D(32, dw_k=7) Pool : AdaptiveAvgPool(16) + PointwiseProj(32→64) Stage 1: 3 × ConvNeXtBlock1D(64, dw_k=7) Head : GlobalAvgPool → Linear(64→32) for z_mu / z_logvar Decoder (mirrored): Expand : Linear(32→64) → repeat 16 times along T axis Stage 1D: 3 × ConvNeXtBlock1D(64, dw_k=7) Upsample: RepeatInterleave(×2) + PointwiseProj(64→32) Stage 0D: 3 × ConvNeXtBlock1D(32, dw_k=7) Out : PointwiseProj(32→8) Objective: β-TCVAE (Chen et al. 2018, "Isolating Sources of Disentanglement") β_tc=4.0, α_mi=1.0. Minibatch TC estimator. Gradient implementation notes ------------------------------ All depthwise convolutions use the im2col trick so that gradients reduce to standard matmul / einsum — no custom numerical differentiation required. DWConv1d forward : windows[b,c,t,k] = x_pad[b,c,t+k] out[b,c,t] = sum_k windows * w[c,k] + b[c] DWConv1d backward: grad_w[c] = einsum('btk,bt->k', windows_c, grad_out_c) over B col2im for grad_x (accumulate k-shifted copies of grad_out * w) β-TCVAE TC gradient: computed via minibatch weight matrix w[i,j] ∝ q(z_i|x_j) """ from __future__ import annotations import json import numpy as np from scipy.special import erf try: from numba import njit as _njit _NUMBA = True except ImportError: _NUMBA = False def _njit(fn): # no-op decorator return fn __all__ = ['ConvNeXtVAE', 'btcvae_loss', 'btcvae_loss_backward'] # --------------------------------------------------------------------------- # Utility # --------------------------------------------------------------------------- def _kaiming(fan_in: int, fan_out: int, rng: np.random.RandomState) -> np.ndarray: std = np.sqrt(2.0 / fan_in) return rng.randn(fan_in, fan_out) * std def _gelu(x: np.ndarray) -> np.ndarray: return 0.5 * x * (1.0 + erf(x / np.sqrt(2.0))) def _gelu_grad(x: np.ndarray) -> np.ndarray: # Two allocations instead of three: out= avoids the final x*pdf temporary cdf = 0.5 * (1.0 + erf(x * (1.0 / np.sqrt(2.0)))) pdf = np.exp(-0.5 * x * x) pdf *= 1.0 / np.sqrt(2.0 * np.pi) np.multiply(x, pdf, out=pdf) # pdf now holds x * pdf, in-place pdf += cdf # pdf now holds cdf + x*pdf return pdf def _logsumexp(a: np.ndarray, axis: int) -> np.ndarray: a_max = a.max(axis=axis, keepdims=True) out = np.log(np.exp(a - a_max).sum(axis=axis, keepdims=True)) + a_max return out.squeeze(axis) # --------------------------------------------------------------------------- # Parameter container — stores weight, gradient, Adam state # --------------------------------------------------------------------------- class Param: __slots__ = ('data', 'grad', '_m', '_v', 't') def __init__(self, data: np.ndarray): self.data = data.astype(np.float64) self.grad = np.zeros_like(self.data) self._m = np.zeros_like(self.data) self._v = np.zeros_like(self.data) self.t = 0 def zero_grad(self): self.grad.fill(0.0) def adam_step(self, lr: float, β1: float = 0.9, β2: float = 0.999, eps: float = 1e-8, wd: float = 0.0): self.t += 1 g = self.grad if wd > 0: g = g + wd * self.data self._m = β1 * self._m + (1 - β1) * g self._v = β2 * self._v + (1 - β2) * g * g m_hat = self._m / (1 - β1 ** self.t) v_hat = self._v / (1 - β2 ** self.t) self.data -= lr * m_hat / (np.sqrt(v_hat) + eps) # --------------------------------------------------------------------------- # im2col helpers # --------------------------------------------------------------------------- def _im2col1d(x: np.ndarray, k: int, pad: int) -> np.ndarray: """ x : (B, C, T) returns: (B, C, T, k) — zero-padded windows Uses np.lib.stride_tricks.as_strided for zero-copy window view, then immediately makes a contiguous copy so writes to the returned array never corrupt x_pad. """ B, C, T = x.shape x_pad = np.pad(x, ((0, 0), (0, 0), (pad, pad))) # stride-trick view: step one element along T axis per k-position shape = (B, C, T, k) s0, s1, s2 = x_pad.strides strides = (s0, s1, s2, s2) windows = np.lib.stride_tricks.as_strided(x_pad, shape=shape, strides=strides) return windows.copy() # contiguous copy — safe for in-place ops downstream @_njit(cache=True) def _col2im1d_kernel(g_pad, grad_win, T, k): """Numba-JIT accumulation loop for col2im.""" for i in range(k): g_pad[:, :, i:i + T] += grad_win[:, :, :, i] def _col2im1d(grad_win: np.ndarray, T: int, k: int, pad: int) -> np.ndarray: """ grad_win: (B, C, T, k) returns : (B, C, T) """ B, C = grad_win.shape[0], grad_win.shape[1] g_pad = np.zeros((B, C, T + 2 * pad)) _col2im1d_kernel(g_pad, grad_win, T, k) return g_pad[:, :, pad:pad + T] # --------------------------------------------------------------------------- # Layer: DWConv1d (depthwise conv via im2col) # --------------------------------------------------------------------------- class DWConv1d: def __init__(self, C: int, k: int, rng: np.random.RandomState): self.C, self.k, self.pad = C, k, k // 2 self.w = Param(rng.randn(C, k) * 0.02) self.b = Param(np.zeros(C)) self._cache: tuple | None = None def params(self) -> list[Param]: return [self.w, self.b] def forward(self, x: np.ndarray) -> np.ndarray: """x: (B, C, T) → (B, C, T)""" wins = _im2col1d(x, self.k, self.pad) # (B, C, T, k) self._cache = (wins,) # einsum with optimize: uses BLAS when possible out = np.einsum('bctk,ck->bct', wins, self.w.data, optimize=True) out += self.b.data[None, :, None] return out def backward(self, grad_out: np.ndarray) -> np.ndarray: """grad_out: (B, C, T) → grad_x: (B, C, T)""" wins, = self._cache B, C, T = grad_out.shape # grad w: einsum over B and T self.w.grad += np.einsum('bct,bctk->ck', grad_out, wins, optimize=True) # grad b self.b.grad += grad_out.sum(axis=(0, 2)) # (C,) # grad x via col2im grad_win = np.einsum('bct,ck->bctk', grad_out, self.w.data, optimize=True) return _col2im1d(grad_win, T, self.k, self.pad) # --------------------------------------------------------------------------- # Layer: LayerNorm (over last axis, applied to (B, T, C) or (B, C)) # --------------------------------------------------------------------------- class LayerNorm: EPS = 1e-6 def __init__(self, D: int, rng: np.random.RandomState | None = None): self.D = D self.gamma = Param(np.ones(D)) self.beta = Param(np.zeros(D)) self._cache: tuple | None = None def params(self) -> list[Param]: return [self.gamma, self.beta] def forward(self, x: np.ndarray) -> np.ndarray: """Normalize last axis.""" mean = x.mean(axis=-1, keepdims=True) var = x.var(axis=-1, keepdims=True) x_hat = (x - mean) / np.sqrt(var + self.EPS) self._cache = (x_hat, var) return self.gamma.data * x_hat + self.beta.data def backward(self, dy: np.ndarray) -> np.ndarray: x_hat, var = self._cache D = x_hat.shape[-1] orig_shape = dy.shape dy_2d = dy.reshape(-1, D) xhat_2d = x_hat.reshape(-1, D) var_2d = var.reshape(-1, 1) self.gamma.grad += (dy_2d * xhat_2d).sum(0) self.beta.grad += dy_2d.sum(0) dx_hat = dy_2d * self.gamma.data # (N, D) std_inv = 1.0 / np.sqrt(var_2d + self.EPS) # (N, 1) dx = std_inv * (dx_hat - dx_hat.mean(-1, keepdims=True) - xhat_2d * (dx_hat * xhat_2d).mean(-1, keepdims=True)) return dx.reshape(orig_shape) # --------------------------------------------------------------------------- # Layer: Linear # --------------------------------------------------------------------------- class Linear: def __init__(self, in_f: int, out_f: int, rng: np.random.RandomState, bias: bool = True): self.W = Param(_kaiming(in_f, out_f, rng)) self.b = Param(np.zeros(out_f)) if bias else None self._x_cache: np.ndarray | None = None def params(self) -> list[Param]: return [self.W] + ([self.b] if self.b is not None else []) def forward(self, x: np.ndarray) -> np.ndarray: self._x_cache = x out = x @ self.W.data if self.b is not None: out = out + self.b.data return out def backward(self, dy: np.ndarray) -> np.ndarray: x = self._x_cache x2 = x.reshape(-1, self.W.data.shape[0]) dy2 = dy.reshape(-1, self.W.data.shape[1]) self.W.grad += x2.T @ dy2 if self.b is not None: self.b.grad += dy2.sum(0) return (dy2 @ self.W.data.T).reshape(x.shape) # --------------------------------------------------------------------------- # Block: ConvNeXtBlock1D # --------------------------------------------------------------------------- # Input/output: (B, C, T) # 1. DWConv1d(C, k) # 2. Permute → (B, T, C) # 3. LayerNorm(C) # 4. Linear(C → 4C) + GELU # 5. Linear(4C → C) # 6. Permute → (B, C, T) # 7. Skip class ConvNeXtBlock1D: def __init__(self, C: int, k: int, rng: np.random.RandomState): self.dwconv = DWConv1d(C, k, rng) self.ln = LayerNorm(C) self.fc1 = Linear(C, 4 * C, rng) self.fc2 = Linear(4 * C, C, rng) self._cache: dict | None = None def params(self) -> list[Param]: return (self.dwconv.params() + self.ln.params() + self.fc1.params() + self.fc2.params()) def forward(self, x: np.ndarray) -> np.ndarray: """x: (B, C, T) → (B, C, T)""" h = self.dwconv.forward(x) # (B, C, T) h = h.transpose(0, 2, 1) # (B, T, C) h = self.ln.forward(h) # (B, T, C) h_fc1 = self.fc1.forward(h) # (B, T, 4C) — pre-GELU act = _gelu(h_fc1) h = self.fc2.forward(act) # (B, T, C) h = h.transpose(0, 2, 1) # (B, C, T) self._pre_gelu = h_fc1 # cache for backward return x + h # skip def backward(self, grad_out: np.ndarray) -> np.ndarray: """grad_out: (B, C, T) → grad_x: (B, C, T)""" # skip: grad flows unchanged to x, AND through the block grad_block = grad_out.copy() # 6. permute (B, C, T) → (B, T, C) grad_block = grad_block.transpose(0, 2, 1) # 5. fc2 backward → grad w.r.t. GELU output (act), shape (B, T, 4C) grad_block = self.fc2.backward(grad_block) # 4. GELU backward (pre_gelu cached during forward) grad_block = grad_block * _gelu_grad(self._pre_gelu) # 4. fc1 backward grad_block = self.fc1.backward(grad_block) # 3. LN backward grad_block = self.ln.backward(grad_block) # 2. permute (B, T, C) → (B, C, T) grad_block = grad_block.transpose(0, 2, 1) # 1. DWConv backward grad_block = self.dwconv.backward(grad_block) return grad_out + grad_block # skip: grad_x = grad_out + grad_block # --------------------------------------------------------------------------- # Pool: AdaptiveAvgPool1D (fixed stride, e.g. T=32→T=16 stride=2) # --------------------------------------------------------------------------- class AvgPool1D: """Stride-2 average pooling over temporal axis.""" def __init__(self, stride: int = 2): self.stride = stride self._T_in: int | None = None def params(self) -> list[Param]: return [] def forward(self, x: np.ndarray) -> np.ndarray: """x: (B, C, T) → (B, C, T//stride)""" B, C, T = x.shape self._T_in = T T_out = T // self.stride # reshape and mean return x[:, :, :T_out * self.stride].reshape(B, C, T_out, self.stride).mean(-1) def backward(self, grad_out: np.ndarray) -> np.ndarray: B, C, T_out = grad_out.shape T_in = self._T_in # each output contributes 1/stride to each input in the pool g = grad_out[:, :, :, None] * np.ones((1, 1, 1, self.stride)) / self.stride g = g.reshape(B, C, T_out * self.stride) if T_in > T_out * self.stride: g = np.pad(g, ((0, 0), (0, 0), (0, T_in - T_out * self.stride))) return g # --------------------------------------------------------------------------- # Upsample: nearest-neighbor ×2 along temporal axis # --------------------------------------------------------------------------- class Upsample1D: """Nearest-neighbor upsampling ×factor along temporal axis.""" def __init__(self, factor: int = 2): self.factor = factor def params(self) -> list[Param]: return [] def forward(self, x: np.ndarray) -> np.ndarray: return np.repeat(x, self.factor, axis=-1) def backward(self, grad_out: np.ndarray) -> np.ndarray: f = self.factor B, C, T = grad_out.shape T_in = T // f return grad_out.reshape(B, C, T_in, f).sum(-1) # --------------------------------------------------------------------------- # PointwiseProj: 1×1 conv (= Linear applied per timestep) # --------------------------------------------------------------------------- class PointwiseProj: """(B, C_in, T) → (B, C_out, T) via shared Linear(C_in, C_out).""" def __init__(self, C_in: int, C_out: int, rng: np.random.RandomState): self.lin = Linear(C_in, C_out, rng) def params(self) -> list[Param]: return self.lin.params() def forward(self, x: np.ndarray) -> np.ndarray: """x: (B, C_in, T)""" B, C, T = x.shape x_t = x.transpose(0, 2, 1) # (B, T, C_in) out = self.lin.forward(x_t) # (B, T, C_out) return out.transpose(0, 2, 1) # (B, C_out, T) def backward(self, grad_out: np.ndarray) -> np.ndarray: g_t = grad_out.transpose(0, 2, 1) # (B, T, C_out) g_t = self.lin.backward(g_t) # (B, T, C_in) return g_t.transpose(0, 2, 1) # (B, C_in, T) # --------------------------------------------------------------------------- # Main model: ConvNeXtVAE # --------------------------------------------------------------------------- class ConvNeXtVAE: """ ConvNeXt-1D β-TCVAE. Parameters ---------- C_in : input channels (default 8) T_in : input timesteps (default 32) z_dim : latent dimensionality (default 32) base_ch : stem channel width (default 32) dw_k : depthwise kernel size (default 7) n_blocks : ConvNeXt blocks per stage (default 3) seed : RNG seed """ def __init__( self, C_in: int = 8, T_in: int = 32, z_dim: int = 32, base_ch: int = 32, dw_k: int = 7, n_blocks: int = 3, seed: int = 42, ): rng = np.random.RandomState(seed) self.C_in = C_in self.T_in = T_in self.z_dim = z_dim self.base_ch = base_ch ch1 = base_ch * 2 # 64 # ---- ENCODER ---- # Stem self.stem_proj = PointwiseProj(C_in, base_ch, rng) self.stem_ln = LayerNorm(base_ch) # Stage 0 self.stage0 = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)] # Pool + channel proj self.pool = AvgPool1D(stride=2) self.pool_ln = LayerNorm(base_ch) self.pool_proj = PointwiseProj(base_ch, ch1, rng) # Stage 1 self.stage1 = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)] # Global avg pool → z_mu, z_logvar self.head_mu = Linear(ch1, z_dim, rng) self.head_logvar = Linear(ch1, z_dim, rng) # ---- DECODER ---- T_bot = T_in // 2 # bottleneck temporal size (16 for T_in=32) self.dec_linear = Linear(z_dim, ch1, rng) # Stage 1D self.stage1d = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)] # Upsample + channel proj self.up = Upsample1D(factor=2) self.up_proj = PointwiseProj(ch1, base_ch, rng) self.up_ln = LayerNorm(base_ch) # Stage 0D self.stage0d = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)] # Output projection self.out_proj = PointwiseProj(base_ch, C_in, rng) # Cache for backward self._enc_cache: dict = {} self._dec_cache: dict = {} self._T_bot = T_bot # ------------------------------------------------------------------ def all_params(self) -> list[Param]: ps = [] ps += self.stem_proj.params() + self.stem_ln.params() for blk in self.stage0: ps += blk.params() ps += (self.pool_ln.params() + self.pool_proj.params()) for blk in self.stage1: ps += blk.params() ps += self.head_mu.params() + self.head_logvar.params() ps += self.dec_linear.params() for blk in self.stage1d: ps += blk.params() ps += self.up_proj.params() + self.up_ln.params() for blk in self.stage0d: ps += blk.params() ps += self.out_proj.params() return ps def zero_grad(self): for p in self.all_params(): p.zero_grad() def adam_step(self, lr: float, wd: float = 1e-4): for p in self.all_params(): p.adam_step(lr, wd=wd) def n_params(self) -> int: return sum(p.data.size for p in self.all_params()) # ------------------------------------------------------------------ # ENCODER # ------------------------------------------------------------------ def encode(self, x: np.ndarray): """ x: (B, C_in, T_in) Returns z_mu (B,z), z_logvar (B,z) """ c = self._enc_cache c['x'] = x # Stem h = self.stem_proj.forward(x) # (B, base_ch, T) h = h.transpose(0, 2, 1) # (B, T, base_ch) h = self.stem_ln.forward(h) # (B, T, base_ch) h = h.transpose(0, 2, 1) # (B, base_ch, T) c['after_stem'] = h # Stage 0 for blk in self.stage0: h = blk.forward(h) c['after_stage0'] = h # Pool + proj h = self.pool.forward(h) # (B, base_ch, T//2) h = h.transpose(0, 2, 1) # (B, T//2, base_ch) h = self.pool_ln.forward(h) # (B, T//2, base_ch) h = h.transpose(0, 2, 1) # (B, base_ch, T//2) h = self.pool_proj.forward(h) # (B, ch1, T//2) c['after_pool'] = h # Stage 1 for blk in self.stage1: h = blk.forward(h) c['after_stage1'] = h # Global avg pool h_gap = h.mean(axis=-1) # (B, ch1) c['gap'] = h_gap z_mu = self.head_mu.forward(h_gap) # (B, z_dim) z_logvar = self.head_logvar.forward(h_gap) # (B, z_dim) return z_mu, z_logvar # ------------------------------------------------------------------ # DECODER # ------------------------------------------------------------------ def decode(self, z: np.ndarray) -> np.ndarray: """ z: (B, z_dim) Returns x_recon: (B, C_in, T_in) """ c = self._dec_cache T_bot = self._T_bot h = self.dec_linear.forward(z) # (B, ch1) h = h[:, :, None] * np.ones((1, 1, T_bot)) # broadcast (B, ch1, T_bot) c['dec_expand'] = h for blk in self.stage1d: h = blk.forward(h) c['after_stage1d'] = h # Upsample + proj h = self.up.forward(h) # (B, ch1, T_in) h = self.up_proj.forward(h) # (B, base_ch, T_in) h = h.transpose(0, 2, 1) # (B, T_in, base_ch) h = self.up_ln.forward(h) # (B, T_in, base_ch) h = h.transpose(0, 2, 1) # (B, base_ch, T_in) c['after_up'] = h for blk in self.stage0d: h = blk.forward(h) x_recon = self.out_proj.forward(h) # (B, C_in, T_in) return x_recon # ------------------------------------------------------------------ # BACKWARD (encoder) # ------------------------------------------------------------------ def backward_encode(self, dz_mu: np.ndarray, dz_logvar: np.ndarray): """ Given grad w.r.t. z_mu and z_logvar, backprop through encoder. Returns grad w.r.t. x (not used further, just for completeness). """ c = self._enc_cache # head d_gap = self.head_mu.backward(dz_mu) + self.head_logvar.backward(dz_logvar) # global avg pool backward: distribute evenly over T axis h_after = c['after_stage1'] T_bot = h_after.shape[-1] d_h = d_gap[:, :, None] / T_bot * np.ones((1, 1, T_bot)) # (B, ch1, T_bot) # stage 1 backward for blk in reversed(self.stage1): d_h = blk.backward(d_h) # pool_proj backward d_h = self.pool_proj.backward(d_h) # (B, base_ch, T//2) # pool_ln backward d_h = d_h.transpose(0, 2, 1) d_h = self.pool_ln.backward(d_h) d_h = d_h.transpose(0, 2, 1) # pool backward d_h = self.pool.backward(d_h) # (B, base_ch, T) # stage 0 backward for blk in reversed(self.stage0): d_h = blk.backward(d_h) # stem backward d_h = d_h.transpose(0, 2, 1) d_h = self.stem_ln.backward(d_h) d_h = d_h.transpose(0, 2, 1) d_h = self.stem_proj.backward(d_h) return d_h # grad w.r.t. input x # ------------------------------------------------------------------ # BACKWARD (decoder) # ------------------------------------------------------------------ def backward_decode(self, d_recon: np.ndarray): """ d_recon: (B, C_in, T_in) — grad w.r.t. x_recon Returns grad w.r.t. z. """ c = self._dec_cache T_bot = self._T_bot # out_proj backward d_h = self.out_proj.backward(d_recon) # (B, base_ch, T_in) # stage 0d backward for blk in reversed(self.stage0d): d_h = blk.backward(d_h) # up_ln backward d_h = d_h.transpose(0, 2, 1) d_h = self.up_ln.backward(d_h) d_h = d_h.transpose(0, 2, 1) # up_proj backward d_h = self.up_proj.backward(d_h) # (B, ch1, T_in) # upsample backward d_h = self.up.backward(d_h) # (B, ch1, T_bot) # stage 1d backward for blk in reversed(self.stage1d): d_h = blk.backward(d_h) # expand backward: sum over temporal axis (it was broadcast) d_h_expand = d_h.sum(axis=-1) # (B, ch1) # dec_linear backward dz = self.dec_linear.backward(d_h_expand) # (B, z_dim) return dz # ------------------------------------------------------------------ # FORWARD convenience # ------------------------------------------------------------------ def forward(self, x: np.ndarray): """Full forward: encode + reparameterize + decode.""" z_mu, z_logvar = self.encode(x) eps = np.random.randn(*z_mu.shape) z = z_mu + eps * np.exp(0.5 * z_logvar) x_recon = self.decode(z) return x_recon, z_mu, z_logvar, z # ------------------------------------------------------------------ # SAVE / LOAD # ------------------------------------------------------------------ def save(self, path: str, norm_mean=None, norm_std=None, extra: dict | None = None, save_adam: bool = False): data: dict = {} params = self.all_params() names = self._param_names() for name, p in zip(names, params): data[name] = p.data.tolist() if save_adam: data[f'{name}__m'] = p._m.tolist() data[f'{name}__v'] = p._v.tolist() data[f'{name}__t'] = int(p.t) if norm_mean is not None: data['norm_mean'] = norm_mean.tolist() data['norm_std'] = norm_std.tolist() data['architecture'] = { 'C_in': self.C_in, 'T_in': self.T_in, 'z_dim': self.z_dim, 'base_ch': self.base_ch, } if extra: data.update(extra) with open(path, 'w') as f: json.dump(data, f) def load(self, path: str) -> dict: with open(path) as f: data = json.load(f) params = self.all_params() names = self._param_names() reserved = set(names) for name, p in zip(names, params): if name in data: p.data[:] = np.array(data[name], dtype=np.float64) if f'{name}__m' in data: p._m[:] = np.array(data[f'{name}__m'], dtype=np.float64) p._v[:] = np.array(data[f'{name}__v'], dtype=np.float64) p.t = int(data[f'{name}__t']) reserved.update({f'{name}__m', f'{name}__v', f'{name}__t'}) extras = {k: v for k, v in data.items() if k not in reserved and k != 'architecture'} return extras def _param_names(self) -> list[str]: """Return stable flat list of parameter name strings matching all_params().""" names = [] # stem for i, _ in enumerate(self.stem_proj.params()): names.append(f'stem_proj_p{i}') for i, _ in enumerate(self.stem_ln.params()): names.append(f'stem_ln_p{i}') # stage0 for bi, blk in enumerate(self.stage0): for i, _ in enumerate(blk.params()): names.append(f'stage0_b{bi}_p{i}') # pool for i, _ in enumerate(self.pool_ln.params()): names.append(f'pool_ln_p{i}') for i, _ in enumerate(self.pool_proj.params()): names.append(f'pool_proj_p{i}') # stage1 for bi, blk in enumerate(self.stage1): for i, _ in enumerate(blk.params()): names.append(f'stage1_b{bi}_p{i}') # head for i, _ in enumerate(self.head_mu.params()): names.append(f'head_mu_p{i}') for i, _ in enumerate(self.head_logvar.params()): names.append(f'head_logvar_p{i}') # dec_linear for i, _ in enumerate(self.dec_linear.params()): names.append(f'dec_linear_p{i}') # stage1d for bi, blk in enumerate(self.stage1d): for i, _ in enumerate(blk.params()): names.append(f'stage1d_b{bi}_p{i}') # up for i, _ in enumerate(self.up_proj.params()): names.append(f'up_proj_p{i}') for i, _ in enumerate(self.up_ln.params()): names.append(f'up_ln_p{i}') # stage0d for bi, blk in enumerate(self.stage0d): for i, _ in enumerate(blk.params()): names.append(f'stage0d_b{bi}_p{i}') # out_proj for i, _ in enumerate(self.out_proj.params()): names.append(f'out_proj_p{i}') return names # --------------------------------------------------------------------------- # β-TCVAE loss # --------------------------------------------------------------------------- def btcvae_loss( x: np.ndarray, x_recon: np.ndarray, z_mu: np.ndarray, z_logvar: np.ndarray, z: np.ndarray, beta_tc: float = 4.0, alpha_mi: float = 1.0, free_bits: float = 0.0, ) -> tuple[float, dict]: """ β-TCVAE loss with minibatch TC estimator. Parameters ---------- x : (B, C, T) — input x_recon : (B, C, T) — reconstruction z_mu : (B, D) z_logvar : (B, D) z : (B, D) — reparameterized sample beta_tc : weight on Total Correlation term alpha_mi : weight on Mutual Information term free_bits : minimum KL per latent dimension (nats). Replaces the minibatch dimKL term with analytical per-dim KL clamped from below at free_bits. Prevents posterior collapse. 0 = standard behaviour. Returns ------- loss : scalar info : dict with breakdown (recon, MI, TC, dimKL) """ B, D = z.shape # Reconstruction loss (per-sample mean, then mean over batch) recon = ((x - x_recon) ** 2).mean(axis=(-1, -2)).mean() # log q(z|x): (B,) log_qz_x = -0.5 * ( np.log(2 * np.pi) + z_logvar + (z - z_mu) ** 2 * np.exp(-z_logvar) ).sum(axis=1) # log p(z): (B,) — standard Gaussian log_pz = -0.5 * (np.log(2 * np.pi) + z ** 2).sum(axis=1) # Minibatch estimator for log q(z) and log prod_d q(z_d) z_i = z[:, None, :] # (B, 1, D) mu_j = z_mu[None, :, :] # (1, B, D) lv_j = z_logvar[None, :, :] # (1, B, D) # (B, B, D) log_q_ij_d = -0.5 * ( np.log(2 * np.pi) + lv_j + (z_i - mu_j) ** 2 * np.exp(-lv_j) ) # log q(z_i) = logsumexp_j[sum_d log_q_ij_d] - log B log_qz = _logsumexp(log_q_ij_d.sum(-1), axis=1) - np.log(B) # (B,) # log prod_d q(z_d[i]) = sum_d logsumexp_j[log_q_ij_d] - D*log(B) log_qz_prod = (_logsumexp(log_q_ij_d, axis=1) - np.log(B)).sum(-1) # (B,) MI = (log_qz_x - log_qz).mean() TC = (log_qz - log_qz_prod).mean() if free_bits > 0.0: # Analytical per-dim KL, clamped: sum_d max(free_bits, KL_d) kl_analytic = 0.5 * (z_mu ** 2 + np.exp(z_logvar) - 1.0 - z_logvar) kl_per_dim = kl_analytic.mean(0) # (D,) dimKL = float(np.maximum(free_bits, kl_per_dim).sum()) else: dimKL = float((log_qz_prod - log_pz).mean()) loss = recon + alpha_mi * MI + beta_tc * TC + dimKL return float(loss), { 'recon': float(recon), 'MI': float(MI), 'TC': float(TC), 'dimKL': dimKL, } # --------------------------------------------------------------------------- # β-TCVAE backward # --------------------------------------------------------------------------- def btcvae_loss_backward( x: np.ndarray, x_recon: np.ndarray, z_mu: np.ndarray, z_logvar: np.ndarray, z: np.ndarray, eps: np.ndarray, beta_tc: float = 4.0, alpha_mi: float = 1.0, free_bits: float = 0.0, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Compute gradients of β-TCVAE loss w.r.t. x_recon, z_mu, z_logvar. Parameters ---------- eps: (B, D) — noise used in reparameterisation: z = z_mu + eps*exp(0.5*z_logvar) Returns ------- d_recon : (B, C, T) grad w.r.t. x_recon d_z_mu : (B, D) grad w.r.t. z_mu d_z_logvar: (B, D) grad w.r.t. z_logvar """ B, D = z.shape var = np.exp(z_logvar) # (B, D) std = np.exp(0.5 * z_logvar) # (B, D) # ---- grad w.r.t. x_recon (reconstruction MSE) ---- # L_recon = mean_{b,c,t} (x-x_recon)^2 = (1/(B*C*T)) * ||x - x_recon||^2 # dL/dx_recon = -2*(x - x_recon) / (B*C*T) n_elements = x.size d_recon = -2.0 * (x - x_recon) / n_elements # shape (B, C, T) # ---- KL terms via minibatch estimator ---- # log q(z_i | x_i) (B,D) terms: err = (z - z_mu) # (B, D) log_qz_x_d = -0.5 * (np.log(2 * np.pi) + z_logvar + err ** 2 / var) # Minibatch terms z_i = z[:, None, :] mu_j = z_mu[None, :, :] lv_j = z_logvar[None, :, :] diff_ij = z_i - mu_j # (B, B, D) var_j = np.exp(lv_j) # (B, B, D) log_q_ij_d = -0.5 * (np.log(2 * np.pi) + lv_j + diff_ij ** 2 / var_j) # (B,B,D) # Weight matrices for log_qz and log_qz_prod # w_qz[i, j] = softmax_j[ sum_d log_q_ij_d[i, j, :] ] lq_sum = log_q_ij_d.sum(-1) # (B, B) lq_sum_max = lq_sum.max(1, keepdims=True) w_qz = np.exp(lq_sum - lq_sum_max) w_qz /= w_qz.sum(1, keepdims=True) + 1e-30 # (B, B) softmax over j # w_prod[i, j, d] = softmax_j[ log_q_ij_d[i, j, d] ] for each d separately lq_max_d = log_q_ij_d.max(1, keepdims=True) # (B, 1, D) w_prod = np.exp(log_q_ij_d - lq_max_d) w_prod /= w_prod.sum(1, keepdims=True) + 1e-30 # (B, B, D) # ---- d(MI + TC + dimKL) / d(z_logvar) and d(z_mu) ---- # MI = mean_i [ log_qz_x[i] - log_qz[i] ] # TC = mean_i [ log_qz[i] - log_qz_prod[i] ] # dimKL = mean_i [ log_qz_prod[i] - log_pz[i] ] # # Combined coefficient of each log term: # d/dtheta = (1/B) * [ alpha * d_log_qz_x - alpha * d_log_qz # + beta * d_log_qz - beta * d_log_qz_prod # + d_log_qz_prod - d_log_pz ] # d log_qz_x[i] / d z_mu[i, d] = err[i,d] / var[i,d] (positive) # d log_qz_x[i] / d z_logvar[i,d] = -0.5 + 0.5 * err[i,d]^2 / var[i,d] # d log_qz[i] / d z_mu[j, d] = -w_qz[i, j] * diff_ij[i, j, d] / var_j[i,j,d] # (gradient w.r.t. z_mu is aggregated over i by summing) # (note: this is through the *denominator* of q(z_i|x_j), not z itself) # d log_qz_prod[i,d] / d z_mu[j,d] = -w_prod[i,j,d] * diff_ij[i,j,d]/var_j[i,j,d] # Additionally, z = mu + eps*std, so d_log_qz/d_z_mu has a DIRECT path via z_i: # d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (-(z[i,d] - mu_j[d]) / var_j[i,j,d]) # d z[i,d] / d mu[i,d] = 1 → direct path # For cleaner code, compute: # G_mu[j, d] = sum_i (alpha_mi + beta_tc) * dlog_qz/dmu_j[d] (through z_mu[j]) # + sum_i (1+beta_tc) * dlog_qz_prod/dmu_j[d] (through z_mu[j]) # Plus direct terms from d_log_qz_x # Direct terms from log q(z|x): # d log_qz_x[i,d] / d mu[i,d] = (z[i,d] - mu[i,d]) / var[i,d] # With reparameterization: d z/d mu = 1, so also flows through z # d log_qz_x[i,d] / d z[i,d] = -err[i,d] / var[i,d] → d mu[i,d] # So total: +err/var + (-err/var) from z-path = 0 for the z-direct path # Actually: d loss / d mu = d loss / d z * dz/dmu + d loss / d mu directly # We'll compute d_loss/d_z separately and use reparameterization to split. # ---- Gradient through reparameterization ---- # L is a function of z_mu, z_logvar, and z = z_mu + eps*std. # dL/d_mu[i,d] = dL/d_z[i,d] * 1 + (dL/d_mu directly) # dL/d_logvar[i,d]= dL/d_z[i,d] * 0.5*eps[i,d]*std[i,d] + (dL/d_logvar directly) # Compute dL/d_z[i,d]: # 1. Reconstruction path (decoder): computed elsewhere and passed in as dz_from_recon # (we return d_recon so caller can backprop decoder separately) # Here we only handle KL/TC terms which depend on z. # 2. From log_qz_x[i] = sum_d log_qz_x_d[i,d]: # d log_qz_x[i] / d z[i,d] = -(z[i,d] - mu[i,d]) / var[i,d] = -err/var d_logqzx_d_z = -err / var # (B, D) # 3. From log_qz[i] via z: (z[i,d] enters log_qz[i] through diff_ij[i,:,d]) # d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (- diff_ij[i,j,d] / var_j[i,j,d]) # But wait: diff_ij[i,j,d] = z[i,d] - mu_j[d], so d(diff_ij)/d(z[i,d]) = 1 d_logqz_d_z = -(w_qz[:, :, None] * diff_ij / var_j).sum(1) # (B, D) # 4. From log_qz_prod[i] via z: # d log_qz_prod[i] / d z[i,d] = sum_j w_prod[i,j,d] * (-diff_ij[i,j,d]/var_j[i,j,d]) d_logqzprod_d_z = -(w_prod * diff_ij / var_j).sum(1) # (B, D) # 5. From log_pz: d log_pz / d z = -z d_logpz_d_z = -z # (B, D) # Combine with coefficients — MI + TC terms only; dimKL handled separately below # when free_bits > 0: skip ALL MI/TC aggregate-posterior terms to prevent # collapse. Only the analytical per-dim KL (masked) + direct encoder term # (alpha_mi * err/var below) remain. This is a plain VAE-with-free-bits mode. # when free_bits == 0: full beta-TCVAE with minibatch TC/MI decomposition. if free_bits > 0.0: dL_dz = np.zeros((B, D), dtype=np.float64) # no MI/TC reparameterization gradient else: dL_dz = (1.0 / B) * ( alpha_mi * d_logqzx_d_z - alpha_mi * d_logqz_d_z + beta_tc * d_logqz_d_z - beta_tc * d_logqzprod_d_z + d_logqzprod_d_z - d_logpz_d_z ) # (B, D) # Gradient through z_mu and z_logvar VIA z (reparameterization): d_z_mu = dL_dz.copy() d_z_logvar = dL_dz * 0.5 * eps * std # Direct gradient w.r.t. z_mu from log q(z|x): d_z_mu += (alpha_mi / B) * err / var # Direct gradient w.r.t. z_logvar from log q(z|x): d_z_logvar += (alpha_mi / B) * (-0.5 + 0.5 * err ** 2 / var) # Gradient of log_qz and log_qz_prod w.r.t. z_mu[j] directly d_logqz_d_mu_direct = ( (w_qz[:, :, None] * diff_ij / var_j).sum(0) # (B, D) ) d_logqzprod_d_mu_direct = ( (w_prod * diff_ij / var_j).sum(0) # (B, D) ) if free_bits > 0.0: pass # no MI/TC direct-to-mu gradient in plain-VAE-with-free-bits mode else: d_z_mu += (1.0 / B) * ( -alpha_mi * d_logqz_d_mu_direct + beta_tc * d_logqz_d_mu_direct - beta_tc * d_logqzprod_d_mu_direct + d_logqzprod_d_mu_direct ) d_logqz_d_lv_direct = ( w_qz[:, :, None] * (-0.5 + 0.5 * diff_ij ** 2 / var_j) ).sum(0) # (B, D) d_logqzprod_d_lv_direct = ( w_prod * (-0.5 + 0.5 * diff_ij ** 2 / var_j) ).sum(0) # (B, D) if free_bits > 0.0: # no MI/TC direct-to-logvar gradient in plain-VAE-with-free-bits mode # Analytical free-bits dimKL gradient: only for dims above the floor kl_analytic = 0.5 * (z_mu ** 2 + var - 1.0 - z_logvar) # (B, D) kl_per_dim = kl_analytic.mean(0) # (D,) mask = (kl_per_dim >= free_bits).astype(np.float64) # (D,) — 1 = penalised d_z_mu += mask[None, :] * z_mu / B d_z_logvar += mask[None, :] * 0.5 * (var - 1.0) / B else: d_z_logvar += (1.0 / B) * ( -alpha_mi * d_logqz_d_lv_direct + beta_tc * d_logqz_d_lv_direct - beta_tc * d_logqzprod_d_lv_direct + d_logqzprod_d_lv_direct ) return d_recon, d_z_mu, d_z_logvar