Skip to content

Commit 37b6c6a

Browse files
committed
Update base for Update on "[ET Device Support] Schema changes: device info on Tensor and buffer-level device array"
This diff adds device placement information to the ExecuTorch schema to support representing tensor-level device type information, which will be the basic requirement for the following tensor_parser updates. This is part of the Phase 1 implementation to make ET device type work E2E without user-specified device placement. Design doc: https://docs.google.com/document/d/1lwd9BlohmwkN5EEvRulO_b-XnZBwv1nMb5l2K3jfuwA/edit?tab=t.0#heading=h.o6anuvkix4bu Differential Revision: [D93635657](https://our.internmc.facebook.com/intern/diff/D93635657/) [ghstack-poisoned]
2 parents d062e75 + 090af6c commit 37b6c6a

60 files changed

Lines changed: 1889 additions & 3007 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/scripts/unittest-linux-cmake.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ if ! python -c "import tosa_serializer" >/dev/null 2>&1; then
1919
TOSA_SERIALIZATION_DIR="${TOSA_TOOLS_DIR}/serialization"
2020
fi
2121

22+
# NOTE: Will be removed when tosa-tools is installed via pypi
23+
python -m pip install pybind11==2.10.4
2224
CMAKE_POLICY_VERSION_MINIMUM=3.5 BUILD_PYBIND=1 \
2325
python -m pip install --no-dependencies \
2426
"${TOSA_SERIALIZATION_DIR}"
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Build Cadence
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
- release/*
9+
workflow_dispatch:
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
13+
cancel-in-progress: true
14+
15+
jobs:
16+
cpu-x86:
17+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
18+
permissions:
19+
id-token: write
20+
contents: read
21+
with:
22+
job-name: build
23+
runner: linux.2xlarge
24+
docker-image: ci-image:executorch-ubuntu-22.04-clang12
25+
submodules: recursive
26+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
27+
timeout: 90
28+
script: |
29+
set -eux
30+
# The generic Linux job chooses to use base env, not the one setup by the image
31+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
32+
conda activate "${CONDA_ENV}"
33+
34+
./install_requirements.sh > /dev/null
35+
bash backends/cadence/build_cadence_runner.sh

backends/arm/MODELS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Some popular torch.nn.functional models (NN functional)
1313
- Some popular torch.nn.modules models (NN modules)
1414
- Some popular torch ops (Torch Functions)
15+
- T5 (T5 for conditional generation)
1516
- Neural Super Sampler (NSS)
1617
- Phi-3
1718
- ResNet 18

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from .decompose_index_select_to_gather_pass import ( # noqa
5656
DecomposeIndexSelectToGatherPass,
5757
)
58+
from .decompose_index_tensor_to_gather_pass import ( # noqa
59+
DecomposeIndexTensorToGatherPass,
60+
)
5861
from .decompose_int16_activation_conv_pass import ( # noqa
5962
DecomposeConvWithInt16ActivationPass,
6063
)

backends/arm/_passes/accumulate_index_put_pass.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import torch
88

99
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.decompose_index_tensor_to_gather_pass import (
11+
DecomposeIndexTensorToGatherPass,
12+
)
13+
from executorch.backends.arm._passes.rewrite_index_put_pass import RewriteIndexPutPass
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115
from executorch.exir.pass_base import ExportPass
1216

@@ -33,7 +37,10 @@ class AccumulateIndexPutPass(ArmPass):
3337
for the index_put op.
3438
"""
3539

36-
_passes_required_after: Set[Type[ExportPass]] = set()
40+
_passes_required_after: Set[Type[ExportPass]] = {
41+
DecomposeIndexTensorToGatherPass,
42+
RewriteIndexPutPass,
43+
}
3744

3845
def call_operator(self, op, args, kwargs, meta):
3946
if op not in (aten_ops + edge_ops) or not self.allowed_to_transform(meta):

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
DecomposeGroupedConvPass,
6161
DecomposeGroupNormPass,
6262
DecomposeIndexSelectToGatherPass,
63+
DecomposeIndexTensorToGatherPass,
6364
DecomposeIntPowPass,
6465
DecomposeLayerNormPass,
6566
DecomposeLeakyReLUPass,
@@ -307,6 +308,9 @@ def _tosa_pipeline(
307308
DecomposeEmbeddingPass(),
308309
DecomposeIndexSelectToGatherPass(),
309310
DecomposeStridedSliceCopyPass(),
311+
DecomposeSliceScatterPass(),
312+
AccumulateIndexPutPass(),
313+
DecomposeIndexTensorToGatherPass(),
310314
Conv1dUnsqueezePass(),
311315
]
312316
)
@@ -329,8 +333,6 @@ def _tosa_pipeline(
329333
# Node transformation passes (post scalar-removal)
330334
self.add_passes(
331335
[
332-
DecomposeSliceScatterPass(),
333-
AccumulateIndexPutPass(),
334336
RewriteIndexPutPass(),
335337
RewriteBoolBitwiseToLogicalPass(),
336338
DecomposeRemainderPass(),

backends/arm/_passes/arm_pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_get_control_flow_submodules,
2222
get_control_flow_submodules,
2323
)
24+
from executorch.exir.pass_base import NodeMetadata
2425

2526
from torch._export.utils import (
2627
get_buffer,
@@ -202,6 +203,14 @@ def insert_q_dq_pair(
202203
return dq
203204

204205

206+
def meta_without_qparams(meta: NodeMetadata) -> NodeMetadata:
207+
"""Return a copy of NodeMetadata with input/output qparams cleared."""
208+
plain_meta_dict = dict(meta.data)
209+
plain_meta_dict["input_qparams"] = {}
210+
plain_meta_dict["output_qparams"] = {}
211+
return NodeMetadata(plain_meta_dict)
212+
213+
205214
def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
206215
"""Returns a FakeTensor from the meta field of 'node'.
207216

0 commit comments

Comments
 (0)