-
Notifications
You must be signed in to change notification settings - Fork 372
Expand file tree
/
Copy pathtest_huggingface.py
More file actions
292 lines (235 loc) · 10.1 KB
/
test_huggingface.py
File metadata and controls
292 lines (235 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from contextlib import nullcontext
import pytest
import torch
import torch.nn as nn
from _test_utils.torch.misc import set_seed
from _test_utils.torch.transformers_models import (
create_tiny_llama_dir,
get_tiny_gpt_oss,
get_tiny_llama,
get_tiny_qwen3_5,
get_tiny_qwen3_5_moe,
get_tiny_qwen3_moe,
tf_modelopt_state_and_output_tester,
)
from packaging.version import Version
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.nn import QuantLinear, QuantModuleRegistry
from modelopt.torch.quantization.plugins.huggingface import (
get_homogeneous_hf_decoder_layers,
is_homogeneous_hf_model,
)
from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector
pytest.importorskip("transformers")
import transformers
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.dbrx.configuration_dbrx import DbrxConfig, DbrxFFNConfig
from transformers.models.dbrx.modeling_dbrx import DbrxExpertGLU, DbrxExperts, DbrxFFN
class HFModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# initialization is (out_features, in_features) instead of (in_features, out_features)
transformers.pytorch_utils.Conv1D(5, 3),
nn.ReLU(),
transformers.pytorch_utils.Conv1D(5, 5),
)
def forward(self, x):
return self.net(x)
class PytorchModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
QuantLinear(3, 5),
nn.ReLU(),
QuantLinear(5, 5),
)
def forward(self, x):
return self.net(x)
def test_convert_conv1d():
set_seed()
assert transformers.pytorch_utils.Conv1D in QuantModuleRegistry
model_ref = HFModel()
model_test = HFModel()
model_test.load_state_dict(model_ref.state_dict())
mtq.replace_quant_module(model_test)
for name, module in model_test.named_modules():
if isinstance(module, transformers.pytorch_utils.Conv1D):
assert hasattr(module, "input_quantizer")
assert hasattr(module, "weight_quantizer")
assert hasattr(module, "output_quantizer")
mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False})
x = torch.randn(2, 3)
out_1 = model_ref(x)
out_2 = model_test(x)
assert torch.allclose(out_1, out_2)
mtq.set_quantizer_attributes_partial(model_test, "*input_quantizer", {"enable": True})
mtq.set_quantizer_attributes_partial(model_test, "*weight_quantizer", {"enable": True})
model_ref = PytorchModel()
model_ref.load_state_dict(model_test.state_dict())
out_1 = model_ref(x)
out_2 = model_test(x)
assert torch.allclose(out_1, out_2)
@pytest.mark.skipif(
Version(transformers.__version__) < Version("5.0"),
reason="test_dbrx is not supported for transformers<5.0",
)
def test_dbrx():
assert DbrxExperts in QuantModuleRegistry
assert DbrxExpertGLU in QuantModuleRegistry
config = DbrxConfig(
ffn_config=DbrxFFNConfig(ffn_hidden_size=8, moe_num_experts=2, hidden_size=32),
hidden_size=32,
)
model_ref = DbrxFFN(config)
model_test = DbrxFFN(config)
with torch.no_grad():
model_ref.experts.mlp.w1.copy_(torch.randn(16, 32))
model_ref.experts.mlp.v1.copy_(torch.randn(16, 32))
model_ref.experts.mlp.w2.copy_(torch.randn(16, 32))
model_test.load_state_dict(model_ref.state_dict())
mtq.replace_quant_module(model_test)
expertglu_ref = model_ref.experts.mlp
expertglu_test = model_test.experts.mlp
assert hasattr(expertglu_test, "w1_linear") and not hasattr(expertglu_test, "w1")
assert hasattr(expertglu_test, "v1_linear") and not hasattr(expertglu_test, "v1")
assert hasattr(expertglu_test, "w2_linear") and not hasattr(expertglu_test, "w2")
# Weights are stored transposed (W = w1[i].T) to match F.linear semantics with
# transformers 5.0's raw matmul: x @ w1[i] = F.linear(x, w1[i].T)
assert torch.allclose(
torch.concat([m.weight.T for m in expertglu_test.w1_linear], dim=0),
expertglu_ref.w1,
)
mtq.set_quantizer_attributes_partial(model_test, "*", {"enable": False})
# In transformers 5.0, the FFN input dimension is ffn_hidden_size (not hidden_size)
x = torch.randn(1, 4, 8)
out_1 = model_ref(x)
out_2 = model_test(x)
assert torch.allclose(out_1[0], out_2[0])
@pytest.mark.parametrize("method", ["gradient", "kl_div"])
@pytest.mark.parametrize("model_provider", [get_tiny_llama, get_tiny_qwen3_moe])
def test_autoquantize_huggingface(model_provider, method):
if model_provider == get_tiny_qwen3_moe and Version(torch.__version__) < Version("2.9"):
pytest.skip("torch 2.8 grouped_mm is CUDA-only")
model = model_provider()
input_ids = model.dummy_inputs["input_ids"]
def forward_step(model, batch):
return model(**batch) if method == "gradient" else model(**batch).logits
warnings.filterwarnings(
"error", message="AutoQuantize: Error enabling gradient checkpointing for huggingface model"
)
# Gradient checkpointing warning should only appear for gradient-based method
context = (
pytest.warns(
UserWarning,
match="AutoQuantize: Huggingface model detected - Enabling gradient checkpointing. ",
)
if method == "gradient"
else nullcontext()
)
with context:
best_model, search_history = mtq.auto_quantize(
model,
constraints={"effective_bits": 11.0},
quantization_formats=[mtq.INT8_DEFAULT_CFG],
data_loader=[{"input_ids": input_ids, "labels": input_ids} for _ in range(2)],
forward_step=forward_step,
loss_func=lambda output, data: output.loss,
num_calib_steps=2,
num_score_steps=2,
verbose=True,
method=method,
)
@pytest.mark.parametrize(
("model_cls", "quant_config"),
[
(LlamaForCausalLM, mtq.INT4_AWQ_CFG),
(AutoModelForCausalLM, mtq.INT4_AWQ_CFG),
],
)
def test_quantized_transformers_save_restore(tmp_path, model_cls, quant_config):
tiny_llama_dir = create_tiny_llama_dir(tmp_path, dtype=torch.float32)
# update config to fit test cases
if quant_config == mtq.INT4_AWQ_CFG:
import copy
quant_config = copy.deepcopy(quant_config)
for entry in quant_config["quant_cfg"]:
if entry["quantizer_name"] == "*weight_quantizer":
entry.setdefault("cfg", {})["block_sizes"] = {-1: 16}
break
else:
raise ValueError(f"Unsupported quant_config: {quant_config}")
model_ref = model_cls.from_pretrained(tiny_llama_dir)
mtq.quantize(model_ref, quant_config, lambda model: model(**model.dummy_inputs))
mtq.compress(model_ref)
model_ref.save_pretrained(tiny_llama_dir / "modelopt_model")
assert os.path.exists(tiny_llama_dir / "modelopt_model/modelopt_state.pth")
model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model")
tf_modelopt_state_and_output_tester(model_ref, model_test)
def test_is_homogeneous_hf_model_llama():
model = get_tiny_llama()
assert is_homogeneous_hf_model(model)
def test_is_homogeneous_hf_model_gpt_oss():
model = get_tiny_gpt_oss(num_hidden_layers=1)
assert is_homogeneous_hf_model(model)
def test_hf_decoder_discoverer_registration_path():
model = get_tiny_llama()
assert any(
is_supported is is_homogeneous_hf_model and discoverer is get_homogeneous_hf_decoder_layers
for is_supported, discoverer in LayerActivationCollector._decoder_layer_support
)
assert LayerActivationCollector.get_decoder_layers(model) is get_homogeneous_hf_decoder_layers(
model
)
def test_qwen3_5_hybrid_attention_quantize():
"""Verify FP8 quantization disables all linear_attn quantizers while self_attn is quantized."""
model = get_tiny_qwen3_5()
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, lambda m: m(**m.dummy_inputs))
for name, module in model.named_modules():
if not hasattr(module, "weight_quantizer"):
continue
if "linear_attn" in name:
assert not module.weight_quantizer.is_enabled, (
f"linear_attn module {name} should have weight_quantizer disabled"
)
assert not module.input_quantizer.is_enabled, (
f"linear_attn module {name} should have input_quantizer disabled"
)
elif "self_attn" in name and "layernorm" not in name:
assert module.weight_quantizer.is_enabled, (
f"self_attn module {name} should have weight_quantizer enabled"
)
@pytest.mark.skipif(
Version(torch.__version__) < Version("2.9"),
reason="torch 2.8 grouped_mm is CUDA-only",
)
def test_qwen3_5_moe_experts_not_quantized():
"""Verify MoE expert quantizers are disabled when build_quant_cfg rules are applied."""
model = get_tiny_qwen3_5_moe()
import copy
quant_cfg = copy.deepcopy(mtq.FP8_DEFAULT_CFG)
quant_cfg["quant_cfg"].append({"quantizer_name": "*experts*", "enable": False})
mtq.quantize(model, quant_cfg, lambda m: m(**m.dummy_inputs))
for name, module in model.named_modules():
if not hasattr(module, "weight_quantizer"):
continue
if "experts" in name:
assert not module.weight_quantizer.is_enabled, (
f"expert module {name} should have weight_quantizer disabled"
)