Skip to content

Commit 34e6dc4

Browse files
Fix run_with_cache(device=...) permanently moving the model (#1345)
* Fix run_with_cache(device=...) permanently moving the model The single-device branch moved the model and inputs to cache_device with no restore, leaving non-CPU models silently migrated and cfg.device stale. The move was redundant since make_cache_hook already offloads each captured activation, matching ActivationCache.to and the legacy get_caching_hooks contract. Flatten the conditional, add a regression test asserting original_model.to is not invoked, and document the device kwarg. * Retire cache_dict workaround in return_cache device offload With the run_with_cache model-move fixed, TransformerBridge.generate return_cache device offload can use a run_with_cache(device=device) passthrough. The offload now happens at capture time, reducing peak memory. Drop the cache_dict direct-write and its justifying comment, simplify the offload test to a device-landing check.
1 parent 84e90e1 commit 34e6dc4

3 files changed

Lines changed: 64 additions & 44 deletions

File tree

tests/integration/model_bridge/test_bridge_generate_return_cache.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
this does not change what is being tested.
1111
"""
1212

13-
import warnings
14-
1513
import pytest
1614
import torch
1715

@@ -149,25 +147,20 @@ def test_names_filter_scopes_cache(self, bridge):
149147
assert set(cache.cache_dict) == set(ref.cache_dict)
150148
assert len(cache.cache_dict) < 20
151149

152-
def test_device_offload_no_spurious_warning(self, bridge):
153-
"""device= offloads cache tensors (cpu here) without ActivationCache.to's move_model warning."""
150+
def test_device_offload_lands_on_requested_device(self, bridge):
151+
"""device= offloads cache tensors to the requested device."""
154152
tokens = bridge.to_tokens("The quick brown")
155-
with warnings.catch_warnings(record=True) as caught:
156-
warnings.simplefilter("always")
157-
with torch.no_grad():
158-
_, cache = bridge.generate(
159-
tokens,
160-
max_new_tokens=4,
161-
do_sample=False,
162-
return_type="tokens",
163-
return_cache=True,
164-
device="cpu",
165-
use_past_kv_cache=False,
166-
)
153+
with torch.no_grad():
154+
_, cache = bridge.generate(
155+
tokens,
156+
max_new_tokens=4,
157+
do_sample=False,
158+
return_type="tokens",
159+
return_cache=True,
160+
device="cpu",
161+
use_past_kv_cache=False,
162+
)
167163
assert str(cache["blocks.0.hook_resid_post"].device) == "cpu"
168-
assert not any("move_model" in str(w.message) for w in caught), [
169-
str(w.message) for w in caught
170-
]
171164

172165

173166
class TestGenerateReturnCacheGuards:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Regression coverage for `TransformerBridge.run_with_cache(device=...)`.
2+
3+
The `device=` kwarg is a cache-offload knob: cached activations are stored on
4+
that device, but the underlying model and inputs must stay where the caller
5+
put them, matching `ActivationCache.to` and the legacy `get_caching_hooks`
6+
("device to store on") contract.
7+
"""
8+
9+
from unittest.mock import patch
10+
11+
import pytest
12+
13+
14+
@pytest.fixture()
15+
def bridge(distilgpt2_bridge):
16+
"""Alias the session fixture for concise test signatures."""
17+
return distilgpt2_bridge
18+
19+
20+
def test_run_with_cache_device_does_not_move_model(bridge):
21+
"""`run_with_cache(device=...)` must not relocate the underlying model.
22+
23+
CPU runners cannot reproduce the original cross-device crash directly
24+
(`to('cpu')` is a no-op there), so we spy on `original_model.to` with
25+
`Mock(wraps=...)` and assert it isn't invoked during the call. That
26+
catches the regression on any platform.
27+
"""
28+
with patch.object(bridge.original_model, "to", wraps=bridge.original_model.to) as to_spy:
29+
_, cache = bridge.run_with_cache(bridge.to_tokens("hello"), device="cpu")
30+
31+
assert to_spy.call_count == 0, (
32+
f"run_with_cache(device=...) moved the underlying model "
33+
f"({to_spy.call_count} call(s): {to_spy.call_args_list})."
34+
)
35+
# And the cache itself still lands on the requested device.
36+
assert next(iter(cache.values())).device.type == "cpu"

transformer_lens/model_bridge/bridge.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,6 +1968,8 @@ def run_with_cache(
19681968
remove_batch_dim: Whether to remove batch dimension
19691969
names_filter: Filter for which activations to cache (str, list of str, or callable)
19701970
stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop)
1971+
device: Where to store cached activations (matches ActivationCache.to;
1972+
does not move the model). Defaults to per-layer storage.
19711973
**kwargs: Additional arguments
19721974
# type: ignore[name-defined]
19731975
Returns:
@@ -2075,25 +2077,19 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
20752077
hook_dict[block_hook_name].add_hook(stop_hook)
20762078
hooks.append((hook_dict[block_hook_name], block_hook_name))
20772079
filtered_kwargs = kwargs.copy()
2078-
if cache_device is not None:
2079-
if getattr(self.cfg, "n_devices", 1) > 1:
2080-
# Moving a dispatched model to a single device collapses accelerate's
2081-
# split and breaks its routing hooks. The cache will stay spread across
2082-
# the per-layer devices; callers can .to(cache_device) on cache entries
2083-
# after the fact if they need a single-device cache.
2084-
warnings.warn(
2085-
f"run_with_cache(device={cache_device!r}) ignored: model is dispatched "
2086-
f"across {self.cfg.n_devices} devices via device_map. Cached activations "
2087-
"will remain on their per-layer devices.",
2088-
stacklevel=2,
2089-
)
2090-
else:
2091-
self.original_model = self.original_model.to(cache_device)
2092-
if processed_args and isinstance(processed_args[0], torch.Tensor):
2093-
processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:])
2094-
for key, value in filtered_kwargs.items():
2095-
if isinstance(value, torch.Tensor):
2096-
filtered_kwargs[key] = value.to(cache_device)
2080+
# `cache_device` is honored by `make_cache_hook` above (`tensor.detach().to(cache_device)`);
2081+
# the model and inputs stay where the caller put them, matching `ActivationCache.to`.
2082+
if cache_device is not None and getattr(self.cfg, "n_devices", 1) > 1:
2083+
# Moving a dispatched model to a single device collapses accelerate's
2084+
# split and breaks its routing hooks. The cache will stay spread across
2085+
# the per-layer devices; callers can .to(cache_device) on cache entries
2086+
# after the fact if they need a single-device cache.
2087+
warnings.warn(
2088+
f"run_with_cache(device={cache_device!r}) ignored: model is dispatched "
2089+
f"across {self.cfg.n_devices} devices via device_map. Cached activations "
2090+
"will remain on their per-layer devices.",
2091+
stacklevel=2,
2092+
)
20972093
try:
20982094
if "output_attentions" not in filtered_kwargs:
20992095
filtered_kwargs["output_attentions"] = True
@@ -2858,12 +2854,7 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...
28582854
# cache is identical to run_with_cache(output_tokens) - all hook points, including
28592855
# attention patterns. The guards above restrict this to single-sequence, decoder-only
28602856
# text generation (see issue #697).
2861-
_, cache = self.run_with_cache(output_tokens, names_filter=names_filter)
2862-
if device is not None:
2863-
# Offload the cached activations to `device`. We move cache_dict directly rather
2864-
# than calling ActivationCache.to(device), which currently emits a spurious
2865-
# move_model DeprecationWarning.
2866-
cache.cache_dict = {key: value.to(device) for key, value in cache.cache_dict.items()}
2857+
_, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device)
28672858
return result, cache
28682859

28692860
@torch.no_grad()

0 commit comments

Comments
 (0)