Skip to content

Commit d73b302

Browse files
Merge pull request #813 from ROCm/zhiwei/fp4_acc
[acc fix] Change the environment set of DeepSeek-R1 FP4 scripts and port the Yuhua's fix
2 parents c49a21b + 82080ea commit d73b302

2 files changed

Lines changed: 19 additions & 11 deletions

File tree

evaluation/deepseek_fp4/launch_deepseekr1_fp4_TP.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
export VLLM_USE_V1=1
2-
export VLLM_USE_TRITON_FLASH_ATTN=0
2+
export VLLM_USE_TRITON_FLASH_ATTN=1 # use triton mha
33
# export VLLM_LOGGING_LEVEL=DEBUG
44
export VLLM_RPC_TIMEOUT=1800000
55
export VLLM_ROCM_USE_AITER=1
66
export VLLM_ROCM_USE_AITER_MHA=0
7-
export VLLM_ROCM_USE_AITER_MLA=1
7+
export VLLM_ROCM_USE_AITER_MLA=0 # use triton mha
88
export VLLM_ROCM_USE_AITER_MOE=1
99
export VLLM_ROCM_USE_TRITON_ROPE=1 # add for acc
1010
export VLLM_DISABLE_COMPILE_CACHE=1
1111
# FIXME: for now disable fp4 asm gemm because of running issue
1212
export VLLM_ROCM_USE_AITER_FP4_ASM_GEMM=0
13-
#export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # for now disable
13+
export VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0 # disable for acc
1414

1515
export TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1
1616
export TRITON_HIP_USE_ASYNC_COPY=1
@@ -37,11 +37,12 @@ vllm serve $model_path \
3737
--trust-remote-code \
3838
--no-enable-prefix-caching \
3939
--disable-log-requests \
40-
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
41-
--gpu_memory_utilization 0.8 \
40+
--enforce-eager \
41+
--gpu_memory_utilization 0.7 \
4242
--async-scheduling \
43+
--block-size 16 \
4344
--load-format fastsafetensors \
4445
--seed 123 2>&1 | tee log.server.log &
4546

46-
# --enforce-eager \
47+
# --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
4748
# --enable-expert-parallel \

vllm/v1/attention/backends/mla/triton_mla.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,21 @@ def _flash_attn_varlen_diff_headdims(
129129
q, k, v, softmax_scale=softmax_scale, **kwargs
130130
)
131131
else:
132-
return super()._flash_attn_varlen_diff_headdims(
133-
q,
134-
k,
135-
v,
136-
return_softmax_lse=return_softmax_lse,
132+
from aiter.ops.triton.mha import flash_attn_varlen_func
133+
134+
result = flash_attn_varlen_func(
135+
q=q,
136+
k=k,
137+
v=v,
138+
return_lse=return_softmax_lse,
137139
softmax_scale=softmax_scale,
138140
**kwargs,
139141
)
142+
if type(result) is tuple and return_softmax_lse:
143+
output, lse = result
144+
lse = lse.T.contiguous()
145+
return (output, lse)
146+
return result
140147

141148
def _forward_decode(
142149
self,

0 commit comments

Comments
 (0)