Skip to content

Commit 8aa690a

Browse files
docs(data-plane): drop api-lifecycle doc; realistic concrete examples
- Remove ``nemo_rl/data_plane/docs/data-plane-api-lifecycle.md``. It duplicated content already in the README (API surface, key invariant, E2E lifecycle, call counts, perf characterization), and the verl comparison section was out of scope. README + ``data-plane-async-proposal.md`` are now the only data-plane docs. - Rewrite ``Concrete examples`` for production scale: * Call-shapes section walks through a real step: ``num_prompts_per_step=128, num_generations_per_prompt=4``, DP world = 8, prompts ≈ 512 tok, responses ≤ 1024 tok. Shows the full sequence (prepare_step → rollout → reward + DS → logprob + advantage → train + finish_step → val path with carry_keys), with realistic meta sizes and explicit per-stage code. * Sequence-length walkthrough scaled to typical math/code rollout lengths (4 prompts × 2 gens, lengths 361-1445 tok, DP=4, length-balanced packing produces shards within ~25% of each other). * ``make_sequence_length_divisible_by`` gotcha updated to a real TP×CP=8 example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent 8cc9f7d commit 8aa690a

2 files changed

Lines changed: 147 additions & 391 deletions

File tree

nemo_rl/data_plane/README.md

Lines changed: 147 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -191,91 +191,188 @@ Worker write-backs append new columns under the same keys.
191191

192192
### Call shapes
193193

194-
**Rollout produces (one Ray RPC, bundles 6 steps — see `rollout_to_tq` docstring):**
194+
A real step at production scale —
195+
`num_prompts_per_step=128, num_generations_per_prompt=4`, DP world = 8,
196+
prompt ≈ 512 tok, response ≤ 1024 tok. Final batch is `128 × 4 = 512`
197+
rows.
198+
199+
**1. Step prepare + rollout** (driver — `grpo_train_sync` body):
195200

196201
```python
197-
# In grpo_sync.py — train path; full driver_carry returned
198-
uids = [str(uuid.uuid4()) for _ in range(n_prompts)]
199-
(meta, driver_carry, rollout_metrics, gen_metrics) = ray.get(
202+
# Open the per-step TQ partition. Cleared and reused across steps.
203+
policy.prepare_step(num_samples=512, group_size=4)
204+
205+
# One Ray RPC bundles: clear gen metrics → rollout → flatten + mask →
206+
# kv_first_write of bulk to TQ → finish_generation → metrics snapshot.
207+
# The actor handles 6 stages internally; the driver gets back the
208+
# meta handle + a small per-row tensor slice.
209+
n_prompts = repeated_batch.size # 512 (= 128 prompts × 4 gens)
210+
uids = [str(uuid.uuid4()) for _ in range(n_prompts // 4)] # 128 uids
211+
meta, driver_carry, rollout_metrics, gen_metrics = ray.get(
200212
rollout_actor.rollout_to_tq.remote(
201213
repeated_batch,
202214
uids=uids,
203-
partition_id=policy.tq_partition_id,
215+
partition_id=policy.tq_partition_id, # "train"
204216
first_iter=(dynamic_sampling_num_gen_batches == 1),
205217
)
206218
)
207-
# meta.keys = ["<uid>_g0", "<uid>_g1", …]
208-
# meta.sequence_lengths = [<actual content lengths>]
209-
# meta.fields = ["input_ids", "input_lengths", "generation_logprobs",
210-
# "token_mask", "sample_mask", …multimodal extras…]
211-
# driver_carry = BDD with per-row tensors the driver needs
219+
# meta.keys ≈ ["a3f9_g0", "a3f9_g1", "a3f9_g2", "a3f9_g3",
220+
# "b7c1_g0", …] (512 keys)
221+
# meta.sequence_lengths ≈ [847, 612, 1503, 989, 711, …] (actual lens)
222+
# meta.fields = ["input_ids", "input_lengths",
223+
# "generation_logprobs", "token_mask",
224+
# "sample_mask", …multimodal extras…]
225+
# driver_carry : BatchedDataDict of per-row tensors
212226
# (total_reward, loss_multiplier, truncated,
213227
# length, input_lengths, prompt_ids_for_adv,
214-
# response_token_lengths, GDPO components).
228+
# response_token_lengths, GDPO components)
229+
```
215230

216-
# In validate_sync — val only needs total_reward; pass carry_keys to
217-
# avoid wasting Ray transfer on the rest.
218-
(meta, driver_carry, rollout_metrics, _) = ray.get(
219-
rollout_actor.rollout_to_tq.remote(
220-
val_batch, uids=uids, partition_id="val",
221-
carry_keys=["total_reward"], # slim — returns 1-key BDD
231+
**2. Reward + dynamic sampling** (driver, on `driver_carry` only):
232+
233+
```python
234+
driver_carry = scale_rewards(driver_carry, cfg["grpo"]["reward_scaling"])
235+
if cfg["grpo"]["reward_shaping"]["enabled"]:
236+
driver_carry = apply_reward_shaping(driver_carry, cfg["grpo"]["reward_shaping"])
237+
driver_carry["baseline"], driver_carry["std"] = (
238+
calculate_baseline_and_std_per_prompt(
239+
driver_carry["prompt_ids_for_adv"],
240+
driver_carry["total_reward"],
241+
torch.ones_like(driver_carry["total_reward"]),
242+
leave_one_out_baseline=cfg["grpo"]["use_leave_one_out_baseline"],
222243
)
223244
)
245+
# Mirror std/baseline onto meta so dynamic sampling can filter on
246+
# meta alone (no tensor fetch).
247+
meta.stamp_tags(
248+
{
249+
"std": driver_carry["std"].tolist(),
250+
"baseline": driver_carry["baseline"].tolist(),
251+
}
252+
)
253+
254+
# DAPO non-zero-std filter — drops rows where the prompt's reward
255+
# variance is zero, kv_clears their bulk, accumulates survivors
256+
# across iterations until train_prompts_size (512) is reached.
257+
if cfg["grpo"]["use_dynamic_sampling"]:
258+
pending_meta, pending_carry, *_ = _apply_dynamic_sampling(
259+
meta=meta, driver_carry=driver_carry,
260+
pending_meta=pending_meta, pending_carry=pending_carry,
261+
train_prompts_size=512,
262+
num_gen_batches=dynamic_sampling_num_gen_batches,
263+
max_gen_batches=cfg["grpo"]["dynamic_sampling_max_gen_batches"],
264+
dp_client=policy.dp_client,
265+
)
224266
```
225267

226-
**Driver appends a column (small delta, no bulk crosses):**
268+
**3. Logprob + advantage + write-back**:
227269

228270
```python
229-
adv_inputs = policy.read_from_dataplane(meta,
230-
select_fields=["token_logprobs", "rewards"])
231-
advantages = compute_advantages(adv_inputs)
232-
policy.write_to_dataplane(meta, {"advantages": advantages})
271+
# Worker fan-out happens inside these. Per-DP-rank shard via
272+
# shard_meta_for_dp(meta, dp_world=8, …); each worker fetches its
273+
# ~64 keys via kv_batch_get and writes back the result column under
274+
# the same keys on the leader.
275+
prev_lp = policy.get_logprobs_from_meta(meta, timer=timer)["logprobs"]
276+
ref_lp = policy.get_reference_policy_logprobs_from_meta(meta, timer=timer)
277+
ref_lp = ref_lp["reference_logprobs"]
278+
279+
# Driver-side per-token columns for masking. Tiny delta — just two
280+
# fields × 512 rows.
281+
extras = policy.read_from_dataplane(
282+
meta,
283+
select_fields=["generation_logprobs", "token_mask"],
284+
pad_value_dict=_pad_dict,
285+
)
286+
advantages = adv_estimator.compute_advantage(
287+
prompt_ids=driver_carry["prompt_ids_for_adv"],
288+
rewards=rewards, mask=mask,
289+
repeated_batch=adv_inputs,
290+
logprobs_policy=prev_lp,
291+
logprobs_reference=ref_lp,
292+
)
293+
294+
# Write the per-token advantage + post-masking sample_mask back to TQ
295+
# under meta.keys so workers fetch the unified view in train.
296+
policy.write_to_dataplane(
297+
meta,
298+
fields={"advantages": advantages, "sample_mask": sample_mask},
299+
)
233300
```
234301

235-
**Worker fan-out + step end:**
302+
**4. Train + cleanup**:
236303

237304
```python
238305
train_results = policy.train_from_meta(meta, loss_fn=loss_fn, timer=timer)
239-
# (shard_meta_for_dp + Ray fan-out + worker fetch / leader write-back
240-
# all happen inside the policy method — see E2E diagram above.)
306+
policy.finish_step(meta) # drop step's bulk from TQ
307+
```
308+
309+
**5. Validation path** — slim `driver_carry` to skip ~1 MB/batch:
241310

242-
policy.finish_step(meta) # drop step's bulk from TQ
311+
```python
312+
# inside validate_sync; val_batch_size ≈ 64
313+
policy.prepare_val_partition(n_prompts, partition_id="val")
314+
meta, driver_carry, rollout_metrics, _ = ray.get(
315+
rollout_actor.rollout_to_tq.remote(
316+
val_batch, uids=uids, partition_id="val",
317+
finish_generation=False, # keep inference state warm
318+
task_to_env_override=val_task_to_env,
319+
carry_keys=["total_reward"], # only field val consumes
320+
)
321+
)
322+
total_rewards.extend(driver_carry["total_reward"].tolist())
323+
mlog_cols = policy.read_from_dataplane(
324+
meta, select_fields=["turn_roles", "turn_contents"],
325+
)
326+
policy.finish_step(meta)
243327
```
244328

245329
### Sequence-length flow (seqpack / dynbatch)
246330

247-
How `meta.sequence_lengths` routes samples to DP ranks. Worked example:
248-
2 prompts × 2 generations = 4 samples.
331+
How `meta.sequence_lengths` routes samples to DP ranks. Worked example
332+
sized to one production microbatch — 4 prompts × 2 generations = 8
333+
samples, DP world = 4, lengths typical of math/code rollouts.
249334

250335
```
251-
# Rollout actor produces flat sequences (prompt + response per row):
252-
# input_lengths[i] = prompt_len_i + response_len_i.
253-
sample 0 (u0_g0): prompt=3, response=4 → input_lengths=7
254-
sample 1 (u0_g1): prompt=3, response=2 → input_lengths=5
255-
sample 2 (u1_g0): prompt=2, response=6 → input_lengths=8
256-
sample 3 (u1_g1): prompt=2, response=3 → input_lengths=5
336+
# Rollout actor flattens prompt + response per sample.
337+
# input_lengths[i] = prompt_len_i + response_len_i (actual content,
338+
# unpadded).
339+
sample 0 (a3f9_g0): prompt=312, response= 892 → input_lengths=1204
340+
sample 1 (a3f9_g1): prompt=312, response= 187 → input_lengths= 499
341+
sample 2 (b7c1_g0): prompt=421, response= 1024 → input_lengths=1445 ← long
342+
sample 3 (b7c1_g1): prompt=421, response= 455 → input_lengths= 876
343+
sample 4 (c0d8_g0): prompt=148, response= 213 → input_lengths= 361 ← short
344+
sample 5 (c0d8_g1): prompt=148, response= 339 → input_lengths= 487
345+
sample 6 (d2e1_g0): prompt=276, response= 651 → input_lengths= 927
346+
sample 7 (d2e1_g1): prompt=276, response= 402 → input_lengths= 678
257347
258348
# kv_first_write returns meta row-aligned with keys:
259-
meta.keys = ["u0_g0", "u0_g1", "u1_g0", "u1_g1"]
260-
meta.sequence_lengths = [ 7, 5, 8, 5 ]
261-
262-
# shard_meta_for_dp slices both keys and sequence_lengths with the
263-
# same idx_list — driver-side, no TQ I/O. With 2 DP ranks + seqpack:
264-
rank 0: idx=[2, 1] → shard.keys=["u1_g0","u0_g1"] lens=[8,5] (=13)
265-
rank 1: idx=[0, 3] → shard.keys=["u0_g0","u1_g1"] lens=[7,5] (=12)
266-
267-
# Each worker then fetches its slice from TQ:
268-
data = self._fetch(shard) # kv_batch_get(keys=shard.keys, …)
349+
meta.keys = ["a3f9_g0", "a3f9_g1", "b7c1_g0", "b7c1_g1",
350+
"c0d8_g0", "c0d8_g1", "d2e1_g0", "d2e1_g1"]
351+
meta.sequence_lengths = [ 1204, 499, 1445, 876,
352+
361, 487, 927, 678 ]
353+
354+
# shard_meta_for_dp slices keys + sequence_lengths with the SAME
355+
# idx_list — driver-side, no TQ I/O. Length-balanced via seqpack:
356+
rank 0: idx=[2, 4] → keys=["b7c1_g0","c0d8_g0"] lens=[1445, 361] = 1806
357+
rank 1: idx=[0, 5] → keys=["a3f9_g0","c0d8_g1"] lens=[1204, 487] = 1691
358+
rank 2: idx=[6, 1] → keys=["d2e1_g0","a3f9_g1"] lens=[ 927, 499] = 1426
359+
rank 3: idx=[3, 7] → keys=["b7c1_g1","d2e1_g1"] lens=[ 876, 678] = 1554
360+
# Σ packed lengths per rank within ~25% — well-balanced.
361+
362+
# Each worker fetches its own ~64 keys per step from TQ:
363+
data = self._fetch(shard) # kv_batch_get(shard.keys, select_fields=…)
269364
```
270365

271-
**Gotcha — `make_sequence_length_divisible_by`**: `input_ids` is padded
272-
to a TP×CP multiple, but `input_lengths` is the actual content length.
273-
Seqpack balances on actual lengths; padding is reapplied per shard.
366+
**Gotcha — `make_sequence_length_divisible_by` (TP×CP alignment)**:
367+
`input_ids` is padded to a multiple of TP×CP at write time (e.g. 8 for
368+
TP=4, CP=2), but `input_lengths` is the actual content length. Seqpack
369+
balances on actual lengths; padding is reapplied per shard.
274370

275371
```
276-
input_ids: [1,2,3,4,5,6,7, 0,0] # padded to 9 (mult of 4)
277-
input_lengths: 7 # actual
278-
meta.sequence_lengths: 7 # what seqpack uses ✓
372+
# row with input_lengths=1204, TP×CP=8 → input_ids padded to 1208:
373+
input_ids: [t0, t1, …, t1203, 0, 0, 0, 0] # 1208 elems
374+
input_lengths: 1204 # actual
375+
meta.sequence_lengths: 1204 # what seqpack uses ✓
279376
```
280377

281378
---

0 commit comments

Comments
 (0)