Skip to content

Commit 42a2d52

Browse files
authored
Fix boolean 4D attention-mask handling in joint-QKV bridge attention reconstruction (#1198)
* initial bare fix for remote CI debugging of downstream package * allow `makefile` to toggle `uv run --active` / `uv sync --active` behavior with `TL_UV_ACTIVE=1` instead of requiring manual edits for out-of-place virtualenv workflows. * Fix boolean 4D attention-mask handling in joint-QKV bridge attention * minor docstring clarification * integrate scale/upcast changes as well as stale combined qkv change * minimize PR change surface by removing cosmetic docstring changes * use min_dtype consistently in fallback causal tril mask paths * apply format fix with new make mode * align rotary reconstructed masking with shared helper * adjust legacy hooks to atol/rtol 1e-6
1 parent 7e1bdc3 commit 42a2d52

5 files changed

Lines changed: 329 additions & 46 deletions

File tree

makefile

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
RUN := uv run
1+
TL_UV_ACTIVE ?= 0
2+
ACTIVE_FLAG := $(if $(filter 1 true TRUE yes YES on ON,$(TL_UV_ACTIVE)), --active,)
3+
RUN := uv run$(ACTIVE_FLAG)
4+
UV_SYNC := uv sync$(ACTIVE_FLAG)
25

36
# Rerun args for flaky tests (httpx timeouts during HF Hub downloads)
47
# Remove this line when no longer needed
58
RERUN_ARGS := --reruns 2 --reruns-delay 5
69

710
dep:
8-
uv sync
11+
$(UV_SYNC)
912

1013
format:
1114
$(RUN) pycln --all . --exclude "__init__.py"
@@ -59,12 +62,12 @@ notebook-test:
5962
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb $(RERUN_ARGS)
6063

6164
test:
62-
make unit-test
63-
make integration-test
64-
make acceptance-test
65-
make benchmark-test
66-
make docstring-test
67-
make notebook-test
65+
$(MAKE) unit-test
66+
$(MAKE) integration-test
67+
$(MAKE) acceptance-test
68+
$(MAKE) benchmark-test
69+
$(MAKE) docstring-test
70+
$(MAKE) notebook-test
6871

6972
docs-hot-reload:
7073
$(RUN) docs-hot-reload

tests/integration/model_bridge/compatibility/test_legacy_hooks.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,15 @@ def test_cache_hook_names_present(self, transformer_bridge, prompt, expected_hoo
139139
def test_cache_hook_equality_with_hooked_transformer(
140140
self, transformer_bridge, hooked_transformer, prompt, expected_hooks
141141
):
142-
"""Test that TransformerBridge cache values match HookedTransformer cache values."""
142+
"""Test that TransformerBridge cache values match HookedTransformer cache values.
143+
144+
Raw attention-score caches intentionally use different masked sentinels:
145+
HookedTransformer stores ``-inf`` for masked causal positions, while
146+
TransformerBridge preserves HuggingFace's finite additive mask
147+
representation using ``torch.finfo(dtype).min``. The unmasked scores and
148+
resulting attention pattern should still match within floating-point
149+
precision.
150+
"""
143151
_, bridge_cache = transformer_bridge.run_with_cache(prompt)
144152
_, hooked_transformer_cache = hooked_transformer.run_with_cache(prompt)
145153

@@ -157,11 +165,35 @@ def test_cache_hook_equality_with_hooked_transformer(
157165
f"TransformerBridge shape {bridge_activation.shape}"
158166
)
159167

160-
# Allow for some numerical differences due to different implementations
161-
# Use nanmean to handle -inf values in attention scores (which produce nan when subtracted)
162-
mean_abs_diff = torch.nanmean(
163-
torch.abs(hooked_transformer_activation - bridge_activation)
164-
)
168+
if hook == "blocks.0.attn.hook_attn_scores":
169+
masked_positions = torch.isinf(hooked_transformer_activation)
170+
unmasked_positions = ~masked_positions
171+
172+
assert torch.allclose(
173+
hooked_transformer_activation[unmasked_positions],
174+
bridge_activation[unmasked_positions],
175+
atol=1e-6,
176+
rtol=1e-6,
177+
), (
178+
"Unmasked attention scores should match within float32 " "numerical precision"
179+
)
180+
181+
masked_bridge_values = bridge_activation[masked_positions]
182+
min_dtype = torch.finfo(bridge_activation.dtype).min
183+
184+
assert masked_positions.any(), "Expected causal masking in attention scores"
185+
assert torch.isfinite(masked_bridge_values).all(), (
186+
"TransformerBridge should keep masked attention scores finite "
187+
"to mirror HuggingFace additive masking semantics"
188+
)
189+
assert torch.all(masked_bridge_values == min_dtype), (
190+
"Masked TransformerBridge attention scores should use dtype min "
191+
"instead of HookedTransformer's -inf sentinel"
192+
)
193+
continue
194+
195+
# Remaining legacy-compatible hooks are finite on this prompt, mean abs diff suffices
196+
mean_abs_diff = torch.abs(hooked_transformer_activation - bridge_activation).mean()
165197
assert mean_abs_diff < 0.5, (
166198
f"Hook {hook} does not match between HookedTransformer and TransformerBridge. "
167199
f"Mean absolute difference: {mean_abs_diff}"

tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py

Lines changed: 221 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,83 @@
1010
class TestJointQKVAttention:
1111
"""Test JointQKVAttentionBridge functionality."""
1212

13+
@classmethod
14+
def _make_additive_mask(cls, boolean_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
15+
min_dtype = torch.finfo(dtype).min
16+
return torch.where(
17+
boolean_mask,
18+
torch.zeros((), dtype=dtype, device=boolean_mask.device),
19+
torch.full((), min_dtype, dtype=dtype, device=boolean_mask.device),
20+
)
21+
22+
@classmethod
23+
def _make_reconstruct_attention_qkv(cls) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
24+
q = torch.tensor(
25+
[
26+
[
27+
[[1.0, 0.0], [0.5, -0.5]],
28+
[[0.3, 0.7], [0.2, -0.1]],
29+
[[-0.4, 0.6], [0.1, 0.9]],
30+
]
31+
],
32+
dtype=torch.float32,
33+
)
34+
k = torch.tensor(
35+
[
36+
[
37+
[[0.9, 0.1], [0.2, -0.3]],
38+
[[0.5, 0.4], [0.3, 0.2]],
39+
[[-0.2, 0.8], [0.7, 0.1]],
40+
]
41+
],
42+
dtype=torch.float32,
43+
)
44+
v = torch.tensor(
45+
[
46+
[
47+
[[0.2, 1.0], [0.1, 0.6]],
48+
[[0.4, 0.3], [0.8, 0.2]],
49+
[[0.7, 0.5], [0.9, 0.4]],
50+
]
51+
],
52+
dtype=torch.float32,
53+
)
54+
return q, k, v
55+
56+
def _assert_non_4d_mask_preserves_causality(
57+
self,
58+
bridge,
59+
*,
60+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
61+
) -> None:
62+
q, k, v = self._make_reconstruct_attention_qkv()
63+
boolean_mask = torch.tensor([[True, True, False]])
64+
additive_mask = self._make_additive_mask(boolean_mask, q.dtype)
65+
reconstruct_kwargs = {}
66+
if position_embeddings is not None:
67+
reconstruct_kwargs["position_embeddings"] = position_embeddings
68+
69+
bool_output, bool_pattern = bridge._reconstruct_attention(
70+
q.clone(),
71+
k.clone(),
72+
v.clone(),
73+
attention_mask=boolean_mask,
74+
**reconstruct_kwargs,
75+
)
76+
additive_output, additive_pattern = bridge._reconstruct_attention(
77+
q.clone(),
78+
k.clone(),
79+
v.clone(),
80+
attention_mask=additive_mask,
81+
**reconstruct_kwargs,
82+
)
83+
84+
assert torch.allclose(bool_output, additive_output)
85+
assert torch.allclose(bool_pattern, additive_pattern)
86+
assert torch.all(bool_pattern[:, :, 0, 1:] == 0)
87+
assert torch.all(bool_pattern[:, :, 1, 2] == 0)
88+
assert torch.all(bool_pattern[..., 2] == 0)
89+
1390
def test_q_hook_out_mutation_applied_in_forward_pass(self):
1491
"""Test that mutations made to q.hook_out are applied in the forward pass result."""
1592

@@ -33,7 +110,8 @@ def forward(self, input):
33110
k_transformation = MockLinear(in_features=128, out_features=384)
34111
v_transformation = MockLinear(in_features=128, out_features=384)
35112

36-
split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation)
113+
def split_qkv_matrix(_component):
114+
return q_transformation, k_transformation, v_transformation
37115

38116
# Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
39117
class MockAttention(torch.nn.Module):
@@ -119,7 +197,8 @@ def forward(self, input):
119197
k_transformation = MockLinear(in_features=128, out_features=384)
120198
v_transformation = MockLinear(in_features=128, out_features=384)
121199

122-
split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation)
200+
def split_qkv_matrix(_component):
201+
return q_transformation, k_transformation, v_transformation
123202

124203
# Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
125204
class MockAttention(torch.nn.Module):
@@ -205,7 +284,8 @@ def forward(self, input):
205284
k_transformation = MockLinear(in_features=128, out_features=384)
206285
v_transformation = MockLinear(in_features=128, out_features=384)
207286

208-
split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation)
287+
def split_qkv_matrix(_component):
288+
return q_transformation, k_transformation, v_transformation
209289

210290
# Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
211291
class MockAttention(torch.nn.Module):
@@ -267,3 +347,141 @@ def v_hook_fn(v_output, hook):
267347
assert not torch.allclose(
268348
baseline_output, hooked_output
269349
), "Output with v_hook_out mutation should be different from baseline"
350+
351+
def test_reconstruct_attention_boolean_mask_matches_additive_mask(self):
352+
"""Boolean 4D masks should be equivalent to additive masks.
353+
354+
This regression test covers the HuggingFace causal-mask path used by
355+
TransformerBridge. Without the boolean-mask conversion in
356+
``_reconstruct_attention()``, boolean masks are added as ``0``/``1``
357+
and produce substantively different scores and patterns than the equivalent additive
358+
float mask.
359+
"""
360+
361+
class TestConfig:
362+
n_heads = 2
363+
d_model = 4
364+
365+
class MockOriginalAttention(torch.nn.Module):
366+
def __init__(self):
367+
super().__init__()
368+
self.attn_dropout = torch.nn.Identity()
369+
370+
bridge = JointQKVAttentionBridge(name="qkv", config=TestConfig())
371+
bridge.add_module("_original_component", MockOriginalAttention())
372+
q, k, v = self._make_reconstruct_attention_qkv()
373+
boolean_mask = torch.tensor(
374+
[[[[False, False, False], [False, True, False], [False, True, True]]]]
375+
)
376+
additive_mask = self._make_additive_mask(boolean_mask, q.dtype)
377+
378+
bool_output, bool_pattern = bridge._reconstruct_attention(
379+
q.clone(),
380+
k.clone(),
381+
v.clone(),
382+
attention_mask=boolean_mask,
383+
)
384+
additive_output, additive_pattern = bridge._reconstruct_attention(
385+
q.clone(),
386+
k.clone(),
387+
v.clone(),
388+
attention_mask=additive_mask,
389+
)
390+
391+
assert torch.isfinite(bool_output).all()
392+
assert torch.isfinite(bool_pattern).all()
393+
assert torch.allclose(bool_output, additive_output)
394+
assert torch.allclose(bool_pattern, additive_pattern)
395+
396+
def test_rotary_reconstruct_attention_boolean_mask_matches_additive_mask(self):
397+
"""Rotary joint-QKV attention should treat boolean and additive masks identically."""
398+
399+
from transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention import (
400+
JointQKVPositionEmbeddingsAttentionBridge,
401+
)
402+
403+
class TestConfig:
404+
n_heads = 2
405+
d_model = 4
406+
407+
class MockOriginalAttention(torch.nn.Module):
408+
def __init__(self):
409+
super().__init__()
410+
self.attn_dropout = torch.nn.Identity()
411+
412+
bridge = JointQKVPositionEmbeddingsAttentionBridge(name="qkv", config=TestConfig())
413+
bridge.add_module("_original_component", MockOriginalAttention())
414+
q, k, v = self._make_reconstruct_attention_qkv()
415+
boolean_mask = torch.tensor(
416+
[[[[False, False, False], [False, True, False], [False, True, True]]]]
417+
)
418+
additive_mask = self._make_additive_mask(boolean_mask, q.dtype)
419+
position_embeddings = (
420+
torch.ones(1, 3, 2, dtype=torch.float32),
421+
torch.zeros(1, 3, 2, dtype=torch.float32),
422+
)
423+
424+
bool_output, bool_pattern = bridge._reconstruct_attention(
425+
q.clone(),
426+
k.clone(),
427+
v.clone(),
428+
attention_mask=boolean_mask,
429+
position_embeddings=position_embeddings,
430+
)
431+
additive_output, additive_pattern = bridge._reconstruct_attention(
432+
q.clone(),
433+
k.clone(),
434+
v.clone(),
435+
attention_mask=additive_mask,
436+
position_embeddings=position_embeddings,
437+
)
438+
439+
assert torch.isfinite(bool_output).all()
440+
assert torch.isfinite(bool_pattern).all()
441+
assert torch.allclose(bool_output, additive_output)
442+
assert torch.allclose(bool_pattern, additive_pattern)
443+
444+
def test_reconstruct_attention_non_4d_mask_preserves_causality(self):
445+
"""Non-4D masks should still receive the local causal mask in the base bridge."""
446+
447+
class TestConfig:
448+
n_heads = 2
449+
d_model = 4
450+
451+
class MockOriginalAttention(torch.nn.Module):
452+
def __init__(self):
453+
super().__init__()
454+
self.attn_dropout = torch.nn.Identity()
455+
456+
bridge = JointQKVAttentionBridge(name="qkv", config=TestConfig())
457+
bridge.add_module("_original_component", MockOriginalAttention())
458+
459+
self._assert_non_4d_mask_preserves_causality(bridge)
460+
461+
def test_rotary_reconstruct_attention_non_4d_mask_preserves_causality(self):
462+
"""Rotary joint-QKV attention should match base masking semantics for non-4D masks."""
463+
464+
from transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention import (
465+
JointQKVPositionEmbeddingsAttentionBridge,
466+
)
467+
468+
class TestConfig:
469+
n_heads = 2
470+
d_model = 4
471+
472+
class MockOriginalAttention(torch.nn.Module):
473+
def __init__(self):
474+
super().__init__()
475+
self.attn_dropout = torch.nn.Identity()
476+
477+
bridge = JointQKVPositionEmbeddingsAttentionBridge(name="qkv", config=TestConfig())
478+
bridge.add_module("_original_component", MockOriginalAttention())
479+
position_embeddings = (
480+
torch.ones(1, 3, 2, dtype=torch.float32),
481+
torch.zeros(1, 3, 2, dtype=torch.float32),
482+
)
483+
484+
self._assert_non_4d_mask_preserves_causality(
485+
bridge,
486+
position_embeddings=position_embeddings,
487+
)

0 commit comments

Comments
 (0)