@@ -97,12 +97,19 @@ def _apply_dynamic_sampling(
9797 max_gen_batches : int ,
9898 dp_client : DataPlaneClient ,
9999) -> tuple [
100- Optional [KVBatchMeta ], Optional [_DSlice ],
101- list [torch .Tensor ], bool , dict [str , Any ], Optional [torch .Tensor ],
100+ Optional [KVBatchMeta ],
101+ Optional [_DSlice ],
102+ list [torch .Tensor ],
103+ bool ,
104+ dict [str , Any ],
105+ Optional [torch .Tensor ],
102106]:
103- """One iteration. Returns (pending_meta, pending_slice, pending_rewards,
107+ """One iteration.
108+
109+ Returns (pending_meta, pending_slice, pending_rewards,
104110 is_complete, ds_metrics, unfiltered_for_log). When complete, the returned
105- pending_* IS the training batch."""
111+ pending_* IS the training batch.
112+ """
106113 # Cumulative unfiltered total_reward for legacy metrics["reward"]
107114 # parity. Reference-only append (no copy) — slice tensors are
108115 # produced fresh per iteration, not aliased to TQ-owned bulk.
@@ -145,7 +152,9 @@ def _apply_dynamic_sampling(
145152 )
146153 pending_meta = pending_meta .slice (0 , train_prompts_size )
147154 pending_slice = pending_slice .slice (0 , train_prompts_size )
148- ds_metrics ["dynamic_sampling_num_discarded_valid_samples" ] = n - train_prompts_size
155+ ds_metrics ["dynamic_sampling_num_discarded_valid_samples" ] = (
156+ n - train_prompts_size
157+ )
149158
150159 unfiltered_for_log = torch .cat (pending_unfiltered_rewards )[:train_prompts_size ]
151160 return pending_meta , pending_slice , [], True , ds_metrics , unfiltered_for_log
@@ -404,9 +413,7 @@ def grpo_train_sync(
404413 # partition exists with the expected schema.
405414 policy .prepare_step (
406415 num_samples = int (repeated_batch .size ),
407- group_size = master_config ["grpo" ][
408- "num_generations_per_prompt"
409- ],
416+ group_size = master_config ["grpo" ]["num_generations_per_prompt" ],
410417 )
411418
412419 # ── Rollout 1-hop put: actor runs rollout + flatten +
@@ -462,11 +469,13 @@ def grpo_train_sync(
462469 # touched by any of these ops).
463470 with timer .time ("reward_calculation" ):
464471 slice_data = scale_rewards (
465- slice_data , master_config ["grpo" ]["reward_scaling" ],
472+ slice_data ,
473+ master_config ["grpo" ]["reward_scaling" ],
466474 )
467475 if master_config ["grpo" ]["reward_shaping" ]["enabled" ]:
468476 slice_data = apply_reward_shaping (
469- slice_data , master_config ["grpo" ]["reward_shaping" ],
477+ slice_data ,
478+ master_config ["grpo" ]["reward_shaping" ],
470479 )
471480 if master_config ["grpo" ]["overlong_filtering" ]:
472481 lm = slice_data ["loss_multiplier" ].clone ()
@@ -495,9 +504,11 @@ def grpo_train_sync(
495504 * master_config ["grpo" ]["num_generations_per_prompt" ]
496505 )
497506 (
498- pending_meta , pending_slice ,
507+ pending_meta ,
508+ pending_slice ,
499509 pending_unfiltered_rewards ,
500- is_complete , ds_metrics ,
510+ is_complete ,
511+ ds_metrics ,
501512 unfiltered_rewards_for_logging ,
502513 ) = _apply_dynamic_sampling (
503514 meta = meta ,
@@ -571,7 +582,8 @@ def grpo_train_sync(
571582 "skip_reference_policy_logprobs_calculation"
572583 ):
573584 _ref_lp = policy .get_reference_policy_logprobs_from_meta (
574- meta , timer = timer ,
585+ meta ,
586+ timer = timer ,
575587 )
576588 reference_policy_logprobs = _ref_lp ["reference_logprobs" ]
577589 else :
@@ -582,7 +594,8 @@ def grpo_train_sync(
582594 # output_ids, attention_mask, position_ids) stays in
583595 # TQ — workers will fetch it via ``train_presharded``.
584596 extras_bdd = read_columns (
585- policy ._dp_client , meta ,
597+ policy ._dp_client ,
598+ meta ,
586599 select_fields = ["generation_logprobs" , "token_mask" ],
587600 pad_value_dict = _pad_dict ,
588601 )
@@ -658,7 +671,8 @@ def grpo_train_sync(
658671 # sample_mask under the same meta.keys so workers fetch
659672 # the union via train_presharded.
660673 write_columns (
661- policy ._dp_client , meta ,
674+ policy ._dp_client ,
675+ meta ,
662676 fields = {
663677 "advantages" : advantages ,
664678 "sample_mask" : sample_mask ,
@@ -696,20 +710,27 @@ def grpo_train_sync(
696710 # mask / adv columns added later are irrelevant
697711 # here.
698712 _calib_fields = [
699- f for f in (meta .fields or [])
700- if f not in (
701- "generation_logprobs" , "token_mask" ,
702- "sample_mask" , "prev_logprobs" ,
703- "reference_policy_logprobs" , "advantages" ,
713+ f
714+ for f in (meta .fields or [])
715+ if f
716+ not in (
717+ "generation_logprobs" ,
718+ "token_mask" ,
719+ "sample_mask" ,
720+ "prev_logprobs" ,
721+ "reference_policy_logprobs" ,
722+ "advantages" ,
704723 )
705724 ]
706725 calibration_data = read_columns (
707- policy ._dp_client , meta ,
726+ policy ._dp_client ,
727+ meta ,
708728 select_fields = _calib_fields ,
709729 pad_value_dict = _pad_dict ,
710730 )
711731 kv_scales_cache = policy .calibrate_qkv_fp8_scales (
712- calibration_data , include_q = True ,
732+ calibration_data ,
733+ include_q = True ,
713734 )["layers" ]
714735 POLICY_GENERATION_STALE = True
715736
@@ -726,15 +747,18 @@ def grpo_train_sync(
726747 if "content" in (meta .fields or []):
727748 _log_select .append ("content" )
728749 _log_extras = read_columns (
729- policy ._dp_client , meta , select_fields = _log_select ,
750+ policy ._dp_client ,
751+ meta ,
752+ select_fields = _log_select ,
730753 pad_value_dict = _pad_dict ,
731754 )
732755 _log_input_ids = _log_extras ["input_ids" ]
733756 _log_content = _log_extras .get ("content" )
734757
735758 # ── Step-end TQ cleanup ────────────────────────────────
736759 policy ._dp_client .kv_clear (
737- keys = meta .keys , partition_id = meta .partition_id ,
760+ keys = meta .keys ,
761+ partition_id = meta .partition_id ,
738762 )
739763
740764 is_last_step = total_steps + 1 >= max_num_steps
@@ -779,9 +803,7 @@ def grpo_train_sync(
779803
780804 # advantages and token_mask are in scope from the
781805 # advantage / masking blocks above. No need to re-fetch.
782- response_advantages = torch .masked_select (
783- advantages , token_mask .bool ()
784- )
806+ response_advantages = torch .masked_select (advantages , token_mask .bool ())
785807
786808 memory_tracker .snapshot_start_of_stage ("Metrics" , dir ())
787809 metrics = {
@@ -1033,7 +1055,9 @@ def grpo_train_sync(
10331055 print (f" • Generation KL Error: { metrics ['gen_kl_error' ]:.4f} " )
10341056 if master_config ["grpo" ]["use_dynamic_sampling" ]:
10351057 print (f" • Avg Filtered Reward: { np .mean (rewards .numpy ()):.4f} " )
1036- print (f" • Avg Total Reward: { np .mean (unfiltered_rewards .numpy ()):.4f} " )
1058+ print (
1059+ f" • Avg Total Reward: { np .mean (unfiltered_rewards .numpy ()):.4f} "
1060+ )
10371061 else :
10381062 print (f" • Avg Reward: { np .mean (rewards .numpy ()):.4f} " )
10391063 print (
0 commit comments