@@ -42,19 +42,100 @@ def call(
4242 graph_module .recompile ()
4343 return PassResult (graph_module , True )
4444
45+ @staticmethod
46+ def _extract_input_tensors (node : torch .fx .Node ) -> tuple [object , ...]:
47+ def _extract_arg_value (arg ):
48+ if isinstance (arg , torch .fx .Node ):
49+ if "val" not in arg .meta :
50+ raise RuntimeError (
51+ f"Missing meta['val'] for SDPA arg node: { arg .name } "
52+ )
53+ return arg .meta ["val" ]
54+ return arg
55+
56+ return tuple (_extract_arg_value (arg ) for arg in node .args )
57+
58+ @staticmethod
59+ def _copy_decomposed_graph (
60+ graph : torch .fx .Graph ,
61+ node : torch .fx .Node ,
62+ decomposed_module : torch .fx .GraphModule ,
63+ scale : object ,
64+ ) -> None :
65+ name_to_input_tensor_map = {}
66+ for i , arg in enumerate (node .args ):
67+ name_to_input_tensor_map [f"arg{ i } _1" ] = arg
68+
69+ decomposed_node_to_subgraph_node : dict [torch .fx .Node , torch .fx .Node ] = {}
70+ last_decomposed_node = None
71+ for decomposed_node in decomposed_module .graph .nodes :
72+ if decomposed_node .op == "placeholder" :
73+ decomposed_node_to_subgraph_node [decomposed_node ] = (
74+ name_to_input_tensor_map [decomposed_node .name ]
75+ )
76+
77+ if decomposed_node .op == "output" :
78+ last_decomposed_node = decomposed_node .args [0 ]
79+
80+ for decomposed_node in decomposed_module .graph .nodes :
81+ node .meta ["nn_module_stack" ] = decomposed_node .meta .get ("nn_module_stack" )
82+ if decomposed_node .op == "placeholder" :
83+ continue
84+
85+ if decomposed_node .op == "output" and last_decomposed_node is not None :
86+ for user in node .users .copy ():
87+ user .replace_input_with (
88+ node ,
89+ decomposed_node_to_subgraph_node [last_decomposed_node ],
90+ )
91+ continue
92+
93+ if scale is not None and decomposed_node .target in [
94+ torch .ops .aten .mul .Scalar
95+ ]:
96+ new_args = list (decomposed_node .args )
97+ new_args [1 ] = math .sqrt (scale )
98+ decomposed_node .args = tuple (new_args )
99+
100+ subgraph_node = graph .node_copy (
101+ decomposed_node ,
102+ arg_transform = lambda x : decomposed_node_to_subgraph_node [x ],
103+ )
104+ subgraph_node .meta ["source_fn_stack" ] = [
105+ (subgraph_node , subgraph_node .target )
106+ ]
107+ decomposed_node_to_subgraph_node [decomposed_node ] = subgraph_node
108+
45109 def _decompose_sdpa_node (
46110 self ,
47111 graph_module : torch .fx .GraphModule ,
48112 node : torch .fx .Node ,
49113 allow_non_fake_inputs : bool ,
50114 ) -> None :
51115 graph = graph_module .graph
52- input_tensors = (input_node .meta ["val" ] for input_node in node .all_input_nodes )
116+
117+ input_tensors = self ._extract_input_tensors (node )
53118 scale = node .kwargs .get ("scale" , None )
54119
120+ def _sdpa_with_gqa (* args , ** kwargs ):
121+ # args: (q, k, v, [attn_mask, dropout_p, is_causal, scale])
122+ q , k , v = args [:3 ]
123+ # Shapes: (B, H, T, D)
124+ Hq = q .shape [1 ]
125+ Hk = k .shape [1 ]
126+ if Hq != Hk :
127+ # LLaMA-style GQA: tile K and V heads to match Q
128+ assert Hq % Hk == 0 , f"GQA mismatch: Hq={ Hq } , Hk={ Hk } "
129+ r = Hq // Hk
130+ B , _ , Tk , D = k .shape
131+ k = k .unsqueeze (2 ).expand (B , Hk , r , Tk , D ).reshape (B , Hq , Tk , D )
132+ v = v .unsqueeze (2 ).expand (B , Hk , r , Tk , D ).reshape (B , Hq , Tk , D )
133+ args = (q , k , v ) + tuple (args [3 :])
134+ return torch .ops .aten .scaled_dot_product_attention .default (* args , ** kwargs )
135+
55136 # refer to pytorch/test/test_decomp.py
56137 decomposed_module = make_fx (
57- node . target ,
138+ _sdpa_with_gqa ,
58139 decomposition_table = get_decompositions ( # pyre-fixme[6]
59140 [
60141 torch .ops .aten ._scaled_dot_product_flash_attention_for_cpu .default ,
@@ -65,56 +146,6 @@ def _decompose_sdpa_node(
65146 )(* input_tensors )
66147
67148 with graph .inserting_before (node ):
68- name_to_input_tensor_map = {}
69- for i , arg in enumerate (node .args ):
70- name_to_input_tensor_map [f"arg{ i } _1" ] = arg
71-
72- decomposed_node_to_subgraph_node : dict [torch .fx .Node , torch .fx .Node ] = {}
73- last_decomposed_node = None
74- # Create a mapping from input nodes in decomposed module to original nodes.
75- # In decomposed module, there are only input tensors for placeholder op.
76- for decomposed_node in decomposed_module .graph .nodes :
77- if decomposed_node .op == "placeholder" :
78- decomposed_node_to_subgraph_node [decomposed_node ] = (
79- name_to_input_tensor_map [decomposed_node .name ]
80- )
81-
82- if decomposed_node .op == "output" :
83- last_decomposed_node = decomposed_node .args [0 ]
84-
85- # Copy node from decompose graph module
86- for decomposed_node in decomposed_module .graph .nodes :
87- node .meta ["nn_module_stack" ] = decomposed_node .meta .get (
88- "nn_module_stack"
89- )
90- if decomposed_node .op == "placeholder" :
91- continue
92-
93- if decomposed_node .op == "output" and last_decomposed_node is not None :
94- for user in node .users .copy ():
95- user .replace_input_with (
96- node ,
97- decomposed_node_to_subgraph_node [last_decomposed_node ],
98- )
99- continue
100-
101- if scale is not None and decomposed_node .target in [
102- torch .ops .aten .mul .Scalar
103- ]:
104- new_args = list (decomposed_node .args )
105- # Based on the implementation of _scaled_dot_product_attention_math,
106- # the scale is applied to q and k before matmul.
107- # refer to pytorch/aten/src/ATen/native/transformers/attention.cpp#L873
108- new_args [1 ] = math .sqrt (scale )
109- decomposed_node .args = tuple (new_args )
110-
111- subgraph_node = graph .node_copy (
112- decomposed_node ,
113- arg_transform = lambda x : decomposed_node_to_subgraph_node [x ],
114- )
115- subgraph_node .meta ["source_fn_stack" ] = [
116- (subgraph_node , subgraph_node .target )
117- ]
118- decomposed_node_to_subgraph_node [decomposed_node ] = subgraph_node
149+ self ._copy_decomposed_graph (graph , node , decomposed_module , scale )
119150
120151 graph .erase_node (node )
0 commit comments