|
28 | 28 | os.environ["RAY_TMPDIR"] = _temp_dir # Alternative env var |
29 | 29 | os.environ["TMPDIR"] = _temp_dir # System temp dir |
30 | 30 |
|
31 | | -from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer |
| 31 | +from nemo_rl.algorithms.async_utils import ( |
| 32 | + AsyncTrajectoryCollector, |
| 33 | + ReplayBuffer, |
| 34 | +) |
| 35 | +from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBufferNew |
32 | 36 | from nemo_rl.algorithms.grpo import MasterConfig |
33 | 37 | from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType |
34 | 38 | from nemo_rl.distributed.batched_data_dict import BatchedDataDict |
@@ -350,6 +354,160 @@ def test_replay_buffer_clear(self): |
350 | 354 | ray.kill(buffer) |
351 | 355 |
|
352 | 356 |
|
| 357 | +class TestReplayBufferNew: |
| 358 | + """Tests for ReplayBufferNew: staleness-window sampling via _evict + sample.""" |
| 359 | + |
| 360 | + def _make_traj(self, label: str) -> dict: |
| 361 | + return {"batch": {"data": label}, "rollout_metrics": {}} |
| 362 | + |
| 363 | + def _add(self, buf, label: str, weight_version: int): |
| 364 | + return ray.get( |
| 365 | + buf.add.remote( |
| 366 | + self._make_traj(label), |
| 367 | + weight_version=weight_version, |
| 368 | + target_weight_version=0, # unused in ReplayBufferNew |
| 369 | + ) |
| 370 | + ) |
| 371 | + |
| 372 | + def _sample(self, buf, num_groups: int, trainer_version: int): |
| 373 | + return ray.get( |
| 374 | + buf.sample.remote( |
| 375 | + num_prompt_groups=num_groups, |
| 376 | + current_weight_version=trainer_version, |
| 377 | + max_age_steps=0, # unused in ReplayBufferNew |
| 378 | + ) |
| 379 | + ) |
| 380 | + |
| 381 | + # ------------------------------------------------------------------ |
| 382 | + # Construction |
| 383 | + # ------------------------------------------------------------------ |
| 384 | + |
| 385 | + def test_invalid_max_staleness_raises(self): |
| 386 | + with pytest.raises(Exception): |
| 387 | + buf = ReplayBufferNew.remote(max_size=10, max_staleness=-1) |
| 388 | + ray.get(buf.size.remote()) |
| 389 | + |
| 390 | + # ------------------------------------------------------------------ |
| 391 | + # _evict (via sample) |
| 392 | + # ------------------------------------------------------------------ |
| 393 | + |
| 394 | + def test_stale_rows_evicted_before_sampling(self): |
| 395 | + """Rows with age > max_staleness are removed before sample() selects.""" |
| 396 | + buf = ReplayBufferNew.remote(max_size=10, max_staleness=2) |
| 397 | + # age at trainer=4: gen_v=1 → 3 > 2 (stale), gen_v=3 → 1 ≤ 2 (valid) |
| 398 | + self._add(buf, "stale", weight_version=1) |
| 399 | + self._add(buf, "fresh", weight_version=3) |
| 400 | + |
| 401 | + result = self._sample(buf, num_groups=1, trainer_version=4) |
| 402 | + |
| 403 | + assert result is not None |
| 404 | + assert result["trajectories"][0]["batch"]["data"] == "fresh" |
| 405 | + assert ray.get(buf.size.remote()) == 0 # stale row also gone |
| 406 | + ray.kill(buf) |
| 407 | + |
| 408 | + def test_all_stale_returns_none(self): |
| 409 | + """sample() returns None when all rows are evicted as stale.""" |
| 410 | + buf = ReplayBufferNew.remote(max_size=10, max_staleness=1) |
| 411 | + self._add(buf, "a", weight_version=0) |
| 412 | + self._add(buf, "b", weight_version=1) |
| 413 | + |
| 414 | + # trainer=5: both ages > 1 |
| 415 | + result = self._sample(buf, num_groups=1, trainer_version=5) |
| 416 | + |
| 417 | + assert result is None |
| 418 | + assert ray.get(buf.size.remote()) == 0 |
| 419 | + ray.kill(buf) |
| 420 | + |
| 421 | + def test_eviction_frees_capacity(self): |
| 422 | + """Evicting a stale row allows a subsequent add() to succeed.""" |
| 423 | + buf = ReplayBufferNew.remote(max_size=1, max_staleness=1) |
| 424 | + self._add(buf, "x", weight_version=1) |
| 425 | + assert self._add(buf, "x", weight_version=1) == "full" |
| 426 | + |
| 427 | + # sample() at trainer=5 evicts the stale row (age 4 > 1) |
| 428 | + self._sample(buf, num_groups=1, trainer_version=5) |
| 429 | + |
| 430 | + assert self._add(buf, "y", weight_version=4) == "success" |
| 431 | + ray.kill(buf) |
| 432 | + |
| 433 | + def test_within_window_not_evicted(self): |
| 434 | + """Rows whose age is within max_staleness are not evicted.""" |
| 435 | + buf = ReplayBufferNew.remote(max_size=10, max_staleness=3) |
| 436 | + self._add(buf, "x", weight_version=4) |
| 437 | + |
| 438 | + # trainer=6: age = 6 - 4 = 2 ≤ 3 → should survive |
| 439 | + # should return None since there is only 1 row |
| 440 | + result = self._sample(buf, num_groups=2, trainer_version=6) |
| 441 | + assert result is None |
| 442 | + |
| 443 | + # this sample should still be there |
| 444 | + assert ray.get(buf.size.remote()) == 1 |
| 445 | + ray.kill(buf) |
| 446 | + |
| 447 | + # ------------------------------------------------------------------ |
| 448 | + # sample() |
| 449 | + # ------------------------------------------------------------------ |
| 450 | + |
| 451 | + @pytest.mark.parametrize("sample_freshest_first", [True, False]) |
| 452 | + def test_sample_freshest_first(self, sample_freshest_first): |
| 453 | + """sample() returns the freshest trajectories first.""" |
| 454 | + buf = ReplayBufferNew.remote( |
| 455 | + max_size=10, max_staleness=5, sample_freshest_first=sample_freshest_first |
| 456 | + ) |
| 457 | + for gen_v in [3, 4, 5]: |
| 458 | + self._add(buf, f"v{gen_v}", weight_version=gen_v) |
| 459 | + |
| 460 | + result = self._sample(buf, num_groups=2, trainer_version=6) |
| 461 | + |
| 462 | + assert result is not None |
| 463 | + data = [t["batch"]["data"] for t in result["trajectories"]] |
| 464 | + if sample_freshest_first: |
| 465 | + assert data == ["v5", "v4"] |
| 466 | + else: |
| 467 | + assert data == ["v3", "v4"] |
| 468 | + ray.kill(buf) |
| 469 | + |
| 470 | + def test_sample_returns_none_when_insufficient(self): |
| 471 | + """sample() returns None when fewer rows than requested remain after eviction.""" |
| 472 | + buf = ReplayBufferNew.remote(max_size=10, max_staleness=5) |
| 473 | + self._add(buf, "only", weight_version=1) |
| 474 | + |
| 475 | + result = self._sample(buf, num_groups=3, trainer_version=2) |
| 476 | + |
| 477 | + assert result is None |
| 478 | + ray.kill(buf) |
| 479 | + |
| 480 | + def test_sample_returns_none_on_empty_buffer(self): |
| 481 | + buf = ReplayBufferNew.remote(max_size=10, max_staleness=5) |
| 482 | + result = self._sample(buf, num_groups=1, trainer_version=1) |
| 483 | + assert result is None |
| 484 | + ray.kill(buf) |
| 485 | + |
| 486 | + def test_sample_avg_trajectory_age(self): |
| 487 | + """avg_trajectory_age is computed from the sampled generation versions.""" |
| 488 | + buf = ReplayBufferNew.remote(max_size=10, max_staleness=5) |
| 489 | + # freshest first: gen 8 (age 2), gen 6 (age 4) → avg = 3.0 |
| 490 | + for gen_v in [6, 8]: |
| 491 | + self._add(buf, f"v{gen_v}", weight_version=gen_v) |
| 492 | + |
| 493 | + result = self._sample(buf, num_groups=2, trainer_version=10) |
| 494 | + |
| 495 | + assert result is not None |
| 496 | + assert abs(result["avg_trajectory_age"] - 3.0) < 1e-6 |
| 497 | + ray.kill(buf) |
| 498 | + |
| 499 | + def test_sample_consumes_selected_rows(self): |
| 500 | + """Rows returned by sample() are removed from the buffer.""" |
| 501 | + buf = ReplayBufferNew.remote(max_size=10, max_staleness=5) |
| 502 | + for gen_v in [1, 2, 3]: |
| 503 | + self._add(buf, f"v{gen_v}", weight_version=gen_v) |
| 504 | + |
| 505 | + self._sample(buf, num_groups=2, trainer_version=4) |
| 506 | + |
| 507 | + assert ray.get(buf.size.remote()) == 1 |
| 508 | + ray.kill(buf) |
| 509 | + |
| 510 | + |
353 | 511 | class TestAsyncTrajectoryCollector: |
354 | 512 | """Test cases for AsyncTrajectoryCollector.""" |
355 | 513 |
|
|
0 commit comments