Skip to content

Commit 522d12c

Browse files
add deepep precision test (#6984)
1 parent 5780345 commit 522d12c

2 files changed

Lines changed: 189 additions & 0 deletions

File tree

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import unittest
2+
3+
import paddle
4+
import paddle.distributed as dist
5+
import paddle.distributed.communication.deep_ep as deep_ep
6+
from paddle.distributed import fleet
7+
8+
9+
class TestFusedMoE(unittest.TestCase):
10+
def setUp(self) -> None:
11+
pass
12+
13+
def test_fused_moe(self):
14+
num_ranks = dist.get_world_size()
15+
if num_ranks <= 1:
16+
return
17+
rank_id = dist.get_rank()
18+
paddle.seed(rank_id + 100)
19+
20+
strategy = fleet.DistributedStrategy()
21+
strategy.hybrid_configs = {"dp_degree": 1, "mp_degree": num_ranks, "pp_degree": 1}
22+
fleet.init(is_collective=True, strategy=strategy)
23+
24+
num_tokens, hidden, num_topk, num_experts = 64, 7168, 4, 64
25+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
26+
27+
ep_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
28+
buffer = deep_ep.Buffer(
29+
ep_group,
30+
num_nvl_bytes=0,
31+
num_rdma_bytes=num_rdma_bytes,
32+
low_latency_mode=True,
33+
num_qps_per_rank=num_experts // num_ranks,
34+
)
35+
36+
x = paddle.randn(shape=[num_tokens, hidden], dtype="bfloat16")
37+
scores = paddle.randn([num_tokens, num_experts], dtype="float32").abs() + 1
38+
topk_info = paddle.topk(scores, num_topk, axis=-1, largest=True, sorted=False)
39+
topk_weight = topk_info[0]
40+
topk_idx = topk_info[1]
41+
42+
gather_x = []
43+
dist.all_gather(gather_x, x, ep_group)
44+
gather_x = paddle.stack(gather_x, axis=0)
45+
46+
gather_topk_idx = []
47+
dist.all_gather(gather_topk_idx, topk_idx, ep_group)
48+
gather_topk_idx = paddle.concat(gather_topk_idx, axis=0)
49+
50+
handle = None
51+
52+
num_tests = 10
53+
54+
for _ in range(num_tests):
55+
56+
dispatch_use_fp8 = False
57+
packed_recv_x, packed_recv_count, handle, event, hook = buffer.low_latency_dispatch(
58+
x,
59+
topk_idx,
60+
None, # expertwise_scale, used in w4a8.
61+
num_tokens,
62+
num_experts,
63+
use_fp8=dispatch_use_fp8,
64+
async_finish=False,
65+
return_recv_hook=True,
66+
)
67+
68+
if hook is not None:
69+
hook()
70+
if dispatch_use_fp8:
71+
fp8, scale = packed_recv_x[0], packed_recv_x[1]
72+
fp32 = fp8.cast("float32").reshape([0, 0, hidden // 128, 128])
73+
scale = scale.transpose([0, 2, 1]).reshape([0, 0, hidden // 128, 1])
74+
fp32 = fp32 * scale
75+
fp32 = fp32.reshape([0, 0, -1])
76+
77+
combined_hidden_states, _, _ = buffer.low_latency_combine(
78+
packed_recv_x,
79+
topk_idx,
80+
topk_weight,
81+
handle,
82+
zero_copy=False,
83+
async_finish=False,
84+
return_recv_hook=False,
85+
)
86+
87+
num_local_experts = num_experts // num_ranks
88+
start_ep_id = rank_id * num_local_experts
89+
end_ep_id = start_ep_id + num_local_experts
90+
91+
num_tokens_send_by_rdma = 0
92+
for token_id in range(topk_idx.shape[0]):
93+
for dst_expert_id in topk_idx[token_id].numpy().tolist():
94+
if dst_expert_id not in range(start_ep_id, end_ep_id):
95+
num_tokens_send_by_rdma += 1
96+
print("num_tokens_send_by_rdma:", num_tokens_send_by_rdma)
97+
98+
(recv_src_info, recv_layout_range, _, _) = handle
99+
100+
for ep_id in range(start_ep_id, end_ep_id):
101+
local_ep_id = ep_id - start_ep_id
102+
token_num_this_ep = packed_recv_count[local_ep_id].item()
103+
token_nums_per_rank = []
104+
begin_idx_per_rank = []
105+
for rank_id in range(num_ranks):
106+
tmp = recv_layout_range[local_ep_id, rank_id].item()
107+
begin_idx_per_rank.append(tmp >> 32)
108+
token_nums_per_rank.append(tmp & ((1 << 32) - 1))
109+
assert token_num_this_ep == sum(token_nums_per_rank)
110+
111+
for rank_id in range(num_ranks):
112+
begin_idx = begin_idx_per_rank[rank_id]
113+
end_idx = begin_idx + token_nums_per_rank[rank_id]
114+
for token_id in range(begin_idx, end_idx):
115+
token = packed_recv_x[local_ep_id, token_id, :]
116+
# 这个token来自rank_id,并且是他的第多少个token呢?
117+
src_token_id = recv_src_info[local_ep_id, token_id].item()
118+
src_token = gather_x[rank_id, src_token_id, :]
119+
# print(token - src_token)
120+
assert (src_token - token).abs().max().item() == 0
121+
122+
123+
if __name__ == "__main__":
124+
unittest.main()
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import subprocess
17+
import sys
18+
19+
20+
def test_launch():
21+
"""
22+
test_fused_moe
23+
"""
24+
current_dir = os.path.dirname(os.path.abspath(__file__))
25+
py_script = os.path.join(current_dir, "./test_hopper_ll_precision.py")
26+
27+
# 为了方便在PDC的环境下直接python运行这个脚本
28+
os.environ.pop("PADDLE_ELASTIC_JOB_ID", None)
29+
os.environ.pop("PADDLE_TRAINER_ENDPOINTS", None)
30+
os.environ.pop("DISTRIBUTED_TRAINER_ENDPOINTS", None)
31+
os.environ.pop("FLAGS_START_PORT", None)
32+
os.environ.pop("PADDLE_ELASTIC_TIMEOUT", None)
33+
34+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
35+
36+
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
37+
command = [
38+
sys.executable,
39+
"-m",
40+
"paddle.distributed.launch",
41+
"--gpus",
42+
"0,1",
43+
"--master",
44+
f"127.0.0.1:{FD_API_PORT}",
45+
"--nnodes",
46+
"1",
47+
"--rank",
48+
"0",
49+
py_script,
50+
]
51+
52+
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
53+
54+
try:
55+
stdout, stderr = process.communicate(timeout=400)
56+
return_code = process.returncode
57+
except subprocess.TimeoutExpired:
58+
process.kill()
59+
stdout, stderr = process.communicate()
60+
return_code = -1
61+
print(f"std_out: {stdout}")
62+
assert return_code in (0, 250, 255), f"Process exited with code {return_code}, stdout: {stdout}, stderr: {stderr}"
63+
64+
65+
test_launch()

0 commit comments

Comments
 (0)