|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +from transformers.pytorch_utils import Conv1D as HFConv1D |
| 5 | + |
| 6 | +from bergson.collector.collector import HookCollectorBase |
| 7 | +from bergson.hessians.kfac import CovarianceCollector |
| 8 | +from bergson.utils.utils import get_device |
| 9 | + |
| 10 | +IN_DIM = 4 |
| 11 | +OUT_DIM = 6 |
| 12 | + |
| 13 | + |
| 14 | +class TinyConv1DModel(nn.Module): |
| 15 | + """Minimal model mixing HFConv1D and nn.Linear with the same in/out dims.""" |
| 16 | + |
| 17 | + def __init__(self): |
| 18 | + super().__init__() |
| 19 | + self.conv = HFConv1D(OUT_DIM, IN_DIM) # HFConv1D(nf=out, nx=in) |
| 20 | + self.linear = nn.Linear(OUT_DIM, IN_DIM) |
| 21 | + |
| 22 | + def forward(self, x): |
| 23 | + return self.linear(self.conv(x)) |
| 24 | + |
| 25 | + |
| 26 | +def test_discover_targets_normalizes_conv1d_weight_shape(): |
| 27 | + model = TinyConv1DModel() |
| 28 | + |
| 29 | + # Sanity-check the HFConv1D storage convention this test guards against |
| 30 | + assert model.conv.weight.shape == (IN_DIM, OUT_DIM) |
| 31 | + |
| 32 | + target_info = HookCollectorBase.discover_targets(model) |
| 33 | + |
| 34 | + # Both layer types must report (out, in) in the nn.Linear convention |
| 35 | + _, conv_shape, _ = target_info["conv"] |
| 36 | + _, linear_shape, _ = target_info["linear"] |
| 37 | + assert conv_shape == torch.Size([OUT_DIM, IN_DIM]) |
| 38 | + assert linear_shape == torch.Size([IN_DIM, OUT_DIM]) |
| 39 | + |
| 40 | + |
| 41 | +def test_covariance_collector_on_conv1d(tmp_path): |
| 42 | + """End-to-end hook pass over a Conv1D layer.""" |
| 43 | + # CovarianceCollector accumulates on get_device(rank); keep everything there |
| 44 | + device = get_device(0) |
| 45 | + model = TinyConv1DModel().to(device) |
| 46 | + collector = CovarianceCollector( |
| 47 | + model=model, dtype=torch.float32, path=str(tmp_path) |
| 48 | + ) |
| 49 | + |
| 50 | + # A_cov is [in, in], S_cov is [out, out] for the Conv1D layer |
| 51 | + assert collector.A_cov_dict["conv"].shape == (IN_DIM, IN_DIM) |
| 52 | + assert collector.S_cov_dict["conv"].shape == (OUT_DIM, OUT_DIM) |
| 53 | + |
| 54 | + n, s = 2, 3 |
| 55 | + x = torch.randn(n, s, IN_DIM, device=device) |
| 56 | + mask = torch.ones(n, s, dtype=torch.bool, device=device) |
| 57 | + |
| 58 | + with collector.with_batch(mask): |
| 59 | + out = model(x) |
| 60 | + out.sum().backward() |
| 61 | + |
| 62 | + # Forward hook accumulated A^T A over valid positions |
| 63 | + a = x[mask] |
| 64 | + torch.testing.assert_close(collector.A_cov_dict["conv"], a.mT @ a) |
| 65 | + |
| 66 | + # Backward hook accumulated G^T G with the right (out) dimension |
| 67 | + assert collector.S_cov_dict["conv"].shape == (OUT_DIM, OUT_DIM) |
| 68 | + assert collector.S_cov_dict["conv"].abs().sum() > 0 |
| 69 | + |
| 70 | + |
| 71 | +if __name__ == "__main__": |
| 72 | + pytest.main([__file__, "-v"]) |
0 commit comments