Skip to content

Commit 0919746

Browse files
[QNN] Fix LiftConstantScalarOperands to handle aten.pow.Scalar (#18994)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 84e1aed commit 0919746

3 files changed

Lines changed: 55 additions & 2 deletions

File tree

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class TensorOpInfo:
4848
aten.sub.Scalar: TensorOpInfo(aten.sub.Tensor, False, False),
4949
aten.sub.Tensor: TensorOpInfo(aten.sub.Tensor, False, False),
5050
aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False),
51+
aten.pow.Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False),
5152
# The scalar number arg[1] is missing when using default. Result in a corner case to deal
5253
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False),
5354
aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False),
@@ -86,11 +87,13 @@ def _build_tensor_constant(
8687
) -> TensorConstant:
8788
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
8889
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
90+
first_arg = node.args[0]
8991
tensor = torch.tensor(
9092
const_val,
9193
dtype=(
92-
node.args[0].meta["val"].dtype
93-
if not is_float_tensor(node)
94+
first_arg.meta["val"].dtype
95+
if isinstance(first_arg, fx.Node)
96+
and not is_float_tensor(node)
9497
and (info := SCALAR_OPS.get(node.target))
9598
and not info.use_self_dtype
9699
else node.meta["val"].dtype

backends/qualcomm/tests/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,15 @@ def forward(self, x):
18291829
return torch.pow(x, self.exponent)
18301830

18311831

1832+
class PowScalar(torch.nn.Module):
1833+
def __init__(self, base=2.0):
1834+
super().__init__()
1835+
self._base = base
1836+
1837+
def forward(self, x):
1838+
return torch.ops.aten.pow.Scalar(self._base, x)
1839+
1840+
18321841
class PReLUDefault(torch.nn.Module):
18331842
def __init__(self):
18341843
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,26 @@ def test_qnn_backend_pow_tensor_scalar(self):
17131713
index += 1
17141714
self.lower_module_and_test_output(module, sample_input)
17151715

1716+
def test_qnn_backend_pow_scalar(self):
1717+
test_comb = [
1718+
{
1719+
QCOM_MODULE: [
1720+
PowScalar(), # base=2.0, default # noqa: F405
1721+
PowScalar(3.0), # base=3.0, common case # noqa: F405
1722+
PowScalar(9), # base=9, integer exp case # noqa: F405
1723+
PowScalar(0.5), # base=0.5, fractional case # noqa: F405
1724+
],
1725+
QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) + 0.1,)],
1726+
},
1727+
]
1728+
index = 0
1729+
for comb in test_comb:
1730+
for module in comb[QCOM_MODULE]:
1731+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
1732+
with self.subTest(i=index):
1733+
index += 1
1734+
self.lower_module_and_test_output(module, sample_input)
1735+
17161736
def test_qnn_backend_prelu(self):
17171737
test_comb = [
17181738
{
@@ -4229,6 +4249,27 @@ def test_qnn_backend_pow_tensor_scalar(self):
42294249
qdq_module = self.get_qdq_module(module, sample_input)
42304250
self.lower_module_and_test_output(qdq_module, sample_input)
42314251

4252+
def test_qnn_backend_pow_scalar(self):
4253+
test_comb = [
4254+
{
4255+
QCOM_MODULE: [
4256+
PowScalar(), # base=2.0, default # noqa: F405
4257+
PowScalar(3.0), # base=3.0, common case # noqa: F405
4258+
PowScalar(9), # base=9, integer exp case # noqa: F405
4259+
PowScalar(0.5), # base=0.5, fractional case # noqa: F405
4260+
],
4261+
QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) + 0.1,)],
4262+
},
4263+
]
4264+
index = 0
4265+
for comb in test_comb:
4266+
for module in comb[QCOM_MODULE]:
4267+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
4268+
with self.subTest(i=index):
4269+
index += 1
4270+
qdq_module = self.get_qdq_module(module, sample_input)
4271+
self.lower_module_and_test_output(qdq_module, sample_input)
4272+
42324273
def test_qnn_backend_prelu(self):
42334274
test_comb = [
42344275
{

0 commit comments

Comments
 (0)