Skip to content

Commit 53c3a19

Browse files
committed
Refactor rollout topology binding
1 parent e92678c commit 53c3a19

13 files changed

Lines changed: 616 additions & 636 deletions

tests/rl/test_multi_task_agent_loop_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def _fake_agent_loop():
167167
rollout_ctl = MagicMock()
168168
rollout_ctl.continue_generation.remote = AsyncMock()
169169
rollout_ctl.pause_generation.remote = AsyncMock()
170-
rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
171170
agent_loop = MagicMock()
172171
agent_loop.rollout_ctl = rollout_ctl
173172
return agent_loop
@@ -177,7 +176,6 @@ def _fake_rollout_controller():
177176
rollout_controller = MagicMock()
178177
rollout_controller.continue_generation.remote = AsyncMock()
179178
rollout_controller.pause_generation.remote = AsyncMock()
180-
rollout_controller.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
181179
return rollout_controller
182180

183181

tests/rl/test_producer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None):
100100
mock_agent_loop = MagicMock()
101101
mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None)
102102
mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None)
103-
mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
104103

105104
async def mock_pause():
106105
await mock_agent_loop.rollout_ctl.pause_generation.remote()

tests/rl/test_rl_colocate_trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def _build_fake_rollout_controller():
7575
rollout_ctl = MagicMock()
7676
rollout_ctl.continue_generation.remote = AsyncMock(return_value=None)
7777
rollout_ctl.pause_generation.remote = AsyncMock(return_value=None)
78-
rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
7978
return rollout_ctl
8079

8180

tests/rl/test_rl_trainer_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(self):
104104
self.restart_inactive_workers = _RemoteMethod(return_value="rollout_restarted")
105105
self.onload_weights = _RemoteMethod(return_value="weights_loaded")
106106
self.onload_kvcache = _RemoteMethod(return_value="kvcache_loaded")
107-
self.get_rollout_metadata = _RemoteMethod(return_value={"server_url_dict": {}})
107+
self.get_rollout_metadata = _RemoteMethod(return_value={})
108108
self.set_enable_partial_rollout = _RemoteMethod(return_value=None)
109109
self.validate_registered_workers_to_proxy = _RemoteMethod(return_value=_AwaitableValue(None))
110110

tests/rl/test_rollout_logic.py

Lines changed: 177 additions & 34 deletions
Large diffs are not rendered by default.

xtuner/v1/rl/rollout/controller.py

Lines changed: 54 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
RolloutConfig,
2020
get_rollout_worker_base_cls,
2121
)
22-
from .worker_registry import RolloutWorkerMetadata, RolloutWorkerRegistry
22+
from .worker_registry import RolloutWorkerRegistry
2323

2424

2525
# Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller
@@ -61,19 +61,15 @@ def __init__(
6161
registry=self.registry,
6262
worker_lifecycle_listeners=[self.proxy_manager] if self.proxy_manager is not None else None,
6363
)
64-
self.health_manager.start()
64+
self.health_manager.start_background_checks()
6565

66-
def get_rollout_metadata(self) -> RolloutWorkerMetadata:
66+
def get_rollout_metadata(self) -> dict:
6767
"""Get information about the current rollout setup.
6868
6969
Returns:
70-
dict: A dictionary containing the engine mesh list, server URL
71-
dictionary, and the rollout configuration.
70+
Legacy trainer/update-weight rollout metadata dictionary.
7271
"""
73-
rollout_metadata = self.registry.training_metadata_snapshot()
74-
self.logger.info(f"Rollout worker server URLs: {rollout_metadata['server_url_dict']}")
75-
self.logger.info(f"Rollout worker session server URLs: {rollout_metadata['worker_session_url_dict']}")
76-
return rollout_metadata
72+
return self.registry.metadata().to_legacy()
7773

7874
def register_active_workers_to_proxy(self) -> None:
7975
if self.proxy_manager is None:
@@ -133,7 +129,7 @@ def set_enable_partial_rollout(self, enable: bool) -> None:
133129
)
134130

135131
def pause_generation(self):
136-
self.health_manager.pause()
132+
self.health_manager.pause_background_checks()
137133
active_workers = self.registry.active_workers()
138134
futures = [
139135
worker.actor.pause_generation.remote() # type: ignore[attr-defined]
@@ -164,7 +160,7 @@ async def restart_inactive_workers(self):
164160

165161
def continue_generation(self):
166162
self._broadcast_to_active_workers("continue_generation")
167-
self.health_manager.resume()
163+
self.health_manager.resume_background_checks()
168164

169165
def offload(self):
170166
self._broadcast_to_active_workers("offload")
@@ -181,7 +177,7 @@ def onload_kvcache(self):
181177

182178
def shutdown(self):
183179
"""Shut down all rollout workers tracked by the controller."""
184-
self.health_manager.stop()
180+
self.health_manager.stop_background_checks()
185181
actors = self.registry.all_actors()
186182
ray.get(
187183
[actor.shutdown.remote(stop_session_server=True) for actor in actors], # type: ignore[attr-defined]
@@ -203,13 +199,16 @@ def _build_remote_worker_cls(self, worker_base_cls):
203199
},
204200
)(worker_base_cls)
205201

206-
def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistry:
202+
def _init_workers(
203+
self,
204+
placement_group: PlacementGroup,
205+
) -> RolloutWorkerRegistry:
207206
"""Initializes and configures the pool of RolloutWorker actors.
208207
209208
This method follows the same high-level flow as the legacy implementation:
210-
create workers, initialize worker-local ports, build engine groups,
211-
select workers that launch rollout servers, launch servers, and
212-
expose request-entrypoint server URLs to rollout traffic.
209+
create workers, initialize worker-local ports, build the bound rollout
210+
topology, launch rollout servers, and expose request-entrypoint server
211+
URLs to rollout traffic.
213212
214213
Returns:
215214
A registry containing all server-process workers and the public
@@ -222,79 +221,59 @@ def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistr
222221
workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group(
223222
worker_cls, self.config, placement_group
224223
)
225-
rank_to_actor = {rank: worker for (rank, _), worker in zip(rank_bundle_idx_list, workers)}
226-
227-
# Reserve worker-local ports for all actors first. build_engine_launch_specs
228-
# uses the returned addresses to bind each ServerProcessSpec to its
229-
# logical engine rendezvous address; only server-process owners call init().
230-
rank_to_dist_init_addr = {
231-
rank: dist_init_addr
232-
for (rank, _), dist_init_addr in zip(
233-
rank_bundle_idx_list,
234-
ray.get([worker.init_dist_port.remote() for worker in workers]), # type: ignore[attr-defined]
235-
)
224+
dist_init_results = ray.get(
225+
[
226+
worker.init_dist_port.remote() # type: ignore[attr-defined]
227+
for worker in workers
228+
]
229+
)
230+
rank_to_worker = {
231+
rank: worker for worker, (rank, _dist_init_addr) in zip(workers, dist_init_results, strict=True)
236232
}
233+
rank_to_dist_init_addr = dict(dist_init_results)
237234

238-
# Build engine groups and server-process specs from the rank/bundle mapping.
239-
engine_launch_specs = worker_base_cls.build_engine_launch_specs(
235+
rollout_topology = worker_base_cls.build_rollout_topology(
240236
self.config,
241237
rank_bundle_idx_list,
242238
rank_to_dist_init_addr,
243239
)
244-
# Keep the public metadata mesh compatible with origin/main. Backends
245-
# may expose a different update-weight mesh than their internal launch
246-
# topology, e.g. LMDeploy EP has one logical engine but one public entry
247-
# per request-serving EP rank.
248-
engine_rank_mesh_array = worker_base_cls.build_metadata_engine_rank_mesh_array(engine_launch_specs)
249-
250-
# Launch every server process described by the backend-specific specs.
251-
server_rank_to_url = dict(
252-
ray.get(
253-
[
254-
rank_to_actor[server_process.worker_rank].init.remote( # type: ignore[attr-defined]
255-
engine_launch_spec=engine_spec,
256-
)
257-
for engine_spec in engine_launch_specs
258-
for server_process in engine_spec.server_processes
259-
]
260-
)
240+
server_launch_specs = rollout_topology.server_launch_specs()
241+
server_workers = tuple(
242+
(launch_spec, rank_to_worker[launch_spec.worker_rank]) for launch_spec in server_launch_specs
243+
)
244+
245+
ray.get(
246+
[
247+
worker.bind_server_launch_spec.remote(launch_spec) # type: ignore[attr-defined]
248+
for launch_spec, worker in server_workers
249+
]
261250
)
262-
session_url_by_rank = dict(
251+
init_results = tuple(
263252
ray.get(
264253
[
265-
(
266-
rank_to_actor[server_process.worker_rank].get_session_server_info.remote() # type: ignore[attr-defined]
267-
)
268-
for engine_spec in engine_launch_specs
269-
for server_process in engine_spec.server_processes
254+
worker.init.remote() # type: ignore[attr-defined]
255+
for _launch_spec, worker in server_workers
270256
]
271257
)
272258
)
259+
registry = RolloutWorkerRegistry(rollout_topology=rollout_topology, rollout_config=self.config)
260+
for init_result in init_results:
261+
registry.register_started_server(
262+
rank=init_result.rank,
263+
actor=rank_to_worker[init_result.rank],
264+
server_url=init_result.server_url,
265+
session_url=init_result.session_url,
266+
)
273267

274-
registry = RolloutWorkerRegistry(
275-
engine_rank_mesh_array=engine_rank_mesh_array,
276-
rollout_config=self.config,
268+
rollout_metadata = registry.metadata()
269+
legacy_metadata = rollout_metadata.to_legacy()
270+
self.logger.info(
271+
"Rollout worker registry snapshot: "
272+
f"server_urls={legacy_metadata['server_url_dict']}, "
273+
f"session_urls={legacy_metadata['worker_session_url_dict']}, "
274+
f"server_process_urls={[worker.url for worker in registry.all_workers()]}, "
275+
f"lifecycle_groups={registry.lifecycle_groups()}"
277276
)
278-
for engine_spec in engine_launch_specs:
279-
for server_process in engine_spec.server_processes:
280-
rank = server_process.worker_rank
281-
url = server_rank_to_url[rank]
282-
session_url = session_url_by_rank.get(rank)
283-
if server_process.accepts_rollout_requests and session_url is None:
284-
raise RuntimeError(f"Rollout worker rank={rank} did not return session server URL during init.")
285-
registry.register_started_server(
286-
rank=rank,
287-
actor=rank_to_actor[rank],
288-
server_url=url,
289-
session_url=session_url,
290-
lifecycle_group_ranks=engine_spec.server_worker_ranks,
291-
is_request_entrypoint=server_process.accepts_rollout_requests,
292-
)
293-
294-
server_process_workers_info = registry.all_workers()
295-
self.logger.info(f"Rollout server-process worker URLs: {[info.url for info in server_process_workers_info]}")
296-
lifecycle_groups = sorted({info.lifecycle_group_ranks for info in server_process_workers_info})
297-
self.logger.info(f"Rollout worker lifecycle groups: {lifecycle_groups}")
298277
return registry
299278

300279

xtuner/v1/rl/rollout/health_manager.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
self._worker_health_failure_counts: dict[int, int] = {}
6565
self._stopped = False
6666

67-
def start(self) -> None:
67+
def start_background_checks(self) -> None:
6868
health_thread_alive = self._thread is not None and self._thread.is_alive()
6969
if health_thread_alive:
7070
return
@@ -73,11 +73,11 @@ def start(self) -> None:
7373
self._stop_event = threading.Event()
7474
self._pause_event = threading.Event()
7575
self._pause_event.set()
76-
self._thread = threading.Thread(target=self._run_loop, daemon=True)
76+
self._thread = threading.Thread(target=self._run_background_check_loop, daemon=True)
7777
self._thread.start()
78-
logger.info("RolloutHealthManager started.")
78+
logger.info("RolloutHealthManager background checks started.")
7979

80-
def stop(self) -> None:
80+
def stop_background_checks(self) -> None:
8181
thread = self._thread
8282
if not thread:
8383
return
@@ -99,19 +99,19 @@ def stop(self) -> None:
9999
self._thread = None
100100
self._stop_event = None
101101
self._pause_event = None
102-
logger.info("RolloutHealthManager stopped.")
102+
logger.info("RolloutHealthManager background checks stopped.")
103103

104-
def pause(self) -> None:
104+
def pause_background_checks(self) -> None:
105105
if self._pause_event is None:
106106
return
107107
self._pause_event.set()
108-
logger.info("RolloutHealthManager paused.")
108+
logger.info("RolloutHealthManager background checks paused.")
109109

110-
def resume(self) -> None:
110+
def resume_background_checks(self) -> None:
111111
if self._pause_event is None:
112112
return
113113
self._pause_event.clear()
114-
logger.info("RolloutHealthManager resumed.")
114+
logger.info("RolloutHealthManager background checks resumed.")
115115

116116
def _is_paused(self) -> bool:
117117
return self._pause_event is None or self._pause_event.is_set()
@@ -121,20 +121,20 @@ def _is_stopping(self) -> bool:
121121
return self._stopped or (self._stop_event is not None and self._stop_event.is_set())
122122

123123
@contextmanager
124-
def _background_health_checks_paused(self):
124+
def _background_checks_paused(self):
125125
was_paused = self._is_paused()
126126
if not was_paused:
127-
self.pause()
127+
self.pause_background_checks()
128128
try:
129129
yield
130130
finally:
131131
if not was_paused:
132-
self.resume()
132+
self.resume_background_checks()
133133

134134
def restart_inactive_workers(self) -> None:
135135
"""Synchronously restart inactive groups before the next sync-step
136136
weight update."""
137-
with self._background_health_checks_paused():
137+
with self._background_checks_paused():
138138
with self._operation_lock:
139139
failed_groups = list(self._registry.claim_inactive_groups_for_recovery())
140140
if not failed_groups:
@@ -217,7 +217,7 @@ def check_and_shutdown_inactive_workers(self) -> None:
217217
"""Fail-fast health-check active workers, mark failures inactive, and
218218
shut down every non-active group so shared resources can be reused by
219219
training."""
220-
with self._background_health_checks_paused():
220+
with self._background_checks_paused():
221221
self._check_and_deactivate_failed_worker_groups(fail_fast=True)
222222
with self._operation_lock:
223223
inactive_groups = list(self._registry.claim_inactive_groups_for_recovery())
@@ -248,7 +248,7 @@ def check_and_shutdown_inactive_workers(self) -> None:
248248
)
249249
)
250250

251-
def run_once(self) -> None:
251+
def run_periodic_health_check(self) -> None:
252252
logger.debug("RolloutHealthManager running health checks for all workers.")
253253
checked_active_count = self._check_and_deactivate_failed_worker_groups()
254254
if self._registry.active_workers() or self._is_stopping():
@@ -367,9 +367,9 @@ async def check_workers(workers: list[WorkerSnapshot]) -> list[bool]:
367367

368368
return [keep_active_by_rank[worker.rank] for worker in workers_to_check]
369369

370-
def _run_loop(self) -> None:
370+
def _run_background_check_loop(self) -> None:
371371
assert self._stop_event is not None and self._pause_event is not None
372-
logger.info("RolloutHealthManager loop started.")
372+
logger.info("RolloutHealthManager background check loop started.")
373373

374374
while not self._stop_event.is_set():
375375
while self._pause_event.is_set() and not self._stop_event.is_set():
@@ -385,13 +385,13 @@ def _run_loop(self) -> None:
385385
continue
386386

387387
try:
388-
self.run_once()
388+
self.run_periodic_health_check()
389389
except RuntimeError:
390390
if self._is_stopping():
391391
break
392-
logger.exception("RolloutHealthManager run_once failed.")
392+
logger.exception("RolloutHealthManager periodic health check failed.")
393393
except Exception:
394-
logger.exception("RolloutHealthManager run_once failed.")
394+
logger.exception("RolloutHealthManager periodic health check failed.")
395395

396396
def _shutdown_worker_group(
397397
self,
@@ -486,8 +486,8 @@ def _restart_worker_group(
486486
)
487487
init_results = ray.get(
488488
[
489-
# init() reuses the immutable launch spec cached on each actor
490-
# during controller startup, including placement bundles and dist addr.
489+
# init() reuses the server launch spec bound during
490+
# controller startup.
491491
worker.actor.init.remote() # type: ignore[attr-defined]
492492
for worker in group.workers
493493
],
@@ -505,11 +505,11 @@ def _restart_worker_group(
505505
return False
506506

507507
for worker, init_result in zip(group.workers, init_results):
508-
init_rank, init_url = init_result
509-
if init_rank != worker.rank or init_url != worker.url:
508+
if init_result.rank != worker.rank or init_result.server_url != worker.url:
510509
logger.error(
511510
f"Rollout worker restart returned unexpected endpoint: rank={worker.rank}, "
512-
f"init_rank={init_rank}, expected_url={worker.url}, init_url={init_url}."
511+
f"init_rank={init_result.rank}, expected_url={worker.url}, "
512+
f"init_url={init_result.server_url}."
513513
)
514514
self._shutdown_worker_group(group, wait_server_down=False, best_effort=True)
515515
return False

0 commit comments

Comments
 (0)