Skip to content

Commit edb0975

Browse files
committed
unit tests for sequential calibrate
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 2285fba commit edb0975

File tree

1 file changed

+356
-0
lines changed

1 file changed

+356
-0
lines changed
Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for sequential_calibrate and LayerActivationCollector."""
17+
18+
import pytest
19+
import torch
20+
import torch.nn as nn
21+
22+
from modelopt.torch.quantization.model_calib import sequential_calibrate
23+
from modelopt.torch.quantization.utils import LayerActivationCollector
24+
25+
26+
class _DecoderBlock(nn.Module):
27+
"""Minimal transformer decoder block."""
28+
29+
def __init__(self, dim=16):
30+
super().__init__()
31+
self.attn = nn.Linear(dim, dim, bias=False)
32+
self.ffn = nn.Sequential(
33+
nn.Linear(dim, dim * 4, bias=False),
34+
nn.ReLU(),
35+
nn.Linear(dim * 4, dim, bias=False),
36+
)
37+
self.norm = nn.LayerNorm(dim)
38+
39+
def forward(self, x, **kwargs):
40+
x = x + self.attn(self.norm(x))
41+
x = x + self.ffn(x)
42+
return x
43+
44+
45+
class _SimpleTransformerModel(nn.Module):
46+
"""model.layers (ModuleList) -- the simplest pattern recognised by get_decoder_layers."""
47+
48+
def __init__(self, n_layers=3, dim=16):
49+
super().__init__()
50+
self.layers = nn.ModuleList([_DecoderBlock(dim) for _ in range(n_layers)])
51+
self.embed = nn.Embedding(32, dim)
52+
53+
def forward(self, x, **kwargs):
54+
x = self.embed(x)
55+
for layer in self.layers:
56+
x = layer(x)
57+
return x
58+
59+
60+
class _FlatMLP(nn.Module):
61+
"""No decoder-layer structure -- should be rejected by sequential_calibrate."""
62+
63+
def __init__(self, dim=16):
64+
super().__init__()
65+
self.net = nn.Sequential(nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim))
66+
67+
def forward(self, x):
68+
return self.net(x)
69+
70+
71+
class _SimpleTwoLayerModel(nn.Module):
72+
"""Minimal model with explicit layers for activation-collection tests."""
73+
74+
def __init__(self, dim=16):
75+
super().__init__()
76+
self.layers = nn.ModuleList(
77+
[nn.Linear(dim, dim, bias=False), nn.Linear(dim, dim, bias=False)]
78+
)
79+
80+
def forward(self, x):
81+
for layer in self.layers:
82+
x = layer(x)
83+
return x
84+
85+
86+
def _make_model_and_data(n_layers=3, dim=16, n_batches=2, batch_size=4):
87+
torch.manual_seed(42)
88+
model = _SimpleTransformerModel(n_layers=n_layers, dim=dim)
89+
tokens = [torch.randint(0, 32, (batch_size, 8)) for _ in range(n_batches)]
90+
return model, tokens
91+
92+
93+
def _run_forward(model, data):
94+
for batch in data:
95+
model(batch)
96+
97+
98+
# LayerActivationCollector tests
99+
100+
101+
def test_collector_collects_correct_number_of_inputs():
102+
torch.manual_seed(0)
103+
model = _SimpleTwoLayerModel(dim=8)
104+
collector = LayerActivationCollector(model)
105+
data = [torch.randn(2, 8) for _ in range(3)]
106+
107+
def forward_loop(m):
108+
for d in data:
109+
m(d)
110+
111+
inputs = collector.get_input_activations(model.layers[0], forward_loop)
112+
assert len(inputs) == 3
113+
114+
115+
def test_collector_activations_match_expected():
116+
"""First layer should receive the raw input data."""
117+
torch.manual_seed(0)
118+
model = _SimpleTwoLayerModel(dim=8)
119+
collector = LayerActivationCollector(model)
120+
data = [torch.randn(2, 8)]
121+
122+
def forward_loop(m):
123+
for d in data:
124+
m(d)
125+
126+
inputs = collector.get_input_activations(model.layers[0], forward_loop)
127+
args, kwargs = inputs[0]
128+
assert torch.allclose(args[0], data[0])
129+
130+
131+
def test_collector_second_layer_receives_transformed_input():
132+
"""Second layer should receive first layer's output, not raw input."""
133+
torch.manual_seed(0)
134+
model = _SimpleTwoLayerModel(dim=8)
135+
collector = LayerActivationCollector(model)
136+
x = torch.randn(2, 8)
137+
138+
def forward_loop(m):
139+
m(x)
140+
141+
expected = model.layers[0](x)
142+
inputs = collector.get_input_activations(model.layers[1], forward_loop)
143+
args, _ = inputs[0]
144+
assert torch.allclose(args[0], expected)
145+
146+
147+
def test_collector_forward_is_restored_after_collection():
148+
model = _SimpleTwoLayerModel(dim=8)
149+
collector = LayerActivationCollector(model)
150+
151+
def forward_loop(m):
152+
m(torch.randn(2, 8))
153+
154+
collector.get_input_activations(model.layers[0], forward_loop)
155+
156+
assert not hasattr(model, "_original_forward")
157+
assert not hasattr(model.layers[0], "inputs")
158+
assert not hasattr(model.layers[0], "_original_forward")
159+
160+
161+
def test_collector_cleanup_on_forward_loop_error():
162+
"""Patching should be cleaned up even if forward_loop raises."""
163+
model = _SimpleTwoLayerModel(dim=8)
164+
collector = LayerActivationCollector(model)
165+
166+
def bad_forward_loop(m):
167+
raise RuntimeError("intentional error")
168+
169+
with pytest.raises(RuntimeError, match="intentional error"):
170+
collector.get_input_activations(model.layers[0], bad_forward_loop)
171+
172+
assert not hasattr(model, "_original_forward")
173+
assert not hasattr(model.layers[0], "inputs")
174+
175+
176+
# sequential_calibrate tests
177+
178+
179+
def test_seq_calib_raises_on_none_forward_loop():
180+
model, _ = _make_model_and_data(n_layers=2)
181+
with pytest.raises(ValueError, match="forward_loop must not be None"):
182+
sequential_calibrate(
183+
model,
184+
forward_loop=None,
185+
calib_func=lambda *a, **kw: None,
186+
)
187+
188+
189+
def test_seq_calib_raises_on_unrecognized_model():
190+
model = _FlatMLP()
191+
with pytest.raises(ValueError, match="Could not find transformer layers"):
192+
sequential_calibrate(
193+
model,
194+
forward_loop=lambda m: m(torch.randn(2, 16)),
195+
calib_func=lambda *a, **kw: None,
196+
)
197+
198+
199+
def test_seq_calib_func_called_per_layer():
200+
model, data = _make_model_and_data(n_layers=4)
201+
call_count = [0]
202+
203+
def counting_calib(layer, forward_loop, **kwargs):
204+
call_count[0] += 1
205+
206+
sequential_calibrate(
207+
model,
208+
forward_loop=lambda m: _run_forward(m, data),
209+
calib_func=counting_calib,
210+
)
211+
212+
assert call_count[0] == 4
213+
214+
215+
def test_seq_calib_func_receives_correct_layer():
216+
model, data = _make_model_and_data(n_layers=3)
217+
called_layers = []
218+
219+
def track_layers(layer, forward_loop, **kwargs):
220+
called_layers.append(layer)
221+
222+
sequential_calibrate(
223+
model,
224+
forward_loop=lambda m: _run_forward(m, data),
225+
calib_func=track_layers,
226+
)
227+
228+
for i, layer in enumerate(model.layers):
229+
assert called_layers[i] is layer
230+
231+
232+
def test_seq_calib_kwargs_forwarded():
233+
model, data = _make_model_and_data(n_layers=2)
234+
received_kwargs = []
235+
236+
def capture_kwargs(layer, forward_loop, **kwargs):
237+
received_kwargs.append(kwargs)
238+
239+
sequential_calibrate(
240+
model,
241+
forward_loop=lambda m: _run_forward(m, data),
242+
calib_func=capture_kwargs,
243+
alpha=0.5,
244+
method="max",
245+
)
246+
247+
assert len(received_kwargs) == 2
248+
for kw in received_kwargs:
249+
assert kw["alpha"] == 0.5
250+
assert kw["method"] == "max"
251+
252+
253+
def test_seq_calib_layer_forward_loop_runs_all_batches():
254+
"""The per-layer forward loop passed to calib_func should replay all batches."""
255+
n_batches = 5
256+
model, data = _make_model_and_data(n_layers=2, n_batches=n_batches)
257+
batch_counts = []
258+
259+
def count_batches(layer, forward_loop, **kwargs):
260+
counter = {"n": 0}
261+
orig_forward = layer.forward
262+
263+
def counting_forward(*args, **kw):
264+
counter["n"] += 1
265+
return orig_forward(*args, **kw)
266+
267+
layer.forward = counting_forward
268+
forward_loop(layer)
269+
layer.forward = orig_forward
270+
batch_counts.append(counter["n"])
271+
272+
sequential_calibrate(
273+
model,
274+
forward_loop=lambda m: _run_forward(m, data),
275+
calib_func=count_batches,
276+
)
277+
278+
for count in batch_counts:
279+
assert count == n_batches
280+
281+
282+
def test_seq_calib_does_not_alter_weights():
283+
"""sequential_calibrate itself should not modify model weights."""
284+
model, data = _make_model_and_data(n_layers=3)
285+
weights_before = {n: p.clone() for n, p in model.named_parameters()}
286+
287+
sequential_calibrate(
288+
model,
289+
forward_loop=lambda m: _run_forward(m, data),
290+
calib_func=lambda layer, forward_loop, **kw: None,
291+
)
292+
293+
for n, p in model.named_parameters():
294+
assert torch.equal(p, weights_before[n]), f"Weight {n} was modified"
295+
296+
297+
def test_seq_calib_activations_update_across_layers():
298+
"""Subsequent layers should see activations transformed by prior layers."""
299+
torch.manual_seed(0)
300+
model = _SimpleTransformerModel(n_layers=2, dim=16)
301+
tokens = [torch.randint(0, 32, (2, 4))]
302+
303+
layer_inputs_record = {}
304+
305+
def record_inputs(layer, forward_loop, **kwargs):
306+
activations = []
307+
orig_forward = layer.forward
308+
309+
def capture_forward(*args, **kw):
310+
activations.append(args[0].clone())
311+
return orig_forward(*args, **kw)
312+
313+
layer.forward = capture_forward
314+
forward_loop(layer)
315+
layer.forward = orig_forward
316+
317+
layer_idx = list(model.layers).index(layer)
318+
layer_inputs_record[layer_idx] = activations
319+
320+
sequential_calibrate(
321+
model,
322+
forward_loop=lambda m: [m(t) for t in tokens],
323+
calib_func=record_inputs,
324+
)
325+
326+
assert not torch.allclose(layer_inputs_record[0][0], layer_inputs_record[1][0]), (
327+
"Layer 1 should receive different activations than layer 0"
328+
)
329+
330+
331+
def test_seq_calib_empty_forward_loop():
332+
"""If forward_loop feeds no data, calib_func still gets called with an empty replay."""
333+
model = _SimpleTransformerModel(n_layers=2, dim=16)
334+
replay_counts = []
335+
336+
def check_empty_replay(layer, forward_loop, **kwargs):
337+
counter = {"n": 0}
338+
orig_forward = layer.forward
339+
340+
def counting_forward(*args, **kw):
341+
counter["n"] += 1
342+
return orig_forward(*args, **kw)
343+
344+
layer.forward = counting_forward
345+
forward_loop(layer)
346+
layer.forward = orig_forward
347+
replay_counts.append(counter["n"])
348+
349+
sequential_calibrate(
350+
model,
351+
forward_loop=lambda m: None,
352+
calib_func=check_empty_replay,
353+
)
354+
355+
for count in replay_counts:
356+
assert count == 0

0 commit comments

Comments
 (0)