@@ -42,53 +42,6 @@ def bridge_and_adapter(self, model_name, device):
4242 bridge = TransformerBridge .boot_transformers (model_name , device = device )
4343 return bridge , bridge .adapter , bridge .cfg
4444
45- @pytest .mark .skip (
46- reason = "API not implemented - ProcessWeights doesn't convert to TL format keys"
47- )
48- def test_processing_with_architecture_adapter (
49- self , raw_hf_model_and_state_dict , bridge_and_adapter
50- ):
51- """Test ProcessWeights.process_weights with architecture adapter."""
52- raw_hf_model , raw_state_dict = raw_hf_model_and_state_dict
53- bridge , adapter , cfg = bridge_and_adapter
54-
55- # Preprocess weights first (this converts to TL format with split Q/K/V)
56- preprocessed_state_dict = adapter .preprocess_weights (raw_state_dict )
57-
58- # Process with architecture adapter
59- processed_with_adapter = ProcessWeights .process_weights (
60- state_dict = preprocessed_state_dict ,
61- cfg = cfg ,
62- adapter = adapter ,
63- fold_ln = False ,
64- center_writing_weights = False ,
65- center_unembed = False ,
66- fold_value_biases = False ,
67- )
68-
69- # Verify processing occurred
70- assert len (processed_with_adapter ) > 0 , "Should process weights with adapter"
71-
72- # Check for TransformerLens-style keys (after preprocessing)
73- # These should be in format like: blocks.0.attn.W_Q, blocks.0.attn.W_K, etc.
74- tl_keys = [
75- k
76- for k in processed_with_adapter .keys ()
77- if any (
78- pattern in k for pattern in [".W_Q" , ".W_K" , ".W_V" , ".W_O" , ".b_Q" , ".b_K" , ".b_V" ]
79- )
80- ]
81- assert (
82- len (tl_keys ) > 0
83- ), "Should have TransformerLens-style attention keys after preprocessing"
84-
85- # Check that expected TL-style keys exist
86- expected_patterns = ["blocks.0.attn.W_Q" , "blocks.0.attn.W_K" , "blocks.0.attn.W_V" ]
87- for pattern in expected_patterns :
88- assert any (
89- pattern in k for k in processed_with_adapter .keys ()
90- ), f"Should have { pattern } in processed weights"
91-
9245 def test_processing_without_architecture_adapter (
9346 self , raw_hf_model_and_state_dict , bridge_and_adapter
9447 ):
@@ -116,125 +69,6 @@ def test_processing_without_architecture_adapter(
11669 ]
11770 assert len (hf_keys ) > 0 , "Should have HF-style keys without adapter"
11871
119- @pytest .mark .skip (reason = "API not implemented - adapter.preprocess_weights doesn't split Q/K/V" )
120- def test_processing_with_different_flags (self , raw_hf_model_and_state_dict , bridge_and_adapter ):
121- """Test processing with different flag combinations."""
122- raw_hf_model , raw_state_dict = raw_hf_model_and_state_dict
123- bridge , adapter , cfg = bridge_and_adapter
124-
125- # Preprocess weights first
126- preprocessed_state_dict = adapter .preprocess_weights (raw_state_dict )
127-
128- # Test processing with all flags enabled
129- processed_with_flags = ProcessWeights .process_weights (
130- state_dict = preprocessed_state_dict .copy (),
131- cfg = cfg ,
132- adapter = adapter ,
133- fold_ln = True ,
134- center_writing_weights = True ,
135- center_unembed = True ,
136- fold_value_biases = True ,
137- )
138-
139- # Test processing with all flags disabled
140- processed_without_flags = ProcessWeights .process_weights (
141- state_dict = preprocessed_state_dict .copy (),
142- cfg = cfg ,
143- adapter = adapter ,
144- fold_ln = False ,
145- center_writing_weights = False ,
146- center_unembed = False ,
147- fold_value_biases = False ,
148- )
149-
150- # Both should process successfully
151- assert len (processed_with_flags ) > 0 , "Should process weights with flags"
152- assert len (processed_without_flags ) > 0 , "Should process weights without flags"
153-
154- @pytest .mark .skip (reason = "API not implemented - adapter.preprocess_weights doesn't split Q/K/V" )
155- def test_architecture_divergence_handling (
156- self , raw_hf_model_and_state_dict , bridge_and_adapter
157- ):
158- """Test that adapter preprocessing changes the state dict format."""
159- raw_hf_model , raw_state_dict = raw_hf_model_and_state_dict
160- bridge , adapter , cfg = bridge_and_adapter
161-
162- # Preprocess with adapter (splits c_attn into Q/K/V)
163- preprocessed_with_adapter = adapter .preprocess_weights (raw_state_dict )
164-
165- # Process with adapter after preprocessing
166- processed_with_adapter = ProcessWeights .process_weights (
167- state_dict = preprocessed_with_adapter ,
168- cfg = cfg ,
169- adapter = adapter ,
170- fold_ln = True ,
171- center_writing_weights = True ,
172- center_unembed = True ,
173- fold_value_biases = True ,
174- )
175-
176- # Process without adapter (no preprocessing)
177- processed_without_adapter = ProcessWeights .process_weights (
178- state_dict = raw_state_dict .copy (),
179- cfg = cfg ,
180- adapter = None ,
181- fold_ln = True ,
182- center_writing_weights = True ,
183- center_unembed = True ,
184- fold_value_biases = True ,
185- )
186-
187- # Results should be different (different processing paths)
188- with_adapter_keys = set (processed_with_adapter .keys ())
189- without_adapter_keys = set (processed_without_adapter .keys ())
190-
191- # Should have some different keys due to different processing
192- assert (
193- with_adapter_keys != without_adapter_keys
194- ), "With and without adapter should produce different key sets"
195-
196- # With adapter should have split Q/K/V keys
197- tl_attn_keys = [
198- k for k in with_adapter_keys if any (p in k for p in [".W_Q" , ".W_K" , ".W_V" ])
199- ]
200- assert len (tl_attn_keys ) > 0 , "With adapter should have split Q/K/V keys"
201-
202- @pytest .mark .skip (reason = "API not implemented - adapter.preprocess_weights doesn't split Q/K/V" )
203- def test_custom_component_processing_integration (
204- self , raw_hf_model_and_state_dict , bridge_and_adapter
205- ):
206- """Test that adapter preprocessing splits QKV weights correctly."""
207- raw_hf_model , raw_state_dict = raw_hf_model_and_state_dict
208- bridge , adapter , cfg = bridge_and_adapter
209-
210- # Preprocess weights first - this is what splits Q/K/V
211- preprocessed_weights = adapter .preprocess_weights (raw_state_dict )
212-
213- # Process with adapter after preprocessing
214- processed_weights = ProcessWeights .process_weights (
215- state_dict = preprocessed_weights ,
216- cfg = cfg ,
217- adapter = adapter ,
218- fold_ln = False ,
219- center_writing_weights = False ,
220- center_unembed = False ,
221- fold_value_biases = False ,
222- )
223-
224- # Check for split Q/K/V weights (created by preprocessing)
225- custom_qkv_found = any (".W_Q" in k for k in processed_weights .keys ())
226-
227- assert custom_qkv_found , "Should have split QKV weights after preprocessing"
228-
229- # Verify that QKV splitting occurred for each layer
230- q_keys = [k for k in processed_weights .keys () if ".W_Q" in k ]
231- k_keys = [k for k in processed_weights .keys () if ".W_K" in k ]
232- v_keys = [k for k in processed_weights .keys () if ".W_V" in k ]
233-
234- assert len (q_keys ) > 0 , "Should have Q weight keys"
235- assert len (k_keys ) > 0 , "Should have K weight keys"
236- assert len (v_keys ) > 0 , "Should have V weight keys"
237-
23872 def test_computational_correctness_with_existing_pipeline (self , model_name , device ):
23973 """Test that centralized processing maintains computational correctness."""
24074 test_tokens = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]], dtype = torch .long )
0 commit comments