Skip to content

Commit e983693

Browse files
authored
Cortex-M backend: Verify output shape before rewriting AdaptiveAvgPool (pytorch#19935)
The pass only does a naive rewrite, so check that the output shape actually matches after the rewrite before doing it. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent aa8a182 commit e983693

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

backends/cortex_m/passes/decompose_mean_pass.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,21 @@ def call_operator(
2525
meta: NodeMetadata,
2626
) -> ProxyValue:
2727
if op == torch.ops.aten.adaptive_avg_pool2d.default:
28-
op = torch.ops.aten.avg_pool2d.default
29-
input_tensor = cast(torch.Tensor, args[0])
30-
shape = input_tensor.data.shape
28+
input_tensor = cast(ProxyValue, args[0]).to_tensor()
29+
shape = input_tensor.shape
3130
stride = [1, 1]
3231
kernel_size = [shape[-2], shape[-1]]
33-
args = (args[0], kernel_size, stride, [0, 0], 0, 0)
3432

33+
new_args = (args[0], kernel_size, stride, [0, 0], 0, 0)
34+
35+
adaptive_output = torch.ops.aten.adaptive_avg_pool2d.default(
36+
input_tensor, *args[1:]
37+
)
38+
avg_pool_output = torch.ops.aten.avg_pool2d.default(
39+
input_tensor, *new_args[1:]
40+
)
41+
42+
if adaptive_output.shape == avg_pool_output.shape:
43+
new_op = torch.ops.aten.avg_pool2d.default
44+
return super().call_operator(new_op, new_args, kwargs, meta)
3545
return super().call_operator(op, args, kwargs, meta)

0 commit comments

Comments
 (0)