Skip to content

Commit 55b382d

Browse files
author
ssjia
committed
Update on "[ET-VK] Update fused SDPA operator to support ViT attention"
This diff extends the ET-VK fused SDPA operator so it can be used for the ViT attention blocks in the EdgeTAM ViT-S encoder. The main correctness problem is that QK^T dot products in ViT attention can exceed the fp16 max (65504), so fp32 accumulation is required. **fp16 overflow fix**: The intermediate `attn_weights` buffer is now always fp32 regardless of input dtype. Previously the QK shader accumulated in fp32 but stored to an fp16 buffer, causing overflow. The softmax shader reads fp32 attention weights and writes fp16 softmax output (safe since values are in [0, 1]). **Texture support**: The QK and AV shaders support both buffer and texture3d storage for Q/K/V/output. The intermediate `attn_weights` and `attn_weights_softmax` tensors now inherit the storage type of the input/output (q_projected for the LLM path, out for the fused path), so the entire fused SDPA pipeline runs in a uniform storage type and no SDPA-internal layout transitions are needed. Differential Revision: [D102360200](https://our.internmc.facebook.com/intern/diff/D102360200/) [ghstack-poisoned]
2 parents 7a7df9e + 2458318 commit 55b382d

67 files changed

Lines changed: 3042 additions & 649 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/docker/build.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/bin/bash
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
4+
# Copyright 2026 Arm Limited and/or its affiliates.
45
#
56
# This source code is licensed under the BSD-style license found in the
67
# LICENSE file in the root directory of this source tree.
@@ -94,11 +95,6 @@ BUILD_DOCS=1
9495
# Copy requirements-lintrunner.txt from root to here
9596
cp ../../requirements-lintrunner.txt ./
9697

97-
# Copy arm setup script from root to here
98-
# TODO(huydhn): Figure out a way to rebuild the Docker image automatically
99-
# with a new image hash when the content here is updated
100-
cp -r ../../examples/arm/ ./arm
101-
10298
docker build \
10399
--no-cache \
104100
--progress=plain \

backends/apple/metal/runtime/metal_backend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <c10/util/safe_numerics.h>
910
#include <dlfcn.h>
1011
#include <executorch/runtime/backend/interface.h>
1112
#include <executorch/runtime/core/error.h>
@@ -459,8 +460,10 @@ class ET_EXPERIMENTAL MetalBackend final
459460

460461
ET_LOG(Debug, "MetalBackend n_outputs %zd generated", n_outputs);
461462

463+
size_t n_io_sum = 0;
462464
ET_CHECK_OR_RETURN_ERROR(
463-
n_inputs + n_outputs == args.size(),
465+
!c10::add_overflows(n_inputs, n_outputs, &n_io_sum) &&
466+
n_io_sum == args.size(),
464467
InvalidArgument,
465468
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
466469
n_inputs,

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
from typing import Set, Type
7+
from collections.abc import Mapping
8+
from typing import Sequence, Set, Type
89

910
import torch._export.utils
1011
import torch.fx
@@ -18,6 +19,7 @@
1819
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
1920
FuseEqualPlaceholdersPass,
2021
)
22+
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
2123
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2224
from executorch.backends.transforms.utils import (
2325
create_constant_placeholder,
@@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5355
super().__init__(*args, **kwargs)
5456
self.exported_program = exported_program
5557

58+
@staticmethod
59+
def _is_tosa_dialect_op(target) -> bool:
60+
target_str = str(target)
61+
return (
62+
"executorch.exir.dialects.backend._ops.tosa." in target_str
63+
or "<EdgeOpOverload: tosa." in target_str
64+
)
65+
66+
@staticmethod
67+
def _arg_contains_symbolic_shape(arg) -> bool:
68+
if isinstance(arg, torch.fx.Node):
69+
if meta_has_shape_mark(arg.meta):
70+
return True
71+
return FuseConstantArgsPass._arg_contains_symbolic_shape(
72+
arg.meta.get("val")
73+
)
74+
if isinstance(arg, torch.SymInt):
75+
return True
76+
if isinstance(arg, Mapping):
77+
return any(
78+
FuseConstantArgsPass._arg_contains_symbolic_shape(k)
79+
or FuseConstantArgsPass._arg_contains_symbolic_shape(v)
80+
for k, v in arg.items()
81+
)
82+
if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)):
83+
return any(
84+
FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg
85+
)
86+
return False
87+
5688
def _propagate_special_dtype(self, from_nodes, to_node, data):
5789
"""Propagate special dtype meta if it exists."""
5890
special_dtypes = set()
@@ -142,13 +174,13 @@ def call(self, graph_module):
142174
for node in graph_module.graph.nodes:
143175
if node.op != "call_function":
144176
continue
145-
if node.target in [
146-
exir_ops.backend.tosa.MATMUL.default,
147-
exir_ops.backend.tosa.RESCALE.default,
148-
exir_ops.backend.tosa.RESIZE.default,
149-
exir_ops.backend.tosa.TABLE.default,
150-
exir_ops.backend.tosa.TRANSPOSE.default,
151-
]:
177+
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
178+
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
179+
if (
180+
self._is_tosa_dialect_op(node.target)
181+
or self._arg_contains_symbolic_shape(node.args)
182+
or self._arg_contains_symbolic_shape(node.kwargs)
183+
):
152184
continue
153185

154186
input_nodes = node.all_input_nodes
@@ -164,7 +196,6 @@ def call(self, graph_module):
164196
)
165197
if not all(input_nodes_constant):
166198
continue
167-
168199
try:
169200
did_fuse = self._fuse_nodes(node)
170201
if did_fuse:

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
get_input_qparams,
2222
get_output_qparams,
2323
)
24+
from executorch.backends.arm._passes.symbolic_value_range import (
25+
evaluate_symbolic_expr_values,
26+
)
2427
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2528
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2629
from executorch.backends.arm.tosa.specification import get_context_shape_env
@@ -83,16 +86,22 @@ def _adjust_pad_if_needed(
8386

8487
if isinstance(mod_remainder, torch.SymInt):
8588
shape_env = get_context_shape_env()
86-
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
87-
mod_remainder_upper = int(value_ranges.upper)
89+
exact_values = evaluate_symbolic_expr_values(
90+
mod_remainder.node.expr, shape_env
91+
)
92+
if exact_values is not None:
93+
mod_remainder_upper = max(exact_values)
94+
else:
95+
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
96+
mod_remainder_upper = int(value_ranges.upper)
8897
if mod_remainder_upper == 0:
8998
mod_remainder = 0
9099
else:
91100
mod_remainder_upper = mod_remainder
92101

93102
if mod_remainder_upper > pad:
94103
raise RuntimeError(
95-
"This case should be handled by the SizeAdjustInputPass, is it enabled?"
104+
"This case should be handled by the SizeAdjustInputPass, is it enabled?\n"
96105
)
97106
return pad - mod_remainder
98107

backends/arm/_passes/size_adjust_input_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import cast, Sequence, Set, Type, TypeAlias
66

7+
import torch
78
import torch.fx
89
from executorch.backends.arm._passes import ArmPass
910
from executorch.backends.arm._passes.arm_pass_utils import (
@@ -12,6 +13,9 @@
1213
)
1314
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1415
from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass
16+
from executorch.backends.arm._passes.symbolic_value_range import (
17+
evaluate_symbolic_expr_values,
18+
)
1519
from executorch.backends.arm.tosa.specification import get_context_shape_env
1620
from executorch.exir.dialects._ops import ops as exir_ops
1721
from executorch.exir.pass_base import ExportPass, PassResult
@@ -49,6 +53,9 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
4953
"""Returns whether an int or SymInt is greater than another value."""
5054
if isinstance(input, torch.SymInt):
5155
shape_env = get_context_shape_env()
56+
exact_values = evaluate_symbolic_expr_values(input.node.expr, shape_env)
57+
if exact_values is not None:
58+
return max(exact_values) > other
5259
value_ranges = shape_env.bound_sympy(input.node.expr)
5360
return value_ranges.upper > other
5461
else:
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Optional
7+
8+
import sympy # type: ignore[import-untyped]
9+
import torch
10+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
11+
from torch.utils._sympy.interp import sympy_interp
12+
13+
_MAX_SET_SIZE = 256
14+
_ExactValues = Optional[frozenset[sympy.Basic]]
15+
16+
17+
def _expr_to_int(sym_expr: sympy.Basic) -> Optional[int]:
18+
if isinstance(sym_expr, int):
19+
return sym_expr
20+
if isinstance(sym_expr, sympy.Integer):
21+
return int(sym_expr)
22+
if getattr(sym_expr, "is_integer", False) and sym_expr.is_number:
23+
return int(sym_expr)
24+
return None
25+
26+
27+
def _symbol_values(symbol: sympy.Symbol, shape_env: ShapeEnv) -> _ExactValues:
28+
value_range = shape_env.var_to_range.get(symbol)
29+
if value_range is None or not value_range.is_int:
30+
return None
31+
32+
lower = _expr_to_int(value_range.lower)
33+
upper = _expr_to_int(value_range.upper)
34+
if lower is None or upper is None or upper < lower:
35+
return None
36+
if upper - lower + 1 > _MAX_SET_SIZE:
37+
return None
38+
39+
return frozenset(sympy.Integer(value) for value in range(lower, upper + 1))
40+
41+
42+
def _map_values(values: _ExactValues, fn) -> _ExactValues:
43+
if values is None:
44+
return None
45+
46+
result = {sympy.simplify(fn(value)) for value in values}
47+
if len(result) > _MAX_SET_SIZE:
48+
return None
49+
return frozenset(result)
50+
51+
52+
def _combine_values(lhs: _ExactValues, rhs: _ExactValues, fn) -> _ExactValues:
53+
if lhs is None or rhs is None:
54+
return None
55+
if len(lhs) * len(rhs) > _MAX_SET_SIZE * _MAX_SET_SIZE:
56+
return None
57+
58+
result = {sympy.simplify(fn(a, b)) for a in lhs for b in rhs}
59+
if len(result) > _MAX_SET_SIZE:
60+
return None
61+
return frozenset(result)
62+
63+
64+
class _ExactValueAnalysis:
65+
@staticmethod
66+
def constant(value, dtype) -> frozenset[sympy.Basic]:
67+
return frozenset({sympy.sympify(value)})
68+
69+
@staticmethod
70+
def add(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
71+
return _combine_values(lhs, rhs, lambda a, b: a + b)
72+
73+
@staticmethod
74+
def mul(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
75+
return _combine_values(lhs, rhs, lambda a, b: a * b)
76+
77+
@staticmethod
78+
def mod(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
79+
if rhs is None or any(value == 0 for value in rhs):
80+
return None
81+
return _combine_values(lhs, rhs, lambda a, b: sympy.Mod(a, b))
82+
83+
@staticmethod
84+
def pow(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
85+
return _combine_values(lhs, rhs, lambda a, b: a**b)
86+
87+
@staticmethod
88+
def floor_to_int(values: _ExactValues, dtype) -> _ExactValues:
89+
return _map_values(values, sympy.floor)
90+
91+
@staticmethod
92+
def sym_sum(args: list[_ExactValues]) -> _ExactValues:
93+
acc: _ExactValues = frozenset({sympy.Integer(0)})
94+
for arg in args:
95+
acc = _ExactValueAnalysis.add(acc, arg)
96+
if acc is None:
97+
return None
98+
return acc
99+
100+
101+
def evaluate_symbolic_expr_values(
102+
expr: sympy.Basic | torch.SymInt,
103+
shape_env: ShapeEnv,
104+
) -> Optional[set[int]]:
105+
"""Return a best-effort finite set of possible integer values.
106+
107+
The helper first relies on ``bound_sympy`` for cheap singleton detection.
108+
When interval bounds are not precise enough, it falls back to a small
109+
exact-set analysis over bounded symbols using ``sympy_interp``.
110+
111+
"""
112+
root_expr = sympy.simplify(
113+
expr.node.expr if isinstance(expr, torch.SymInt) else expr
114+
)
115+
value_range = shape_env.bound_sympy(root_expr)
116+
if value_range.is_int and value_range.is_singleton():
117+
singleton = _expr_to_int(value_range.lower)
118+
return {singleton} if singleton is not None else None
119+
120+
exact_values = sympy_interp(
121+
_ExactValueAnalysis,
122+
{
123+
symbol: _symbol_values(symbol, shape_env)
124+
for symbol in root_expr.free_symbols
125+
},
126+
root_expr,
127+
missing_handler=lambda symbol: _symbol_values(symbol, shape_env),
128+
)
129+
if exact_values is None:
130+
return None
131+
132+
result: set[int] = set()
133+
for value in exact_values:
134+
integer_value = _expr_to_int(sympy.simplify(value))
135+
if integer_value is None:
136+
return None
137+
result.add(integer_value)
138+
return result

0 commit comments

Comments
 (0)