Skip to content

Commit b3b9934

Browse files
Kymi808brendanlongjlarson4SamuelePunzo
authored
Fix FactoredMatrix indexing returning empty result for -1 index (#1340)
* Fix type of HookedTransformerConfig.device (#1230) * Fix type of HookedTransformerConfig.device This is typed as `Optional[str]` but sometimes returns `torch.device`. Updated the code to just return the `str` instead of wrapping with a device. I'm not confident that every function which takes a device will always be passed a string, so I didn't change functions like warn_if_mps. Found while working on #1219 * more cleanup * 3.0 CI Bugs (#1261) * Fixing `utils` imports * skip gated notebooks on PR from forks * Updating notebooks * Ensure LLaMA only runs when HF_TOKEN is available --------- Co-authored-by: jlarson4 <jonahalarson@comcast.net> * Fix TransformerBridge backward hook cleanup (#1324) * Fix TransformerBridge backward hook cleanup * Preserve backward hooks in run_with_cache * Fix FactoredMatrix indexing returning empty result for -1 index FactoredMatrix.__getitem__ converts an integer index `v` into the matrix (ldim/rdim) dimensions to `slice(v, v + 1)`. For `v == -1` this becomes `slice(-1, 0)`, which is an empty slice, so indexing the last row/column with a negative index silently returns a (0, ...) tensor instead of the requested element. Other negative indices (-2, -3, ...) are unaffected because `v + 1` stays negative. Use `None` as the slice stop when `v == -1` so the final element is kept. Adds regression tests covering negative indices on each matrix dimension. --------- Co-authored-by: Brendan Long <self@brendanlong.com> Co-authored-by: jlarson4 <jonahalarson@comcast.net> Co-authored-by: Samuele_Punzo <90847990+SamuelePunzo@users.noreply.github.com>
1 parent f676d8a commit b3b9934

6 files changed

Lines changed: 95 additions & 14 deletions

File tree

tests/acceptance/model_bridge/compatibility/test_backward_hooks.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,42 @@ def sum_bridge_grads(grad, hook=None):
5151
f"Gradient sums should be identical but differ by "
5252
f"{abs(hooked_grad_sum - bridge_grad_sum).item():.6f}"
5353
)
54+
55+
56+
def test_transformer_bridge_hooks_context_cleans_up_backward_hooks(
57+
gpt2_hooked_unprocessed, gpt2_bridge_compat_no_processing
58+
):
59+
"""Regression test for backward-hook cleanup on context exit."""
60+
hooked_model = gpt2_hooked_unprocessed
61+
bridge_model = gpt2_bridge_compat_no_processing
62+
hooked_hook = hooked_model.blocks[0].hook_resid_post
63+
bridge_hook = bridge_model.blocks[0].hook_resid_post
64+
test_input = torch.tensor([[1, 2, 3]])
65+
66+
def noop_backward_hook(grad, hook=None):
67+
return None
68+
69+
hooked_model.zero_grad()
70+
with hooked_model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", noop_backward_hook)]):
71+
hooked_model(test_input).sum().backward()
72+
73+
bridge_model.zero_grad()
74+
with bridge_model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", noop_backward_hook)]):
75+
bridge_model(test_input).sum().backward()
76+
77+
assert not hooked_hook.has_hooks(dir="bwd", including_permanent=False)
78+
assert not bridge_hook.has_hooks(dir="bwd", including_permanent=False)
79+
80+
81+
def test_transformer_bridge_reset_hooks_removes_backward_hooks(gpt2_bridge_compat_no_processing):
82+
"""Regression test for bridge reset_hooks removing backward hooks."""
83+
bridge_model = gpt2_bridge_compat_no_processing
84+
backward_hook = bridge_model.blocks[0].hook_resid_post
85+
86+
backward_hook.add_hook(lambda grad, hook=None: None, dir="bwd")
87+
88+
assert backward_hook.has_hooks(dir="bwd", including_permanent=False)
89+
90+
bridge_model.reset_hooks()
91+
92+
assert not backward_hook.has_hooks(dir="bwd", including_permanent=False)

tests/acceptance/model_bridge/compatibility/test_run_with_cache.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,21 @@ def test_run_with_cache_accepts_1d_tensor(self, gpt2_bridge_compat_no_processing
7777
assert torch.allclose(
7878
cache_1d["blocks.0.hook_mlp_out"], cache_2d["blocks.0.hook_mlp_out"], atol=1e-5
7979
)
80+
81+
82+
def test_transformer_bridge_run_with_cache_preserves_existing_backward_hooks(
83+
gpt2_bridge_compat_no_processing,
84+
):
85+
"""run_with_cache should not remove unrelated backward hooks on the same HookPoint."""
86+
bridge_model = gpt2_bridge_compat_no_processing
87+
target_hook = bridge_model.blocks[0].hook_resid_post
88+
89+
target_hook.add_hook(lambda grad, hook=None: None, dir="bwd")
90+
91+
assert target_hook.has_hooks(dir="bwd", including_permanent=False)
92+
93+
bridge_model.run_with_cache(torch.tensor([[1, 2, 3]]), names_filter="blocks.0.hook_resid_post")
94+
95+
assert target_hook.has_hooks(dir="bwd", including_permanent=False)
96+
97+
bridge_model.reset_hooks()

tests/unit/factored_matrix/test_get_item.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,25 @@ def test_index_dimension_get_element(sample_factored_matrix):
5151
assert_close(result.AB.squeeze(), sample_factored_matrix.AB[0, 0, 0, 0, 1])
5252

5353

54+
def test_index_dimension_get_line_negative(sample_factored_matrix):
55+
# Negative index into the row (ldim) of the matrix. `idx == -1` previously
56+
# produced an empty slice(-1, 0) and returned a (0, ...) tensor.
57+
result = sample_factored_matrix[0, 0, 0, -1]
58+
assert_close(result.AB.squeeze(), sample_factored_matrix.AB[0, 0, 0, -1])
59+
60+
61+
def test_index_dimension_get_element_negative(sample_factored_matrix):
62+
# Negative index into the column (rdim) of the matrix.
63+
result = sample_factored_matrix[0, 0, 0, 0, -1]
64+
assert_close(result.AB.squeeze(), sample_factored_matrix.AB[0, 0, 0, 0, -1])
65+
66+
67+
def test_index_dimension_get_element_both_negative(sample_factored_matrix):
68+
# Negative index into both matrix dimensions at once.
69+
result = sample_factored_matrix[0, 0, 0, -1, -1]
70+
assert_close(result.AB.squeeze(), sample_factored_matrix.AB[0, 0, 0, -1, -1])
71+
72+
5473
def test_index_dimension_too_big(sample_factored_matrix):
5574
with pytest.raises(Exception):
5675
_ = sample_factored_matrix[1, 1, 1, 1, 1, 1]

transformer_lens/FactoredMatrix.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,12 @@ def _convert_to_slice(self, sequence: Union[Tuple, List], idx: int) -> Tuple:
288288
if isinstance(idx, int):
289289
sequence = list(sequence)
290290
if isinstance(sequence[idx], int):
291-
sequence[idx] = slice(sequence[idx], sequence[idx] + 1)
291+
value = sequence[idx]
292+
# `value + 1` selects the single requested element, except when
293+
# value == -1: there `value + 1 == 0` yields the empty slice(-1, 0).
294+
# Use `None` as the stop so the final element is kept.
295+
stop = value + 1 if value != -1 else None
296+
sequence[idx] = slice(value, stop)
292297
sequence = tuple(sequence)
293298

294299
return sequence

transformer_lens/model_bridge/bridge.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,7 +2114,7 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
21142114
raise e
21152115
finally:
21162116
for hp, _ in hooks:
2117-
hp.remove_hooks()
2117+
hp.remove_hooks(dir="fwd")
21182118
if self.compatibility_mode == True:
21192119
reverse_aliases = {}
21202120
for old_name, new_name in aliases.items():
@@ -2181,7 +2181,7 @@ def run_with_hooks(
21812181
Returns:
21822182
Model output
21832183
"""
2184-
added_hooks: List[Tuple[HookPoint, str]] = []
2184+
added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
21852185
effective_stop_layer = None
21862186
if stop_at_layer is not None and hasattr(self, "blocks"):
21872187
if stop_at_layer < 0:
@@ -2207,7 +2207,7 @@ def add_hook_to_point(
22072207
hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
22082208
else:
22092209
hook_point.add_hook(hook_fn, dir=dir)
2210-
added_hooks.append((hook_point, name))
2210+
added_hooks.append((hook_point, dir))
22112211

22122212
if stop_at_layer is not None and hasattr(self, "blocks"):
22132213
if stop_at_layer < 0:
@@ -2276,8 +2276,8 @@ def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn):
22762276
return output
22772277
finally:
22782278
if reset_hooks_end:
2279-
for hook_point, name in added_hooks:
2280-
hook_point.remove_hooks()
2279+
for hook_point, direction in added_hooks:
2280+
hook_point.remove_hooks(dir=direction)
22812281

22822282
def _generate_tokens(
22832283
self,
@@ -3452,7 +3452,7 @@ def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts
34523452

34533453
@contextmanager
34543454
def _hooks_context():
3455-
added_hooks: List[Tuple[HookPoint, str]] = []
3455+
added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
34563456

34573457
def add_hook_to_point(
34583458
hook_point: HookPoint,
@@ -3468,7 +3468,7 @@ def add_hook_to_point(
34683468
hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
34693469
else:
34703470
hook_point.add_hook(hook_fn, dir=dir)
3471-
added_hooks.append((hook_point, name))
3471+
added_hooks.append((hook_point, dir))
34723472

34733473
def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
34743474
direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
@@ -3501,8 +3501,8 @@ def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool
35013501
yield self
35023502
finally:
35033503
if reset_hooks_end:
3504-
for hook_point, name in added_hooks:
3505-
hook_point.remove_hooks()
3504+
for hook_point, direction in added_hooks:
3505+
hook_point.remove_hooks(dir=direction)
35063506

35073507
return _hooks_context()
35083508

transformer_lens/model_bridge/generalized_components/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,12 @@ def remove_hooks(self, hook_name: str | None = None) -> None:
188188
hook_name: Name of the hook point to remove. If None, removes all hooks.
189189
"""
190190
if hook_name is None:
191-
self.hook_in.remove_hooks()
192-
self.hook_out.remove_hooks()
191+
self.hook_in.remove_hooks(dir="both")
192+
self.hook_out.remove_hooks(dir="both")
193193
elif hook_name == "output":
194-
self.hook_out.remove_hooks()
194+
self.hook_out.remove_hooks(dir="both")
195195
elif hook_name == "input":
196-
self.hook_in.remove_hooks()
196+
self.hook_in.remove_hooks(dir="both")
197197
else:
198198
raise ValueError(
199199
f"Hook name '{hook_name}' not supported. Supported names are 'output' and 'input'."

0 commit comments

Comments
 (0)