Skip to content

Commit fcccda3

Browse files
authored
Arm backend: Add optional ToDevicePass (#18230)
A pass to move a graph_module to a device correctly. This is needed on models containing ops with "device" kwargs. They are not moved when model.to(device=....) is called. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent c11ba1b commit fcccda3

File tree

6 files changed

+305
-1
lines changed

6 files changed

+305
-1
lines changed

backends/arm/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,17 @@ List of model specific and optional passes:
308308
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
309309
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
310310

311+
- ToDevicePass
312+
- This is a utility for moving an already-quantized or already-decomposed GraphModule to another device.
313+
- it is intended to be used immediately before rerunning / retracing / torch.export.export(...)
314+
- Functionalities:
315+
- Calls `.to(device)` on the GraphModule and rewrites explicit `device=` kwargs on `call_function` nodes to a user-specified device.
316+
- Useful when manually moving an already-quantized or already-decomposed graph module to another device for validation, since some constant-producing nodes may still carry an export-time device kwarg.
317+
- Example usage:
318+
- `from executorch.exir.passes import ToDevicePass`
319+
- `graph_module = ToDevicePass("cpu")(graph_module).graph_module`
320+
- backends/arm/test/misc/test_post_quant_device_switch.py
321+
311322
## Help & Improvements
312323

313324
If you have problems or questions, or have suggestions for ways to improve the Arm backend, please reach out
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
from dataclasses import dataclass
8+
from typing import Callable
9+
10+
import pytest
11+
import torch
12+
import torch.nn.functional as F
13+
from executorch.backends.arm.quantizer import (
14+
get_symmetric_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.tosa import TosaSpecification
18+
from executorch.exir.passes import ToDevicePass
19+
from torch._subclasses.fake_tensor import FakeTensor
20+
from torchao.quantization.pt2e import move_exported_model_to_eval
21+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
22+
23+
24+
class AddAlpha(torch.nn.Module):
25+
def forward(self, x, y):
26+
return torch.add(x, y, alpha=2.0)
27+
28+
29+
class SubAlpha(torch.nn.Module):
30+
def forward(self, x, y):
31+
return torch.sub(x, y, alpha=2.0)
32+
33+
34+
class SliceScatter(torch.nn.Module):
35+
def forward(self, x, src):
36+
return torch.slice_scatter(x, src, dim=1, start=0, end=4, step=2)
37+
38+
39+
class MeanDim(torch.nn.Module):
40+
def forward(self, x):
41+
return torch.mean(x, dim=(1,), keepdim=True)
42+
43+
44+
class MeanDefault(torch.nn.Module):
45+
def forward(self, x):
46+
return torch.mean(x)
47+
48+
49+
class VarCorrection(torch.nn.Module):
50+
def forward(self, x):
51+
return torch.var(x, dim=(2, 3), correction=1, keepdim=True)
52+
53+
54+
class VarDim(torch.nn.Module):
55+
def forward(self, x):
56+
return torch.ops.aten.var.dim(x, [2, 3], 1, True)
57+
58+
59+
class DivTensorMode(torch.nn.Module):
60+
def forward(self, x, y):
61+
return torch.div(x, y, rounding_mode="trunc")
62+
63+
64+
class LeakyRelu(torch.nn.Module):
65+
def forward(self, x):
66+
return F.leaky_relu(x, negative_slope=0.2)
67+
68+
69+
class AvgPool2d(torch.nn.Module):
70+
def forward(self, x):
71+
return F.avg_pool2d(x, kernel_size=2, stride=1, padding=1)
72+
73+
74+
class LayerNorm(torch.nn.Module):
75+
def __init__(self):
76+
super().__init__()
77+
self.layer_norm = torch.nn.LayerNorm(4, elementwise_affine=False)
78+
79+
def forward(self, x):
80+
return self.layer_norm(x)
81+
82+
83+
class GroupNorm(torch.nn.Module):
84+
def __init__(self):
85+
super().__init__()
86+
self.group_norm = torch.nn.GroupNorm(2, 4, affine=False)
87+
88+
def forward(self, x):
89+
return self.group_norm(x)
90+
91+
92+
@dataclass(frozen=True)
93+
class MetaRetraceCase:
94+
name: str
95+
module_factory: Callable[[], torch.nn.Module]
96+
inputs_factory: Callable[[], tuple[torch.Tensor, ...]]
97+
aten_op: str
98+
99+
100+
_TEST_CASES = [
101+
MetaRetraceCase(
102+
"add_alpha",
103+
AddAlpha,
104+
lambda: (torch.randn(2, 3), torch.randn(2, 3)),
105+
"aten.add.Tensor",
106+
),
107+
MetaRetraceCase(
108+
"sub_alpha",
109+
SubAlpha,
110+
lambda: (torch.randn(2, 3), torch.randn(2, 3)),
111+
"aten.sub.Tensor",
112+
),
113+
MetaRetraceCase(
114+
"slice_scatter",
115+
SliceScatter,
116+
lambda: (torch.randn(2, 4), torch.randn(2, 2)),
117+
"aten.slice_scatter.default",
118+
),
119+
MetaRetraceCase(
120+
"mean_dim",
121+
MeanDim,
122+
lambda: (torch.randn(2, 3, 4),),
123+
"aten.mean.dim",
124+
),
125+
MetaRetraceCase(
126+
"mean_default",
127+
MeanDefault,
128+
lambda: (torch.randn(2, 3, 4),),
129+
"aten.mean.default",
130+
),
131+
MetaRetraceCase(
132+
"var_correction",
133+
VarCorrection,
134+
lambda: (torch.randn(2, 3, 4, 4),),
135+
"aten.var.correction",
136+
),
137+
MetaRetraceCase(
138+
"var_dim",
139+
VarDim,
140+
lambda: (torch.randn(2, 3, 4, 4),),
141+
"aten.var.dim",
142+
),
143+
MetaRetraceCase(
144+
"div_tensor_mode",
145+
DivTensorMode,
146+
lambda: (torch.randn(2, 3), torch.randn(2, 3) + 1.0),
147+
"aten.div.Tensor_mode",
148+
),
149+
MetaRetraceCase(
150+
"leaky_relu",
151+
LeakyRelu,
152+
lambda: (torch.randn(2, 3),),
153+
"aten.leaky_relu.default",
154+
),
155+
MetaRetraceCase(
156+
"avg_pool2d",
157+
AvgPool2d,
158+
lambda: (torch.randn(1, 3, 4, 4),),
159+
"aten.avg_pool2d.default",
160+
),
161+
MetaRetraceCase(
162+
"layer_norm",
163+
LayerNorm,
164+
lambda: (torch.randn(2, 3, 4),),
165+
"aten.layer_norm.default",
166+
),
167+
MetaRetraceCase(
168+
"group_norm",
169+
GroupNorm,
170+
lambda: (torch.randn(2, 4, 3, 3),),
171+
"aten.group_norm.default",
172+
),
173+
]
174+
175+
176+
def _make_quantizer() -> TOSAQuantizer:
177+
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
178+
quantizer.set_global(get_symmetric_quantization_config(is_per_channel=False))
179+
return quantizer
180+
181+
182+
def _iter_fake_tensors(meta_val):
183+
if isinstance(meta_val, FakeTensor):
184+
yield meta_val
185+
return
186+
187+
if isinstance(meta_val, (list, tuple)):
188+
for item in meta_val:
189+
yield from _iter_fake_tensors(item)
190+
191+
192+
def _to_meta_inputs(
193+
example_inputs: tuple[torch.Tensor, ...],
194+
) -> tuple[torch.Tensor, ...]:
195+
return tuple(inp.to(device="meta") for inp in example_inputs)
196+
197+
198+
@pytest.mark.parametrize("case", _TEST_CASES, ids=[case.name for case in _TEST_CASES])
199+
def test_post_quant_device_switch_no_target(case: MetaRetraceCase) -> None:
200+
"""This test tests that moving a model to another device after quantiation
201+
works.
202+
"""
203+
module = case.module_factory().train()
204+
example_inputs = case.inputs_factory()
205+
206+
# Quantize module
207+
exported = torch.export.export(module, example_inputs, strict=True)
208+
prepared = prepare_qat_pt2e(copy.deepcopy(exported.graph_module), _make_quantizer())
209+
prepared(*example_inputs)
210+
prepared = move_exported_model_to_eval(prepared)
211+
quantized_module = convert_pt2e(prepared)
212+
213+
# Move and test running the model with other device.
214+
meta_inputs = _to_meta_inputs(example_inputs)
215+
meta_module = ToDevicePass("meta")(quantized_module).graph_module
216+
meta_module(*meta_inputs)
217+
218+
# Retrace module using meta device to check all fake tensors are moved.
219+
meta_module = torch.export.export(meta_module, meta_inputs, strict=True)
220+
221+
# Validate transformation.
222+
fake_tensor_devices = [
223+
(str(fake_tensor.device), str(node))
224+
for node in meta_module.graph.nodes
225+
for fake_tensor in _iter_fake_tensors(node.meta.get("val"))
226+
]
227+
228+
assert fake_tensor_devices, "Expected traced graph to contain FakeTensor metadata"
229+
assert all(device == "meta" for device, _ in fake_tensor_devices), (
230+
"Expected all traced FakeTensors to use the meta device, got "
231+
f"{fake_tensor_devices}"
232+
)

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def define_arm_tests():
4848
"misc/test_bn_relu_folding_qat.py",
4949
"misc/test_custom_partition.py",
5050
"misc/test_debug_hook.py",
51+
"misc/test_post_quant_device_switch.py",
5152
# "misc/test_dim_order.py", (TODO - T238390249)
5253
]
5354

exir/passes/BUCK

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ fbcode_target(_kind = runtime.python_library,
3232
":spec_prop_pass",
3333
":sym_shape_eval_pass",
3434
":sym_to_tensor_pass",
35+
":to_device_pass",
3536
":weights_to_outputs_pass",
3637
":reinplace_pass",
3738
"//caffe2:torch",
@@ -92,6 +93,17 @@ fbcode_target(_kind = runtime.python_library,
9293
],
9394
)
9495

96+
fbcode_target(_kind = runtime.python_library,
97+
name = "to_device_pass",
98+
srcs = [
99+
"to_device_pass.py",
100+
],
101+
deps = [
102+
"//caffe2:torch",
103+
"//executorch/exir:pass_base",
104+
],
105+
)
106+
95107
fbcode_target(_kind = runtime.python_library,
96108
name = "weights_to_outputs_pass",
97109
srcs = [

exir/passes/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -30,7 +31,6 @@
3031
to_out_variant,
3132
to_scratch_op,
3233
)
33-
3434
from executorch.exir.pass_base import ExportPass
3535
from executorch.exir.pass_manager import PassManager, PassType
3636
from executorch.exir.passes.const_prop_pass import ConstPropPass
@@ -59,6 +59,8 @@
5959
from executorch.exir.passes.spec_prop_pass import SpecPropPass
6060
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
6161
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
62+
63+
from executorch.exir.passes.to_device_pass import ToDevicePass
6264
from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
6365
from torch import fx
6466
from torch._subclasses import FakeTensor
@@ -71,6 +73,7 @@
7173
"ConstPropPass",
7274
"QuantFusionPass",
7375
"OpReplacePass",
76+
"ToDevicePass",
7477
"EdgeToBackendOpsPass",
7578
"MemoryFormatOpsPass",
7679
"MemoryPlanningPass",

exir/passes/to_device_pass.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class ToDevicePass(ExportPass):
13+
"""Call .to(device) and rewrite explicit `device=` kwargs on call_function
14+
nodes to given device.
15+
"""
16+
17+
_passes_required_after: Set[Type[ExportPass]] = set()
18+
19+
def __init__(self, device: str | torch.device, *args, **kwargs) -> None:
20+
super().__init__(*args, **kwargs)
21+
self.device = torch.device(device)
22+
23+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24+
graph_module = graph_module.to(self.device)
25+
modified = False
26+
27+
for node in graph_module.graph.nodes:
28+
if node.op != "call_function" or "device" not in node.kwargs:
29+
continue
30+
31+
current_device = node.kwargs["device"]
32+
if current_device == self.device:
33+
continue
34+
35+
node.update_kwarg("device", self.device)
36+
modified = True
37+
38+
if modified:
39+
graph_module.recompile()
40+
41+
return PassResult(graph_module, True)
42+
43+
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
44+
"""Reimplement __call__ to avoid Optional[PassResult] type hint."""
45+
return self.call(graph_module)

0 commit comments

Comments
 (0)