333 lines
11 KiB
Python
333 lines
11 KiB
Python
|
|
"""
|
||
|
|
Unit tests for ConvNeXt-1D β-TCVAE (pure numpy).
|
||
|
|
|
||
|
|
Tests
|
||
|
|
-----
|
||
|
|
T1 DWConv1d forward shape correct
|
||
|
|
T2 DWConv1d backward numerical gradient check
|
||
|
|
T3 LayerNorm forward/backward numerical gradient check
|
||
|
|
T4 ConvNeXtBlock1D output shape == input shape (skip preserved)
|
||
|
|
T5 ConvNeXtBlock1D backward numerical gradient check
|
||
|
|
T6 ConvNeXtVAE forward: output shapes correct
|
||
|
|
T7 ConvNeXtVAE forward/backward: loss is finite, grads finite
|
||
|
|
T8 Loss decreases over 20 steps on small batch (gradient descent works)
|
||
|
|
T9 Save/load round-trip: weights preserved
|
||
|
|
T10 β-TCVAE loss backward: numerical gradient check on z_mu/z_logvar
|
||
|
|
|
||
|
|
Run: python test_convnext_dvae.py
|
||
|
|
"""
|
||
|
|
|
||
|
|
import sys
|
||
|
|
import os
|
||
|
|
import tempfile
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
# ensure local import
|
||
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
||
|
|
from convnext_dvae import (
|
||
|
|
ConvNeXtVAE, DWConv1d, LayerNorm, ConvNeXtBlock1D,
|
||
|
|
btcvae_loss, btcvae_loss_backward
|
||
|
|
)
|
||
|
|
|
||
|
|
RNG = np.random.RandomState(0)
|
||
|
|
PASS = []
|
||
|
|
FAIL = []
|
||
|
|
|
||
|
|
|
||
|
|
def check(name: str, cond: bool, detail: str = ''):
|
||
|
|
if cond:
|
||
|
|
print(f' PASS {name}')
|
||
|
|
PASS.append(name)
|
||
|
|
else:
|
||
|
|
print(f' FAIL {name} {detail}')
|
||
|
|
FAIL.append(name)
|
||
|
|
|
||
|
|
|
||
|
|
def num_grad(f, x, eps=1e-5):
|
||
|
|
"""Numeric gradient of scalar f at x."""
|
||
|
|
g = np.zeros_like(x)
|
||
|
|
it = np.nditer(x, flags=['multi_index'])
|
||
|
|
while not it.finished:
|
||
|
|
idx = it.multi_index
|
||
|
|
orig = x[idx]
|
||
|
|
x[idx] = orig + eps
|
||
|
|
fp = f(x)
|
||
|
|
x[idx] = orig - eps
|
||
|
|
fm = f(x)
|
||
|
|
x[idx] = orig
|
||
|
|
g[idx] = (fp - fm) / (2 * eps)
|
||
|
|
it.iternext()
|
||
|
|
return g
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
def test_t1_dwconv_shape():
|
||
|
|
print('\nT1: DWConv1d forward shape')
|
||
|
|
layer = DWConv1d(8, 7, RNG)
|
||
|
|
x = RNG.randn(4, 8, 32)
|
||
|
|
y = layer.forward(x)
|
||
|
|
check('output_shape', y.shape == (4, 8, 32), str(y.shape))
|
||
|
|
|
||
|
|
|
||
|
|
def test_t2_dwconv_grad():
|
||
|
|
print('\nT2: DWConv1d backward numerical check')
|
||
|
|
rng2 = np.random.RandomState(1)
|
||
|
|
layer = DWConv1d(4, 3, rng2)
|
||
|
|
x = rng2.randn(2, 4, 8).astype(np.float64) * 0.5
|
||
|
|
|
||
|
|
def fwd(w_flat):
|
||
|
|
layer.w.data[:] = w_flat.reshape(layer.w.data.shape)
|
||
|
|
out = layer.forward(x)
|
||
|
|
return float(out.sum())
|
||
|
|
|
||
|
|
layer.forward(x)
|
||
|
|
layer.backward(np.ones((2, 4, 8)))
|
||
|
|
|
||
|
|
w0 = layer.w.data.copy()
|
||
|
|
ng = num_grad(lambda w: fwd(w.ravel()), w0.ravel())
|
||
|
|
ag = layer.w.grad.ravel()
|
||
|
|
rel = np.abs(ag - ng) / (np.abs(ng) + 1e-8)
|
||
|
|
check('dw_max_rel_err', rel.max() < 0.01, f'max_rel={rel.max():.4f}')
|
||
|
|
|
||
|
|
# also check dx
|
||
|
|
def fwd_x(xf):
|
||
|
|
return float(layer.forward(xf.reshape(2, 4, 8)).sum())
|
||
|
|
|
||
|
|
ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy())
|
||
|
|
layer.forward(x)
|
||
|
|
gx = layer.backward(np.ones((2, 4, 8))).ravel()
|
||
|
|
rel_x = np.abs(gx - ng_x) / (np.abs(ng_x) + 1e-8)
|
||
|
|
check('dx_max_rel_err', rel_x.max() < 0.01, f'max_rel_x={rel_x.max():.4f}')
|
||
|
|
|
||
|
|
|
||
|
|
def test_t3_layernorm_grad():
|
||
|
|
print('\nT3: LayerNorm backward numerical check')
|
||
|
|
rng2 = np.random.RandomState(2)
|
||
|
|
layer = LayerNorm(8)
|
||
|
|
x = rng2.randn(4, 12, 8).astype(np.float64)
|
||
|
|
|
||
|
|
def fwd_gamma(gf):
|
||
|
|
layer.gamma.data[:] = gf
|
||
|
|
return float(layer.forward(x).sum())
|
||
|
|
|
||
|
|
layer.forward(x)
|
||
|
|
layer.backward(np.ones((4, 12, 8)))
|
||
|
|
ag = layer.gamma.grad.copy()
|
||
|
|
ng = num_grad(lambda gf: fwd_gamma(gf), layer.gamma.data.copy())
|
||
|
|
rel = np.abs(ag - ng) / (np.abs(ng) + 1e-8)
|
||
|
|
check('dgamma_max_rel_err', rel.max() < 0.01, f'{rel.max():.4f}')
|
||
|
|
|
||
|
|
def fwd_x(xf):
|
||
|
|
return float(layer.forward(xf.reshape(4, 12, 8)).sum())
|
||
|
|
dx = layer.backward(np.ones((4, 12, 8))).ravel()
|
||
|
|
ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy())
|
||
|
|
rel_x = np.abs(dx - ng_x) / (np.abs(ng_x) + 1e-8)
|
||
|
|
check('dx_max_rel_err', rel_x.max() < 0.01, f'{rel_x.max():.4f}')
|
||
|
|
|
||
|
|
|
||
|
|
def test_t4_block_shape():
|
||
|
|
print('\nT4: ConvNeXtBlock1D shape preserved')
|
||
|
|
rng2 = np.random.RandomState(3)
|
||
|
|
blk = ConvNeXtBlock1D(16, 7, rng2)
|
||
|
|
x = rng2.randn(4, 16, 32)
|
||
|
|
y = blk.forward(x)
|
||
|
|
check('shape_unchanged', y.shape == x.shape, str(y.shape))
|
||
|
|
check('skip_active', not np.allclose(y, x)) # block changes values
|
||
|
|
|
||
|
|
|
||
|
|
def test_t5_block_grad():
|
||
|
|
print('\nT5: ConvNeXtBlock1D backward numerical check (small block)')
|
||
|
|
rng2 = np.random.RandomState(4)
|
||
|
|
blk = ConvNeXtBlock1D(4, 3, rng2)
|
||
|
|
x = rng2.randn(2, 4, 8).astype(np.float64) * 0.3
|
||
|
|
|
||
|
|
def fwd_x(xf):
|
||
|
|
return float(blk.forward(xf.reshape(2, 4, 8)).sum())
|
||
|
|
|
||
|
|
blk.forward(x)
|
||
|
|
dx_an = blk.backward(np.ones((2, 4, 8))).ravel()
|
||
|
|
ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy())
|
||
|
|
rel = np.abs(dx_an - ng_x) / (np.abs(ng_x) + 1e-8)
|
||
|
|
check('dx_max_rel_err', rel.max() < 0.02, f'{rel.max():.4f}')
|
||
|
|
|
||
|
|
|
||
|
|
def test_t6_vae_shapes():
|
||
|
|
print('\nT6: ConvNeXtVAE forward shapes')
|
||
|
|
rng2 = np.random.RandomState(5)
|
||
|
|
model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=5)
|
||
|
|
x = rng2.randn(4, 8, 32)
|
||
|
|
x_recon, z_mu, z_logvar, z = model.forward(x)
|
||
|
|
check('x_recon_shape', x_recon.shape == (4, 8, 32), str(x_recon.shape))
|
||
|
|
check('z_mu_shape', z_mu.shape == (4, 16), str(z_mu.shape))
|
||
|
|
check('z_logvar_shape',z_logvar.shape == (4, 16), str(z_logvar.shape))
|
||
|
|
check('x_recon_finite',np.all(np.isfinite(x_recon)))
|
||
|
|
check('z_mu_finite', np.all(np.isfinite(z_mu)))
|
||
|
|
|
||
|
|
|
||
|
|
def test_t7_vae_backward():
|
||
|
|
print('\nT7: ConvNeXtVAE backward: grads finite, loss finite')
|
||
|
|
rng2 = np.random.RandomState(6)
|
||
|
|
model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=6)
|
||
|
|
x = rng2.randn(8, 8, 32)
|
||
|
|
eps_noise = rng2.randn(8, 16)
|
||
|
|
|
||
|
|
# manual forward with stored eps
|
||
|
|
z_mu, z_logvar = model.encode(x)
|
||
|
|
z = z_mu + eps_noise * np.exp(0.5 * z_logvar)
|
||
|
|
x_recon = model.decode(z)
|
||
|
|
|
||
|
|
loss, info = btcvae_loss(x, x_recon, z_mu, z_logvar, z, beta_tc=2.0)
|
||
|
|
check('loss_finite', np.isfinite(loss), f'loss={loss:.4f}')
|
||
|
|
for k, v in info.items():
|
||
|
|
check(f'{k}_finite', np.isfinite(v), f'{k}={v:.4f}')
|
||
|
|
|
||
|
|
d_recon, d_z_mu, d_z_logvar = btcvae_loss_backward(
|
||
|
|
x, x_recon, z_mu, z_logvar, z, eps_noise, beta_tc=2.0
|
||
|
|
)
|
||
|
|
|
||
|
|
model.zero_grad()
|
||
|
|
dz = model.backward_decode(d_recon)
|
||
|
|
model.backward_encode(d_z_mu + dz, d_z_logvar)
|
||
|
|
|
||
|
|
all_grads = [p.grad for p in model.all_params()]
|
||
|
|
any_nan = any(np.any(np.isnan(g)) for g in all_grads)
|
||
|
|
any_inf = any(np.any(np.isinf(g)) for g in all_grads)
|
||
|
|
all_zero = all(np.all(g == 0) for g in all_grads)
|
||
|
|
check('grads_no_nan', not any_nan)
|
||
|
|
check('grads_no_inf', not any_inf)
|
||
|
|
check('grads_not_all_zero', not all_zero)
|
||
|
|
|
||
|
|
|
||
|
|
def test_t8_loss_decreases():
|
||
|
|
print('\nT8: Loss decreases over 20 gradient steps')
|
||
|
|
rng2 = np.random.RandomState(7)
|
||
|
|
model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=7)
|
||
|
|
x = rng2.randn(16, 8, 32).astype(np.float64)
|
||
|
|
|
||
|
|
losses = []
|
||
|
|
lr = 1e-3
|
||
|
|
for step in range(20):
|
||
|
|
rng2_eps = np.random.RandomState(step)
|
||
|
|
eps_noise = rng2_eps.randn(16, 16)
|
||
|
|
z_mu, z_logvar = model.encode(x)
|
||
|
|
z = z_mu + eps_noise * np.exp(0.5 * z_logvar)
|
||
|
|
x_recon = model.decode(z)
|
||
|
|
loss, _ = btcvae_loss(x, x_recon, z_mu, z_logvar, z, beta_tc=1.0, alpha_mi=1.0)
|
||
|
|
losses.append(loss)
|
||
|
|
|
||
|
|
d_recon, d_z_mu, d_z_logvar = btcvae_loss_backward(
|
||
|
|
x, x_recon, z_mu, z_logvar, z, eps_noise, beta_tc=1.0, alpha_mi=1.0
|
||
|
|
)
|
||
|
|
model.zero_grad()
|
||
|
|
dz = model.backward_decode(d_recon)
|
||
|
|
model.backward_encode(d_z_mu + dz, d_z_logvar)
|
||
|
|
model.adam_step(lr)
|
||
|
|
|
||
|
|
first5 = np.mean(losses[:5])
|
||
|
|
last5 = np.mean(losses[-5:])
|
||
|
|
print(f' loss[0:5]={first5:.4f} loss[15:20]={last5:.4f}')
|
||
|
|
check('loss_decreasing', last5 < first5, f'{last5:.4f} vs {first5:.4f}')
|
||
|
|
|
||
|
|
|
||
|
|
def test_t9_save_load():
|
||
|
|
print('\nT9: Save/load round-trip')
|
||
|
|
rng2 = np.random.RandomState(8)
|
||
|
|
model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=8)
|
||
|
|
x = rng2.randn(4, 8, 32)
|
||
|
|
|
||
|
|
# deterministic forward: encode then decode with fixed z_mu (no sampling)
|
||
|
|
z_mu1, _ = model.encode(x)
|
||
|
|
x_recon1 = model.decode(z_mu1)
|
||
|
|
|
||
|
|
nm = rng2.randn(8).astype(np.float64)
|
||
|
|
ns = np.abs(rng2.randn(8)).astype(np.float64) + 0.1
|
||
|
|
|
||
|
|
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f:
|
||
|
|
path = f.name
|
||
|
|
|
||
|
|
try:
|
||
|
|
model.save(path, norm_mean=nm, norm_std=ns)
|
||
|
|
model2 = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=0)
|
||
|
|
extras = model2.load(path)
|
||
|
|
|
||
|
|
check('norm_mean_saved', 'norm_mean' in extras)
|
||
|
|
check('norm_std_saved', 'norm_std' in extras)
|
||
|
|
|
||
|
|
z_mu2, _ = model2.encode(x)
|
||
|
|
x_recon2 = model2.decode(z_mu2)
|
||
|
|
check('recon_reproduced', np.allclose(x_recon1, x_recon2, atol=1e-10))
|
||
|
|
check('z_mu_reproduced', np.allclose(z_mu1, z_mu2, atol=1e-10))
|
||
|
|
finally:
|
||
|
|
os.unlink(path)
|
||
|
|
|
||
|
|
|
||
|
|
def test_t10_btcvae_grad():
|
||
|
|
print('\nT10: beta-TCVAE backward numerical gradient check (z_mu, z_logvar)')
|
||
|
|
rng2 = np.random.RandomState(9)
|
||
|
|
B, D = 8, 6
|
||
|
|
C, T = 4, 8
|
||
|
|
x = rng2.randn(B, C, T).astype(np.float64)
|
||
|
|
z_mu = rng2.randn(B, D).astype(np.float64) * 0.5
|
||
|
|
z_lv = rng2.randn(B, D).astype(np.float64) * 0.3 - 1.0 # negative → small var
|
||
|
|
eps = rng2.randn(B, D).astype(np.float64)
|
||
|
|
z = z_mu + eps * np.exp(0.5 * z_lv)
|
||
|
|
|
||
|
|
# Use a fixed simple decoder output to avoid recomputing model
|
||
|
|
x_recon = rng2.randn(B, C, T).astype(np.float64) * 0.5
|
||
|
|
|
||
|
|
def loss_fn_mu(mu_flat):
|
||
|
|
mu = mu_flat.reshape(B, D)
|
||
|
|
z_ = mu + eps * np.exp(0.5 * z_lv)
|
||
|
|
l, _ = btcvae_loss(x, x_recon, mu, z_lv, z_, beta_tc=2.0, alpha_mi=1.0)
|
||
|
|
return l
|
||
|
|
|
||
|
|
def loss_fn_lv(lv_flat):
|
||
|
|
lv = lv_flat.reshape(B, D)
|
||
|
|
z_ = z_mu + eps * np.exp(0.5 * lv)
|
||
|
|
l, _ = btcvae_loss(x, x_recon, z_mu, lv, z_, beta_tc=2.0, alpha_mi=1.0)
|
||
|
|
return l
|
||
|
|
|
||
|
|
_, d_z_mu, d_z_lv = btcvae_loss_backward(
|
||
|
|
x, x_recon, z_mu, z_lv, z, eps, beta_tc=2.0, alpha_mi=1.0
|
||
|
|
)
|
||
|
|
|
||
|
|
ng_mu = num_grad(loss_fn_mu, z_mu.ravel().copy()).reshape(B, D)
|
||
|
|
ng_lv = num_grad(loss_fn_lv, z_lv.ravel().copy()).reshape(B, D)
|
||
|
|
|
||
|
|
rel_mu = np.abs(d_z_mu - ng_mu) / (np.abs(ng_mu) + 1e-6)
|
||
|
|
rel_lv = np.abs(d_z_lv - ng_lv) / (np.abs(ng_lv) + 1e-6)
|
||
|
|
|
||
|
|
print(f' d_z_mu max_rel={rel_mu.max():.4f} mean_rel={rel_mu.mean():.4f}')
|
||
|
|
print(f' d_z_lv max_rel={rel_lv.max():.4f} mean_rel={rel_lv.mean():.4f}')
|
||
|
|
check('d_z_mu_rel_err', rel_mu.max() < 0.05, f'{rel_mu.max():.4f}')
|
||
|
|
check('d_z_lv_rel_err', rel_lv.max() < 0.05, f'{rel_lv.max():.4f}')
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
print('=' * 60)
|
||
|
|
print('ConvNeXt-1D beta-TCVAE unit tests')
|
||
|
|
print('=' * 60)
|
||
|
|
|
||
|
|
test_t1_dwconv_shape()
|
||
|
|
test_t2_dwconv_grad()
|
||
|
|
test_t3_layernorm_grad()
|
||
|
|
test_t4_block_shape()
|
||
|
|
test_t5_block_grad()
|
||
|
|
test_t6_vae_shapes()
|
||
|
|
test_t7_vae_backward()
|
||
|
|
test_t8_loss_decreases()
|
||
|
|
test_t9_save_load()
|
||
|
|
test_t10_btcvae_grad()
|
||
|
|
|
||
|
|
print()
|
||
|
|
print('=' * 60)
|
||
|
|
print(f'Results: {len(PASS)} PASS / {len(FAIL)} FAIL')
|
||
|
|
if FAIL:
|
||
|
|
print('FAILED:', FAIL)
|
||
|
|
sys.exit(1)
|
||
|
|
else:
|
||
|
|
print('All tests passed.')
|