Skip to content

Commit 82af6a5

Browse files
committed
Arm backend: Data layout cast bugfixes
Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: Iff630a5bd12d05119f14be37f819bdb28586a6d2
1 parent b48d5d2 commit 82af6a5

2 files changed

Lines changed: 62 additions & 10 deletions

File tree

backends/arm/_passes/insert_data_layout_casts_pass.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from executorch.backends.arm._passes import ArmPass
1010
from executorch.backends.arm.tosa.specification import get_context_spec
1111
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import ExportPass
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata
1313

1414

1515
class InsertDataLayoutCastsPass(ArmPass):
@@ -55,6 +55,7 @@ class InsertDataLayoutCastsPass(ArmPass):
5555
}
5656

5757
_int_to_fp_map = {
58+
torch.int8: torch.float16, # This doubles the size after casting, but is very unlikely to occur in practice since int8 is only ever used by LOGICAL_SHIFT and CAST/RESCALE ops in PRO-FP.
5859
torch.int16: torch.float16,
5960
torch.int32: torch.float32,
6061
}
@@ -63,9 +64,15 @@ def call_operator(self, op, args, kwargs, meta):
6364
if op not in self.targeted_ops:
6465
return super().call_operator(op, args, kwargs, meta)
6566

66-
dtype = args[0].data.dtype
67-
spec = get_context_spec()
67+
if op in self._concat_ops:
68+
# Cast to largest dtype
69+
dtypes = [arg.data.dtype for arg in args[0]]
70+
dtype_sizes = [dtype.itemsize for dtype in dtypes]
71+
dtype = dtypes[dtype_sizes.index(max(dtype_sizes))]
72+
else:
73+
dtype = args[0].data.dtype
6874

75+
spec = get_context_spec()
6976
dtype_is_integer = not dtype.is_floating_point and dtype != torch.bool
7077
if dtype_is_integer and not spec.support_integer():
7178
supported_dtype = self._int_to_fp_map.get(dtype, None)
@@ -93,16 +100,30 @@ def call_operator(self, op, args, kwargs, meta):
93100
for arg in args[0]:
94101
x_casted.append(
95102
super().call_operator(
96-
self._cast_op, (arg,), {"dtype": supported_dtype}, meta
103+
self._cast_op,
104+
(arg,),
105+
{"dtype": supported_dtype},
106+
NodeMetadata(arg.node.meta),
107+
updated=True,
97108
)
98109
)
99-
y_casted = super().call_operator(op, (x_casted,), kwargs, meta)
110+
y_casted = super().call_operator(
111+
op, (x_casted, *args[1:]), kwargs, meta, updated=True
112+
)
100113

101114
else:
102115
x_casted = super().call_operator(
103-
self._cast_op, (args[0],), {"dtype": supported_dtype}, meta
116+
self._cast_op,
117+
(args[0],),
118+
{"dtype": supported_dtype},
119+
NodeMetadata(args[0].node.meta),
120+
updated=True,
121+
)
122+
y_casted = super().call_operator(
123+
op, (x_casted, *args[1:]), kwargs, meta, updated=True
104124
)
105-
y_casted = super().call_operator(op, (x_casted, *args[1:]), kwargs, meta)
106125

107-
y = super().call_operator(self._cast_op, (y_casted,), {"dtype": dtype}, meta)
126+
y = super().call_operator(
127+
self._cast_op, (y_casted,), {"dtype": dtype}, meta, updated=True
128+
)
108129
return y

backends/arm/test/passes/test_insert_data_layout_casts_pass.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
def _collect_cast_dtypes(
14-
pipeline: PassPipeline[tuple[torch.Tensor]],
14+
pipeline: PassPipeline[tuple[torch.Tensor, ...]],
1515
) -> list[torch.dtype]:
1616
exported_program = pipeline.tester.get_artifact(
1717
StageType.RUN_PASSES
@@ -34,10 +34,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3434
return x.view(2, 2)
3535

3636

37+
class CatModule(torch.nn.Module):
38+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
39+
return torch.cat([x, y], dim=1)
40+
41+
3742
def test_insert_data_layout_casts_no_target_view_fp_profile_inserts_casts() -> None:
3843
test_data = (torch.arange(4, dtype=torch.int32).reshape(1, 4),)
3944

40-
pipeline = PassPipeline[tuple[torch.Tensor]](
45+
pipeline = PassPipeline[tuple[torch.Tensor, ...]](
4146
ViewModule(),
4247
test_data,
4348
quantize=False,
@@ -78,3 +83,29 @@ def test_insert_data_layout_casts_no_target_view_fp_profile_skips_supported_dtyp
7883
pass_list=[InsertDataLayoutCastsPass],
7984
)
8085
pipeline.run()
86+
87+
88+
def test_insert_data_layout_casts_no_target_cat_fp_profile_inserts_casts() -> None:
89+
test_data = (
90+
torch.arange(4, dtype=torch.int32).reshape(1, 4),
91+
torch.arange(4, dtype=torch.int32).reshape(1, 4),
92+
)
93+
94+
pipeline = PassPipeline[tuple[torch.Tensor, ...]](
95+
CatModule(),
96+
test_data,
97+
quantize=False,
98+
ops_before_pass={
99+
"executorch_exir_dialects_edge__ops_aten_cat_default": 1,
100+
},
101+
ops_after_pass={
102+
"executorch_exir_dialects_edge__ops_aten_cat_default": 1,
103+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
104+
},
105+
pass_list=[InsertDataLayoutCastsPass],
106+
)
107+
pipeline.run()
108+
109+
cast_dtypes = _collect_cast_dtypes(pipeline)
110+
assert cast_dtypes.count(torch.float32) == 2
111+
assert cast_dtypes.count(torch.int32) == 1

0 commit comments

Comments
 (0)