Skip to content

Commit a14c946

Browse files
authored
Support whole-number float exponents in ReplacePowWithMulPass (#18851)
Differential Revision: D100695654 Pull Request resolved: #18851
1 parent 2f339f0 commit a14c946

2 files changed

Lines changed: 15 additions & 7 deletions

File tree

backends/cadence/aot/replace_ops.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,19 +2266,25 @@ def _get_split_sizes(self, node: torch.fx.Node) -> Optional[list[tuple[int, ...]
22662266
class ReplacePowWithMulPass(RemoveOrReplacePassInterface):
22672267
"""
22682268
Replace the pow op with successive mul ops when the exponent is an
2269-
integer between 2 and 4 (inclusive).
2269+
integer between 2 and 4 (inclusive). Float exponents that are whole
2270+
numbers (e.g., 2.0, 3.0, 4.0) are also accepted.
22702271
"""
22712272

22722273
@property
22732274
def targets(self) -> list[EdgeOpOverload]:
22742275
return [exir_ops.edge.aten.pow.Tensor_Scalar]
22752276

22762277
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
2277-
# Check if we have at least 2 args and the exponent is an int
2278-
if len(node.args) < 2 or not isinstance(node.args[1], int):
2278+
# Check if we have at least 2 args and the exponent is an int or float
2279+
if len(node.args) < 2 or not isinstance(node.args[1], (int, float)):
22792280
return False
22802281

2281-
exponent = cast(int, node.args[1])
2282+
exponent_val = node.args[1]
2283+
if isinstance(exponent_val, float):
2284+
if not exponent_val.is_integer():
2285+
return False
2286+
exponent_val = int(exponent_val)
2287+
exponent = cast(int, exponent_val)
22822288

22832289
# Only replace if exponent is between 2 and 4 (inclusive)
22842290
if exponent < 2 or exponent > 4:

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,8 +1926,8 @@ def test_replace_split_with_sizes_with_slice(self) -> None:
19261926
2,
19271927
)
19281928

1929-
@expand([[2], [3], [4]])
1930-
def test_replace_pow_with_mul(self, exponent: int) -> None:
1929+
@expand([[2], [3], [4], [2.0], [3.0], [4.0]])
1930+
def test_replace_pow_with_mul(self, exponent: int | float) -> None:
19311931
x_input = torch.randn(2, 1, 64)
19321932
x = x_input
19331933
original_gm = single_op_builder(
@@ -1956,13 +1956,15 @@ def test_replace_pow_with_mul(self, exponent: int) -> None:
19561956
graph_after_passes,
19571957
exir_ops.edge.aten.mul.Tensor,
19581958
),
1959-
exponent - 1,
1959+
int(exponent) - 1,
19601960
)
19611961

19621962
@expand(
19631963
[
19641964
[1],
19651965
[1.5],
1966+
[5.0],
1967+
[0.5],
19661968
]
19671969
)
19681970
def test_replace_pow_with_mul_not_applied(self, exponent: float) -> None:

0 commit comments

Comments
 (0)