Skip to content

Commit ed12762

Browse files
committed
Fixing llava tests that were focused on HookedTransformer, which doesn't support it
1 parent 4012537 commit ed12762

1 file changed

Lines changed: 5 additions & 37 deletions

File tree

tests/unit/test_llava_config.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
SUPPORTED_ARCHITECTURES,
88
ArchitectureAdapterFactory,
99
)
10-
from transformer_lens.loading_from_pretrained import get_pretrained_model_config
1110
from transformer_lens.model_bridge.supported_architectures.llava import (
1211
LlavaArchitectureAdapter,
1312
)
@@ -89,21 +88,21 @@ def test_vision_projector_path(self, adapter):
8988
assert adapter.component_mapping["vision_projector"].name == "model.multi_modal_projector"
9089

9190
def test_embed_path(self, adapter):
92-
assert adapter.component_mapping["embed"].name == "model.language_model.model.embed_tokens"
91+
assert adapter.component_mapping["embed"].name == "model.language_model.embed_tokens"
9392

9493
def test_rotary_emb_path(self, adapter):
9594
assert (
96-
adapter.component_mapping["rotary_emb"].name == "model.language_model.model.rotary_emb"
95+
adapter.component_mapping["rotary_emb"].name == "model.language_model.rotary_emb"
9796
)
9897

9998
def test_blocks_path(self, adapter):
100-
assert adapter.component_mapping["blocks"].name == "model.language_model.model.layers"
99+
assert adapter.component_mapping["blocks"].name == "model.language_model.layers"
101100

102101
def test_ln_final_path(self, adapter):
103-
assert adapter.component_mapping["ln_final"].name == "model.language_model.model.norm"
102+
assert adapter.component_mapping["ln_final"].name == "model.language_model.norm"
104103

105104
def test_unembed_path(self, adapter):
106-
assert adapter.component_mapping["unembed"].name == "model.language_model.lm_head"
105+
assert adapter.component_mapping["unembed"].name == "lm_head"
107106

108107
def test_weight_processing_conversions_exist(self, adapter):
109108
assert "blocks.{i}.attn.q.weight" in adapter.weight_processing_conversions
@@ -117,34 +116,3 @@ def test_no_norm_offset_conversions(self, adapter):
117116
assert "ln1" not in key
118117
assert "ln2" not in key
119118
assert "ln_final" not in key
120-
121-
122-
class TestLlavaConfigGeneration:
123-
"""Test that get_pretrained_model_config generates correct configs for LLava."""
124-
125-
def test_llava_7b_config(self):
126-
cfg = get_pretrained_model_config("llava-hf/llava-1.5-7b-hf")
127-
assert cfg.d_model == 4096
128-
assert cfg.n_heads == 32
129-
assert cfg.n_layers == 32
130-
assert cfg.d_mlp == 11008
131-
assert cfg.d_vocab == 32064
132-
assert cfg.act_fn == "silu"
133-
assert cfg.normalization_type == "RMS"
134-
assert cfg.original_architecture == "LlavaForConditionalGeneration"
135-
136-
def test_llava_13b_config(self):
137-
cfg = get_pretrained_model_config("llava-hf/llava-1.5-13b-hf")
138-
assert cfg.d_model == 5120
139-
assert cfg.n_heads == 40
140-
assert cfg.n_layers == 40
141-
assert cfg.d_mlp == 13824
142-
assert cfg.d_vocab == 32064
143-
assert cfg.act_fn == "silu"
144-
assert cfg.normalization_type == "RMS"
145-
assert cfg.original_architecture == "LlavaForConditionalGeneration"
146-
147-
def test_llava_architecture_detection(self):
148-
"""Test that 'llava' in model name triggers correct architecture."""
149-
cfg = get_pretrained_model_config("llava-hf/llava-1.5-7b-hf")
150-
assert cfg.original_architecture == "LlavaForConditionalGeneration"

0 commit comments

Comments
 (0)