Skip to content

Commit d0abdbb

Browse files
authored
Merge branch 'main' into reorder-macos-torch-install
2 parents 987557d + fa857bd commit d0abdbb

30 files changed

Lines changed: 297 additions & 71 deletions

.ci/scripts/utils.sh

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ install_pytorch_and_domains() {
105105
fi
106106
local python_version=$(python -c 'import platform; v=platform.python_version_tuple(); print(f"{v[0]}{v[1]}")')
107107
local torch_release=$(cat version.txt)
108-
local torch_short_hash=${TORCH_VERSION:0:7}
108+
# Download key must match the upload key below (basename of dist/*.whl,
109+
# which always carries setup.py's resolved +gitHASH). Branch-ref pins
110+
# like `release/2.11` would otherwise produce `+gitrelease` here and
111+
# never hit the cache.
112+
local torch_short_hash=$(git rev-parse --short=7 HEAD)
109113
local torch_wheel_path="cached_artifacts/pytorch/executorch/pytorch_wheels/${system_name}/${python_version}"
110114
local torch_wheel_name="torch-${torch_release}%2Bgit${torch_short_hash}-cp${python_version}-cp${python_version}-${platform:-}.whl"
111115

@@ -131,6 +135,30 @@ install_pytorch_and_domains() {
131135
USE_DISTRIBUTED=1 python setup.py bdist_wheel
132136
pip install "$(echo dist/*.whl)"
133137

138+
# Invariant: the basename setup.py just produced must match the cache
139+
# URL we'd reconstruct on the next run. If they diverge (someone edits
140+
# torch_wheel_name above, or PyTorch renames its wheels), the cache
141+
# will silently miss and every macOS run will fall back to a ~30-min
142+
# source build. Fail loudly so the regression is caught immediately.
143+
shopt -s nullglob
144+
local built_wheels=(dist/*.whl)
145+
shopt -u nullglob
146+
if [[ ${#built_wheels[@]} -ne 1 ]]; then
147+
echo "ERROR: expected exactly 1 wheel in dist/, found ${#built_wheels[@]}" >&2
148+
exit 1
149+
fi
150+
local built_wheel_name
151+
built_wheel_name=$(basename "${built_wheels[0]}")
152+
local expected_wheel_name="${torch_wheel_name//\%2B/+}"
153+
if [[ "${built_wheel_name}" != "${expected_wheel_name}" ]]; then
154+
echo "ERROR: built torch wheel name does not match cache URL key:" >&2
155+
echo " built: ${built_wheel_name}" >&2
156+
echo " expected: ${expected_wheel_name}" >&2
157+
echo "Fix torch_wheel_name construction in install_pytorch_and_domains" >&2
158+
echo "in .ci/scripts/utils.sh" >&2
159+
exit 1
160+
fi
161+
134162
# Only AWS runners have access to S3
135163
if command -v aws && [[ -z "${GITHUB_RUNNER:-}" ]]; then
136164
for wheel_path in dist/*.whl; do

backends/arm/test/passes/test_rewrite_conv_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import os
7+
68
import pytest
79
import torch
810
import torch.nn as nn
@@ -34,6 +36,8 @@
3436
from torch.export import Dim, export
3537
from torch.export.exported_program import _get_shape_env
3638

39+
_VGF_ENABLED = "LAVAPIPE_LIB_PATH" in os.environ
40+
3741

3842
class TinyConvReluCat(nn.Module):
3943
def __init__(self, conv1_bias: bool = True) -> None:
@@ -214,6 +218,7 @@ def test_rewrite_conv_tosa_FP():
214218
pipeline.run()
215219

216220

221+
@pytest.mark.skipif(not _VGF_ENABLED, reason="VGF not enabled")
217222
def test_fold_and_annotate_q_params_vgf_quant_preserves_output_qparams_on_non_fuseable_clamp() -> (
218223
None
219224
):
@@ -228,6 +233,7 @@ def test_fold_and_annotate_q_params_vgf_quant_preserves_output_qparams_on_non_fu
228233
assert clamp.meta["output_qparams"]
229234

230235

236+
@pytest.mark.skipif(not _VGF_ENABLED, reason="VGF not enabled")
231237
def test_rewrite_conv_vgf_quant_handles_non_fuseable_conv_clamp_cat_branch() -> None:
232238
exported_program = _export_quantized(TinyConvReluCat())
233239
compile_spec = _compile_spec()
@@ -239,6 +245,7 @@ def test_rewrite_conv_vgf_quant_handles_non_fuseable_conv_clamp_cat_branch() ->
239245
)
240246

241247

248+
@pytest.mark.skipif(not _VGF_ENABLED, reason="VGF not enabled")
242249
def test_rewrite_conv_vgf_quant_infers_quantized_bias_dtype_from_inputs() -> None:
243250
exported_program = _export_quantized(TinyConvReluCat(conv1_bias=False))
244251
edge_program = to_edge(

backends/arm/test/targets.bzl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,12 @@ def define_arm_tests():
8484
"EMULATION_LAYER_TENSOR_JSON": "$(location fbsource//third-party/arm-ml-emulation-layer/v0.9.0/src:VkLayer_Tensor_json)",
8585
"EMULATION_LAYER_GRAPH_JSON": "$(location fbsource//third-party/arm-ml-emulation-layer/v0.9.0/src:VkLayer_Graph_json)",
8686
} if _ENABLE_VGF else {}),
87-
preload_deps = [] if runtime.is_oss or not _ENABLE_VGF else [
87+
preload_deps = [
8888
"//executorch/kernels/quantized:custom_ops_generated_lib",
89+
] + ([] if runtime.is_oss or not _ENABLE_VGF else [
8990
"fbsource//third-party/khronos:vulkan",
9091
"//executorch/backends/arm/runtime:vgf_backend",
91-
],
92+
]),
9293
deps = [
9394
"//executorch/backends/arm/test:arm_tester" if runtime.is_oss else "//executorch/backends/arm/test/tester/fb:arm_tester_fb",
9495
"//executorch/backends/arm/test:conftest",

backends/cadence/aot/functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@
399399
- arg_meta: null
400400
kernel_name: impl::generic::quantized_conv1d_ncl_per_tensor_out
401401

402-
- func: cadence::quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
402+
- func: cadence::quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, Tensor? offset=None, *, Tensor(a!) out) -> Tensor(a!)
403403
kernels:
404404
- arg_meta: null
405405
kernel_name: impl::generic::quantized_conv1d_nlc_per_tensor_out

backends/cadence/aot/functions_hifi.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@
574574
- arg_meta: null
575575
kernel_name: impl::HiFi::quantized_conv1d_ncl_per_tensor_out
576576

577-
- func: cadence::quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
577+
- func: cadence::quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, Tensor? offset=None, *, Tensor(a!) out) -> Tensor(a!)
578578
kernels:
579579
- arg_meta: null
580580
kernel_name: impl::HiFi::quantized_conv1d_nlc_per_tensor_out

backends/cadence/aot/ops_registrations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,10 @@ def register_fake(
263263
"quantized_conv1d_nlc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
264264
)
265265
lib.define(
266-
"quantized_conv1d_nlc.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
266+
"quantized_conv1d_nlc.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, Tensor? offset=None) -> (Tensor Z)"
267267
)
268268
lib.define(
269-
"quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
269+
"quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, Tensor? offset=None, *, Tensor(a!) out) -> Tensor(a!)"
270270
)
271271
lib.define(
272272
"quantized_depthwise_conv1d_ncl.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
@@ -1305,6 +1305,7 @@ def quantized_conv1d_nlc_per_tensor_meta(
13051305
output_zero_point: int,
13061306
out_multiplier: int,
13071307
out_shift: int,
1308+
offset: Optional[torch.Tensor] = None,
13081309
) -> torch.Tensor:
13091310
torch._check(bias.dtype == torch.int32, lambda: "expected int32")
13101311
# NLC format: input is [N, L, C], weight is [OC, K, IC/groups]

backends/cadence/aot/reorder_ops.py

Lines changed: 91 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -721,15 +721,16 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
721721

722722
@register_cadence_pass(CadencePassAttribute(opt_level=1))
723723
class PropagateSlice(RemoveOrReplacePassInterface):
724-
"""Propagate slice_copy before unary element-wise ops when the cost
725-
model indicates it reduces total data movement.
724+
"""Propagate slice_copy before element-wise ops when the cost model
725+
indicates it reduces total data movement.
726726
727727
Supported ops (extensible via dispatch table):
728-
- quantize_per_tensor: element-wise, slice passes through unchanged
729-
- dequantize_per_tensor: element-wise, slice passes through unchanged
728+
- quantize_per_tensor: unary element-wise
729+
- dequantize_per_tensor: unary element-wise
730+
- add.Tensor: binary with broadcast — slices non-broadcasting inputs
731+
- mul.Tensor: binary with broadcast — slices non-broadcasting inputs
730732
731-
Handles any slice dim and any step size. Runs in the iterative pass
732-
loop — chains are handled by repeated application.
733+
Handles any slice dim and any step size.
733734
"""
734735

735736
def __init__(self) -> None:
@@ -740,16 +741,28 @@ def __init__(self) -> None:
740741
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
741742
exir_ops.edge.cadence.dequantize_per_tensor.default,
742743
]
744+
binary_targets = [
745+
exir_ops.edge.aten.add.Tensor,
746+
exir_ops.edge.aten.mul.Tensor,
747+
]
743748
self._dispatch: dict[
744749
EdgeOpOverload,
745750
tuple[
746751
Callable[[torch.fx.Node, torch.fx.Node], bool],
747752
Callable[[torch.fx.Node, torch.fx.Node], bool],
748753
],
749-
] = {
750-
t: (self._should_swap_elementwise, self._swap_elementwise_slice)
751-
for t in elementwise_targets
752-
}
754+
] = {}
755+
for t in elementwise_targets:
756+
self._dispatch[t] = (
757+
self._should_swap_elementwise,
758+
self._swap_elementwise_slice,
759+
)
760+
761+
for t in binary_targets:
762+
self._dispatch[t] = (
763+
self._should_swap_binary_elementwise,
764+
self._swap_binary_elementwise_slice,
765+
)
753766

754767
@property
755768
def targets(self) -> list[EdgeOpOverload]:
@@ -765,19 +778,21 @@ def _should_swap_elementwise(
765778
def _swap_elementwise_slice(
766779
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
767780
) -> bool:
768-
op_input = op_node.args[0]
769-
assert isinstance(op_input, torch.fx.Node)
781+
op_input = get_arg(op_node, "input", torch.fx.Node)
770782
graph = slice_node.graph
771783

772-
slice_args = slice_node.args[1:]
784+
slice_dim = get_arg(slice_node, "dim", int)
785+
slice_start = get_arg(slice_node, "start")
786+
slice_end = get_arg(slice_node, "end")
787+
slice_step = get_arg(slice_node, "step", int)
773788

774789
with graph.inserting_before(op_node):
775790
new_slice = graph.call_function(
776791
exir_ops.edge.aten.slice_copy.Tensor,
777-
args=(op_input, *slice_args),
792+
args=(op_input, slice_dim, slice_start, slice_end, slice_step),
778793
)
779794
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
780-
op_input.meta["val"], *slice_args
795+
op_input.meta["val"], slice_dim, slice_start, slice_end, slice_step
781796
)
782797

783798
new_args = list(op_node.args)
@@ -805,10 +820,68 @@ def _swap_elementwise_slice(
805820
graph.erase_node(op_node)
806821
return True
807822

808-
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
809-
parent = node.args[0]
810-
if not isinstance(parent, torch.fx.Node):
823+
def _should_swap_binary_elementwise(
824+
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
825+
) -> bool:
826+
lhs, rhs = op_node.args[0], op_node.args[1]
827+
assert isinstance(lhs, torch.fx.Node) and isinstance(rhs, torch.fx.Node)
828+
if lhs.meta["val"].shape == rhs.meta["val"].shape:
811829
return False
830+
full_size = prod(op_node.meta["val"].shape)
831+
sliced_size = prod(slice_node.meta["val"].shape)
832+
return sliced_size < full_size
833+
834+
def _swap_binary_elementwise_slice(
835+
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
836+
) -> bool:
837+
lhs, rhs = op_node.args[0], op_node.args[1]
838+
assert isinstance(lhs, torch.fx.Node) and isinstance(rhs, torch.fx.Node)
839+
graph = slice_node.graph
840+
841+
slice_dim = get_arg(slice_node, "dim", int)
842+
slice_start = get_arg(slice_node, "start")
843+
slice_end = get_arg(slice_node, "end")
844+
slice_step = get_arg(slice_node, "step", int)
845+
846+
output_shape = op_node.meta["val"].shape
847+
848+
new_args = list(op_node.args)
849+
with graph.inserting_before(op_node):
850+
for i, inp in enumerate([lhs, rhs]):
851+
if inp.meta["val"].shape[slice_dim] == output_shape[slice_dim]:
852+
new_slice = graph.call_function(
853+
exir_ops.edge.aten.slice_copy.Tensor,
854+
args=(inp, slice_dim, slice_start, slice_end, slice_step),
855+
)
856+
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
857+
inp.meta["val"], slice_dim, slice_start, slice_end, slice_step
858+
)
859+
new_args[i] = new_slice
860+
861+
target = cast(EdgeOpOverload, op_node.target)
862+
new_op = graph.call_function(
863+
target,
864+
args=tuple(new_args),
865+
kwargs=op_node.kwargs,
866+
)
867+
new_op.meta["val"] = target(
868+
*[
869+
a.meta["val"] if isinstance(a, torch.fx.Node) else a
870+
for a in new_args
871+
],
872+
**{
873+
k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
874+
for k, v in op_node.kwargs.items()
875+
},
876+
)
877+
878+
slice_node.replace_all_uses_with(new_op)
879+
graph.erase_node(slice_node)
880+
graph.erase_node(op_node)
881+
return True
882+
883+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
884+
parent = get_arg(node, "input", torch.fx.Node)
812885
if len(parent.users) != 1:
813886
return False
814887
if not isinstance(parent.target, EdgeOpOverload):

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,100 @@ def test_unsupported_parent_not_swapped(self) -> None:
927927
result = PropagateSlice().call(gm)
928928

929929
self.assertFalse(result.modified)
930+
931+
def test_swap_broadcast_mul_slice_on_broadcast_dim(self) -> None:
932+
"""[1,60,1,1] * [4,1,1,1] → [4,60,1,1] → slice(dim=0, step=2)
933+
Only the [4,1,1,1] input should be sliced."""
934+
builder = GraphBuilder()
935+
a = builder.placeholder("a", torch.randn(1, 60, 1, 1))
936+
b = builder.placeholder("b", torch.randn(4, 1, 1, 1))
937+
mul = builder.call_operator(exir_ops.edge.aten.mul.Tensor, args=(a, b))
938+
sliced = builder.call_operator(
939+
exir_ops.edge.aten.slice_copy.Tensor,
940+
args=(mul, 0, 0, 4, 2),
941+
)
942+
builder.output([sliced])
943+
gm = builder.get_graph_module()
944+
945+
result = PropagateSlice().call(gm)
946+
947+
self.assertTrue(result.modified)
948+
949+
slice_nodes = gm.graph.find_nodes(
950+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
951+
)
952+
self.assertEqual(len(slice_nodes), 1)
953+
self.assertEqual(slice_nodes[0].args[0].name, "b")
954+
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 1, 1, 1])
955+
956+
mul_nodes = gm.graph.find_nodes(
957+
op="call_function", target=exir_ops.edge.aten.mul.Tensor
958+
)
959+
self.assertEqual(len(mul_nodes), 1)
960+
self.assertEqual(list(mul_nodes[0].meta["val"].shape), [2, 60, 1, 1])
961+
962+
def test_swap_broadcast_add_lhs_broadcasts(self) -> None:
963+
"""[1,60,4,4] + [4,60,4,4] → [4,60,4,4] → slice(dim=0, step=2)
964+
Only the [4,60,4,4] (rhs) should be sliced."""
965+
builder = GraphBuilder()
966+
a = builder.placeholder("a", torch.randn(1, 60, 4, 4))
967+
b = builder.placeholder("b", torch.randn(4, 60, 4, 4))
968+
add = builder.call_operator(exir_ops.edge.aten.add.Tensor, args=(a, b))
969+
sliced = builder.call_operator(
970+
exir_ops.edge.aten.slice_copy.Tensor,
971+
args=(add, 0, 0, 4, 2),
972+
)
973+
builder.output([sliced])
974+
gm = builder.get_graph_module()
975+
976+
result = PropagateSlice().call(gm)
977+
978+
self.assertTrue(result.modified)
979+
980+
slice_nodes = gm.graph.find_nodes(
981+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
982+
)
983+
self.assertEqual(len(slice_nodes), 1)
984+
self.assertEqual(slice_nodes[0].args[0].name, "b")
985+
986+
def test_swap_broadcast_mul_slice_on_non_broadcast_dim(self) -> None:
987+
"""[4,60,1,1] * [4,1,1,1] → [4,60,1,1] → slice(dim=1, start=0, end=30)
988+
Only the [4,60,1,1] (lhs) should be sliced since rhs has dim1=1."""
989+
builder = GraphBuilder()
990+
a = builder.placeholder("a", torch.randn(4, 60, 1, 1))
991+
b = builder.placeholder("b", torch.randn(4, 1, 1, 1))
992+
mul = builder.call_operator(exir_ops.edge.aten.mul.Tensor, args=(a, b))
993+
sliced = builder.call_operator(
994+
exir_ops.edge.aten.slice_copy.Tensor,
995+
args=(mul, 1, 0, 30, 1),
996+
)
997+
builder.output([sliced])
998+
gm = builder.get_graph_module()
999+
1000+
result = PropagateSlice().call(gm)
1001+
1002+
self.assertTrue(result.modified)
1003+
1004+
slice_nodes = gm.graph.find_nodes(
1005+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1006+
)
1007+
self.assertEqual(len(slice_nodes), 1)
1008+
self.assertEqual(slice_nodes[0].args[0].name, "a")
1009+
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 1, 1])
1010+
1011+
def test_no_swap_binary_same_shape(self) -> None:
1012+
"""Same-shape binary ops are not swapped (no broadcast)."""
1013+
builder = GraphBuilder()
1014+
a = builder.placeholder("a", torch.randn(4, 60, 4, 4))
1015+
b = builder.placeholder("b", torch.randn(4, 60, 4, 4))
1016+
add = builder.call_operator(exir_ops.edge.aten.add.Tensor, args=(a, b))
1017+
sliced = builder.call_operator(
1018+
exir_ops.edge.aten.slice_copy.Tensor,
1019+
args=(add, 0, 0, 4, 2),
1020+
)
1021+
builder.output([sliced])
1022+
gm = builder.get_graph_module()
1023+
1024+
result = PropagateSlice().call(gm)
1025+
1026+
self.assertFalse(result.modified)

0 commit comments

Comments
 (0)