1049 lines
39 KiB
Python
1049 lines
39 KiB
Python
|
|
"""
|
|||
|
|
ConvNeXt-1D β-TCVAE — pure numpy (no PyTorch required).
|
|||
|
|
========================================================
|
|||
|
|
|
|||
|
|
Architecture
|
|||
|
|
------------
|
|||
|
|
Input : (B, C_in=8, T=32) — 8 eigenvalue channels × 32 timestep window
|
|||
|
|
Stem : PointwiseProj(8→32) + LayerNorm
|
|||
|
|
Stage 0: 3 × ConvNeXtBlock1D(32, dw_k=7)
|
|||
|
|
Pool : AdaptiveAvgPool(16) + PointwiseProj(32→64)
|
|||
|
|
Stage 1: 3 × ConvNeXtBlock1D(64, dw_k=7)
|
|||
|
|
Head : GlobalAvgPool → Linear(64→32) for z_mu / z_logvar
|
|||
|
|
|
|||
|
|
Decoder (mirrored):
|
|||
|
|
Expand : Linear(32→64) → repeat 16 times along T axis
|
|||
|
|
Stage 1D: 3 × ConvNeXtBlock1D(64, dw_k=7)
|
|||
|
|
Upsample: RepeatInterleave(×2) + PointwiseProj(64→32)
|
|||
|
|
Stage 0D: 3 × ConvNeXtBlock1D(32, dw_k=7)
|
|||
|
|
Out : PointwiseProj(32→8)
|
|||
|
|
|
|||
|
|
Objective: β-TCVAE (Chen et al. 2018, "Isolating Sources of Disentanglement")
|
|||
|
|
β_tc=4.0, α_mi=1.0. Minibatch TC estimator.
|
|||
|
|
|
|||
|
|
Gradient implementation notes
|
|||
|
|
------------------------------
|
|||
|
|
All depthwise convolutions use the im2col trick so that gradients reduce to
|
|||
|
|
standard matmul / einsum — no custom numerical differentiation required.
|
|||
|
|
|
|||
|
|
DWConv1d forward : windows[b,c,t,k] = x_pad[b,c,t+k]
|
|||
|
|
out[b,c,t] = sum_k windows * w[c,k] + b[c]
|
|||
|
|
DWConv1d backward: grad_w[c] = einsum('btk,bt->k', windows_c, grad_out_c) over B
|
|||
|
|
col2im for grad_x (accumulate k-shifted copies of grad_out * w)
|
|||
|
|
|
|||
|
|
β-TCVAE TC gradient: computed via minibatch weight matrix w[i,j] ∝ q(z_i|x_j)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
import json
|
|||
|
|
import numpy as np
|
|||
|
|
from scipy.special import erf
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from numba import njit as _njit
|
|||
|
|
_NUMBA = True
|
|||
|
|
except ImportError:
|
|||
|
|
_NUMBA = False
|
|||
|
|
def _njit(fn): # no-op decorator
|
|||
|
|
return fn
|
|||
|
|
|
|||
|
|
__all__ = ['ConvNeXtVAE', 'btcvae_loss', 'btcvae_loss_backward']
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Utility
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _kaiming(fan_in: int, fan_out: int, rng: np.random.RandomState) -> np.ndarray:
|
|||
|
|
std = np.sqrt(2.0 / fan_in)
|
|||
|
|
return rng.randn(fan_in, fan_out) * std
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _gelu(x: np.ndarray) -> np.ndarray:
|
|||
|
|
return 0.5 * x * (1.0 + erf(x / np.sqrt(2.0)))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _gelu_grad(x: np.ndarray) -> np.ndarray:
|
|||
|
|
# Two allocations instead of three: out= avoids the final x*pdf temporary
|
|||
|
|
cdf = 0.5 * (1.0 + erf(x * (1.0 / np.sqrt(2.0))))
|
|||
|
|
pdf = np.exp(-0.5 * x * x)
|
|||
|
|
pdf *= 1.0 / np.sqrt(2.0 * np.pi)
|
|||
|
|
np.multiply(x, pdf, out=pdf) # pdf now holds x * pdf, in-place
|
|||
|
|
pdf += cdf # pdf now holds cdf + x*pdf
|
|||
|
|
return pdf
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _logsumexp(a: np.ndarray, axis: int) -> np.ndarray:
|
|||
|
|
a_max = a.max(axis=axis, keepdims=True)
|
|||
|
|
out = np.log(np.exp(a - a_max).sum(axis=axis, keepdims=True)) + a_max
|
|||
|
|
return out.squeeze(axis)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Parameter container — stores weight, gradient, Adam state
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
class Param:
|
|||
|
|
__slots__ = ('data', 'grad', '_m', '_v', 't')
|
|||
|
|
|
|||
|
|
def __init__(self, data: np.ndarray):
|
|||
|
|
self.data = data.astype(np.float64)
|
|||
|
|
self.grad = np.zeros_like(self.data)
|
|||
|
|
self._m = np.zeros_like(self.data)
|
|||
|
|
self._v = np.zeros_like(self.data)
|
|||
|
|
self.t = 0
|
|||
|
|
|
|||
|
|
def zero_grad(self):
|
|||
|
|
self.grad.fill(0.0)
|
|||
|
|
|
|||
|
|
def adam_step(self, lr: float, β1: float = 0.9, β2: float = 0.999,
|
|||
|
|
eps: float = 1e-8, wd: float = 0.0):
|
|||
|
|
self.t += 1
|
|||
|
|
g = self.grad
|
|||
|
|
if wd > 0:
|
|||
|
|
g = g + wd * self.data
|
|||
|
|
self._m = β1 * self._m + (1 - β1) * g
|
|||
|
|
self._v = β2 * self._v + (1 - β2) * g * g
|
|||
|
|
m_hat = self._m / (1 - β1 ** self.t)
|
|||
|
|
v_hat = self._v / (1 - β2 ** self.t)
|
|||
|
|
self.data -= lr * m_hat / (np.sqrt(v_hat) + eps)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# im2col helpers
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _im2col1d(x: np.ndarray, k: int, pad: int) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
x : (B, C, T)
|
|||
|
|
returns: (B, C, T, k) — zero-padded windows
|
|||
|
|
|
|||
|
|
Uses np.lib.stride_tricks.as_strided for zero-copy window view,
|
|||
|
|
then immediately makes a contiguous copy so writes to the returned
|
|||
|
|
array never corrupt x_pad.
|
|||
|
|
"""
|
|||
|
|
B, C, T = x.shape
|
|||
|
|
x_pad = np.pad(x, ((0, 0), (0, 0), (pad, pad)))
|
|||
|
|
# stride-trick view: step one element along T axis per k-position
|
|||
|
|
shape = (B, C, T, k)
|
|||
|
|
s0, s1, s2 = x_pad.strides
|
|||
|
|
strides = (s0, s1, s2, s2)
|
|||
|
|
windows = np.lib.stride_tricks.as_strided(x_pad, shape=shape, strides=strides)
|
|||
|
|
return windows.copy() # contiguous copy — safe for in-place ops downstream
|
|||
|
|
|
|||
|
|
|
|||
|
|
@_njit(cache=True)
|
|||
|
|
def _col2im1d_kernel(g_pad, grad_win, T, k):
|
|||
|
|
"""Numba-JIT accumulation loop for col2im."""
|
|||
|
|
for i in range(k):
|
|||
|
|
g_pad[:, :, i:i + T] += grad_win[:, :, :, i]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _col2im1d(grad_win: np.ndarray, T: int, k: int, pad: int) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
grad_win: (B, C, T, k)
|
|||
|
|
returns : (B, C, T)
|
|||
|
|
"""
|
|||
|
|
B, C = grad_win.shape[0], grad_win.shape[1]
|
|||
|
|
g_pad = np.zeros((B, C, T + 2 * pad))
|
|||
|
|
_col2im1d_kernel(g_pad, grad_win, T, k)
|
|||
|
|
return g_pad[:, :, pad:pad + T]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Layer: DWConv1d (depthwise conv via im2col)
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
class DWConv1d:
|
|||
|
|
def __init__(self, C: int, k: int, rng: np.random.RandomState):
|
|||
|
|
self.C, self.k, self.pad = C, k, k // 2
|
|||
|
|
self.w = Param(rng.randn(C, k) * 0.02)
|
|||
|
|
self.b = Param(np.zeros(C))
|
|||
|
|
self._cache: tuple | None = None
|
|||
|
|
|
|||
|
|
def params(self) -> list[Param]:
|
|||
|
|
return [self.w, self.b]
|
|||
|
|
|
|||
|
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
|||
|
|
"""x: (B, C, T) → (B, C, T)"""
|
|||
|
|
wins = _im2col1d(x, self.k, self.pad) # (B, C, T, k)
|
|||
|
|
self._cache = (wins,)
|
|||
|
|
# einsum with optimize: uses BLAS when possible
|
|||
|
|
out = np.einsum('bctk,ck->bct', wins, self.w.data, optimize=True)
|
|||
|
|
out += self.b.data[None, :, None]
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
|||
|
|
"""grad_out: (B, C, T) → grad_x: (B, C, T)"""
|
|||
|
|
wins, = self._cache
|
|||
|
|
B, C, T = grad_out.shape
|
|||
|
|
# grad w: einsum over B and T
|
|||
|
|
self.w.grad += np.einsum('bct,bctk->ck', grad_out, wins, optimize=True)
|
|||
|
|
# grad b
|
|||
|
|
self.b.grad += grad_out.sum(axis=(0, 2)) # (C,)
|
|||
|
|
# grad x via col2im
|
|||
|
|
grad_win = np.einsum('bct,ck->bctk', grad_out, self.w.data, optimize=True)
|
|||
|
|
return _col2im1d(grad_win, T, self.k, self.pad)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Layer: LayerNorm (over last axis, applied to (B, T, C) or (B, C))
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
class LayerNorm:
|
|||
|
|
EPS = 1e-6
|
|||
|
|
|
|||
|
|
def __init__(self, D: int, rng: np.random.RandomState | None = None):
|
|||
|
|
self.D = D
|
|||
|
|
self.gamma = Param(np.ones(D))
|
|||
|
|
self.beta = Param(np.zeros(D))
|
|||
|
|
self._cache: tuple | None = None
|
|||
|
|
|
|||
|
|
def params(self) -> list[Param]:
|
|||
|
|
return [self.gamma, self.beta]
|
|||
|
|
|
|||
|
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
|||
|
|
"""Normalize last axis."""
|
|||
|
|
mean = x.mean(axis=-1, keepdims=True)
|
|||
|
|
var = x.var(axis=-1, keepdims=True)
|
|||
|
|
x_hat = (x - mean) / np.sqrt(var + self.EPS)
|
|||
|
|
self._cache = (x_hat, var)
|
|||
|
|
return self.gamma.data * x_hat + self.beta.data
|
|||
|
|
|
|||
|
|
def backward(self, dy: np.ndarray) -> np.ndarray:
|
|||
|
|
x_hat, var = self._cache
|
|||
|
|
D = x_hat.shape[-1]
|
|||
|
|
orig_shape = dy.shape
|
|||
|
|
|
|||
|
|
dy_2d = dy.reshape(-1, D)
|
|||
|
|
xhat_2d = x_hat.reshape(-1, D)
|
|||
|
|
var_2d = var.reshape(-1, 1)
|
|||
|
|
|
|||
|
|
self.gamma.grad += (dy_2d * xhat_2d).sum(0)
|
|||
|
|
self.beta.grad += dy_2d.sum(0)
|
|||
|
|
|
|||
|
|
dx_hat = dy_2d * self.gamma.data # (N, D)
|
|||
|
|
std_inv = 1.0 / np.sqrt(var_2d + self.EPS) # (N, 1)
|
|||
|
|
dx = std_inv * (dx_hat
|
|||
|
|
- dx_hat.mean(-1, keepdims=True)
|
|||
|
|
- xhat_2d * (dx_hat * xhat_2d).mean(-1, keepdims=True))
|
|||
|
|
return dx.reshape(orig_shape)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Layer: Linear
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
class Linear:
|
|||
|
|
def __init__(self, in_f: int, out_f: int, rng: np.random.RandomState,
|
|||
|
|
bias: bool = True):
|
|||
|
|
self.W = Param(_kaiming(in_f, out_f, rng))
|
|||
|
|
self.b = Param(np.zeros(out_f)) if bias else None
|
|||
|
|
self._x_cache: np.ndarray | None = None
|
|||
|
|
|
|||
|
|
def params(self) -> list[Param]:
|
|||
|
|
return [self.W] + ([self.b] if self.b is not None else [])
|
|||
|
|
|
|||
|
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
|||
|
|
self._x_cache = x
|
|||
|
|
out = x @ self.W.data
|
|||
|
|
if self.b is not None:
|
|||
|
|
out = out + self.b.data
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
def backward(self, dy: np.ndarray) -> np.ndarray:
|
|||
|
|
x = self._x_cache
|
|||
|
|
x2 = x.reshape(-1, self.W.data.shape[0])
|
|||
|
|
dy2 = dy.reshape(-1, self.W.data.shape[1])
|
|||
|
|
self.W.grad += x2.T @ dy2
|
|||
|
|
if self.b is not None:
|
|||
|
|
self.b.grad += dy2.sum(0)
|
|||
|
|
return (dy2 @ self.W.data.T).reshape(x.shape)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Block: ConvNeXtBlock1D
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Input/output: (B, C, T)
|
|||
|
|
# 1. DWConv1d(C, k)
|
|||
|
|
# 2. Permute → (B, T, C)
|
|||
|
|
# 3. LayerNorm(C)
|
|||
|
|
# 4. Linear(C → 4C) + GELU
|
|||
|
|
# 5. Linear(4C → C)
|
|||
|
|
# 6. Permute → (B, C, T)
|
|||
|
|
# 7. Skip
|
|||
|
|
|
|||
|
|
class ConvNeXtBlock1D:
|
|||
|
|
def __init__(self, C: int, k: int, rng: np.random.RandomState):
|
|||
|
|
self.dwconv = DWConv1d(C, k, rng)
|
|||
|
|
self.ln = LayerNorm(C)
|
|||
|
|
self.fc1 = Linear(C, 4 * C, rng)
|
|||
|
|
self.fc2 = Linear(4 * C, C, rng)
|
|||
|
|
self._cache: dict | None = None
|
|||
|
|
|
|||
|
|
def params(self) -> list[Param]:
|
|||
|
|
return (self.dwconv.params() + self.ln.params() +
|
|||
|
|
self.fc1.params() + self.fc2.params())
|
|||
|
|
|
|||
|
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
|||
|
|
"""x: (B, C, T) → (B, C, T)"""
|
|||
|
|
h = self.dwconv.forward(x) # (B, C, T)
|
|||
|
|
h = h.transpose(0, 2, 1) # (B, T, C)
|
|||
|
|
h = self.ln.forward(h) # (B, T, C)
|
|||
|
|
h_fc1 = self.fc1.forward(h) # (B, T, 4C) — pre-GELU
|
|||
|
|
act = _gelu(h_fc1)
|
|||
|
|
h = self.fc2.forward(act) # (B, T, C)
|
|||
|
|
h = h.transpose(0, 2, 1) # (B, C, T)
|
|||
|
|
self._pre_gelu = h_fc1 # cache for backward
|
|||
|
|
return x + h # skip
|
|||
|
|
|
|||
|
|
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
|||
|
|
"""grad_out: (B, C, T) → grad_x: (B, C, T)"""
|
|||
|
|
# skip: grad flows unchanged to x, AND through the block
|
|||
|
|
grad_block = grad_out.copy()
|
|||
|
|
# 6. permute (B, C, T) → (B, T, C)
|
|||
|
|
grad_block = grad_block.transpose(0, 2, 1)
|
|||
|
|
# 5. fc2 backward → grad w.r.t. GELU output (act), shape (B, T, 4C)
|
|||
|
|
grad_block = self.fc2.backward(grad_block)
|
|||
|
|
# 4. GELU backward (pre_gelu cached during forward)
|
|||
|
|
grad_block = grad_block * _gelu_grad(self._pre_gelu)
|
|||
|
|
# 4. fc1 backward
|
|||
|
|
grad_block = self.fc1.backward(grad_block)
|
|||
|
|
# 3. LN backward
|
|||
|
|
grad_block = self.ln.backward(grad_block)
|
|||
|
|
# 2. permute (B, T, C) → (B, C, T)
|
|||
|
|
grad_block = grad_block.transpose(0, 2, 1)
|
|||
|
|
# 1. DWConv backward
|
|||
|
|
grad_block = self.dwconv.backward(grad_block)
|
|||
|
|
return grad_out + grad_block # skip: grad_x = grad_out + grad_block
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Pool: AdaptiveAvgPool1D (fixed stride, e.g. T=32→T=16 stride=2)
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
class AvgPool1D:
|
|||
|
|
"""Stride-2 average pooling over temporal axis."""
|
|||
|
|
def __init__(self, stride: int = 2):
|
|||
|
|
self.stride = stride
|
|||
|
|
self._T_in: int | None = None
|
|||
|
|
|
|||
|
|
def params(self) -> list[Param]:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
|||
|
|
"""x: (B, C, T) → (B, C, T//stride)"""
|
|||
|
|
B, C, T = x.shape
|
|||
|
|
self._T_in = T
|
|||
|
|
T_out = T // self.stride
|
|||
|
|
# reshape and mean
|
|||
|
|
return x[:, :, :T_out * self.stride].reshape(B, C, T_out, self.stride).mean(-1)
|
|||
|
|
|
|||
|
|
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
|||
|
|
B, C, T_out = grad_out.shape
|
|||
|
|
T_in = self._T_in
|
|||
|
|
# each output contributes 1/stride to each input in the pool
|
|||
|
|
g = grad_out[:, :, :, None] * np.ones((1, 1, 1, self.stride)) / self.stride
|
|||
|
|
g = g.reshape(B, C, T_out * self.stride)
|
|||
|
|
if T_in > T_out * self.stride:
|
|||
|
|
g = np.pad(g, ((0, 0), (0, 0), (0, T_in - T_out * self.stride)))
|
|||
|
|
return g
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Upsample: nearest-neighbor ×2 along temporal axis
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
class Upsample1D:
|
|||
|
|
"""Nearest-neighbor upsampling ×factor along temporal axis."""
|
|||
|
|
def __init__(self, factor: int = 2):
|
|||
|
|
self.factor = factor
|
|||
|
|
|
|||
|
|
def params(self) -> list[Param]:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
|||
|
|
return np.repeat(x, self.factor, axis=-1)
|
|||
|
|
|
|||
|
|
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
|||
|
|
f = self.factor
|
|||
|
|
B, C, T = grad_out.shape
|
|||
|
|
T_in = T // f
|
|||
|
|
return grad_out.reshape(B, C, T_in, f).sum(-1)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# PointwiseProj: 1×1 conv (= Linear applied per timestep)
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
class PointwiseProj:
|
|||
|
|
"""(B, C_in, T) → (B, C_out, T) via shared Linear(C_in, C_out)."""
|
|||
|
|
def __init__(self, C_in: int, C_out: int, rng: np.random.RandomState):
|
|||
|
|
self.lin = Linear(C_in, C_out, rng)
|
|||
|
|
|
|||
|
|
def params(self) -> list[Param]:
|
|||
|
|
return self.lin.params()
|
|||
|
|
|
|||
|
|
def forward(self, x: np.ndarray) -> np.ndarray:
|
|||
|
|
"""x: (B, C_in, T)"""
|
|||
|
|
B, C, T = x.shape
|
|||
|
|
x_t = x.transpose(0, 2, 1) # (B, T, C_in)
|
|||
|
|
out = self.lin.forward(x_t) # (B, T, C_out)
|
|||
|
|
return out.transpose(0, 2, 1) # (B, C_out, T)
|
|||
|
|
|
|||
|
|
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
|||
|
|
g_t = grad_out.transpose(0, 2, 1) # (B, T, C_out)
|
|||
|
|
g_t = self.lin.backward(g_t) # (B, T, C_in)
|
|||
|
|
return g_t.transpose(0, 2, 1) # (B, C_in, T)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# Main model: ConvNeXtVAE
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
class ConvNeXtVAE:
|
|||
|
|
"""
|
|||
|
|
ConvNeXt-1D β-TCVAE.
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
C_in : input channels (default 8)
|
|||
|
|
T_in : input timesteps (default 32)
|
|||
|
|
z_dim : latent dimensionality (default 32)
|
|||
|
|
base_ch : stem channel width (default 32)
|
|||
|
|
dw_k : depthwise kernel size (default 7)
|
|||
|
|
n_blocks : ConvNeXt blocks per stage (default 3)
|
|||
|
|
seed : RNG seed
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
C_in: int = 8,
|
|||
|
|
T_in: int = 32,
|
|||
|
|
z_dim: int = 32,
|
|||
|
|
base_ch: int = 32,
|
|||
|
|
dw_k: int = 7,
|
|||
|
|
n_blocks: int = 3,
|
|||
|
|
seed: int = 42,
|
|||
|
|
):
|
|||
|
|
rng = np.random.RandomState(seed)
|
|||
|
|
self.C_in = C_in
|
|||
|
|
self.T_in = T_in
|
|||
|
|
self.z_dim = z_dim
|
|||
|
|
self.base_ch = base_ch
|
|||
|
|
ch1 = base_ch * 2 # 64
|
|||
|
|
|
|||
|
|
# ---- ENCODER ----
|
|||
|
|
# Stem
|
|||
|
|
self.stem_proj = PointwiseProj(C_in, base_ch, rng)
|
|||
|
|
self.stem_ln = LayerNorm(base_ch)
|
|||
|
|
|
|||
|
|
# Stage 0
|
|||
|
|
self.stage0 = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)]
|
|||
|
|
|
|||
|
|
# Pool + channel proj
|
|||
|
|
self.pool = AvgPool1D(stride=2)
|
|||
|
|
self.pool_ln = LayerNorm(base_ch)
|
|||
|
|
self.pool_proj = PointwiseProj(base_ch, ch1, rng)
|
|||
|
|
|
|||
|
|
# Stage 1
|
|||
|
|
self.stage1 = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)]
|
|||
|
|
|
|||
|
|
# Global avg pool → z_mu, z_logvar
|
|||
|
|
self.head_mu = Linear(ch1, z_dim, rng)
|
|||
|
|
self.head_logvar = Linear(ch1, z_dim, rng)
|
|||
|
|
|
|||
|
|
# ---- DECODER ----
|
|||
|
|
T_bot = T_in // 2 # bottleneck temporal size (16 for T_in=32)
|
|||
|
|
self.dec_linear = Linear(z_dim, ch1, rng)
|
|||
|
|
|
|||
|
|
# Stage 1D
|
|||
|
|
self.stage1d = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)]
|
|||
|
|
|
|||
|
|
# Upsample + channel proj
|
|||
|
|
self.up = Upsample1D(factor=2)
|
|||
|
|
self.up_proj = PointwiseProj(ch1, base_ch, rng)
|
|||
|
|
self.up_ln = LayerNorm(base_ch)
|
|||
|
|
|
|||
|
|
# Stage 0D
|
|||
|
|
self.stage0d = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)]
|
|||
|
|
|
|||
|
|
# Output projection
|
|||
|
|
self.out_proj = PointwiseProj(base_ch, C_in, rng)
|
|||
|
|
|
|||
|
|
# Cache for backward
|
|||
|
|
self._enc_cache: dict = {}
|
|||
|
|
self._dec_cache: dict = {}
|
|||
|
|
self._T_bot = T_bot
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def all_params(self) -> list[Param]:
|
|||
|
|
ps = []
|
|||
|
|
ps += self.stem_proj.params() + self.stem_ln.params()
|
|||
|
|
for blk in self.stage0:
|
|||
|
|
ps += blk.params()
|
|||
|
|
ps += (self.pool_ln.params() + self.pool_proj.params())
|
|||
|
|
for blk in self.stage1:
|
|||
|
|
ps += blk.params()
|
|||
|
|
ps += self.head_mu.params() + self.head_logvar.params()
|
|||
|
|
ps += self.dec_linear.params()
|
|||
|
|
for blk in self.stage1d:
|
|||
|
|
ps += blk.params()
|
|||
|
|
ps += self.up_proj.params() + self.up_ln.params()
|
|||
|
|
for blk in self.stage0d:
|
|||
|
|
ps += blk.params()
|
|||
|
|
ps += self.out_proj.params()
|
|||
|
|
return ps
|
|||
|
|
|
|||
|
|
def zero_grad(self):
|
|||
|
|
for p in self.all_params():
|
|||
|
|
p.zero_grad()
|
|||
|
|
|
|||
|
|
def adam_step(self, lr: float, wd: float = 1e-4):
|
|||
|
|
for p in self.all_params():
|
|||
|
|
p.adam_step(lr, wd=wd)
|
|||
|
|
|
|||
|
|
def n_params(self) -> int:
|
|||
|
|
return sum(p.data.size for p in self.all_params())
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# ENCODER
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def encode(self, x: np.ndarray):
|
|||
|
|
"""
|
|||
|
|
x: (B, C_in, T_in)
|
|||
|
|
Returns z_mu (B,z), z_logvar (B,z)
|
|||
|
|
"""
|
|||
|
|
c = self._enc_cache
|
|||
|
|
c['x'] = x
|
|||
|
|
|
|||
|
|
# Stem
|
|||
|
|
h = self.stem_proj.forward(x) # (B, base_ch, T)
|
|||
|
|
h = h.transpose(0, 2, 1) # (B, T, base_ch)
|
|||
|
|
h = self.stem_ln.forward(h) # (B, T, base_ch)
|
|||
|
|
h = h.transpose(0, 2, 1) # (B, base_ch, T)
|
|||
|
|
c['after_stem'] = h
|
|||
|
|
|
|||
|
|
# Stage 0
|
|||
|
|
for blk in self.stage0:
|
|||
|
|
h = blk.forward(h)
|
|||
|
|
c['after_stage0'] = h
|
|||
|
|
|
|||
|
|
# Pool + proj
|
|||
|
|
h = self.pool.forward(h) # (B, base_ch, T//2)
|
|||
|
|
h = h.transpose(0, 2, 1) # (B, T//2, base_ch)
|
|||
|
|
h = self.pool_ln.forward(h) # (B, T//2, base_ch)
|
|||
|
|
h = h.transpose(0, 2, 1) # (B, base_ch, T//2)
|
|||
|
|
h = self.pool_proj.forward(h) # (B, ch1, T//2)
|
|||
|
|
c['after_pool'] = h
|
|||
|
|
|
|||
|
|
# Stage 1
|
|||
|
|
for blk in self.stage1:
|
|||
|
|
h = blk.forward(h)
|
|||
|
|
c['after_stage1'] = h
|
|||
|
|
|
|||
|
|
# Global avg pool
|
|||
|
|
h_gap = h.mean(axis=-1) # (B, ch1)
|
|||
|
|
c['gap'] = h_gap
|
|||
|
|
|
|||
|
|
z_mu = self.head_mu.forward(h_gap) # (B, z_dim)
|
|||
|
|
z_logvar = self.head_logvar.forward(h_gap) # (B, z_dim)
|
|||
|
|
return z_mu, z_logvar
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# DECODER
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def decode(self, z: np.ndarray) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
z: (B, z_dim)
|
|||
|
|
Returns x_recon: (B, C_in, T_in)
|
|||
|
|
"""
|
|||
|
|
c = self._dec_cache
|
|||
|
|
T_bot = self._T_bot
|
|||
|
|
|
|||
|
|
h = self.dec_linear.forward(z) # (B, ch1)
|
|||
|
|
h = h[:, :, None] * np.ones((1, 1, T_bot)) # broadcast (B, ch1, T_bot)
|
|||
|
|
c['dec_expand'] = h
|
|||
|
|
|
|||
|
|
for blk in self.stage1d:
|
|||
|
|
h = blk.forward(h)
|
|||
|
|
c['after_stage1d'] = h
|
|||
|
|
|
|||
|
|
# Upsample + proj
|
|||
|
|
h = self.up.forward(h) # (B, ch1, T_in)
|
|||
|
|
h = self.up_proj.forward(h) # (B, base_ch, T_in)
|
|||
|
|
h = h.transpose(0, 2, 1) # (B, T_in, base_ch)
|
|||
|
|
h = self.up_ln.forward(h) # (B, T_in, base_ch)
|
|||
|
|
h = h.transpose(0, 2, 1) # (B, base_ch, T_in)
|
|||
|
|
c['after_up'] = h
|
|||
|
|
|
|||
|
|
for blk in self.stage0d:
|
|||
|
|
h = blk.forward(h)
|
|||
|
|
|
|||
|
|
x_recon = self.out_proj.forward(h) # (B, C_in, T_in)
|
|||
|
|
return x_recon
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# BACKWARD (encoder)
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def backward_encode(self, dz_mu: np.ndarray, dz_logvar: np.ndarray):
|
|||
|
|
"""
|
|||
|
|
Given grad w.r.t. z_mu and z_logvar, backprop through encoder.
|
|||
|
|
Returns grad w.r.t. x (not used further, just for completeness).
|
|||
|
|
"""
|
|||
|
|
c = self._enc_cache
|
|||
|
|
|
|||
|
|
# head
|
|||
|
|
d_gap = self.head_mu.backward(dz_mu) + self.head_logvar.backward(dz_logvar)
|
|||
|
|
|
|||
|
|
# global avg pool backward: distribute evenly over T axis
|
|||
|
|
h_after = c['after_stage1']
|
|||
|
|
T_bot = h_after.shape[-1]
|
|||
|
|
d_h = d_gap[:, :, None] / T_bot * np.ones((1, 1, T_bot)) # (B, ch1, T_bot)
|
|||
|
|
|
|||
|
|
# stage 1 backward
|
|||
|
|
for blk in reversed(self.stage1):
|
|||
|
|
d_h = blk.backward(d_h)
|
|||
|
|
|
|||
|
|
# pool_proj backward
|
|||
|
|
d_h = self.pool_proj.backward(d_h) # (B, base_ch, T//2)
|
|||
|
|
# pool_ln backward
|
|||
|
|
d_h = d_h.transpose(0, 2, 1)
|
|||
|
|
d_h = self.pool_ln.backward(d_h)
|
|||
|
|
d_h = d_h.transpose(0, 2, 1)
|
|||
|
|
# pool backward
|
|||
|
|
d_h = self.pool.backward(d_h) # (B, base_ch, T)
|
|||
|
|
|
|||
|
|
# stage 0 backward
|
|||
|
|
for blk in reversed(self.stage0):
|
|||
|
|
d_h = blk.backward(d_h)
|
|||
|
|
|
|||
|
|
# stem backward
|
|||
|
|
d_h = d_h.transpose(0, 2, 1)
|
|||
|
|
d_h = self.stem_ln.backward(d_h)
|
|||
|
|
d_h = d_h.transpose(0, 2, 1)
|
|||
|
|
d_h = self.stem_proj.backward(d_h)
|
|||
|
|
|
|||
|
|
return d_h # grad w.r.t. input x
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# BACKWARD (decoder)
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def backward_decode(self, d_recon: np.ndarray):
|
|||
|
|
"""
|
|||
|
|
d_recon: (B, C_in, T_in) — grad w.r.t. x_recon
|
|||
|
|
Returns grad w.r.t. z.
|
|||
|
|
"""
|
|||
|
|
c = self._dec_cache
|
|||
|
|
T_bot = self._T_bot
|
|||
|
|
|
|||
|
|
# out_proj backward
|
|||
|
|
d_h = self.out_proj.backward(d_recon) # (B, base_ch, T_in)
|
|||
|
|
|
|||
|
|
# stage 0d backward
|
|||
|
|
for blk in reversed(self.stage0d):
|
|||
|
|
d_h = blk.backward(d_h)
|
|||
|
|
|
|||
|
|
# up_ln backward
|
|||
|
|
d_h = d_h.transpose(0, 2, 1)
|
|||
|
|
d_h = self.up_ln.backward(d_h)
|
|||
|
|
d_h = d_h.transpose(0, 2, 1)
|
|||
|
|
# up_proj backward
|
|||
|
|
d_h = self.up_proj.backward(d_h) # (B, ch1, T_in)
|
|||
|
|
# upsample backward
|
|||
|
|
d_h = self.up.backward(d_h) # (B, ch1, T_bot)
|
|||
|
|
|
|||
|
|
# stage 1d backward
|
|||
|
|
for blk in reversed(self.stage1d):
|
|||
|
|
d_h = blk.backward(d_h)
|
|||
|
|
|
|||
|
|
# expand backward: sum over temporal axis (it was broadcast)
|
|||
|
|
d_h_expand = d_h.sum(axis=-1) # (B, ch1)
|
|||
|
|
|
|||
|
|
# dec_linear backward
|
|||
|
|
dz = self.dec_linear.backward(d_h_expand) # (B, z_dim)
|
|||
|
|
return dz
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# FORWARD convenience
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def forward(self, x: np.ndarray):
|
|||
|
|
"""Full forward: encode + reparameterize + decode."""
|
|||
|
|
z_mu, z_logvar = self.encode(x)
|
|||
|
|
eps = np.random.randn(*z_mu.shape)
|
|||
|
|
z = z_mu + eps * np.exp(0.5 * z_logvar)
|
|||
|
|
x_recon = self.decode(z)
|
|||
|
|
return x_recon, z_mu, z_logvar, z
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# SAVE / LOAD
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
def save(self, path: str, norm_mean=None, norm_std=None,
|
|||
|
|
extra: dict | None = None, save_adam: bool = False):
|
|||
|
|
data: dict = {}
|
|||
|
|
params = self.all_params()
|
|||
|
|
names = self._param_names()
|
|||
|
|
for name, p in zip(names, params):
|
|||
|
|
data[name] = p.data.tolist()
|
|||
|
|
if save_adam:
|
|||
|
|
data[f'{name}__m'] = p._m.tolist()
|
|||
|
|
data[f'{name}__v'] = p._v.tolist()
|
|||
|
|
data[f'{name}__t'] = int(p.t)
|
|||
|
|
if norm_mean is not None:
|
|||
|
|
data['norm_mean'] = norm_mean.tolist()
|
|||
|
|
data['norm_std'] = norm_std.tolist()
|
|||
|
|
data['architecture'] = {
|
|||
|
|
'C_in': self.C_in, 'T_in': self.T_in, 'z_dim': self.z_dim,
|
|||
|
|
'base_ch': self.base_ch,
|
|||
|
|
}
|
|||
|
|
if extra:
|
|||
|
|
data.update(extra)
|
|||
|
|
with open(path, 'w') as f:
|
|||
|
|
json.dump(data, f)
|
|||
|
|
|
|||
|
|
def load(self, path: str) -> dict:
|
|||
|
|
with open(path) as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
params = self.all_params()
|
|||
|
|
names = self._param_names()
|
|||
|
|
reserved = set(names)
|
|||
|
|
for name, p in zip(names, params):
|
|||
|
|
if name in data:
|
|||
|
|
p.data[:] = np.array(data[name], dtype=np.float64)
|
|||
|
|
if f'{name}__m' in data:
|
|||
|
|
p._m[:] = np.array(data[f'{name}__m'], dtype=np.float64)
|
|||
|
|
p._v[:] = np.array(data[f'{name}__v'], dtype=np.float64)
|
|||
|
|
p.t = int(data[f'{name}__t'])
|
|||
|
|
reserved.update({f'{name}__m', f'{name}__v', f'{name}__t'})
|
|||
|
|
extras = {k: v for k, v in data.items()
|
|||
|
|
if k not in reserved and k != 'architecture'}
|
|||
|
|
return extras
|
|||
|
|
|
|||
|
|
def _param_names(self) -> list[str]:
|
|||
|
|
"""Return stable flat list of parameter name strings matching all_params()."""
|
|||
|
|
names = []
|
|||
|
|
# stem
|
|||
|
|
for i, _ in enumerate(self.stem_proj.params()):
|
|||
|
|
names.append(f'stem_proj_p{i}')
|
|||
|
|
for i, _ in enumerate(self.stem_ln.params()):
|
|||
|
|
names.append(f'stem_ln_p{i}')
|
|||
|
|
# stage0
|
|||
|
|
for bi, blk in enumerate(self.stage0):
|
|||
|
|
for i, _ in enumerate(blk.params()):
|
|||
|
|
names.append(f'stage0_b{bi}_p{i}')
|
|||
|
|
# pool
|
|||
|
|
for i, _ in enumerate(self.pool_ln.params()):
|
|||
|
|
names.append(f'pool_ln_p{i}')
|
|||
|
|
for i, _ in enumerate(self.pool_proj.params()):
|
|||
|
|
names.append(f'pool_proj_p{i}')
|
|||
|
|
# stage1
|
|||
|
|
for bi, blk in enumerate(self.stage1):
|
|||
|
|
for i, _ in enumerate(blk.params()):
|
|||
|
|
names.append(f'stage1_b{bi}_p{i}')
|
|||
|
|
# head
|
|||
|
|
for i, _ in enumerate(self.head_mu.params()):
|
|||
|
|
names.append(f'head_mu_p{i}')
|
|||
|
|
for i, _ in enumerate(self.head_logvar.params()):
|
|||
|
|
names.append(f'head_logvar_p{i}')
|
|||
|
|
# dec_linear
|
|||
|
|
for i, _ in enumerate(self.dec_linear.params()):
|
|||
|
|
names.append(f'dec_linear_p{i}')
|
|||
|
|
# stage1d
|
|||
|
|
for bi, blk in enumerate(self.stage1d):
|
|||
|
|
for i, _ in enumerate(blk.params()):
|
|||
|
|
names.append(f'stage1d_b{bi}_p{i}')
|
|||
|
|
# up
|
|||
|
|
for i, _ in enumerate(self.up_proj.params()):
|
|||
|
|
names.append(f'up_proj_p{i}')
|
|||
|
|
for i, _ in enumerate(self.up_ln.params()):
|
|||
|
|
names.append(f'up_ln_p{i}')
|
|||
|
|
# stage0d
|
|||
|
|
for bi, blk in enumerate(self.stage0d):
|
|||
|
|
for i, _ in enumerate(blk.params()):
|
|||
|
|
names.append(f'stage0d_b{bi}_p{i}')
|
|||
|
|
# out_proj
|
|||
|
|
for i, _ in enumerate(self.out_proj.params()):
|
|||
|
|
names.append(f'out_proj_p{i}')
|
|||
|
|
return names
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# β-TCVAE loss
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def btcvae_loss(
|
|||
|
|
x: np.ndarray,
|
|||
|
|
x_recon: np.ndarray,
|
|||
|
|
z_mu: np.ndarray,
|
|||
|
|
z_logvar: np.ndarray,
|
|||
|
|
z: np.ndarray,
|
|||
|
|
beta_tc: float = 4.0,
|
|||
|
|
alpha_mi: float = 1.0,
|
|||
|
|
free_bits: float = 0.0,
|
|||
|
|
) -> tuple[float, dict]:
|
|||
|
|
"""
|
|||
|
|
β-TCVAE loss with minibatch TC estimator.
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
x : (B, C, T) — input
|
|||
|
|
x_recon : (B, C, T) — reconstruction
|
|||
|
|
z_mu : (B, D)
|
|||
|
|
z_logvar : (B, D)
|
|||
|
|
z : (B, D) — reparameterized sample
|
|||
|
|
beta_tc : weight on Total Correlation term
|
|||
|
|
alpha_mi : weight on Mutual Information term
|
|||
|
|
free_bits : minimum KL per latent dimension (nats). Replaces the minibatch
|
|||
|
|
dimKL term with analytical per-dim KL clamped from below at
|
|||
|
|
free_bits. Prevents posterior collapse. 0 = standard behaviour.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
loss : scalar
|
|||
|
|
info : dict with breakdown (recon, MI, TC, dimKL)
|
|||
|
|
"""
|
|||
|
|
B, D = z.shape
|
|||
|
|
|
|||
|
|
# Reconstruction loss (per-sample mean, then mean over batch)
|
|||
|
|
recon = ((x - x_recon) ** 2).mean(axis=(-1, -2)).mean()
|
|||
|
|
|
|||
|
|
# log q(z|x): (B,)
|
|||
|
|
log_qz_x = -0.5 * (
|
|||
|
|
np.log(2 * np.pi) + z_logvar + (z - z_mu) ** 2 * np.exp(-z_logvar)
|
|||
|
|
).sum(axis=1)
|
|||
|
|
|
|||
|
|
# log p(z): (B,) — standard Gaussian
|
|||
|
|
log_pz = -0.5 * (np.log(2 * np.pi) + z ** 2).sum(axis=1)
|
|||
|
|
|
|||
|
|
# Minibatch estimator for log q(z) and log prod_d q(z_d)
|
|||
|
|
z_i = z[:, None, :] # (B, 1, D)
|
|||
|
|
mu_j = z_mu[None, :, :] # (1, B, D)
|
|||
|
|
lv_j = z_logvar[None, :, :] # (1, B, D)
|
|||
|
|
|
|||
|
|
# (B, B, D)
|
|||
|
|
log_q_ij_d = -0.5 * (
|
|||
|
|
np.log(2 * np.pi) + lv_j + (z_i - mu_j) ** 2 * np.exp(-lv_j)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# log q(z_i) = logsumexp_j[sum_d log_q_ij_d] - log B
|
|||
|
|
log_qz = _logsumexp(log_q_ij_d.sum(-1), axis=1) - np.log(B) # (B,)
|
|||
|
|
|
|||
|
|
# log prod_d q(z_d[i]) = sum_d logsumexp_j[log_q_ij_d] - D*log(B)
|
|||
|
|
log_qz_prod = (_logsumexp(log_q_ij_d, axis=1) - np.log(B)).sum(-1) # (B,)
|
|||
|
|
|
|||
|
|
MI = (log_qz_x - log_qz).mean()
|
|||
|
|
TC = (log_qz - log_qz_prod).mean()
|
|||
|
|
|
|||
|
|
if free_bits > 0.0:
|
|||
|
|
# Analytical per-dim KL, clamped: sum_d max(free_bits, KL_d)
|
|||
|
|
kl_analytic = 0.5 * (z_mu ** 2 + np.exp(z_logvar) - 1.0 - z_logvar)
|
|||
|
|
kl_per_dim = kl_analytic.mean(0) # (D,)
|
|||
|
|
dimKL = float(np.maximum(free_bits, kl_per_dim).sum())
|
|||
|
|
else:
|
|||
|
|
dimKL = float((log_qz_prod - log_pz).mean())
|
|||
|
|
|
|||
|
|
loss = recon + alpha_mi * MI + beta_tc * TC + dimKL
|
|||
|
|
|
|||
|
|
return float(loss), {
|
|||
|
|
'recon': float(recon),
|
|||
|
|
'MI': float(MI),
|
|||
|
|
'TC': float(TC),
|
|||
|
|
'dimKL': dimKL,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
# β-TCVAE backward
|
|||
|
|
# ---------------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def btcvae_loss_backward(
|
|||
|
|
x: np.ndarray,
|
|||
|
|
x_recon: np.ndarray,
|
|||
|
|
z_mu: np.ndarray,
|
|||
|
|
z_logvar: np.ndarray,
|
|||
|
|
z: np.ndarray,
|
|||
|
|
eps: np.ndarray,
|
|||
|
|
beta_tc: float = 4.0,
|
|||
|
|
alpha_mi: float = 1.0,
|
|||
|
|
free_bits: float = 0.0,
|
|||
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|||
|
|
"""
|
|||
|
|
Compute gradients of β-TCVAE loss w.r.t. x_recon, z_mu, z_logvar.
|
|||
|
|
|
|||
|
|
Parameters
|
|||
|
|
----------
|
|||
|
|
eps: (B, D) — noise used in reparameterisation: z = z_mu + eps*exp(0.5*z_logvar)
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
d_recon : (B, C, T) grad w.r.t. x_recon
|
|||
|
|
d_z_mu : (B, D) grad w.r.t. z_mu
|
|||
|
|
d_z_logvar: (B, D) grad w.r.t. z_logvar
|
|||
|
|
"""
|
|||
|
|
B, D = z.shape
|
|||
|
|
var = np.exp(z_logvar) # (B, D)
|
|||
|
|
std = np.exp(0.5 * z_logvar) # (B, D)
|
|||
|
|
|
|||
|
|
# ---- grad w.r.t. x_recon (reconstruction MSE) ----
|
|||
|
|
# L_recon = mean_{b,c,t} (x-x_recon)^2 = (1/(B*C*T)) * ||x - x_recon||^2
|
|||
|
|
# dL/dx_recon = -2*(x - x_recon) / (B*C*T)
|
|||
|
|
n_elements = x.size
|
|||
|
|
d_recon = -2.0 * (x - x_recon) / n_elements # shape (B, C, T)
|
|||
|
|
|
|||
|
|
# ---- KL terms via minibatch estimator ----
|
|||
|
|
# log q(z_i | x_i) (B,D) terms:
|
|||
|
|
err = (z - z_mu) # (B, D)
|
|||
|
|
log_qz_x_d = -0.5 * (np.log(2 * np.pi) + z_logvar + err ** 2 / var)
|
|||
|
|
|
|||
|
|
# Minibatch terms
|
|||
|
|
z_i = z[:, None, :]
|
|||
|
|
mu_j = z_mu[None, :, :]
|
|||
|
|
lv_j = z_logvar[None, :, :]
|
|||
|
|
|
|||
|
|
diff_ij = z_i - mu_j # (B, B, D)
|
|||
|
|
var_j = np.exp(lv_j) # (B, B, D)
|
|||
|
|
log_q_ij_d = -0.5 * (np.log(2 * np.pi) + lv_j + diff_ij ** 2 / var_j) # (B,B,D)
|
|||
|
|
|
|||
|
|
# Weight matrices for log_qz and log_qz_prod
|
|||
|
|
# w_qz[i, j] = softmax_j[ sum_d log_q_ij_d[i, j, :] ]
|
|||
|
|
lq_sum = log_q_ij_d.sum(-1) # (B, B)
|
|||
|
|
lq_sum_max = lq_sum.max(1, keepdims=True)
|
|||
|
|
w_qz = np.exp(lq_sum - lq_sum_max)
|
|||
|
|
w_qz /= w_qz.sum(1, keepdims=True) + 1e-30 # (B, B) softmax over j
|
|||
|
|
|
|||
|
|
# w_prod[i, j, d] = softmax_j[ log_q_ij_d[i, j, d] ] for each d separately
|
|||
|
|
lq_max_d = log_q_ij_d.max(1, keepdims=True) # (B, 1, D)
|
|||
|
|
w_prod = np.exp(log_q_ij_d - lq_max_d)
|
|||
|
|
w_prod /= w_prod.sum(1, keepdims=True) + 1e-30 # (B, B, D)
|
|||
|
|
|
|||
|
|
# ---- d(MI + TC + dimKL) / d(z_logvar) and d(z_mu) ----
|
|||
|
|
# MI = mean_i [ log_qz_x[i] - log_qz[i] ]
|
|||
|
|
# TC = mean_i [ log_qz[i] - log_qz_prod[i] ]
|
|||
|
|
# dimKL = mean_i [ log_qz_prod[i] - log_pz[i] ]
|
|||
|
|
#
|
|||
|
|
# Combined coefficient of each log term:
|
|||
|
|
# d/dtheta = (1/B) * [ alpha * d_log_qz_x - alpha * d_log_qz
|
|||
|
|
# + beta * d_log_qz - beta * d_log_qz_prod
|
|||
|
|
# + d_log_qz_prod - d_log_pz ]
|
|||
|
|
|
|||
|
|
# d log_qz_x[i] / d z_mu[i, d] = err[i,d] / var[i,d] (positive)
|
|||
|
|
# d log_qz_x[i] / d z_logvar[i,d] = -0.5 + 0.5 * err[i,d]^2 / var[i,d]
|
|||
|
|
|
|||
|
|
# d log_qz[i] / d z_mu[j, d] = -w_qz[i, j] * diff_ij[i, j, d] / var_j[i,j,d]
|
|||
|
|
# (gradient w.r.t. z_mu is aggregated over i by summing)
|
|||
|
|
# (note: this is through the *denominator* of q(z_i|x_j), not z itself)
|
|||
|
|
# d log_qz_prod[i,d] / d z_mu[j,d] = -w_prod[i,j,d] * diff_ij[i,j,d]/var_j[i,j,d]
|
|||
|
|
|
|||
|
|
# Additionally, z = mu + eps*std, so d_log_qz/d_z_mu has a DIRECT path via z_i:
|
|||
|
|
# d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (-(z[i,d] - mu_j[d]) / var_j[i,j,d])
|
|||
|
|
# d z[i,d] / d mu[i,d] = 1 → direct path
|
|||
|
|
|
|||
|
|
# For cleaner code, compute:
|
|||
|
|
# G_mu[j, d] = sum_i (alpha_mi + beta_tc) * dlog_qz/dmu_j[d] (through z_mu[j])
|
|||
|
|
# + sum_i (1+beta_tc) * dlog_qz_prod/dmu_j[d] (through z_mu[j])
|
|||
|
|
# Plus direct terms from d_log_qz_x
|
|||
|
|
|
|||
|
|
# Direct terms from log q(z|x):
|
|||
|
|
# d log_qz_x[i,d] / d mu[i,d] = (z[i,d] - mu[i,d]) / var[i,d]
|
|||
|
|
# With reparameterization: d z/d mu = 1, so also flows through z
|
|||
|
|
# d log_qz_x[i,d] / d z[i,d] = -err[i,d] / var[i,d] → d mu[i,d]
|
|||
|
|
# So total: +err/var + (-err/var) from z-path = 0 for the z-direct path
|
|||
|
|
# Actually: d loss / d mu = d loss / d z * dz/dmu + d loss / d mu directly
|
|||
|
|
# We'll compute d_loss/d_z separately and use reparameterization to split.
|
|||
|
|
|
|||
|
|
# ---- Gradient through reparameterization ----
|
|||
|
|
# L is a function of z_mu, z_logvar, and z = z_mu + eps*std.
|
|||
|
|
# dL/d_mu[i,d] = dL/d_z[i,d] * 1 + (dL/d_mu directly)
|
|||
|
|
# dL/d_logvar[i,d]= dL/d_z[i,d] * 0.5*eps[i,d]*std[i,d] + (dL/d_logvar directly)
|
|||
|
|
|
|||
|
|
# Compute dL/d_z[i,d]:
|
|||
|
|
# 1. Reconstruction path (decoder): computed elsewhere and passed in as dz_from_recon
|
|||
|
|
# (we return d_recon so caller can backprop decoder separately)
|
|||
|
|
# Here we only handle KL/TC terms which depend on z.
|
|||
|
|
|
|||
|
|
# 2. From log_qz_x[i] = sum_d log_qz_x_d[i,d]:
|
|||
|
|
# d log_qz_x[i] / d z[i,d] = -(z[i,d] - mu[i,d]) / var[i,d] = -err/var
|
|||
|
|
d_logqzx_d_z = -err / var # (B, D)
|
|||
|
|
|
|||
|
|
# 3. From log_qz[i] via z: (z[i,d] enters log_qz[i] through diff_ij[i,:,d])
|
|||
|
|
# d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (- diff_ij[i,j,d] / var_j[i,j,d])
|
|||
|
|
# But wait: diff_ij[i,j,d] = z[i,d] - mu_j[d], so d(diff_ij)/d(z[i,d]) = 1
|
|||
|
|
d_logqz_d_z = -(w_qz[:, :, None] * diff_ij / var_j).sum(1) # (B, D)
|
|||
|
|
|
|||
|
|
# 4. From log_qz_prod[i] via z:
|
|||
|
|
# d log_qz_prod[i] / d z[i,d] = sum_j w_prod[i,j,d] * (-diff_ij[i,j,d]/var_j[i,j,d])
|
|||
|
|
d_logqzprod_d_z = -(w_prod * diff_ij / var_j).sum(1) # (B, D)
|
|||
|
|
|
|||
|
|
# 5. From log_pz: d log_pz / d z = -z
|
|||
|
|
d_logpz_d_z = -z # (B, D)
|
|||
|
|
|
|||
|
|
# Combine with coefficients — MI + TC terms only; dimKL handled separately below
|
|||
|
|
# when free_bits > 0: skip ALL MI/TC aggregate-posterior terms to prevent
|
|||
|
|
# collapse. Only the analytical per-dim KL (masked) + direct encoder term
|
|||
|
|
# (alpha_mi * err/var below) remain. This is a plain VAE-with-free-bits mode.
|
|||
|
|
# when free_bits == 0: full beta-TCVAE with minibatch TC/MI decomposition.
|
|||
|
|
if free_bits > 0.0:
|
|||
|
|
dL_dz = np.zeros((B, D), dtype=np.float64) # no MI/TC reparameterization gradient
|
|||
|
|
else:
|
|||
|
|
dL_dz = (1.0 / B) * (
|
|||
|
|
alpha_mi * d_logqzx_d_z
|
|||
|
|
- alpha_mi * d_logqz_d_z
|
|||
|
|
+ beta_tc * d_logqz_d_z
|
|||
|
|
- beta_tc * d_logqzprod_d_z
|
|||
|
|
+ d_logqzprod_d_z
|
|||
|
|
- d_logpz_d_z
|
|||
|
|
) # (B, D)
|
|||
|
|
|
|||
|
|
# Gradient through z_mu and z_logvar VIA z (reparameterization):
|
|||
|
|
d_z_mu = dL_dz.copy()
|
|||
|
|
d_z_logvar = dL_dz * 0.5 * eps * std
|
|||
|
|
|
|||
|
|
# Direct gradient w.r.t. z_mu from log q(z|x):
|
|||
|
|
d_z_mu += (alpha_mi / B) * err / var
|
|||
|
|
|
|||
|
|
# Direct gradient w.r.t. z_logvar from log q(z|x):
|
|||
|
|
d_z_logvar += (alpha_mi / B) * (-0.5 + 0.5 * err ** 2 / var)
|
|||
|
|
|
|||
|
|
# Gradient of log_qz and log_qz_prod w.r.t. z_mu[j] directly
|
|||
|
|
d_logqz_d_mu_direct = (
|
|||
|
|
(w_qz[:, :, None] * diff_ij / var_j).sum(0) # (B, D)
|
|||
|
|
)
|
|||
|
|
d_logqzprod_d_mu_direct = (
|
|||
|
|
(w_prod * diff_ij / var_j).sum(0) # (B, D)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if free_bits > 0.0:
|
|||
|
|
pass # no MI/TC direct-to-mu gradient in plain-VAE-with-free-bits mode
|
|||
|
|
else:
|
|||
|
|
d_z_mu += (1.0 / B) * (
|
|||
|
|
-alpha_mi * d_logqz_d_mu_direct
|
|||
|
|
+ beta_tc * d_logqz_d_mu_direct
|
|||
|
|
- beta_tc * d_logqzprod_d_mu_direct
|
|||
|
|
+ d_logqzprod_d_mu_direct
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
d_logqz_d_lv_direct = (
|
|||
|
|
w_qz[:, :, None] * (-0.5 + 0.5 * diff_ij ** 2 / var_j)
|
|||
|
|
).sum(0) # (B, D)
|
|||
|
|
|
|||
|
|
d_logqzprod_d_lv_direct = (
|
|||
|
|
w_prod * (-0.5 + 0.5 * diff_ij ** 2 / var_j)
|
|||
|
|
).sum(0) # (B, D)
|
|||
|
|
|
|||
|
|
if free_bits > 0.0:
|
|||
|
|
# no MI/TC direct-to-logvar gradient in plain-VAE-with-free-bits mode
|
|||
|
|
# Analytical free-bits dimKL gradient: only for dims above the floor
|
|||
|
|
kl_analytic = 0.5 * (z_mu ** 2 + var - 1.0 - z_logvar) # (B, D)
|
|||
|
|
kl_per_dim = kl_analytic.mean(0) # (D,)
|
|||
|
|
mask = (kl_per_dim >= free_bits).astype(np.float64) # (D,) — 1 = penalised
|
|||
|
|
d_z_mu += mask[None, :] * z_mu / B
|
|||
|
|
d_z_logvar += mask[None, :] * 0.5 * (var - 1.0) / B
|
|||
|
|
else:
|
|||
|
|
d_z_logvar += (1.0 / B) * (
|
|||
|
|
-alpha_mi * d_logqz_d_lv_direct
|
|||
|
|
+ beta_tc * d_logqz_d_lv_direct
|
|||
|
|
- beta_tc * d_logqzprod_d_lv_direct
|
|||
|
|
+ d_logqzprod_d_lv_direct
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return d_recon, d_z_mu, d_z_logvar
|