@@ -2266,19 +2266,25 @@ def _get_split_sizes(self, node: torch.fx.Node) -> Optional[list[tuple[int, ...]
22662266class ReplacePowWithMulPass (RemoveOrReplacePassInterface ):
22672267 """
22682268 Replace the pow op with successive mul ops when the exponent is an
2269- integer between 2 and 4 (inclusive).
2269+ integer between 2 and 4 (inclusive). Float exponents that are whole
2270+ numbers (e.g., 2.0, 3.0, 4.0) are also accepted.
22702271 """
22712272
22722273 @property
22732274 def targets (self ) -> list [EdgeOpOverload ]:
22742275 return [exir_ops .edge .aten .pow .Tensor_Scalar ]
22752276
22762277 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
2277- # Check if we have at least 2 args and the exponent is an int
2278- if len (node .args ) < 2 or not isinstance (node .args [1 ], int ):
2278+ # Check if we have at least 2 args and the exponent is an int or float
2279+ if len (node .args ) < 2 or not isinstance (node .args [1 ], ( int , float ) ):
22792280 return False
22802281
2281- exponent = cast (int , node .args [1 ])
2282+ exponent_val = node .args [1 ]
2283+ if isinstance (exponent_val , float ):
2284+ if not exponent_val .is_integer ():
2285+ return False
2286+ exponent_val = int (exponent_val )
2287+ exponent = cast (int , exponent_val )
22822288
22832289 # Only replace if exponent is between 2 and 4 (inclusive)
22842290 if exponent < 2 or exponent > 4 :
0 commit comments