@@ -115,37 +115,28 @@ async def wait_to_model_ready(self):
115115 self .model_rpc_servers = []
116116 # 用于 kv move 管理进程 和 推理进程进行task信息的交互。
117117 self .info_queue : mp .Queue = mp .Queue ()
118- self .rpc_event = multiprocessing .Event ()
119- self .rpc_finished_event = multiprocessing .Event ()
120118
121119 assert (self .world_size % self .nnodes ) == 0
122120 node_world_size = self .world_size // self .nnodes
123121
124122 # Create tasks for parallel startup
125123 tasks = []
126124 for rank_id in range (self .node_rank * node_world_size , (self .node_rank + 1 ) * node_world_size ):
125+ rank_in_node = rank_id % node_world_size
127126 task = asyncio .create_task (
128127 start_model_process (
129128 args = self .args ,
130129 rank = rank_id ,
131- rank_in_node = rank_id % node_world_size ,
130+ rank_in_node = rank_in_node ,
132131 node_world_size = node_world_size ,
133- rpc_event = self .rpc_event ,
134- rpc_finished_event = self .rpc_finished_event ,
135132 info_queue = self .info_queue ,
136133 router_lock = self .router_lock ,
137134 )
138135 )
139136 tasks .append (task )
140137
141138 # Wait for all tasks to complete in parallel
142- self .model_rpc_servers = await asyncio .gather (* tasks )
143-
144- self .model_rpc_client = ModelRpcClient (
145- rpc_event = self .rpc_event ,
146- rpc_finished_event = self .rpc_finished_event ,
147- )
148-
139+ self .model_rpc_clients = await asyncio .gather (* tasks )
149140 kvargs = {
150141 "args" : self .args ,
151142 "rank_id" : None , # 由后续处理填充真实数据
@@ -178,10 +169,19 @@ async def wait_to_model_ready(self):
178169 "pd_rpyc_ports" : self .args .pd_node_infer_rpyc_ports , # 非 pd 模式可以不设置
179170 }
180171
181- await self .model_rpc_client .init_model (kvargs = kvargs )
172+ # Call init_model on all model processes
173+ init_tasks = []
174+ for model_rpc_client in self .model_rpc_clients :
175+ init_tasks .append (model_rpc_client .init_model (kvargs = kvargs ))
176+ await asyncio .gather (* init_tasks )
182177
183178 if self .max_total_token_num is None :
184- self .max_total_token_num = await self .model_rpc_client .get_max_total_token_num ()
179+ _tasks = []
180+ for model_rpc_client in self .model_rpc_clients :
181+ _tasks .append (model_rpc_client .get_max_total_token_num ())
182+ _nums = await asyncio .gather (* _tasks )
183+ assert max (_nums ) == min (_nums ), "all rank must have same token num"
184+ self .max_total_token_num = _nums [0 ]
185185 self .args .max_total_token_num = self .max_total_token_num
186186 if not self .args .disable_dynamic_prompt_cache :
187187 self .radix_cache_client = RadixCacheReadOnlyClient (
0 commit comments