Skip to content

Commit b07a123

Browse files
committed
support moe model deterministic
1 parent a9135f8 commit b07a123

8 files changed

Lines changed: 153 additions & 63 deletions

File tree

xtuner/v1/ray/environment/single_turn_env.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
update_rollout_item,
1717
)
1818
from xtuner.v1.ray.environment.base_env import BaseEnvironment
19+
from xtuner.v1.ray.utils import build_deterministic_session_id, deterministic_item_sort_key
1920
from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, ray_method
2021

2122

@@ -93,14 +94,19 @@ async def generate( # type: ignore[override]
9394
if extra_params is None:
9495
extra_params = {}
9596
if self.rollout_controller:
97+
if XTUNER_DETERMINISTIC:
98+
group_data_items = sorted(group_data_items, key=deterministic_item_sort_key)
9699
response_future = []
97100
for i, sample in enumerate(group_data_items):
98101
rollout_extra_info = dict(sample.data.extra_info)
99102
rollout_extra_info["root_id"] = sample.uid.root_id
100103
rollout_extra_info["action_id"] = sample.uid.action_id
104+
rollout_extra_info["observation_id"] = sample.uid.observation_id
101105
update_sample_params = sample_params
106+
session_id = None
102107
if XTUNER_DETERMINISTIC:
103108
update_sample_params.sampling_seed = self.rollout_cfg.random_seed + i
109+
session_id = build_deterministic_session_id(self.environment, sample)
104110

105111
if "partial_rollout_input_ids" in sample.env.rollout.extra_info:
106112
input_ids_length = len(sample.data.input_ids) if sample.data.input_ids is not None else 0
@@ -125,6 +131,7 @@ async def generate( # type: ignore[override]
125131
input_ids=sample.data.input_ids,
126132
sample_params=update_sample_params,
127133
extra_params=extra_params,
134+
session_id=session_id,
128135
extra_info=rollout_extra_info,
129136
)
130137
del rollout_extra_info

xtuner/v1/ray/rollout/controller.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import time
66
from collections import OrderedDict
77
from dataclasses import dataclass
8-
from itertools import cycle
98
from typing import Any, Dict, List, Optional, Union
109
from uuid import uuid4
1110

@@ -52,7 +51,6 @@ def __init__(
5251

5352
# OrderedDict: key=session_id -> value=(worker, last_used_ts)
5453
self._map: OrderedDict[int, tuple[Any, float]] = OrderedDict()
55-
self._worker_cycler = cycle(self._workers)
5654
self._lock = asyncio.Lock()
5755
self.logger = get_logger()
5856

@@ -80,7 +78,16 @@ def _evict_lru_to_capacity(self):
8078
def update_active_workers(self, worker_status: Dict[Any, bool]):
8179
self._workers = list(worker_status.items())
8280
self.logger.debug(f"SessionRouter update active workers: {self._workers}")
83-
self._worker_cycler = cycle(self._workers)
81+
82+
def _get_healthy_workers(self) -> List[tuple[Any, bool]]:
83+
return [worker for worker in self._workers if worker[1]]
84+
85+
def _select_worker_for_session(self, session_id: int) -> tuple[Any, bool]:
86+
healthy_workers = self._get_healthy_workers()
87+
if not healthy_workers:
88+
raise RuntimeError("No healthy rollout workers available for SessionRouter.")
89+
worker_idx = session_id % len(healthy_workers)
90+
return healthy_workers[worker_idx]
8491

8592
async def get_worker(self, session_id: int) -> Any:
8693
async with self._lock:
@@ -92,9 +99,7 @@ async def get_worker(self, session_id: int) -> Any:
9299
if worker[1]: # worker is healthy
93100
return worker[0]
94101

95-
worker = next(self._worker_cycler)
96-
while worker[1] is False:
97-
worker = next(self._worker_cycler)
102+
worker = self._select_worker_for_session(session_id)
98103
self._map[session_id] = (worker, self._now())
99104

100105
self._evict_lru_to_capacity()

xtuner/v1/ray/rollout/lmdeploy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import ray
88
import requests
9+
import torch
910
from ray.util.placement_group import placement_group_table
1011

1112
from transformers import AutoTokenizer
@@ -208,6 +209,14 @@ def reset_prefix_cache(self):
208209
"""It will implemented for LMDeploy worker in the future."""
209210
pass
210211

212+
def _decode_routed_experts(self, routed_experts: Any, meta_info: Dict[str, Any]):
213+
if isinstance(routed_experts, str):
214+
import base64
215+
216+
data = base64.b64decode(routed_experts)
217+
return ray.cloudpickle.loads(data)
218+
return torch.tensor(routed_experts)
219+
211220
def _transform_rollout_config_to_server_configs(self) -> Namespace:
212221
"""Transform the RolloutConfig into a Namespace suitable for the
213222
LMDeploy server.

xtuner/v1/ray/rollout/sglang.py

Lines changed: 91 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import base64
12
import os
23
from typing import Any, Dict, List, Union
34

5+
import numpy as np
46
import requests
7+
import torch
58
from urllib3.exceptions import NewConnectionError
69

7-
from transformers import AutoTokenizer
10+
from transformers import AutoConfig, AutoTokenizer
811
from xtuner.v1.ray.config import RolloutConfig
912
from xtuner.v1.utils import XTUNER_DETERMINISTIC
1013

@@ -29,6 +32,11 @@ def __init__(
2932
self.endpoints["generate"] = "generate"
3033
self.endpoints["v1/chat/completions"] = "v1/chat/completions"
3134
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path, trust_remote_code=True)
35+
self.model_config = AutoConfig.from_pretrained(self.config.model_path, trust_remote_code=True)
36+
text_config = getattr(self.model_config, "text_config", self.model_config)
37+
self.model_type = getattr(text_config, "model_type", getattr(self.model_config, "model_type", None))
38+
self.routed_experts_num_hidden_layers = getattr(text_config, "num_hidden_layers", None)
39+
self.routed_experts_num_experts_per_tok = getattr(text_config, "num_experts_per_tok", None)
3240
self.api_keys = self.config.api_key
3341
self.model_name = self.config.model_name
3442
self.enable_return_routed_experts = self.config.enable_return_routed_experts
@@ -141,6 +149,37 @@ def reset_prefix_cache(self):
141149
self.flush_cache()
142150
return self._make_request("release_memory_occupation")
143151

152+
def _decode_routed_experts(self, routed_experts: Any, meta_info: Dict[str, Any]):
153+
if not isinstance(routed_experts, str):
154+
return super()._decode_routed_experts(routed_experts, meta_info)
155+
156+
prompt_tokens = meta_info.get("prompt_tokens", 0)
157+
completion_tokens = meta_info.get("completion_tokens", 0)
158+
num_tokens = prompt_tokens + completion_tokens - 1
159+
assert num_tokens > 0, (
160+
f"Unexpected routed_experts token count: prompt_tokens={prompt_tokens}, completion_tokens={completion_tokens}"
161+
)
162+
assert self.routed_experts_num_hidden_layers is not None, (
163+
"num_hidden_layers is required to decode routed_experts"
164+
)
165+
assert self.routed_experts_num_experts_per_tok is not None, (
166+
"num_experts_per_tok is required to decode routed_experts"
167+
)
168+
169+
routed_experts_flat = np.frombuffer(base64.b64decode(routed_experts), dtype=np.int32)
170+
expected_size = num_tokens * self.routed_experts_num_hidden_layers * self.routed_experts_num_experts_per_tok
171+
assert routed_experts_flat.size == expected_size, (
172+
f"Unexpected routed_experts size {routed_experts_flat.size}, expected {expected_size}. "
173+
f"num_tokens={num_tokens}, num_hidden_layers={self.routed_experts_num_hidden_layers}, "
174+
f"num_experts_per_tok={self.routed_experts_num_experts_per_tok}"
175+
)
176+
routed_experts_array = routed_experts_flat.reshape(
177+
num_tokens,
178+
self.routed_experts_num_hidden_layers,
179+
self.routed_experts_num_experts_per_tok,
180+
)
181+
return torch.from_numpy(routed_experts_array.copy())
182+
144183
def _transform_rollout_config_to_server_configs(self):
145184
# remove the CUDA_VISIBLE_DEVICES set by ray and use base_gpu_id
146185
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
@@ -150,55 +189,70 @@ def _transform_rollout_config_to_server_configs(self):
150189
sglang_config_kwargs = {
151190
k.replace("sglang_", ""): v for k, v in extra_config.items() if k.startswith("sglang_")
152191
}
153-
grammar_backend = sglang_config_kwargs.get(
154-
"grammar_backend", None
155-
) # for intern-s1 series models, have to set the grammar_backend to "none"
156192
log_level = sglang_config_kwargs.get("log_level", "error")
157193
log_level_http = sglang_config_kwargs.get("log_level_http", "error")
158-
sglang_server_args = ServerArgs(model_path=self.config.model_path, trust_remote_code=True)
159194
num_gpus_per_engine = (
160195
self.config.expert_parallel_size
161196
if self.config.expert_parallel_size > 1
162197
else self.config.tensor_parallel_size
163198
)
164-
sglang_server_args.host = self.host
165-
sglang_server_args.port = self.server_port
166-
sglang_server_args.nccl_port = self.nccl_port
167-
sglang_server_args.dist_init_addr = self.dist_init_addr
168-
sglang_server_args.base_gpu_id = self.rank % self.config.gpus_per_node
169-
sglang_server_args.gpu_id_step = 1
170-
sglang_server_args.nnodes = max(1, num_gpus_per_engine // self.config.gpus_per_node)
171-
sglang_server_args.skip_server_warmup = True
172-
sglang_server_args.mem_fraction_static = self.config.gpu_memory_utilization
173-
# note: 非共卡模式下无需设置,共卡模式下需要offload必须设置,否则显存释放不了
174-
sglang_server_args.enable_memory_saver = True
175-
199+
tp_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.tensor_parallel_size
200+
ep_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.expert_parallel_size
201+
nnodes = max(1, num_gpus_per_engine // self.config.gpus_per_node)
202+
node_rank = self.rank // self.config.gpus_per_node if nnodes > 1 else 0
203+
init_kwargs = dict(
204+
model_path=self.config.model_path,
205+
trust_remote_code=True,
206+
host=self.host,
207+
port=self.server_port,
208+
nccl_port=self.nccl_port,
209+
dist_init_addr=self.dist_init_addr,
210+
base_gpu_id=self.rank % self.config.gpus_per_node,
211+
gpu_id_step=1,
212+
nnodes=nnodes,
213+
node_rank=node_rank,
214+
skip_server_warmup=True,
215+
mem_fraction_static=self.config.gpu_memory_utilization,
216+
enable_memory_saver=True,
217+
max_running_requests=self.config.rollout_max_batch_size_per_instance,
218+
log_level=log_level,
219+
log_level_http=log_level_http,
220+
tp_size=tp_size,
221+
ep_size=ep_size,
222+
)
176223
if self.enable_return_routed_experts:
177-
sglang_server_args.enable_return_routed_experts = True
178-
179-
sglang_server_args.max_running_requests = self.config.rollout_max_batch_size_per_instance
180-
sglang_server_args.log_level = log_level
181-
sglang_server_args.log_level_http = log_level_http
224+
init_kwargs["enable_return_routed_experts"] = True
182225
if XTUNER_DETERMINISTIC:
183-
sglang_server_args.enable_deterministic_inference = True
184-
sglang_server_args.rl_on_policy_target = True
185-
if self.config.expert_parallel_size > 1:
186-
sglang_server_args.tp_size = num_gpus_per_engine
187-
sglang_server_args.ep_size = num_gpus_per_engine
188-
else:
189-
sglang_server_args.tp_size = self.config.tensor_parallel_size
190-
sglang_server_args.ep_size = self.config.expert_parallel_size
226+
init_kwargs["enable_deterministic_inference"] = True
227+
init_kwargs["rl_on_policy_target"] = "fsdp"
228+
init_kwargs["attention_backend"] = "fa3"
229+
init_kwargs["random_seed"] = self.config.random_seed
230+
# SGLang's deterministic mode does not currently force-disable every
231+
# performance-oriented runtime path. For long MoE rollouts we still
232+
# observed rare trajectory divergence, so explicitly turn off the
233+
# scheduler/cache/graph features that can perturb execution order.
234+
init_kwargs["disable_radix_cache"] = True
235+
init_kwargs["disable_overlap_schedule"] = True
236+
init_kwargs["disable_cuda_graph"] = True
191237

192-
if grammar_backend is not None:
193-
sglang_server_args.grammar_backend = grammar_backend
238+
# Forward supported sglang_* extra configs to ServerArgs directly.
239+
server_arg_fields = getattr(ServerArgs, "__dataclass_fields__", {})
240+
for key, value in sglang_config_kwargs.items():
241+
if key in server_arg_fields:
242+
init_kwargs[key] = value
243+
else:
244+
self.logger.warning(f"Ignore unknown SGLang server arg: {key}={value!r}")
245+
246+
# Qwen3-MoE in sglang 0.5.9 can hit native rotary + fused KV buffer incompatibility
247+
# during server startup unless fused qk_norm_rope is enabled.
248+
if self.model_type == "qwen3_moe" and "enable_fused_qk_norm_rope" not in sglang_config_kwargs:
249+
init_kwargs["enable_fused_qk_norm_rope"] = True
250+
self.logger.info("Auto enable SGLang enable_fused_qk_norm_rope for qwen3_moe.")
194251

195252
if self.config.context_length is not None:
196-
sglang_server_args.context_length = self.config.context_length
253+
init_kwargs["context_length"] = self.config.context_length
197254

198-
if sglang_server_args.nnodes > 1:
199-
sglang_server_args.node_rank = self.rank // self.config.gpus_per_node
200-
else:
201-
sglang_server_args.node_rank = 0
255+
sglang_server_args = ServerArgs(**init_kwargs)
202256

203257
return sglang_server_args
204258

xtuner/v1/ray/rollout/vllm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def reset_prefix_cache(self):
101101
# todo
102102
pass
103103

104+
def _decode_routed_experts(self, routed_experts: Any, meta_info: Dict[str, Any]):
105+
raise NotImplementedError
106+
104107
def _transform_rollout_config_to_server_configs(self) -> Namespace:
105108
# use vllm FlexibleArgumentParser to parse the config
106109
# and return the args as the default server config

xtuner/v1/ray/rollout/worker.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import numpy as np
1414
import ray
1515
import requests # type: ignore[import-untyped]
16-
import torch
1716
from packaging.version import Version
17+
from ray import ObjectRef
1818
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
1919

2020
from transformers import AutoTokenizer
@@ -153,6 +153,9 @@ def init(self, dist_init_addr: str = ""):
153153
self.launch_server()
154154
return (self.rank, self.server_url)
155155

156+
def _decode_routed_experts(self, routed_experts: Any, meta_info: Dict[str, Any]) -> Any:
157+
return routed_experts
158+
156159
def set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]):
157160
self.engine_rank_mesh_array = engine_rank_mesh_array
158161

@@ -360,7 +363,7 @@ async def rollout_task(
360363
) -> RLRolloutResponseItem:
361364
uid = extra_info.get("action_id", str(uuid.uuid4()))
362365
action_id = extra_info.get("action_id", str(uuid.uuid4()))
363-
root_id = extra_info.get("action_id", str(uuid.uuid4()))
366+
root_id = extra_info.get("root_id", str(uuid.uuid4()))
364367
response = None
365368
cur_retry_times = 0
366369

@@ -568,28 +571,18 @@ async def _handle_non_stream_response(
568571
routed_experts = response["meta_info"].pop("routed_experts") # token[layer[expert]]
569572
if routed_experts is not None and not exist_history_routed_experts:
570573
# 不存在历史专家,先把当前专家存起来
571-
if isinstance(routed_experts, str):
572-
import base64
573-
574-
data = base64.b64decode(routed_experts)
575-
routed_experts = ray.cloudpickle.loads(data)
576-
del data
577-
else:
578-
routed_experts = torch.tensor(routed_experts) # n,layer,expert
574+
routed_experts = self._decode_routed_experts(routed_experts, response["meta_info"])
575+
if not isinstance(routed_experts, ObjectRef):
579576
routed_experts = ray.put(routed_experts)
580577
extra_info["routed_experts"] = routed_experts
581578
elif routed_experts is not None and exist_history_routed_experts:
582579
# 存在历史专家,则不进行put 操作,直接进行concat
583-
if isinstance(routed_experts, str):
584-
import base64
585-
586-
data = base64.b64decode(routed_experts)
587-
routed_experts = ray.cloudpickle.loads(data)
580+
routed_experts = self._decode_routed_experts(routed_experts, response["meta_info"])
581+
if isinstance(routed_experts, ObjectRef):
588582
cur_routed_experts = await routed_experts # n,layer,expert
589583
ray.internal.free(routed_experts, local_only=False)
590584
del data
591585
else:
592-
routed_experts = torch.tensor(routed_experts) # n,layer,expert
593586
cur_routed_experts = routed_experts
594587

595588
history_routed_experts = await input_extra_info["routed_experts"] # n, layer, expert

xtuner/v1/ray/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import hashlib
23
import importlib
34
import socket
45
from asyncio import AbstractEventLoop, Task
@@ -71,7 +72,9 @@ def _is_port_available(check_socket: socket.socket, port: int) -> bool:
7172

7273

7374
@ray.remote
74-
def find_master_addr_and_port(nums=1, start_port=None, end_port=None):
75+
def find_master_addr_and_port(
76+
nums: int = 1, start_port: Optional[int] = None, end_port: Optional[int] = None
77+
) -> tuple[str, int] | tuple[str, list[int]]:
7578
"""Finds an available master address and a specified number of ports.
7679
7780
This remote function gets the node's IP address and binds to one or more
@@ -219,3 +222,16 @@ def free_object_refs(refs: List[ObjectRef]) -> None:
219222
ray._private.internal_api.free(valid_refs, local_only=False)
220223
except Exception:
221224
ray.internal.free(valid_refs, local_only=False)
225+
def deterministic_item_sort_key(sample) -> tuple[int, int, int, int]:
226+
return (
227+
sample.uid.root_id,
228+
sample.uid.action_id,
229+
sample.uid.observation_id,
230+
sample.uid.version,
231+
)
232+
233+
234+
def build_deterministic_session_id(environment: str, sample) -> int:
235+
session_key = f"{environment}|{sample.uid.root_id}|{sample.uid.action_id}|{sample.uid.observation_id}"
236+
session_id = int.from_bytes(hashlib.sha256(session_key.encode("utf-8")).digest()[:8], "big")
237+
return session_id or 1

0 commit comments

Comments
 (0)