Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions backends/qualcomm/_passes/lift_constant_scalar_operands.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TensorOpInfo:
aten.sub.Scalar: TensorOpInfo(aten.sub.Tensor, False, False),
aten.sub.Tensor: TensorOpInfo(aten.sub.Tensor, False, False),
aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False),
aten.pow.Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False),
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add/extend a unit test that exercises aten.pow.Scalar so this regression stays fixed (e.g., export a tiny module using torch.pow(2.0, x), run LiftConstantScalarOperands, and assert the node target becomes aten.pow.Tensor_Tensor and the pass doesn’t crash when args[0] is a Python float).

Copilot uses AI. Check for mistakes.
# The scalar number arg[1] is missing when using default. Result in a corner case to deal
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False),
aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False),
Expand Down Expand Up @@ -86,11 +87,13 @@ def _build_tensor_constant(
) -> TensorConstant:
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
first_arg = node.args[0]
tensor = torch.tensor(
const_val,
dtype=(
node.args[0].meta["val"].dtype
if not is_float_tensor(node)
first_arg.meta["val"].dtype
if isinstance(first_arg, fx.Node)
and not is_float_tensor(node)
and (info := SCALAR_OPS.get(node.target))
and not info.use_self_dtype
else node.meta["val"].dtype
Expand Down
9 changes: 9 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,15 @@ def forward(self, x):
return torch.pow(x, self.exponent)


class PowScalar(torch.nn.Module):
def __init__(self, base=2.0):
super().__init__()
self._base = base

def forward(self, x):
return torch.ops.aten.pow.Scalar(self._base, x)


class PReLUDefault(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
41 changes: 41 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,26 @@ def test_qnn_backend_pow_tensor_scalar(self):
index += 1
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_pow_scalar(self):
test_comb = [
{
QCOM_MODULE: [
PowScalar(), # base=2.0, default # noqa: F405
PowScalar(3.0), # base=3.0, common case # noqa: F405
PowScalar(9), # base=9, integer exp case # noqa: F405
PowScalar(0.5), # base=0.5, fractional case # noqa: F405
Comment on lines +1722 to +1723
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inline comments describe these values as exponent-related (e.g., "integer exp case"), but in PowScalar the varying constructor argument is the scalar base (the exponent is the tensor input). Update the comments to refer to the base to avoid confusion when reading/fixing test failures.

Suggested change
PowScalar(9), # base=9, integer exp case # noqa: F405
PowScalar(0.5), # base=0.5, fractional case # noqa: F405
PowScalar(9), # base=9, integer base case # noqa: F405
PowScalar(0.5), # base=0.5, fractional base case # noqa: F405

Copilot uses AI. Check for mistakes.
],
QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) + 0.1,)],
},
]
index = 0
for comb in test_comb:
for module in comb[QCOM_MODULE]:
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
with self.subTest(i=index):
index += 1
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_prelu(self):
test_comb = [
{
Expand Down Expand Up @@ -4229,6 +4249,27 @@ def test_qnn_backend_pow_tensor_scalar(self):
qdq_module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(qdq_module, sample_input)

def test_qnn_backend_pow_scalar(self):
test_comb = [
{
QCOM_MODULE: [
PowScalar(), # base=2.0, default # noqa: F405
PowScalar(3.0), # base=3.0, common case # noqa: F405
PowScalar(9), # base=9, integer exp case # noqa: F405
PowScalar(0.5), # base=0.5, fractional case # noqa: F405
Comment on lines +4258 to +4259
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the floating-point test: these comments describe exponent cases, but the varied constructor argument in PowScalar(...) is the scalar base. Please adjust the comments to refer to the base to prevent misunderstanding.

Suggested change
PowScalar(9), # base=9, integer exp case # noqa: F405
PowScalar(0.5), # base=0.5, fractional case # noqa: F405
PowScalar(9), # base=9, integer base case # noqa: F405
PowScalar(0.5), # base=0.5, fractional base case # noqa: F405

Copilot uses AI. Check for mistakes.
],
QCOM_SAMPLE_INPUTS: [(torch.rand(10, 10) + 0.1,)],
},
]
index = 0
for comb in test_comb:
for module in comb[QCOM_MODULE]:
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
with self.subTest(i=index):
index += 1
qdq_module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(qdq_module, sample_input)

def test_qnn_backend_prelu(self):
test_comb = [
{
Expand Down
Loading