Skip to content

Commit 1ba8f43

Browse files
feat(grpo-sync): equivalency fixes + content via TQ object column
Brings the TQ-mediated GRPO trainer (grpo_train_sync) into parity with the legacy grpo_train path: * Fix DS print KeyError on repeated_batch['total_reward'] — use the cumulative unfiltered_rewards tracker instead. * Wrap scale_rewards/shaping/overlong/baseline-std in timer.time("reward_calculation") to match legacy timing dashboards. * Warn when calculate_advantages_on_gpu is set under TQ (no-op since the slice is CPU-side). * Add per-step generation-side metric hooks (snapshot_step_metrics on the first DS iter, clear_logger_metrics before each rollout) inside SyncRolloutActor.rollout_to_tq via a new first_iter kwarg. * Plumb GDPO reward components through the rollout slice so GDPOAdvantageEstimator and scale_rewards see them. * Plumb assistant text content through TQ as an object column (verl- style np.ndarray(dtype=object) → pack_object_array → uint8 jagged nested tensor); driver fetches it pre-kv_clear alongside input_ids via read_columns and writes it into train_data_step{N}.jsonl. Also refactors _apply_dynamic_sampling to use BatchedDataDict's select_indices / from_batches / slice methods rather than open-coded helpers — slice_data is now a BatchedDataDict end-to-end. Verified: tests/data_plane/unit/ 102 passed / 1 xfailed (Slurm 11653849); GRPO 1B mcore + DS + TQ on simple backend 5/5 steps (Slurm 11653848); on mooncake_cpu 5/5 steps with raised mooncake defaults (Slurm 11654191). Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> feat(data-plane): make mooncake segment/buffer Hydra-overridable, raise defaults global_segment_size and local_buffer_size were hardcoded at 128 GiB / 16 GiB. Multi-iter DAPO with large message_log object payloads exhausts mooncake_cpu's internal allocator headroom at those sizes, manifesting as RuntimeError: batch_get_tensor returned None for '<idx>@input_ids' partway through training (verified failure JOBID 11653282 on the 1n8g GRPO 1B + DS + TQ + mooncake_cpu recipe). Both knobs now read from cfg.get(...), defaults raised to 512 GiB and 64 GiB respectively. Override per-recipe via +data_plane.global_segment_size=<bytes> / +data_plane.local_buffer_size=<bytes>. Lazy mmap, so RSS stays bounded by actual traffic. Verified at the new defaults: 1n8g GRPO 1B + DS + TQ + mooncake_cpu runs 5/5 steps (JOBID 11654191). Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> test(data-plane): add object×backend coverage and mooncake load repro Closes the gap that hid the mooncake_cpu under-sized-segment failure: previously the codec round-trip (test_codec_object.py) tested pack_object_array in-process, and the smoke round-trip (test_smoke_round_trip_backends) ran tensor-only fields against both backends. Object fields × mooncake_cpu was untested. * tests/data_plane/functional/test_tq_lifecycle.py: - test_object_round_trip_backends: np.ndarray(dtype=object) put → get → decode equality, parametrized over simple + mooncake_cpu. - test_object_and_tensor_mixed_round_trip_backends: mixed schema (tensor + object on the same partition) — regression guard for co-fetch tensor/object decode in a single read_columns call. * research/mooncake_object_repro.{py,sbatch}: standalone Slurm-runnable reproducer that hammers a backend with object-heavy puts/gets in isolation (no rollout, no policy). Two modes: --mode=load (N iters × M object fields, fresh partition per iter) and --mode=schema (single put, mixed tensor + object). Lets us narrow future storage-layer failures to a tiny artifact for upstream triage. Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent 9606176 commit 1ba8f43

5 files changed

Lines changed: 286 additions & 74 deletions

File tree

nemo_rl/algorithms/grpo_sync.py

Lines changed: 92 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from nemo_rl.algorithms.reward_functions import apply_reward_shaping
5858
from nemo_rl.algorithms.utils import (
5959
calculate_baseline_and_std_per_prompt,
60+
get_gdpo_reward_component_keys,
6061
log_generation_metrics_to_wandb,
6162
print_performance_metrics,
6263
)
@@ -81,7 +82,7 @@
8182
# on std != 0, accumulate survivors across iterations, slice on overflow.
8283
# Bulk in TQ untouched except for kv_clear of dropped/discarded uids.
8384

84-
_DSlice = dict[str, torch.Tensor]
85+
_DSlice = BatchedDataDict[Any]
8586

8687

8788
def _apply_dynamic_sampling(
@@ -116,21 +117,14 @@ def _apply_dynamic_sampling(
116117
# Subset this iteration's survivors and merge into the running cache.
117118
if keep_idx:
118119
km = meta.subset(keep_idx)
119-
ks: _DSlice = {
120-
k: (v[keep_idx] if isinstance(v, torch.Tensor) else v)
121-
for k, v in slice_data.items()
122-
}
120+
ks = slice_data.select_indices(keep_idx)
123121
ks["filtered_reward"] = ks["total_reward"]
124122
if pending_meta is None:
125123
pending_meta, pending_slice = km, ks
126124
else:
127125
assert pending_slice is not None
128126
pending_meta = pending_meta.concat(km)
129-
pending_slice = {
130-
k: (torch.cat([pending_slice[k], ks[k]])
131-
if isinstance(ks[k], torch.Tensor) else ks[k])
132-
for k in ks
133-
}
127+
pending_slice = BatchedDataDict.from_batches([pending_slice, ks])
134128

135129
n = len(pending_meta.keys) if pending_meta is not None else 0
136130
if n < train_prompts_size:
@@ -150,10 +144,7 @@ def _apply_dynamic_sampling(
150144
partition_id=pending_meta.partition_id,
151145
)
152146
pending_meta = pending_meta.slice(0, train_prompts_size)
153-
pending_slice = {
154-
k: (v[:train_prompts_size] if isinstance(v, torch.Tensor) else v)
155-
for k, v in pending_slice.items()
156-
}
147+
pending_slice = pending_slice.slice(0, train_prompts_size)
157148
ds_metrics["dynamic_sampling_num_discarded_valid_samples"] = n - train_prompts_size
158149

159150
unfiltered_for_log = torch.cat(pending_unfiltered_rewards)[:train_prompts_size]
@@ -255,6 +246,17 @@ def grpo_train_sync(
255246
"constructs it via the policy_factory when data_plane.enabled=True."
256247
)
257248

249+
# TQ-resident tensors live on CPU; baseline/std are computed on the
250+
# slice without a CUDA hop. The flag is a no-op here — warn so users
251+
# don't expect it to do anything.
252+
if master_config["grpo"].get("calculate_advantages_on_gpu"):
253+
warnings.warn(
254+
"grpo.calculate_advantages_on_gpu has no effect when "
255+
"data_plane.enabled=true; baseline/std are computed on CPU "
256+
"because TQ-resident tensors are CPU-side.",
257+
stacklevel=2,
258+
)
259+
258260
# ── Sync rollout actor (rollout 1-hop put) ──────────────────────
259261
# The actor owns the multi-turn rollout loop AND post-rollout
260262
# flatten / mask construction / prompt extraction / baseline-std /
@@ -312,7 +314,7 @@ def grpo_train_sync(
312314
# legacy ``metrics["reward"]`` semantics (cumulative unfiltered
313315
# total_reward across all contributing iterations).
314316
pending_meta = None
315-
pending_slice: Optional[dict[str, torch.Tensor]] = None
317+
pending_slice: Optional[_DSlice] = None
316318
pending_unfiltered_rewards: list[torch.Tensor] = []
317319
dynamic_sampling_num_gen_batches = 0
318320

@@ -420,18 +422,26 @@ def grpo_train_sync(
420422
# extraction + baseline/std + kv_batch_put + finish
421423
# generation + logger metrics — all bundled into one
422424
# round-trip.
425+
# ``first_iter`` is the actor's signal to call
426+
# ``policy_generation.snapshot_step_metrics()``.
427+
# ``dynamic_sampling_num_gen_batches`` is incremented
428+
# to 1 just above before this branch — keep these in
429+
# sync if either is renamed.
423430
(
424431
meta,
425-
slice_data,
432+
slice_extras,
426433
rollout_metrics,
427434
generation_logger_metrics,
428435
) = ray.get(
429436
rollout_actor.rollout_to_tq.remote(
430437
repeated_batch,
431438
uids=uids,
432439
partition_id=policy._tq_partition_id,
440+
first_iter=(dynamic_sampling_num_gen_batches == 1),
433441
)
434442
)
443+
slice_data: _DSlice = BatchedDataDict[Any](slice_extras)
444+
del slice_extras
435445

436446
if not _should_log_nemo_gym_responses(master_config):
437447
for key in list(rollout_metrics):
@@ -450,27 +460,28 @@ def grpo_train_sync(
450460
# used to be on the driver, were briefly on the actor,
451461
# now back on the driver where they belong (no bulk
452462
# touched by any of these ops).
453-
slice_data = scale_rewards(
454-
slice_data, master_config["grpo"]["reward_scaling"],
455-
)
456-
if master_config["grpo"]["reward_shaping"]["enabled"]:
457-
slice_data = apply_reward_shaping(
458-
slice_data, master_config["grpo"]["reward_shaping"],
463+
with timer.time("reward_calculation"):
464+
slice_data = scale_rewards(
465+
slice_data, master_config["grpo"]["reward_scaling"],
459466
)
460-
if master_config["grpo"]["overlong_filtering"]:
461-
lm = slice_data["loss_multiplier"].clone()
462-
lm[slice_data["truncated"]] = 0
463-
slice_data["loss_multiplier"] = lm
464-
slice_data["baseline"], slice_data["std"] = (
465-
calculate_baseline_and_std_per_prompt(
466-
slice_data["prompt_ids_for_adv"],
467-
slice_data["total_reward"],
468-
torch.ones_like(slice_data["total_reward"]),
469-
leave_one_out_baseline=master_config["grpo"][
470-
"use_leave_one_out_baseline"
471-
],
467+
if master_config["grpo"]["reward_shaping"]["enabled"]:
468+
slice_data = apply_reward_shaping(
469+
slice_data, master_config["grpo"]["reward_shaping"],
470+
)
471+
if master_config["grpo"]["overlong_filtering"]:
472+
lm = slice_data["loss_multiplier"].clone()
473+
lm[slice_data["truncated"]] = 0
474+
slice_data["loss_multiplier"] = lm
475+
slice_data["baseline"], slice_data["std"] = (
476+
calculate_baseline_and_std_per_prompt(
477+
slice_data["prompt_ids_for_adv"],
478+
slice_data["total_reward"],
479+
torch.ones_like(slice_data["total_reward"]),
480+
leave_one_out_baseline=master_config["grpo"][
481+
"use_leave_one_out_baseline"
482+
],
483+
)
472484
)
473-
)
474485

475486
# ── Dynamic sampling (DAPO non-zero-std filter) ────────
476487
# Slice-only; bulk in TQ untouched except for kv_clear
@@ -609,18 +620,21 @@ def grpo_train_sync(
609620
mask = token_mask * sample_mask.unsqueeze(-1)
610621

611622
# Thin slice-shaped repeated_batch for compute_advantage.
612-
# The estimator only reads scalar/per-sample fields
613-
# (total_reward, baseline, std) plus the optional
614-
# filtered_reward when dynamic_sampling is engaged
615-
# (rejected at the actor for now — see
616-
# SyncRolloutActor.rollout_to_tq).
623+
# GRPO and Reinforce++ estimators ignore repeated_batch
624+
# (swallowed via **kwargs); GDPO reads the per-component
625+
# reward keys discovered by get_gdpo_reward_component_keys.
626+
# The actor plumbs those keys into ``slice_data`` so the
627+
# thin BDD here is byte-equivalent to legacy passing the
628+
# full repeated_batch.
617629
rb_for_adv = BatchedDataDict[Any](
618630
{
619631
"total_reward": rewards,
620632
"baseline": baseline,
621633
"std": std,
622634
}
623635
)
636+
for k in get_gdpo_reward_component_keys(slice_data):
637+
rb_for_adv[k] = slice_data[k]
624638
advantages = adv_estimator.compute_advantage(
625639
prompt_ids=prompt_ids_for_adv,
626640
rewards=rewards,
@@ -699,16 +713,24 @@ def grpo_train_sync(
699713
)["layers"]
700714
POLICY_GENERATION_STALE = True
701715

702-
# Stash input_ids before kv_clear so the late log_data
703-
# jsonl block (which logs token_ids) can use it. The clear
704-
# below removes meta.keys from TQ, so any post-clear
705-
# read_columns on this meta would fail.
716+
# Stash input_ids and content before kv_clear so the
717+
# late log_data jsonl block can use them. The clear below
718+
# removes meta.keys from TQ, so any post-clear
719+
# read_columns on this meta would fail. ``content`` is a
720+
# decoded object array (list[str]); read_columns handles
721+
# decoding via meta.extra_info[META_OBJECT_FIELDS].
706722
_log_input_ids: Optional[torch.Tensor] = None
723+
_log_content: Optional[np.ndarray] = None
707724
if not _should_log_nemo_gym_responses(master_config):
708-
_log_input_ids = read_columns(
709-
policy._dp_client, meta, select_fields=["input_ids"],
725+
_log_select = ["input_ids"]
726+
if "content" in (meta.fields or []):
727+
_log_select.append("content")
728+
_log_extras = read_columns(
729+
policy._dp_client, meta, select_fields=_log_select,
710730
pad_value_dict=_pad_dict,
711-
)["input_ids"]
731+
)
732+
_log_input_ids = _log_extras["input_ids"]
733+
_log_content = _log_extras.get("content")
712734

713735
# ── Step-end TQ cleanup ────────────────────────────────
714736
policy._dp_client.kv_clear(
@@ -784,19 +806,19 @@ def grpo_train_sync(
784806
metrics.update(
785807
{f"moe/{k}": v for k, v in train_results["moe_metrics"].items()}
786808
)
809+
# Cumulative unfiltered total_reward across all DS iterations
810+
# (sliced to train_prompts_size). Falls back to filtered
811+
# rewards if apply_dynamic_sampling didn't provide it
812+
# (mid-step path). Hoisted once for reuse in metrics, jsonl,
813+
# and the per-step print below.
814+
unfiltered_rewards = (
815+
unfiltered_rewards_for_logging
816+
if unfiltered_rewards_for_logging is not None
817+
else rewards
818+
)
787819
if master_config["grpo"]["use_dynamic_sampling"]:
788820
metrics["filtered_reward"] = rewards.numpy()
789-
# Cumulative unfiltered total_reward across all
790-
# contributing iterations — matches legacy
791-
# ``metrics["reward"]`` semantics (sliced to
792-
# train_prompts_size). Falls back to filtered if
793-
# apply_dynamic_sampling didn't provide it (e.g.
794-
# mid-step path).
795-
metrics["reward"] = (
796-
unfiltered_rewards_for_logging.numpy()
797-
if unfiltered_rewards_for_logging is not None
798-
else rewards.numpy()
799-
)
821+
metrics["reward"] = unfiltered_rewards.numpy()
800822

801823
metrics.update(train_results["all_mb_metrics"])
802824
metrics.update(gen_step_metrics)
@@ -937,7 +959,13 @@ def grpo_train_sync(
937959
log_data: dict = {}
938960
if "agent_ref" in repeated_batch:
939961
log_data["agent_ref"] = repeated_batch["agent_ref"]
940-
log_data["rewards"] = rewards.tolist()
962+
if master_config["grpo"]["use_dynamic_sampling"]:
963+
# Legacy semantics: ``rewards`` is unfiltered total_reward,
964+
# ``filtered_rewards`` is the kept slice that's trained on.
965+
log_data["rewards"] = unfiltered_rewards.tolist()
966+
log_data["filtered_rewards"] = rewards.tolist()
967+
else:
968+
log_data["rewards"] = rewards.tolist()
941969
log_data["input_lengths"] = input_lengths.tolist()
942970
log_data["token_loss_mask"] = token_mask.tolist()
943971
log_data["sample_loss_mask"] = sample_mask.tolist()
@@ -950,10 +978,10 @@ def grpo_train_sync(
950978
# outer ``if not _should_log_nemo_gym_responses`` branch.
951979
if _log_input_ids is not None:
952980
log_data["token_ids"] = _log_input_ids.tolist()
953-
# NOTE: ``content`` (raw assistant text) is not stored in
954-
# TQ — the codec is tensor-only. When non-tensor logging
955-
# matters, plumb it through Ray return on rollout_to_tq's
956-
# slice.
981+
# ``content`` (raw assistant text) is fetched from TQ as
982+
# an object-array column above (stashed before kv_clear).
983+
if _log_content is not None:
984+
log_data["content"] = _log_content.tolist()
957985
logger.log_batched_dict_as_jsonl(
958986
log_data, f"train_data_step{total_steps + 1}.jsonl"
959987
)
@@ -1005,9 +1033,7 @@ def grpo_train_sync(
10051033
print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}")
10061034
if master_config["grpo"]["use_dynamic_sampling"]:
10071035
print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}")
1008-
print(
1009-
f" • Avg Total Reward: {np.mean(repeated_batch['total_reward'].numpy()):.4f}"
1010-
)
1036+
print(f" • Avg Total Reward: {np.mean(unfiltered_rewards.numpy()):.4f}")
10111037
else:
10121038
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
10131039
print(

nemo_rl/data_plane/adapters/transfer_queue.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,16 +243,25 @@ def _init_tq(cfg: DataPlaneConfig) -> None:
243243
# including this driver). _init_tq only needs local_ip below
244244
# for the metadata/master server URLs (driver-bound).
245245
local_ip = _get_local_node_ip()
246+
# Mooncake virtual segment / local buffer sizing. Defaults sized
247+
# for production-scale rollouts (multi-iter DAPO, large
248+
# message_log object payloads); under-sized values cause
249+
# ``batch_get_tensor returned None`` once mooncake exhausts its
250+
# internal allocator headroom. Lazy-mmap'd, so RSS is bounded
251+
# by actual traffic. Override per-recipe via
252+
# ``data_plane.global_segment_size`` /
253+
# ``data_plane.local_buffer_size`` (bytes).
246254
overlay = {
247255
**controller_overlay,
248256
"backend": {
249257
"storage_backend": "MooncakeStore",
250258
"MooncakeStore": {
251-
# Sized to match data-plane-bench's proven config
252-
# (32-node / 48-node tests). 4 GiB / 1 GiB defaults
253-
# are too small for production-scale rollouts.
254-
"global_segment_size": 128 * 1024**3,
255-
"local_buffer_size": 16 * 1024**3,
259+
"global_segment_size": int(
260+
cfg.get("global_segment_size", 512 * 1024**3)
261+
),
262+
"local_buffer_size": int(
263+
cfg.get("local_buffer_size", 64 * 1024**3)
264+
),
256265
# _init_tq runs on the driver only — driver IS the
257266
# head, so local_ip here is also the head's IP that
258267
# mooncake_master + the metadata server bind to.

nemo_rl/experience/sync_rollout_actor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def rollout_to_tq(
176176
*,
177177
uids: list[str],
178178
partition_id: str,
179+
first_iter: bool = True,
179180
) -> tuple[
180181
KVBatchMeta,
181182
dict[str, Any],
@@ -192,18 +193,40 @@ def rollout_to_tq(
192193
bulk-touching ops — flatten / mask / prompt extraction — that
193194
require ``message_log`` and would otherwise force bulk onto the
194195
driver.
196+
197+
Args:
198+
input_batch: Per-step prompt batch (already repeat-interleaved).
199+
uids: One uid per prompt; bulk keys are ``f"{uid}_g{i}"``.
200+
partition_id: TQ partition target.
201+
first_iter: True on the first DS iteration of a step. Drives
202+
``policy_generation.snapshot_step_metrics()`` so per-step
203+
generation metrics align with the legacy
204+
``grpo.grpo_train`` path. Driver passes
205+
``dynamic_sampling_num_gen_batches == 1``.
195206
"""
196207
# Lazy imports — avoid pulling grpo into this module at load.
197208
from nemo_rl.algorithms.grpo import (
198209
_extract_prompt_only_messages,
199210
_should_use_async_rollouts,
200211
_should_use_nemo_gym,
201212
)
213+
from nemo_rl.algorithms.utils import get_gdpo_reward_component_keys
202214
from nemo_rl.data.llm_message_utils import (
203215
add_loss_mask_to_message_log,
204216
batched_message_log_to_flat_message,
205217
)
206218

219+
# Per-step generation-side metric hooks: snapshot once on the
220+
# first DS iter so backends with per-step deltas have a stable
221+
# anchor; clear accumulators before every rollout. Mirrors
222+
# legacy ``grpo_train``.
223+
if self.policy_generation is not None:
224+
if first_iter and hasattr(
225+
self.policy_generation, "snapshot_step_metrics"
226+
):
227+
self.policy_generation.snapshot_step_metrics()
228+
self.policy_generation.clear_logger_metrics()
229+
207230
cfg = self.master_config
208231
common = dict(
209232
policy_generation=self.policy_generation,
@@ -268,6 +291,11 @@ def rollout_to_tq(
268291
for k, v in flat.get_multimodal_dict(as_tensors=False).items():
269292
if isinstance(v, torch.Tensor):
270293
bulk_batch[k] = v
294+
# ``content`` (raw assistant text per sample) — rides TQ as an
295+
# object array so the driver can fetch it back at jsonl time
296+
# (kv_first_write packs it via pack_object_array).
297+
if "content" in flat:
298+
bulk_batch["content"] = np.asarray(flat["content"], dtype=object)
271299

272300
# Type-driven dispatch (verl pattern): producer-emitted type IS
273301
# the schema. torch.Tensor and np.ndarray(object) pass through;
@@ -302,6 +330,12 @@ def rollout_to_tq(
302330
"input_lengths": input_lengths,
303331
"prompt_ids_for_adv": prompt_flat["token_ids"],
304332
}
333+
# GDPO multi-reward components: scale_rewards iterates these
334+
# keys driver-side and the GDPO advantage estimator reads them
335+
# from rb_for_adv. Plumb them through the slice rather than
336+
# forcing a separate TQ fetch.
337+
for k in get_gdpo_reward_component_keys(fb):
338+
slice_extras[k] = fb[k]
305339

306340
meta = kv_first_write(
307341
bulk_batch, uids=uids,

0 commit comments

Comments
 (0)