Skip to content

Commit dce6eca

Browse files
authored
Merge branch 'main' into main
2 parents bae4e37 + ec31735 commit dce6eca

128 files changed

Lines changed: 4963 additions & 514 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/scripts/unittest-macos-cmake.sh

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@ set -eux
1212
export TORCHINDUCTOR_CACHE_DIR="$(mktemp -d "${RUNNER_TEMP:-/tmp}/torchinductor_cache_XXXXXX")"
1313
trap 'rm -rf "${TORCHINDUCTOR_CACHE_DIR}"' EXIT
1414

15-
# Run pytest with coverage
16-
${CONDA_RUN} pytest -n auto --cov=./ --cov-report=xml
15+
# TODO(SS-JIA): AOTI tests hang on macOS CI runners — the thread blocks in
16+
# native C/C++ code (dlopen / inductor compilation) so faulthandler cannot
17+
# even produce a traceback. Diagnosis ongoing in #19886.
18+
AOTI_SKIPS=(
19+
--ignore=examples/models/llama3_2_vision/preprocess/test_preprocess.py
20+
--ignore=examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py
21+
--ignore=examples/models/llama3_2_vision/text_decoder/test/test_text_decoder.py
22+
--deselect=extension/llm/modules/test/test_position_embeddings.py::TilePositionalEmbeddingTest::test_tile_positional_embedding_aoti
23+
--deselect=extension/llm/modules/test/test_position_embeddings.py::TiledTokenPositionalEmbeddingTest::test_tiled_token_positional_embedding_aoti
24+
--deselect=extension/llm/modules/test/test_attention.py::AttentionTest::test_attention_aoti
25+
)
26+
27+
${CONDA_RUN} pytest -n auto --cov=./ --cov-report=xml "${AOTI_SKIPS[@]}"
1728
# Run gtest
1829
LLVM_PROFDATA="xcrun llvm-profdata" LLVM_COV="xcrun llvm-cov" \
1930
${CONDA_RUN} test/run_oss_cpp_tests.sh

.github/scripts/propose_ghstack_orig_pr.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,9 @@ def extract_stack_from_body(pr_body: str) -> List[int]:
5252
"""
5353

5454
prs = []
55-
ghstack_begin = (
56-
"Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):"
57-
)
5855
ghstack_begin_seen = False
5956
for line in pr_body.splitlines():
60-
if ghstack_begin in line:
57+
if line.startswith("Stack from [ghstack]"):
6158
ghstack_begin_seen = True
6259
if not ghstack_begin_seen:
6360
continue

.github/workflows/_unittest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ jobs:
4949
python-version: '3.11'
5050
submodules: 'recursive'
5151
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
52+
timeout: 120
5253
script: |
5354
set -eux
5455
# This is needed to get the prebuilt PyTorch wheel from S3

.github/workflows/mlx.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ jobs:
8080
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v
8181
echo "::endgroup::"
8282
83+
echo "::group::Run tq_norm op tests"
84+
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v
85+
echo "::endgroup::"
86+
87+
echo "::group::Run tq4_compress op tests"
88+
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v
89+
echo "::endgroup::"
90+
91+
echo "::group::Run tq_dequant op tests"
92+
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v
93+
echo "::endgroup::"
94+
8395
test-mlx-qwen35-moe:
8496
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
8597
with:

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
from . import arm_pass_utils # noqa
8-
from .arm_pass import ArmPass # noqa # usort: skip
8+
from .arm_pass import ArmOpTargetedPass, ArmPass # noqa # usort: skip
99
from .accumulate_index_put_pass import AccumulateIndexPutPass # noqa
1010
from .broadcast_args_pass import BroadcastArgsPass # noqa
1111
from .canonicalize_gather_pass import CanonicalizeGatherPass # noqa

backends/arm/_passes/accumulate_index_put_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from executorch.backends.arm._passes import ArmPass
9+
from executorch.backends.arm._passes import ArmOpTargetedPass
1010
from executorch.backends.arm._passes.decompose_index_tensor_to_gather_pass import (
1111
DecomposeIndexTensorToGatherPass,
1212
)
@@ -32,7 +32,7 @@ def get_ops(op):
3232
raise RuntimeError(f"Can't get index_put decomposition for op {op}")
3333

3434

35-
class AccumulateIndexPutPass(ArmPass):
35+
class AccumulateIndexPutPass(ArmOpTargetedPass):
3636
"""This pass adjusts the values arg when the accumulate arg is set to true
3737
for the index_put op.
3838
"""
@@ -41,9 +41,11 @@ class AccumulateIndexPutPass(ArmPass):
4141
DecomposeIndexTensorToGatherPass,
4242
RewriteIndexPutPass,
4343
}
44+
target_ops = aten_ops + edge_ops
45+
check_allowed_to_transform = True
4446

4547
def call_operator(self, op, args, kwargs, meta):
46-
if op not in (aten_ops + edge_ops) or not self.allowed_to_transform(meta):
48+
if op not in self.target_ops or not self.allowed_to_transform(meta):
4749
return super().call_operator(op, args, kwargs, meta)
4850

4951
source, indices, values = args[:3]

backends/arm/_passes/arm_pass.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import copy
88
import traceback
99
from abc import abstractmethod
10+
from collections.abc import Collection
1011
from typing import Any, List, Optional, Set, Type
1112

1213
import torch
1314
from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY
1415
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1516
from executorch.exir.dialects._ops import ops as exir_ops
1617
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
17-
from torch.fx import GraphModule
18+
from torch.fx import GraphModule, Node
1819
from torch.fx.passes.infra.pass_base import PassResult
1920
from torch.utils import _pytree as pytree
2021

@@ -191,3 +192,99 @@ def call_scalar(self, value: int | float, meta: NodeMetadata | dict[str, Any]):
191192
meta=meta,
192193
updated=True,
193194
)
195+
196+
def should_run_pass(self, graph_module: GraphModule) -> bool:
197+
"""Return whether this pass should run on the graph module.
198+
199+
Subclasses can override this to cheaply skip the pass before
200+
``call()`` starts the normal ``ExportPass`` retracing path.
201+
202+
Args:
203+
graph_module (GraphModule): The graph module to inspect.
204+
205+
Returns:
206+
bool: True when the pass should run.
207+
208+
"""
209+
return True
210+
211+
def __call__(self, graph_module: GraphModule) -> PassResult | None:
212+
self.requires(graph_module)
213+
if not self.should_run_pass(graph_module):
214+
self.ensures(graph_module)
215+
return PassResult(graph_module, False)
216+
res = self.call(graph_module)
217+
self.ensures(graph_module)
218+
return res
219+
220+
221+
class ArmOpTargetedPass(ArmPass):
222+
"""Base class for passes that only transform selected operators.
223+
224+
Subclasses set ``target_ops`` to the call_function targets they can
225+
transform. If the current graph and nested control-flow subgraphs do not
226+
contain any target, the pass returns immediately without paying the default
227+
ExportPass retracing cost.
228+
229+
Set ``check_allowed_to_transform`` to ``True`` when the target pre-scan
230+
should also apply ``allowed_to_transform()`` to matching target nodes. This
231+
is useful for TFA passes whose ``call_operator()`` leaves disallowed target
232+
nodes unchanged. If all matching targets are disallowed, the pass can
233+
return before entering the normal ``ExportPass`` path.
234+
235+
"""
236+
237+
target_ops: Collection[Any] = ()
238+
check_allowed_to_transform = False
239+
240+
def has_target_node(self, graph_module: GraphModule) -> bool:
241+
"""Return whether the graph module tree contains a target node.
242+
243+
Args:
244+
graph_module (GraphModule): The graph module tree to inspect.
245+
246+
Returns:
247+
bool: True if a matching call_function node is present.
248+
249+
"""
250+
visited_graph_modules = set()
251+
252+
def target_node_can_trigger_pass(node: Node) -> bool:
253+
if not self.check_allowed_to_transform:
254+
return True
255+
if self.allowed_to_transform(node.meta):
256+
return True
257+
return False
258+
259+
def graph_has_target(module: GraphModule) -> bool:
260+
if id(module) in visited_graph_modules:
261+
return False
262+
visited_graph_modules.add(id(module))
263+
264+
for target in self.target_ops:
265+
for node in module.graph.find_nodes(
266+
op="call_function",
267+
target=target,
268+
sort=False,
269+
):
270+
if target_node_can_trigger_pass(node):
271+
return True
272+
273+
return any(
274+
isinstance(child, GraphModule) and graph_has_target(child)
275+
for child in module.children()
276+
)
277+
278+
return graph_has_target(graph_module)
279+
280+
def should_run_pass(self, graph_module: GraphModule) -> bool:
281+
"""Return whether this pass has a target node to transform.
282+
283+
Args:
284+
graph_module (GraphModule): The graph module tree to inspect.
285+
286+
Returns:
287+
bool: True when a matching target node is present.
288+
289+
"""
290+
return self.has_target_node(graph_module)

backends/arm/_passes/canonicalize_gather_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from typing import Set, Type
77

88
import torch
9-
from executorch.backends.arm._passes import ArmPass
9+
from executorch.backends.arm._passes import ArmOpTargetedPass
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass
1212

1313

14-
class CanonicalizeGatherPass(ArmPass):
14+
class CanonicalizeGatherPass(ArmOpTargetedPass):
1515
"""Canonicalize gather so it can be lowered to TOSA.GATHER via the backend
1616
dialect.
1717
@@ -40,10 +40,10 @@ class CanonicalizeGatherPass(ArmPass):
4040

4141
_passes_required_after: Set[Type[ExportPass]] = set()
4242

43-
_TARGET_OPS = {exir_ops.edge.aten.gather.default}
43+
target_ops = {exir_ops.edge.aten.gather.default}
4444

4545
def call_operator(self, op, args, kwargs, meta):
46-
if op not in self._TARGET_OPS:
46+
if op not in self.target_ops:
4747
return super().call_operator(op, args, kwargs, meta)
4848

4949
# edge.aten.gather.default: (x, dim, index) with kw-only sparse_grad

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from typing import Set, Type
1010

11-
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes import ArmOpTargetedPass
1212

1313
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1414
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
@@ -17,7 +17,7 @@
1717
from executorch.exir.pass_base import ExportPass
1818

1919

20-
class Conv1dUnsqueezePass(ArmPass):
20+
class Conv1dUnsqueezePass(ArmOpTargetedPass):
2121
"""This pass is used to change conv1d ops into conv2d since TOSA only
2222
supports 2d and 3d convolution.
2323
@@ -34,9 +34,10 @@ class Conv1dUnsqueezePass(ArmPass):
3434
RewriteConvPass,
3535
SizeAdjustInputPass,
3636
}
37+
target_ops = (exir_ops.edge.aten.convolution.default,)
3738

3839
def call_operator(self, op, args, kwargs, meta):
39-
if op != exir_ops.edge.aten.convolution.default:
40+
if op not in self.target_ops:
4041
return super().call_operator(op, args, kwargs, meta)
4142
stride = list(args[3])
4243
if len(stride) != 1:

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.arm._passes.arm_pass import ArmPass
12+
from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass
1313
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
1414
UnsqueezeBeforeRepeatPass,
1515
)
@@ -51,7 +51,7 @@ def calculate_multiples(args):
5151
return multiples, expanded_rank != len(input_shape)
5252

5353

54-
class ConvertExpandCopyToRepeatPass(ArmPass):
54+
class ConvertExpandCopyToRepeatPass(ArmOpTargetedPass):
5555
"""Replace expand copy with repeat since it is a repeat that can only repeat
5656
singleton dimensions.
5757
"""
@@ -60,9 +60,10 @@ class ConvertExpandCopyToRepeatPass(ArmPass):
6060

6161
expand_copy = exir_ops.edge.aten.expand_copy.default
6262
repeat = exir_ops.edge.aten.repeat.default
63+
target_ops = (expand_copy,)
6364

6465
def call_operator(self, op, args, kwargs, meta):
65-
if op != self.expand_copy:
66+
if op not in self.target_ops:
6667
return super().call_operator(op, args, kwargs, meta)
6768

6869
multiples, changes_rank = calculate_multiples(args)

0 commit comments

Comments
 (0)