Skip to content

Commit ac7723c

Browse files
authored
feat: support staleness-window in ReplayBufferNew (#2458)
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 9114a1d commit ac7723c

4 files changed

Lines changed: 259 additions & 24 deletions

File tree

nemo_rl/algorithms/async_utils/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBuffer, ReplayBufferNew
15+
from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBuffer
1616
from nemo_rl.algorithms.async_utils.trajectory_collector import AsyncTrajectoryCollector
1717

1818
__all__ = [
1919
"ReplayBuffer",
20-
"ReplayBufferNew",
2120
"AsyncTrajectoryCollector",
2221
]

nemo_rl/algorithms/async_utils/interfaces.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ def sample(
5151
"""
5252
...
5353

54-
def evict(self) -> None:
55-
"""Evict old trajectories."""
56-
...
57-
5854
def size(self) -> int:
5955
"""Return current buffer size."""
6056
...

nemo_rl/algorithms/async_utils/replay_buffer.py

Lines changed: 99 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
import threading as _threading
16-
from typing import Any, Optional
16+
from collections import Counter
17+
from typing import Any, Iterable, Optional
1718

1819
import ray
1920

@@ -87,6 +88,13 @@ def get_existing_target_weights(self) -> set[int]:
8788
with self._lock:
8889
return set(self.target_weight_versions)
8990

91+
def _remove_indices(self, indices: Iterable[int]) -> None:
92+
"""Remove trajectories at the given indices."""
93+
for idx in sorted(indices, reverse=True):
94+
self.trajectory_versions.pop(idx)
95+
self.target_weight_versions.pop(idx)
96+
self.trajectories.pop(idx)
97+
9098
def sample(
9199
self,
92100
num_prompt_groups: int,
@@ -113,8 +121,6 @@ def sample(
113121
print(f" {self.trajectory_versions=}")
114122

115123
# For debugging: check for unexpected old trajectories
116-
from collections import Counter
117-
118124
version_counts = Counter(self.trajectory_versions)
119125
print(f" {version_counts=}")
120126

@@ -180,8 +186,6 @@ def sample(
180186
f" ✅ Selected {len(selected)} trajectories all intended for step {current_weight_version}"
181187
)
182188

183-
from collections import Counter
184-
185189
sampled_weights = [self.trajectory_versions[i] for i in selected]
186190
avg_trajectory_age = current_weight_version - sum(sampled_weights) / len(
187191
sampled_weights
@@ -194,13 +198,9 @@ def sample(
194198
f"🎯 All selected trajectories target step {current_weight_version} (100% target match)"
195199
)
196200

197-
sampled_items = [self.trajectories[i] for i in selected]
198-
199201
# Remove selected items in reverse order to maintain correct indices
200-
for idx in sorted(selected, reverse=True):
201-
self.trajectory_versions.pop(idx)
202-
self.target_weight_versions.pop(idx)
203-
self.trajectories.pop(idx)
202+
sampled_items = [self.trajectories[i] for i in selected]
203+
self._remove_indices(selected)
204204
print(
205205
f"🗑️ Consumed and removed {len(selected)} groups from buffer, old buffer size: {total_trajectories}, new buffer size: {len(self.trajectories)}, new target weight versions {self.target_weight_versions}"
206206
)
@@ -210,11 +210,6 @@ def sample(
210210
"avg_trajectory_age": avg_trajectory_age,
211211
}
212212

213-
def evict(self) -> None:
214-
"""Evict old trajectories."""
215-
# Adding for backward compatibility.
216-
pass
217-
218213
def size(self) -> int:
219214
"""Return current buffer size."""
220215
with self._lock:
@@ -233,6 +228,93 @@ class ReplayBuffer(ReplayBufferImpl):
233228
pass
234229

235230

231+
# WIP: DO NOT USE - This class is WIP and may be changed without notice, please DO NOT USE it.
232+
# Will be replaced by TQReplayBuffer once TQ is ready.
236233
@ray.remote # pragma: no cover
237234
class ReplayBufferNew(ReplayBufferImpl):
238-
pass
235+
"""Staleness-window replay buffer.
236+
237+
-- WIP: DO NOT USE --
238+
This class is WIP and may be changed without notice, please DO NOT USE it.
239+
240+
Differences from ReplayBuffer:
241+
- _evict(): Stale rows (trainer_version - weight_version > max_staleness) are evicted
242+
at the start of every sample() call.
243+
- sample(): selects trajectories in freshest-first order (default) or FIFO order,
244+
controlled by the sample_freshest_first flag, from whatever remains in the buffer
245+
after eviction.
246+
247+
TODO: remove when cleaning up
248+
- max_age_steps won't be used in ReplayBufferNew;
249+
- self.target_weight_versions won't be used in ReplayBufferNew and will be removed
250+
when cleaning up. target_weight_versions gates generation on specific trainer steps,
251+
which causes generation pauses; ReplayBufferNew intentionally avoids this.
252+
- add this class to nemo_rl/algorithms/async_utils/__init__.py
253+
"""
254+
255+
def __init__(
256+
self, max_size: int, max_staleness: int, sample_freshest_first: bool = True
257+
):
258+
super().__init__(max_size)
259+
if max_staleness < 0:
260+
raise ValueError(f"max_staleness must be non-negative, got {max_staleness}")
261+
self.max_staleness = max_staleness
262+
# will move to StalenessSampler when we implement it
263+
self.sample_freshest_first = sample_freshest_first
264+
265+
def _evict(self, current_weight_version: int) -> None:
266+
"""Evict rows where trainer_version - weight_version > max_staleness.
267+
268+
Must be called with self._lock held.
269+
"""
270+
min_valid = current_weight_version - self.max_staleness
271+
stale = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid]
272+
self._remove_indices(stale)
273+
274+
def sample(
275+
self,
276+
num_prompt_groups: int,
277+
current_weight_version: int,
278+
max_age_steps: int,
279+
) -> Optional[dict[str, Any]]:
280+
"""Sample num_prompt_groups trajectories, freshest-first.
281+
282+
Will evict stale rows before sampling, so we will get [current_weight_version - self.max_staleness, current_weight_version] valid trajectories.
283+
284+
Returns:
285+
Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None.
286+
"""
287+
with self._lock:
288+
self._evict(current_weight_version)
289+
290+
if not self.trajectories:
291+
return None
292+
293+
all_indices = range(len(self.trajectory_versions))
294+
if self.sample_freshest_first:
295+
all_indices = sorted(
296+
all_indices,
297+
key=lambda i: self.trajectory_versions[i],
298+
reverse=True,
299+
)
300+
301+
if len(all_indices) < num_prompt_groups:
302+
print(
303+
f"Insufficient trajectories: have {len(all_indices)}, "
304+
f"need {num_prompt_groups}. Waiting."
305+
)
306+
return None
307+
308+
selected = all_indices[:num_prompt_groups]
309+
sampled_weights = [self.trajectory_versions[i] for i in selected]
310+
avg_trajectory_age = current_weight_version - sum(sampled_weights) / len(
311+
sampled_weights
312+
)
313+
314+
sampled_items = [self.trajectories[i] for i in selected]
315+
self._remove_indices(selected)
316+
317+
return {
318+
"trajectories": sampled_items,
319+
"avg_trajectory_age": avg_trajectory_age,
320+
}

tests/unit/algorithms/test_async_utils.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
os.environ["RAY_TMPDIR"] = _temp_dir # Alternative env var
2929
os.environ["TMPDIR"] = _temp_dir # System temp dir
3030

31-
from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer
31+
from nemo_rl.algorithms.async_utils import (
32+
AsyncTrajectoryCollector,
33+
ReplayBuffer,
34+
)
35+
from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBufferNew
3236
from nemo_rl.algorithms.grpo import MasterConfig
3337
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
3438
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
@@ -350,6 +354,160 @@ def test_replay_buffer_clear(self):
350354
ray.kill(buffer)
351355

352356

357+
class TestReplayBufferNew:
358+
"""Tests for ReplayBufferNew: staleness-window sampling via _evict + sample."""
359+
360+
def _make_traj(self, label: str) -> dict:
361+
return {"batch": {"data": label}, "rollout_metrics": {}}
362+
363+
def _add(self, buf, label: str, weight_version: int):
364+
return ray.get(
365+
buf.add.remote(
366+
self._make_traj(label),
367+
weight_version=weight_version,
368+
target_weight_version=0, # unused in ReplayBufferNew
369+
)
370+
)
371+
372+
def _sample(self, buf, num_groups: int, trainer_version: int):
373+
return ray.get(
374+
buf.sample.remote(
375+
num_prompt_groups=num_groups,
376+
current_weight_version=trainer_version,
377+
max_age_steps=0, # unused in ReplayBufferNew
378+
)
379+
)
380+
381+
# ------------------------------------------------------------------
382+
# Construction
383+
# ------------------------------------------------------------------
384+
385+
def test_invalid_max_staleness_raises(self):
386+
with pytest.raises(Exception):
387+
buf = ReplayBufferNew.remote(max_size=10, max_staleness=-1)
388+
ray.get(buf.size.remote())
389+
390+
# ------------------------------------------------------------------
391+
# _evict (via sample)
392+
# ------------------------------------------------------------------
393+
394+
def test_stale_rows_evicted_before_sampling(self):
395+
"""Rows with age > max_staleness are removed before sample() selects."""
396+
buf = ReplayBufferNew.remote(max_size=10, max_staleness=2)
397+
# age at trainer=4: gen_v=1 → 3 > 2 (stale), gen_v=3 → 1 ≤ 2 (valid)
398+
self._add(buf, "stale", weight_version=1)
399+
self._add(buf, "fresh", weight_version=3)
400+
401+
result = self._sample(buf, num_groups=1, trainer_version=4)
402+
403+
assert result is not None
404+
assert result["trajectories"][0]["batch"]["data"] == "fresh"
405+
assert ray.get(buf.size.remote()) == 0 # stale row also gone
406+
ray.kill(buf)
407+
408+
def test_all_stale_returns_none(self):
409+
"""sample() returns None when all rows are evicted as stale."""
410+
buf = ReplayBufferNew.remote(max_size=10, max_staleness=1)
411+
self._add(buf, "a", weight_version=0)
412+
self._add(buf, "b", weight_version=1)
413+
414+
# trainer=5: both ages > 1
415+
result = self._sample(buf, num_groups=1, trainer_version=5)
416+
417+
assert result is None
418+
assert ray.get(buf.size.remote()) == 0
419+
ray.kill(buf)
420+
421+
def test_eviction_frees_capacity(self):
422+
"""Evicting a stale row allows a subsequent add() to succeed."""
423+
buf = ReplayBufferNew.remote(max_size=1, max_staleness=1)
424+
self._add(buf, "x", weight_version=1)
425+
assert self._add(buf, "x", weight_version=1) == "full"
426+
427+
# sample() at trainer=5 evicts the stale row (age 4 > 1)
428+
self._sample(buf, num_groups=1, trainer_version=5)
429+
430+
assert self._add(buf, "y", weight_version=4) == "success"
431+
ray.kill(buf)
432+
433+
def test_within_window_not_evicted(self):
434+
"""Rows whose age is within max_staleness are not evicted."""
435+
buf = ReplayBufferNew.remote(max_size=10, max_staleness=3)
436+
self._add(buf, "x", weight_version=4)
437+
438+
# trainer=6: age = 6 - 4 = 2 ≤ 3 → should survive
439+
# should return None since there is only 1 row
440+
result = self._sample(buf, num_groups=2, trainer_version=6)
441+
assert result is None
442+
443+
# this sample should still be there
444+
assert ray.get(buf.size.remote()) == 1
445+
ray.kill(buf)
446+
447+
# ------------------------------------------------------------------
448+
# sample()
449+
# ------------------------------------------------------------------
450+
451+
@pytest.mark.parametrize("sample_freshest_first", [True, False])
452+
def test_sample_freshest_first(self, sample_freshest_first):
453+
"""sample() returns the freshest trajectories first."""
454+
buf = ReplayBufferNew.remote(
455+
max_size=10, max_staleness=5, sample_freshest_first=sample_freshest_first
456+
)
457+
for gen_v in [3, 4, 5]:
458+
self._add(buf, f"v{gen_v}", weight_version=gen_v)
459+
460+
result = self._sample(buf, num_groups=2, trainer_version=6)
461+
462+
assert result is not None
463+
data = [t["batch"]["data"] for t in result["trajectories"]]
464+
if sample_freshest_first:
465+
assert data == ["v5", "v4"]
466+
else:
467+
assert data == ["v3", "v4"]
468+
ray.kill(buf)
469+
470+
def test_sample_returns_none_when_insufficient(self):
471+
"""sample() returns None when fewer rows than requested remain after eviction."""
472+
buf = ReplayBufferNew.remote(max_size=10, max_staleness=5)
473+
self._add(buf, "only", weight_version=1)
474+
475+
result = self._sample(buf, num_groups=3, trainer_version=2)
476+
477+
assert result is None
478+
ray.kill(buf)
479+
480+
def test_sample_returns_none_on_empty_buffer(self):
481+
buf = ReplayBufferNew.remote(max_size=10, max_staleness=5)
482+
result = self._sample(buf, num_groups=1, trainer_version=1)
483+
assert result is None
484+
ray.kill(buf)
485+
486+
def test_sample_avg_trajectory_age(self):
487+
"""avg_trajectory_age is computed from the sampled generation versions."""
488+
buf = ReplayBufferNew.remote(max_size=10, max_staleness=5)
489+
# freshest first: gen 8 (age 2), gen 6 (age 4) → avg = 3.0
490+
for gen_v in [6, 8]:
491+
self._add(buf, f"v{gen_v}", weight_version=gen_v)
492+
493+
result = self._sample(buf, num_groups=2, trainer_version=10)
494+
495+
assert result is not None
496+
assert abs(result["avg_trajectory_age"] - 3.0) < 1e-6
497+
ray.kill(buf)
498+
499+
def test_sample_consumes_selected_rows(self):
500+
"""Rows returned by sample() are removed from the buffer."""
501+
buf = ReplayBufferNew.remote(max_size=10, max_staleness=5)
502+
for gen_v in [1, 2, 3]:
503+
self._add(buf, f"v{gen_v}", weight_version=gen_v)
504+
505+
self._sample(buf, num_groups=2, trainer_version=4)
506+
507+
assert ray.get(buf.size.remote()) == 1
508+
ray.kill(buf)
509+
510+
353511
class TestAsyncTrajectoryCollector:
354512
"""Test cases for AsyncTrajectoryCollector."""
355513

0 commit comments

Comments
 (0)