""" Unit tests for ConvNeXt-1D β-TCVAE (pure numpy). Tests ----- T1 DWConv1d forward shape correct T2 DWConv1d backward numerical gradient check T3 LayerNorm forward/backward numerical gradient check T4 ConvNeXtBlock1D output shape == input shape (skip preserved) T5 ConvNeXtBlock1D backward numerical gradient check T6 ConvNeXtVAE forward: output shapes correct T7 ConvNeXtVAE forward/backward: loss is finite, grads finite T8 Loss decreases over 20 steps on small batch (gradient descent works) T9 Save/load round-trip: weights preserved T10 β-TCVAE loss backward: numerical gradient check on z_mu/z_logvar Run: python test_convnext_dvae.py """ import sys import os import tempfile import numpy as np # ensure local import sys.path.insert(0, os.path.dirname(__file__)) from convnext_dvae import ( ConvNeXtVAE, DWConv1d, LayerNorm, ConvNeXtBlock1D, btcvae_loss, btcvae_loss_backward ) RNG = np.random.RandomState(0) PASS = [] FAIL = [] def check(name: str, cond: bool, detail: str = ''): if cond: print(f' PASS {name}') PASS.append(name) else: print(f' FAIL {name} {detail}') FAIL.append(name) def num_grad(f, x, eps=1e-5): """Numeric gradient of scalar f at x.""" g = np.zeros_like(x) it = np.nditer(x, flags=['multi_index']) while not it.finished: idx = it.multi_index orig = x[idx] x[idx] = orig + eps fp = f(x) x[idx] = orig - eps fm = f(x) x[idx] = orig g[idx] = (fp - fm) / (2 * eps) it.iternext() return g # --------------------------------------------------------------------------- def test_t1_dwconv_shape(): print('\nT1: DWConv1d forward shape') layer = DWConv1d(8, 7, RNG) x = RNG.randn(4, 8, 32) y = layer.forward(x) check('output_shape', y.shape == (4, 8, 32), str(y.shape)) def test_t2_dwconv_grad(): print('\nT2: DWConv1d backward numerical check') rng2 = np.random.RandomState(1) layer = DWConv1d(4, 3, rng2) x = rng2.randn(2, 4, 8).astype(np.float64) * 0.5 def fwd(w_flat): layer.w.data[:] = w_flat.reshape(layer.w.data.shape) out = layer.forward(x) return float(out.sum()) layer.forward(x) layer.backward(np.ones((2, 4, 8))) w0 = layer.w.data.copy() ng = num_grad(lambda w: fwd(w.ravel()), w0.ravel()) ag = layer.w.grad.ravel() rel = np.abs(ag - ng) / (np.abs(ng) + 1e-8) check('dw_max_rel_err', rel.max() < 0.01, f'max_rel={rel.max():.4f}') # also check dx def fwd_x(xf): return float(layer.forward(xf.reshape(2, 4, 8)).sum()) ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy()) layer.forward(x) gx = layer.backward(np.ones((2, 4, 8))).ravel() rel_x = np.abs(gx - ng_x) / (np.abs(ng_x) + 1e-8) check('dx_max_rel_err', rel_x.max() < 0.01, f'max_rel_x={rel_x.max():.4f}') def test_t3_layernorm_grad(): print('\nT3: LayerNorm backward numerical check') rng2 = np.random.RandomState(2) layer = LayerNorm(8) x = rng2.randn(4, 12, 8).astype(np.float64) def fwd_gamma(gf): layer.gamma.data[:] = gf return float(layer.forward(x).sum()) layer.forward(x) layer.backward(np.ones((4, 12, 8))) ag = layer.gamma.grad.copy() ng = num_grad(lambda gf: fwd_gamma(gf), layer.gamma.data.copy()) rel = np.abs(ag - ng) / (np.abs(ng) + 1e-8) check('dgamma_max_rel_err', rel.max() < 0.01, f'{rel.max():.4f}') def fwd_x(xf): return float(layer.forward(xf.reshape(4, 12, 8)).sum()) dx = layer.backward(np.ones((4, 12, 8))).ravel() ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy()) rel_x = np.abs(dx - ng_x) / (np.abs(ng_x) + 1e-8) check('dx_max_rel_err', rel_x.max() < 0.01, f'{rel_x.max():.4f}') def test_t4_block_shape(): print('\nT4: ConvNeXtBlock1D shape preserved') rng2 = np.random.RandomState(3) blk = ConvNeXtBlock1D(16, 7, rng2) x = rng2.randn(4, 16, 32) y = blk.forward(x) check('shape_unchanged', y.shape == x.shape, str(y.shape)) check('skip_active', not np.allclose(y, x)) # block changes values def test_t5_block_grad(): print('\nT5: ConvNeXtBlock1D backward numerical check (small block)') rng2 = np.random.RandomState(4) blk = ConvNeXtBlock1D(4, 3, rng2) x = rng2.randn(2, 4, 8).astype(np.float64) * 0.3 def fwd_x(xf): return float(blk.forward(xf.reshape(2, 4, 8)).sum()) blk.forward(x) dx_an = blk.backward(np.ones((2, 4, 8))).ravel() ng_x = num_grad(lambda xf: fwd_x(xf), x.ravel().copy()) rel = np.abs(dx_an - ng_x) / (np.abs(ng_x) + 1e-8) check('dx_max_rel_err', rel.max() < 0.02, f'{rel.max():.4f}') def test_t6_vae_shapes(): print('\nT6: ConvNeXtVAE forward shapes') rng2 = np.random.RandomState(5) model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=5) x = rng2.randn(4, 8, 32) x_recon, z_mu, z_logvar, z = model.forward(x) check('x_recon_shape', x_recon.shape == (4, 8, 32), str(x_recon.shape)) check('z_mu_shape', z_mu.shape == (4, 16), str(z_mu.shape)) check('z_logvar_shape',z_logvar.shape == (4, 16), str(z_logvar.shape)) check('x_recon_finite',np.all(np.isfinite(x_recon))) check('z_mu_finite', np.all(np.isfinite(z_mu))) def test_t7_vae_backward(): print('\nT7: ConvNeXtVAE backward: grads finite, loss finite') rng2 = np.random.RandomState(6) model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=6) x = rng2.randn(8, 8, 32) eps_noise = rng2.randn(8, 16) # manual forward with stored eps z_mu, z_logvar = model.encode(x) z = z_mu + eps_noise * np.exp(0.5 * z_logvar) x_recon = model.decode(z) loss, info = btcvae_loss(x, x_recon, z_mu, z_logvar, z, beta_tc=2.0) check('loss_finite', np.isfinite(loss), f'loss={loss:.4f}') for k, v in info.items(): check(f'{k}_finite', np.isfinite(v), f'{k}={v:.4f}') d_recon, d_z_mu, d_z_logvar = btcvae_loss_backward( x, x_recon, z_mu, z_logvar, z, eps_noise, beta_tc=2.0 ) model.zero_grad() dz = model.backward_decode(d_recon) model.backward_encode(d_z_mu + dz, d_z_logvar) all_grads = [p.grad for p in model.all_params()] any_nan = any(np.any(np.isnan(g)) for g in all_grads) any_inf = any(np.any(np.isinf(g)) for g in all_grads) all_zero = all(np.all(g == 0) for g in all_grads) check('grads_no_nan', not any_nan) check('grads_no_inf', not any_inf) check('grads_not_all_zero', not all_zero) def test_t8_loss_decreases(): print('\nT8: Loss decreases over 20 gradient steps') rng2 = np.random.RandomState(7) model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=7) x = rng2.randn(16, 8, 32).astype(np.float64) losses = [] lr = 1e-3 for step in range(20): rng2_eps = np.random.RandomState(step) eps_noise = rng2_eps.randn(16, 16) z_mu, z_logvar = model.encode(x) z = z_mu + eps_noise * np.exp(0.5 * z_logvar) x_recon = model.decode(z) loss, _ = btcvae_loss(x, x_recon, z_mu, z_logvar, z, beta_tc=1.0, alpha_mi=1.0) losses.append(loss) d_recon, d_z_mu, d_z_logvar = btcvae_loss_backward( x, x_recon, z_mu, z_logvar, z, eps_noise, beta_tc=1.0, alpha_mi=1.0 ) model.zero_grad() dz = model.backward_decode(d_recon) model.backward_encode(d_z_mu + dz, d_z_logvar) model.adam_step(lr) first5 = np.mean(losses[:5]) last5 = np.mean(losses[-5:]) print(f' loss[0:5]={first5:.4f} loss[15:20]={last5:.4f}') check('loss_decreasing', last5 < first5, f'{last5:.4f} vs {first5:.4f}') def test_t9_save_load(): print('\nT9: Save/load round-trip') rng2 = np.random.RandomState(8) model = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=8) x = rng2.randn(4, 8, 32) # deterministic forward: encode then decode with fixed z_mu (no sampling) z_mu1, _ = model.encode(x) x_recon1 = model.decode(z_mu1) nm = rng2.randn(8).astype(np.float64) ns = np.abs(rng2.randn(8)).astype(np.float64) + 0.1 with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: path = f.name try: model.save(path, norm_mean=nm, norm_std=ns) model2 = ConvNeXtVAE(C_in=8, T_in=32, z_dim=16, base_ch=16, n_blocks=2, seed=0) extras = model2.load(path) check('norm_mean_saved', 'norm_mean' in extras) check('norm_std_saved', 'norm_std' in extras) z_mu2, _ = model2.encode(x) x_recon2 = model2.decode(z_mu2) check('recon_reproduced', np.allclose(x_recon1, x_recon2, atol=1e-10)) check('z_mu_reproduced', np.allclose(z_mu1, z_mu2, atol=1e-10)) finally: os.unlink(path) def test_t10_btcvae_grad(): print('\nT10: beta-TCVAE backward numerical gradient check (z_mu, z_logvar)') rng2 = np.random.RandomState(9) B, D = 8, 6 C, T = 4, 8 x = rng2.randn(B, C, T).astype(np.float64) z_mu = rng2.randn(B, D).astype(np.float64) * 0.5 z_lv = rng2.randn(B, D).astype(np.float64) * 0.3 - 1.0 # negative → small var eps = rng2.randn(B, D).astype(np.float64) z = z_mu + eps * np.exp(0.5 * z_lv) # Use a fixed simple decoder output to avoid recomputing model x_recon = rng2.randn(B, C, T).astype(np.float64) * 0.5 def loss_fn_mu(mu_flat): mu = mu_flat.reshape(B, D) z_ = mu + eps * np.exp(0.5 * z_lv) l, _ = btcvae_loss(x, x_recon, mu, z_lv, z_, beta_tc=2.0, alpha_mi=1.0) return l def loss_fn_lv(lv_flat): lv = lv_flat.reshape(B, D) z_ = z_mu + eps * np.exp(0.5 * lv) l, _ = btcvae_loss(x, x_recon, z_mu, lv, z_, beta_tc=2.0, alpha_mi=1.0) return l _, d_z_mu, d_z_lv = btcvae_loss_backward( x, x_recon, z_mu, z_lv, z, eps, beta_tc=2.0, alpha_mi=1.0 ) ng_mu = num_grad(loss_fn_mu, z_mu.ravel().copy()).reshape(B, D) ng_lv = num_grad(loss_fn_lv, z_lv.ravel().copy()).reshape(B, D) rel_mu = np.abs(d_z_mu - ng_mu) / (np.abs(ng_mu) + 1e-6) rel_lv = np.abs(d_z_lv - ng_lv) / (np.abs(ng_lv) + 1e-6) print(f' d_z_mu max_rel={rel_mu.max():.4f} mean_rel={rel_mu.mean():.4f}') print(f' d_z_lv max_rel={rel_lv.max():.4f} mean_rel={rel_lv.mean():.4f}') check('d_z_mu_rel_err', rel_mu.max() < 0.05, f'{rel_mu.max():.4f}') check('d_z_lv_rel_err', rel_lv.max() < 0.05, f'{rel_lv.max():.4f}') # --------------------------------------------------------------------------- if __name__ == '__main__': print('=' * 60) print('ConvNeXt-1D beta-TCVAE unit tests') print('=' * 60) test_t1_dwconv_shape() test_t2_dwconv_grad() test_t3_layernorm_grad() test_t4_block_shape() test_t5_block_grad() test_t6_vae_shapes() test_t7_vae_backward() test_t8_loss_decreases() test_t9_save_load() test_t10_btcvae_grad() print() print('=' * 60) print(f'Results: {len(PASS)} PASS / {len(FAIL)} FAIL') if FAIL: print('FAILED:', FAIL) sys.exit(1) else: print('All tests passed.')