Skip to content

Commit 35399f0

Browse files
committed
update forward
1 parent 2120b72 commit 35399f0

3 files changed

Lines changed: 32 additions & 92 deletions

File tree

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 6 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from __future__ import annotations
1818

1919
import math
20-
import os
2120
import re
2221
from typing import Dict
2322

@@ -345,9 +344,6 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
345344

346345
self.prefix = prefix
347346

348-
prop = paddle.device.cuda.get_device_properties()
349-
self.prop = prop
350-
351347
@staticmethod
352348
def yarn_get_mscale(scale=1, mscale=1):
353349
""" """
@@ -366,22 +362,6 @@ def forward(
366362
fused_read_cache_and_interleave,
367363
)
368364

369-
<<<<<<< HEAD
370-
q_total_token_num = hidden_states.shape[0]
371-
=======
372-
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
373-
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0
374-
375-
# Idle pass (e.g. CUDAGraph padding): skip all attention computation
376-
if not need_do_prefill and not need_do_decode:
377-
return self.o_proj(
378-
paddle.zeros(
379-
[hidden_states.shape[0], self.num_attention_heads_tp * self.v_head_dim],
380-
dtype=hidden_states.dtype,
381-
)
382-
)
383-
>>>>>>> 15a153a24 (update forward)
384-
385365
attn_out = None
386366
if self.use_gated_attn:
387367
gate_out = self.gate(hidden_states)
@@ -459,36 +439,6 @@ def forward(
459439
attn_out = fmha_out
460440

461441
if need_do_decode: # max_dec_len_this_time
462-
463-
if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
464-
pass
465-
else:
466-
from fastdeploy.model_executor.layers.attention.mla_attention_backend import (
467-
extract_decoder_token_from_q,
468-
insert_decoder_result_back,
469-
)
470-
471-
decoder_query_nope, cache_seqlens = extract_decoder_token_from_q(
472-
query_nope.reshape([0, -1]),
473-
forward_meta.cu_seqlens_q,
474-
forward_meta.seq_lens_encoder,
475-
forward_meta.seq_lens_decoder,
476-
)
477-
478-
decoder_query_pe, cache_seqlens = extract_decoder_token_from_q(
479-
query_pe.reshape([0, -1]),
480-
forward_meta.cu_seqlens_q,
481-
forward_meta.seq_lens_encoder,
482-
forward_meta.seq_lens_decoder,
483-
)
484-
assert decoder_query_nope.shape[0] == forward_meta.seq_lens_encoder.shape[0]
485-
assert decoder_query_pe.shape[0] == forward_meta.seq_lens_encoder.shape[0]
486-
487-
forward_meta.cache_seqlens = cache_seqlens
488-
489-
query_nope = decoder_query_nope.reshape([0, -1, self.qk_nope_head_dim])
490-
query_pe = decoder_query_pe.reshape([0, -1, self.qk_rope_head_dim])
491-
492442
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
493443

494444
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
@@ -517,17 +467,6 @@ def forward(
517467
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
518468
)
519469

520-
if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9:
521-
pass
522-
else:
523-
fmqa_out = insert_decoder_result_back(
524-
fmqa_out.reshape([0, 1, self.num_attention_heads_tp, self.v_head_dim]),
525-
forward_meta.cu_seqlens_q,
526-
forward_meta.seq_lens_encoder,
527-
forward_meta.seq_lens_decoder,
528-
q_total_token_num,
529-
)
530-
531470
if need_do_prefill:
532471
merge_prefill_decode_output(
533472
attn_out,
@@ -1119,6 +1058,12 @@ def forward(
11191058
residual: paddle.Tensor,
11201059
):
11211060
""" """
1061+
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
1062+
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0
1063+
1064+
if not need_do_prefill and not need_do_decode:
1065+
return hidden_states
1066+
11221067
if hidden_states.shape[0] > 0:
11231068
hidden_states, residual = self.input_layernorm(
11241069
hidden_states, residual_input=residual, forward_meta=forward_meta

tests/deterministic/test_triton_decode_attention.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def cosine_similarity(a, b):
6363
# ---------------------------------------------------------------------------
6464
# Reference implementation: naive decode attention (no paging)
6565
# ---------------------------------------------------------------------------
66-
def naive_decode_attention_ref(q, k_pages, v_pages, kv_indptr, kv_indices,
67-
sm_scale, kv_block_size):
66+
def naive_decode_attention_ref(q, k_pages, v_pages, kv_indptr, kv_indices, sm_scale, kv_block_size):
6867
"""
6968
Naive Python reference for decode attention with paged KV cache.
7069
@@ -154,7 +153,6 @@ def build_decode_test_data(
154153
np.random.seed(seed)
155154
paddle.seed(seed)
156155

157-
total_kv_len = sum(seq_lens)
158156
num_blocks_needed = sum((s + block_size - 1) // block_size for s in seq_lens)
159157
num_blocks = max(num_blocks_needed + 4, 8)
160158

@@ -194,7 +192,7 @@ def build_decode_test_data(
194192
attn_lse = paddle.empty([batch_size, num_heads, max_kv_splits], dtype="float32")
195193
o = paddle.empty([batch_size, num_heads, Lv], dtype=dtype)
196194

197-
sm_scale = head_dim_k ** -0.5
195+
sm_scale = head_dim_k**-0.5
198196

199197
return {
200198
"q": q,
@@ -290,8 +288,7 @@ def test_empty(self):
290288
ids=[c[0] for c in _DECODE_CASES],
291289
)
292290
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
293-
def test_decode_attention_correctness(name, batch, num_heads, kv_heads, Lk, Lv,
294-
seq_lens, block_size, dtype):
291+
def test_decode_attention_correctness(name, batch, num_heads, kv_heads, Lk, Lv, seq_lens, block_size, dtype):
295292
"""Triton decode attention output should match naive reference."""
296293
data = build_decode_test_data(
297294
batch_size=batch,
@@ -336,12 +333,10 @@ def test_decode_attention_correctness(name, batch, num_heads, kv_heads, Lk, Lv,
336333
cos_sim = cosine_similarity(triton_out, ref_out)
337334

338335
atol = BF16_ATOL if dtype == "bfloat16" else FP16_ATOL
339-
assert max_diff < atol, (
340-
f"[{name}/{dtype}] max_diff={max_diff:.6f} exceeds atol={atol}"
341-
)
342-
assert cos_sim > COSINE_SIM_THRESHOLD, (
343-
f"[{name}/{dtype}] cos_sim={cos_sim:.6f} below threshold={COSINE_SIM_THRESHOLD}"
344-
)
336+
assert max_diff < atol, f"[{name}/{dtype}] max_diff={max_diff:.6f} exceeds atol={atol}"
337+
assert (
338+
cos_sim > COSINE_SIM_THRESHOLD
339+
), f"[{name}/{dtype}] cos_sim={cos_sim:.6f} below threshold={COSINE_SIM_THRESHOLD}"
345340

346341

347342
# ===========================================================================
@@ -380,10 +375,7 @@ def test_decode_attention_determinism():
380375
results.append(o.astype("float32").numpy())
381376

382377
for i in range(1, len(results)):
383-
np.testing.assert_array_equal(
384-
results[0], results[i],
385-
err_msg=f"Run 0 vs run {i} differ — non-deterministic!"
386-
)
378+
np.testing.assert_array_equal(results[0], results[i], err_msg=f"Run 0 vs run {i} differ — non-deterministic!")
387379

388380

389381
# ===========================================================================

tests/deterministic/test_triton_mla_cache_kernel.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,7 @@ def test_write_cache_determinism():
411411
results.append(cache.astype("float32").numpy())
412412

413413
for i in range(1, len(results)):
414-
np.testing.assert_array_equal(
415-
results[0], results[i],
416-
err_msg=f"Run 0 vs run {i} differ — non-deterministic!"
417-
)
414+
np.testing.assert_array_equal(results[0], results[i], err_msg=f"Run 0 vs run {i} differ — non-deterministic!")
418415

419416

420417
# ===========================================================================
@@ -429,17 +426,23 @@ def test_manual_baseline():
429426
latent_dim = kv_lora_rank + qk_rope_head_dim # 6
430427

431428
# 3 tokens, deterministic values
432-
compressed_kv = paddle.to_tensor([
433-
[1.0, 2.0, 3.0, 4.0],
434-
[5.0, 6.0, 7.0, 8.0],
435-
[9.0, 10.0, 11.0, 12.0],
436-
], dtype="float32")
437-
438-
k_pe = paddle.to_tensor([
439-
[0.1, 0.2],
440-
[0.3, 0.4],
441-
[0.5, 0.6],
442-
], dtype="float32")
429+
compressed_kv = paddle.to_tensor(
430+
[
431+
[1.0, 2.0, 3.0, 4.0],
432+
[5.0, 6.0, 7.0, 8.0],
433+
[9.0, 10.0, 11.0, 12.0],
434+
],
435+
dtype="float32",
436+
)
437+
438+
k_pe = paddle.to_tensor(
439+
[
440+
[0.1, 0.2],
441+
[0.3, 0.4],
442+
[0.5, 0.6],
443+
],
444+
dtype="float32",
445+
)
443446

444447
latent_cache = paddle.zeros([num_blocks, 1, block_size, latent_dim], dtype="float32")
445448

0 commit comments

Comments
 (0)