1+ import base64
12import os
23from typing import Any , Dict , List , Union
34
5+ import numpy as np
46import requests
7+ import torch
58from urllib3 .exceptions import NewConnectionError
69
7- from transformers import AutoTokenizer
10+ from transformers import AutoConfig , AutoTokenizer
811from xtuner .v1 .ray .config import RolloutConfig
912from xtuner .v1 .utils import XTUNER_DETERMINISTIC
1013
@@ -29,6 +32,11 @@ def __init__(
2932 self .endpoints ["generate" ] = "generate"
3033 self .endpoints ["v1/chat/completions" ] = "v1/chat/completions"
3134 self .tokenizer = AutoTokenizer .from_pretrained (self .config .model_path , trust_remote_code = True )
35+ self .model_config = AutoConfig .from_pretrained (self .config .model_path , trust_remote_code = True )
36+ text_config = getattr (self .model_config , "text_config" , self .model_config )
37+ self .model_type = getattr (text_config , "model_type" , getattr (self .model_config , "model_type" , None ))
38+ self .routed_experts_num_hidden_layers = getattr (text_config , "num_hidden_layers" , None )
39+ self .routed_experts_num_experts_per_tok = getattr (text_config , "num_experts_per_tok" , None )
3240 self .api_keys = self .config .api_key
3341 self .model_name = self .config .model_name
3442 self .enable_return_routed_experts = self .config .enable_return_routed_experts
@@ -141,6 +149,37 @@ def reset_prefix_cache(self):
141149 self .flush_cache ()
142150 return self ._make_request ("release_memory_occupation" )
143151
152+ def _decode_routed_experts (self , routed_experts : Any , meta_info : Dict [str , Any ]):
153+ if not isinstance (routed_experts , str ):
154+ return super ()._decode_routed_experts (routed_experts , meta_info )
155+
156+ prompt_tokens = meta_info .get ("prompt_tokens" , 0 )
157+ completion_tokens = meta_info .get ("completion_tokens" , 0 )
158+ num_tokens = prompt_tokens + completion_tokens - 1
159+ assert num_tokens > 0 , (
160+ f"Unexpected routed_experts token count: prompt_tokens={ prompt_tokens } , completion_tokens={ completion_tokens } "
161+ )
162+ assert self .routed_experts_num_hidden_layers is not None , (
163+ "num_hidden_layers is required to decode routed_experts"
164+ )
165+ assert self .routed_experts_num_experts_per_tok is not None , (
166+ "num_experts_per_tok is required to decode routed_experts"
167+ )
168+
169+ routed_experts_flat = np .frombuffer (base64 .b64decode (routed_experts ), dtype = np .int32 )
170+ expected_size = num_tokens * self .routed_experts_num_hidden_layers * self .routed_experts_num_experts_per_tok
171+ assert routed_experts_flat .size == expected_size , (
172+ f"Unexpected routed_experts size { routed_experts_flat .size } , expected { expected_size } . "
173+ f"num_tokens={ num_tokens } , num_hidden_layers={ self .routed_experts_num_hidden_layers } , "
174+ f"num_experts_per_tok={ self .routed_experts_num_experts_per_tok } "
175+ )
176+ routed_experts_array = routed_experts_flat .reshape (
177+ num_tokens ,
178+ self .routed_experts_num_hidden_layers ,
179+ self .routed_experts_num_experts_per_tok ,
180+ )
181+ return torch .from_numpy (routed_experts_array .copy ())
182+
144183 def _transform_rollout_config_to_server_configs (self ):
145184 # remove the CUDA_VISIBLE_DEVICES set by ray and use base_gpu_id
146185 os .environ .pop ("CUDA_VISIBLE_DEVICES" , None )
@@ -150,55 +189,70 @@ def _transform_rollout_config_to_server_configs(self):
150189 sglang_config_kwargs = {
151190 k .replace ("sglang_" , "" ): v for k , v in extra_config .items () if k .startswith ("sglang_" )
152191 }
153- grammar_backend = sglang_config_kwargs .get (
154- "grammar_backend" , None
155- ) # for intern-s1 series models, have to set the grammar_backend to "none"
156192 log_level = sglang_config_kwargs .get ("log_level" , "error" )
157193 log_level_http = sglang_config_kwargs .get ("log_level_http" , "error" )
158- sglang_server_args = ServerArgs (model_path = self .config .model_path , trust_remote_code = True )
159194 num_gpus_per_engine = (
160195 self .config .expert_parallel_size
161196 if self .config .expert_parallel_size > 1
162197 else self .config .tensor_parallel_size
163198 )
164- sglang_server_args .host = self .host
165- sglang_server_args .port = self .server_port
166- sglang_server_args .nccl_port = self .nccl_port
167- sglang_server_args .dist_init_addr = self .dist_init_addr
168- sglang_server_args .base_gpu_id = self .rank % self .config .gpus_per_node
169- sglang_server_args .gpu_id_step = 1
170- sglang_server_args .nnodes = max (1 , num_gpus_per_engine // self .config .gpus_per_node )
171- sglang_server_args .skip_server_warmup = True
172- sglang_server_args .mem_fraction_static = self .config .gpu_memory_utilization
173- # note: 非共卡模式下无需设置,共卡模式下需要offload必须设置,否则显存释放不了
174- sglang_server_args .enable_memory_saver = True
175-
199+ tp_size = num_gpus_per_engine if self .config .expert_parallel_size > 1 else self .config .tensor_parallel_size
200+ ep_size = num_gpus_per_engine if self .config .expert_parallel_size > 1 else self .config .expert_parallel_size
201+ nnodes = max (1 , num_gpus_per_engine // self .config .gpus_per_node )
202+ node_rank = self .rank // self .config .gpus_per_node if nnodes > 1 else 0
203+ init_kwargs = dict (
204+ model_path = self .config .model_path ,
205+ trust_remote_code = True ,
206+ host = self .host ,
207+ port = self .server_port ,
208+ nccl_port = self .nccl_port ,
209+ dist_init_addr = self .dist_init_addr ,
210+ base_gpu_id = self .rank % self .config .gpus_per_node ,
211+ gpu_id_step = 1 ,
212+ nnodes = nnodes ,
213+ node_rank = node_rank ,
214+ skip_server_warmup = True ,
215+ mem_fraction_static = self .config .gpu_memory_utilization ,
216+ enable_memory_saver = True ,
217+ max_running_requests = self .config .rollout_max_batch_size_per_instance ,
218+ log_level = log_level ,
219+ log_level_http = log_level_http ,
220+ tp_size = tp_size ,
221+ ep_size = ep_size ,
222+ )
176223 if self .enable_return_routed_experts :
177- sglang_server_args .enable_return_routed_experts = True
178-
179- sglang_server_args .max_running_requests = self .config .rollout_max_batch_size_per_instance
180- sglang_server_args .log_level = log_level
181- sglang_server_args .log_level_http = log_level_http
224+ init_kwargs ["enable_return_routed_experts" ] = True
182225 if XTUNER_DETERMINISTIC :
183- sglang_server_args .enable_deterministic_inference = True
184- sglang_server_args .rl_on_policy_target = True
185- if self .config .expert_parallel_size > 1 :
186- sglang_server_args .tp_size = num_gpus_per_engine
187- sglang_server_args .ep_size = num_gpus_per_engine
188- else :
189- sglang_server_args .tp_size = self .config .tensor_parallel_size
190- sglang_server_args .ep_size = self .config .expert_parallel_size
226+ init_kwargs ["enable_deterministic_inference" ] = True
227+ init_kwargs ["rl_on_policy_target" ] = "fsdp"
228+ init_kwargs ["attention_backend" ] = "fa3"
229+ init_kwargs ["random_seed" ] = self .config .random_seed
230+ # SGLang's deterministic mode does not currently force-disable every
231+ # performance-oriented runtime path. For long MoE rollouts we still
232+ # observed rare trajectory divergence, so explicitly turn off the
233+ # scheduler/cache/graph features that can perturb execution order.
234+ init_kwargs ["disable_radix_cache" ] = True
235+ init_kwargs ["disable_overlap_schedule" ] = True
236+ init_kwargs ["disable_cuda_graph" ] = True
191237
192- if grammar_backend is not None :
193- sglang_server_args .grammar_backend = grammar_backend
238+ # Forward supported sglang_* extra configs to ServerArgs directly.
239+ server_arg_fields = getattr (ServerArgs , "__dataclass_fields__" , {})
240+ for key , value in sglang_config_kwargs .items ():
241+ if key in server_arg_fields :
242+ init_kwargs [key ] = value
243+ else :
244+ self .logger .warning (f"Ignore unknown SGLang server arg: { key } ={ value !r} " )
245+
246+ # Qwen3-MoE in sglang 0.5.9 can hit native rotary + fused KV buffer incompatibility
247+ # during server startup unless fused qk_norm_rope is enabled.
248+ if self .model_type == "qwen3_moe" and "enable_fused_qk_norm_rope" not in sglang_config_kwargs :
249+ init_kwargs ["enable_fused_qk_norm_rope" ] = True
250+ self .logger .info ("Auto enable SGLang enable_fused_qk_norm_rope for qwen3_moe." )
194251
195252 if self .config .context_length is not None :
196- sglang_server_args . context_length = self .config .context_length
253+ init_kwargs [ " context_length" ] = self .config .context_length
197254
198- if sglang_server_args .nnodes > 1 :
199- sglang_server_args .node_rank = self .rank // self .config .gpus_per_node
200- else :
201- sglang_server_args .node_rank = 0
255+ sglang_server_args = ServerArgs (** init_kwargs )
202256
203257 return sglang_server_args
204258
0 commit comments