Skip to content

Commit 5ab4986

Browse files
feat: Add comprehensive tests for TurboModel architecture resolution and quantization state tracking
1 parent 208ad76 commit 5ab4986

4 files changed

Lines changed: 728 additions & 0 deletions

File tree

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from types import SimpleNamespace
2+
from unittest.mock import Mock
3+
4+
import transformers
5+
6+
from quantllm.core.turbo_model import TurboModel
7+
import quantllm.core.turbo_model as turbo_model_module
8+
9+
10+
class _DummySmartConfig(SimpleNamespace):
11+
def print_summary(self):
12+
return None
13+
14+
15+
def _make_smart_config():
16+
return _DummySmartConfig(
17+
bits=16,
18+
effective_loading_bits=16,
19+
dtype="float16",
20+
cpu_offload=False,
21+
device="cpu",
22+
gradient_checkpointing=False,
23+
use_flash_attention=False,
24+
compile_model=False,
25+
)
26+
27+
28+
def _make_tokenizer():
29+
return SimpleNamespace(pad_token=None, eos_token="</s>", eos_token_id=2)
30+
31+
32+
def test_resolve_model_type_detects_common_patterns():
33+
assert TurboModel.resolve_model_type("meta-llama/Llama-3.2-3B") == "llama"
34+
# Newer Qwen names still fall back to the qwen2 base family.
35+
assert TurboModel.resolve_model_type("Qwen/Qwen3-8B") == "qwen2"
36+
assert TurboModel.resolve_model_type("org/custom-arch-1b") is None
37+
38+
39+
def test_register_architecture_maps_new_model_to_base_family(monkeypatch):
40+
monkeypatch.setattr(TurboModel, "_architecture_registry", {})
41+
monkeypatch.setattr(TurboModel, "_model_class_registry", {})
42+
TurboModel.register_architecture("newmodel", base_model_type="llama")
43+
44+
assert TurboModel.resolve_model_type("org/newmodel-7b") == "llama"
45+
46+
47+
def test_registered_class_fallback_is_used(monkeypatch):
48+
monkeypatch.setattr(TurboModel, "_architecture_registry", {})
49+
monkeypatch.setattr(TurboModel, "_model_class_registry", {})
50+
monkeypatch.setattr(
51+
turbo_model_module.SmartConfig,
52+
"detect",
53+
lambda *args, **kwargs: _make_smart_config(),
54+
)
55+
monkeypatch.setattr(
56+
turbo_model_module.AutoTokenizer,
57+
"from_pretrained",
58+
lambda *args, **kwargs: _make_tokenizer(),
59+
)
60+
monkeypatch.setattr(
61+
transformers.AutoConfig,
62+
"from_pretrained",
63+
lambda *args, **kwargs: SimpleNamespace(
64+
model_type="newmodel",
65+
quantization_config=None,
66+
),
67+
)
68+
69+
class _FakeAutoModel:
70+
@staticmethod
71+
def from_pretrained(*args, **kwargs):
72+
raise ValueError("Unrecognized configuration class")
73+
74+
@staticmethod
75+
def from_config(*args, **kwargs):
76+
return SimpleNamespace(config=SimpleNamespace(model_type="llama"))
77+
78+
registered_call = Mock()
79+
80+
def _registered_from_pretrained(cls, *args, **kwargs):
81+
registered_call()
82+
return SimpleNamespace(config=SimpleNamespace(model_type="llama"))
83+
84+
class _RegisteredModel:
85+
from_pretrained = classmethod(_registered_from_pretrained)
86+
87+
monkeypatch.setattr(
88+
turbo_model_module,
89+
"AutoModelForCausalLM",
90+
_FakeAutoModel,
91+
)
92+
93+
TurboModel.register_architecture("newmodel", base_model_type="llama")
94+
TurboModel.register_architecture("llama", model_class=_RegisteredModel)
95+
96+
loaded = TurboModel.from_pretrained(
97+
"org/newmodel-7b",
98+
quantize=False,
99+
verbose=False,
100+
)
101+
102+
assert registered_call.called is True
103+
assert loaded.model.config.model_type == "llama"
104+
105+
106+
def test_from_pretrained_supports_from_config_only(monkeypatch):
107+
monkeypatch.setattr(TurboModel, "_architecture_registry", {})
108+
monkeypatch.setattr(TurboModel, "_model_class_registry", {})
109+
monkeypatch.setattr(
110+
turbo_model_module.SmartConfig,
111+
"detect",
112+
lambda *args, **kwargs: _make_smart_config(),
113+
)
114+
monkeypatch.setattr(
115+
turbo_model_module.AutoTokenizer,
116+
"from_pretrained",
117+
lambda *args, **kwargs: _make_tokenizer(),
118+
)
119+
monkeypatch.setattr(
120+
transformers.AutoConfig,
121+
"from_pretrained",
122+
lambda *args, **kwargs: SimpleNamespace(
123+
model_type="llama",
124+
quantization_config=None,
125+
),
126+
)
127+
128+
class _FakeAutoModel:
129+
called_from_pretrained = False
130+
called_from_config = False
131+
132+
@classmethod
133+
def from_pretrained(cls, *args, **kwargs):
134+
cls.called_from_pretrained = True
135+
return SimpleNamespace(config=SimpleNamespace(model_type="llama"))
136+
137+
@classmethod
138+
def from_config(cls, *args, **kwargs):
139+
cls.called_from_config = True
140+
return SimpleNamespace(config=SimpleNamespace(model_type="llama"))
141+
142+
monkeypatch.setattr(
143+
turbo_model_module,
144+
"AutoModelForCausalLM",
145+
_FakeAutoModel,
146+
)
147+
148+
loaded = TurboModel.from_pretrained(
149+
"org/llama-like-7b",
150+
quantize=False,
151+
verbose=False,
152+
from_config_only=True,
153+
)
154+
155+
assert _FakeAutoModel.called_from_pretrained is False
156+
assert _FakeAutoModel.called_from_config is True
157+
assert loaded.model.config.model_type == "llama"
158+
159+
160+
def test_trust_remote_code_warns_for_unregistered_architecture(monkeypatch, caplog):
161+
monkeypatch.setattr(TurboModel, "_architecture_registry", {})
162+
monkeypatch.setattr(TurboModel, "_model_class_registry", {})
163+
monkeypatch.setattr(
164+
turbo_model_module.SmartConfig,
165+
"detect",
166+
lambda *args, **kwargs: _make_smart_config(),
167+
)
168+
monkeypatch.setattr(
169+
turbo_model_module.AutoTokenizer,
170+
"from_pretrained",
171+
lambda *args, **kwargs: _make_tokenizer(),
172+
)
173+
monkeypatch.setattr(
174+
transformers.AutoConfig,
175+
"from_pretrained",
176+
lambda *args, **kwargs: SimpleNamespace(
177+
model_type="newmodel",
178+
quantization_config=None,
179+
),
180+
)
181+
182+
class _FakeAutoModel:
183+
@staticmethod
184+
def from_pretrained(*args, **kwargs):
185+
if "config" in kwargs:
186+
return SimpleNamespace(config=SimpleNamespace(model_type="llama"))
187+
raise ValueError("Unrecognized configuration class")
188+
189+
monkeypatch.setattr(
190+
turbo_model_module,
191+
"AutoModelForCausalLM",
192+
_FakeAutoModel,
193+
)
194+
195+
with caplog.at_level("WARNING"):
196+
loaded = TurboModel.from_pretrained(
197+
"org/newmodel-7b",
198+
quantize=False,
199+
verbose=False,
200+
base_model_fallback=True,
201+
trust_remote_code=True,
202+
)
203+
204+
assert loaded.model.config.model_type == "llama"
205+
assert (
206+
"trust_remote_code=True is enabled for unregistered architecture 'newmodel'"
207+
in caplog.text
208+
)
209+
210+
211+
def test_quantization_kwargs_are_preserved_during_fallback(monkeypatch):
212+
monkeypatch.setattr(TurboModel, "_architecture_registry", {})
213+
monkeypatch.setattr(TurboModel, "_model_class_registry", {})
214+
smart_config = _make_smart_config()
215+
smart_config.bits = 4
216+
monkeypatch.setattr(
217+
turbo_model_module.SmartConfig,
218+
"detect",
219+
lambda *args, **kwargs: smart_config,
220+
)
221+
monkeypatch.setattr(
222+
turbo_model_module.AutoTokenizer,
223+
"from_pretrained",
224+
lambda *args, **kwargs: _make_tokenizer(),
225+
)
226+
monkeypatch.setattr(
227+
transformers.AutoConfig,
228+
"from_pretrained",
229+
lambda *args, **kwargs: SimpleNamespace(
230+
model_type="newmodel",
231+
quantization_config=None,
232+
),
233+
)
234+
monkeypatch.setattr(
235+
TurboModel,
236+
"_get_quantization_kwargs",
237+
classmethod(lambda cls, cfg: {"quantization_config": "nf4-sentinel"}),
238+
)
239+
240+
calls = []
241+
242+
class _FakeAutoModel:
243+
@staticmethod
244+
def from_pretrained(*args, **kwargs):
245+
calls.append(kwargs)
246+
if len(calls) == 1:
247+
raise ValueError("Unrecognized configuration class")
248+
return SimpleNamespace(config=SimpleNamespace(model_type="llama"))
249+
250+
monkeypatch.setattr(
251+
turbo_model_module,
252+
"AutoModelForCausalLM",
253+
_FakeAutoModel,
254+
)
255+
256+
loaded = TurboModel.from_pretrained(
257+
"org/newmodel-7b",
258+
quantize=True,
259+
verbose=False,
260+
base_model_fallback=True,
261+
)
262+
263+
assert loaded.model.config.model_type == "llama"
264+
assert len(calls) == 2
265+
assert calls[0]["quantization_config"] == "nf4-sentinel"
266+
assert calls[1]["quantization_config"] == "nf4-sentinel"

0 commit comments

Comments
 (0)