Skip to content

Commit 6a84d4e

Browse files
daihaowzfenghui
andauthored
fix(infra): correct staleness capacity inflation after recovery (#1345)
* fix(infra): correct staleness capacity inflation after checkpoint recovery StalenessManager's accepted counter started at 0 while the version was restored to a high value by the recovery path. This caused the capacity formula to yield (max_staleness + recovered_version + 1) * batch_size instead of the intended (max_staleness + 1) * batch_size, allowing a burst of rollout submissions and unbounded staleness growth. Add on_version_recovered() to StalenessManager and call it from rl_trainer after recover completes. The trainer accesses the staleness manager directly via the known concrete type (RolloutController in single-controller mode, workflow_executor in SPMD mode). * fix(infra): clarify staleness recovery semantics and use public APIs Address review feedback on the staleness manager recovery path: - Document that on_version_recovered is expected to be called with running == 0 and explain the bound when it is not. - Reach the manager through the public staleness_manager properties on RolloutController and WorkflowExecutor instead of the private _staleness_manager attribute, avoiding coupling to internal layout. - Extend tests with the version=0 no-op case and a parametrized case with in-flight rollouts to verify accepted is set correctly. --------- Co-authored-by: fenghui <dh183333@antgroup.com>
1 parent 9c2ec43 commit 6a84d4e

3 files changed

Lines changed: 75 additions & 0 deletions

File tree

areal/infra/staleness_manager.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ def get_capacity(self) -> int:
112112
capacity = min(concurrency_capacity, staleness_capacity)
113113
return capacity
114114

115+
def on_version_recovered(self, version: int) -> None:
116+
"""Adjust accepted count after checkpoint recovery.
117+
118+
When a checkpoint is recovered, the version jumps from 0 to the
119+
recovered value. Without adjusting accepted, the capacity formula
120+
yields (max_staleness + version + 1) * batch_size instead of the
121+
intended (max_staleness + 1) * batch_size, causing a burst of
122+
submissions and unbounded staleness growth.
123+
124+
Expected to be called during trainer init, before any rollouts are
125+
submitted, so running == 0. If running > 0 (unlikely in practice),
126+
accepted is still set correctly and the capacity formula remains
127+
bounded — (max_staleness + 1) * consumer_bs - running.
128+
"""
129+
with self.lock:
130+
consumer_bs = max(1, self.consumer_batch_size)
131+
self.rollout_stat.accepted = version * consumer_bs
132+
115133
def on_rollout_enqueued(self) -> None:
116134
"""Callback when a rollout is enqueued as a pending input task.
117135

areal/trainer/rl_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,17 @@ def __init__(
368368
weight_update_meta=self.weight_update_meta,
369369
)
370370

371+
# After recovery, sync the staleness manager so its capacity formula
372+
# stays bounded despite the version jumping from 0 to recovery_version.
373+
if self.recover_info is not None:
374+
recovery_version = self.recover_info.last_step_info.global_step + 1
375+
if is_single_controller():
376+
sm = self.rollout.staleness_manager
377+
else:
378+
sm = self.rollout.workflow_executor.staleness_manager
379+
if sm is not None:
380+
sm.on_version_recovered(recovery_version)
381+
371382
self._config_perf_tracer()
372383
self._apply_initial_offload_policy()
373384

tests/test_staleness_manager.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,52 @@ def test_parametrized_version_progression(version):
766766
assert capacity == min(1000, expected_staleness_capacity)
767767

768768

769+
@pytest.mark.parametrize("recovered_version", [0, 5, 10, 50])
770+
def test_on_version_recovered(recovered_version):
771+
"""Test that on_version_recovered adjusts accepted so capacity stays bounded."""
772+
version_provider = MockVersionProvider(0)
773+
manager = StalenessManager(
774+
version_provider=version_provider,
775+
max_concurrent_rollouts=1000,
776+
consumer_batch_size=16,
777+
max_staleness=2,
778+
)
779+
780+
# Simulate recovery: version jumps to recovered_version
781+
version_provider.set_version(recovered_version)
782+
manager.on_version_recovered(recovered_version)
783+
784+
# After recovery, capacity should be (max_staleness + 1) * consumer_batch_size
785+
# regardless of the recovered version value.
786+
capacity = manager.get_capacity()
787+
assert capacity == (2 + 1) * 16
788+
789+
790+
@pytest.mark.parametrize("running", [1, 5, 16])
791+
def test_on_version_recovered_with_running_rollouts(running):
792+
"""Test that on_version_recovered sets accepted correctly even when running > 0."""
793+
recovered_version = 10
794+
version_provider = MockVersionProvider(0)
795+
manager = StalenessManager(
796+
version_provider=version_provider,
797+
max_concurrent_rollouts=1000,
798+
consumer_batch_size=16,
799+
max_staleness=2,
800+
)
801+
802+
# Simulate in-flight rollouts at recovery time
803+
for _ in range(running):
804+
manager.on_rollout_enqueued()
805+
manager.on_rollout_submitted()
806+
807+
version_provider.set_version(recovered_version)
808+
manager.on_version_recovered(recovered_version)
809+
810+
# Capacity formula: (max_staleness + 1) * consumer_bs - running
811+
capacity = manager.get_capacity()
812+
assert capacity == (2 + 1) * 16 - running
813+
814+
769815
if __name__ == "__main__":
770816
# Run tests with pytest
771817
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)