Files
DOLPHIN/nautilus_dolphin/dvae/test_convnext_dvae.py

333 lines
11 KiB
Python
Raw Normal View History

"""
Unit tests for ConvNeXt-1D β-TCVAE (pure numpy).
Tests
-----
T1 DWConv1d forward shape correct
T2 DWConv1d backward numerical gradient check
T3 LayerNorm forward/backward numerical gradient check
T4 ConvNeXtBlock1D output shape == input shape (skip preserved)
T5 ConvNeXtBlock1D backward numerical gradient check
T6 ConvNeXtVAE forward: output shapes correct
T7 ConvNeXtVAE forward/backward: loss is finite, grads finite
T8 Loss decreases over 20 steps on small batch (gradient descent works)
T9 Save/load round-trip: weights preserved
T10 β-TCVAE loss backward: numerical gradient check on z_mu/z_logvar
Run: python test_convnext_dvae.py
"""
import sys
import os
import tempfile
import numpy as np
# ensure local import
sys.path.insert(0, os.path.dirname(__file__))
from convnext_dvae import (
ConvNeXtVAE, DWConv1d, LayerNorm, ConvNeXtBlock1D,
btcvae_loss, btcvae_loss_backward
)
RNG = np.random.RandomState(0)
PASS = []
FAIL = []
def check(name: str, cond: bool, detail: str = ''):
if cond:
print(f' PASS {name}')
PASS.append(name)
else:
print(f' FAIL {name} {detail}')
FAIL.append(name)
def num_grad(f, x, eps=1e-5):
"""Numeric gradient of scalar f at x."""
g = np.zeros_like(x)
it = np.nditer(x, flags=['multi_index'])
while not it.finished:
idx = it.multi_index
orig = x[idx]
x[idx] = orig + eps
fp = f(x)
x[idx] = orig - eps
fm = f(x)
x[idx] = orig
g[idx] = (fp - fm) / (2 * eps)
it.iternext()
return g
# ---------------------------------------------------------------------------
def test_t1_dwconv_shape():
print('\nT1: DWConv1d forward shape')
layer = DWConv1d(8, 7, RNG)
x = RNG.randn(4, 8, 32)
y = layer.forward(x)
check('output_shape', y.shape == (4, 8, 32), str(y.shape))
def test_t2_dwconv_grad():
print('\nT2: DWConv1d backward numerical check')
rng2 = np.random.RandomState(1)
layer = DWConv1d(4, 3, rng2)
x = rng2.randn(2, 4, 8).astype(np.float64) * 0.5
def fwd(w_flat):
layer.w.data[:] = w_flat.reshape(layer.w.data.shape)
out = layer.forward(x)
return float(out.sum())
layer.forward(x)
layer.backward(np.ones((2, 4, 8)))
w0 = layer.w.data.copy()
ng = num_grad(lambda w: fwd(w.ravel()), w0.ravel())
ag = layer.w.grad.ravel()
rel = np.abs(ag - ng) / (np.abs(ng) + 1e-8)
check('dw_max_rel_err', rel.max() < 0.01, f'max_rel={rel.max():.4f}')
# also check dx
def fwd_x(xf):
return float(layer.forward(xf.reshape(2, 4, 8)).sum())
ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy())
layer.forward(x)
gx = layer.backward(np.ones((2, 4, 8))).ravel()
rel_x = np.abs(gx - ng_x) / (np.abs(ng_x) + 1e-8)
check('dx_max_rel_err', rel_x.max() < 0.01, f'max_rel_x={rel_x.max():.4f}')
def test_t3_layernorm_grad():
print('\nT3: LayerNorm backward numerical check')
rng2 = np.random.RandomState(2)
layer = LayerNorm(8)
x = rng2.randn(4, 12, 8).astype(np.float64)
def fwd_gamma(gf):
layer.gamma.data[:] = gf
return float(layer.forward(x).sum())
layer.forward(x)
layer.backward(np.ones((4, 12, 8)))
ag = layer.gamma.grad.copy()
ng = num_grad(lambda gf: fwd_gamma(gf), layer.gamma.data.copy())
rel = np.abs(ag - ng) / (np.abs(ng) + 1e-8)
check('dgamma_max_rel_err', rel.max() < 0.01, f'{rel.max():.4f}')
def fwd_x(xf):
return float(layer.forward(xf.reshape(4, 12, 8)).sum())
dx = layer.backward(np.ones((4, 12, 8))).ravel()
ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy())
rel_x = np.abs(dx - ng_x) / (np.abs(ng_x) + 1e-8)
check('dx_max_rel_err', rel_x.max() < 0.01, f'{rel_x.max():.4f}')
def test_t4_block_shape():
print('\nT4: ConvNeXtBlock1D shape preserved')
rng2 = np.random.RandomState(3)
blk = ConvNeXtBlock1D(16, 7, rng2)
x = rng2.randn(4, 16, 32)
y = blk.forward(x)
check('shape_unchanged', y.shape == x.shape, str(y.shape))
check('skip_active', not np.allclose(y, x)) # block changes values
def test_t5_block_grad():
print('\nT5: ConvNeXtBlock1D backward numerical check (small block)')
rng2 = np.random.RandomState(4)
blk = ConvNeXtBlock1D(4, 3, rng2)
x = rng2.randn(2, 4, 8).astype(np.float64) * 0.3
def fwd_x(xf):
return float(blk.forward(xf.reshape(2, 4, 8)).sum())
blk.forward(x)
dx_an = blk.backward(np.ones((2, 4, 8))).ravel()
ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy())
rel = np.abs(dx_an - ng_x) / (np.abs(ng_x) + 1e-8)
check('dx_max_rel_err', rel.max() < 0.02, f'{rel.max():.4f}')
def test_t6_vae_shapes():
print('\nT6: ConvNeXtVAE forward shapes')
rng2 = np.random.RandomState(5)
model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=5)
x = rng2.randn(4, 8, 32)
x_recon, z_mu, z_logvar, z = model.forward(x)
check('x_recon_shape', x_recon.shape == (4, 8, 32), str(x_recon.shape))
check('z_mu_shape', z_mu.shape == (4, 16), str(z_mu.shape))
check('z_logvar_shape',z_logvar.shape == (4, 16), str(z_logvar.shape))
check('x_recon_finite',np.all(np.isfinite(x_recon)))
check('z_mu_finite', np.all(np.isfinite(z_mu)))
def test_t7_vae_backward():
print('\nT7: ConvNeXtVAE backward: grads finite, loss finite')
rng2 = np.random.RandomState(6)
model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=6)
x = rng2.randn(8, 8, 32)
eps_noise = rng2.randn(8, 16)
# manual forward with stored eps
z_mu, z_logvar = model.encode(x)
z = z_mu + eps_noise * np.exp(0.5 * z_logvar)
x_recon = model.decode(z)
loss, info = btcvae_loss(x, x_recon, z_mu, z_logvar, z, beta_tc=2.0)
check('loss_finite', np.isfinite(loss), f'loss={loss:.4f}')
for k, v in info.items():
check(f'{k}_finite', np.isfinite(v), f'{k}={v:.4f}')
d_recon, d_z_mu, d_z_logvar = btcvae_loss_backward(
x, x_recon, z_mu, z_logvar, z, eps_noise, beta_tc=2.0
)
model.zero_grad()
dz = model.backward_decode(d_recon)
model.backward_encode(d_z_mu + dz, d_z_logvar)
all_grads = [p.grad for p in model.all_params()]
any_nan = any(np.any(np.isnan(g)) for g in all_grads)
any_inf = any(np.any(np.isinf(g)) for g in all_grads)
all_zero = all(np.all(g == 0) for g in all_grads)
check('grads_no_nan', not any_nan)
check('grads_no_inf', not any_inf)
check('grads_not_all_zero', not all_zero)
def test_t8_loss_decreases():
print('\nT8: Loss decreases over 20 gradient steps')
rng2 = np.random.RandomState(7)
model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=7)
x = rng2.randn(16, 8, 32).astype(np.float64)
losses = []
lr = 1e-3
for step in range(20):
rng2_eps = np.random.RandomState(step)
eps_noise = rng2_eps.randn(16, 16)
z_mu, z_logvar = model.encode(x)
z = z_mu + eps_noise * np.exp(0.5 * z_logvar)
x_recon = model.decode(z)
loss, _ = btcvae_loss(x, x_recon, z_mu, z_logvar, z, beta_tc=1.0, alpha_mi=1.0)
losses.append(loss)
d_recon, d_z_mu, d_z_logvar = btcvae_loss_backward(
x, x_recon, z_mu, z_logvar, z, eps_noise, beta_tc=1.0, alpha_mi=1.0
)
model.zero_grad()
dz = model.backward_decode(d_recon)
model.backward_encode(d_z_mu + dz, d_z_logvar)
model.adam_step(lr)
first5 = np.mean(losses[:5])
last5 = np.mean(losses[-5:])
print(f' loss[0:5]={first5:.4f} loss[15:20]={last5:.4f}')
check('loss_decreasing', last5 < first5, f'{last5:.4f} vs {first5:.4f}')
def test_t9_save_load():
print('\nT9: Save/load round-trip')
rng2 = np.random.RandomState(8)
model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=8)
x = rng2.randn(4, 8, 32)
# deterministic forward: encode then decode with fixed z_mu (no sampling)
z_mu1, _ = model.encode(x)
x_recon1 = model.decode(z_mu1)
nm = rng2.randn(8).astype(np.float64)
ns = np.abs(rng2.randn(8)).astype(np.float64) + 0.1
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f:
path = f.name
try:
model.save(path, norm_mean=nm, norm_std=ns)
model2 = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=0)
extras = model2.load(path)
check('norm_mean_saved', 'norm_mean' in extras)
check('norm_std_saved', 'norm_std' in extras)
z_mu2, _ = model2.encode(x)
x_recon2 = model2.decode(z_mu2)
check('recon_reproduced', np.allclose(x_recon1, x_recon2, atol=1e-10))
check('z_mu_reproduced', np.allclose(z_mu1, z_mu2, atol=1e-10))
finally:
os.unlink(path)
def test_t10_btcvae_grad():
print('\nT10: beta-TCVAE backward numerical gradient check (z_mu, z_logvar)')
rng2 = np.random.RandomState(9)
B, D = 8, 6
C, T = 4, 8
x = rng2.randn(B, C, T).astype(np.float64)
z_mu = rng2.randn(B, D).astype(np.float64) * 0.5
z_lv = rng2.randn(B, D).astype(np.float64) * 0.3 - 1.0 # negative → small var
eps = rng2.randn(B, D).astype(np.float64)
z = z_mu + eps * np.exp(0.5 * z_lv)
# Use a fixed simple decoder output to avoid recomputing model
x_recon = rng2.randn(B, C, T).astype(np.float64) * 0.5
def loss_fn_mu(mu_flat):
mu = mu_flat.reshape(B, D)
z_ = mu + eps * np.exp(0.5 * z_lv)
l, _ = btcvae_loss(x, x_recon, mu, z_lv, z_, beta_tc=2.0, alpha_mi=1.0)
return l
def loss_fn_lv(lv_flat):
lv = lv_flat.reshape(B, D)
z_ = z_mu + eps * np.exp(0.5 * lv)
l, _ = btcvae_loss(x, x_recon, z_mu, lv, z_, beta_tc=2.0, alpha_mi=1.0)
return l
_, d_z_mu, d_z_lv = btcvae_loss_backward(
x, x_recon, z_mu, z_lv, z, eps, beta_tc=2.0, alpha_mi=1.0
)
ng_mu = num_grad(loss_fn_mu, z_mu.ravel().copy()).reshape(B, D)
ng_lv = num_grad(loss_fn_lv, z_lv.ravel().copy()).reshape(B, D)
rel_mu = np.abs(d_z_mu - ng_mu) / (np.abs(ng_mu) + 1e-6)
rel_lv = np.abs(d_z_lv - ng_lv) / (np.abs(ng_lv) + 1e-6)
print(f' d_z_mu max_rel={rel_mu.max():.4f} mean_rel={rel_mu.mean():.4f}')
print(f' d_z_lv max_rel={rel_lv.max():.4f} mean_rel={rel_lv.mean():.4f}')
check('d_z_mu_rel_err', rel_mu.max() < 0.05, f'{rel_mu.max():.4f}')
check('d_z_lv_rel_err', rel_lv.max() < 0.05, f'{rel_lv.max():.4f}')
# ---------------------------------------------------------------------------
if __name__ == '__main__':
print('=' * 60)
print('ConvNeXt-1D beta-TCVAE unit tests')
print('=' * 60)
test_t1_dwconv_shape()
test_t2_dwconv_grad()
test_t3_layernorm_grad()
test_t4_block_shape()
test_t5_block_grad()
test_t6_vae_shapes()
test_t7_vae_backward()
test_t8_loss_decreases()
test_t9_save_load()
test_t10_btcvae_grad()
print()
print('=' * 60)
print(f'Results: {len(PASS)} PASS / {len(FAIL)} FAIL')
if FAIL:
print('FAILED:', FAIL)
sys.exit(1)
else:
print('All tests passed.')