Skip to content

Commit 466651c

Browse files
blueswhenniushengxiaoshihaobai
authored
feat: deep_ep v2 (#1303)
Co-authored-by: niushengxiao <niushengxiao@sensetime.com> Co-authored-by: shihaobai <1798930569@qq.com>
1 parent 520c041 commit 466651c

36 files changed

Lines changed: 1046 additions & 295 deletions

File tree

docker/Dockerfile

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
ARG CUDA_VERSION=12.8.0
1+
ARG CUDA_VERSION=13.0.0
22
FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
33

44
ARG PYTHON_VERSION=3.10
55
ARG MAMBA_VERSION=24.7.1-0
6-
ARG VLLM_VERSION=0.16.0
6+
ARG VLLM_VERSION=0.21.0
7+
ARG NIXL_REF=v1.1.0
78
ARG FLASH_MLA_REF=47c35a7
9+
ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f
810
ARG TARGETPLATFORM
911
ARG ENABLE_DEEPEP=1
1012
ARG ENABLE_NIXL=1
1113
ARG ENABLE_CACHE=1
14+
ARG ENABLE_SM100=0
1215

1316
ENV PATH=/opt/conda/bin:$PATH \
1417
CONDA_PREFIX=/opt/conda
@@ -44,13 +47,18 @@ WORKDIR /root
4447

4548
COPY ./requirements.txt /lightllm/requirements.txt
4649
RUN pip install -U pip
47-
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
48-
RUN pip install --no-cache-dir vllm==${VLLM_VERSION}
49-
RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
50+
RUN pip install --no-cache-dir \
51+
--extra-index-url https://download.pytorch.org/whl/cu130 \
52+
vllm==${VLLM_VERSION}
53+
RUN pip install -r /lightllm/requirements.txt --no-cache-dir \
54+
--extra-index-url https://download.pytorch.org/whl/cu130
55+
RUN export CPATH=/usr/local/cuda/targets/x86_64-linux/include/cccl:/usr/local/cuda/targets/x86_64-linux/include${CPATH:+:${CPATH}} && \
56+
git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
5057
cd /root/FlashMLA && \
5158
git checkout ${FLASH_MLA_REF} && \
5259
git submodule update --init --recursive && \
53-
FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir .
60+
FLASH_MLA_DISABLE_SM100="$(if [ "${ENABLE_SM100}" = "1" ]; then echo 0; else echo 1; fi)" \
61+
pip install --no-cache-dir .
5462

5563
RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*
5664

@@ -78,27 +86,20 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \
7886
RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \
7987
set -e; \
8088
ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \
81-
NVSHMEM_VERSION=3.3.9; \
82-
CUDA_ARCHS=90; \
83-
wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \
84-
&& tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \
85-
&& cd nvshmem \
86-
&& rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \
87-
&& NVSHMEM_SHMEM_SUPPORT=0 \
88-
NVSHMEM_UCX_SUPPORT=0 \
89-
NVSHMEM_USE_NCCL=0 \
90-
NVSHMEM_MPI_SUPPORT=0 \
91-
NVSHMEM_IBGDA_SUPPORT=1 \
92-
NVSHMEM_PMIX_SUPPORT=0 \
93-
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
94-
NVSHMEM_USE_GDRCOPY=1 \
95-
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \
96-
&& cmake --build build --target install -j64; \
97-
DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \
98-
cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \
99-
cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \
89+
python -m pip install --upgrade --no-deps \
90+
"nvidia-nccl-cu13==2.30.4" \
91+
"nvidia-nvshmem-cu13==3.6.5"; \
92+
cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \
93+
ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \
94+
ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \
95+
pip install --no-build-isolation .; \
10096
fi
10197

98+
RUN cd /root && git clone https://github.com/deepseek-ai/DeepGEMM.git && \
99+
cd DeepGEMM && git checkout ${DEEPGEMM_REF} && \
100+
git submodule update --init --recursive && \
101+
pip install --no-build-isolation .
102+
102103
RUN if [ "${ENABLE_NIXL}" = "1" ]; then \
103104
apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \
104105
DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \
@@ -126,7 +127,7 @@ RUN if [ "${ENABLE_NIXL}" = "1" ]; then \
126127
apt-get update && apt-get install -y pkg-config tmux net-tools && \
127128
cd /usr/local/src; \
128129
pip install --upgrade meson pybind11 patchelf; \
129-
git clone https://github.com/ai-dynamo/nixl.git -b main && \
130+
git clone https://github.com/ai-dynamo/nixl.git -b ${NIXL_REF} && \
130131
cd nixl && \
131132
rm -rf build && \
132133
mkdir build && \

docker/scripts/build.sh

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,23 @@ set -euo pipefail
1818
# --no-nixl Disable NIXL (default: enabled)
1919
# --no-cache Disable cache (default: enabled)
2020
# --lite Disable DEEPEP, NIXL and cache in one shot
21-
# --cuda-version <ver> CUDA version (default: 12.8.0)
21+
# --cuda-version <ver> CUDA version (default: 13.0.0)
2222
# --image-prefix <name> Image prefix (default: lightllm)
2323
# --image-tag <tag> Image tag (default: generated from enabled features)
24+
# --enable-sm100 Enable SM100 support (default: disabled)
2425
# -h / --help Show help
2526

2627
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
2728
cd "${ROOT_DIR}"
2829

2930
IMAGE_PREFIX="${IMAGE_PREFIX:-lightllm}"
30-
CUDA_VERSION="${CUDA_VERSION:-12.8.0}"
31+
CUDA_VERSION="${CUDA_VERSION:-13.0.0}"
3132
IMAGE_TAG="${IMAGE_TAG:-}"
3233

3334
ENABLE_DEEPEP="${ENABLE_DEEPEP:-1}"
3435
ENABLE_NIXL="${ENABLE_NIXL:-1}"
3536
ENABLE_CACHE="${ENABLE_CACHE:-1}"
37+
ENABLE_SM100="${ENABLE_SM100:-0}"
3638

3739
print_help() {
3840
sed -n '1,80p' "$0" | sed 's/^# \{0,1\}//'
@@ -43,6 +45,7 @@ while [[ $# -gt 0 ]]; do
4345
--no-deepep) ENABLE_DEEPEP=0 ;;
4446
--no-nixl) ENABLE_NIXL=0 ;;
4547
--no-cache) ENABLE_CACHE=0 ;;
48+
--enable-sm100) ENABLE_SM100=1 ;;
4649
--lite)
4750
ENABLE_DEEPEP=0
4851
ENABLE_NIXL=0
@@ -78,13 +81,16 @@ done
7881
# - Other combos: composed from enabled feature names
7982
if [[ -z "${IMAGE_TAG}" ]]; then
8083
tag_parts=()
84+
if [[ "${ENABLE_SM100}" -eq 1 ]]; then
85+
tag_parts+=("sm100")
86+
fi
8187
if [[ "${ENABLE_NIXL}" -eq 1 ]]; then
8288
tag_parts+=("nixl")
8389
fi
8490
if [[ "${ENABLE_DEEPEP}" -eq 1 ]]; then
8591
tag_parts+=("deepep")
8692
fi
87-
if [[ "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then
93+
if [[ "${ENABLE_SM100}" -eq 0 && "${ENABLE_NIXL}" -eq 1 && "${ENABLE_DEEPEP}" -eq 1 && "${ENABLE_CACHE}" -eq 1 ]]; then
8894
IMAGE_TAG="cuda${CUDA_VERSION}"
8995
else
9096
prefix=""
@@ -100,6 +106,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \
100106
--build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \
101107
--build-arg ENABLE_NIXL="${ENABLE_NIXL}" \
102108
--build-arg ENABLE_CACHE="${ENABLE_CACHE}" \
109+
--build-arg ENABLE_SM100="${ENABLE_SM100}" \
103110
--progress=plain \
104111
-t "${IMAGE_PREFIX}:${IMAGE_TAG}" .
105-

docs/CN/source/tutorial/api_server_args.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,14 @@ PD 分离模式参数
464464

465465
示例可以在 test/advanced_config/mixed_quantization/llamacls-mix-down.yaml 中找到。
466466

467+
.. option:: --expert_dtype
468+
469+
EP MoE 专家量化类型,可选值:
470+
471+
* ``fp8``
472+
* ``fp4``,仅支持 SM100 GPU
473+
* ``None`` (默认)
474+
467475
.. option:: --vit_quant_type
468476

469477
ViT 量化方法,可选值:

docs/EN/source/tutorial/api_server_args.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,14 @@ Quantization Parameters
465465

466466
Examples can be found in test/advanced_config/mixed_quantization/llamacls-mix-down.yaml.
467467

468+
.. option:: --expert_dtype
469+
470+
Expert quantization dtype for EP MoE, optional values:
471+
472+
* ``fp8``
473+
* ``fp4``: SM100 GPUs only
474+
* ``None`` (default)
475+
468476
.. option:: --vit_quant_type
469477

470478
ViT quantization method, optional values:

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(self, kvargs):
8585
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
8686
self.quant_type = kvargs.get("quant_type", "none")
8787
self.quant_cfg_path = kvargs.get("quant_cfg", None)
88+
self.expert_dtype = kvargs.get("expert_dtype", None)
8889
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
8990
self.tp_world_size_ = get_dp_world_size()
9091
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
@@ -156,7 +157,7 @@ def _verify_params(self):
156157
return
157158

158159
def _init_quant(self):
159-
self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path)
160+
self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path, self.expert_dtype)
160161
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")
161162

162163
def _init_weights(self, start_layer_index=0):

lightllm/common/basemodel/layer_infer/cache_tensor_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class BufNode:
3333
inner_tensor: torch.Tensor
3434
shape_key: Tuple[int, torch.dtype]
3535
storage_weak_ptr: int
36+
free_use_count_bias: int = 0
3637
shape_to_tensor: Dict[Union[torch.Size, Iterable[int]], torch.Tensor] = field(default_factory=dict)
3738

3839
def __del__(self):
@@ -99,7 +100,8 @@ def alloc_tensor(
99100
# 回收可能消亡的 tensor
100101
for ptr in self.changed_ptr:
101102
t_buf_node = self.ptr_to_bufnode[ptr]
102-
if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor):
103+
free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor)
104+
if self.use_count(ptr) <= free_use_count:
103105
self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node)
104106
self.changed_ptr.clear()
105107

@@ -131,6 +133,7 @@ def alloc_tensor(
131133
self.ptr_to_bufnode[storage_weak_ptr] = buf_node
132134
if shape not in buf_node.shape_to_tensor:
133135
buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape)
136+
buf_node.free_use_count_bias = self.use_count(storage_weak_ptr) - (1 + len(buf_node.shape_to_tensor))
134137
mark_tensor = buf_node.shape_to_tensor[shape]
135138
ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断
136139
ans.storage_weak_ptr = buf_node.storage_weak_ptr

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py

Lines changed: 36 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
from lightllm.distributed import dist_group_manager
55
from lightllm.common.triton_utils.autotuner import Autotuner
66
from lightllm.common.quantization.quantize_method import WeightPack
7-
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
7+
from lightllm.utils.envs_utils import (
8+
get_deepep_num_max_dispatch_tokens_per_rank_prefill,
9+
get_deepep_num_max_dispatch_tokens_per_rank_decode,
10+
)
811
from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import (
9-
fused_experts_impl,
12+
fused_experts,
13+
get_ep_num_sms,
1014
masked_group_gemm,
11-
_deepgemm_grouped_fp8_nt_contiguous,
15+
deepgemm_grouped_fp8_nt_contiguous,
16+
quantize_fused_experts_input,
1217
)
1318
from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import (
1419
per_token_group_quant_fp8,
@@ -72,23 +77,15 @@ def _fused_experts(
7277
router_logits: Optional[torch.Tensor] = None,
7378
is_prefill: Optional[bool] = None,
7479
):
75-
w13_weight, w13_scale = w13.weight, w13.weight_scale
76-
w2_weight, w2_scale = w2.weight, w2.weight_scale
77-
use_fp8_w8a8 = self.quant_method.method_name != "none"
78-
output = fused_experts_impl(
80+
output = fused_experts(
7981
hidden_states=input_tensor,
80-
w1=w13_weight,
81-
w2=w2_weight,
82+
w13=w13,
83+
w2=w2,
8284
topk_weights=topk_weights,
8385
topk_idx=topk_ids.to(torch.long),
8486
num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy
85-
buffer=dist_group_manager.ep_buffer,
87+
quant_method=self.quant_method,
8688
is_prefill=is_prefill,
87-
use_fp8_w8a8=use_fp8_w8a8,
88-
use_fp8_all2all=use_fp8_w8a8,
89-
use_int8_w8a16=False, # default to False
90-
w1_scale=w13_scale,
91-
w2_scale=w2_scale,
9289
previous_event=None, # for overlap
9390
)
9491
return output
@@ -118,13 +115,13 @@ def low_latency_dispatch(
118115
)
119116

120117
topk_idx = topk_idx.to(torch.long)
121-
num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank()
118+
num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode()
122119
use_fp8_w8a8 = self.quant_method.method_name != "none"
123-
recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch(
124-
hidden_states,
125-
topk_idx,
126-
num_max_dispatch_tokens_per_rank,
127-
self.total_expert_num_contain_redundancy,
120+
recv_x, masked_m, handle, event, hook = dist_group_manager.ep_low_latency_buffer.low_latency_dispatch(
121+
topk_idx=topk_idx,
122+
x=hidden_states,
123+
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
124+
num_experts=self.total_expert_num_contain_redundancy,
128125
use_fp8=use_fp8_w8a8,
129126
async_finish=False,
130127
return_recv_hook=True,
@@ -155,13 +152,8 @@ def select_experts_and_quant_input(
155152
num_expert_group=n_group,
156153
scoring_func=scoring_func,
157154
)
158-
w13_weight, w13_scale = w13.weight, w13.weight_scale
159-
block_size_k = 0
160-
if w13_weight.ndim == 3:
161-
block_size_k = w13_weight.shape[2] // w13_scale.shape[2]
162-
assert block_size_k == 128, "block_size_k must be 128"
163-
qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13_weight.dtype)
164-
return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale)
155+
qinput_tensor = quantize_fused_experts_input(hidden_states, w13, self.quant_method)
156+
return topk_weights, topk_idx.to(torch.long), qinput_tensor
165157

166158
def dispatch(
167159
self,
@@ -171,38 +163,26 @@ def dispatch(
171163
overlap_event: Optional[Any] = None,
172164
):
173165
buffer = dist_group_manager.ep_buffer
174-
# get_dispatch_layout
175-
(
176-
num_tokens_per_rank,
177-
num_tokens_per_rdma_rank,
178-
num_tokens_per_expert,
179-
is_token_in_rank,
180-
previous_event,
181-
) = buffer.get_dispatch_layout(
182-
topk_idx,
183-
self.total_expert_num_contain_redundancy,
184-
previous_event=overlap_event,
185-
async_finish=True,
186-
allocate_on_comm_stream=True,
187-
)
188-
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch(
166+
num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill()
167+
recv_x, recv_topk_idx, recv_topk_weights, handle, event = buffer.dispatch(
189168
qinput_tensor,
190169
topk_idx=topk_idx,
191170
topk_weights=topk_weights,
192-
num_tokens_per_rank=num_tokens_per_rank,
193-
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
194-
is_token_in_rank=is_token_in_rank,
195-
num_tokens_per_expert=num_tokens_per_expert,
196-
previous_event=previous_event,
197-
async_finish=True,
198-
allocate_on_comm_stream=True,
171+
num_experts=self.total_expert_num_contain_redundancy,
172+
num_max_tokens_per_rank=num_max_tokens_per_rank,
199173
expert_alignment=128,
174+
num_sms=get_ep_num_sms(),
175+
previous_event=overlap_event,
176+
async_with_compute_stream=True,
177+
allocate_on_comm_stream=True,
178+
do_cpu_sync=True,
179+
do_handle_copy=False,
200180
)
201181

202182
def hook():
203183
event.current_stream_wait()
204184

205-
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook
185+
return recv_x, recv_topk_idx, recv_topk_weights, handle.num_recv_tokens_per_expert_list, handle, hook
206186

207187
def masked_group_gemm(
208188
self,
@@ -281,7 +261,7 @@ def prefilled_group_gemm(
281261
# groupgemm (contiguous layout)
282262
gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype)
283263

284-
_deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices)
264+
deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices)
285265

286266
# silu_and_mul_fwd + qaunt
287267
# TODO fused kernel
@@ -295,7 +275,7 @@ def prefilled_group_gemm(
295275
# groupgemm (contiguous layout)
296276
gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype)
297277

298-
_deepgemm_grouped_fp8_nt_contiguous(
278+
deepgemm_grouped_fp8_nt_contiguous(
299279
(qsilu_out, qsilu_out_scale), (w2_weight, w2_scale), gemm_out_b, m_indices
300280
)
301281
# gather and local reduce
@@ -319,7 +299,7 @@ def low_latency_combine(
319299
topk_weights: torch.Tensor,
320300
handle: Any,
321301
):
322-
combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine(
302+
combined_x, event_overlap, hook = dist_group_manager.ep_low_latency_buffer.low_latency_combine(
323303
gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True
324304
)
325305
return combined_x, hook
@@ -335,8 +315,9 @@ def combine(
335315
gemm_out_b,
336316
handle,
337317
topk_weights=None,
338-
async_finish=True,
318+
num_sms=get_ep_num_sms(),
339319
previous_event=overlap_event,
320+
async_with_compute_stream=True,
340321
allocate_on_comm_stream=True,
341322
)
342323

0 commit comments

Comments
 (0)