Files
DOLPHIN/nautilus_dolphin/dvae/convnext_dvae.py

1049 lines
39 KiB
Python
Raw Normal View History

"""
ConvNeXt-1D β-TCVAE pure numpy (no PyTorch required).
========================================================
Architecture
------------
Input : (B, C_in=8, T=32) 8 eigenvalue channels × 32 timestep window
Stem : PointwiseProj(832) + LayerNorm
Stage 0: 3 × ConvNeXtBlock1D(32, dw_k=7)
Pool : AdaptiveAvgPool(16) + PointwiseProj(3264)
Stage 1: 3 × ConvNeXtBlock1D(64, dw_k=7)
Head : GlobalAvgPool Linear(6432) for z_mu / z_logvar
Decoder (mirrored):
Expand : Linear(3264) repeat 16 times along T axis
Stage 1D: 3 × ConvNeXtBlock1D(64, dw_k=7)
Upsample: RepeatInterleave(×2) + PointwiseProj(6432)
Stage 0D: 3 × ConvNeXtBlock1D(32, dw_k=7)
Out : PointwiseProj(328)
Objective: β-TCVAE (Chen et al. 2018, "Isolating Sources of Disentanglement")
β_tc=4.0, α_mi=1.0. Minibatch TC estimator.
Gradient implementation notes
------------------------------
All depthwise convolutions use the im2col trick so that gradients reduce to
standard matmul / einsum no custom numerical differentiation required.
DWConv1d forward : windows[b,c,t,k] = x_pad[b,c,t+k]
out[b,c,t] = sum_k windows * w[c,k] + b[c]
DWConv1d backward: grad_w[c] = einsum('btk,bt->k', windows_c, grad_out_c) over B
col2im for grad_x (accumulate k-shifted copies of grad_out * w)
β-TCVAE TC gradient: computed via minibatch weight matrix w[i,j] q(z_i|x_j)
"""
from __future__ import annotations
import json
import numpy as np
from scipy.special import erf
try:
from numba import njit as _njit
_NUMBA = True
except ImportError:
_NUMBA = False
def _njit(fn): # no-op decorator
return fn
__all__ = ['ConvNeXtVAE', 'btcvae_loss', 'btcvae_loss_backward']
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
def _kaiming(fan_in: int, fan_out: int, rng: np.random.RandomState) -> np.ndarray:
std = np.sqrt(2.0 / fan_in)
return rng.randn(fan_in, fan_out) * std
def _gelu(x: np.ndarray) -> np.ndarray:
return 0.5 * x * (1.0 + erf(x / np.sqrt(2.0)))
def _gelu_grad(x: np.ndarray) -> np.ndarray:
# Two allocations instead of three: out= avoids the final x*pdf temporary
cdf = 0.5 * (1.0 + erf(x * (1.0 / np.sqrt(2.0))))
pdf = np.exp(-0.5 * x * x)
pdf *= 1.0 / np.sqrt(2.0 * np.pi)
np.multiply(x, pdf, out=pdf) # pdf now holds x * pdf, in-place
pdf += cdf # pdf now holds cdf + x*pdf
return pdf
def _logsumexp(a: np.ndarray, axis: int) -> np.ndarray:
a_max = a.max(axis=axis, keepdims=True)
out = np.log(np.exp(a - a_max).sum(axis=axis, keepdims=True)) + a_max
return out.squeeze(axis)
# ---------------------------------------------------------------------------
# Parameter container — stores weight, gradient, Adam state
# ---------------------------------------------------------------------------
class Param:
__slots__ = ('data', 'grad', '_m', '_v', 't')
def __init__(self, data: np.ndarray):
self.data = data.astype(np.float64)
self.grad = np.zeros_like(self.data)
self._m = np.zeros_like(self.data)
self._v = np.zeros_like(self.data)
self.t = 0
def zero_grad(self):
self.grad.fill(0.0)
def adam_step(self, lr: float, β1: float = 0.9, β2: float = 0.999,
eps: float = 1e-8, wd: float = 0.0):
self.t += 1
g = self.grad
if wd > 0:
g = g + wd * self.data
self._m = β1 * self._m + (1 - β1) * g
self._v = β2 * self._v + (1 - β2) * g * g
m_hat = self._m / (1 - β1 ** self.t)
v_hat = self._v / (1 - β2 ** self.t)
self.data -= lr * m_hat / (np.sqrt(v_hat) + eps)
# ---------------------------------------------------------------------------
# im2col helpers
# ---------------------------------------------------------------------------
def _im2col1d(x: np.ndarray, k: int, pad: int) -> np.ndarray:
"""
x : (B, C, T)
returns: (B, C, T, k) zero-padded windows
Uses np.lib.stride_tricks.as_strided for zero-copy window view,
then immediately makes a contiguous copy so writes to the returned
array never corrupt x_pad.
"""
B, C, T = x.shape
x_pad = np.pad(x, ((0, 0), (0, 0), (pad, pad)))
# stride-trick view: step one element along T axis per k-position
shape = (B, C, T, k)
s0, s1, s2 = x_pad.strides
strides = (s0, s1, s2, s2)
windows = np.lib.stride_tricks.as_strided(x_pad, shape=shape, strides=strides)
return windows.copy() # contiguous copy — safe for in-place ops downstream
@_njit(cache=True)
def _col2im1d_kernel(g_pad, grad_win, T, k):
"""Numba-JIT accumulation loop for col2im."""
for i in range(k):
g_pad[:, :, i:i + T] += grad_win[:, :, :, i]
def _col2im1d(grad_win: np.ndarray, T: int, k: int, pad: int) -> np.ndarray:
"""
grad_win: (B, C, T, k)
returns : (B, C, T)
"""
B, C = grad_win.shape[0], grad_win.shape[1]
g_pad = np.zeros((B, C, T + 2 * pad))
_col2im1d_kernel(g_pad, grad_win, T, k)
return g_pad[:, :, pad:pad + T]
# ---------------------------------------------------------------------------
# Layer: DWConv1d (depthwise conv via im2col)
# ---------------------------------------------------------------------------
class DWConv1d:
def __init__(self, C: int, k: int, rng: np.random.RandomState):
self.C, self.k, self.pad = C, k, k // 2
self.w = Param(rng.randn(C, k) * 0.02)
self.b = Param(np.zeros(C))
self._cache: tuple | None = None
def params(self) -> list[Param]:
return [self.w, self.b]
def forward(self, x: np.ndarray) -> np.ndarray:
"""x: (B, C, T) → (B, C, T)"""
wins = _im2col1d(x, self.k, self.pad) # (B, C, T, k)
self._cache = (wins,)
# einsum with optimize: uses BLAS when possible
out = np.einsum('bctk,ck->bct', wins, self.w.data, optimize=True)
out += self.b.data[None, :, None]
return out
def backward(self, grad_out: np.ndarray) -> np.ndarray:
"""grad_out: (B, C, T) → grad_x: (B, C, T)"""
wins, = self._cache
B, C, T = grad_out.shape
# grad w: einsum over B and T
self.w.grad += np.einsum('bct,bctk->ck', grad_out, wins, optimize=True)
# grad b
self.b.grad += grad_out.sum(axis=(0, 2)) # (C,)
# grad x via col2im
grad_win = np.einsum('bct,ck->bctk', grad_out, self.w.data, optimize=True)
return _col2im1d(grad_win, T, self.k, self.pad)
# ---------------------------------------------------------------------------
# Layer: LayerNorm (over last axis, applied to (B, T, C) or (B, C))
# ---------------------------------------------------------------------------
class LayerNorm:
EPS = 1e-6
def __init__(self, D: int, rng: np.random.RandomState | None = None):
self.D = D
self.gamma = Param(np.ones(D))
self.beta = Param(np.zeros(D))
self._cache: tuple | None = None
def params(self) -> list[Param]:
return [self.gamma, self.beta]
def forward(self, x: np.ndarray) -> np.ndarray:
"""Normalize last axis."""
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x_hat = (x - mean) / np.sqrt(var + self.EPS)
self._cache = (x_hat, var)
return self.gamma.data * x_hat + self.beta.data
def backward(self, dy: np.ndarray) -> np.ndarray:
x_hat, var = self._cache
D = x_hat.shape[-1]
orig_shape = dy.shape
dy_2d = dy.reshape(-1, D)
xhat_2d = x_hat.reshape(-1, D)
var_2d = var.reshape(-1, 1)
self.gamma.grad += (dy_2d * xhat_2d).sum(0)
self.beta.grad += dy_2d.sum(0)
dx_hat = dy_2d * self.gamma.data # (N, D)
std_inv = 1.0 / np.sqrt(var_2d + self.EPS) # (N, 1)
dx = std_inv * (dx_hat
- dx_hat.mean(-1, keepdims=True)
- xhat_2d * (dx_hat * xhat_2d).mean(-1, keepdims=True))
return dx.reshape(orig_shape)
# ---------------------------------------------------------------------------
# Layer: Linear
# ---------------------------------------------------------------------------
class Linear:
def __init__(self, in_f: int, out_f: int, rng: np.random.RandomState,
bias: bool = True):
self.W = Param(_kaiming(in_f, out_f, rng))
self.b = Param(np.zeros(out_f)) if bias else None
self._x_cache: np.ndarray | None = None
def params(self) -> list[Param]:
return [self.W] + ([self.b] if self.b is not None else [])
def forward(self, x: np.ndarray) -> np.ndarray:
self._x_cache = x
out = x @ self.W.data
if self.b is not None:
out = out + self.b.data
return out
def backward(self, dy: np.ndarray) -> np.ndarray:
x = self._x_cache
x2 = x.reshape(-1, self.W.data.shape[0])
dy2 = dy.reshape(-1, self.W.data.shape[1])
self.W.grad += x2.T @ dy2
if self.b is not None:
self.b.grad += dy2.sum(0)
return (dy2 @ self.W.data.T).reshape(x.shape)
# ---------------------------------------------------------------------------
# Block: ConvNeXtBlock1D
# ---------------------------------------------------------------------------
# Input/output: (B, C, T)
# 1. DWConv1d(C, k)
# 2. Permute → (B, T, C)
# 3. LayerNorm(C)
# 4. Linear(C → 4C) + GELU
# 5. Linear(4C → C)
# 6. Permute → (B, C, T)
# 7. Skip
class ConvNeXtBlock1D:
def __init__(self, C: int, k: int, rng: np.random.RandomState):
self.dwconv = DWConv1d(C, k, rng)
self.ln = LayerNorm(C)
self.fc1 = Linear(C, 4 * C, rng)
self.fc2 = Linear(4 * C, C, rng)
self._cache: dict | None = None
def params(self) -> list[Param]:
return (self.dwconv.params() + self.ln.params() +
self.fc1.params() + self.fc2.params())
def forward(self, x: np.ndarray) -> np.ndarray:
"""x: (B, C, T) → (B, C, T)"""
h = self.dwconv.forward(x) # (B, C, T)
h = h.transpose(0, 2, 1) # (B, T, C)
h = self.ln.forward(h) # (B, T, C)
h_fc1 = self.fc1.forward(h) # (B, T, 4C) — pre-GELU
act = _gelu(h_fc1)
h = self.fc2.forward(act) # (B, T, C)
h = h.transpose(0, 2, 1) # (B, C, T)
self._pre_gelu = h_fc1 # cache for backward
return x + h # skip
def backward(self, grad_out: np.ndarray) -> np.ndarray:
"""grad_out: (B, C, T) → grad_x: (B, C, T)"""
# skip: grad flows unchanged to x, AND through the block
grad_block = grad_out.copy()
# 6. permute (B, C, T) → (B, T, C)
grad_block = grad_block.transpose(0, 2, 1)
# 5. fc2 backward → grad w.r.t. GELU output (act), shape (B, T, 4C)
grad_block = self.fc2.backward(grad_block)
# 4. GELU backward (pre_gelu cached during forward)
grad_block = grad_block * _gelu_grad(self._pre_gelu)
# 4. fc1 backward
grad_block = self.fc1.backward(grad_block)
# 3. LN backward
grad_block = self.ln.backward(grad_block)
# 2. permute (B, T, C) → (B, C, T)
grad_block = grad_block.transpose(0, 2, 1)
# 1. DWConv backward
grad_block = self.dwconv.backward(grad_block)
return grad_out + grad_block # skip: grad_x = grad_out + grad_block
# ---------------------------------------------------------------------------
# Pool: AdaptiveAvgPool1D (fixed stride, e.g. T=32→T=16 stride=2)
# ---------------------------------------------------------------------------
class AvgPool1D:
"""Stride-2 average pooling over temporal axis."""
def __init__(self, stride: int = 2):
self.stride = stride
self._T_in: int | None = None
def params(self) -> list[Param]:
return []
def forward(self, x: np.ndarray) -> np.ndarray:
"""x: (B, C, T) → (B, C, T//stride)"""
B, C, T = x.shape
self._T_in = T
T_out = T // self.stride
# reshape and mean
return x[:, :, :T_out * self.stride].reshape(B, C, T_out, self.stride).mean(-1)
def backward(self, grad_out: np.ndarray) -> np.ndarray:
B, C, T_out = grad_out.shape
T_in = self._T_in
# each output contributes 1/stride to each input in the pool
g = grad_out[:, :, :, None] * np.ones((1, 1, 1, self.stride)) / self.stride
g = g.reshape(B, C, T_out * self.stride)
if T_in > T_out * self.stride:
g = np.pad(g, ((0, 0), (0, 0), (0, T_in - T_out * self.stride)))
return g
# ---------------------------------------------------------------------------
# Upsample: nearest-neighbor ×2 along temporal axis
# ---------------------------------------------------------------------------
class Upsample1D:
"""Nearest-neighbor upsampling ×factor along temporal axis."""
def __init__(self, factor: int = 2):
self.factor = factor
def params(self) -> list[Param]:
return []
def forward(self, x: np.ndarray) -> np.ndarray:
return np.repeat(x, self.factor, axis=-1)
def backward(self, grad_out: np.ndarray) -> np.ndarray:
f = self.factor
B, C, T = grad_out.shape
T_in = T // f
return grad_out.reshape(B, C, T_in, f).sum(-1)
# ---------------------------------------------------------------------------
# PointwiseProj: 1×1 conv (= Linear applied per timestep)
# ---------------------------------------------------------------------------
class PointwiseProj:
"""(B, C_in, T) → (B, C_out, T) via shared Linear(C_in, C_out)."""
def __init__(self, C_in: int, C_out: int, rng: np.random.RandomState):
self.lin = Linear(C_in, C_out, rng)
def params(self) -> list[Param]:
return self.lin.params()
def forward(self, x: np.ndarray) -> np.ndarray:
"""x: (B, C_in, T)"""
B, C, T = x.shape
x_t = x.transpose(0, 2, 1) # (B, T, C_in)
out = self.lin.forward(x_t) # (B, T, C_out)
return out.transpose(0, 2, 1) # (B, C_out, T)
def backward(self, grad_out: np.ndarray) -> np.ndarray:
g_t = grad_out.transpose(0, 2, 1) # (B, T, C_out)
g_t = self.lin.backward(g_t) # (B, T, C_in)
return g_t.transpose(0, 2, 1) # (B, C_in, T)
# ---------------------------------------------------------------------------
# Main model: ConvNeXtVAE
# ---------------------------------------------------------------------------
class ConvNeXtVAE:
"""
ConvNeXt-1D β-TCVAE.
Parameters
----------
C_in : input channels (default 8)
T_in : input timesteps (default 32)
z_dim : latent dimensionality (default 32)
base_ch : stem channel width (default 32)
dw_k : depthwise kernel size (default 7)
n_blocks : ConvNeXt blocks per stage (default 3)
seed : RNG seed
"""
def __init__(
self,
C_in: int = 8,
T_in: int = 32,
z_dim: int = 32,
base_ch: int = 32,
dw_k: int = 7,
n_blocks: int = 3,
seed: int = 42,
):
rng = np.random.RandomState(seed)
self.C_in = C_in
self.T_in = T_in
self.z_dim = z_dim
self.base_ch = base_ch
ch1 = base_ch * 2 # 64
# ---- ENCODER ----
# Stem
self.stem_proj = PointwiseProj(C_in, base_ch, rng)
self.stem_ln = LayerNorm(base_ch)
# Stage 0
self.stage0 = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)]
# Pool + channel proj
self.pool = AvgPool1D(stride=2)
self.pool_ln = LayerNorm(base_ch)
self.pool_proj = PointwiseProj(base_ch, ch1, rng)
# Stage 1
self.stage1 = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)]
# Global avg pool → z_mu, z_logvar
self.head_mu = Linear(ch1, z_dim, rng)
self.head_logvar = Linear(ch1, z_dim, rng)
# ---- DECODER ----
T_bot = T_in // 2 # bottleneck temporal size (16 for T_in=32)
self.dec_linear = Linear(z_dim, ch1, rng)
# Stage 1D
self.stage1d = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)]
# Upsample + channel proj
self.up = Upsample1D(factor=2)
self.up_proj = PointwiseProj(ch1, base_ch, rng)
self.up_ln = LayerNorm(base_ch)
# Stage 0D
self.stage0d = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)]
# Output projection
self.out_proj = PointwiseProj(base_ch, C_in, rng)
# Cache for backward
self._enc_cache: dict = {}
self._dec_cache: dict = {}
self._T_bot = T_bot
# ------------------------------------------------------------------
def all_params(self) -> list[Param]:
ps = []
ps += self.stem_proj.params() + self.stem_ln.params()
for blk in self.stage0:
ps += blk.params()
ps += (self.pool_ln.params() + self.pool_proj.params())
for blk in self.stage1:
ps += blk.params()
ps += self.head_mu.params() + self.head_logvar.params()
ps += self.dec_linear.params()
for blk in self.stage1d:
ps += blk.params()
ps += self.up_proj.params() + self.up_ln.params()
for blk in self.stage0d:
ps += blk.params()
ps += self.out_proj.params()
return ps
def zero_grad(self):
for p in self.all_params():
p.zero_grad()
def adam_step(self, lr: float, wd: float = 1e-4):
for p in self.all_params():
p.adam_step(lr, wd=wd)
def n_params(self) -> int:
return sum(p.data.size for p in self.all_params())
# ------------------------------------------------------------------
# ENCODER
# ------------------------------------------------------------------
def encode(self, x: np.ndarray):
"""
x: (B, C_in, T_in)
Returns z_mu (B,z), z_logvar (B,z)
"""
c = self._enc_cache
c['x'] = x
# Stem
h = self.stem_proj.forward(x) # (B, base_ch, T)
h = h.transpose(0, 2, 1) # (B, T, base_ch)
h = self.stem_ln.forward(h) # (B, T, base_ch)
h = h.transpose(0, 2, 1) # (B, base_ch, T)
c['after_stem'] = h
# Stage 0
for blk in self.stage0:
h = blk.forward(h)
c['after_stage0'] = h
# Pool + proj
h = self.pool.forward(h) # (B, base_ch, T//2)
h = h.transpose(0, 2, 1) # (B, T//2, base_ch)
h = self.pool_ln.forward(h) # (B, T//2, base_ch)
h = h.transpose(0, 2, 1) # (B, base_ch, T//2)
h = self.pool_proj.forward(h) # (B, ch1, T//2)
c['after_pool'] = h
# Stage 1
for blk in self.stage1:
h = blk.forward(h)
c['after_stage1'] = h
# Global avg pool
h_gap = h.mean(axis=-1) # (B, ch1)
c['gap'] = h_gap
z_mu = self.head_mu.forward(h_gap) # (B, z_dim)
z_logvar = self.head_logvar.forward(h_gap) # (B, z_dim)
return z_mu, z_logvar
# ------------------------------------------------------------------
# DECODER
# ------------------------------------------------------------------
def decode(self, z: np.ndarray) -> np.ndarray:
"""
z: (B, z_dim)
Returns x_recon: (B, C_in, T_in)
"""
c = self._dec_cache
T_bot = self._T_bot
h = self.dec_linear.forward(z) # (B, ch1)
h = h[:, :, None] * np.ones((1, 1, T_bot)) # broadcast (B, ch1, T_bot)
c['dec_expand'] = h
for blk in self.stage1d:
h = blk.forward(h)
c['after_stage1d'] = h
# Upsample + proj
h = self.up.forward(h) # (B, ch1, T_in)
h = self.up_proj.forward(h) # (B, base_ch, T_in)
h = h.transpose(0, 2, 1) # (B, T_in, base_ch)
h = self.up_ln.forward(h) # (B, T_in, base_ch)
h = h.transpose(0, 2, 1) # (B, base_ch, T_in)
c['after_up'] = h
for blk in self.stage0d:
h = blk.forward(h)
x_recon = self.out_proj.forward(h) # (B, C_in, T_in)
return x_recon
# ------------------------------------------------------------------
# BACKWARD (encoder)
# ------------------------------------------------------------------
def backward_encode(self, dz_mu: np.ndarray, dz_logvar: np.ndarray):
"""
Given grad w.r.t. z_mu and z_logvar, backprop through encoder.
Returns grad w.r.t. x (not used further, just for completeness).
"""
c = self._enc_cache
# head
d_gap = self.head_mu.backward(dz_mu) + self.head_logvar.backward(dz_logvar)
# global avg pool backward: distribute evenly over T axis
h_after = c['after_stage1']
T_bot = h_after.shape[-1]
d_h = d_gap[:, :, None] / T_bot * np.ones((1, 1, T_bot)) # (B, ch1, T_bot)
# stage 1 backward
for blk in reversed(self.stage1):
d_h = blk.backward(d_h)
# pool_proj backward
d_h = self.pool_proj.backward(d_h) # (B, base_ch, T//2)
# pool_ln backward
d_h = d_h.transpose(0, 2, 1)
d_h = self.pool_ln.backward(d_h)
d_h = d_h.transpose(0, 2, 1)
# pool backward
d_h = self.pool.backward(d_h) # (B, base_ch, T)
# stage 0 backward
for blk in reversed(self.stage0):
d_h = blk.backward(d_h)
# stem backward
d_h = d_h.transpose(0, 2, 1)
d_h = self.stem_ln.backward(d_h)
d_h = d_h.transpose(0, 2, 1)
d_h = self.stem_proj.backward(d_h)
return d_h # grad w.r.t. input x
# ------------------------------------------------------------------
# BACKWARD (decoder)
# ------------------------------------------------------------------
def backward_decode(self, d_recon: np.ndarray):
"""
d_recon: (B, C_in, T_in) grad w.r.t. x_recon
Returns grad w.r.t. z.
"""
c = self._dec_cache
T_bot = self._T_bot
# out_proj backward
d_h = self.out_proj.backward(d_recon) # (B, base_ch, T_in)
# stage 0d backward
for blk in reversed(self.stage0d):
d_h = blk.backward(d_h)
# up_ln backward
d_h = d_h.transpose(0, 2, 1)
d_h = self.up_ln.backward(d_h)
d_h = d_h.transpose(0, 2, 1)
# up_proj backward
d_h = self.up_proj.backward(d_h) # (B, ch1, T_in)
# upsample backward
d_h = self.up.backward(d_h) # (B, ch1, T_bot)
# stage 1d backward
for blk in reversed(self.stage1d):
d_h = blk.backward(d_h)
# expand backward: sum over temporal axis (it was broadcast)
d_h_expand = d_h.sum(axis=-1) # (B, ch1)
# dec_linear backward
dz = self.dec_linear.backward(d_h_expand) # (B, z_dim)
return dz
# ------------------------------------------------------------------
# FORWARD convenience
# ------------------------------------------------------------------
def forward(self, x: np.ndarray):
"""Full forward: encode + reparameterize + decode."""
z_mu, z_logvar = self.encode(x)
eps = np.random.randn(*z_mu.shape)
z = z_mu + eps * np.exp(0.5 * z_logvar)
x_recon = self.decode(z)
return x_recon, z_mu, z_logvar, z
# ------------------------------------------------------------------
# SAVE / LOAD
# ------------------------------------------------------------------
def save(self, path: str, norm_mean=None, norm_std=None,
extra: dict | None = None, save_adam: bool = False):
data: dict = {}
params = self.all_params()
names = self._param_names()
for name, p in zip(names, params):
data[name] = p.data.tolist()
if save_adam:
data[f'{name}__m'] = p._m.tolist()
data[f'{name}__v'] = p._v.tolist()
data[f'{name}__t'] = int(p.t)
if norm_mean is not None:
data['norm_mean'] = norm_mean.tolist()
data['norm_std'] = norm_std.tolist()
data['architecture'] = {
'C_in': self.C_in, 'T_in': self.T_in, 'z_dim': self.z_dim,
'base_ch': self.base_ch,
}
if extra:
data.update(extra)
with open(path, 'w') as f:
json.dump(data, f)
def load(self, path: str) -> dict:
with open(path) as f:
data = json.load(f)
params = self.all_params()
names = self._param_names()
reserved = set(names)
for name, p in zip(names, params):
if name in data:
p.data[:] = np.array(data[name], dtype=np.float64)
if f'{name}__m' in data:
p._m[:] = np.array(data[f'{name}__m'], dtype=np.float64)
p._v[:] = np.array(data[f'{name}__v'], dtype=np.float64)
p.t = int(data[f'{name}__t'])
reserved.update({f'{name}__m', f'{name}__v', f'{name}__t'})
extras = {k: v for k, v in data.items()
if k not in reserved and k != 'architecture'}
return extras
def _param_names(self) -> list[str]:
"""Return stable flat list of parameter name strings matching all_params()."""
names = []
# stem
for i, _ in enumerate(self.stem_proj.params()):
names.append(f'stem_proj_p{i}')
for i, _ in enumerate(self.stem_ln.params()):
names.append(f'stem_ln_p{i}')
# stage0
for bi, blk in enumerate(self.stage0):
for i, _ in enumerate(blk.params()):
names.append(f'stage0_b{bi}_p{i}')
# pool
for i, _ in enumerate(self.pool_ln.params()):
names.append(f'pool_ln_p{i}')
for i, _ in enumerate(self.pool_proj.params()):
names.append(f'pool_proj_p{i}')
# stage1
for bi, blk in enumerate(self.stage1):
for i, _ in enumerate(blk.params()):
names.append(f'stage1_b{bi}_p{i}')
# head
for i, _ in enumerate(self.head_mu.params()):
names.append(f'head_mu_p{i}')
for i, _ in enumerate(self.head_logvar.params()):
names.append(f'head_logvar_p{i}')
# dec_linear
for i, _ in enumerate(self.dec_linear.params()):
names.append(f'dec_linear_p{i}')
# stage1d
for bi, blk in enumerate(self.stage1d):
for i, _ in enumerate(blk.params()):
names.append(f'stage1d_b{bi}_p{i}')
# up
for i, _ in enumerate(self.up_proj.params()):
names.append(f'up_proj_p{i}')
for i, _ in enumerate(self.up_ln.params()):
names.append(f'up_ln_p{i}')
# stage0d
for bi, blk in enumerate(self.stage0d):
for i, _ in enumerate(blk.params()):
names.append(f'stage0d_b{bi}_p{i}')
# out_proj
for i, _ in enumerate(self.out_proj.params()):
names.append(f'out_proj_p{i}')
return names
# ---------------------------------------------------------------------------
# β-TCVAE loss
# ---------------------------------------------------------------------------
def btcvae_loss(
x: np.ndarray,
x_recon: np.ndarray,
z_mu: np.ndarray,
z_logvar: np.ndarray,
z: np.ndarray,
beta_tc: float = 4.0,
alpha_mi: float = 1.0,
free_bits: float = 0.0,
) -> tuple[float, dict]:
"""
β-TCVAE loss with minibatch TC estimator.
Parameters
----------
x : (B, C, T) input
x_recon : (B, C, T) reconstruction
z_mu : (B, D)
z_logvar : (B, D)
z : (B, D) reparameterized sample
beta_tc : weight on Total Correlation term
alpha_mi : weight on Mutual Information term
free_bits : minimum KL per latent dimension (nats). Replaces the minibatch
dimKL term with analytical per-dim KL clamped from below at
free_bits. Prevents posterior collapse. 0 = standard behaviour.
Returns
-------
loss : scalar
info : dict with breakdown (recon, MI, TC, dimKL)
"""
B, D = z.shape
# Reconstruction loss (per-sample mean, then mean over batch)
recon = ((x - x_recon) ** 2).mean(axis=(-1, -2)).mean()
# log q(z|x): (B,)
log_qz_x = -0.5 * (
np.log(2 * np.pi) + z_logvar + (z - z_mu) ** 2 * np.exp(-z_logvar)
).sum(axis=1)
# log p(z): (B,) — standard Gaussian
log_pz = -0.5 * (np.log(2 * np.pi) + z ** 2).sum(axis=1)
# Minibatch estimator for log q(z) and log prod_d q(z_d)
z_i = z[:, None, :] # (B, 1, D)
mu_j = z_mu[None, :, :] # (1, B, D)
lv_j = z_logvar[None, :, :] # (1, B, D)
# (B, B, D)
log_q_ij_d = -0.5 * (
np.log(2 * np.pi) + lv_j + (z_i - mu_j) ** 2 * np.exp(-lv_j)
)
# log q(z_i) = logsumexp_j[sum_d log_q_ij_d] - log B
log_qz = _logsumexp(log_q_ij_d.sum(-1), axis=1) - np.log(B) # (B,)
# log prod_d q(z_d[i]) = sum_d logsumexp_j[log_q_ij_d] - D*log(B)
log_qz_prod = (_logsumexp(log_q_ij_d, axis=1) - np.log(B)).sum(-1) # (B,)
MI = (log_qz_x - log_qz).mean()
TC = (log_qz - log_qz_prod).mean()
if free_bits > 0.0:
# Analytical per-dim KL, clamped: sum_d max(free_bits, KL_d)
kl_analytic = 0.5 * (z_mu ** 2 + np.exp(z_logvar) - 1.0 - z_logvar)
kl_per_dim = kl_analytic.mean(0) # (D,)
dimKL = float(np.maximum(free_bits, kl_per_dim).sum())
else:
dimKL = float((log_qz_prod - log_pz).mean())
loss = recon + alpha_mi * MI + beta_tc * TC + dimKL
return float(loss), {
'recon': float(recon),
'MI': float(MI),
'TC': float(TC),
'dimKL': dimKL,
}
# ---------------------------------------------------------------------------
# β-TCVAE backward
# ---------------------------------------------------------------------------
def btcvae_loss_backward(
x: np.ndarray,
x_recon: np.ndarray,
z_mu: np.ndarray,
z_logvar: np.ndarray,
z: np.ndarray,
eps: np.ndarray,
beta_tc: float = 4.0,
alpha_mi: float = 1.0,
free_bits: float = 0.0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute gradients of β-TCVAE loss w.r.t. x_recon, z_mu, z_logvar.
Parameters
----------
eps: (B, D) noise used in reparameterisation: z = z_mu + eps*exp(0.5*z_logvar)
Returns
-------
d_recon : (B, C, T) grad w.r.t. x_recon
d_z_mu : (B, D) grad w.r.t. z_mu
d_z_logvar: (B, D) grad w.r.t. z_logvar
"""
B, D = z.shape
var = np.exp(z_logvar) # (B, D)
std = np.exp(0.5 * z_logvar) # (B, D)
# ---- grad w.r.t. x_recon (reconstruction MSE) ----
# L_recon = mean_{b,c,t} (x-x_recon)^2 = (1/(B*C*T)) * ||x - x_recon||^2
# dL/dx_recon = -2*(x - x_recon) / (B*C*T)
n_elements = x.size
d_recon = -2.0 * (x - x_recon) / n_elements # shape (B, C, T)
# ---- KL terms via minibatch estimator ----
# log q(z_i | x_i) (B,D) terms:
err = (z - z_mu) # (B, D)
log_qz_x_d = -0.5 * (np.log(2 * np.pi) + z_logvar + err ** 2 / var)
# Minibatch terms
z_i = z[:, None, :]
mu_j = z_mu[None, :, :]
lv_j = z_logvar[None, :, :]
diff_ij = z_i - mu_j # (B, B, D)
var_j = np.exp(lv_j) # (B, B, D)
log_q_ij_d = -0.5 * (np.log(2 * np.pi) + lv_j + diff_ij ** 2 / var_j) # (B,B,D)
# Weight matrices for log_qz and log_qz_prod
# w_qz[i, j] = softmax_j[ sum_d log_q_ij_d[i, j, :] ]
lq_sum = log_q_ij_d.sum(-1) # (B, B)
lq_sum_max = lq_sum.max(1, keepdims=True)
w_qz = np.exp(lq_sum - lq_sum_max)
w_qz /= w_qz.sum(1, keepdims=True) + 1e-30 # (B, B) softmax over j
# w_prod[i, j, d] = softmax_j[ log_q_ij_d[i, j, d] ] for each d separately
lq_max_d = log_q_ij_d.max(1, keepdims=True) # (B, 1, D)
w_prod = np.exp(log_q_ij_d - lq_max_d)
w_prod /= w_prod.sum(1, keepdims=True) + 1e-30 # (B, B, D)
# ---- d(MI + TC + dimKL) / d(z_logvar) and d(z_mu) ----
# MI = mean_i [ log_qz_x[i] - log_qz[i] ]
# TC = mean_i [ log_qz[i] - log_qz_prod[i] ]
# dimKL = mean_i [ log_qz_prod[i] - log_pz[i] ]
#
# Combined coefficient of each log term:
# d/dtheta = (1/B) * [ alpha * d_log_qz_x - alpha * d_log_qz
# + beta * d_log_qz - beta * d_log_qz_prod
# + d_log_qz_prod - d_log_pz ]
# d log_qz_x[i] / d z_mu[i, d] = err[i,d] / var[i,d] (positive)
# d log_qz_x[i] / d z_logvar[i,d] = -0.5 + 0.5 * err[i,d]^2 / var[i,d]
# d log_qz[i] / d z_mu[j, d] = -w_qz[i, j] * diff_ij[i, j, d] / var_j[i,j,d]
# (gradient w.r.t. z_mu is aggregated over i by summing)
# (note: this is through the *denominator* of q(z_i|x_j), not z itself)
# d log_qz_prod[i,d] / d z_mu[j,d] = -w_prod[i,j,d] * diff_ij[i,j,d]/var_j[i,j,d]
# Additionally, z = mu + eps*std, so d_log_qz/d_z_mu has a DIRECT path via z_i:
# d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (-(z[i,d] - mu_j[d]) / var_j[i,j,d])
# d z[i,d] / d mu[i,d] = 1 → direct path
# For cleaner code, compute:
# G_mu[j, d] = sum_i (alpha_mi + beta_tc) * dlog_qz/dmu_j[d] (through z_mu[j])
# + sum_i (1+beta_tc) * dlog_qz_prod/dmu_j[d] (through z_mu[j])
# Plus direct terms from d_log_qz_x
# Direct terms from log q(z|x):
# d log_qz_x[i,d] / d mu[i,d] = (z[i,d] - mu[i,d]) / var[i,d]
# With reparameterization: d z/d mu = 1, so also flows through z
# d log_qz_x[i,d] / d z[i,d] = -err[i,d] / var[i,d] → d mu[i,d]
# So total: +err/var + (-err/var) from z-path = 0 for the z-direct path
# Actually: d loss / d mu = d loss / d z * dz/dmu + d loss / d mu directly
# We'll compute d_loss/d_z separately and use reparameterization to split.
# ---- Gradient through reparameterization ----
# L is a function of z_mu, z_logvar, and z = z_mu + eps*std.
# dL/d_mu[i,d] = dL/d_z[i,d] * 1 + (dL/d_mu directly)
# dL/d_logvar[i,d]= dL/d_z[i,d] * 0.5*eps[i,d]*std[i,d] + (dL/d_logvar directly)
# Compute dL/d_z[i,d]:
# 1. Reconstruction path (decoder): computed elsewhere and passed in as dz_from_recon
# (we return d_recon so caller can backprop decoder separately)
# Here we only handle KL/TC terms which depend on z.
# 2. From log_qz_x[i] = sum_d log_qz_x_d[i,d]:
# d log_qz_x[i] / d z[i,d] = -(z[i,d] - mu[i,d]) / var[i,d] = -err/var
d_logqzx_d_z = -err / var # (B, D)
# 3. From log_qz[i] via z: (z[i,d] enters log_qz[i] through diff_ij[i,:,d])
# d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (- diff_ij[i,j,d] / var_j[i,j,d])
# But wait: diff_ij[i,j,d] = z[i,d] - mu_j[d], so d(diff_ij)/d(z[i,d]) = 1
d_logqz_d_z = -(w_qz[:, :, None] * diff_ij / var_j).sum(1) # (B, D)
# 4. From log_qz_prod[i] via z:
# d log_qz_prod[i] / d z[i,d] = sum_j w_prod[i,j,d] * (-diff_ij[i,j,d]/var_j[i,j,d])
d_logqzprod_d_z = -(w_prod * diff_ij / var_j).sum(1) # (B, D)
# 5. From log_pz: d log_pz / d z = -z
d_logpz_d_z = -z # (B, D)
# Combine with coefficients — MI + TC terms only; dimKL handled separately below
# when free_bits > 0: skip ALL MI/TC aggregate-posterior terms to prevent
# collapse. Only the analytical per-dim KL (masked) + direct encoder term
# (alpha_mi * err/var below) remain. This is a plain VAE-with-free-bits mode.
# when free_bits == 0: full beta-TCVAE with minibatch TC/MI decomposition.
if free_bits > 0.0:
dL_dz = np.zeros((B, D), dtype=np.float64) # no MI/TC reparameterization gradient
else:
dL_dz = (1.0 / B) * (
alpha_mi * d_logqzx_d_z
- alpha_mi * d_logqz_d_z
+ beta_tc * d_logqz_d_z
- beta_tc * d_logqzprod_d_z
+ d_logqzprod_d_z
- d_logpz_d_z
) # (B, D)
# Gradient through z_mu and z_logvar VIA z (reparameterization):
d_z_mu = dL_dz.copy()
d_z_logvar = dL_dz * 0.5 * eps * std
# Direct gradient w.r.t. z_mu from log q(z|x):
d_z_mu += (alpha_mi / B) * err / var
# Direct gradient w.r.t. z_logvar from log q(z|x):
d_z_logvar += (alpha_mi / B) * (-0.5 + 0.5 * err ** 2 / var)
# Gradient of log_qz and log_qz_prod w.r.t. z_mu[j] directly
d_logqz_d_mu_direct = (
(w_qz[:, :, None] * diff_ij / var_j).sum(0) # (B, D)
)
d_logqzprod_d_mu_direct = (
(w_prod * diff_ij / var_j).sum(0) # (B, D)
)
if free_bits > 0.0:
pass # no MI/TC direct-to-mu gradient in plain-VAE-with-free-bits mode
else:
d_z_mu += (1.0 / B) * (
-alpha_mi * d_logqz_d_mu_direct
+ beta_tc * d_logqz_d_mu_direct
- beta_tc * d_logqzprod_d_mu_direct
+ d_logqzprod_d_mu_direct
)
d_logqz_d_lv_direct = (
w_qz[:, :, None] * (-0.5 + 0.5 * diff_ij ** 2 / var_j)
).sum(0) # (B, D)
d_logqzprod_d_lv_direct = (
w_prod * (-0.5 + 0.5 * diff_ij ** 2 / var_j)
).sum(0) # (B, D)
if free_bits > 0.0:
# no MI/TC direct-to-logvar gradient in plain-VAE-with-free-bits mode
# Analytical free-bits dimKL gradient: only for dims above the floor
kl_analytic = 0.5 * (z_mu ** 2 + var - 1.0 - z_logvar) # (B, D)
kl_per_dim = kl_analytic.mean(0) # (D,)
mask = (kl_per_dim >= free_bits).astype(np.float64) # (D,) — 1 = penalised
d_z_mu += mask[None, :] * z_mu / B
d_z_logvar += mask[None, :] * 0.5 * (var - 1.0) / B
else:
d_z_logvar += (1.0 / B) * (
-alpha_mi * d_logqz_d_lv_direct
+ beta_tc * d_logqz_d_lv_direct
- beta_tc * d_logqzprod_d_lv_direct
+ d_logqzprod_d_lv_direct
)
return d_recon, d_z_mu, d_z_logvar