Skip to content

Commit 8e2ae3d

Browse files
authored
add torch_native_sink backend and enable decode with sink (mingfeima#57)
1 parent 514f1c9 commit 8e2ae3d

12 files changed

Lines changed: 795 additions & 476 deletions

File tree

python/sglang/srt/layers/activation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
135135
x = F.relu(x)
136136
return x * x
137137

138+
class SwiGLU(CustomOp):
139+
def forward_native(self, x: torch.Tensor, alpha: float = 1.702, pair_wise: bool = True) -> torch.Tensor:
140+
# reference implementation
141+
if not pair_wise:
142+
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
143+
else:
144+
x_glu, x_linear = x[..., ::2], x[..., 1::2]
145+
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
146+
return out_glu * (x_linear + 1) # Note that here add an extra bias of 1 to the linear layer
147+
148+
def forward_cuda(self, x: torch.Tensor, alpha: float = 1.702, pair_wise: bool = True) -> torch.Tensor:
149+
# TODO: Implement the CUDA kernel for SwiGLU in sgl-kernel
150+
return self.forward_native(x, alpha, pair_wise)
138151

139152
class QuickGELU(CustomOp):
140153
def forward_native(self, x: torch.Tensor) -> torch.Tensor:

python/sglang/srt/layers/attention/intel_amx_backend.py

Lines changed: 37 additions & 261 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import TYPE_CHECKING
44

55
import torch
6-
from torch.nn.functional import scaled_dot_product_attention
76

87
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
98
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -50,187 +49,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
5049
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
5150
self.forward_metadata = (attn_logits, max_extend_len)
5251

53-
def get_graph_seq_len_fill_value(self):
54-
return 1
55-
56-
57-
def _run_sdpa_forward_extend(
58-
self,
59-
query: torch.Tensor,
60-
output: torch.Tensor,
61-
k_cache: torch.Tensor,
62-
v_cache: torch.Tensor,
63-
req_to_token: torch.Tensor,
64-
req_pool_indices: torch.Tensor,
65-
seq_lens: torch.Tensor,
66-
extend_prefix_lens: torch.Tensor,
67-
extend_seq_lens: torch.Tensor,
68-
encoder_lens=None,
69-
scaling=None,
70-
enable_gqa=False,
71-
causal=False,
72-
is_cross_attn=False,
73-
):
74-
"""Run the extend forward by using torch native sdpa op.
75-
76-
Args:
77-
query: [num_tokens, num_heads, head_size]
78-
output: [num_tokens, num_heads, head_size]
79-
k_cache: [max_total_num_tokens, num_heads, head_size]
80-
v_cache: [max_total_num_tokens, num_heads, head_size]
81-
req_to_token: [max_num_reqs, max_context_len]
82-
req_pool_indices: [num_seqs]
83-
encoder_lens: [num_seqs] or None
84-
seq_lens: [num_seqs]
85-
extend_prefix_lens: [num_seqs]
86-
extend_seq_lens: [num_seqs]
87-
scaling: float or None
88-
enable_gqa: bool
89-
causal: bool
90-
is_cross_attn: bool
91-
92-
Returns:
93-
output: [num_tokens, num_heads, head_size]
94-
"""
95-
96-
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
97-
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
98-
99-
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
100-
query = query.movedim(0, query.dim() - 2)
101-
102-
start_q, start_kv = 0, 0
103-
for seq_idx in range(seq_lens.shape[0]):
104-
# TODO: this loop process a sequence per iter, this is inefficient.
105-
# Need optimize the performance later.
106-
107-
extend_seq_len_q = extend_seq_lens[seq_idx]
108-
prefill_seq_len_q = extend_prefix_lens[seq_idx]
109-
110-
seq_len_kv = seq_lens[seq_idx]
111-
end_q = start_q + extend_seq_len_q
112-
if encoder_lens is not None:
113-
start_kv = 0 if is_cross_attn else encoder_lens[seq_idx]
114-
end_kv = (
115-
encoder_lens[seq_idx] if is_cross_attn else start_kv + seq_len_kv
116-
)
117-
else:
118-
start_kv = 0
119-
end_kv = start_kv + seq_len_kv
120-
per_req_query = query[:, start_q:end_q, :]
121-
per_req_query_redudant = torch.empty(
122-
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
123-
dtype=per_req_query.dtype,
124-
device=per_req_query.device,
125-
)
126-
127-
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
128-
129-
# get key and value from cache. per_req_tokens contains the kv cache
130-
# index for each token in the sequence.
131-
req_pool_idx = req_pool_indices[seq_idx]
132-
per_req_tokens = req_to_token[req_pool_idx, start_kv:end_kv]
133-
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
134-
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
135-
136-
per_req_out_redudant = (
137-
scaled_dot_product_attention(
138-
per_req_query_redudant.unsqueeze(0),
139-
per_req_key.unsqueeze(0),
140-
per_req_value.unsqueeze(0),
141-
enable_gqa=enable_gqa,
142-
scale=scaling,
143-
is_causal=causal,
144-
)
145-
.squeeze(0)
146-
.movedim(query.dim() - 2, 0)
147-
)
148-
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
149-
start_q, start_kv = end_q, end_kv
150-
return output
151-
152-
def _run_sdpa_forward_decode(
153-
self,
154-
query: torch.Tensor,
155-
output: torch.Tensor,
156-
k_cache: torch.Tensor,
157-
v_cache: torch.Tensor,
158-
req_to_token: torch.Tensor,
159-
req_pool_indices: torch.Tensor,
160-
seq_lens: torch.Tensor,
161-
encoder_lens=None,
162-
scaling=None,
163-
enable_gqa=False,
164-
causal=False,
165-
is_cross_attn=False,
166-
):
167-
"""Run the decode forward by using torch native sdpa op.
168-
169-
Args:
170-
query: [num_tokens, num_heads, head_size]
171-
output: [num_tokens, num_heads, head_size]
172-
k_cache: [max_total_num_tokens, num_heads, head_size]
173-
v_cache: [max_total_num_tokens, num_heads, head_size]
174-
req_to_token: [max_num_reqs, max_context_len],
175-
req_pool_indices: [num_seqs],
176-
seq_lens: [num_seqs]
177-
encoder_lens: [num_seqs] or None
178-
scaling: float or None
179-
enable_gqa: bool
180-
causal: bool
181-
is_cross_attn: bool
182-
183-
Returns:
184-
output: [num_tokens, num_heads, head_size]
185-
"""
186-
187-
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
188-
query = query.movedim(0, query.dim() - 2)
189-
190-
start_q, start_kv = 0, 0
191-
for seq_idx in range(seq_lens.shape[0]):
192-
# TODO: this loop process a sequence per iter, this is inefficient.
193-
# Need optimize the performance later.
194-
195-
seq_len_q = 1
196-
seq_len_kv = seq_lens[seq_idx]
197-
end_q = start_q + seq_len_q
198-
if encoder_lens is not None:
199-
start_kv = 0 if is_cross_attn else encoder_lens[seq_idx]
200-
end_kv = (
201-
encoder_lens[seq_idx] if is_cross_attn else start_kv + seq_len_kv
202-
)
203-
else:
204-
start_kv = 0
205-
end_kv = start_kv + seq_len_kv
206-
207-
per_req_query = query[:, start_q:end_q, :]
208-
209-
# get key and value from cache. per_req_tokens contains the kv cache
210-
# index for each token in the sequence.
211-
212-
req_pool_idx = req_pool_indices[seq_idx]
213-
per_req_tokens = req_to_token[req_pool_idx, start_kv:end_kv]
214-
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
215-
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
216-
217-
per_req_out = (
218-
scaled_dot_product_attention(
219-
per_req_query.unsqueeze(0),
220-
per_req_key.unsqueeze(0),
221-
per_req_value.unsqueeze(0),
222-
enable_gqa=enable_gqa,
223-
scale=scaling,
224-
is_causal=causal,
225-
)
226-
.squeeze(0)
227-
.movedim(query.dim() - 2, 0)
228-
)
229-
output[start_q:end_q, :, :] = per_req_out
230-
start_q, start_kv = end_q, end_kv
231-
232-
return output
233-
23452
def forward_extend(
23553
self,
23654
q,
@@ -239,6 +57,7 @@ def forward_extend(
23957
layer: RadixAttention,
24058
forward_batch: ForwardBatch,
24159
save_kv_cache=True,
60+
sk=None,
24261
):
24362
if layer.qk_head_dim != layer.v_head_dim:
24463
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
@@ -255,46 +74,24 @@ def forward_extend(
25574
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
25675

25776
_, max_extend_len = self.forward_metadata
258-
if k is not None:
259-
assert v is not None
260-
self.extend_attention_fwd(
261-
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
262-
k,
263-
v,
264-
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
265-
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
266-
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
267-
forward_batch.req_to_token_pool.req_to_token,
268-
forward_batch.req_pool_indices,
269-
forward_batch.seq_lens,
270-
forward_batch.extend_seq_lens,
271-
forward_batch.extend_start_loc,
272-
max_extend_len,
273-
layer.scaling,
274-
layer.logit_cap,
275-
forward_batch.encoder_lens,
276-
)
277-
else:
278-
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
279-
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
280-
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
281-
282-
self._run_sdpa_forward_extend(
283-
q_,
284-
o_,
285-
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
286-
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
287-
forward_batch.req_to_token_pool.req_to_token,
288-
forward_batch.req_pool_indices,
289-
forward_batch.seq_lens,
290-
forward_batch.extend_prefix_lens,
291-
forward_batch.extend_seq_lens,
292-
encoder_lens=forward_batch.encoder_lens,
293-
scaling=layer.scaling,
294-
enable_gqa=use_gqa,
295-
causal=not layer.is_cross_attention,
296-
is_cross_attn=layer.is_cross_attention,
297-
)
77+
self.extend_attention_fwd(
78+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
79+
k,
80+
v,
81+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
82+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
83+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
84+
forward_batch.req_to_token_pool.req_to_token,
85+
forward_batch.req_pool_indices,
86+
forward_batch.seq_lens,
87+
forward_batch.extend_seq_lens,
88+
forward_batch.extend_start_loc,
89+
max_extend_len,
90+
layer.scaling,
91+
layer.logit_cap,
92+
layer.is_cross_attention,
93+
forward_batch.encoder_lens,
94+
)
29895
return o
29996

30097
def forward_decode(
@@ -305,6 +102,7 @@ def forward_decode(
305102
layer: RadixAttention,
306103
forward_batch: ForwardBatch,
307104
save_kv_cache=True,
105+
sk=None,
308106
):
309107
attn_logits, _ = self.forward_metadata
310108

@@ -319,45 +117,23 @@ def forward_decode(
319117
if not layer.is_cross_attention
320118
else forward_batch.encoder_out_cache_loc
321119
)
322-
if k is not None:
323-
assert v is not None
324-
self.decode_attention_fwd(
325-
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
326-
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
327-
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
328-
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
329-
k,
330-
v,
331-
cache_loc,
332-
attn_logits,
333-
forward_batch.req_to_token_pool.req_to_token,
334-
forward_batch.req_pool_indices,
335-
forward_batch.seq_lens,
336-
layer.scaling,
337-
layer.logit_cap,
338-
forward_batch.encoder_lens,
339-
)
340-
else:
341-
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
342-
343-
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
344-
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
345-
346-
self._run_sdpa_forward_decode(
347-
q_,
348-
o_,
349-
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
350-
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
351-
forward_batch.req_to_token_pool.req_to_token,
352-
forward_batch.req_pool_indices,
353-
forward_batch.seq_lens,
354-
encoder_lens=forward_batch.encoder_lens,
355-
scaling=layer.scaling,
356-
enable_gqa=use_gqa,
357-
causal=False,
358-
is_cross_attn=layer.is_cross_attention,
359-
)
360-
120+
self.decode_attention_fwd(
121+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
122+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
123+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
124+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
125+
k,
126+
v,
127+
cache_loc,
128+
attn_logits,
129+
forward_batch.req_to_token_pool.req_to_token,
130+
forward_batch.req_pool_indices,
131+
forward_batch.seq_lens,
132+
layer.scaling,
133+
layer.logit_cap,
134+
layer.is_cross_attention,
135+
forward_batch.encoder_lens,
136+
)
361137
return o
362138

363139
def support_triton(self):

0 commit comments

Comments
 (0)