1010class TestJointQKVAttention :
1111 """Test JointQKVAttentionBridge functionality."""
1212
13+ @classmethod
14+ def _make_additive_mask (cls , boolean_mask : torch .Tensor , dtype : torch .dtype ) -> torch .Tensor :
15+ min_dtype = torch .finfo (dtype ).min
16+ return torch .where (
17+ boolean_mask ,
18+ torch .zeros ((), dtype = dtype , device = boolean_mask .device ),
19+ torch .full ((), min_dtype , dtype = dtype , device = boolean_mask .device ),
20+ )
21+
22+ @classmethod
23+ def _make_reconstruct_attention_qkv (cls ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
24+ q = torch .tensor (
25+ [
26+ [
27+ [[1.0 , 0.0 ], [0.5 , - 0.5 ]],
28+ [[0.3 , 0.7 ], [0.2 , - 0.1 ]],
29+ [[- 0.4 , 0.6 ], [0.1 , 0.9 ]],
30+ ]
31+ ],
32+ dtype = torch .float32 ,
33+ )
34+ k = torch .tensor (
35+ [
36+ [
37+ [[0.9 , 0.1 ], [0.2 , - 0.3 ]],
38+ [[0.5 , 0.4 ], [0.3 , 0.2 ]],
39+ [[- 0.2 , 0.8 ], [0.7 , 0.1 ]],
40+ ]
41+ ],
42+ dtype = torch .float32 ,
43+ )
44+ v = torch .tensor (
45+ [
46+ [
47+ [[0.2 , 1.0 ], [0.1 , 0.6 ]],
48+ [[0.4 , 0.3 ], [0.8 , 0.2 ]],
49+ [[0.7 , 0.5 ], [0.9 , 0.4 ]],
50+ ]
51+ ],
52+ dtype = torch .float32 ,
53+ )
54+ return q , k , v
55+
56+ def _assert_non_4d_mask_preserves_causality (
57+ self ,
58+ bridge ,
59+ * ,
60+ position_embeddings : tuple [torch .Tensor , torch .Tensor ] | None = None ,
61+ ) -> None :
62+ q , k , v = self ._make_reconstruct_attention_qkv ()
63+ boolean_mask = torch .tensor ([[True , True , False ]])
64+ additive_mask = self ._make_additive_mask (boolean_mask , q .dtype )
65+ reconstruct_kwargs = {}
66+ if position_embeddings is not None :
67+ reconstruct_kwargs ["position_embeddings" ] = position_embeddings
68+
69+ bool_output , bool_pattern = bridge ._reconstruct_attention (
70+ q .clone (),
71+ k .clone (),
72+ v .clone (),
73+ attention_mask = boolean_mask ,
74+ ** reconstruct_kwargs ,
75+ )
76+ additive_output , additive_pattern = bridge ._reconstruct_attention (
77+ q .clone (),
78+ k .clone (),
79+ v .clone (),
80+ attention_mask = additive_mask ,
81+ ** reconstruct_kwargs ,
82+ )
83+
84+ assert torch .allclose (bool_output , additive_output )
85+ assert torch .allclose (bool_pattern , additive_pattern )
86+ assert torch .all (bool_pattern [:, :, 0 , 1 :] == 0 )
87+ assert torch .all (bool_pattern [:, :, 1 , 2 ] == 0 )
88+ assert torch .all (bool_pattern [..., 2 ] == 0 )
89+
1390 def test_q_hook_out_mutation_applied_in_forward_pass (self ):
1491 """Test that mutations made to q.hook_out are applied in the forward pass result."""
1592
@@ -33,7 +110,8 @@ def forward(self, input):
33110 k_transformation = MockLinear (in_features = 128 , out_features = 384 )
34111 v_transformation = MockLinear (in_features = 128 , out_features = 384 )
35112
36- split_qkv_matrix = lambda x : (q_transformation , k_transformation , v_transformation )
113+ def split_qkv_matrix (_component ):
114+ return q_transformation , k_transformation , v_transformation
37115
38116 # Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
39117 class MockAttention (torch .nn .Module ):
@@ -119,7 +197,8 @@ def forward(self, input):
119197 k_transformation = MockLinear (in_features = 128 , out_features = 384 )
120198 v_transformation = MockLinear (in_features = 128 , out_features = 384 )
121199
122- split_qkv_matrix = lambda x : (q_transformation , k_transformation , v_transformation )
200+ def split_qkv_matrix (_component ):
201+ return q_transformation , k_transformation , v_transformation
123202
124203 # Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
125204 class MockAttention (torch .nn .Module ):
@@ -205,7 +284,8 @@ def forward(self, input):
205284 k_transformation = MockLinear (in_features = 128 , out_features = 384 )
206285 v_transformation = MockLinear (in_features = 128 , out_features = 384 )
207286
208- split_qkv_matrix = lambda x : (q_transformation , k_transformation , v_transformation )
287+ def split_qkv_matrix (_component ):
288+ return q_transformation , k_transformation , v_transformation
209289
210290 # Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
211291 class MockAttention (torch .nn .Module ):
@@ -267,3 +347,141 @@ def v_hook_fn(v_output, hook):
267347 assert not torch .allclose (
268348 baseline_output , hooked_output
269349 ), "Output with v_hook_out mutation should be different from baseline"
350+
351+ def test_reconstruct_attention_boolean_mask_matches_additive_mask (self ):
352+ """Boolean 4D masks should be equivalent to additive masks.
353+
354+ This regression test covers the HuggingFace causal-mask path used by
355+ TransformerBridge. Without the boolean-mask conversion in
356+ ``_reconstruct_attention()``, boolean masks are added as ``0``/``1``
357+ and produce substantively different scores and patterns than the equivalent additive
358+ float mask.
359+ """
360+
361+ class TestConfig :
362+ n_heads = 2
363+ d_model = 4
364+
365+ class MockOriginalAttention (torch .nn .Module ):
366+ def __init__ (self ):
367+ super ().__init__ ()
368+ self .attn_dropout = torch .nn .Identity ()
369+
370+ bridge = JointQKVAttentionBridge (name = "qkv" , config = TestConfig ())
371+ bridge .add_module ("_original_component" , MockOriginalAttention ())
372+ q , k , v = self ._make_reconstruct_attention_qkv ()
373+ boolean_mask = torch .tensor (
374+ [[[[False , False , False ], [False , True , False ], [False , True , True ]]]]
375+ )
376+ additive_mask = self ._make_additive_mask (boolean_mask , q .dtype )
377+
378+ bool_output , bool_pattern = bridge ._reconstruct_attention (
379+ q .clone (),
380+ k .clone (),
381+ v .clone (),
382+ attention_mask = boolean_mask ,
383+ )
384+ additive_output , additive_pattern = bridge ._reconstruct_attention (
385+ q .clone (),
386+ k .clone (),
387+ v .clone (),
388+ attention_mask = additive_mask ,
389+ )
390+
391+ assert torch .isfinite (bool_output ).all ()
392+ assert torch .isfinite (bool_pattern ).all ()
393+ assert torch .allclose (bool_output , additive_output )
394+ assert torch .allclose (bool_pattern , additive_pattern )
395+
396+ def test_rotary_reconstruct_attention_boolean_mask_matches_additive_mask (self ):
397+ """Rotary joint-QKV attention should treat boolean and additive masks identically."""
398+
399+ from transformer_lens .model_bridge .generalized_components .joint_qkv_position_embeddings_attention import (
400+ JointQKVPositionEmbeddingsAttentionBridge ,
401+ )
402+
403+ class TestConfig :
404+ n_heads = 2
405+ d_model = 4
406+
407+ class MockOriginalAttention (torch .nn .Module ):
408+ def __init__ (self ):
409+ super ().__init__ ()
410+ self .attn_dropout = torch .nn .Identity ()
411+
412+ bridge = JointQKVPositionEmbeddingsAttentionBridge (name = "qkv" , config = TestConfig ())
413+ bridge .add_module ("_original_component" , MockOriginalAttention ())
414+ q , k , v = self ._make_reconstruct_attention_qkv ()
415+ boolean_mask = torch .tensor (
416+ [[[[False , False , False ], [False , True , False ], [False , True , True ]]]]
417+ )
418+ additive_mask = self ._make_additive_mask (boolean_mask , q .dtype )
419+ position_embeddings = (
420+ torch .ones (1 , 3 , 2 , dtype = torch .float32 ),
421+ torch .zeros (1 , 3 , 2 , dtype = torch .float32 ),
422+ )
423+
424+ bool_output , bool_pattern = bridge ._reconstruct_attention (
425+ q .clone (),
426+ k .clone (),
427+ v .clone (),
428+ attention_mask = boolean_mask ,
429+ position_embeddings = position_embeddings ,
430+ )
431+ additive_output , additive_pattern = bridge ._reconstruct_attention (
432+ q .clone (),
433+ k .clone (),
434+ v .clone (),
435+ attention_mask = additive_mask ,
436+ position_embeddings = position_embeddings ,
437+ )
438+
439+ assert torch .isfinite (bool_output ).all ()
440+ assert torch .isfinite (bool_pattern ).all ()
441+ assert torch .allclose (bool_output , additive_output )
442+ assert torch .allclose (bool_pattern , additive_pattern )
443+
444+ def test_reconstruct_attention_non_4d_mask_preserves_causality (self ):
445+ """Non-4D masks should still receive the local causal mask in the base bridge."""
446+
447+ class TestConfig :
448+ n_heads = 2
449+ d_model = 4
450+
451+ class MockOriginalAttention (torch .nn .Module ):
452+ def __init__ (self ):
453+ super ().__init__ ()
454+ self .attn_dropout = torch .nn .Identity ()
455+
456+ bridge = JointQKVAttentionBridge (name = "qkv" , config = TestConfig ())
457+ bridge .add_module ("_original_component" , MockOriginalAttention ())
458+
459+ self ._assert_non_4d_mask_preserves_causality (bridge )
460+
461+ def test_rotary_reconstruct_attention_non_4d_mask_preserves_causality (self ):
462+ """Rotary joint-QKV attention should match base masking semantics for non-4D masks."""
463+
464+ from transformer_lens .model_bridge .generalized_components .joint_qkv_position_embeddings_attention import (
465+ JointQKVPositionEmbeddingsAttentionBridge ,
466+ )
467+
468+ class TestConfig :
469+ n_heads = 2
470+ d_model = 4
471+
472+ class MockOriginalAttention (torch .nn .Module ):
473+ def __init__ (self ):
474+ super ().__init__ ()
475+ self .attn_dropout = torch .nn .Identity ()
476+
477+ bridge = JointQKVPositionEmbeddingsAttentionBridge (name = "qkv" , config = TestConfig ())
478+ bridge .add_module ("_original_component" , MockOriginalAttention ())
479+ position_embeddings = (
480+ torch .ones (1 , 3 , 2 , dtype = torch .float32 ),
481+ torch .zeros (1 , 3 , 2 , dtype = torch .float32 ),
482+ )
483+
484+ self ._assert_non_4d_mask_preserves_causality (
485+ bridge ,
486+ position_embeddings = position_embeddings ,
487+ )
0 commit comments