@@ -191,91 +191,188 @@ Worker write-backs append new columns under the same keys.
191191
192192### Call shapes
193193
194- ** Rollout produces (one Ray RPC, bundles 6 steps — see ` rollout_to_tq ` docstring):**
194+ A real step at production scale —
195+ ` num_prompts_per_step=128, num_generations_per_prompt=4 ` , DP world = 8,
196+ prompt ≈ 512 tok, response ≤ 1024 tok. Final batch is ` 128 × 4 = 512 `
197+ rows.
198+
199+ ** 1. Step prepare + rollout** (driver — ` grpo_train_sync ` body):
195200
196201``` python
197- # In grpo_sync.py — train path; full driver_carry returned
198- uids = [str (uuid.uuid4()) for _ in range (n_prompts)]
199- (meta, driver_carry, rollout_metrics, gen_metrics) = ray.get(
202+ # Open the per-step TQ partition. Cleared and reused across steps.
203+ policy.prepare_step(num_samples = 512 , group_size = 4 )
204+
205+ # One Ray RPC bundles: clear gen metrics → rollout → flatten + mask →
206+ # kv_first_write of bulk to TQ → finish_generation → metrics snapshot.
207+ # The actor handles 6 stages internally; the driver gets back the
208+ # meta handle + a small per-row tensor slice.
209+ n_prompts = repeated_batch.size # 512 (= 128 prompts × 4 gens)
210+ uids = [str (uuid.uuid4()) for _ in range (n_prompts // 4 )] # 128 uids
211+ meta, driver_carry, rollout_metrics, gen_metrics = ray.get(
200212 rollout_actor.rollout_to_tq.remote(
201213 repeated_batch,
202214 uids = uids,
203- partition_id = policy.tq_partition_id,
215+ partition_id = policy.tq_partition_id, # "train"
204216 first_iter = (dynamic_sampling_num_gen_batches == 1 ),
205217 )
206218)
207- # meta.keys = ["<uid>_g0", "<uid>_g1", …]
208- # meta.sequence_lengths = [<actual content lengths>]
209- # meta.fields = ["input_ids", "input_lengths", "generation_logprobs",
210- # "token_mask", "sample_mask", …multimodal extras…]
211- # driver_carry = BDD with per-row tensors the driver needs
219+ # meta.keys ≈ ["a3f9_g0", "a3f9_g1", "a3f9_g2", "a3f9_g3",
220+ # "b7c1_g0", …] (512 keys)
221+ # meta.sequence_lengths ≈ [847, 612, 1503, 989, 711, …] (actual lens)
222+ # meta.fields = ["input_ids", "input_lengths",
223+ # "generation_logprobs", "token_mask",
224+ # "sample_mask", …multimodal extras…]
225+ # driver_carry : BatchedDataDict of per-row tensors
212226# (total_reward, loss_multiplier, truncated,
213227# length, input_lengths, prompt_ids_for_adv,
214- # response_token_lengths, GDPO components).
228+ # response_token_lengths, GDPO components)
229+ ```
215230
216- # In validate_sync — val only needs total_reward; pass carry_keys to
217- # avoid wasting Ray transfer on the rest.
218- (meta, driver_carry, rollout_metrics, _) = ray.get(
219- rollout_actor.rollout_to_tq.remote(
220- val_batch, uids = uids, partition_id = " val" ,
221- carry_keys = [" total_reward" ], # slim — returns 1-key BDD
231+ ** 2. Reward + dynamic sampling** (driver, on ` driver_carry ` only):
232+
233+ ``` python
234+ driver_carry = scale_rewards(driver_carry, cfg[" grpo" ][" reward_scaling" ])
235+ if cfg[" grpo" ][" reward_shaping" ][" enabled" ]:
236+ driver_carry = apply_reward_shaping(driver_carry, cfg[" grpo" ][" reward_shaping" ])
237+ driver_carry[" baseline" ], driver_carry[" std" ] = (
238+ calculate_baseline_and_std_per_prompt(
239+ driver_carry[" prompt_ids_for_adv" ],
240+ driver_carry[" total_reward" ],
241+ torch.ones_like(driver_carry[" total_reward" ]),
242+ leave_one_out_baseline = cfg[" grpo" ][" use_leave_one_out_baseline" ],
222243 )
223244)
245+ # Mirror std/baseline onto meta so dynamic sampling can filter on
246+ # meta alone (no tensor fetch).
247+ meta.stamp_tags(
248+ {
249+ " std" : driver_carry[" std" ].tolist(),
250+ " baseline" : driver_carry[" baseline" ].tolist(),
251+ }
252+ )
253+
254+ # DAPO non-zero-std filter — drops rows where the prompt's reward
255+ # variance is zero, kv_clears their bulk, accumulates survivors
256+ # across iterations until train_prompts_size (512) is reached.
257+ if cfg[" grpo" ][" use_dynamic_sampling" ]:
258+ pending_meta, pending_carry, * _ = _apply_dynamic_sampling(
259+ meta = meta, driver_carry = driver_carry,
260+ pending_meta = pending_meta, pending_carry = pending_carry,
261+ train_prompts_size = 512 ,
262+ num_gen_batches = dynamic_sampling_num_gen_batches,
263+ max_gen_batches = cfg[" grpo" ][" dynamic_sampling_max_gen_batches" ],
264+ dp_client = policy.dp_client,
265+ )
224266```
225267
226- ** Driver appends a column (small delta, no bulk crosses): **
268+ ** 3. Logprob + advantage + write-back ** :
227269
228270``` python
229- adv_inputs = policy.read_from_dataplane(meta,
230- select_fields = [" token_logprobs" , " rewards" ])
231- advantages = compute_advantages(adv_inputs)
232- policy.write_to_dataplane(meta, {" advantages" : advantages})
271+ # Worker fan-out happens inside these. Per-DP-rank shard via
272+ # shard_meta_for_dp(meta, dp_world=8, …); each worker fetches its
273+ # ~64 keys via kv_batch_get and writes back the result column under
274+ # the same keys on the leader.
275+ prev_lp = policy.get_logprobs_from_meta(meta, timer = timer)[" logprobs" ]
276+ ref_lp = policy.get_reference_policy_logprobs_from_meta(meta, timer = timer)
277+ ref_lp = ref_lp[" reference_logprobs" ]
278+
279+ # Driver-side per-token columns for masking. Tiny delta — just two
280+ # fields × 512 rows.
281+ extras = policy.read_from_dataplane(
282+ meta,
283+ select_fields = [" generation_logprobs" , " token_mask" ],
284+ pad_value_dict = _pad_dict,
285+ )
286+ advantages = adv_estimator.compute_advantage(
287+ prompt_ids = driver_carry[" prompt_ids_for_adv" ],
288+ rewards = rewards, mask = mask,
289+ repeated_batch = adv_inputs,
290+ logprobs_policy = prev_lp,
291+ logprobs_reference = ref_lp,
292+ )
293+
294+ # Write the per-token advantage + post-masking sample_mask back to TQ
295+ # under meta.keys so workers fetch the unified view in train.
296+ policy.write_to_dataplane(
297+ meta,
298+ fields = {" advantages" : advantages, " sample_mask" : sample_mask},
299+ )
233300```
234301
235- ** Worker fan-out + step end: **
302+ ** 4. Train + cleanup ** :
236303
237304``` python
238305train_results = policy.train_from_meta(meta, loss_fn = loss_fn, timer = timer)
239- # (shard_meta_for_dp + Ray fan-out + worker fetch / leader write-back
240- # all happen inside the policy method — see E2E diagram above.)
306+ policy.finish_step(meta) # drop step's bulk from TQ
307+ ```
308+
309+ ** 5. Validation path** — slim ` driver_carry ` to skip ~ 1 MB/batch:
241310
242- policy.finish_step(meta) # drop step's bulk from TQ
311+ ``` python
312+ # inside validate_sync; val_batch_size ≈ 64
313+ policy.prepare_val_partition(n_prompts, partition_id = " val" )
314+ meta, driver_carry, rollout_metrics, _ = ray.get(
315+ rollout_actor.rollout_to_tq.remote(
316+ val_batch, uids = uids, partition_id = " val" ,
317+ finish_generation = False , # keep inference state warm
318+ task_to_env_override = val_task_to_env,
319+ carry_keys = [" total_reward" ], # only field val consumes
320+ )
321+ )
322+ total_rewards.extend(driver_carry[" total_reward" ].tolist())
323+ mlog_cols = policy.read_from_dataplane(
324+ meta, select_fields = [" turn_roles" , " turn_contents" ],
325+ )
326+ policy.finish_step(meta)
243327```
244328
245329### Sequence-length flow (seqpack / dynbatch)
246330
247- How ` meta.sequence_lengths ` routes samples to DP ranks. Worked example:
248- 2 prompts × 2 generations = 4 samples.
331+ How ` meta.sequence_lengths ` routes samples to DP ranks. Worked example
332+ sized to one production microbatch — 4 prompts × 2 generations = 8
333+ samples, DP world = 4, lengths typical of math/code rollouts.
249334
250335```
251- # Rollout actor produces flat sequences (prompt + response per row):
252- # input_lengths[i] = prompt_len_i + response_len_i.
253- sample 0 (u0_g0): prompt=3, response=4 → input_lengths=7
254- sample 1 (u0_g1): prompt=3, response=2 → input_lengths=5
255- sample 2 (u1_g0): prompt=2, response=6 → input_lengths=8
256- sample 3 (u1_g1): prompt=2, response=3 → input_lengths=5
336+ # Rollout actor flattens prompt + response per sample.
337+ # input_lengths[i] = prompt_len_i + response_len_i (actual content,
338+ # unpadded).
339+ sample 0 (a3f9_g0): prompt=312, response= 892 → input_lengths=1204
340+ sample 1 (a3f9_g1): prompt=312, response= 187 → input_lengths= 499
341+ sample 2 (b7c1_g0): prompt=421, response= 1024 → input_lengths=1445 ← long
342+ sample 3 (b7c1_g1): prompt=421, response= 455 → input_lengths= 876
343+ sample 4 (c0d8_g0): prompt=148, response= 213 → input_lengths= 361 ← short
344+ sample 5 (c0d8_g1): prompt=148, response= 339 → input_lengths= 487
345+ sample 6 (d2e1_g0): prompt=276, response= 651 → input_lengths= 927
346+ sample 7 (d2e1_g1): prompt=276, response= 402 → input_lengths= 678
257347
258348# kv_first_write returns meta row-aligned with keys:
259- meta.keys = ["u0_g0", "u0_g1", "u1_g0", "u1_g1"]
260- meta.sequence_lengths = [ 7, 5, 8, 5 ]
261-
262- # shard_meta_for_dp slices both keys and sequence_lengths with the
263- # same idx_list — driver-side, no TQ I/O. With 2 DP ranks + seqpack:
264- rank 0: idx=[2, 1] → shard.keys=["u1_g0","u0_g1"] lens=[8,5] (=13)
265- rank 1: idx=[0, 3] → shard.keys=["u0_g0","u1_g1"] lens=[7,5] (=12)
266-
267- # Each worker then fetches its slice from TQ:
268- data = self._fetch(shard) # kv_batch_get(keys=shard.keys, …)
349+ meta.keys = ["a3f9_g0", "a3f9_g1", "b7c1_g0", "b7c1_g1",
350+ "c0d8_g0", "c0d8_g1", "d2e1_g0", "d2e1_g1"]
351+ meta.sequence_lengths = [ 1204, 499, 1445, 876,
352+ 361, 487, 927, 678 ]
353+
354+ # shard_meta_for_dp slices keys + sequence_lengths with the SAME
355+ # idx_list — driver-side, no TQ I/O. Length-balanced via seqpack:
356+ rank 0: idx=[2, 4] → keys=["b7c1_g0","c0d8_g0"] lens=[1445, 361] = 1806
357+ rank 1: idx=[0, 5] → keys=["a3f9_g0","c0d8_g1"] lens=[1204, 487] = 1691
358+ rank 2: idx=[6, 1] → keys=["d2e1_g0","a3f9_g1"] lens=[ 927, 499] = 1426
359+ rank 3: idx=[3, 7] → keys=["b7c1_g1","d2e1_g1"] lens=[ 876, 678] = 1554
360+ # Σ packed lengths per rank within ~25% — well-balanced.
361+
362+ # Each worker fetches its own ~64 keys per step from TQ:
363+ data = self._fetch(shard) # kv_batch_get(shard.keys, select_fields=…)
269364```
270365
271- ** Gotcha — ` make_sequence_length_divisible_by ` ** : ` input_ids ` is padded
272- to a TP×CP multiple, but ` input_lengths ` is the actual content length.
273- Seqpack balances on actual lengths; padding is reapplied per shard.
366+ ** Gotcha — ` make_sequence_length_divisible_by ` (TP×CP alignment)** :
367+ ` input_ids ` is padded to a multiple of TP×CP at write time (e.g. 8 for
368+ TP=4, CP=2), but ` input_lengths ` is the actual content length. Seqpack
369+ balances on actual lengths; padding is reapplied per shard.
274370
275371```
276- input_ids: [1,2,3,4,5,6,7, 0,0] # padded to 9 (mult of 4)
277- input_lengths: 7 # actual
278- meta.sequence_lengths: 7 # what seqpack uses ✓
372+ # row with input_lengths=1204, TP×CP=8 → input_ids padded to 1208:
373+ input_ids: [t0, t1, …, t1203, 0, 0, 0, 0] # 1208 elems
374+ input_lengths: 1204 # actual
375+ meta.sequence_lengths: 1204 # what seqpack uses ✓
279376```
280377
281378---
0 commit comments