Skip to content

Commit fa5436d

Browse files
committed
update forward
1 parent 15a153a commit fa5436d

3 files changed

Lines changed: 32 additions & 43 deletions

File tree

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -362,18 +362,6 @@ def forward(
362362
fused_read_cache_and_interleave,
363363
)
364364

365-
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
366-
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0
367-
368-
# Idle pass (e.g. CUDAGraph padding): skip all attention computation
369-
if not need_do_prefill and not need_do_decode:
370-
return self.o_proj(
371-
paddle.zeros(
372-
[hidden_states.shape[0], self.num_attention_heads_tp * self.v_head_dim],
373-
dtype=hidden_states.dtype,
374-
)
375-
)
376-
377365
attn_out = None
378366
if self.use_gated_attn:
379367
gate_out = self.gate(hidden_states)
@@ -1070,6 +1058,12 @@ def forward(
10701058
residual: paddle.Tensor,
10711059
):
10721060
""" """
1061+
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
1062+
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0
1063+
1064+
if not need_do_prefill and not need_do_decode:
1065+
return hidden_states
1066+
10731067
if hidden_states.shape[0] > 0:
10741068
hidden_states, residual = self.input_layernorm(
10751069
hidden_states, residual_input=residual, forward_meta=forward_meta

tests/deterministic/test_triton_decode_attention.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def cosine_similarity(a, b):
6363
# ---------------------------------------------------------------------------
6464
# Reference implementation: naive decode attention (no paging)
6565
# ---------------------------------------------------------------------------
66-
def naive_decode_attention_ref(q, k_pages, v_pages, kv_indptr, kv_indices,
67-
sm_scale, kv_block_size):
66+
def naive_decode_attention_ref(q, k_pages, v_pages, kv_indptr, kv_indices, sm_scale, kv_block_size):
6867
"""
6968
Naive Python reference for decode attention with paged KV cache.
7069
@@ -154,7 +153,6 @@ def build_decode_test_data(
154153
np.random.seed(seed)
155154
paddle.seed(seed)
156155

157-
total_kv_len = sum(seq_lens)
158156
num_blocks_needed = sum((s + block_size - 1) // block_size for s in seq_lens)
159157
num_blocks = max(num_blocks_needed + 4, 8)
160158

@@ -194,7 +192,7 @@ def build_decode_test_data(
194192
attn_lse = paddle.empty([batch_size, num_heads, max_kv_splits], dtype="float32")
195193
o = paddle.empty([batch_size, num_heads, Lv], dtype=dtype)
196194

197-
sm_scale = head_dim_k ** -0.5
195+
sm_scale = head_dim_k**-0.5
198196

199197
return {
200198
"q": q,
@@ -290,8 +288,7 @@ def test_empty(self):
290288
ids=[c[0] for c in _DECODE_CASES],
291289
)
292290
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
293-
def test_decode_attention_correctness(name, batch, num_heads, kv_heads, Lk, Lv,
294-
seq_lens, block_size, dtype):
291+
def test_decode_attention_correctness(name, batch, num_heads, kv_heads, Lk, Lv, seq_lens, block_size, dtype):
295292
"""Triton decode attention output should match naive reference."""
296293
data = build_decode_test_data(
297294
batch_size=batch,
@@ -336,12 +333,10 @@ def test_decode_attention_correctness(name, batch, num_heads, kv_heads, Lk, Lv,
336333
cos_sim = cosine_similarity(triton_out, ref_out)
337334

338335
atol = BF16_ATOL if dtype == "bfloat16" else FP16_ATOL
339-
assert max_diff < atol, (
340-
f"[{name}/{dtype}] max_diff={max_diff:.6f} exceeds atol={atol}"
341-
)
342-
assert cos_sim > COSINE_SIM_THRESHOLD, (
343-
f"[{name}/{dtype}] cos_sim={cos_sim:.6f} below threshold={COSINE_SIM_THRESHOLD}"
344-
)
336+
assert max_diff < atol, f"[{name}/{dtype}] max_diff={max_diff:.6f} exceeds atol={atol}"
337+
assert (
338+
cos_sim > COSINE_SIM_THRESHOLD
339+
), f"[{name}/{dtype}] cos_sim={cos_sim:.6f} below threshold={COSINE_SIM_THRESHOLD}"
345340

346341

347342
# ===========================================================================
@@ -380,10 +375,7 @@ def test_decode_attention_determinism():
380375
results.append(o.astype("float32").numpy())
381376

382377
for i in range(1, len(results)):
383-
np.testing.assert_array_equal(
384-
results[0], results[i],
385-
err_msg=f"Run 0 vs run {i} differ — non-deterministic!"
386-
)
378+
np.testing.assert_array_equal(results[0], results[i], err_msg=f"Run 0 vs run {i} differ — non-deterministic!")
387379

388380

389381
# ===========================================================================

tests/deterministic/test_triton_mla_cache_kernel.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,7 @@ def test_write_cache_determinism():
411411
results.append(cache.astype("float32").numpy())
412412

413413
for i in range(1, len(results)):
414-
np.testing.assert_array_equal(
415-
results[0], results[i],
416-
err_msg=f"Run 0 vs run {i} differ — non-deterministic!"
417-
)
414+
np.testing.assert_array_equal(results[0], results[i], err_msg=f"Run 0 vs run {i} differ — non-deterministic!")
418415

419416

420417
# ===========================================================================
@@ -429,17 +426,23 @@ def test_manual_baseline():
429426
latent_dim = kv_lora_rank + qk_rope_head_dim # 6
430427

431428
# 3 tokens, deterministic values
432-
compressed_kv = paddle.to_tensor([
433-
[1.0, 2.0, 3.0, 4.0],
434-
[5.0, 6.0, 7.0, 8.0],
435-
[9.0, 10.0, 11.0, 12.0],
436-
], dtype="float32")
437-
438-
k_pe = paddle.to_tensor([
439-
[0.1, 0.2],
440-
[0.3, 0.4],
441-
[0.5, 0.6],
442-
], dtype="float32")
429+
compressed_kv = paddle.to_tensor(
430+
[
431+
[1.0, 2.0, 3.0, 4.0],
432+
[5.0, 6.0, 7.0, 8.0],
433+
[9.0, 10.0, 11.0, 12.0],
434+
],
435+
dtype="float32",
436+
)
437+
438+
k_pe = paddle.to_tensor(
439+
[
440+
[0.1, 0.2],
441+
[0.3, 0.4],
442+
[0.5, 0.6],
443+
],
444+
dtype="float32",
445+
)
443446

444447
latent_cache = paddle.zeros([num_blocks, 1, block_size, latent_dim], dtype="float32")
445448

0 commit comments

Comments
 (0)