Skip to content

Commit bff0471

Browse files
docs(data-plane): consolidate README; drop stale plan/verl refs
Combine research/data_plane_api_lifecycle.md into nemo_rl/data_plane/README.md as the canonical reference. Move the rest of the data-plane research docs (integration plan, observability, mooncake status, prefetch/test plans, async-RL limitations, policy subclass plan, test SOP, tests/data_plane/README) to local-only nemo_rl/data_plane/docs/ — kept untracked, out of the PR. Comment cleanup across grpo_sync, sync_rollout_actor, tq_policy, and nemo_rl/data_plane/{interfaces,codec,driver_io,preshard,adapters/*}: strip dangling research/data_plane_integration_plan.md §1.2 pointers, defunct Stage 1/2/3/4/5/Phase 1/(P2)/(P3)/Tier N references, verl line-number cross-refs, and a (commit a085559) provenance line. Fix an incorrect adapters/noop.py docstring that claimed the factory returns NoOp on enabled=False (it actually raises). Retarget interfaces.py and tq_policy.py docstrings from the removed integration plan to the README. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent e5305e0 commit bff0471

18 files changed

Lines changed: 352 additions & 3080 deletions

nemo_rl/algorithms/grpo_sync.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,17 @@
1313
# limitations under the License.
1414
"""GRPO trainer — TransferQueue-mediated path (sync).
1515
16-
Sibling fork of ``nemo_rl.algorithms.grpo``. Mirrors verl's split between
17-
``main_ppo.py`` (legacy) and ``main_ppo_sync.py`` (TQ-only): each file
18-
has zero internal branching on whether TQ is engaged, and the example
19-
script chooses one or the other.
16+
Sibling fork of ``nemo_rl.algorithms.grpo``. Each file has zero
17+
internal branching on whether TQ is engaged; the example script
18+
chooses one or the other based on ``data_plane.enabled``.
2019
2120
Setup, helpers, and ``validate`` are re-imported from ``grpo``; only the
2221
training loop body is duplicated here so the per-step lifecycle hooks
2322
(register / seed-put / per-rank fetch / clear) can live in straight
2423
sequential code.
2524
2625
Parity with the legacy path is verified by running the same config
27-
against both entrypoints and diffing the wandb runs (Stage 5 of the
28-
data-plane integration plan).
26+
against both entrypoints and diffing the wandb runs.
2927
"""
3028

3129
from __future__ import annotations
@@ -261,8 +259,7 @@ def grpo_train_sync(
261259
# The actor owns the multi-turn rollout loop AND post-rollout
262260
# flatten / mask construction / prompt extraction / baseline-std /
263261
# TQ first-write. Bulk tensors stay actor-side until kv_batch_put;
264-
# driver receives only KVBatchMeta + small slice via Ray. See
265-
# research/data_plane_integration_plan.md §1.2.
262+
# driver receives only KVBatchMeta + small slice via Ray.
266263
rollout_actor = SyncRolloutActor.options(
267264
runtime_env=make_actor_runtime_env(
268265
"nemo_rl.experience.sync_rollout_actor.SyncRolloutActor"
@@ -414,7 +411,6 @@ def grpo_train_sync(
414411
# mask construction + prompt extraction + baseline/std,
415412
# writes bulk to TQ in one flat kv_batch_put, returns
416413
# only meta + small slice. Bulk never visits the driver.
417-
# See research/data_plane_integration_plan.md §1.2.
418414
dynamic_sampling_num_gen_batches += 1
419415
with timer.time("generation"):
420416
n_prompts = int(repeated_batch.size)
@@ -551,12 +547,12 @@ def grpo_train_sync(
551547

552548
print("▶ Computing logprobs...", flush=True)
553549
with timer.time("policy_and_reference_logprobs"):
554-
# Meta-driven worker dispatch (verl pattern). Workers
555-
# fetch their slice from TQ; logprob result is also
556-
# written back to TQ as ``prev_logprobs`` /
550+
# Meta-driven worker dispatch. Workers fetch their
551+
# slice from TQ; logprob result is also written back
552+
# to TQ as ``prev_logprobs`` /
557553
# ``reference_policy_logprobs`` columns under
558-
# ``meta.keys`` (worker write-back from PR-A.5) AND
559-
# returned to the driver via Ray for the next compute.
554+
# ``meta.keys`` AND returned to the driver via Ray
555+
# for the next compute.
560556
_prev_lp = policy.get_logprobs_from_meta(meta, timer=timer)
561557
prev_logprobs = _prev_lp["logprobs"]
562558

@@ -582,9 +578,8 @@ def grpo_train_sync(
582578
generation_logprobs = extras_bdd["generation_logprobs"]
583579
token_mask = extras_bdd["token_mask"]
584580

585-
# Thin BDD for the data-driven masking call. Mirrors
586-
# verl's ``_compute_old_log_prob`` pattern: take the
587-
# slice you need, transform, write delta back.
581+
# Thin BDD for the data-driven masking call: take
582+
# the slice you need, transform, write delta back.
588583
masking_data = BatchedDataDict[ClippedPGLossDataDict](
589584
{
590585
"token_mask": token_mask,
@@ -682,11 +677,10 @@ def grpo_train_sync(
682677
# Calibration needs input_ids + input_lengths +
683678
# multimodal fields. The actor wrote all of those
684679
# to TQ at rollout time; fetch them back as a
685-
# slice (driver-driven, data-driven — same shape
686-
# as verl's _compute_old_log_prob reshape: pull
687-
# what you compute against, transform, no need
688-
# to refetch the bulk schema). Logprob/mask/adv
689-
# columns added later are irrelevant here.
680+
# slice — pull what you compute against, transform,
681+
# no need to refetch the bulk schema. Logprob /
682+
# mask / adv columns added later are irrelevant
683+
# here.
690684
_calib_fields = [
691685
f for f in (meta.fields or [])
692686
if f not in (
@@ -957,9 +951,9 @@ def grpo_train_sync(
957951
if _log_input_ids is not None:
958952
log_data["token_ids"] = _log_input_ids.tolist()
959953
# NOTE: ``content`` (raw assistant text) is not stored in
960-
# TQ — the codec is tensor-only (Tier 1 of P3 in the
961-
# integration plan). When non-tensor logging matters,
962-
# plumb it through Ray return on rollout_to_tq's slice.
954+
# TQ — the codec is tensor-only. When non-tensor logging
955+
# matters, plumb it through Ray return on rollout_to_tq's
956+
# slice.
963957
logger.log_batched_dict_as_jsonl(
964958
log_data, f"train_data_step{total_steps + 1}.jsonl"
965959
)

0 commit comments

Comments
 (0)