Skip to content

Commit 00cf3c2

Browse files
authored
Test coverage/cleaning up xfails (#1334)
* Cleaning up NSP & weight processing tests * Clean up old unused skipped tests * Ignoring output in demo notebook import cells
1 parent 64c6375 commit 00cf3c2

9 files changed

Lines changed: 23 additions & 551 deletions

File tree

demos/Othello_GPT.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
"metadata": {},
149149
"outputs": [],
150150
"source": [
151+
"# NBVAL_IGNORE_OUTPUT\n",
151152
"# Import stuff\n",
152153
"import torch\n",
153154
"import torch.nn as nn\n",
@@ -175,10 +176,11 @@
175176
},
176177
{
177178
"cell_type": "code",
178-
"execution_count": 52,
179+
"execution_count": null,
179180
"metadata": {},
180181
"outputs": [],
181182
"source": [
183+
"# NBVAL_IGNORE_OUTPUT\n",
182184
"import transformer_lens\n",
183185
"import transformer_lens.utilities as utils\n",
184186
"from transformer_lens.hook_points import HookPoint\n",

demos/Santa_Coder.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
},
5252
{
5353
"cell_type": "code",
54-
"execution_count": 2,
54+
"execution_count": null,
5555
"id": "da9f5a40",
5656
"metadata": {
5757
"execution": {
@@ -63,6 +63,7 @@
6363
},
6464
"outputs": [],
6565
"source": [
66+
"# NBVAL_IGNORE_OUTPUT\n",
6667
"# Import stuff\n",
6768
"import torch\n",
6869
"import torch.nn as nn\n",

docs/source/content/migrating_to_v3.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ These work identically on `TransformerBridge` and need no migration:
136136

137137
If your code only touches these APIs, the migration is genuinely just the loading call and (optionally) `enable_compatibility_mode`.
138138

139+
### BERT Next Sentence Prediction
140+
141+
`BertNextSentencePrediction` is not ported to `TransformerBridge`. Keep using `HookedEncoder` + `BertNextSentencePrediction` for NSP workflows. The bridge's BERT adapter does load NSP HuggingFace checkpoints (it rewires the unembed to `cls.seq_relationship`), but the high-level NSP API – sentence-pair tokenization, `[CLS]` pooling, "sequential"/"not sequential" decoding — is not exposed. If this is feature is something you'd like added to TransformerBridge, please file an issue.
142+
139143
### New in 3.x: streaming generation
140144

141145
Both `HookedTransformer` and `TransformerBridge` now expose `generate_stream`, which yields tokens progressively instead of returning the full completion at once:

tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def hook_fn(tensor, hook):
8484
)
8585
assert fired == {"resid_pre_0", "resid_post_0"}
8686

87-
@pytest.mark.xfail(reason="add_perma_hook not yet implemented on TransformerBridge")
8887
def test_perma_hook_persists_across_calls(self, bridge):
8988
"""A permanent hook fires on every forward pass until removed."""
9089
count = 0

tests/integration/test_centralized_weight_processing.py

Lines changed: 0 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -42,53 +42,6 @@ def bridge_and_adapter(self, model_name, device):
4242
bridge = TransformerBridge.boot_transformers(model_name, device=device)
4343
return bridge, bridge.adapter, bridge.cfg
4444

45-
@pytest.mark.skip(
46-
reason="API not implemented - ProcessWeights doesn't convert to TL format keys"
47-
)
48-
def test_processing_with_architecture_adapter(
49-
self, raw_hf_model_and_state_dict, bridge_and_adapter
50-
):
51-
"""Test ProcessWeights.process_weights with architecture adapter."""
52-
raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict
53-
bridge, adapter, cfg = bridge_and_adapter
54-
55-
# Preprocess weights first (this converts to TL format with split Q/K/V)
56-
preprocessed_state_dict = adapter.preprocess_weights(raw_state_dict)
57-
58-
# Process with architecture adapter
59-
processed_with_adapter = ProcessWeights.process_weights(
60-
state_dict=preprocessed_state_dict,
61-
cfg=cfg,
62-
adapter=adapter,
63-
fold_ln=False,
64-
center_writing_weights=False,
65-
center_unembed=False,
66-
fold_value_biases=False,
67-
)
68-
69-
# Verify processing occurred
70-
assert len(processed_with_adapter) > 0, "Should process weights with adapter"
71-
72-
# Check for TransformerLens-style keys (after preprocessing)
73-
# These should be in format like: blocks.0.attn.W_Q, blocks.0.attn.W_K, etc.
74-
tl_keys = [
75-
k
76-
for k in processed_with_adapter.keys()
77-
if any(
78-
pattern in k for pattern in [".W_Q", ".W_K", ".W_V", ".W_O", ".b_Q", ".b_K", ".b_V"]
79-
)
80-
]
81-
assert (
82-
len(tl_keys) > 0
83-
), "Should have TransformerLens-style attention keys after preprocessing"
84-
85-
# Check that expected TL-style keys exist
86-
expected_patterns = ["blocks.0.attn.W_Q", "blocks.0.attn.W_K", "blocks.0.attn.W_V"]
87-
for pattern in expected_patterns:
88-
assert any(
89-
pattern in k for k in processed_with_adapter.keys()
90-
), f"Should have {pattern} in processed weights"
91-
9245
def test_processing_without_architecture_adapter(
9346
self, raw_hf_model_and_state_dict, bridge_and_adapter
9447
):
@@ -116,125 +69,6 @@ def test_processing_without_architecture_adapter(
11669
]
11770
assert len(hf_keys) > 0, "Should have HF-style keys without adapter"
11871

119-
@pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V")
120-
def test_processing_with_different_flags(self, raw_hf_model_and_state_dict, bridge_and_adapter):
121-
"""Test processing with different flag combinations."""
122-
raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict
123-
bridge, adapter, cfg = bridge_and_adapter
124-
125-
# Preprocess weights first
126-
preprocessed_state_dict = adapter.preprocess_weights(raw_state_dict)
127-
128-
# Test processing with all flags enabled
129-
processed_with_flags = ProcessWeights.process_weights(
130-
state_dict=preprocessed_state_dict.copy(),
131-
cfg=cfg,
132-
adapter=adapter,
133-
fold_ln=True,
134-
center_writing_weights=True,
135-
center_unembed=True,
136-
fold_value_biases=True,
137-
)
138-
139-
# Test processing with all flags disabled
140-
processed_without_flags = ProcessWeights.process_weights(
141-
state_dict=preprocessed_state_dict.copy(),
142-
cfg=cfg,
143-
adapter=adapter,
144-
fold_ln=False,
145-
center_writing_weights=False,
146-
center_unembed=False,
147-
fold_value_biases=False,
148-
)
149-
150-
# Both should process successfully
151-
assert len(processed_with_flags) > 0, "Should process weights with flags"
152-
assert len(processed_without_flags) > 0, "Should process weights without flags"
153-
154-
@pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V")
155-
def test_architecture_divergence_handling(
156-
self, raw_hf_model_and_state_dict, bridge_and_adapter
157-
):
158-
"""Test that adapter preprocessing changes the state dict format."""
159-
raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict
160-
bridge, adapter, cfg = bridge_and_adapter
161-
162-
# Preprocess with adapter (splits c_attn into Q/K/V)
163-
preprocessed_with_adapter = adapter.preprocess_weights(raw_state_dict)
164-
165-
# Process with adapter after preprocessing
166-
processed_with_adapter = ProcessWeights.process_weights(
167-
state_dict=preprocessed_with_adapter,
168-
cfg=cfg,
169-
adapter=adapter,
170-
fold_ln=True,
171-
center_writing_weights=True,
172-
center_unembed=True,
173-
fold_value_biases=True,
174-
)
175-
176-
# Process without adapter (no preprocessing)
177-
processed_without_adapter = ProcessWeights.process_weights(
178-
state_dict=raw_state_dict.copy(),
179-
cfg=cfg,
180-
adapter=None,
181-
fold_ln=True,
182-
center_writing_weights=True,
183-
center_unembed=True,
184-
fold_value_biases=True,
185-
)
186-
187-
# Results should be different (different processing paths)
188-
with_adapter_keys = set(processed_with_adapter.keys())
189-
without_adapter_keys = set(processed_without_adapter.keys())
190-
191-
# Should have some different keys due to different processing
192-
assert (
193-
with_adapter_keys != without_adapter_keys
194-
), "With and without adapter should produce different key sets"
195-
196-
# With adapter should have split Q/K/V keys
197-
tl_attn_keys = [
198-
k for k in with_adapter_keys if any(p in k for p in [".W_Q", ".W_K", ".W_V"])
199-
]
200-
assert len(tl_attn_keys) > 0, "With adapter should have split Q/K/V keys"
201-
202-
@pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V")
203-
def test_custom_component_processing_integration(
204-
self, raw_hf_model_and_state_dict, bridge_and_adapter
205-
):
206-
"""Test that adapter preprocessing splits QKV weights correctly."""
207-
raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict
208-
bridge, adapter, cfg = bridge_and_adapter
209-
210-
# Preprocess weights first - this is what splits Q/K/V
211-
preprocessed_weights = adapter.preprocess_weights(raw_state_dict)
212-
213-
# Process with adapter after preprocessing
214-
processed_weights = ProcessWeights.process_weights(
215-
state_dict=preprocessed_weights,
216-
cfg=cfg,
217-
adapter=adapter,
218-
fold_ln=False,
219-
center_writing_weights=False,
220-
center_unembed=False,
221-
fold_value_biases=False,
222-
)
223-
224-
# Check for split Q/K/V weights (created by preprocessing)
225-
custom_qkv_found = any(".W_Q" in k for k in processed_weights.keys())
226-
227-
assert custom_qkv_found, "Should have split QKV weights after preprocessing"
228-
229-
# Verify that QKV splitting occurred for each layer
230-
q_keys = [k for k in processed_weights.keys() if ".W_Q" in k]
231-
k_keys = [k for k in processed_weights.keys() if ".W_K" in k]
232-
v_keys = [k for k in processed_weights.keys() if ".W_V" in k]
233-
234-
assert len(q_keys) > 0, "Should have Q weight keys"
235-
assert len(k_keys) > 0, "Should have K weight keys"
236-
assert len(v_keys) > 0, "Should have V weight keys"
237-
23872
def test_computational_correctness_with_existing_pipeline(self, model_name, device):
23973
"""Test that centralized processing maintains computational correctness."""
24074
test_tokens = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)

0 commit comments

Comments
 (0)