Skip to content

Commit 466e5e3

Browse files
Use correct full-op in replace_scalar_with_tensor pass (pytorch#16637)
Use torch.ops.aten.full.default for torch ops and exir_ops.edge.aten.full.default for exir ops. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent f5afa0b commit 466e5e3

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

backends/transforms/replace_scalar_with_tensor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1313
from executorch.exir.pass_base import ExportPass
14+
from torch._ops import OpOverload
1415

1516

1617
class ReplaceScalarWithTensorArgPass(ExportPass):
@@ -46,6 +47,11 @@ def __init__(
4647
super().__init__()
4748

4849
def get_replacement(self, op, args, kwargs, meta):
50+
if isinstance(op, OpOverload):
51+
full_op = torch.ops.aten.full.default
52+
else:
53+
full_op = exir_ops.edge.aten.full.default
54+
4955
return super().call_operator(
5056
# Replace with .Tensor variant.
5157
op=self.scalar_to_tensor_ops[op],
@@ -54,10 +60,10 @@ def get_replacement(self, op, args, kwargs, meta):
5460
args[0],
5561
# Scalar arg - replace with aten.full tensor.
5662
super().call_operator(
57-
exir_ops.edge.aten.full.default,
63+
full_op,
5864
args=(
5965
(1,),
60-
args[1],
66+
float(args[1]),
6167
),
6268
kwargs={
6369
"dtype": args[0].to_tensor().dtype,

0 commit comments

Comments
 (0)