Skip to content

Commit cf1be01

Browse files
authored
Add qwen3 and llava_next adapter tests (#1331)
* adding llava_next and qwen_3 tests * update docstring and add missing tests
1 parent 961f971 commit cf1be01

2 files changed

Lines changed: 771 additions & 0 deletions

File tree

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
"""Unit tests for LlavaNextArchitectureAdapter.
2+
3+
LlavaNext shares its module hierarchy with the base Llava adapter (HF's forward
4+
handles high-res tiling internally), so these tests assert that the subclass
5+
preserves the inherited config, component mapping, weight conversions, and
6+
that the factory routes the LlavaNext architecture key to it.
7+
"""
8+
9+
from types import SimpleNamespace
10+
from typing import Any
11+
12+
import pytest
13+
14+
from transformer_lens.config import TransformerBridgeConfig
15+
from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
16+
from transformer_lens.conversion_utils.param_processing_conversion import (
17+
ParamProcessingConversion,
18+
)
19+
from transformer_lens.model_bridge.generalized_components import (
20+
BlockBridge,
21+
CLIPVisionEncoderBridge,
22+
EmbeddingBridge,
23+
GatedMLPBridge,
24+
RMSNormalizationBridge,
25+
RotaryEmbeddingBridge,
26+
SiglipVisionEncoderBridge,
27+
UnembeddingBridge,
28+
VisionProjectionBridge,
29+
)
30+
from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
31+
PositionEmbeddingsAttentionBridge,
32+
)
33+
from transformer_lens.model_bridge.supported_architectures.llava import (
34+
LlavaArchitectureAdapter,
35+
)
36+
from transformer_lens.model_bridge.supported_architectures.llava_next import (
37+
LlavaNextArchitectureAdapter,
38+
)
39+
40+
41+
def _make_cfg(
42+
n_heads: int = 8,
43+
n_key_value_heads: int = 4,
44+
d_model: int = 64,
45+
n_layers: int = 2,
46+
d_vocab: int = 100,
47+
n_ctx: int = 128,
48+
vision_model_type: str = "clip_vision_model",
49+
) -> TransformerBridgeConfig:
50+
"""Minimal TransformerBridgeConfig with a vision sub-config attached."""
51+
cfg = TransformerBridgeConfig(
52+
d_model=d_model,
53+
d_head=d_model // n_heads,
54+
n_layers=n_layers,
55+
n_ctx=n_ctx,
56+
n_heads=n_heads,
57+
n_key_value_heads=n_key_value_heads,
58+
d_vocab=d_vocab,
59+
default_prepend_bos=True,
60+
architecture="LlavaNextForConditionalGeneration",
61+
)
62+
cfg.vision_config = SimpleNamespace(
63+
model_type=vision_model_type,
64+
hidden_size=128,
65+
num_hidden_layers=4,
66+
num_attention_heads=8,
67+
)
68+
return cfg
69+
70+
71+
@pytest.fixture
72+
def cfg() -> TransformerBridgeConfig:
73+
return _make_cfg()
74+
75+
76+
@pytest.fixture
77+
def adapter(cfg: TransformerBridgeConfig) -> LlavaNextArchitectureAdapter:
78+
return LlavaNextArchitectureAdapter(cfg)
79+
80+
81+
class TestLlavaNextInheritance:
82+
83+
"""Subclass relationship to LlavaArchitectureAdapter. The class body is
84+
`pass`; the inherited surface is the contract worth pinning so a future
85+
accidental override is caught.
86+
"""
87+
88+
def test_subclass_of_llava(self) -> None:
89+
assert issubclass(LlavaNextArchitectureAdapter, LlavaArchitectureAdapter)
90+
91+
def test_instance_is_also_llava(self, adapter: LlavaNextArchitectureAdapter) -> None:
92+
assert isinstance(adapter, LlavaArchitectureAdapter)
93+
94+
95+
class TestLlavaNextAdapterConfig:
96+
"""Multimodal config flags, vision-config propagation
97+
(vision_hidden_size, vision_num_layers, vision_num_heads), and
98+
language-model config defaults (RMSNorm, rotary, gated MLP, eager
99+
attention, GQA via n_key_value_heads)."""
100+
101+
def test_is_multimodal(self, adapter: LlavaNextArchitectureAdapter) -> None:
102+
assert adapter.cfg.is_multimodal is True
103+
104+
def test_normalization_type(self, adapter: LlavaNextArchitectureAdapter) -> None:
105+
assert adapter.cfg.normalization_type == "RMS"
106+
107+
def test_positional_embedding_type(self, adapter: LlavaNextArchitectureAdapter) -> None:
108+
assert adapter.cfg.positional_embedding_type == "rotary"
109+
110+
def test_final_rms(self, adapter: LlavaNextArchitectureAdapter) -> None:
111+
assert adapter.cfg.final_rms is True
112+
113+
def test_gated_mlp(self, adapter: LlavaNextArchitectureAdapter) -> None:
114+
assert adapter.cfg.gated_mlp is True
115+
116+
def test_uses_rms_norm(self, adapter: LlavaNextArchitectureAdapter) -> None:
117+
assert adapter.cfg.uses_rms_norm is True
118+
119+
def test_attn_only(self, adapter: LlavaNextArchitectureAdapter) -> None:
120+
assert adapter.cfg.attn_only is False
121+
122+
def test_attn_implementation(self, adapter: LlavaNextArchitectureAdapter) -> None:
123+
assert adapter.cfg.attn_implementation == "eager"
124+
125+
def test_eps_attr(self, adapter: LlavaNextArchitectureAdapter) -> None:
126+
assert adapter.cfg.eps_attr == "variance_epsilon"
127+
128+
def test_n_key_value_heads_preserved(self, adapter: LlavaNextArchitectureAdapter) -> None:
129+
assert adapter.cfg.n_key_value_heads == 4
130+
131+
def test_vision_config_propagated(self, adapter: LlavaNextArchitectureAdapter) -> None:
132+
assert adapter.cfg.vision_hidden_size == 128
133+
assert adapter.cfg.vision_num_layers == 4
134+
assert adapter.cfg.vision_num_heads == 8
135+
136+
137+
class TestLlavaNextAdapterComponentMapping:
138+
"""
139+
Testcases for setup component mapping
140+
"""
141+
142+
@staticmethod
143+
def _mapping(adapter: LlavaNextArchitectureAdapter) -> dict[str, Any]:
144+
mapping = adapter.component_mapping
145+
assert mapping is not None
146+
return mapping
147+
148+
def test_vision_encoder_clip_default(self, adapter: LlavaNextArchitectureAdapter) -> None:
149+
mapping = self._mapping(adapter)
150+
assert isinstance(mapping["vision_encoder"], CLIPVisionEncoderBridge)
151+
assert mapping["vision_encoder"].name == "model.vision_tower"
152+
153+
def test_vision_encoder_siglip_when_configured(self) -> None:
154+
cfg = _make_cfg(vision_model_type="siglip_vision_model")
155+
adapter = LlavaNextArchitectureAdapter(cfg)
156+
mapping = adapter.component_mapping
157+
assert mapping is not None
158+
assert isinstance(mapping["vision_encoder"], SiglipVisionEncoderBridge)
159+
160+
def test_vision_projector(self, adapter: LlavaNextArchitectureAdapter) -> None:
161+
mapping = self._mapping(adapter)
162+
assert isinstance(mapping["vision_projector"], VisionProjectionBridge)
163+
assert mapping["vision_projector"].name == "model.multi_modal_projector"
164+
165+
def test_embed(self, adapter: LlavaNextArchitectureAdapter) -> None:
166+
mapping = self._mapping(adapter)
167+
assert isinstance(mapping["embed"], EmbeddingBridge)
168+
assert mapping["embed"].name == "model.language_model.embed_tokens"
169+
170+
def test_rotary_emb(self, adapter: LlavaNextArchitectureAdapter) -> None:
171+
mapping = self._mapping(adapter)
172+
assert isinstance(mapping["rotary_emb"], RotaryEmbeddingBridge)
173+
assert mapping["rotary_emb"].name == "model.language_model.rotary_emb"
174+
175+
def test_blocks(self, adapter: LlavaNextArchitectureAdapter) -> None:
176+
mapping = self._mapping(adapter)
177+
assert isinstance(mapping["blocks"], BlockBridge)
178+
assert mapping["blocks"].name == "model.language_model.layers"
179+
180+
def test_ln_final(self, adapter: LlavaNextArchitectureAdapter) -> None:
181+
mapping = self._mapping(adapter)
182+
assert isinstance(mapping["ln_final"], RMSNormalizationBridge)
183+
assert mapping["ln_final"].name == "model.language_model.norm"
184+
185+
def test_unembed(self, adapter: LlavaNextArchitectureAdapter) -> None:
186+
mapping = self._mapping(adapter)
187+
assert isinstance(mapping["unembed"], UnembeddingBridge)
188+
assert mapping["unembed"].name == "lm_head"
189+
190+
def test_block_ln1(self, adapter: LlavaNextArchitectureAdapter) -> None:
191+
blocks = self._mapping(adapter)["blocks"]
192+
assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge)
193+
assert blocks.submodules["ln1"].name == "input_layernorm"
194+
195+
def test_block_ln2(self, adapter: LlavaNextArchitectureAdapter) -> None:
196+
blocks = self._mapping(adapter)["blocks"]
197+
assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge)
198+
assert blocks.submodules["ln2"].name == "post_attention_layernorm"
199+
200+
def test_block_attn(self, adapter: LlavaNextArchitectureAdapter) -> None:
201+
attn = self._mapping(adapter)["blocks"].submodules["attn"]
202+
assert isinstance(attn, PositionEmbeddingsAttentionBridge)
203+
assert attn.name == "self_attn"
204+
assert attn.submodules["q"].name == "q_proj"
205+
assert attn.submodules["k"].name == "k_proj"
206+
assert attn.submodules["v"].name == "v_proj"
207+
assert attn.submodules["o"].name == "o_proj"
208+
209+
def test_block_mlp(self, adapter: LlavaNextArchitectureAdapter) -> None:
210+
mlp = self._mapping(adapter)["blocks"].submodules["mlp"]
211+
assert isinstance(mlp, GatedMLPBridge)
212+
assert mlp.name == "mlp"
213+
assert mlp.submodules["gate"].name == "gate_proj"
214+
assert mlp.submodules["in"].name == "up_proj"
215+
assert mlp.submodules["out"].name == "down_proj"
216+
217+
218+
# ---------------------------------------------------------------------------
219+
# Weight conversion tests
220+
# ---------------------------------------------------------------------------
221+
222+
223+
class TestLlavaNextAdapterWeightConversions:
224+
"""
225+
Testcases for accurate weights conversions
226+
"""
227+
228+
def test_four_conversion_keys(self, adapter: LlavaNextArchitectureAdapter) -> None:
229+
convs = adapter.weight_processing_conversions
230+
assert convs is not None
231+
assert len(convs) == 4
232+
233+
def test_qkvo_keys_present(self, adapter: LlavaNextArchitectureAdapter) -> None:
234+
convs = adapter.weight_processing_conversions
235+
assert convs is not None
236+
for key in [
237+
"blocks.{i}.attn.q.weight",
238+
"blocks.{i}.attn.k.weight",
239+
"blocks.{i}.attn.v.weight",
240+
"blocks.{i}.attn.o.weight",
241+
]:
242+
assert key in convs
243+
244+
def test_q_uses_n_heads(self, adapter: LlavaNextArchitectureAdapter) -> None:
245+
convs = adapter.weight_processing_conversions
246+
assert convs is not None
247+
conv = convs["blocks.{i}.attn.q.weight"]
248+
assert isinstance(conv, ParamProcessingConversion)
249+
assert isinstance(conv.tensor_conversion, RearrangeTensorConversion)
250+
assert conv.tensor_conversion.pattern == "(n h) m -> n m h"
251+
assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads
252+
253+
def test_k_uses_n_key_value_heads(self, adapter: LlavaNextArchitectureAdapter) -> None:
254+
"""GQA: K is split along n_key_value_heads."""
255+
convs = adapter.weight_processing_conversions
256+
assert convs is not None
257+
conv = convs["blocks.{i}.attn.k.weight"]
258+
assert isinstance(conv, ParamProcessingConversion)
259+
assert isinstance(conv.tensor_conversion, RearrangeTensorConversion)
260+
assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_key_value_heads
261+
262+
def test_v_uses_n_key_value_heads(self, adapter: LlavaNextArchitectureAdapter) -> None:
263+
convs = adapter.weight_processing_conversions
264+
assert convs is not None
265+
conv = convs["blocks.{i}.attn.v.weight"]
266+
assert isinstance(conv, ParamProcessingConversion)
267+
assert isinstance(conv.tensor_conversion, RearrangeTensorConversion)
268+
assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_key_value_heads
269+
270+
def test_k_falls_back_to_n_heads_when_no_gqa(self) -> None:
271+
"""Without n_key_value_heads, K must use n_heads."""
272+
cfg = _make_cfg(n_key_value_heads=None)
273+
adapter = LlavaNextArchitectureAdapter(cfg)
274+
convs = adapter.weight_processing_conversions
275+
assert convs is not None
276+
conv = convs["blocks.{i}.attn.k.weight"]
277+
assert isinstance(conv, ParamProcessingConversion)
278+
assert isinstance(conv.tensor_conversion, RearrangeTensorConversion)
279+
assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads
280+
281+
def test_o_pattern(self, adapter: LlavaNextArchitectureAdapter) -> None:
282+
convs = adapter.weight_processing_conversions
283+
assert convs is not None
284+
conv = convs["blocks.{i}.attn.o.weight"]
285+
assert isinstance(conv, ParamProcessingConversion)
286+
assert isinstance(conv.tensor_conversion, RearrangeTensorConversion)
287+
assert conv.tensor_conversion.pattern == "m (n h) -> n h m"
288+
assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads
289+
290+
291+
class TestLlavaNextFactoryRegistration:
292+
"""
293+
Lllava Next factory Registration Tests
294+
"""
295+
296+
def test_factory_key_registered(self) -> None:
297+
from transformer_lens.factories.architecture_adapter_factory import (
298+
SUPPORTED_ARCHITECTURES,
299+
)
300+
301+
assert "LlavaNextForConditionalGeneration" in SUPPORTED_ARCHITECTURES
302+
303+
def test_factory_returns_llava_next_adapter(self) -> None:
304+
from transformer_lens.factories.architecture_adapter_factory import (
305+
ArchitectureAdapterFactory,
306+
)
307+
308+
cfg = _make_cfg()
309+
cfg.architecture = "LlavaNextForConditionalGeneration"
310+
adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
311+
assert isinstance(adapter, LlavaNextArchitectureAdapter)
312+
313+
def test_factory_key_distinct_from_base_llava(self) -> None:
314+
"""LlavaNext must not be aliased to base Llava in the registry."""
315+
from transformer_lens.factories.architecture_adapter_factory import (
316+
SUPPORTED_ARCHITECTURES,
317+
)
318+
319+
assert (
320+
SUPPORTED_ARCHITECTURES["LlavaNextForConditionalGeneration"]
321+
is LlavaNextArchitectureAdapter
322+
)
323+
324+
def test_import_from_init(self) -> None:
325+
from transformer_lens.model_bridge.supported_architectures import (
326+
LlavaNextArchitectureAdapter as FromInit,
327+
)
328+
329+
assert FromInit is LlavaNextArchitectureAdapter

0 commit comments

Comments
 (0)