forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecompose_sdpa.py
More file actions
120 lines (101 loc) · 4.73 KB
/
decompose_sdpa.py
File metadata and controls
120 lines (101 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import math
from typing import Set, Type
import torch
from executorch.exir.pass_base import ExportPass, PassResult
from torch._decomp import get_decompositions
from torch.fx.experimental.proxy_tensor import make_fx
class DecomposeScaledDotProductAttention(ExportPass):
"""
Decompose from scaled_dot_product_attention to multiple nodes.
"""
_passes_required_after: Set[Type[ExportPass]] = set()
def __init__(self, allow_non_fake_inputs: bool = True) -> None:
super().__init__()
# With allow_non_fake_inputs=False, we don't get _unsafe_view ops
# in the graph, we allow disabling it here.
self._allow_non_fake_inputs = allow_non_fake_inputs
def call(
self, graph_module: torch.fx.GraphModule, allow_non_fake_inputs: bool = True
) -> PassResult:
graph = graph_module.graph
for node in list(graph.nodes):
if node.target != torch.ops.aten.scaled_dot_product_attention.default:
continue
self._decompose_sdpa_node(graph_module, node, allow_non_fake_inputs)
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
def _decompose_sdpa_node(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
allow_non_fake_inputs: bool,
) -> None:
graph = graph_module.graph
input_tensors = (input_node.meta["val"] for input_node in node.all_input_nodes)
scale = node.kwargs.get("scale", None)
# refer to pytorch/test/test_decomp.py
decomposed_module = make_fx(
node.target,
decomposition_table=get_decompositions( # pyre-fixme[6]
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
]
),
tracing_mode="fake",
_allow_non_fake_inputs=allow_non_fake_inputs,
)(*input_tensors)
with graph.inserting_before(node):
name_to_input_tensor_map = {}
for i, arg in enumerate(node.args):
name_to_input_tensor_map[f"arg{i}_1"] = arg
decomposed_node_to_subgraph_node: dict[torch.fx.Node, torch.fx.Node] = {}
last_decomposed_node = None
# Create a mapping from input nodes in decomposed module to original nodes.
# In decomposed module, there are only input tensors for placeholder op.
for decomposed_node in decomposed_module.graph.nodes:
if decomposed_node.op == "placeholder":
decomposed_node_to_subgraph_node[decomposed_node] = (
name_to_input_tensor_map[decomposed_node.name]
)
if decomposed_node.op == "output":
last_decomposed_node = decomposed_node.args[0]
# Copy node from decompose graph module
for decomposed_node in decomposed_module.graph.nodes:
node.meta["nn_module_stack"] = decomposed_node.meta.get(
"nn_module_stack"
)
if decomposed_node.op == "placeholder":
continue
if decomposed_node.op == "output" and last_decomposed_node is not None:
for user in node.users.copy():
user.replace_input_with(
node,
decomposed_node_to_subgraph_node[last_decomposed_node],
)
continue
if scale is not None and decomposed_node.target in [
torch.ops.aten.mul.Scalar
]:
new_args = list(decomposed_node.args)
# Based on the implementation of _scaled_dot_product_attention_math,
# the scale is applied to q and k before matmul.
# refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
new_args[1] = math.sqrt(scale)
decomposed_node.args = tuple(new_args)
subgraph_node = graph.node_copy(
decomposed_node,
arg_transform=lambda x: decomposed_node_to_subgraph_node[x],
)
subgraph_node.meta["source_fn_stack"] = [
(subgraph_node, subgraph_node.target)
]
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
graph.erase_node(node)