Includes core prod + GREEN/BLUE subsystems: - prod/ (BLUE harness, configs, scripts, docs) - nautilus_dolphin/ (GREEN Nautilus-native impl + dvae/ preserved) - adaptive_exit/ (AEM engine + models/bucket_assignments.pkl) - Observability/ (EsoF advisor, TUI, dashboards) - external_factors/ (EsoF producer) - mc_forewarning_qlabs_fork/ (MC regime/envelope) Excludes runtime caches, logs, backups, and reproducible artifacts per .gitignore.
1049 lines
39 KiB
Python
Executable File
1049 lines
39 KiB
Python
Executable File
"""
|
||
ConvNeXt-1D β-TCVAE — pure numpy (no PyTorch required).
|
||
========================================================
|
||
|
||
Architecture
|
||
------------
|
||
Input : (B, C_in=8, T=32) — 8 eigenvalue channels × 32 timestep window
|
||
Stem : PointwiseProj(8→32) + LayerNorm
|
||
Stage 0: 3 × ConvNeXtBlock1D(32, dw_k=7)
|
||
Pool : AdaptiveAvgPool(16) + PointwiseProj(32→64)
|
||
Stage 1: 3 × ConvNeXtBlock1D(64, dw_k=7)
|
||
Head : GlobalAvgPool → Linear(64→32) for z_mu / z_logvar
|
||
|
||
Decoder (mirrored):
|
||
Expand : Linear(32→64) → repeat 16 times along T axis
|
||
Stage 1D: 3 × ConvNeXtBlock1D(64, dw_k=7)
|
||
Upsample: RepeatInterleave(×2) + PointwiseProj(64→32)
|
||
Stage 0D: 3 × ConvNeXtBlock1D(32, dw_k=7)
|
||
Out : PointwiseProj(32→8)
|
||
|
||
Objective: β-TCVAE (Chen et al. 2018, "Isolating Sources of Disentanglement")
|
||
β_tc=4.0, α_mi=1.0. Minibatch TC estimator.
|
||
|
||
Gradient implementation notes
|
||
------------------------------
|
||
All depthwise convolutions use the im2col trick so that gradients reduce to
|
||
standard matmul / einsum — no custom numerical differentiation required.
|
||
|
||
DWConv1d forward : windows[b,c,t,k] = x_pad[b,c,t+k]
|
||
out[b,c,t] = sum_k windows * w[c,k] + b[c]
|
||
DWConv1d backward: grad_w[c] = einsum('btk,bt->k', windows_c, grad_out_c) over B
|
||
col2im for grad_x (accumulate k-shifted copies of grad_out * w)
|
||
|
||
β-TCVAE TC gradient: computed via minibatch weight matrix w[i,j] ∝ q(z_i|x_j)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
import json
|
||
import numpy as np
|
||
from scipy.special import erf
|
||
|
||
try:
|
||
from numba import njit as _njit
|
||
_NUMBA = True
|
||
except ImportError:
|
||
_NUMBA = False
|
||
def _njit(fn): # no-op decorator
|
||
return fn
|
||
|
||
__all__ = ['ConvNeXtVAE', 'btcvae_loss', 'btcvae_loss_backward']
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Utility
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _kaiming(fan_in: int, fan_out: int, rng: np.random.RandomState) -> np.ndarray:
|
||
std = np.sqrt(2.0 / fan_in)
|
||
return rng.randn(fan_in, fan_out) * std
|
||
|
||
|
||
def _gelu(x: np.ndarray) -> np.ndarray:
|
||
return 0.5 * x * (1.0 + erf(x / np.sqrt(2.0)))
|
||
|
||
|
||
def _gelu_grad(x: np.ndarray) -> np.ndarray:
|
||
# Two allocations instead of three: out= avoids the final x*pdf temporary
|
||
cdf = 0.5 * (1.0 + erf(x * (1.0 / np.sqrt(2.0))))
|
||
pdf = np.exp(-0.5 * x * x)
|
||
pdf *= 1.0 / np.sqrt(2.0 * np.pi)
|
||
np.multiply(x, pdf, out=pdf) # pdf now holds x * pdf, in-place
|
||
pdf += cdf # pdf now holds cdf + x*pdf
|
||
return pdf
|
||
|
||
|
||
def _logsumexp(a: np.ndarray, axis: int) -> np.ndarray:
|
||
a_max = a.max(axis=axis, keepdims=True)
|
||
out = np.log(np.exp(a - a_max).sum(axis=axis, keepdims=True)) + a_max
|
||
return out.squeeze(axis)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Parameter container — stores weight, gradient, Adam state
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class Param:
|
||
__slots__ = ('data', 'grad', '_m', '_v', 't')
|
||
|
||
def __init__(self, data: np.ndarray):
|
||
self.data = data.astype(np.float64)
|
||
self.grad = np.zeros_like(self.data)
|
||
self._m = np.zeros_like(self.data)
|
||
self._v = np.zeros_like(self.data)
|
||
self.t = 0
|
||
|
||
def zero_grad(self):
|
||
self.grad.fill(0.0)
|
||
|
||
def adam_step(self, lr: float, β1: float = 0.9, β2: float = 0.999,
|
||
eps: float = 1e-8, wd: float = 0.0):
|
||
self.t += 1
|
||
g = self.grad
|
||
if wd > 0:
|
||
g = g + wd * self.data
|
||
self._m = β1 * self._m + (1 - β1) * g
|
||
self._v = β2 * self._v + (1 - β2) * g * g
|
||
m_hat = self._m / (1 - β1 ** self.t)
|
||
v_hat = self._v / (1 - β2 ** self.t)
|
||
self.data -= lr * m_hat / (np.sqrt(v_hat) + eps)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# im2col helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _im2col1d(x: np.ndarray, k: int, pad: int) -> np.ndarray:
|
||
"""
|
||
x : (B, C, T)
|
||
returns: (B, C, T, k) — zero-padded windows
|
||
|
||
Uses np.lib.stride_tricks.as_strided for zero-copy window view,
|
||
then immediately makes a contiguous copy so writes to the returned
|
||
array never corrupt x_pad.
|
||
"""
|
||
B, C, T = x.shape
|
||
x_pad = np.pad(x, ((0, 0), (0, 0), (pad, pad)))
|
||
# stride-trick view: step one element along T axis per k-position
|
||
shape = (B, C, T, k)
|
||
s0, s1, s2 = x_pad.strides
|
||
strides = (s0, s1, s2, s2)
|
||
windows = np.lib.stride_tricks.as_strided(x_pad, shape=shape, strides=strides)
|
||
return windows.copy() # contiguous copy — safe for in-place ops downstream
|
||
|
||
|
||
@_njit(cache=True)
|
||
def _col2im1d_kernel(g_pad, grad_win, T, k):
|
||
"""Numba-JIT accumulation loop for col2im."""
|
||
for i in range(k):
|
||
g_pad[:, :, i:i + T] += grad_win[:, :, :, i]
|
||
|
||
|
||
def _col2im1d(grad_win: np.ndarray, T: int, k: int, pad: int) -> np.ndarray:
|
||
"""
|
||
grad_win: (B, C, T, k)
|
||
returns : (B, C, T)
|
||
"""
|
||
B, C = grad_win.shape[0], grad_win.shape[1]
|
||
g_pad = np.zeros((B, C, T + 2 * pad))
|
||
_col2im1d_kernel(g_pad, grad_win, T, k)
|
||
return g_pad[:, :, pad:pad + T]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer: DWConv1d (depthwise conv via im2col)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class DWConv1d:
|
||
def __init__(self, C: int, k: int, rng: np.random.RandomState):
|
||
self.C, self.k, self.pad = C, k, k // 2
|
||
self.w = Param(rng.randn(C, k) * 0.02)
|
||
self.b = Param(np.zeros(C))
|
||
self._cache: tuple | None = None
|
||
|
||
def params(self) -> list[Param]:
|
||
return [self.w, self.b]
|
||
|
||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||
"""x: (B, C, T) → (B, C, T)"""
|
||
wins = _im2col1d(x, self.k, self.pad) # (B, C, T, k)
|
||
self._cache = (wins,)
|
||
# einsum with optimize: uses BLAS when possible
|
||
out = np.einsum('bctk,ck->bct', wins, self.w.data, optimize=True)
|
||
out += self.b.data[None, :, None]
|
||
return out
|
||
|
||
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
||
"""grad_out: (B, C, T) → grad_x: (B, C, T)"""
|
||
wins, = self._cache
|
||
B, C, T = grad_out.shape
|
||
# grad w: einsum over B and T
|
||
self.w.grad += np.einsum('bct,bctk->ck', grad_out, wins, optimize=True)
|
||
# grad b
|
||
self.b.grad += grad_out.sum(axis=(0, 2)) # (C,)
|
||
# grad x via col2im
|
||
grad_win = np.einsum('bct,ck->bctk', grad_out, self.w.data, optimize=True)
|
||
return _col2im1d(grad_win, T, self.k, self.pad)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer: LayerNorm (over last axis, applied to (B, T, C) or (B, C))
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class LayerNorm:
|
||
EPS = 1e-6
|
||
|
||
def __init__(self, D: int, rng: np.random.RandomState | None = None):
|
||
self.D = D
|
||
self.gamma = Param(np.ones(D))
|
||
self.beta = Param(np.zeros(D))
|
||
self._cache: tuple | None = None
|
||
|
||
def params(self) -> list[Param]:
|
||
return [self.gamma, self.beta]
|
||
|
||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||
"""Normalize last axis."""
|
||
mean = x.mean(axis=-1, keepdims=True)
|
||
var = x.var(axis=-1, keepdims=True)
|
||
x_hat = (x - mean) / np.sqrt(var + self.EPS)
|
||
self._cache = (x_hat, var)
|
||
return self.gamma.data * x_hat + self.beta.data
|
||
|
||
def backward(self, dy: np.ndarray) -> np.ndarray:
|
||
x_hat, var = self._cache
|
||
D = x_hat.shape[-1]
|
||
orig_shape = dy.shape
|
||
|
||
dy_2d = dy.reshape(-1, D)
|
||
xhat_2d = x_hat.reshape(-1, D)
|
||
var_2d = var.reshape(-1, 1)
|
||
|
||
self.gamma.grad += (dy_2d * xhat_2d).sum(0)
|
||
self.beta.grad += dy_2d.sum(0)
|
||
|
||
dx_hat = dy_2d * self.gamma.data # (N, D)
|
||
std_inv = 1.0 / np.sqrt(var_2d + self.EPS) # (N, 1)
|
||
dx = std_inv * (dx_hat
|
||
- dx_hat.mean(-1, keepdims=True)
|
||
- xhat_2d * (dx_hat * xhat_2d).mean(-1, keepdims=True))
|
||
return dx.reshape(orig_shape)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer: Linear
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class Linear:
|
||
def __init__(self, in_f: int, out_f: int, rng: np.random.RandomState,
|
||
bias: bool = True):
|
||
self.W = Param(_kaiming(in_f, out_f, rng))
|
||
self.b = Param(np.zeros(out_f)) if bias else None
|
||
self._x_cache: np.ndarray | None = None
|
||
|
||
def params(self) -> list[Param]:
|
||
return [self.W] + ([self.b] if self.b is not None else [])
|
||
|
||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||
self._x_cache = x
|
||
out = x @ self.W.data
|
||
if self.b is not None:
|
||
out = out + self.b.data
|
||
return out
|
||
|
||
def backward(self, dy: np.ndarray) -> np.ndarray:
|
||
x = self._x_cache
|
||
x2 = x.reshape(-1, self.W.data.shape[0])
|
||
dy2 = dy.reshape(-1, self.W.data.shape[1])
|
||
self.W.grad += x2.T @ dy2
|
||
if self.b is not None:
|
||
self.b.grad += dy2.sum(0)
|
||
return (dy2 @ self.W.data.T).reshape(x.shape)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Block: ConvNeXtBlock1D
|
||
# ---------------------------------------------------------------------------
|
||
# Input/output: (B, C, T)
|
||
# 1. DWConv1d(C, k)
|
||
# 2. Permute → (B, T, C)
|
||
# 3. LayerNorm(C)
|
||
# 4. Linear(C → 4C) + GELU
|
||
# 5. Linear(4C → C)
|
||
# 6. Permute → (B, C, T)
|
||
# 7. Skip
|
||
|
||
class ConvNeXtBlock1D:
|
||
def __init__(self, C: int, k: int, rng: np.random.RandomState):
|
||
self.dwconv = DWConv1d(C, k, rng)
|
||
self.ln = LayerNorm(C)
|
||
self.fc1 = Linear(C, 4 * C, rng)
|
||
self.fc2 = Linear(4 * C, C, rng)
|
||
self._cache: dict | None = None
|
||
|
||
def params(self) -> list[Param]:
|
||
return (self.dwconv.params() + self.ln.params() +
|
||
self.fc1.params() + self.fc2.params())
|
||
|
||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||
"""x: (B, C, T) → (B, C, T)"""
|
||
h = self.dwconv.forward(x) # (B, C, T)
|
||
h = h.transpose(0, 2, 1) # (B, T, C)
|
||
h = self.ln.forward(h) # (B, T, C)
|
||
h_fc1 = self.fc1.forward(h) # (B, T, 4C) — pre-GELU
|
||
act = _gelu(h_fc1)
|
||
h = self.fc2.forward(act) # (B, T, C)
|
||
h = h.transpose(0, 2, 1) # (B, C, T)
|
||
self._pre_gelu = h_fc1 # cache for backward
|
||
return x + h # skip
|
||
|
||
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
||
"""grad_out: (B, C, T) → grad_x: (B, C, T)"""
|
||
# skip: grad flows unchanged to x, AND through the block
|
||
grad_block = grad_out.copy()
|
||
# 6. permute (B, C, T) → (B, T, C)
|
||
grad_block = grad_block.transpose(0, 2, 1)
|
||
# 5. fc2 backward → grad w.r.t. GELU output (act), shape (B, T, 4C)
|
||
grad_block = self.fc2.backward(grad_block)
|
||
# 4. GELU backward (pre_gelu cached during forward)
|
||
grad_block = grad_block * _gelu_grad(self._pre_gelu)
|
||
# 4. fc1 backward
|
||
grad_block = self.fc1.backward(grad_block)
|
||
# 3. LN backward
|
||
grad_block = self.ln.backward(grad_block)
|
||
# 2. permute (B, T, C) → (B, C, T)
|
||
grad_block = grad_block.transpose(0, 2, 1)
|
||
# 1. DWConv backward
|
||
grad_block = self.dwconv.backward(grad_block)
|
||
return grad_out + grad_block # skip: grad_x = grad_out + grad_block
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pool: AdaptiveAvgPool1D (fixed stride, e.g. T=32→T=16 stride=2)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class AvgPool1D:
|
||
"""Stride-2 average pooling over temporal axis."""
|
||
def __init__(self, stride: int = 2):
|
||
self.stride = stride
|
||
self._T_in: int | None = None
|
||
|
||
def params(self) -> list[Param]:
|
||
return []
|
||
|
||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||
"""x: (B, C, T) → (B, C, T//stride)"""
|
||
B, C, T = x.shape
|
||
self._T_in = T
|
||
T_out = T // self.stride
|
||
# reshape and mean
|
||
return x[:, :, :T_out * self.stride].reshape(B, C, T_out, self.stride).mean(-1)
|
||
|
||
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
||
B, C, T_out = grad_out.shape
|
||
T_in = self._T_in
|
||
# each output contributes 1/stride to each input in the pool
|
||
g = grad_out[:, :, :, None] * np.ones((1, 1, 1, self.stride)) / self.stride
|
||
g = g.reshape(B, C, T_out * self.stride)
|
||
if T_in > T_out * self.stride:
|
||
g = np.pad(g, ((0, 0), (0, 0), (0, T_in - T_out * self.stride)))
|
||
return g
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Upsample: nearest-neighbor ×2 along temporal axis
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class Upsample1D:
|
||
"""Nearest-neighbor upsampling ×factor along temporal axis."""
|
||
def __init__(self, factor: int = 2):
|
||
self.factor = factor
|
||
|
||
def params(self) -> list[Param]:
|
||
return []
|
||
|
||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||
return np.repeat(x, self.factor, axis=-1)
|
||
|
||
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
||
f = self.factor
|
||
B, C, T = grad_out.shape
|
||
T_in = T // f
|
||
return grad_out.reshape(B, C, T_in, f).sum(-1)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# PointwiseProj: 1×1 conv (= Linear applied per timestep)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class PointwiseProj:
|
||
"""(B, C_in, T) → (B, C_out, T) via shared Linear(C_in, C_out)."""
|
||
def __init__(self, C_in: int, C_out: int, rng: np.random.RandomState):
|
||
self.lin = Linear(C_in, C_out, rng)
|
||
|
||
def params(self) -> list[Param]:
|
||
return self.lin.params()
|
||
|
||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||
"""x: (B, C_in, T)"""
|
||
B, C, T = x.shape
|
||
x_t = x.transpose(0, 2, 1) # (B, T, C_in)
|
||
out = self.lin.forward(x_t) # (B, T, C_out)
|
||
return out.transpose(0, 2, 1) # (B, C_out, T)
|
||
|
||
def backward(self, grad_out: np.ndarray) -> np.ndarray:
|
||
g_t = grad_out.transpose(0, 2, 1) # (B, T, C_out)
|
||
g_t = self.lin.backward(g_t) # (B, T, C_in)
|
||
return g_t.transpose(0, 2, 1) # (B, C_in, T)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main model: ConvNeXtVAE
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class ConvNeXtVAE:
|
||
"""
|
||
ConvNeXt-1D β-TCVAE.
|
||
|
||
Parameters
|
||
----------
|
||
C_in : input channels (default 8)
|
||
T_in : input timesteps (default 32)
|
||
z_dim : latent dimensionality (default 32)
|
||
base_ch : stem channel width (default 32)
|
||
dw_k : depthwise kernel size (default 7)
|
||
n_blocks : ConvNeXt blocks per stage (default 3)
|
||
seed : RNG seed
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
C_in: int = 8,
|
||
T_in: int = 32,
|
||
z_dim: int = 32,
|
||
base_ch: int = 32,
|
||
dw_k: int = 7,
|
||
n_blocks: int = 3,
|
||
seed: int = 42,
|
||
):
|
||
rng = np.random.RandomState(seed)
|
||
self.C_in = C_in
|
||
self.T_in = T_in
|
||
self.z_dim = z_dim
|
||
self.base_ch = base_ch
|
||
ch1 = base_ch * 2 # 64
|
||
|
||
# ---- ENCODER ----
|
||
# Stem
|
||
self.stem_proj = PointwiseProj(C_in, base_ch, rng)
|
||
self.stem_ln = LayerNorm(base_ch)
|
||
|
||
# Stage 0
|
||
self.stage0 = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)]
|
||
|
||
# Pool + channel proj
|
||
self.pool = AvgPool1D(stride=2)
|
||
self.pool_ln = LayerNorm(base_ch)
|
||
self.pool_proj = PointwiseProj(base_ch, ch1, rng)
|
||
|
||
# Stage 1
|
||
self.stage1 = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)]
|
||
|
||
# Global avg pool → z_mu, z_logvar
|
||
self.head_mu = Linear(ch1, z_dim, rng)
|
||
self.head_logvar = Linear(ch1, z_dim, rng)
|
||
|
||
# ---- DECODER ----
|
||
T_bot = T_in // 2 # bottleneck temporal size (16 for T_in=32)
|
||
self.dec_linear = Linear(z_dim, ch1, rng)
|
||
|
||
# Stage 1D
|
||
self.stage1d = [ConvNeXtBlock1D(ch1, dw_k, rng) for _ in range(n_blocks)]
|
||
|
||
# Upsample + channel proj
|
||
self.up = Upsample1D(factor=2)
|
||
self.up_proj = PointwiseProj(ch1, base_ch, rng)
|
||
self.up_ln = LayerNorm(base_ch)
|
||
|
||
# Stage 0D
|
||
self.stage0d = [ConvNeXtBlock1D(base_ch, dw_k, rng) for _ in range(n_blocks)]
|
||
|
||
# Output projection
|
||
self.out_proj = PointwiseProj(base_ch, C_in, rng)
|
||
|
||
# Cache for backward
|
||
self._enc_cache: dict = {}
|
||
self._dec_cache: dict = {}
|
||
self._T_bot = T_bot
|
||
|
||
# ------------------------------------------------------------------
|
||
def all_params(self) -> list[Param]:
|
||
ps = []
|
||
ps += self.stem_proj.params() + self.stem_ln.params()
|
||
for blk in self.stage0:
|
||
ps += blk.params()
|
||
ps += (self.pool_ln.params() + self.pool_proj.params())
|
||
for blk in self.stage1:
|
||
ps += blk.params()
|
||
ps += self.head_mu.params() + self.head_logvar.params()
|
||
ps += self.dec_linear.params()
|
||
for blk in self.stage1d:
|
||
ps += blk.params()
|
||
ps += self.up_proj.params() + self.up_ln.params()
|
||
for blk in self.stage0d:
|
||
ps += blk.params()
|
||
ps += self.out_proj.params()
|
||
return ps
|
||
|
||
def zero_grad(self):
|
||
for p in self.all_params():
|
||
p.zero_grad()
|
||
|
||
def adam_step(self, lr: float, wd: float = 1e-4):
|
||
for p in self.all_params():
|
||
p.adam_step(lr, wd=wd)
|
||
|
||
def n_params(self) -> int:
|
||
return sum(p.data.size for p in self.all_params())
|
||
|
||
# ------------------------------------------------------------------
|
||
# ENCODER
|
||
# ------------------------------------------------------------------
|
||
def encode(self, x: np.ndarray):
|
||
"""
|
||
x: (B, C_in, T_in)
|
||
Returns z_mu (B,z), z_logvar (B,z)
|
||
"""
|
||
c = self._enc_cache
|
||
c['x'] = x
|
||
|
||
# Stem
|
||
h = self.stem_proj.forward(x) # (B, base_ch, T)
|
||
h = h.transpose(0, 2, 1) # (B, T, base_ch)
|
||
h = self.stem_ln.forward(h) # (B, T, base_ch)
|
||
h = h.transpose(0, 2, 1) # (B, base_ch, T)
|
||
c['after_stem'] = h
|
||
|
||
# Stage 0
|
||
for blk in self.stage0:
|
||
h = blk.forward(h)
|
||
c['after_stage0'] = h
|
||
|
||
# Pool + proj
|
||
h = self.pool.forward(h) # (B, base_ch, T//2)
|
||
h = h.transpose(0, 2, 1) # (B, T//2, base_ch)
|
||
h = self.pool_ln.forward(h) # (B, T//2, base_ch)
|
||
h = h.transpose(0, 2, 1) # (B, base_ch, T//2)
|
||
h = self.pool_proj.forward(h) # (B, ch1, T//2)
|
||
c['after_pool'] = h
|
||
|
||
# Stage 1
|
||
for blk in self.stage1:
|
||
h = blk.forward(h)
|
||
c['after_stage1'] = h
|
||
|
||
# Global avg pool
|
||
h_gap = h.mean(axis=-1) # (B, ch1)
|
||
c['gap'] = h_gap
|
||
|
||
z_mu = self.head_mu.forward(h_gap) # (B, z_dim)
|
||
z_logvar = self.head_logvar.forward(h_gap) # (B, z_dim)
|
||
return z_mu, z_logvar
|
||
|
||
# ------------------------------------------------------------------
|
||
# DECODER
|
||
# ------------------------------------------------------------------
|
||
def decode(self, z: np.ndarray) -> np.ndarray:
|
||
"""
|
||
z: (B, z_dim)
|
||
Returns x_recon: (B, C_in, T_in)
|
||
"""
|
||
c = self._dec_cache
|
||
T_bot = self._T_bot
|
||
|
||
h = self.dec_linear.forward(z) # (B, ch1)
|
||
h = h[:, :, None] * np.ones((1, 1, T_bot)) # broadcast (B, ch1, T_bot)
|
||
c['dec_expand'] = h
|
||
|
||
for blk in self.stage1d:
|
||
h = blk.forward(h)
|
||
c['after_stage1d'] = h
|
||
|
||
# Upsample + proj
|
||
h = self.up.forward(h) # (B, ch1, T_in)
|
||
h = self.up_proj.forward(h) # (B, base_ch, T_in)
|
||
h = h.transpose(0, 2, 1) # (B, T_in, base_ch)
|
||
h = self.up_ln.forward(h) # (B, T_in, base_ch)
|
||
h = h.transpose(0, 2, 1) # (B, base_ch, T_in)
|
||
c['after_up'] = h
|
||
|
||
for blk in self.stage0d:
|
||
h = blk.forward(h)
|
||
|
||
x_recon = self.out_proj.forward(h) # (B, C_in, T_in)
|
||
return x_recon
|
||
|
||
# ------------------------------------------------------------------
|
||
# BACKWARD (encoder)
|
||
# ------------------------------------------------------------------
|
||
def backward_encode(self, dz_mu: np.ndarray, dz_logvar: np.ndarray):
|
||
"""
|
||
Given grad w.r.t. z_mu and z_logvar, backprop through encoder.
|
||
Returns grad w.r.t. x (not used further, just for completeness).
|
||
"""
|
||
c = self._enc_cache
|
||
|
||
# head
|
||
d_gap = self.head_mu.backward(dz_mu) + self.head_logvar.backward(dz_logvar)
|
||
|
||
# global avg pool backward: distribute evenly over T axis
|
||
h_after = c['after_stage1']
|
||
T_bot = h_after.shape[-1]
|
||
d_h = d_gap[:, :, None] / T_bot * np.ones((1, 1, T_bot)) # (B, ch1, T_bot)
|
||
|
||
# stage 1 backward
|
||
for blk in reversed(self.stage1):
|
||
d_h = blk.backward(d_h)
|
||
|
||
# pool_proj backward
|
||
d_h = self.pool_proj.backward(d_h) # (B, base_ch, T//2)
|
||
# pool_ln backward
|
||
d_h = d_h.transpose(0, 2, 1)
|
||
d_h = self.pool_ln.backward(d_h)
|
||
d_h = d_h.transpose(0, 2, 1)
|
||
# pool backward
|
||
d_h = self.pool.backward(d_h) # (B, base_ch, T)
|
||
|
||
# stage 0 backward
|
||
for blk in reversed(self.stage0):
|
||
d_h = blk.backward(d_h)
|
||
|
||
# stem backward
|
||
d_h = d_h.transpose(0, 2, 1)
|
||
d_h = self.stem_ln.backward(d_h)
|
||
d_h = d_h.transpose(0, 2, 1)
|
||
d_h = self.stem_proj.backward(d_h)
|
||
|
||
return d_h # grad w.r.t. input x
|
||
|
||
# ------------------------------------------------------------------
|
||
# BACKWARD (decoder)
|
||
# ------------------------------------------------------------------
|
||
def backward_decode(self, d_recon: np.ndarray):
|
||
"""
|
||
d_recon: (B, C_in, T_in) — grad w.r.t. x_recon
|
||
Returns grad w.r.t. z.
|
||
"""
|
||
c = self._dec_cache
|
||
T_bot = self._T_bot
|
||
|
||
# out_proj backward
|
||
d_h = self.out_proj.backward(d_recon) # (B, base_ch, T_in)
|
||
|
||
# stage 0d backward
|
||
for blk in reversed(self.stage0d):
|
||
d_h = blk.backward(d_h)
|
||
|
||
# up_ln backward
|
||
d_h = d_h.transpose(0, 2, 1)
|
||
d_h = self.up_ln.backward(d_h)
|
||
d_h = d_h.transpose(0, 2, 1)
|
||
# up_proj backward
|
||
d_h = self.up_proj.backward(d_h) # (B, ch1, T_in)
|
||
# upsample backward
|
||
d_h = self.up.backward(d_h) # (B, ch1, T_bot)
|
||
|
||
# stage 1d backward
|
||
for blk in reversed(self.stage1d):
|
||
d_h = blk.backward(d_h)
|
||
|
||
# expand backward: sum over temporal axis (it was broadcast)
|
||
d_h_expand = d_h.sum(axis=-1) # (B, ch1)
|
||
|
||
# dec_linear backward
|
||
dz = self.dec_linear.backward(d_h_expand) # (B, z_dim)
|
||
return dz
|
||
|
||
# ------------------------------------------------------------------
|
||
# FORWARD convenience
|
||
# ------------------------------------------------------------------
|
||
def forward(self, x: np.ndarray):
|
||
"""Full forward: encode + reparameterize + decode."""
|
||
z_mu, z_logvar = self.encode(x)
|
||
eps = np.random.randn(*z_mu.shape)
|
||
z = z_mu + eps * np.exp(0.5 * z_logvar)
|
||
x_recon = self.decode(z)
|
||
return x_recon, z_mu, z_logvar, z
|
||
|
||
# ------------------------------------------------------------------
|
||
# SAVE / LOAD
|
||
# ------------------------------------------------------------------
|
||
def save(self, path: str, norm_mean=None, norm_std=None,
|
||
extra: dict | None = None, save_adam: bool = False):
|
||
data: dict = {}
|
||
params = self.all_params()
|
||
names = self._param_names()
|
||
for name, p in zip(names, params):
|
||
data[name] = p.data.tolist()
|
||
if save_adam:
|
||
data[f'{name}__m'] = p._m.tolist()
|
||
data[f'{name}__v'] = p._v.tolist()
|
||
data[f'{name}__t'] = int(p.t)
|
||
if norm_mean is not None:
|
||
data['norm_mean'] = norm_mean.tolist()
|
||
data['norm_std'] = norm_std.tolist()
|
||
data['architecture'] = {
|
||
'C_in': self.C_in, 'T_in': self.T_in, 'z_dim': self.z_dim,
|
||
'base_ch': self.base_ch,
|
||
}
|
||
if extra:
|
||
data.update(extra)
|
||
with open(path, 'w') as f:
|
||
json.dump(data, f)
|
||
|
||
def load(self, path: str) -> dict:
|
||
with open(path) as f:
|
||
data = json.load(f)
|
||
params = self.all_params()
|
||
names = self._param_names()
|
||
reserved = set(names)
|
||
for name, p in zip(names, params):
|
||
if name in data:
|
||
p.data[:] = np.array(data[name], dtype=np.float64)
|
||
if f'{name}__m' in data:
|
||
p._m[:] = np.array(data[f'{name}__m'], dtype=np.float64)
|
||
p._v[:] = np.array(data[f'{name}__v'], dtype=np.float64)
|
||
p.t = int(data[f'{name}__t'])
|
||
reserved.update({f'{name}__m', f'{name}__v', f'{name}__t'})
|
||
extras = {k: v for k, v in data.items()
|
||
if k not in reserved and k != 'architecture'}
|
||
return extras
|
||
|
||
def _param_names(self) -> list[str]:
|
||
"""Return stable flat list of parameter name strings matching all_params()."""
|
||
names = []
|
||
# stem
|
||
for i, _ in enumerate(self.stem_proj.params()):
|
||
names.append(f'stem_proj_p{i}')
|
||
for i, _ in enumerate(self.stem_ln.params()):
|
||
names.append(f'stem_ln_p{i}')
|
||
# stage0
|
||
for bi, blk in enumerate(self.stage0):
|
||
for i, _ in enumerate(blk.params()):
|
||
names.append(f'stage0_b{bi}_p{i}')
|
||
# pool
|
||
for i, _ in enumerate(self.pool_ln.params()):
|
||
names.append(f'pool_ln_p{i}')
|
||
for i, _ in enumerate(self.pool_proj.params()):
|
||
names.append(f'pool_proj_p{i}')
|
||
# stage1
|
||
for bi, blk in enumerate(self.stage1):
|
||
for i, _ in enumerate(blk.params()):
|
||
names.append(f'stage1_b{bi}_p{i}')
|
||
# head
|
||
for i, _ in enumerate(self.head_mu.params()):
|
||
names.append(f'head_mu_p{i}')
|
||
for i, _ in enumerate(self.head_logvar.params()):
|
||
names.append(f'head_logvar_p{i}')
|
||
# dec_linear
|
||
for i, _ in enumerate(self.dec_linear.params()):
|
||
names.append(f'dec_linear_p{i}')
|
||
# stage1d
|
||
for bi, blk in enumerate(self.stage1d):
|
||
for i, _ in enumerate(blk.params()):
|
||
names.append(f'stage1d_b{bi}_p{i}')
|
||
# up
|
||
for i, _ in enumerate(self.up_proj.params()):
|
||
names.append(f'up_proj_p{i}')
|
||
for i, _ in enumerate(self.up_ln.params()):
|
||
names.append(f'up_ln_p{i}')
|
||
# stage0d
|
||
for bi, blk in enumerate(self.stage0d):
|
||
for i, _ in enumerate(blk.params()):
|
||
names.append(f'stage0d_b{bi}_p{i}')
|
||
# out_proj
|
||
for i, _ in enumerate(self.out_proj.params()):
|
||
names.append(f'out_proj_p{i}')
|
||
return names
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# β-TCVAE loss
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def btcvae_loss(
|
||
x: np.ndarray,
|
||
x_recon: np.ndarray,
|
||
z_mu: np.ndarray,
|
||
z_logvar: np.ndarray,
|
||
z: np.ndarray,
|
||
beta_tc: float = 4.0,
|
||
alpha_mi: float = 1.0,
|
||
free_bits: float = 0.0,
|
||
) -> tuple[float, dict]:
|
||
"""
|
||
β-TCVAE loss with minibatch TC estimator.
|
||
|
||
Parameters
|
||
----------
|
||
x : (B, C, T) — input
|
||
x_recon : (B, C, T) — reconstruction
|
||
z_mu : (B, D)
|
||
z_logvar : (B, D)
|
||
z : (B, D) — reparameterized sample
|
||
beta_tc : weight on Total Correlation term
|
||
alpha_mi : weight on Mutual Information term
|
||
free_bits : minimum KL per latent dimension (nats). Replaces the minibatch
|
||
dimKL term with analytical per-dim KL clamped from below at
|
||
free_bits. Prevents posterior collapse. 0 = standard behaviour.
|
||
|
||
Returns
|
||
-------
|
||
loss : scalar
|
||
info : dict with breakdown (recon, MI, TC, dimKL)
|
||
"""
|
||
B, D = z.shape
|
||
|
||
# Reconstruction loss (per-sample mean, then mean over batch)
|
||
recon = ((x - x_recon) ** 2).mean(axis=(-1, -2)).mean()
|
||
|
||
# log q(z|x): (B,)
|
||
log_qz_x = -0.5 * (
|
||
np.log(2 * np.pi) + z_logvar + (z - z_mu) ** 2 * np.exp(-z_logvar)
|
||
).sum(axis=1)
|
||
|
||
# log p(z): (B,) — standard Gaussian
|
||
log_pz = -0.5 * (np.log(2 * np.pi) + z ** 2).sum(axis=1)
|
||
|
||
# Minibatch estimator for log q(z) and log prod_d q(z_d)
|
||
z_i = z[:, None, :] # (B, 1, D)
|
||
mu_j = z_mu[None, :, :] # (1, B, D)
|
||
lv_j = z_logvar[None, :, :] # (1, B, D)
|
||
|
||
# (B, B, D)
|
||
log_q_ij_d = -0.5 * (
|
||
np.log(2 * np.pi) + lv_j + (z_i - mu_j) ** 2 * np.exp(-lv_j)
|
||
)
|
||
|
||
# log q(z_i) = logsumexp_j[sum_d log_q_ij_d] - log B
|
||
log_qz = _logsumexp(log_q_ij_d.sum(-1), axis=1) - np.log(B) # (B,)
|
||
|
||
# log prod_d q(z_d[i]) = sum_d logsumexp_j[log_q_ij_d] - D*log(B)
|
||
log_qz_prod = (_logsumexp(log_q_ij_d, axis=1) - np.log(B)).sum(-1) # (B,)
|
||
|
||
MI = (log_qz_x - log_qz).mean()
|
||
TC = (log_qz - log_qz_prod).mean()
|
||
|
||
if free_bits > 0.0:
|
||
# Analytical per-dim KL, clamped: sum_d max(free_bits, KL_d)
|
||
kl_analytic = 0.5 * (z_mu ** 2 + np.exp(z_logvar) - 1.0 - z_logvar)
|
||
kl_per_dim = kl_analytic.mean(0) # (D,)
|
||
dimKL = float(np.maximum(free_bits, kl_per_dim).sum())
|
||
else:
|
||
dimKL = float((log_qz_prod - log_pz).mean())
|
||
|
||
loss = recon + alpha_mi * MI + beta_tc * TC + dimKL
|
||
|
||
return float(loss), {
|
||
'recon': float(recon),
|
||
'MI': float(MI),
|
||
'TC': float(TC),
|
||
'dimKL': dimKL,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# β-TCVAE backward
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def btcvae_loss_backward(
|
||
x: np.ndarray,
|
||
x_recon: np.ndarray,
|
||
z_mu: np.ndarray,
|
||
z_logvar: np.ndarray,
|
||
z: np.ndarray,
|
||
eps: np.ndarray,
|
||
beta_tc: float = 4.0,
|
||
alpha_mi: float = 1.0,
|
||
free_bits: float = 0.0,
|
||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||
"""
|
||
Compute gradients of β-TCVAE loss w.r.t. x_recon, z_mu, z_logvar.
|
||
|
||
Parameters
|
||
----------
|
||
eps: (B, D) — noise used in reparameterisation: z = z_mu + eps*exp(0.5*z_logvar)
|
||
|
||
Returns
|
||
-------
|
||
d_recon : (B, C, T) grad w.r.t. x_recon
|
||
d_z_mu : (B, D) grad w.r.t. z_mu
|
||
d_z_logvar: (B, D) grad w.r.t. z_logvar
|
||
"""
|
||
B, D = z.shape
|
||
var = np.exp(z_logvar) # (B, D)
|
||
std = np.exp(0.5 * z_logvar) # (B, D)
|
||
|
||
# ---- grad w.r.t. x_recon (reconstruction MSE) ----
|
||
# L_recon = mean_{b,c,t} (x-x_recon)^2 = (1/(B*C*T)) * ||x - x_recon||^2
|
||
# dL/dx_recon = -2*(x - x_recon) / (B*C*T)
|
||
n_elements = x.size
|
||
d_recon = -2.0 * (x - x_recon) / n_elements # shape (B, C, T)
|
||
|
||
# ---- KL terms via minibatch estimator ----
|
||
# log q(z_i | x_i) (B,D) terms:
|
||
err = (z - z_mu) # (B, D)
|
||
log_qz_x_d = -0.5 * (np.log(2 * np.pi) + z_logvar + err ** 2 / var)
|
||
|
||
# Minibatch terms
|
||
z_i = z[:, None, :]
|
||
mu_j = z_mu[None, :, :]
|
||
lv_j = z_logvar[None, :, :]
|
||
|
||
diff_ij = z_i - mu_j # (B, B, D)
|
||
var_j = np.exp(lv_j) # (B, B, D)
|
||
log_q_ij_d = -0.5 * (np.log(2 * np.pi) + lv_j + diff_ij ** 2 / var_j) # (B,B,D)
|
||
|
||
# Weight matrices for log_qz and log_qz_prod
|
||
# w_qz[i, j] = softmax_j[ sum_d log_q_ij_d[i, j, :] ]
|
||
lq_sum = log_q_ij_d.sum(-1) # (B, B)
|
||
lq_sum_max = lq_sum.max(1, keepdims=True)
|
||
w_qz = np.exp(lq_sum - lq_sum_max)
|
||
w_qz /= w_qz.sum(1, keepdims=True) + 1e-30 # (B, B) softmax over j
|
||
|
||
# w_prod[i, j, d] = softmax_j[ log_q_ij_d[i, j, d] ] for each d separately
|
||
lq_max_d = log_q_ij_d.max(1, keepdims=True) # (B, 1, D)
|
||
w_prod = np.exp(log_q_ij_d - lq_max_d)
|
||
w_prod /= w_prod.sum(1, keepdims=True) + 1e-30 # (B, B, D)
|
||
|
||
# ---- d(MI + TC + dimKL) / d(z_logvar) and d(z_mu) ----
|
||
# MI = mean_i [ log_qz_x[i] - log_qz[i] ]
|
||
# TC = mean_i [ log_qz[i] - log_qz_prod[i] ]
|
||
# dimKL = mean_i [ log_qz_prod[i] - log_pz[i] ]
|
||
#
|
||
# Combined coefficient of each log term:
|
||
# d/dtheta = (1/B) * [ alpha * d_log_qz_x - alpha * d_log_qz
|
||
# + beta * d_log_qz - beta * d_log_qz_prod
|
||
# + d_log_qz_prod - d_log_pz ]
|
||
|
||
# d log_qz_x[i] / d z_mu[i, d] = err[i,d] / var[i,d] (positive)
|
||
# d log_qz_x[i] / d z_logvar[i,d] = -0.5 + 0.5 * err[i,d]^2 / var[i,d]
|
||
|
||
# d log_qz[i] / d z_mu[j, d] = -w_qz[i, j] * diff_ij[i, j, d] / var_j[i,j,d]
|
||
# (gradient w.r.t. z_mu is aggregated over i by summing)
|
||
# (note: this is through the *denominator* of q(z_i|x_j), not z itself)
|
||
# d log_qz_prod[i,d] / d z_mu[j,d] = -w_prod[i,j,d] * diff_ij[i,j,d]/var_j[i,j,d]
|
||
|
||
# Additionally, z = mu + eps*std, so d_log_qz/d_z_mu has a DIRECT path via z_i:
|
||
# d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (-(z[i,d] - mu_j[d]) / var_j[i,j,d])
|
||
# d z[i,d] / d mu[i,d] = 1 → direct path
|
||
|
||
# For cleaner code, compute:
|
||
# G_mu[j, d] = sum_i (alpha_mi + beta_tc) * dlog_qz/dmu_j[d] (through z_mu[j])
|
||
# + sum_i (1+beta_tc) * dlog_qz_prod/dmu_j[d] (through z_mu[j])
|
||
# Plus direct terms from d_log_qz_x
|
||
|
||
# Direct terms from log q(z|x):
|
||
# d log_qz_x[i,d] / d mu[i,d] = (z[i,d] - mu[i,d]) / var[i,d]
|
||
# With reparameterization: d z/d mu = 1, so also flows through z
|
||
# d log_qz_x[i,d] / d z[i,d] = -err[i,d] / var[i,d] → d mu[i,d]
|
||
# So total: +err/var + (-err/var) from z-path = 0 for the z-direct path
|
||
# Actually: d loss / d mu = d loss / d z * dz/dmu + d loss / d mu directly
|
||
# We'll compute d_loss/d_z separately and use reparameterization to split.
|
||
|
||
# ---- Gradient through reparameterization ----
|
||
# L is a function of z_mu, z_logvar, and z = z_mu + eps*std.
|
||
# dL/d_mu[i,d] = dL/d_z[i,d] * 1 + (dL/d_mu directly)
|
||
# dL/d_logvar[i,d]= dL/d_z[i,d] * 0.5*eps[i,d]*std[i,d] + (dL/d_logvar directly)
|
||
|
||
# Compute dL/d_z[i,d]:
|
||
# 1. Reconstruction path (decoder): computed elsewhere and passed in as dz_from_recon
|
||
# (we return d_recon so caller can backprop decoder separately)
|
||
# Here we only handle KL/TC terms which depend on z.
|
||
|
||
# 2. From log_qz_x[i] = sum_d log_qz_x_d[i,d]:
|
||
# d log_qz_x[i] / d z[i,d] = -(z[i,d] - mu[i,d]) / var[i,d] = -err/var
|
||
d_logqzx_d_z = -err / var # (B, D)
|
||
|
||
# 3. From log_qz[i] via z: (z[i,d] enters log_qz[i] through diff_ij[i,:,d])
|
||
# d log_qz[i] / d z[i, d] = sum_j w_qz[i,j] * (- diff_ij[i,j,d] / var_j[i,j,d])
|
||
# But wait: diff_ij[i,j,d] = z[i,d] - mu_j[d], so d(diff_ij)/d(z[i,d]) = 1
|
||
d_logqz_d_z = -(w_qz[:, :, None] * diff_ij / var_j).sum(1) # (B, D)
|
||
|
||
# 4. From log_qz_prod[i] via z:
|
||
# d log_qz_prod[i] / d z[i,d] = sum_j w_prod[i,j,d] * (-diff_ij[i,j,d]/var_j[i,j,d])
|
||
d_logqzprod_d_z = -(w_prod * diff_ij / var_j).sum(1) # (B, D)
|
||
|
||
# 5. From log_pz: d log_pz / d z = -z
|
||
d_logpz_d_z = -z # (B, D)
|
||
|
||
# Combine with coefficients — MI + TC terms only; dimKL handled separately below
|
||
# when free_bits > 0: skip ALL MI/TC aggregate-posterior terms to prevent
|
||
# collapse. Only the analytical per-dim KL (masked) + direct encoder term
|
||
# (alpha_mi * err/var below) remain. This is a plain VAE-with-free-bits mode.
|
||
# when free_bits == 0: full beta-TCVAE with minibatch TC/MI decomposition.
|
||
if free_bits > 0.0:
|
||
dL_dz = np.zeros((B, D), dtype=np.float64) # no MI/TC reparameterization gradient
|
||
else:
|
||
dL_dz = (1.0 / B) * (
|
||
alpha_mi * d_logqzx_d_z
|
||
- alpha_mi * d_logqz_d_z
|
||
+ beta_tc * d_logqz_d_z
|
||
- beta_tc * d_logqzprod_d_z
|
||
+ d_logqzprod_d_z
|
||
- d_logpz_d_z
|
||
) # (B, D)
|
||
|
||
# Gradient through z_mu and z_logvar VIA z (reparameterization):
|
||
d_z_mu = dL_dz.copy()
|
||
d_z_logvar = dL_dz * 0.5 * eps * std
|
||
|
||
# Direct gradient w.r.t. z_mu from log q(z|x):
|
||
d_z_mu += (alpha_mi / B) * err / var
|
||
|
||
# Direct gradient w.r.t. z_logvar from log q(z|x):
|
||
d_z_logvar += (alpha_mi / B) * (-0.5 + 0.5 * err ** 2 / var)
|
||
|
||
# Gradient of log_qz and log_qz_prod w.r.t. z_mu[j] directly
|
||
d_logqz_d_mu_direct = (
|
||
(w_qz[:, :, None] * diff_ij / var_j).sum(0) # (B, D)
|
||
)
|
||
d_logqzprod_d_mu_direct = (
|
||
(w_prod * diff_ij / var_j).sum(0) # (B, D)
|
||
)
|
||
|
||
if free_bits > 0.0:
|
||
pass # no MI/TC direct-to-mu gradient in plain-VAE-with-free-bits mode
|
||
else:
|
||
d_z_mu += (1.0 / B) * (
|
||
-alpha_mi * d_logqz_d_mu_direct
|
||
+ beta_tc * d_logqz_d_mu_direct
|
||
- beta_tc * d_logqzprod_d_mu_direct
|
||
+ d_logqzprod_d_mu_direct
|
||
)
|
||
|
||
d_logqz_d_lv_direct = (
|
||
w_qz[:, :, None] * (-0.5 + 0.5 * diff_ij ** 2 / var_j)
|
||
).sum(0) # (B, D)
|
||
|
||
d_logqzprod_d_lv_direct = (
|
||
w_prod * (-0.5 + 0.5 * diff_ij ** 2 / var_j)
|
||
).sum(0) # (B, D)
|
||
|
||
if free_bits > 0.0:
|
||
# no MI/TC direct-to-logvar gradient in plain-VAE-with-free-bits mode
|
||
# Analytical free-bits dimKL gradient: only for dims above the floor
|
||
kl_analytic = 0.5 * (z_mu ** 2 + var - 1.0 - z_logvar) # (B, D)
|
||
kl_per_dim = kl_analytic.mean(0) # (D,)
|
||
mask = (kl_per_dim >= free_bits).astype(np.float64) # (D,) — 1 = penalised
|
||
d_z_mu += mask[None, :] * z_mu / B
|
||
d_z_logvar += mask[None, :] * 0.5 * (var - 1.0) / B
|
||
else:
|
||
d_z_logvar += (1.0 / B) * (
|
||
-alpha_mi * d_logqz_d_lv_direct
|
||
+ beta_tc * d_logqz_d_lv_direct
|
||
- beta_tc * d_logqzprod_d_lv_direct
|
||
+ d_logqzprod_d_lv_direct
|
||
)
|
||
|
||
return d_recon, d_z_mu, d_z_logvar
|