1919 RolloutConfig ,
2020 get_rollout_worker_base_cls ,
2121)
22- from .worker_registry import RolloutWorkerMetadata , RolloutWorkerRegistry
22+ from .worker_registry import RolloutWorkerRegistry
2323
2424
2525# Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller
@@ -61,19 +61,15 @@ def __init__(
6161 registry = self .registry ,
6262 worker_lifecycle_listeners = [self .proxy_manager ] if self .proxy_manager is not None else None ,
6363 )
64- self .health_manager .start ()
64+ self .health_manager .start_background_checks ()
6565
66- def get_rollout_metadata (self ) -> RolloutWorkerMetadata :
66+ def get_rollout_metadata (self ) -> dict :
6767 """Get information about the current rollout setup.
6868
6969 Returns:
70- dict: A dictionary containing the engine mesh list, server URL
71- dictionary, and the rollout configuration.
70+ Legacy trainer/update-weight rollout metadata dictionary.
7271 """
73- rollout_metadata = self .registry .training_metadata_snapshot ()
74- self .logger .info (f"Rollout worker server URLs: { rollout_metadata ['server_url_dict' ]} " )
75- self .logger .info (f"Rollout worker session server URLs: { rollout_metadata ['worker_session_url_dict' ]} " )
76- return rollout_metadata
72+ return self .registry .metadata ().to_legacy ()
7773
7874 def register_active_workers_to_proxy (self ) -> None :
7975 if self .proxy_manager is None :
@@ -133,7 +129,7 @@ def set_enable_partial_rollout(self, enable: bool) -> None:
133129 )
134130
135131 def pause_generation (self ):
136- self .health_manager .pause ()
132+ self .health_manager .pause_background_checks ()
137133 active_workers = self .registry .active_workers ()
138134 futures = [
139135 worker .actor .pause_generation .remote () # type: ignore[attr-defined]
@@ -164,7 +160,7 @@ async def restart_inactive_workers(self):
164160
165161 def continue_generation (self ):
166162 self ._broadcast_to_active_workers ("continue_generation" )
167- self .health_manager .resume ()
163+ self .health_manager .resume_background_checks ()
168164
169165 def offload (self ):
170166 self ._broadcast_to_active_workers ("offload" )
@@ -181,7 +177,7 @@ def onload_kvcache(self):
181177
182178 def shutdown (self ):
183179 """Shut down all rollout workers tracked by the controller."""
184- self .health_manager .stop ()
180+ self .health_manager .stop_background_checks ()
185181 actors = self .registry .all_actors ()
186182 ray .get (
187183 [actor .shutdown .remote (stop_session_server = True ) for actor in actors ], # type: ignore[attr-defined]
@@ -203,13 +199,16 @@ def _build_remote_worker_cls(self, worker_base_cls):
203199 },
204200 )(worker_base_cls )
205201
206- def _init_workers (self , placement_group : PlacementGroup ) -> RolloutWorkerRegistry :
202+ def _init_workers (
203+ self ,
204+ placement_group : PlacementGroup ,
205+ ) -> RolloutWorkerRegistry :
207206 """Initializes and configures the pool of RolloutWorker actors.
208207
209208 This method follows the same high-level flow as the legacy implementation:
210- create workers, initialize worker-local ports, build engine groups,
211- select workers that launch rollout servers, launch servers, and
212- expose request-entrypoint server URLs to rollout traffic.
209+ create workers, initialize worker-local ports, build the bound rollout
210+ topology, launch rollout servers, and expose request-entrypoint server
211+ URLs to rollout traffic.
213212
214213 Returns:
215214 A registry containing all server-process workers and the public
@@ -222,79 +221,59 @@ def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistr
222221 workers , rank_bundle_idx_list = AutoAcceleratorWorkers .from_placement_group (
223222 worker_cls , self .config , placement_group
224223 )
225- rank_to_actor = {rank : worker for (rank , _ ), worker in zip (rank_bundle_idx_list , workers )}
226-
227- # Reserve worker-local ports for all actors first. build_engine_launch_specs
228- # uses the returned addresses to bind each ServerProcessSpec to its
229- # logical engine rendezvous address; only server-process owners call init().
230- rank_to_dist_init_addr = {
231- rank : dist_init_addr
232- for (rank , _ ), dist_init_addr in zip (
233- rank_bundle_idx_list ,
234- ray .get ([worker .init_dist_port .remote () for worker in workers ]), # type: ignore[attr-defined]
235- )
224+ dist_init_results = ray .get (
225+ [
226+ worker .init_dist_port .remote () # type: ignore[attr-defined]
227+ for worker in workers
228+ ]
229+ )
230+ rank_to_worker = {
231+ rank : worker for worker , (rank , _dist_init_addr ) in zip (workers , dist_init_results , strict = True )
236232 }
233+ rank_to_dist_init_addr = dict (dist_init_results )
237234
238- # Build engine groups and server-process specs from the rank/bundle mapping.
239- engine_launch_specs = worker_base_cls .build_engine_launch_specs (
235+ rollout_topology = worker_base_cls .build_rollout_topology (
240236 self .config ,
241237 rank_bundle_idx_list ,
242238 rank_to_dist_init_addr ,
243239 )
244- # Keep the public metadata mesh compatible with origin/main. Backends
245- # may expose a different update-weight mesh than their internal launch
246- # topology, e.g. LMDeploy EP has one logical engine but one public entry
247- # per request-serving EP rank.
248- engine_rank_mesh_array = worker_base_cls .build_metadata_engine_rank_mesh_array (engine_launch_specs )
249-
250- # Launch every server process described by the backend-specific specs.
251- server_rank_to_url = dict (
252- ray .get (
253- [
254- rank_to_actor [server_process .worker_rank ].init .remote ( # type: ignore[attr-defined]
255- engine_launch_spec = engine_spec ,
256- )
257- for engine_spec in engine_launch_specs
258- for server_process in engine_spec .server_processes
259- ]
260- )
240+ server_launch_specs = rollout_topology .server_launch_specs ()
241+ server_workers = tuple (
242+ (launch_spec , rank_to_worker [launch_spec .worker_rank ]) for launch_spec in server_launch_specs
243+ )
244+
245+ ray .get (
246+ [
247+ worker .bind_server_launch_spec .remote (launch_spec ) # type: ignore[attr-defined]
248+ for launch_spec , worker in server_workers
249+ ]
261250 )
262- session_url_by_rank = dict (
251+ init_results = tuple (
263252 ray .get (
264253 [
265- (
266- rank_to_actor [server_process .worker_rank ].get_session_server_info .remote () # type: ignore[attr-defined]
267- )
268- for engine_spec in engine_launch_specs
269- for server_process in engine_spec .server_processes
254+ worker .init .remote () # type: ignore[attr-defined]
255+ for _launch_spec , worker in server_workers
270256 ]
271257 )
272258 )
259+ registry = RolloutWorkerRegistry (rollout_topology = rollout_topology , rollout_config = self .config )
260+ for init_result in init_results :
261+ registry .register_started_server (
262+ rank = init_result .rank ,
263+ actor = rank_to_worker [init_result .rank ],
264+ server_url = init_result .server_url ,
265+ session_url = init_result .session_url ,
266+ )
273267
274- registry = RolloutWorkerRegistry (
275- engine_rank_mesh_array = engine_rank_mesh_array ,
276- rollout_config = self .config ,
268+ rollout_metadata = registry .metadata ()
269+ legacy_metadata = rollout_metadata .to_legacy ()
270+ self .logger .info (
271+ "Rollout worker registry snapshot: "
272+ f"server_urls={ legacy_metadata ['server_url_dict' ]} , "
273+ f"session_urls={ legacy_metadata ['worker_session_url_dict' ]} , "
274+ f"server_process_urls={ [worker .url for worker in registry .all_workers ()]} , "
275+ f"lifecycle_groups={ registry .lifecycle_groups ()} "
277276 )
278- for engine_spec in engine_launch_specs :
279- for server_process in engine_spec .server_processes :
280- rank = server_process .worker_rank
281- url = server_rank_to_url [rank ]
282- session_url = session_url_by_rank .get (rank )
283- if server_process .accepts_rollout_requests and session_url is None :
284- raise RuntimeError (f"Rollout worker rank={ rank } did not return session server URL during init." )
285- registry .register_started_server (
286- rank = rank ,
287- actor = rank_to_actor [rank ],
288- server_url = url ,
289- session_url = session_url ,
290- lifecycle_group_ranks = engine_spec .server_worker_ranks ,
291- is_request_entrypoint = server_process .accepts_rollout_requests ,
292- )
293-
294- server_process_workers_info = registry .all_workers ()
295- self .logger .info (f"Rollout server-process worker URLs: { [info .url for info in server_process_workers_info ]} " )
296- lifecycle_groups = sorted ({info .lifecycle_group_ranks for info in server_process_workers_info })
297- self .logger .info (f"Rollout worker lifecycle groups: { lifecycle_groups } " )
298277 return registry
299278
300279
0 commit comments