Skip to content

Commit 129c687

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for _cdist_forward core ATen op (pytorch#20195)
### Summary Added support for the `_cdist_forward` core ATen op using the existing implementation for `CDist`. Note this is an internal ATen variant of `torch.cdist` that `torch.export` produces, so just added the target to the existing pass for `CDist` and other small additions to make sure the pass is registered in the correct pipelines. ### Test plan ``` python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperator.test_qnn_backend_cdist_forward --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_cdist_forward --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android ``` cc @cccclai @cbilgin @abhinaykukkadapu
1 parent 5526971 commit 129c687

5 files changed

Lines changed: 34 additions & 2 deletions

File tree

backends/qualcomm/_passes/decompose_cdist.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,19 @@ class DecomposeCDist(ExportPass):
3636
Decompose for math equivalent op.
3737
"""
3838

39+
cdist_targets = {
40+
torch.ops.aten.cdist.default,
41+
torch.ops.aten._cdist_forward.default,
42+
}
43+
3944
def __init__(self) -> None:
4045
super().__init__()
4146

4247
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4348
graph = graph_module.graph
4449
for node in graph.nodes:
4550
model = CDist()
46-
if torch.ops.aten.cdist.default == node.target:
51+
if node.target in self.cdist_targets:
4752
if len(node.args) > 2:
4853
assert (
4954
node.args[2] == 2

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def get_default_pass_activations(cls):
126126
(DecomposeAny, True),
127127
(DecomposeAtan2, True),
128128
(DecomposeColIm, True),
129+
(DecomposeCDist, True),
129130
(DecomposeFill, True),
130131
(DecomposeLogVariants, True),
131132
(DecomposeMaxPool3d, True),
@@ -278,6 +279,7 @@ def get_passes_dependency_for_capture_program(cls):
278279
DecomposeAny: [RemoveRedundancy],
279280
DecomposeAtan2: [RemoveRedundancy],
280281
DecomposeColIm: [FoldQDQ],
282+
DecomposeCDist: [RemoveRedundancy],
281283
DecomposeFill: [RemoveRedundancy],
282284
DecomposeLinalgVectorNorm: [RemoveRedundancy],
283285
DecomposeLogVariants: [RemoveRedundancy],

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ The following PyTorch operators are supported through decomposition or annotatio
502502
| `aten.any` | `DecomposeAny` |
503503
| `aten.atan2.default`, `aten.atan2.out` | `DecomposeAtan2` |
504504
| `aten.add` (with alpha), `aten.sub` (with alpha) | `DecomposeBinaryAlpha` |
505-
| `aten.cdist` | `DecomposeCDist` |
505+
| `aten.cdist`, `aten._cdist_forward` | `DecomposeCDist` |
506506
| `aten.im2col`, `aten.col2im` | `DecomposeColIm` |
507507
| `aten.einsum` | `DecomposeEinsum` |
508508
| `aten.special_expm1` | `DecomposeExpM1` |

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,14 @@ def forward(self, x, y):
408408
return torch.cdist(x, y, p=2)
409409

410410

411+
class CDistForward(torch.nn.Module):
412+
def __init__(self):
413+
super().__init__()
414+
415+
def forward(self, x, y):
416+
return torch.ops.aten._cdist_forward.default(x, y, 2.0, None)
417+
418+
411419
class Ceil(torch.nn.Module):
412420
def __init__(self):
413421
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,14 @@ def test_qnn_backend_cdist(self):
428428
)
429429
self.lower_module_and_test_output(module, sample_input)
430430

431+
def test_qnn_backend_cdist_forward(self):
432+
module = CDistForward() # noqa: F405
433+
sample_input = (
434+
torch.randn(1, 125, 256),
435+
torch.randn(1, 2048, 256),
436+
)
437+
self.lower_module_and_test_output(module, sample_input)
438+
431439
def test_qnn_backend_channel_shuffle(self):
432440
module = ChannelShuffle(2) # noqa: F405
433441
sample_input = (torch.randn(1, 4, 3, 3),)
@@ -3159,6 +3167,15 @@ def test_qnn_backend_cdist(self):
31593167
module = self.get_qdq_module(module, sample_input)
31603168
self.lower_module_and_test_output(module, sample_input)
31613169

3170+
def test_qnn_backend_cdist_forward(self):
3171+
module = CDistForward() # noqa: F405
3172+
sample_input = (
3173+
torch.randn(1, 125, 256),
3174+
torch.randn(1, 2048, 256),
3175+
)
3176+
module = self.get_qdq_module(module, sample_input)
3177+
self.lower_module_and_test_output(module, sample_input)
3178+
31623179
def test_qnn_backend_channel_shuffle(self):
31633180
module = ChannelShuffle(2) # noqa: F405
31643181
sample_input = (torch.randn(1, 4, 3, 3),)

0 commit comments

Comments
 (0)