Skip to content

Commit 23fc6e1

Browse files
committed
Add regression test for cov init shape issue
1 parent 6632bd1 commit 23fc6e1

1 file changed

Lines changed: 72 additions & 0 deletions

File tree

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
from transformers.pytorch_utils import Conv1D as HFConv1D
5+
6+
from bergson.collector.collector import HookCollectorBase
7+
from bergson.hessians.kfac import CovarianceCollector
8+
from bergson.utils.utils import get_device
9+
10+
IN_DIM = 4
11+
OUT_DIM = 6
12+
13+
14+
class TinyConv1DModel(nn.Module):
15+
"""Minimal model mixing HFConv1D and nn.Linear with the same in/out dims."""
16+
17+
def __init__(self):
18+
super().__init__()
19+
self.conv = HFConv1D(OUT_DIM, IN_DIM) # HFConv1D(nf=out, nx=in)
20+
self.linear = nn.Linear(OUT_DIM, IN_DIM)
21+
22+
def forward(self, x):
23+
return self.linear(self.conv(x))
24+
25+
26+
def test_discover_targets_normalizes_conv1d_weight_shape():
27+
model = TinyConv1DModel()
28+
29+
# Sanity-check the HFConv1D storage convention this test guards against
30+
assert model.conv.weight.shape == (IN_DIM, OUT_DIM)
31+
32+
target_info = HookCollectorBase.discover_targets(model)
33+
34+
# Both layer types must report (out, in) in the nn.Linear convention
35+
_, conv_shape, _ = target_info["conv"]
36+
_, linear_shape, _ = target_info["linear"]
37+
assert conv_shape == torch.Size([OUT_DIM, IN_DIM])
38+
assert linear_shape == torch.Size([IN_DIM, OUT_DIM])
39+
40+
41+
def test_covariance_collector_on_conv1d(tmp_path):
42+
"""End-to-end hook pass over a Conv1D layer."""
43+
# CovarianceCollector accumulates on get_device(rank); keep everything there
44+
device = get_device(0)
45+
model = TinyConv1DModel().to(device)
46+
collector = CovarianceCollector(
47+
model=model, dtype=torch.float32, path=str(tmp_path)
48+
)
49+
50+
# A_cov is [in, in], S_cov is [out, out] for the Conv1D layer
51+
assert collector.A_cov_dict["conv"].shape == (IN_DIM, IN_DIM)
52+
assert collector.S_cov_dict["conv"].shape == (OUT_DIM, OUT_DIM)
53+
54+
n, s = 2, 3
55+
x = torch.randn(n, s, IN_DIM, device=device)
56+
mask = torch.ones(n, s, dtype=torch.bool, device=device)
57+
58+
with collector.with_batch(mask):
59+
out = model(x)
60+
out.sum().backward()
61+
62+
# Forward hook accumulated A^T A over valid positions
63+
a = x[mask]
64+
torch.testing.assert_close(collector.A_cov_dict["conv"], a.mT @ a)
65+
66+
# Backward hook accumulated G^T G with the right (out) dimension
67+
assert collector.S_cov_dict["conv"].shape == (OUT_DIM, OUT_DIM)
68+
assert collector.S_cov_dict["conv"].abs().sum() > 0
69+
70+
71+
if __name__ == "__main__":
72+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)