@@ -22,6 +22,11 @@ class DecomposeScaledDotProductAttention(ExportPass):
2222 """
2323
2424 _passes_required_after : Set [Type [ExportPass ]] = set ()
25+ _SDPA_OPTIONAL_ARGS = (
26+ ("attn_mask" , None ),
27+ ("dropout_p" , 0.0 ),
28+ ("is_causal" , False ),
29+ )
2530
2631 def __init__ (self , allow_non_fake_inputs : bool = True ) -> None :
2732 super ().__init__ ()
@@ -42,19 +47,137 @@ def call(
4247 graph_module .recompile ()
4348 return PassResult (graph_module , True )
4449
50+ @staticmethod
51+ def _extract_arg_value (arg : object ) -> object :
52+ if isinstance (arg , torch .fx .Node ):
53+ if "val" not in arg .meta :
54+ raise RuntimeError (f"Missing meta['val'] for SDPA arg node: { arg .name } " )
55+ return arg .meta ["val" ]
56+ return arg
57+
58+ @classmethod
59+ def _canonicalize_sdpa_call (
60+ cls , node : torch .fx .Node
61+ ) -> tuple [tuple [object , ...], object , object ]:
62+ input_args = list (node .args )
63+ input_kwargs = dict (node .kwargs )
64+
65+ canonical_args = list (input_args [:3 ])
66+ for arg_index , (arg_name , default ) in enumerate (
67+ cls ._SDPA_OPTIONAL_ARGS , start = 3
68+ ):
69+ if len (input_args ) > arg_index :
70+ canonical_args .append (input_args [arg_index ])
71+ else :
72+ canonical_args .append (input_kwargs .pop (arg_name , default ))
73+
74+ raw_scale = input_kwargs .pop ("scale" , None )
75+ canonical_args .append (raw_scale )
76+ scale = cls ._extract_arg_value (raw_scale )
77+ enable_gqa = cls ._extract_arg_value (input_kwargs .pop ("enable_gqa" , False ))
78+ if input_kwargs :
79+ raise RuntimeError (
80+ "Unsupported kwargs for scaled_dot_product_attention: "
81+ f"{ ', ' .join (sorted (input_kwargs .keys ()))} "
82+ )
83+
84+ return tuple (canonical_args ), scale , enable_gqa
85+
86+ @staticmethod
87+ def _copy_decomposed_graph (
88+ graph : torch .fx .Graph ,
89+ node : torch .fx .Node ,
90+ decomposed_module : torch .fx .GraphModule ,
91+ canonical_inputs : tuple [object , ...],
92+ scale : object ,
93+ ) -> None :
94+ decomposed_node_to_subgraph_node : dict [torch .fx .Node , torch .fx .Node ] = {}
95+ last_decomposed_node = None
96+ placeholder_nodes = [
97+ decomposed_node
98+ for decomposed_node in decomposed_module .graph .nodes
99+ if decomposed_node .op == "placeholder"
100+ ]
101+ if len (placeholder_nodes ) != len (canonical_inputs ):
102+ raise RuntimeError (
103+ "Unexpected placeholder count when decomposing "
104+ "scaled_dot_product_attention"
105+ )
106+ for decomposed_node , arg in zip (placeholder_nodes , canonical_inputs ):
107+ decomposed_node_to_subgraph_node [decomposed_node ] = arg
108+
109+ for decomposed_node in decomposed_module .graph .nodes :
110+ if decomposed_node .op == "output" :
111+ last_decomposed_node = decomposed_node .args [0 ]
112+
113+ for decomposed_node in decomposed_module .graph .nodes :
114+ decomposed_node .meta ["nn_module_stack" ] = node .meta .get ("nn_module_stack" )
115+ if decomposed_node .op == "placeholder" :
116+ continue
117+
118+ if decomposed_node .op == "output" and last_decomposed_node is not None :
119+ for user in node .users .copy ():
120+ user .replace_input_with (
121+ node ,
122+ decomposed_node_to_subgraph_node [last_decomposed_node ],
123+ )
124+ continue
125+
126+ if scale is not None and decomposed_node .target in [
127+ torch .ops .aten .mul .Scalar
128+ ]:
129+ new_args = list (decomposed_node .args )
130+ new_args [1 ] = math .sqrt (scale )
131+ decomposed_node .args = tuple (new_args )
132+
133+ subgraph_node = graph .node_copy (
134+ decomposed_node ,
135+ arg_transform = lambda x : decomposed_node_to_subgraph_node [x ],
136+ )
137+ subgraph_node .meta ["source_fn_stack" ] = [
138+ (subgraph_node , subgraph_node .target )
139+ ]
140+ decomposed_node_to_subgraph_node [decomposed_node ] = subgraph_node
141+
45142 def _decompose_sdpa_node (
46143 self ,
47144 graph_module : torch .fx .GraphModule ,
48145 node : torch .fx .Node ,
49146 allow_non_fake_inputs : bool ,
50147 ) -> None :
51148 graph = graph_module .graph
52- input_tensors = (input_node .meta ["val" ] for input_node in node .all_input_nodes )
53- scale = node .kwargs .get ("scale" , None )
149+
150+ canonical_inputs , scale , enable_gqa = self ._canonicalize_sdpa_call (node )
151+ input_tensors = tuple (self ._extract_arg_value (arg ) for arg in canonical_inputs )
152+
153+ def _sdpa_with_gqa (
154+ q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = None
155+ ):
156+ # Shapes: (B, H, T, D)
157+ Hq = q .shape [1 ]
158+ Hk = k .shape [1 ]
159+ if Hq != Hk :
160+ # LLaMA-style GQA: tile K and V heads to match Q
161+ if Hq % Hk != 0 :
162+ raise ValueError (f"GQA mismatch: Hq={ Hq } , Hk={ Hk } " )
163+ r = Hq // Hk
164+ B , _ , Tk , D = k .shape
165+ k = k .unsqueeze (2 ).expand (B , Hk , r , Tk , D ).reshape (B , Hq , Tk , D )
166+ v = v .unsqueeze (2 ).expand (B , Hk , r , Tk , D ).reshape (B , Hq , Tk , D )
167+ return torch .ops .aten .scaled_dot_product_attention .default (
168+ q ,
169+ k ,
170+ v ,
171+ attn_mask ,
172+ dropout_p ,
173+ is_causal ,
174+ scale = scale ,
175+ enable_gqa = enable_gqa ,
176+ )
54177
55178 # refer to pytorch/test/test_decomp.py
56179 decomposed_module = make_fx (
57- node . target ,
180+ _sdpa_with_gqa ,
58181 decomposition_table = get_decompositions ( # pyre-fixme[6]
59182 [
60183 torch .ops .aten ._scaled_dot_product_flash_attention_for_cpu .default ,
@@ -65,56 +188,8 @@ def _decompose_sdpa_node(
65188 )(* input_tensors )
66189
67190 with graph .inserting_before (node ):
68- name_to_input_tensor_map = {}
69- for i , arg in enumerate (node .args ):
70- name_to_input_tensor_map [f"arg{ i } _1" ] = arg
71-
72- decomposed_node_to_subgraph_node : dict [torch .fx .Node , torch .fx .Node ] = {}
73- last_decomposed_node = None
74- # Create a mapping from input nodes in decomposed module to original nodes.
75- # In decomposed module, there are only input tensors for placeholder op.
76- for decomposed_node in decomposed_module .graph .nodes :
77- if decomposed_node .op == "placeholder" :
78- decomposed_node_to_subgraph_node [decomposed_node ] = (
79- name_to_input_tensor_map [decomposed_node .name ]
80- )
81-
82- if decomposed_node .op == "output" :
83- last_decomposed_node = decomposed_node .args [0 ]
84-
85- # Copy node from decompose graph module
86- for decomposed_node in decomposed_module .graph .nodes :
87- node .meta ["nn_module_stack" ] = decomposed_node .meta .get (
88- "nn_module_stack"
89- )
90- if decomposed_node .op == "placeholder" :
91- continue
92-
93- if decomposed_node .op == "output" and last_decomposed_node is not None :
94- for user in node .users .copy ():
95- user .replace_input_with (
96- node ,
97- decomposed_node_to_subgraph_node [last_decomposed_node ],
98- )
99- continue
100-
101- if scale is not None and decomposed_node .target in [
102- torch .ops .aten .mul .Scalar
103- ]:
104- new_args = list (decomposed_node .args )
105- # Based on the implementation of _scaled_dot_product_attention_math,
106- # the scale is applied to q and k before matmul.
107- # refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
108- new_args [1 ] = math .sqrt (scale )
109- decomposed_node .args = tuple (new_args )
110-
111- subgraph_node = graph .node_copy (
112- decomposed_node ,
113- arg_transform = lambda x : decomposed_node_to_subgraph_node [x ],
114- )
115- subgraph_node .meta ["source_fn_stack" ] = [
116- (subgraph_node , subgraph_node .target )
117- ]
118- decomposed_node_to_subgraph_node [decomposed_node ] = subgraph_node
191+ self ._copy_decomposed_graph (
192+ graph , node , decomposed_module , canonical_inputs , scale
193+ )
119194
120195 graph .erase_node (node )
0 commit comments