Skip to content

Commit 7324ed4

Browse files
Arm backend: Preserve inputs for pow zero decomposition (pytorch#19637)
Keep pow(x, 0) dependent on the original input when decomposing it to ones, so the TOSA graph still expects the user input. Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent acf1ad9 commit 7324ed4

3 files changed

Lines changed: 9 additions & 6 deletions

File tree

backends/arm/_passes/decompose_int_pow_pass.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@ def call_operator(self, op, args, kwargs, meta):
3232
x = args[0]
3333
exp = args[1]
3434

35-
# Handle zero first and return early
3635
if exp == 0:
37-
# return a tensor of ones with the same shape as x
38-
return super().call_operator(
36+
zeros = super().call_operator(
37+
exir_ops.edge.aten.sub.Tensor, (x, x), {}, meta, True
38+
)
39+
ones = super().call_operator(
3940
exir_ops.edge.aten.full_like.default, (x, 1), {}, meta, True
4041
)
42+
return super().call_operator(
43+
exir_ops.edge.aten.add.Tensor, (zeros, ones), {}, meta, True
44+
)
4145

4246
if not isinstance(exp, int):
4347
return super().call_operator(op, args, kwargs, meta)

backends/arm/test/ops/test_pow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def test_pow_tensor_tensor_vgf_no_quant(test_data: Pow_TensorTensor.input_t):
147147

148148
x_fail_FP = {
149149
"exp_two": "TOSA constraints: If x <0 .",
150-
"exp_zero": "MLETORCH-2041 : Invalid inputs.",
151150
}
152151

153152

backends/arm/test/passes/test_decompose_int_pow_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,18 @@ def get_inputs(self) -> input_t:
5959
def test_decompose_int_pow_tosa_FP(data: TestParam) -> None:
6060
module_with_inputs, nbr_muls = data
6161
module = cast(torch.nn.Module, module_with_inputs)
62+
pow_op = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"
6263
pipeline = PassPipeline[input_t](
6364
module,
6465
module_with_inputs.get_inputs(),
6566
quantize=False,
6667
ops_before_pass={
67-
"executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 1,
68+
pow_op: 1,
6869
},
6970
ops_not_before_pass=[],
7071
ops_after_pass={
7172
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": nbr_muls,
7273
},
73-
ops_not_after_pass=["executorch_exir_dialects_edge__ops_pow_Tensor_Scalar"],
7474
pass_list=[DecomposeIntPowPass],
7575
)
7676
pipeline.run()

0 commit comments

Comments
 (0)