From dd0acb26ee2e9422cf2d7ac9f955b7f024353271 Mon Sep 17 00:00:00 2001 From: Sunny Joshi Date: Tue, 26 May 2026 22:54:32 +0100 Subject: [PATCH 1/2] adding llava_next and qwen_3 tests --- .../test_llava_next_adapter.py | 326 ++++++++++++++++++ .../test_qwen3_adapter.py | 311 +++++++++++++++++ 2 files changed, 637 insertions(+) create mode 100644 tests/unit/model_bridge/supported_architectures/test_llava_next_adapter.py create mode 100644 tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py diff --git a/tests/unit/model_bridge/supported_architectures/test_llava_next_adapter.py b/tests/unit/model_bridge/supported_architectures/test_llava_next_adapter.py new file mode 100644 index 000000000..ed35d7953 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_llava_next_adapter.py @@ -0,0 +1,326 @@ +"""Unit tests for LlavaNextArchitectureAdapter. + +LlavaNext shares its module hierarchy with the base Llava adapter (HF's forward +handles high-res tiling internally), so these tests assert that the subclass +preserves the inherited config, component mapping, weight conversions, and +that the factory routes the LlavaNext architecture key to it. +""" + +from types import SimpleNamespace +from typing import Any + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + CLIPVisionEncoderBridge, + EmbeddingBridge, + GatedMLPBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + SiglipVisionEncoderBridge, + UnembeddingBridge, + VisionProjectionBridge, +) +from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( + PositionEmbeddingsAttentionBridge, +) +from transformer_lens.model_bridge.supported_architectures.llava import ( + LlavaArchitectureAdapter, +) +from transformer_lens.model_bridge.supported_architectures.llava_next import ( + LlavaNextArchitectureAdapter, +) + + +def _make_cfg( + n_heads: int = 8, + n_key_value_heads: int = 4, + d_model: int = 64, + n_layers: int = 2, + d_vocab: int = 100, + n_ctx: int = 128, + vision_model_type: str = "clip_vision_model", +) -> TransformerBridgeConfig: + """Minimal TransformerBridgeConfig with a vision sub-config attached.""" + cfg = TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + n_key_value_heads=n_key_value_heads, + d_vocab=d_vocab, + default_prepend_bos=True, + architecture="LlavaNextForConditionalGeneration", + ) + cfg.vision_config = SimpleNamespace( + model_type=vision_model_type, + hidden_size=128, + num_hidden_layers=4, + num_attention_heads=8, + ) + return cfg + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> LlavaNextArchitectureAdapter: + return LlavaNextArchitectureAdapter(cfg) + + +class TestLlavaNextInheritance: + """ + Documentation for subclass relationship + """ + + def test_subclass_of_llava(self) -> None: + assert issubclass(LlavaNextArchitectureAdapter, LlavaArchitectureAdapter) + + def test_instance_is_also_llava(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert isinstance(adapter, LlavaArchitectureAdapter) + + +class TestLlavaNextAdapterConfig: + """ + Config attribute tests + """ + + def test_is_multimodal(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.is_multimodal is True + + def test_normalization_type(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is True + + def test_gated_mlp(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is True + + def test_uses_rms_norm(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_attn_only(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_attn_implementation(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.attn_implementation == "eager" + + def test_eps_attr(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.eps_attr == "variance_epsilon" + + def test_n_key_value_heads_preserved(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.n_key_value_heads == 4 + + def test_vision_config_propagated(self, adapter: LlavaNextArchitectureAdapter) -> None: + assert adapter.cfg.vision_hidden_size == 128 + assert adapter.cfg.vision_num_layers == 4 + assert adapter.cfg.vision_num_heads == 8 + + +class TestLlavaNextAdapterComponentMapping: + """ + Testcases for setup component mapping + """ + + @staticmethod + def _mapping(adapter: LlavaNextArchitectureAdapter) -> dict[str, Any]: + mapping = adapter.component_mapping + assert mapping is not None + return mapping + + def test_vision_encoder_clip_default(self, adapter: LlavaNextArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["vision_encoder"], CLIPVisionEncoderBridge) + assert mapping["vision_encoder"].name == "model.vision_tower" + + def test_vision_encoder_siglip_when_configured(self) -> None: + cfg = _make_cfg(vision_model_type="siglip_vision_model") + adapter = LlavaNextArchitectureAdapter(cfg) + mapping = adapter.component_mapping + assert mapping is not None + assert isinstance(mapping["vision_encoder"], SiglipVisionEncoderBridge) + + def test_vision_projector(self, adapter: LlavaNextArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["vision_projector"], VisionProjectionBridge) + assert mapping["vision_projector"].name == "model.multi_modal_projector" + + def test_embed(self, adapter: LlavaNextArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["embed"], EmbeddingBridge) + assert mapping["embed"].name == "model.language_model.embed_tokens" + + def test_rotary_emb(self, adapter: LlavaNextArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["rotary_emb"], RotaryEmbeddingBridge) + assert mapping["rotary_emb"].name == "model.language_model.rotary_emb" + + def test_blocks(self, adapter: LlavaNextArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["blocks"], BlockBridge) + assert mapping["blocks"].name == "model.language_model.layers" + + def test_ln_final(self, adapter: LlavaNextArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["ln_final"], RMSNormalizationBridge) + assert mapping["ln_final"].name == "model.language_model.norm" + + def test_unembed(self, adapter: LlavaNextArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["unembed"], UnembeddingBridge) + assert mapping["unembed"].name == "lm_head" + + def test_block_ln1(self, adapter: LlavaNextArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) + assert blocks.submodules["ln1"].name == "input_layernorm" + + def test_block_ln2(self, adapter: LlavaNextArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) + assert blocks.submodules["ln2"].name == "post_attention_layernorm" + + def test_block_attn(self, adapter: LlavaNextArchitectureAdapter) -> None: + attn = self._mapping(adapter)["blocks"].submodules["attn"] + assert isinstance(attn, PositionEmbeddingsAttentionBridge) + assert attn.name == "self_attn" + assert attn.submodules["q"].name == "q_proj" + assert attn.submodules["k"].name == "k_proj" + assert attn.submodules["v"].name == "v_proj" + assert attn.submodules["o"].name == "o_proj" + + def test_block_mlp(self, adapter: LlavaNextArchitectureAdapter) -> None: + mlp = self._mapping(adapter)["blocks"].submodules["mlp"] + assert isinstance(mlp, GatedMLPBridge) + assert mlp.name == "mlp" + assert mlp.submodules["gate"].name == "gate_proj" + assert mlp.submodules["in"].name == "up_proj" + assert mlp.submodules["out"].name == "down_proj" + + +# --------------------------------------------------------------------------- +# Weight conversion tests +# --------------------------------------------------------------------------- + + +class TestLlavaNextAdapterWeightConversions: + """ + Testcases for accurate weights conversions + """ + + def test_four_conversion_keys(self, adapter: LlavaNextArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + assert len(convs) == 4 + + def test_qkvo_keys_present(self, adapter: LlavaNextArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + for key in [ + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + ]: + assert key in convs + + def test_q_uses_n_heads(self, adapter: LlavaNextArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_k_uses_n_key_value_heads(self, adapter: LlavaNextArchitectureAdapter) -> None: + """GQA: K is split along n_key_value_heads.""" + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.k.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_key_value_heads + + def test_v_uses_n_key_value_heads(self, adapter: LlavaNextArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.v.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_key_value_heads + + def test_k_falls_back_to_n_heads_when_no_gqa(self) -> None: + """Without n_key_value_heads, K must use n_heads.""" + cfg = _make_cfg(n_key_value_heads=None) + adapter = LlavaNextArchitectureAdapter(cfg) + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.k.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_o_pattern(self, adapter: LlavaNextArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + +class TestLlavaNextFactoryRegistration: + """ + Lllava Next factory Registration Tests + """ + + def test_factory_key_registered(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "LlavaNextForConditionalGeneration" in SUPPORTED_ARCHITECTURES + + def test_factory_returns_llava_next_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg() + cfg.architecture = "LlavaNextForConditionalGeneration" + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, LlavaNextArchitectureAdapter) + + def test_factory_key_distinct_from_base_llava(self) -> None: + """LlavaNext must not be aliased to base Llava in the registry.""" + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert ( + SUPPORTED_ARCHITECTURES["LlavaNextForConditionalGeneration"] + is LlavaNextArchitectureAdapter + ) + + def test_import_from_init(self) -> None: + from transformer_lens.model_bridge.supported_architectures import ( + LlavaNextArchitectureAdapter as FromInit, + ) + + assert FromInit is LlavaNextArchitectureAdapter diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py new file mode 100644 index 000000000..22c84e9b8 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py @@ -0,0 +1,311 @@ +"""Unit tests for Qwen3ArchitectureAdapter. + +Tests cover: +- Config attributes +- Component mapping structure and HF module names (incl. q_norm/k_norm) +- Weight conversion keys/types (GQA: k/v use n_key_value_heads) +- _preprocess_gated_q_proj static helper (gated q_proj slicing) +- Factory registration +""" + +from typing import Any + +import pytest +import torch + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( + PositionEmbeddingsAttentionBridge, +) +from transformer_lens.model_bridge.supported_architectures.qwen3 import ( + Qwen3ArchitectureAdapter, +) + + +def _make_cfg( + n_heads: int = 8, + n_key_value_heads: int = 4, + d_model: int = 64, + n_layers: int = 2, + d_vocab: int = 100, + n_ctx: int = 128, +) -> TransformerBridgeConfig: + """Minimal TransformerBridgeConfig for Qwen3 adapter tests.""" + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + n_key_value_heads=n_key_value_heads, + d_vocab=d_vocab, + default_prepend_bos=False, + architecture="Qwen3ForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> Qwen3ArchitectureAdapter: + return Qwen3ArchitectureAdapter(cfg) + + +class TestQwen3AdapterConfig: + """ + Config attribute tests + """ + + def test_normalization_type(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is True + + def test_gated_mlp(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is True + + def test_attn_only(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_default_prepend_bos_false(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.default_prepend_bos is False + + def test_attn_implementation_eager(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.attn_implementation == "eager" + + def test_n_key_value_heads_preserved(self, adapter: Qwen3ArchitectureAdapter) -> None: + assert adapter.cfg.n_key_value_heads == 4 + + +class TestQwen3AdapterComponentMapping: + """ + Testcases for component mapping setup + """ + + @staticmethod + def _mapping(adapter: Qwen3ArchitectureAdapter) -> dict[str, Any]: + mapping = adapter.component_mapping + assert mapping is not None + return mapping + + def test_embed_type_and_name(self, adapter: Qwen3ArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["embed"], EmbeddingBridge) + assert mapping["embed"].name == "model.embed_tokens" + + def test_rotary_emb(self, adapter: Qwen3ArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["rotary_emb"], RotaryEmbeddingBridge) + assert mapping["rotary_emb"].name == "model.rotary_emb" + + def test_blocks_type_and_name(self, adapter: Qwen3ArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["blocks"], BlockBridge) + assert mapping["blocks"].name == "model.layers" + + def test_ln_final(self, adapter: Qwen3ArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["ln_final"], RMSNormalizationBridge) + assert mapping["ln_final"].name == "model.norm" + + def test_unembed(self, adapter: Qwen3ArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["unembed"], UnembeddingBridge) + assert mapping["unembed"].name == "lm_head" + + def test_ln1(self, adapter: Qwen3ArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln1"], RMSNormalizationBridge) + assert blocks.submodules["ln1"].name == "input_layernorm" + + def test_ln2(self, adapter: Qwen3ArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln2"], RMSNormalizationBridge) + assert blocks.submodules["ln2"].name == "post_attention_layernorm" + + def test_attn_type_and_name(self, adapter: Qwen3ArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["attn"], PositionEmbeddingsAttentionBridge) + assert blocks.submodules["attn"].name == "self_attn" + + def test_attn_qkvo_names(self, adapter: Qwen3ArchitectureAdapter) -> None: + attn = self._mapping(adapter)["blocks"].submodules["attn"] + assert attn.submodules["q"].name == "q_proj" + assert attn.submodules["k"].name == "k_proj" + assert attn.submodules["v"].name == "v_proj" + assert attn.submodules["o"].name == "o_proj" + + def test_attn_qk_norms(self, adapter: Qwen3ArchitectureAdapter) -> None: + """Qwen3-specific Q/K head norms.""" + attn = self._mapping(adapter)["blocks"].submodules["attn"] + assert isinstance(attn.submodules["q_norm"], RMSNormalizationBridge) + assert attn.submodules["q_norm"].name == "q_norm" + assert isinstance(attn.submodules["k_norm"], RMSNormalizationBridge) + assert attn.submodules["k_norm"].name == "k_norm" + + def test_mlp(self, adapter: Qwen3ArchitectureAdapter) -> None: + mlp = self._mapping(adapter)["blocks"].submodules["mlp"] + assert isinstance(mlp, GatedMLPBridge) + assert mlp.name == "mlp" + assert mlp.submodules["gate"].name == "gate_proj" + assert mlp.submodules["in"].name == "up_proj" + assert mlp.submodules["out"].name == "down_proj" + + def test_no_linear_attn_when_dense(self, adapter: Qwen3ArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert "linear_attn" not in blocks.submodules + + +class TestQwen3AdapterWeightConversions: + """ + Weights conversion tests + """ + + def test_four_conversion_keys(self, adapter: Qwen3ArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + assert len(convs) == 4 + + def test_qkvo_keys_present(self, adapter: Qwen3ArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + for key in [ + "blocks.{i}.attn.q.weight", + "blocks.{i}.attn.k.weight", + "blocks.{i}.attn.v.weight", + "blocks.{i}.attn.o.weight", + ]: + assert key in convs + + def test_q_uses_n_heads(self, adapter: Qwen3ArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.q.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + def test_k_uses_n_key_value_heads(self, adapter: Qwen3ArchitectureAdapter) -> None: + """GQA: K is split along n_key_value_heads, not n_heads.""" + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.k.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_key_value_heads + + def test_v_uses_n_key_value_heads(self, adapter: Qwen3ArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.v.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_key_value_heads + + def test_o_pattern(self, adapter: Qwen3ArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + conv = convs["blocks.{i}.attn.o.weight"] + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + +class TestPreprocessGatedQProj: + """ + Tests for _preprocess_gated_q_proj + """ + + def test_slices_query_half(self) -> None: + """Interleaved [query, gate] rows per head must be reduced to query-only.""" + n_heads, d_head, d_model = 4, 8, 16 + # Build q_proj.weight as (n_heads, d_head*2, d_model): query=1.0, gate=9.0 + w = torch.empty(n_heads, d_head * 2, d_model) + w[:, :d_head, :] = 1.0 + w[:, d_head:, :] = 9.0 + w_flat = w.reshape(n_heads * d_head * 2, d_model) + + state_dict = {"model.layers.0.self_attn.q_proj.weight": w_flat.clone()} + out = Qwen3ArchitectureAdapter._preprocess_gated_q_proj(state_dict, n_heads, d_head) + + result = out["model.layers.0.self_attn.q_proj.weight"] + assert result.shape == (n_heads * d_head, d_model) + assert torch.all(result == 1.0), "gate rows must be dropped" + + def test_only_q_proj_keys_modified(self) -> None: + n_heads, d_head, d_model = 2, 4, 8 + q_w = torch.ones(n_heads * d_head * 2, d_model) + other = torch.full((d_model, d_model), 7.0) + state_dict = { + "model.layers.0.self_attn.q_proj.weight": q_w, + "model.layers.0.self_attn.k_proj.weight": other.clone(), + "model.layers.0.mlp.gate_proj.weight": other.clone(), + } + out = Qwen3ArchitectureAdapter._preprocess_gated_q_proj(state_dict, n_heads, d_head) + assert torch.equal(out["model.layers.0.self_attn.k_proj.weight"], other) + assert torch.equal(out["model.layers.0.mlp.gate_proj.weight"], other) + + def test_multiple_layers(self) -> None: + n_heads, d_head, d_model = 2, 4, 8 + state_dict = { + f"model.layers.{i}.self_attn.q_proj.weight": torch.ones(n_heads * d_head * 2, d_model) + for i in range(3) + } + out = Qwen3ArchitectureAdapter._preprocess_gated_q_proj(state_dict, n_heads, d_head) + for i in range(3): + assert out[f"model.layers.{i}.self_attn.q_proj.weight"].shape == ( + n_heads * d_head, + d_model, + ) + + +class TestQwen3FactoryRegistration: + """Factory registeration Tests""" + + def test_factory_key_registered(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "Qwen3ForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_returns_qwen3_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg() + cfg.architecture = "Qwen3ForCausalLM" + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + assert isinstance(adapter, Qwen3ArchitectureAdapter) + + def test_import_from_init(self) -> None: + from transformer_lens.model_bridge.supported_architectures import ( + Qwen3ArchitectureAdapter as FromInit, + ) + + assert FromInit is Qwen3ArchitectureAdapter From 5981131afa596ca7628607e46504e3e8d09f82b8 Mon Sep 17 00:00:00 2001 From: Sunny Joshi Date: Tue, 26 May 2026 23:18:28 +0100 Subject: [PATCH 2/2] update docstring and add missing tests --- .../test_llava_next_adapter.py | 13 +- .../test_qwen3_adapter.py | 153 ++++++++++++++++-- 2 files changed, 150 insertions(+), 16 deletions(-) diff --git a/tests/unit/model_bridge/supported_architectures/test_llava_next_adapter.py b/tests/unit/model_bridge/supported_architectures/test_llava_next_adapter.py index ed35d7953..afc194cb7 100644 --- a/tests/unit/model_bridge/supported_architectures/test_llava_next_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_llava_next_adapter.py @@ -79,8 +79,10 @@ def adapter(cfg: TransformerBridgeConfig) -> LlavaNextArchitectureAdapter: class TestLlavaNextInheritance: - """ - Documentation for subclass relationship + + """Subclass relationship to LlavaArchitectureAdapter. The class body is + `pass`; the inherited surface is the contract worth pinning so a future + accidental override is caught. """ def test_subclass_of_llava(self) -> None: @@ -91,9 +93,10 @@ def test_instance_is_also_llava(self, adapter: LlavaNextArchitectureAdapter) -> class TestLlavaNextAdapterConfig: - """ - Config attribute tests - """ + """Multimodal config flags, vision-config propagation + (vision_hidden_size, vision_num_layers, vision_num_heads), and + language-model config defaults (RMSNorm, rotary, gated MLP, eager + attention, GQA via n_key_value_heads).""" def test_is_multimodal(self, adapter: LlavaNextArchitectureAdapter) -> None: assert adapter.cfg.is_multimodal is True diff --git a/tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py b/tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py index 22c84e9b8..07e0fe979 100644 --- a/tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py +++ b/tests/unit/model_bridge/supported_architectures/test_qwen3_adapter.py @@ -7,7 +7,7 @@ - _preprocess_gated_q_proj static helper (gated q_proj slicing) - Factory registration """ - +from types import SimpleNamespace from typing import Any import pytest @@ -26,6 +26,9 @@ RotaryEmbeddingBridge, UnembeddingBridge, ) +from transformer_lens.model_bridge.generalized_components.gated_delta_net import ( + GatedDeltaNetBridge, +) from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( PositionEmbeddingsAttentionBridge, ) @@ -67,8 +70,8 @@ def adapter(cfg: TransformerBridgeConfig) -> Qwen3ArchitectureAdapter: class TestQwen3AdapterConfig: - """ - Config attribute tests + """Adapter config defaults: RMSNorm, rotary, gated MLP, eager attention, + default_prepend_bos=False, and GQA propagation via n_key_value_heads. """ def test_normalization_type(self, adapter: Qwen3ArchitectureAdapter) -> None: @@ -101,7 +104,8 @@ def test_n_key_value_heads_preserved(self, adapter: Qwen3ArchitectureAdapter) -> class TestQwen3AdapterComponentMapping: """ - Testcases for component mapping setup + Component-mapping structure, bridge types, including the Qwen3-specific per-head q_norm / k_norm and the dense + (non-hybrid) shape with no linear_attn submodule. """ @staticmethod @@ -179,9 +183,8 @@ def test_no_linear_attn_when_dense(self, adapter: Qwen3ArchitectureAdapter) -> N class TestQwen3AdapterWeightConversions: - """ - Weights conversion tests - """ + """QKVO weight conversions with GQA-aware head counts: + Q uses n_heads; K and V use n_key_value_heads.""" def test_four_conversion_keys(self, adapter: Qwen3ArchitectureAdapter) -> None: convs = adapter.weight_processing_conversions @@ -236,9 +239,10 @@ def test_o_pattern(self, adapter: Qwen3ArchitectureAdapter) -> None: class TestPreprocessGatedQProj: - """ - Tests for _preprocess_gated_q_proj - """ + """Numerical correctness of the _preprocess_gated_q_proj static helper + on synthetic interleaved [query, gate] rows: asserts query-half slicing, + that unrelated state-dict keys are untouched, and that the rewrite + applies across all matching layers.""" def test_slices_query_half(self) -> None: """Interleaved [query, gate] rows per head must be reduced to query-only.""" @@ -283,8 +287,135 @@ def test_multiple_layers(self) -> None: ) +class TestQwen3HybridConstructor: + """The hybrid=True constructor branch on the base class. The Qwen3_5 / + Qwen3Next subclasses exercise this path transitively; pinning it here + surfaces regressions in the base contract: + - linear_attn (GatedDeltaNetBridge) submodule appears alongside the + full-attention branch + - supports_fold_ln flips to False + - weight_processing_conversions is cleared + """ + + @pytest.fixture + def hybrid_adapter(self) -> Qwen3ArchitectureAdapter: + return Qwen3ArchitectureAdapter(_make_cfg(), hybrid=True) + + def test_supports_fold_ln_disabled(self, hybrid_adapter: Qwen3ArchitectureAdapter) -> None: + assert hybrid_adapter.supports_fold_ln is False + + def test_weight_processing_conversions_empty( + self, hybrid_adapter: Qwen3ArchitectureAdapter + ) -> None: + assert hybrid_adapter.weight_processing_conversions == {} + + def test_linear_attn_submodule_present(self, hybrid_adapter: Qwen3ArchitectureAdapter) -> None: + mapping = hybrid_adapter.component_mapping + assert mapping is not None + blocks = mapping["blocks"] + assert "linear_attn" in blocks.submodules + assert isinstance(blocks.submodules["linear_attn"], GatedDeltaNetBridge) + assert blocks.submodules["linear_attn"].name == "linear_attn" + + def test_attn_submodule_still_present(self, hybrid_adapter: Qwen3ArchitectureAdapter) -> None: + """Hybrid keeps full attention alongside linear_attn (both optional).""" + mapping = hybrid_adapter.component_mapping + assert mapping is not None + blocks = mapping["blocks"] + assert "attn" in blocks.submodules + assert isinstance(blocks.submodules["attn"], PositionEmbeddingsAttentionBridge) + + def test_dense_default_has_conversions(self, cfg: TransformerBridgeConfig) -> None: + """Sanity contrast: dense (hybrid=False) keeps the QKVO conversions.""" + dense = Qwen3ArchitectureAdapter(cfg) + assert dense.weight_processing_conversions + assert len(dense.weight_processing_conversions) == 4 + + +class _StubAttnBlock: + """Stand-in for a bridge block with an .attn that records set_rotary_emb.""" + + def __init__(self) -> None: + self.attn = SimpleNamespace(_rotary=None) + # set_rotary_emb mimics the PositionEmbeddingBridgeMixin contract. + self.attn.set_rotary_emb = lambda r: setattr(self.attn, "_rotary", r) + # Mirror nn.Module._modules so the adapter's `"attn" in block._modules` check passes. + self._modules = {"attn": self.attn} + + +class TestQwen3SetupComponentTesting: + """ + Setup_component_testing wiring: + - forces eager attention on both the top-level HF config and each + per-layer self_attn.config + - calls set_rotary_emb on each bridge block's attention + - tolerates bridge_model=None (no-op for bridge wiring) + - swallows get_generalized_component lookup failures on the template + (the documented (ValueError, AttributeError, KeyError) net) + """ + + def _make_fake_attn(self, layer_idx: int) -> SimpleNamespace: + """Per-layer self_attn with a mutable .config to assert eager flip.""" + return SimpleNamespace(config=SimpleNamespace(_attn_implementation="sdpa")) + + def _make_fake_hf_model(self, n_layers: int = 2) -> SimpleNamespace: + """Minimal hf_model stub exposing the attributes setup_component_testing walks.""" + layers = [SimpleNamespace(self_attn=self._make_fake_attn(i)) for i in range(n_layers)] + sentinel_rotary = SimpleNamespace(_id="rotary-sentinel") + return SimpleNamespace( + config=SimpleNamespace(_attn_implementation="sdpa"), + model=SimpleNamespace(rotary_emb=sentinel_rotary, layers=layers), + ) + + def test_flips_top_level_attn_implementation_to_eager( + self, adapter: Qwen3ArchitectureAdapter + ) -> None: + hf = self._make_fake_hf_model() + adapter.setup_component_testing(hf) + assert hf.config._attn_implementation == "eager" + + def test_flips_per_layer_attn_implementation_to_eager( + self, adapter: Qwen3ArchitectureAdapter + ) -> None: + hf = self._make_fake_hf_model(n_layers=3) + adapter.setup_component_testing(hf) + for layer in hf.model.layers: + assert layer.self_attn.config._attn_implementation == "eager" + + def test_wires_rotary_on_bridge_blocks(self, adapter: Qwen3ArchitectureAdapter) -> None: + hf = self._make_fake_hf_model() + bridge_blocks = [_StubAttnBlock(), _StubAttnBlock()] + bridge_model = SimpleNamespace(blocks=bridge_blocks) + adapter.setup_component_testing(hf, bridge_model=bridge_model) + for block in bridge_blocks: + assert block.attn._rotary is hf.model.rotary_emb + + def test_skips_bridge_wiring_when_bridge_model_none( + self, adapter: Qwen3ArchitectureAdapter + ) -> None: + """No bridge_model → must not raise; eager flips still apply.""" + hf = self._make_fake_hf_model() + adapter.setup_component_testing(hf, bridge_model=None) + assert hf.config._attn_implementation == "eager" + + def test_swallows_template_lookup_failure( + self, adapter: Qwen3ArchitectureAdapter, monkeypatch: pytest.MonkeyPatch + ) -> None: + """get_generalized_component may raise; setup_component_testing must + not propagate (caught by the (ValueError, AttributeError, KeyError) net).""" + + def _raise(_self: Any, _path: str) -> None: + raise KeyError("blocks.0.attn") + + monkeypatch.setattr(Qwen3ArchitectureAdapter, "get_generalized_component", _raise) + hf = self._make_fake_hf_model() + # Must not raise. + adapter.setup_component_testing(hf) + + class TestQwen3FactoryRegistration: - """Factory registeration Tests""" + """Factory registration and dispatch via select_architecture_adapter, + plus the import-from-__init__ guard.""" def test_factory_key_registered(self) -> None: from transformer_lens.factories.architecture_adapter_factory import (