Skip to content

Commit 2137894

Browse files
authored
Arm backend: Fix sum call signature in DecomposeSumPass (#19546)
exir_ops.edge.aten.sum.dim_IntList requires the second arg to be a list[int]. The DecomposeSumPass and op_sum, however, used int. Seems like the operator itself could handle it, but later passes tripped on it in some cases. Since the documented signature is list[int], modify DecomposeSumPass and op_sum to use that.
1 parent 9ccbc4a commit 2137894

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def call_operator(self, op, args, kwargs, meta):
7878
for dim in dims:
7979
input_node = super().call_operator(
8080
sum_op,
81-
(input_node, dim, True),
81+
(input_node, [dim], True),
8282
kwargs,
8383
meta,
8484
updated=True,

backends/arm/operators/op_sum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def define_node(
4343

4444
tensor = inputs[0]
4545
input_shape = list(tensor.shape)
46-
dim = int(inputs[1].number % len(input_shape))
46+
dim = int(inputs[1].special[0] % len(input_shape))
4747

4848
attr = ts.TosaSerializerAttribute()
4949
attr.ReduceSumAttribute(dim)

0 commit comments

Comments
 (0)