Skip to content

Commit 3ee411b

Browse files
authored
Fix TransformerBridge backward hook cleanup (#1324)
* Fix TransformerBridge backward hook cleanup * Preserve backward hooks in run_with_cache
1 parent 5f7b02e commit 3ee411b

4 files changed

Lines changed: 70 additions & 13 deletions

File tree

tests/acceptance/model_bridge/compatibility/test_backward_hooks.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,42 @@ def sum_bridge_grads(grad, hook=None):
5151
f"Gradient sums should be identical but differ by "
5252
f"{abs(hooked_grad_sum - bridge_grad_sum).item():.6f}"
5353
)
54+
55+
56+
def test_transformer_bridge_hooks_context_cleans_up_backward_hooks(
57+
gpt2_hooked_unprocessed, gpt2_bridge_compat_no_processing
58+
):
59+
"""Regression test for backward-hook cleanup on context exit."""
60+
hooked_model = gpt2_hooked_unprocessed
61+
bridge_model = gpt2_bridge_compat_no_processing
62+
hooked_hook = hooked_model.blocks[0].hook_resid_post
63+
bridge_hook = bridge_model.blocks[0].hook_resid_post
64+
test_input = torch.tensor([[1, 2, 3]])
65+
66+
def noop_backward_hook(grad, hook=None):
67+
return None
68+
69+
hooked_model.zero_grad()
70+
with hooked_model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", noop_backward_hook)]):
71+
hooked_model(test_input).sum().backward()
72+
73+
bridge_model.zero_grad()
74+
with bridge_model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", noop_backward_hook)]):
75+
bridge_model(test_input).sum().backward()
76+
77+
assert not hooked_hook.has_hooks(dir="bwd", including_permanent=False)
78+
assert not bridge_hook.has_hooks(dir="bwd", including_permanent=False)
79+
80+
81+
def test_transformer_bridge_reset_hooks_removes_backward_hooks(gpt2_bridge_compat_no_processing):
82+
"""Regression test for bridge reset_hooks removing backward hooks."""
83+
bridge_model = gpt2_bridge_compat_no_processing
84+
backward_hook = bridge_model.blocks[0].hook_resid_post
85+
86+
backward_hook.add_hook(lambda grad, hook=None: None, dir="bwd")
87+
88+
assert backward_hook.has_hooks(dir="bwd", including_permanent=False)
89+
90+
bridge_model.reset_hooks()
91+
92+
assert not backward_hook.has_hooks(dir="bwd", including_permanent=False)

tests/acceptance/model_bridge/compatibility/test_run_with_cache.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,21 @@ def test_run_with_cache_accepts_1d_tensor(self, gpt2_bridge_compat_no_processing
7777
assert torch.allclose(
7878
cache_1d["blocks.0.hook_mlp_out"], cache_2d["blocks.0.hook_mlp_out"], atol=1e-5
7979
)
80+
81+
82+
def test_transformer_bridge_run_with_cache_preserves_existing_backward_hooks(
83+
gpt2_bridge_compat_no_processing,
84+
):
85+
"""run_with_cache should not remove unrelated backward hooks on the same HookPoint."""
86+
bridge_model = gpt2_bridge_compat_no_processing
87+
target_hook = bridge_model.blocks[0].hook_resid_post
88+
89+
target_hook.add_hook(lambda grad, hook=None: None, dir="bwd")
90+
91+
assert target_hook.has_hooks(dir="bwd", including_permanent=False)
92+
93+
bridge_model.run_with_cache(torch.tensor([[1, 2, 3]]), names_filter="blocks.0.hook_resid_post")
94+
95+
assert target_hook.has_hooks(dir="bwd", including_permanent=False)
96+
97+
bridge_model.reset_hooks()

transformer_lens/model_bridge/bridge.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,7 +2081,7 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
20812081
raise e
20822082
finally:
20832083
for hp, _ in hooks:
2084-
hp.remove_hooks()
2084+
hp.remove_hooks(dir="fwd")
20852085
if self.compatibility_mode == True:
20862086
reverse_aliases = {}
20872087
for old_name, new_name in aliases.items():
@@ -2148,7 +2148,7 @@ def run_with_hooks(
21482148
Returns:
21492149
Model output
21502150
"""
2151-
added_hooks: List[Tuple[HookPoint, str]] = []
2151+
added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
21522152
effective_stop_layer = None
21532153
if stop_at_layer is not None and hasattr(self, "blocks"):
21542154
if stop_at_layer < 0:
@@ -2174,7 +2174,7 @@ def add_hook_to_point(
21742174
hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
21752175
else:
21762176
hook_point.add_hook(hook_fn, dir=dir)
2177-
added_hooks.append((hook_point, name))
2177+
added_hooks.append((hook_point, dir))
21782178

21792179
if stop_at_layer is not None and hasattr(self, "blocks"):
21802180
if stop_at_layer < 0:
@@ -2243,8 +2243,8 @@ def wrapped_hook_fn(tensor, hook, _orig_fn=original_hook_fn):
22432243
return output
22442244
finally:
22452245
if reset_hooks_end:
2246-
for hook_point, name in added_hooks:
2247-
hook_point.remove_hooks()
2246+
for hook_point, direction in added_hooks:
2247+
hook_point.remove_hooks(dir=direction)
22482248

22492249
def _generate_tokens(
22502250
self,
@@ -3306,7 +3306,7 @@ def hooks(self, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts
33063306

33073307
@contextmanager
33083308
def _hooks_context():
3309-
added_hooks: List[Tuple[HookPoint, str]] = []
3309+
added_hooks: List[Tuple[HookPoint, Literal["fwd", "bwd"]]] = []
33103310

33113311
def add_hook_to_point(
33123312
hook_point: HookPoint,
@@ -3322,7 +3322,7 @@ def add_hook_to_point(
33223322
hook_point.add_hook(hook_fn, dir=dir, alias_names=alias_names_list)
33233323
else:
33243324
hook_point.add_hook(hook_fn, dir=dir)
3325-
added_hooks.append((hook_point, name))
3325+
added_hooks.append((hook_point, dir))
33263326

33273327
def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
33283328
direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
@@ -3355,8 +3355,8 @@ def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool
33553355
yield self
33563356
finally:
33573357
if reset_hooks_end:
3358-
for hook_point, name in added_hooks:
3359-
hook_point.remove_hooks()
3358+
for hook_point, direction in added_hooks:
3359+
hook_point.remove_hooks(dir=direction)
33603360

33613361
return _hooks_context()
33623362

transformer_lens/model_bridge/generalized_components/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,12 @@ def remove_hooks(self, hook_name: str | None = None) -> None:
188188
hook_name: Name of the hook point to remove. If None, removes all hooks.
189189
"""
190190
if hook_name is None:
191-
self.hook_in.remove_hooks()
192-
self.hook_out.remove_hooks()
191+
self.hook_in.remove_hooks(dir="both")
192+
self.hook_out.remove_hooks(dir="both")
193193
elif hook_name == "output":
194-
self.hook_out.remove_hooks()
194+
self.hook_out.remove_hooks(dir="both")
195195
elif hook_name == "input":
196-
self.hook_in.remove_hooks()
196+
self.hook_in.remove_hooks(dir="both")
197197
else:
198198
raise ValueError(
199199
f"Hook name '{hook_name}' not supported. Supported names are 'output' and 'input'."

0 commit comments

Comments
 (0)