77 SUPPORTED_ARCHITECTURES ,
88 ArchitectureAdapterFactory ,
99)
10- from transformer_lens .loading_from_pretrained import get_pretrained_model_config
1110from transformer_lens .model_bridge .supported_architectures .llava import (
1211 LlavaArchitectureAdapter ,
1312)
@@ -89,21 +88,21 @@ def test_vision_projector_path(self, adapter):
8988 assert adapter .component_mapping ["vision_projector" ].name == "model.multi_modal_projector"
9089
9190 def test_embed_path (self , adapter ):
92- assert adapter .component_mapping ["embed" ].name == "model.language_model.model. embed_tokens"
91+ assert adapter .component_mapping ["embed" ].name == "model.language_model.embed_tokens"
9392
9493 def test_rotary_emb_path (self , adapter ):
9594 assert (
96- adapter .component_mapping ["rotary_emb" ].name == "model.language_model.model. rotary_emb"
95+ adapter .component_mapping ["rotary_emb" ].name == "model.language_model.rotary_emb"
9796 )
9897
9998 def test_blocks_path (self , adapter ):
100- assert adapter .component_mapping ["blocks" ].name == "model.language_model.model. layers"
99+ assert adapter .component_mapping ["blocks" ].name == "model.language_model.layers"
101100
102101 def test_ln_final_path (self , adapter ):
103- assert adapter .component_mapping ["ln_final" ].name == "model.language_model.model. norm"
102+ assert adapter .component_mapping ["ln_final" ].name == "model.language_model.norm"
104103
105104 def test_unembed_path (self , adapter ):
106- assert adapter .component_mapping ["unembed" ].name == "model.language_model. lm_head"
105+ assert adapter .component_mapping ["unembed" ].name == "lm_head"
107106
108107 def test_weight_processing_conversions_exist (self , adapter ):
109108 assert "blocks.{i}.attn.q.weight" in adapter .weight_processing_conversions
@@ -117,34 +116,3 @@ def test_no_norm_offset_conversions(self, adapter):
117116 assert "ln1" not in key
118117 assert "ln2" not in key
119118 assert "ln_final" not in key
120-
121-
122- class TestLlavaConfigGeneration :
123- """Test that get_pretrained_model_config generates correct configs for LLava."""
124-
125- def test_llava_7b_config (self ):
126- cfg = get_pretrained_model_config ("llava-hf/llava-1.5-7b-hf" )
127- assert cfg .d_model == 4096
128- assert cfg .n_heads == 32
129- assert cfg .n_layers == 32
130- assert cfg .d_mlp == 11008
131- assert cfg .d_vocab == 32064
132- assert cfg .act_fn == "silu"
133- assert cfg .normalization_type == "RMS"
134- assert cfg .original_architecture == "LlavaForConditionalGeneration"
135-
136- def test_llava_13b_config (self ):
137- cfg = get_pretrained_model_config ("llava-hf/llava-1.5-13b-hf" )
138- assert cfg .d_model == 5120
139- assert cfg .n_heads == 40
140- assert cfg .n_layers == 40
141- assert cfg .d_mlp == 13824
142- assert cfg .d_vocab == 32064
143- assert cfg .act_fn == "silu"
144- assert cfg .normalization_type == "RMS"
145- assert cfg .original_architecture == "LlavaForConditionalGeneration"
146-
147- def test_llava_architecture_detection (self ):
148- """Test that 'llava' in model name triggers correct architecture."""
149- cfg = get_pretrained_model_config ("llava-hf/llava-1.5-7b-hf" )
150- assert cfg .original_architecture == "LlavaForConditionalGeneration"
0 commit comments