Skip to content

Commit 13225ba

Browse files
committed
Update on "[ET Device Support] Schema changes: device info on Tensor and buffer-level device array"
This diff adds device placement information to the ExecuTorch schema to support representing tensor-level device type information, which will be the basic requirement for the following tensor_parser updates. This is part of the Phase 1 implementation to make ET device type work E2E without user-specified device placement. Design doc: https://docs.google.com/document/d/1lwd9BlohmwkN5EEvRulO_b-XnZBwv1nMb5l2K3jfuwA/edit?tab=t.0#heading=h.o6anuvkix4bu Differential Revision: [D93635657](https://our.internmc.facebook.com/intern/diff/D93635657/) [ghstack-poisoned]
2 parents d15c925 + c8cb794 commit 13225ba

61 files changed

Lines changed: 2289 additions & 3277 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/scripts/test_backend.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ if [[ "$FLOW" == *qnn* ]]; then
4646
export LD_LIBRARY_PATH"=$QNN_X86_LIB_DIR:$QNN_SDK_ROOT/lib/x86_64-linux-clang/:${LD_LIBRARY_PATH:-}"
4747

4848
# TODO Get SDK root from install scripts
49-
EXTRA_BUILD_ARGS+=" -DEXECUTORCH_BUILD_QNN=ON -DQNN_SDK_ROOT=$QNN_SDK_ROOT"
49+
EXTRA_BUILD_ARGS+=" -DEXECUTORCH_BUILD_QNN=ON -DQNN_SDK_ROOT=$QNN_SDK_ROOT -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON"
5050
fi
5151

5252
if [[ "$FLOW" == *vulkan* ]]; then

.github/workflows/_unittest.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,11 @@ jobs:
5858
if: ${{ inputs.build-tool == 'cmake' }}
5959
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main
6060
with:
61-
submodules: 'recursive'
6261
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
6362
timeout: 120
6463
script: |
64+
git config --global http.sslBackend openssl
65+
git submodule update --init --recursive
6566
conda init powershell
6667
6768
powershell -Command "& {

.github/workflows/build-presets.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@ jobs:
113113
with:
114114
job-name: build
115115
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
116-
submodules: recursive
117116
timeout: 90
118117
script: |
119118
set -eux
119+
git config --global http.sslBackend openssl
120+
git submodule update --init --recursive
120121
conda init powershell
121122
powershell -Command "& {
122123
Set-PSDebug -Trace 1

.github/workflows/cuda-windows.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ jobs:
127127
runner: windows.g5.4xlarge.nvidia.gpu
128128
gpu-arch-type: cuda
129129
gpu-arch-version: 12.8
130-
submodules: recursive
131130
download-artifact: ${{ matrix.model_repo }}-${{ matrix.model_name }}-cuda-windows-${{ matrix.quant }}
132131
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
133132
script: |
133+
git config --global http.sslBackend openssl
134+
git submodule update --init --recursive
134135
conda init powershell
135136
powershell -Command "& {
136137
Set-PSDebug -Trace 1

.github/workflows/trunk.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1123,10 +1123,11 @@ jobs:
11231123
model: [mv3, resnet50, vit, mobilebert, emformer_transcribe]
11241124
backend: [portable, xnnpack-q8]
11251125
with:
1126-
submodules: 'recursive'
11271126
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
11281127
timeout: 60
11291128
script: |
1129+
git config --global http.sslBackend openssl
1130+
git submodule update --init --recursive
11301131
conda init powershell
11311132
11321133
powershell -Command "& {

.github/workflows/windows-msvc.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ jobs:
2222
name: build-windows-msvc
2323
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main
2424
with:
25-
submodules: 'recursive'
2625
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
2726
timeout: 60
2827
script: |
28+
git config --global http.sslBackend openssl
29+
git submodule update --init --recursive
2930
conda init powershell
3031
powershell -Command "& {
3132
Set-PSDebug -Trace 1

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .decompose_glu_pass import DecomposeGluPass # noqa
5353
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
5454
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
55+
from .decompose_index_copy_pass import DecomposeIndexCopyPass # noqa
5556
from .decompose_index_select_to_gather_pass import ( # noqa
5657
DecomposeIndexSelectToGatherPass,
5758
)
@@ -94,6 +95,7 @@
9495
from .decompose_tril_pass import DecomposeTrilPass # noqa
9596
from .decompose_unfold_to_gather_pass import DecomposeUnfoldToGatherPass # noqa
9697
from .decompose_var_pass import DecomposeVarPass # noqa
98+
from .decompose_where_scalar_other_pass import DecomposeWhereScalarOtherPass # noqa
9799
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
98100
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
99101
FoldAndAnnotateQParamsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
DecomposeGluPass,
6060
DecomposeGroupedConvPass,
6161
DecomposeGroupNormPass,
62+
DecomposeIndexCopyPass,
6263
DecomposeIndexSelectToGatherPass,
6364
DecomposeIndexTensorToGatherPass,
6465
DecomposeIntPowPass,
@@ -92,6 +93,7 @@
9293
DecomposeTrilPass,
9394
DecomposeUnfoldToGatherPass,
9495
DecomposeVarPass,
96+
DecomposeWhereScalarOtherPass,
9597
DecorateFp32toInt32CastingPass,
9698
FoldAndAnnotateQParamsPass,
9799
FuseBatchNorm2dPass,
@@ -320,6 +322,7 @@ def _tosa_pipeline(
320322
[
321323
ReplaceScalarWithTensorByProfilePass(),
322324
RewriteLeLtToGeGtPass(),
325+
DecomposeLeakyReLUPass(), # Emits full_like so before ConvertFullLikeToFullPass
323326
ConvertFullLikeToFullPass(),
324327
MatchArgDtypePass(),
325328
UnsqueezeScalarPlaceholdersPass(exported_program),
@@ -340,7 +343,6 @@ def _tosa_pipeline(
340343
FuseBatchNorm2dPass(exported_program),
341344
ConvertMmToBmmPass(),
342345
DecomposeGluPass(),
343-
DecomposeLeakyReLUPass(),
344346
DecomposeDivPass(),
345347
# _safe_softmax results in a ReduceMax
346348
# which is not currently supported by TOSA in U55
@@ -418,6 +420,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
418420
# Transformation passes (pre scalar -> tensor)
419421
self.add_passes(
420422
[
423+
DecomposeIndexCopyPass(tfa_pass=True),
421424
DecomposeSelectScatterPass(tfa_pass=True),
422425
DecomposeSliceScatterPass(tfa_pass=True),
423426
ConvertInt64ConstOpsToInt32Pass(tfa_pass=True),
@@ -434,6 +437,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
434437
DecomposeRemainderPass(tfa_pass=True),
435438
DecomposeFloorDividePass(tfa_pass=True),
436439
DecomposeDivTensorModePass(tfa_pass=True),
440+
DecomposeWhereScalarOtherPass(tfa_pass=True),
437441
]
438442
)
439443

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes.get_decomposition_pass import GetDecompositionPass
10+
from executorch.backends.arm._passes.insert_int32_casts_after_int64_placeholders import (
11+
InsertInt32CastsAfterInt64PlaceholdersPass,
12+
)
13+
from executorch.exir.pass_base import ExportPass
14+
15+
16+
class DecomposeIndexCopyPass(GetDecompositionPass):
17+
"""Decomposes aten.index_copy into aten.index_put, as well as it's
18+
surrounding operators.
19+
20+
This pass is intended to be called in transform_for_annotation to prepare
21+
the graph for quantization. After quantization, this operator will be
22+
prepared for lowering to TOSA using the RewriteIndexPut pass
23+
24+
"""
25+
26+
_passes_required_after: Set[Type[ExportPass]] = {
27+
InsertInt32CastsAfterInt64PlaceholdersPass
28+
}
29+
30+
targeted_ops = [
31+
torch.ops.aten.index_copy.default,
32+
torch.ops.aten.index_copy_.default,
33+
]

backends/arm/_passes/decompose_leaky_relu_pass.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
import torch
1111
from executorch.backends.arm._passes import ArmPass
12+
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
13+
ConvertFullLikeToFullPass,
14+
)
1215
from executorch.exir.dialects._ops import ops as exir_ops
1316
from executorch.exir.pass_base import ExportPass
1417

@@ -20,14 +23,14 @@ def _get_leaky_relu_ops(op) -> tuple:
2023
if op in edge_ops:
2124
return (
2225
exir_ops.edge.aten.clamp.default,
23-
exir_ops.edge.aten.full.default,
26+
exir_ops.edge.aten.full_like.default,
2427
exir_ops.edge.aten.mul.Tensor,
2528
exir_ops.edge.aten.add.Tensor,
2629
)
2730
elif op in torch_ops:
2831
return (
2932
torch.ops.aten.clamp.default,
30-
torch.ops.aten.full.default,
33+
torch.ops.aten.full_like.default,
3134
torch.ops.aten.mul.Tensor,
3235
torch.ops.aten.add.Tensor,
3336
)
@@ -42,33 +45,31 @@ class DecomposeLeakyReLUPass(ArmPass):
4245
Example:
4346
%op1 = clamp(x,0,None) (equivalent to max(0,x))
4447
%op2 = clamp(x,None,0) (equivalent to min(0,x))
45-
%op3 = full(x.shape,slope)
48+
%op3 = full_like(x,slope)
4649
%op4 = mul(%op3,%op2)
4750
%op5 = add(%op1,%op4)
4851
4952
"""
5053

51-
_passes_required_after: Set[Type[ExportPass]] = set()
54+
_passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass}
5255

5356
def call_operator(self, op, args, kwargs, meta):
5457
if op not in (edge_ops + torch_ops) or not self.allowed_to_transform(meta):
5558
return super().call_operator(op, args, kwargs, meta)
5659

5760
x = args[0]
5861
slope = args[1] if len(args) > 1 else 0.01
59-
dtype = x.node.meta["val"].dtype
60-
device = x.node.meta["val"].device
61-
clamp, full, mul, add = _get_leaky_relu_ops(op)
62+
clamp, full_like, mul, add = _get_leaky_relu_ops(op)
6263
op1 = super().call_operator(
6364
op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta
6465
)
6566
op2 = super().call_operator(
6667
op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta
6768
)
6869
op3 = super().call_operator(
69-
op=full,
70-
args=(x.node.meta["val"].shape, slope),
71-
kwargs={"dtype": dtype, "device": device},
70+
op=full_like,
71+
args=(x, slope),
72+
kwargs={},
7273
meta=meta,
7374
)
7475
op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta)

0 commit comments

Comments
 (0)