33from typing import TYPE_CHECKING
44
55import torch
6- from torch .nn .functional import scaled_dot_product_attention
76
87from sglang .srt .layers .attention .base_attn_backend import AttentionBackend
98from sglang .srt .model_executor .forward_batch_info import ForwardBatch
@@ -50,187 +49,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
5049 max_extend_len = torch .max (forward_batch .extend_seq_lens ).item ()
5150 self .forward_metadata = (attn_logits , max_extend_len )
5251
53- def get_graph_seq_len_fill_value (self ):
54- return 1
55-
56-
57- def _run_sdpa_forward_extend (
58- self ,
59- query : torch .Tensor ,
60- output : torch .Tensor ,
61- k_cache : torch .Tensor ,
62- v_cache : torch .Tensor ,
63- req_to_token : torch .Tensor ,
64- req_pool_indices : torch .Tensor ,
65- seq_lens : torch .Tensor ,
66- extend_prefix_lens : torch .Tensor ,
67- extend_seq_lens : torch .Tensor ,
68- encoder_lens = None ,
69- scaling = None ,
70- enable_gqa = False ,
71- causal = False ,
72- is_cross_attn = False ,
73- ):
74- """Run the extend forward by using torch native sdpa op.
75-
76- Args:
77- query: [num_tokens, num_heads, head_size]
78- output: [num_tokens, num_heads, head_size]
79- k_cache: [max_total_num_tokens, num_heads, head_size]
80- v_cache: [max_total_num_tokens, num_heads, head_size]
81- req_to_token: [max_num_reqs, max_context_len]
82- req_pool_indices: [num_seqs]
83- encoder_lens: [num_seqs] or None
84- seq_lens: [num_seqs]
85- extend_prefix_lens: [num_seqs]
86- extend_seq_lens: [num_seqs]
87- scaling: float or None
88- enable_gqa: bool
89- causal: bool
90- is_cross_attn: bool
91-
92- Returns:
93- output: [num_tokens, num_heads, head_size]
94- """
95-
96- assert seq_lens .shape [0 ] == extend_prefix_lens .shape [0 ]
97- assert seq_lens .shape [0 ] == extend_seq_lens .shape [0 ]
98-
99- # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
100- query = query .movedim (0 , query .dim () - 2 )
101-
102- start_q , start_kv = 0 , 0
103- for seq_idx in range (seq_lens .shape [0 ]):
104- # TODO: this loop process a sequence per iter, this is inefficient.
105- # Need optimize the performance later.
106-
107- extend_seq_len_q = extend_seq_lens [seq_idx ]
108- prefill_seq_len_q = extend_prefix_lens [seq_idx ]
109-
110- seq_len_kv = seq_lens [seq_idx ]
111- end_q = start_q + extend_seq_len_q
112- if encoder_lens is not None :
113- start_kv = 0 if is_cross_attn else encoder_lens [seq_idx ]
114- end_kv = (
115- encoder_lens [seq_idx ] if is_cross_attn else start_kv + seq_len_kv
116- )
117- else :
118- start_kv = 0
119- end_kv = start_kv + seq_len_kv
120- per_req_query = query [:, start_q :end_q , :]
121- per_req_query_redudant = torch .empty (
122- (per_req_query .shape [0 ], seq_len_kv , per_req_query .shape [2 ]),
123- dtype = per_req_query .dtype ,
124- device = per_req_query .device ,
125- )
126-
127- per_req_query_redudant [:, prefill_seq_len_q :, :] = per_req_query
128-
129- # get key and value from cache. per_req_tokens contains the kv cache
130- # index for each token in the sequence.
131- req_pool_idx = req_pool_indices [seq_idx ]
132- per_req_tokens = req_to_token [req_pool_idx , start_kv :end_kv ]
133- per_req_key = k_cache [per_req_tokens ].movedim (0 , query .dim () - 2 )
134- per_req_value = v_cache [per_req_tokens ].movedim (0 , query .dim () - 2 )
135-
136- per_req_out_redudant = (
137- scaled_dot_product_attention (
138- per_req_query_redudant .unsqueeze (0 ),
139- per_req_key .unsqueeze (0 ),
140- per_req_value .unsqueeze (0 ),
141- enable_gqa = enable_gqa ,
142- scale = scaling ,
143- is_causal = causal ,
144- )
145- .squeeze (0 )
146- .movedim (query .dim () - 2 , 0 )
147- )
148- output [start_q :end_q , :, :] = per_req_out_redudant [prefill_seq_len_q :, :, :]
149- start_q , start_kv = end_q , end_kv
150- return output
151-
152- def _run_sdpa_forward_decode (
153- self ,
154- query : torch .Tensor ,
155- output : torch .Tensor ,
156- k_cache : torch .Tensor ,
157- v_cache : torch .Tensor ,
158- req_to_token : torch .Tensor ,
159- req_pool_indices : torch .Tensor ,
160- seq_lens : torch .Tensor ,
161- encoder_lens = None ,
162- scaling = None ,
163- enable_gqa = False ,
164- causal = False ,
165- is_cross_attn = False ,
166- ):
167- """Run the decode forward by using torch native sdpa op.
168-
169- Args:
170- query: [num_tokens, num_heads, head_size]
171- output: [num_tokens, num_heads, head_size]
172- k_cache: [max_total_num_tokens, num_heads, head_size]
173- v_cache: [max_total_num_tokens, num_heads, head_size]
174- req_to_token: [max_num_reqs, max_context_len],
175- req_pool_indices: [num_seqs],
176- seq_lens: [num_seqs]
177- encoder_lens: [num_seqs] or None
178- scaling: float or None
179- enable_gqa: bool
180- causal: bool
181- is_cross_attn: bool
182-
183- Returns:
184- output: [num_tokens, num_heads, head_size]
185- """
186-
187- # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
188- query = query .movedim (0 , query .dim () - 2 )
189-
190- start_q , start_kv = 0 , 0
191- for seq_idx in range (seq_lens .shape [0 ]):
192- # TODO: this loop process a sequence per iter, this is inefficient.
193- # Need optimize the performance later.
194-
195- seq_len_q = 1
196- seq_len_kv = seq_lens [seq_idx ]
197- end_q = start_q + seq_len_q
198- if encoder_lens is not None :
199- start_kv = 0 if is_cross_attn else encoder_lens [seq_idx ]
200- end_kv = (
201- encoder_lens [seq_idx ] if is_cross_attn else start_kv + seq_len_kv
202- )
203- else :
204- start_kv = 0
205- end_kv = start_kv + seq_len_kv
206-
207- per_req_query = query [:, start_q :end_q , :]
208-
209- # get key and value from cache. per_req_tokens contains the kv cache
210- # index for each token in the sequence.
211-
212- req_pool_idx = req_pool_indices [seq_idx ]
213- per_req_tokens = req_to_token [req_pool_idx , start_kv :end_kv ]
214- per_req_key = k_cache [per_req_tokens ].movedim (0 , query .dim () - 2 )
215- per_req_value = v_cache [per_req_tokens ].movedim (0 , query .dim () - 2 )
216-
217- per_req_out = (
218- scaled_dot_product_attention (
219- per_req_query .unsqueeze (0 ),
220- per_req_key .unsqueeze (0 ),
221- per_req_value .unsqueeze (0 ),
222- enable_gqa = enable_gqa ,
223- scale = scaling ,
224- is_causal = causal ,
225- )
226- .squeeze (0 )
227- .movedim (query .dim () - 2 , 0 )
228- )
229- output [start_q :end_q , :, :] = per_req_out
230- start_q , start_kv = end_q , end_kv
231-
232- return output
233-
23452 def forward_extend (
23553 self ,
23654 q ,
@@ -239,6 +57,7 @@ def forward_extend(
23957 layer : RadixAttention ,
24058 forward_batch : ForwardBatch ,
24159 save_kv_cache = True ,
60+ sk = None ,
24261 ):
24362 if layer .qk_head_dim != layer .v_head_dim :
24463 o = q .new_empty ((q .shape [0 ], layer .tp_q_head_num * layer .v_head_dim ))
@@ -255,46 +74,24 @@ def forward_extend(
25574 forward_batch .token_to_kv_pool .set_kv_buffer (layer , cache_loc , k , v )
25675
25776 _ , max_extend_len = self .forward_metadata
258- if k is not None :
259- assert v is not None
260- self .extend_attention_fwd (
261- q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
262- k ,
263- v ,
264- o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim ),
265- forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ),
266- forward_batch .token_to_kv_pool .get_value_buffer (layer .layer_id ),
267- forward_batch .req_to_token_pool .req_to_token ,
268- forward_batch .req_pool_indices ,
269- forward_batch .seq_lens ,
270- forward_batch .extend_seq_lens ,
271- forward_batch .extend_start_loc ,
272- max_extend_len ,
273- layer .scaling ,
274- layer .logit_cap ,
275- forward_batch .encoder_lens ,
276- )
277- else :
278- use_gqa = layer .tp_q_head_num != layer .tp_k_head_num
279- q_ = q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim )
280- o_ = o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim )
281-
282- self ._run_sdpa_forward_extend (
283- q_ ,
284- o_ ,
285- forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ),
286- forward_batch .token_to_kv_pool .get_value_buffer (layer .layer_id ),
287- forward_batch .req_to_token_pool .req_to_token ,
288- forward_batch .req_pool_indices ,
289- forward_batch .seq_lens ,
290- forward_batch .extend_prefix_lens ,
291- forward_batch .extend_seq_lens ,
292- encoder_lens = forward_batch .encoder_lens ,
293- scaling = layer .scaling ,
294- enable_gqa = use_gqa ,
295- causal = not layer .is_cross_attention ,
296- is_cross_attn = layer .is_cross_attention ,
297- )
77+ self .extend_attention_fwd (
78+ q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
79+ k ,
80+ v ,
81+ o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim ),
82+ forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ),
83+ forward_batch .token_to_kv_pool .get_value_buffer (layer .layer_id ),
84+ forward_batch .req_to_token_pool .req_to_token ,
85+ forward_batch .req_pool_indices ,
86+ forward_batch .seq_lens ,
87+ forward_batch .extend_seq_lens ,
88+ forward_batch .extend_start_loc ,
89+ max_extend_len ,
90+ layer .scaling ,
91+ layer .logit_cap ,
92+ layer .is_cross_attention ,
93+ forward_batch .encoder_lens ,
94+ )
29895 return o
29996
30097 def forward_decode (
@@ -305,6 +102,7 @@ def forward_decode(
305102 layer : RadixAttention ,
306103 forward_batch : ForwardBatch ,
307104 save_kv_cache = True ,
105+ sk = None ,
308106 ):
309107 attn_logits , _ = self .forward_metadata
310108
@@ -319,45 +117,23 @@ def forward_decode(
319117 if not layer .is_cross_attention
320118 else forward_batch .encoder_out_cache_loc
321119 )
322- if k is not None :
323- assert v is not None
324- self .decode_attention_fwd (
325- q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
326- forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ),
327- forward_batch .token_to_kv_pool .get_value_buffer (layer .layer_id ),
328- o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim ),
329- k ,
330- v ,
331- cache_loc ,
332- attn_logits ,
333- forward_batch .req_to_token_pool .req_to_token ,
334- forward_batch .req_pool_indices ,
335- forward_batch .seq_lens ,
336- layer .scaling ,
337- layer .logit_cap ,
338- forward_batch .encoder_lens ,
339- )
340- else :
341- use_gqa = layer .tp_q_head_num != layer .tp_k_head_num
342-
343- q_ = q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim )
344- o_ = o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim )
345-
346- self ._run_sdpa_forward_decode (
347- q_ ,
348- o_ ,
349- forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ),
350- forward_batch .token_to_kv_pool .get_value_buffer (layer .layer_id ),
351- forward_batch .req_to_token_pool .req_to_token ,
352- forward_batch .req_pool_indices ,
353- forward_batch .seq_lens ,
354- encoder_lens = forward_batch .encoder_lens ,
355- scaling = layer .scaling ,
356- enable_gqa = use_gqa ,
357- causal = False ,
358- is_cross_attn = layer .is_cross_attention ,
359- )
360-
120+ self .decode_attention_fwd (
121+ q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
122+ forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id ),
123+ forward_batch .token_to_kv_pool .get_value_buffer (layer .layer_id ),
124+ o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim ),
125+ k ,
126+ v ,
127+ cache_loc ,
128+ attn_logits ,
129+ forward_batch .req_to_token_pool .req_to_token ,
130+ forward_batch .req_pool_indices ,
131+ forward_batch .seq_lens ,
132+ layer .scaling ,
133+ layer .logit_cap ,
134+ layer .is_cross_attention ,
135+ forward_batch .encoder_lens ,
136+ )
361137 return o
362138
363139 def support_triton (self ):
0 commit comments