Skip to content

Commit deb1c34

Browse files
authored
Merge branch 'cuda-graph' into fused-deltanet-decode
2 parents 1c73738 + aa7bb82 commit deb1c34

30 files changed

Lines changed: 2521 additions & 197 deletions

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def call(self, graph_module: torch.fx.GraphModule):
465465
Entry point for the pass: annotate spatial ranks, compute dim orders,
466466
insert bridging transposes, and forward to child passes.
467467
"""
468+
graph_module.graph.eliminate_dead_code()
468469
nodes = list(graph_module.graph.nodes)
469470
for node in nodes:
470471
if not self._is_ok_for_annotation(node):

backends/arm/test/misc/test_const_shape.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,22 @@
55

66
from typing import Set, Type
77

8+
import executorch.backends.arm.tosa.dialect # noqa: F401
9+
import pytest
810
import torch
11+
import tosa_serializer as ts
912
from executorch.backends.arm._passes.arm_pass import ArmPass
13+
from executorch.backends.arm._passes.to_tosa_memory_format_pass import (
14+
ToTosaMemoryFormatPass,
15+
)
16+
from executorch.backends.arm.operators.node_visitor import get_node_visitors
17+
from executorch.backends.arm.process_node import process_call_function
1018
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
19+
from executorch.backends.arm.tosa.specification import (
20+
TosaLoweringContext,
21+
TosaSpecification,
22+
)
23+
from executorch.backends.test.graph_builder import GraphBuilder
1124
from executorch.exir import to_edge
1225
from executorch.exir.dialects._ops import ops as exir_ops
1326
from executorch.exir.pass_base import ExportPass
@@ -54,3 +67,76 @@ def forward(self, x):
5467
assert const_shape_nodes
5568
for n in const_shape_nodes:
5669
assert n.meta[TosaSpecialDtype.meta_key()] == TosaSpecialDtype.SHAPE
70+
71+
72+
def _graph_module_with_unused_const_shape():
73+
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
74+
builder = GraphBuilder()
75+
builder.call_operator(exir_ops.backend.tosa.CONST_SHAPE.default, ([1],))
76+
live_const = builder.call_operator(
77+
exir_ops.backend.tosa.CONST_SHAPE.default, ([3],)
78+
)
79+
builder.output([live_const])
80+
graph_module = ExportPass().call(builder.get_graph_module()).graph_module
81+
for node in graph_module.graph.nodes:
82+
if node.op == "call_function":
83+
node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
84+
return graph_module
85+
86+
87+
def _propagate_shape_dim_orders_from_users(graph_module: torch.fx.GraphModule) -> None:
88+
output_node = next(node for node in graph_module.graph.nodes if node.op == "output")
89+
output_node.meta["tosa_dim_order"] = (0,)
90+
dummy_exported = torch.export.export(torch.nn.Identity(), (torch.randn(1),))
91+
tosa_memory_format_pass = ToTosaMemoryFormatPass(dummy_exported)
92+
tosa_memory_format_pass._propagate_dim_order_to_shape_args(output_node)
93+
94+
95+
def _serialize_graph_module_to_tosa(graph_module: torch.fx.GraphModule):
96+
tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+shape")
97+
node_visitors = get_node_visitors(None, tosa_spec)
98+
tosa_graph = ts.TosaSerializer(
99+
"",
100+
targetMajor=tosa_spec.version.major,
101+
targetMinor=tosa_spec.version.minor,
102+
targetPatch=tosa_spec.version.micro,
103+
targetDraft=True,
104+
)
105+
106+
for node in graph_module.graph.nodes:
107+
if node.op == "call_function":
108+
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
109+
110+
return tosa_graph
111+
112+
113+
def test_unused_shape_ops_miss_tosa_dim_order_and_must_be_removed_before_tosa_serialization():
114+
graph_module = _graph_module_with_unused_const_shape()
115+
_propagate_shape_dim_orders_from_users(graph_module)
116+
117+
const_shape_nodes = [
118+
node
119+
for node in graph_module.graph.nodes
120+
if node.op == "call_function"
121+
and node.target == exir_ops.backend.tosa.CONST_SHAPE.default
122+
]
123+
dead_const_shape, live_const_shape = const_shape_nodes
124+
125+
assert dead_const_shape.users == {}
126+
assert "tosa_dim_order" not in dead_const_shape.meta
127+
assert live_const_shape.meta["tosa_dim_order"] == (0,)
128+
129+
with pytest.raises(KeyError, match="tosa_dim_order"):
130+
_serialize_graph_module_to_tosa(graph_module)
131+
132+
graph_module.graph.eliminate_dead_code()
133+
graph_module.recompile()
134+
135+
remaining_const_shape = next(
136+
node
137+
for node in graph_module.graph.nodes
138+
if node.op == "call_function"
139+
and node.target == exir_ops.backend.tosa.CONST_SHAPE.default
140+
)
141+
assert remaining_const_shape.meta["tosa_dim_order"] == (0,)
142+
assert _serialize_graph_module_to_tosa(graph_module)

backends/arm/test/targets.bzl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ load("@fbcode_macros//build_defs:python_pytest.bzl", "python_pytest")
33
load("@bazel_skylib//lib:paths.bzl", "paths")
44
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
55

6-
_ENABLE_VGF = True
6+
_ENABLE_VGF = False # Disabled: memfd_create blocked by seccomp on Sandcastle causes segfaults before Python pre-flight check can run
77

88
def define_arm_tests():
99
# TODO [fbonly] Add more tests
@@ -72,6 +72,7 @@ def define_arm_tests():
7272
resources = ["conftest.py"],
7373
compile = "with-source",
7474
typing = False,
75+
skip_on_mode_mac = True,
7576
env = {} if runtime.is_oss else ({
7677
"MODEL_CONVERTER_PATH": "$(location fbsource//third-party/pypi/ai-ml-sdk-model-converter/0.8.0:model-converter-bin)",
7778
"MODEL_CONVERTER_LIB_DIR": "$(location fbsource//third-party/nvidia-nsight-systems:linux-x86_64)/host-linux-x64",
@@ -81,12 +82,11 @@ def define_arm_tests():
8182
"EMULATION_LAYER_TENSOR_JSON": "$(location fbsource//third-party/arm-ml-emulation-layer/v0.9.0/src:VkLayer_Tensor_json)",
8283
"EMULATION_LAYER_GRAPH_JSON": "$(location fbsource//third-party/arm-ml-emulation-layer/v0.9.0/src:VkLayer_Graph_json)",
8384
} if _ENABLE_VGF else {}),
84-
preload_deps = [
85+
preload_deps = [] if runtime.is_oss or not _ENABLE_VGF else [
8586
"//executorch/kernels/quantized:custom_ops_generated_lib",
86-
] + ([] if runtime.is_oss or not _ENABLE_VGF else [
8787
"fbsource//third-party/khronos:vulkan",
8888
"//executorch/backends/arm/runtime:vgf_backend",
89-
]),
89+
],
9090
deps = [
9191
"//executorch/backends/arm/test:arm_tester" if runtime.is_oss else "//executorch/backends/arm/test/tester/fb:arm_tester_fb",
9292
"//executorch/backends/arm/test:conftest",

backends/cadence/aot/ops_registrations.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -689,11 +689,11 @@ def register_fake(
689689
)
690690

691691
lib.define(
692-
"quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale) -> Tensor"
692+
"quantized_w8a32_gru(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden) -> Tensor"
693693
)
694694

695695
lib.define(
696-
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_i_scale, Tensor bias_hidden, float b_h_scale, *, Tensor(a!) out) -> Tensor(a!)"
696+
"quantized_w8a32_gru.out(Tensor inputs, Tensor hidden, Tensor weights_inputs, float w_i_scale, Tensor weights_hidden, float w_h_scale, Tensor bias_inputs, float b_scale, Tensor bias_hidden, *, Tensor(a!) out) -> Tensor(a!)"
697697
)
698698

699699
lib.define(
@@ -3060,11 +3060,20 @@ def quantized_w8a32_gru_meta(
30603060
weights_hidden: torch.Tensor,
30613061
w_h_scale: float,
30623062
bias_inputs: torch.Tensor,
3063-
b_i_scale: float,
3063+
b_scale: float,
30643064
bias_hidden: torch.Tensor,
3065-
b_h_scale: float,
30663065
) -> torch.Tensor:
3067-
return hidden.new_empty((2, *hidden.shape), dtype=torch.float32)
3066+
seq_len = inputs.shape[1]
3067+
assert seq_len == 1
3068+
# inputs comes in shape [batch, seq_len, input_size]
3069+
# hidden comes in shape [batch, seq_len, hidden_size]
3070+
# weights_inputs comes in shape [3 * hidden_size, input_size]
3071+
# weights_hidden comes in shape [3 * hidden_size, hidden_size]
3072+
# output comes in empty with shape [2, batch, seq_len, hidden_size]
3073+
# The first dimension stacks the output and the new hidden state
3074+
return hidden.new_empty(
3075+
(2, inputs.shape[0], inputs.shape[1], hidden.shape[-1]), dtype=torch.float32
3076+
)
30683077

30693078

30703079
@register_fake("cadence::slice_scatter_")

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -438,26 +438,36 @@ def get_args_and_kwargs_mixed_w8a32_conv(
438438
torch.ops.aten.permute.default,
439439
(other_inputs[0], [0, 2, 1]), # NCL -> NLC
440440
)
441-
assert "val" in other_inputs[0].meta, "Missing val metadata on input node"
442-
original_val = other_inputs[0].meta["val"]
443-
assert original_val.fake_mode is not None, "fake_mode is None on input node"
444-
with original_val.fake_mode:
445-
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
446-
original_val, [0, 2, 1]
447-
)
441+
# Propagate val metadata for transposed_inputs
442+
if "val" in other_inputs[0].meta:
443+
original_val = other_inputs[0].meta["val"]
444+
fake_mode = original_val.fake_mode
445+
if fake_mode is not None:
446+
with fake_mode:
447+
transposed_val = torch.ops.aten.permute.default(original_val, [0, 2, 1])
448+
transposed_inputs.meta["val"] = transposed_val
449+
else:
450+
transposed_inputs.meta["val"] = torch.ops.aten.permute.default(
451+
original_val, [0, 2, 1]
452+
)
448453
copy_node_metadata(transposed_inputs, other_inputs[0])
449454

450455
transposed_weights = graph_module.graph.call_function(
451456
torch.ops.aten.permute.default,
452457
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
453458
)
454-
assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node"
455-
original_val = weights_inputs[0].meta["val"]
456-
assert original_val.fake_mode is not None, "fake_mode is None on weight node"
457-
with original_val.fake_mode:
458-
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
459-
original_val, [2, 0, 1]
460-
)
459+
# Propagate val metadata for transposed_weights
460+
if "val" in weights_inputs[0].meta:
461+
original_val = weights_inputs[0].meta["val"]
462+
fake_mode = original_val.fake_mode
463+
if fake_mode is not None:
464+
with fake_mode:
465+
transposed_val = torch.ops.aten.permute.default(original_val, [2, 0, 1])
466+
transposed_weights.meta["val"] = transposed_val
467+
else:
468+
transposed_weights.meta["val"] = torch.ops.aten.permute.default(
469+
original_val, [2, 0, 1]
470+
)
461471
copy_node_metadata(transposed_weights, weights_inputs[0])
462472

463473
args = (
@@ -511,12 +521,10 @@ def get_args_and_kwargs_mixed_w8a32_gru(
511521
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
512522
# Stride, padding, dilation, groups not supported yet
513523

514-
assert len(dequants_weights) == 2
515524
assert len(dequants_biases) == 2
516525
w_i_scale = dequants_weights[0].args[1]
517526
w_h_scale = dequants_weights[1].args[1]
518-
b_i_scale = dequants_biases[0].args[1]
519-
b_h_scale = dequants_biases[1].args[1]
527+
b_scale = dequants_biases[0].args[1]
520528

521529
args = (
522530
other_inputs[0],
@@ -526,9 +534,8 @@ def get_args_and_kwargs_mixed_w8a32_gru(
526534
weights_inputs[1],
527535
w_h_scale,
528536
bias_inputs[0],
529-
b_i_scale,
537+
b_scale,
530538
bias_inputs[1],
531-
b_h_scale,
532539
)
533540
kwargs = {}
534541

backends/cadence/aot/quantizer/patterns.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def get_anchors(
718718
)
719719

720720
cnn_weights = conv_layer.args[1]
721-
if hasattr(cnn_weights.meta, "tensor_meta"):
721+
if "tensor_meta" in cnn_weights.meta:
722722
cnn_weights_shape = cnn_weights.meta["tensor_meta"].shape
723723
# Bail if the channels are not multiple of 4 (SIMD)
724724
if cnn_weights_shape[0] % 4 != 0:
@@ -744,6 +744,18 @@ def get_anchors(
744744
conv_layer,
745745
)
746746

747+
inputs = conv_layer.args[0]
748+
if "tensor_meta" in inputs.meta:
749+
inputs_shape = inputs.meta["tensor_meta"].shape
750+
# Bail if length != kernel size - Not yet supported
751+
if inputs_shape[-1] != cnn_weights_shape[2]:
752+
return (
753+
PartitionAnchors(
754+
empty=True,
755+
),
756+
conv_layer,
757+
)
758+
747759
return (
748760
PartitionAnchors(
749761
inputs=[],
@@ -777,14 +789,16 @@ def get_anchors(
777789
)
778790

779791
# Bail if input or states are not multiple of 4 (SIMD)
780-
if gru_layer.args[0].meta["tensor_meta"].shape[-1] % 4 != 0:
792+
tensor_meta_0 = gru_layer.args[0].meta.get("tensor_meta", None)
793+
if tensor_meta_0 is None or tensor_meta_0.shape[-1] % 4 != 0:
781794
return (
782795
PartitionAnchors(
783796
empty=True,
784797
),
785798
gru_layer,
786799
)
787-
if gru_layer.args[1].meta["tensor_meta"].shape[-1] % 4 != 0:
800+
tensor_meta_1 = gru_layer.args[1].meta.get("tensor_meta", None)
801+
if tensor_meta_1 is None or tensor_meta_1.shape[-1] % 4 != 0:
788802
return (
789803
PartitionAnchors(
790804
empty=True,
@@ -799,13 +813,26 @@ def __init__(self, args, meta):
799813

800814
wrapper = Wrapper(tuple(gru_layer.args[2]), gru_layer.meta)
801815

816+
# Using SharedQuantizationSpec so that bias_hh has the same observer as bias_ih
817+
# Both biases get the same quantization scale to match the cpp operator
818+
bias_ih_node = wrapper.args[2]
819+
bias_ih_edge = (bias_ih_node, gru_layer)
820+
shared_bias_qspec = SharedQuantizationSpec(edge_or_node=bias_ih_edge)
821+
802822
return (
803823
PartitionAnchors(
804824
inputs=[],
805825
# pyre-fixme[6]: Expected `List[Tuple[Node, int]]` but got `List[Tuple[Wrapper, int]]`.
806826
weights=[(wrapper, 0), (wrapper, 1)],
807827
# pyre-fixme[6]: Expected `List[Union[Tuple[Node, int], Tuple[Node, int, DerivedQuantizationSpec]]]` but got `List[Tuple[Wrapper, int]]`.
808-
biases=[(wrapper, 2), (wrapper, 3)],
828+
biases=[
829+
(wrapper, 2), # bias_ih gets normal qspec
830+
(
831+
wrapper,
832+
3,
833+
shared_bias_qspec,
834+
), # bias_hh shares observer with bias_ih
835+
],
809836
output=[],
810837
others=[(gru_layer, 0), (gru_layer, 1)],
811838
),

backends/cadence/aot/ref_implementations.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,9 +1257,8 @@ def quantized_w8a32_gru(
12571257
weights_hidden: torch.Tensor,
12581258
w_h_scale: float,
12591259
bias_inputs: torch.Tensor,
1260-
b_i_scale: float,
1260+
b_scale: float,
12611261
bias_hidden: torch.Tensor,
1262-
b_h_scale: float,
12631262
) -> torch.Tensor:
12641263
assert weights_inputs.dtype == torch.int8
12651264
assert weights_hidden.dtype == torch.int8
@@ -1288,10 +1287,8 @@ def quantized_w8a32_gru(
12881287
dequant_weights_inputs = weights_inputs.float() * w_i_scale
12891288
dequant_weights_hidden = weights_hidden.float() * w_h_scale
12901289

1291-
# C++ implementation averages the two bias scales
1292-
avg_bias_scale = (b_i_scale + b_h_scale) / 2
1293-
dequant_bias_inputs = bias_inputs.float() * avg_bias_scale
1294-
dequant_bias_hidden = bias_hidden.float() * avg_bias_scale
1290+
dequant_bias_inputs = bias_inputs.float() * b_scale
1291+
dequant_bias_hidden = bias_hidden.float() * b_scale
12951292

12961293
gi = F.linear(inputs, dequant_weights_inputs, dequant_bias_inputs)
12971294
gh = F.linear(hidden, dequant_weights_hidden, dequant_bias_hidden)
@@ -1310,8 +1307,14 @@ def quantized_w8a32_gru(
13101307

13111308
assert new_hidden.shape == original_hidden_shape
13121309

1313-
new_hidden = new_hidden.view(original_hidden_shape)
1314-
return torch.stack([new_hidden, new_hidden], dim=0)
1310+
batch_size = inputs.shape[0]
1311+
input_dim = inputs.shape[1]
1312+
hidden_dim = hidden.shape[-1]
1313+
1314+
new_hidden_expanded = new_hidden.unsqueeze(1).expand(
1315+
batch_size, input_dim, hidden_dim
1316+
)
1317+
return torch.stack([new_hidden_expanded, new_hidden_expanded], dim=0)
13151318

13161319

13171320
@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")

0 commit comments

Comments
 (0)