From 49e6b55cbed52e68e92350aa2eac2be923a85b85 Mon Sep 17 00:00:00 2001 From: slikhite-1 Date: Wed, 4 Mar 2026 14:34:37 -0800 Subject: [PATCH 01/12] CISPO implementation Signed-off-by: slikhite-1 --- nemo_rl/algorithms/loss/loss_functions.py | 32 ++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 65127e02da..622e653df4 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -128,6 +128,10 @@ class ClippedPGLossConfig(BaseModel, extra="allow"): # NOTE: This should only be used when doing exactly one update per rollout # (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size) force_on_policy_ratio: bool = False + # If True, add KL penalty to reward instead of loss (used by Reinforce++) + use_kl_in_reward: NotRequired[bool] + # If True, use CISPO (Clipped IS-weight Policy Optimization) from MiniMax-M1. + use_cispo: NotRequired[bool] class ClippedPGLossDataDict(TypedDict): @@ -152,6 +156,7 @@ class ClippedPGLossFn(LossFunction): - GRPO - https://arxiv.org/abs/2402.03300 - REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740 - GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071 + - CISPO (set use_cispo = True) - https://arxiv.org/abs/2506.13585 - Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout) Formula: @@ -171,6 +176,10 @@ class ClippedPGLossFn(LossFunction): For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref) + Formula (CISPO): + L(θ) = E_t [ sg(clip(r_t(θ), 1-ε_low, 1+ε_high)) * A_t * log π_θ(a_t|s_t) ] + + Also supports "Dual-Clipping" from https://arxiv.org/pdf/1912.09729, which imposes an additional upper bound on the probability ratio when advantages are negative. This prevents excessive policy updates. $rA << 0$ -> $cA$(clipped) @@ -218,6 +227,23 @@ def __init__(self, cfg: ClippedPGLossConfig): "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" ) + self.use_cispo = cfg.get("use_cispo", False) + if self.use_cispo: + assert not self.disable_ppo_ratio, ( + "use_cispo is incompatible with disable_ppo_ratio; " + "CISPO computes its own IS-weight-based policy gradient loss" + ) + assert not self.force_on_policy_ratio, ( + "use_cispo is incompatible with force_on_policy_ratio" + ) + assert not self.sequence_level_importance_ratios, ( + "use_cispo is incompatible with sequence_level_importance_ratios; " + "CISPO is a token-level loss function" + ) + assert self.ratio_clip_c is None, ( + "use_cispo is incompatible with ratio_clip_c; " + "ratio_clip_c is not supported when use_cispo=True" + ) if self.truncated_importance_sampling_type is not None: assert self.use_importance_sampling_correction, ( "truncated importance sampling is only supported when use_importance_sampling_correction is True" @@ -417,7 +443,11 @@ def __call__( # Determine which value to use for clipping (max for pessimistic estimate) clip_loss = torch.max(loss1, loss2) - + if self.use_cispo: + ratios_clamped = ratios.clamp( + 1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max + ) + clip_loss = -advantages * ratios_clamped.detach() * curr_logprobs # Dual-clipping see https://arxiv.org/pdf/1912.09729 if self.ratio_clip_c is not None: assert self.ratio_clip_c > 1, ( From b3cc275d7ef974ea4931d9fcfbb35b3493f0a706 Mon Sep 17 00:00:00 2001 From: slikhite-1 Date: Wed, 11 Mar 2026 15:58:00 -0700 Subject: [PATCH 02/12] =?UTF-8?q?=E2=89=88test=20cases?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: slikhite-1 --- examples/configs/cispo_math_8B.yaml | 24 ++++++++ tests/unit/algorithms/test_loss_functions.py | 65 ++++++++++++++++++-- 2 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 examples/configs/cispo_math_8B.yaml diff --git a/examples/configs/cispo_math_8B.yaml b/examples/configs/cispo_math_8B.yaml new file mode 100644 index 0000000000..34a53c20e1 --- /dev/null +++ b/examples/configs/cispo_math_8B.yaml @@ -0,0 +1,24 @@ +# CISPO Algorithm Configuration +# This algoritm implements the CISPO algorithm from MiniMax-M1 paper: https://arxiv.org/abs/2506.13585 +defaults: "grpo_math_1B.yaml" + + # ============================================================================ + # CISPO: Clipped IS-weight Policy Optimization + # CISPO clips the IS weight itself and applies stop-gradient, then multiplies by + # advantage and log-probability. + # ratio_clip_min / ratio_clip_max control the IS weight clipping bounds (ε_IS_low / ε_IS_high). + # The original paper sets ratio_clip_min to a large value (effectively no lower bound) and + # only tunes ratio_clip_max. Dual-clipping (ratio_clip_c) is ignored when use_cispo=True. + # See: https://arxiv.org/abs/2506.13585 + +loss_fn: + use_cispo: true + reference_policy_kl_penalty: 0.0 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + +checkpointing: + checkpoint_dir: "results/cispo" + +logger: + log_dir: "logs/cispo" \ No newline at end of file diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index e4ac4fab66..4a001a5f0e 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -623,11 +623,12 @@ def test_clipped_pg_loss_force_on_policy_ratio(): def test_clipped_pg_loss_force_on_policy_ratio_ignores_prev_logprobs(): - """Tests that force_on_policy_ratio ignores prev_logprobs from data and uses curr_logprobs instead. + """Tests that force_on_policy_ratio ignores prev_logprobs from data. - When force_on_policy_ratio=True, the loss function should use curr_logprobs.detach() - as prev_logprobs, so the actual prev_logprobs in data are irrelevant. This allows - skipping the expensive prev_logprobs computation upstream. + When force_on_policy_ratio=True, the loss function should use + curr_logprobs.detach() as prev_logprobs, so the actual prev_logprobs in + data are irrelevant. This allows skipping the expensive prev_logprobs + computation upstream. """ if not torch.cuda.is_available(): pytest.skip("No GPU available") @@ -678,6 +679,62 @@ def test_clipped_pg_loss_force_on_policy_ratio_ignores_prev_logprobs(): assert metrics_1["probs_ratio"] == metrics_2["probs_ratio"] == 1.0 +def test_clipped_pg_loss_cispo(): + """Tests CISPO (Clipped IS-weight Policy Optimization) path in ClippedPGLossFn. + + Uses the same data pattern as test_clipped_pg_loss_ppo_clipping: ratios are + [0.5, 1.0, 1.5] and clamp to [0.8, 1.0, 1.2]. CISPO formula: + + L = -advantages * clip(ratio, 1-eps, 1+eps).detach() * curr_logprobs + + The IS weight is clipped and stop-gradiented; gradients flow only through + curr_logprobs. + """ + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = deepcopy(basic_pg_loss_test_config) + cfg["use_cispo"] = True + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Target ratios: 0.5, 1.0, 1.5 -> after clip(0.2, 0.2): 0.8, 1.0, 1.2 + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + + # --- Hand calculation: CISPO loss = -A * clip(r, 1-ε, 1+ε) * curr_lp (ratio stop-grad) --- + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # [0.5, 1.0, 1.5] + ratios_clamped = torch.clamp( + ratios, 1.0 - cfg["ratio_clip_min"], 1.0 + cfg["ratio_clip_max"] + ) # [0.8, 1.0, 1.2] + cispo_loss_per_token = -adv_masked * ratios_clamped * curr_lp_masked + expected_loss = torch.mean(cispo_loss_per_token) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device + ) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) + + actual_loss, _ = loss_fn( + data=data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum( + data["sample_mask"].unsqueeze(-1) * data["token_mask"] + ), + **loss_input, + ) + torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize("kl_type", ["k1", "k2", "k3"]) def test_calculate_kl(kl_type): """Tests KL calculations.""" From 32af27d96e0fa1f84b435874372f55b94d576451 Mon Sep 17 00:00:00 2001 From: slikhite-1 Date: Fri, 13 Mar 2026 14:29:14 -0700 Subject: [PATCH 03/12] docs Signed-off-by: slikhite-1 --- docs/guides/grpo.md | 14 ++++++++++++++ examples/configs/cispo_math_8B.yaml | 10 ++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 0cddc5c95d..0f1590d022 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -456,6 +456,20 @@ grpo: Set `overlong_filtering` to true when training on tasks where truncation at the maximum sequence length is expected, such as long-form reasoning or mathematical proofs. +#### CISPO (Clipped IS-weight Policy Optimization) + +CISPO introduced in [MiniMax-M1 paper](https://arxiv.org/abs/2506.13585) clips the importance sampling weight itself and applies stop-gradient. + +The loss is: + +$$ +L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \text{sg}\big(\text{clip}(r(\theta), 1-\varepsilon_{\text{low}}, 1+\varepsilon_{\text{high}})\big) \cdot A_t \cdot \log \pi_\theta(x) \Big] +$$ + +where $r(\theta) = \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$, $\text{sg}$ denotes stop-gradient, and $\varepsilon_{\text{low}}$, $\varepsilon_{\text{high}}$ are the IS-weight clipping bounds. Dual-clipping (`ratio_clip_c`) is ignored when CISPO is enabled. + +To use CISPO, set `loss_fn.use_cispo: true` in your config. Tune `ratio_clip_min` and `ratio_clip_max` (mapping to $\varepsilon_{\text{low}}$ and $\varepsilon_{\text{high}}$). It is recommended to use a large `ratio_clip_min` (e.g. 1.0) and tune `ratio_clip_max` (e.g. 0.8). Example: [examples/configs/cispo_math_8B.yaml](../../examples/configs/cispo_math_8B.yaml). + #### Top-p and top-k filtering The implementation aligns with vLLM’s top-p and top-k filtering by applying an equivalent process to the logits. diff --git a/examples/configs/cispo_math_8B.yaml b/examples/configs/cispo_math_8B.yaml index 34a53c20e1..e3e0ceb273 100644 --- a/examples/configs/cispo_math_8B.yaml +++ b/examples/configs/cispo_math_8B.yaml @@ -1,5 +1,4 @@ # CISPO Algorithm Configuration -# This algoritm implements the CISPO algorithm from MiniMax-M1 paper: https://arxiv.org/abs/2506.13585 defaults: "grpo_math_1B.yaml" # ============================================================================ @@ -14,9 +13,12 @@ defaults: "grpo_math_1B.yaml" loss_fn: use_cispo: true reference_policy_kl_penalty: 0.0 - ratio_clip_min: 0.2 - ratio_clip_max: 0.2 - + ratio_clip_min: 1.0 # set to very high + ratio_clip_max: 0.8 #tune as per experiments + token_level_loss: true + force_on_policy_ratio: false + ratio_clip_c: null + checkpointing: checkpoint_dir: "results/cispo" From 617bc93328b8a0c704e14e81ed684139a941bfa1 Mon Sep 17 00:00:00 2001 From: slikhite-1 Date: Wed, 1 Apr 2026 15:01:02 -0700 Subject: [PATCH 04/12] removed assertion Signed-off-by: slikhite-1 --- nemo_rl/algorithms/loss/loss_functions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 622e653df4..246625bf85 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -231,10 +231,6 @@ def __init__(self, cfg: ClippedPGLossConfig): if self.use_cispo: assert not self.disable_ppo_ratio, ( "use_cispo is incompatible with disable_ppo_ratio; " - "CISPO computes its own IS-weight-based policy gradient loss" - ) - assert not self.force_on_policy_ratio, ( - "use_cispo is incompatible with force_on_policy_ratio" ) assert not self.sequence_level_importance_ratios, ( "use_cispo is incompatible with sequence_level_importance_ratios; " @@ -446,7 +442,7 @@ def __call__( if self.use_cispo: ratios_clamped = ratios.clamp( 1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max - ) + ) clip_loss = -advantages * ratios_clamped.detach() * curr_logprobs # Dual-clipping see https://arxiv.org/pdf/1912.09729 if self.ratio_clip_c is not None: From 96f9b2e70b74bad34d02526e62b4d71b4d39ef97 Mon Sep 17 00:00:00 2001 From: slikhite-1 Date: Wed, 1 Apr 2026 16:26:21 -0700 Subject: [PATCH 05/12] assert removed Signed-off-by: slikhite-1 --- docs/guides/grpo.md | 14 -------------- examples/configs/cispo_math_8B.yaml | 3 +-- nemo_rl/algorithms/loss/loss_functions.py | 3 +++ 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 0f1590d022..0cddc5c95d 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -456,20 +456,6 @@ grpo: Set `overlong_filtering` to true when training on tasks where truncation at the maximum sequence length is expected, such as long-form reasoning or mathematical proofs. -#### CISPO (Clipped IS-weight Policy Optimization) - -CISPO introduced in [MiniMax-M1 paper](https://arxiv.org/abs/2506.13585) clips the importance sampling weight itself and applies stop-gradient. - -The loss is: - -$$ -L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \text{sg}\big(\text{clip}(r(\theta), 1-\varepsilon_{\text{low}}, 1+\varepsilon_{\text{high}})\big) \cdot A_t \cdot \log \pi_\theta(x) \Big] -$$ - -where $r(\theta) = \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$, $\text{sg}$ denotes stop-gradient, and $\varepsilon_{\text{low}}$, $\varepsilon_{\text{high}}$ are the IS-weight clipping bounds. Dual-clipping (`ratio_clip_c`) is ignored when CISPO is enabled. - -To use CISPO, set `loss_fn.use_cispo: true` in your config. Tune `ratio_clip_min` and `ratio_clip_max` (mapping to $\varepsilon_{\text{low}}$ and $\varepsilon_{\text{high}}$). It is recommended to use a large `ratio_clip_min` (e.g. 1.0) and tune `ratio_clip_max` (e.g. 0.8). Example: [examples/configs/cispo_math_8B.yaml](../../examples/configs/cispo_math_8B.yaml). - #### Top-p and top-k filtering The implementation aligns with vLLM’s top-p and top-k filtering by applying an equivalent process to the logits. diff --git a/examples/configs/cispo_math_8B.yaml b/examples/configs/cispo_math_8B.yaml index e3e0ceb273..f5d315ca47 100644 --- a/examples/configs/cispo_math_8B.yaml +++ b/examples/configs/cispo_math_8B.yaml @@ -1,8 +1,7 @@ -# CISPO Algorithm Configuration defaults: "grpo_math_1B.yaml" # ============================================================================ - # CISPO: Clipped IS-weight Policy Optimization + # CISPO: Clipped Importance Sampling Policy Optimization # CISPO clips the IS weight itself and applies stop-gradient, then multiplies by # advantage and log-probability. # ratio_clip_min / ratio_clip_max control the IS weight clipping bounds (ε_IS_low / ε_IS_high). diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 246625bf85..3d280b7cb9 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -239,6 +239,9 @@ def __init__(self, cfg: ClippedPGLossConfig): assert self.ratio_clip_c is None, ( "use_cispo is incompatible with ratio_clip_c; " "ratio_clip_c is not supported when use_cispo=True" + if self.truncated_importance_sampling_ratio is not None: + assert self.use_importance_sampling_correction, ( + "truncated_importance_sampling_ratio is only supported when use_importance_sampling_correction is True" ) if self.truncated_importance_sampling_type is not None: assert self.use_importance_sampling_correction, ( From bfa408f4b64caf09d2a580d41fb16bec2bda77a7 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Tue, 19 May 2026 03:30:23 +0000 Subject: [PATCH 06/12] initial fix of the previous PR, add many test cases now, and will remove / check later Signed-off-by: pengdurice --- examples/configs/cispo_math_8B.yaml | 25 ---- examples/configs/grpo_math_1B.yaml | 6 + ...qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml | 87 +++++++++++++ ...-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml | 77 +++++++++++ ...lica-qwen3-30ba3b-2n8g-megatron-cispo.yaml | 95 ++++++++++++++ ...plica-qwen3-30ba3b-2n8g-megatron-dapo.yaml | 88 +++++++++++++ ...plica-qwen3-30ba3b-2n8g-megatron-grpo.yaml | 98 ++++++++++++++ ...n2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml | 59 +++++++++ nemo_rl/algorithms/loss/loss_functions.py | 122 ++++++++++++++++-- ...b-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh | 37 ++++++ ...ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh | 42 ++++++ ...eplica-qwen3-30ba3b-2n8g-megatron-cispo.sh | 32 +++++ ...replica-qwen3-30ba3b-2n8g-megatron-dapo.sh | 32 +++++ ...replica-qwen3-30ba3b-2n8g-megatron-grpo.sh | 33 +++++ ...wen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh | 48 +++++++ tests/test_suites/nightly.txt | 3 + tests/unit/algorithms/test_loss_functions.py | 29 ++++- 17 files changed, 875 insertions(+), 38 deletions(-) delete mode 100644 examples/configs/cispo_math_8B.yaml create mode 100644 examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml create mode 100644 examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml create mode 100644 examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml create mode 100644 examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml create mode 100644 examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml create mode 100644 examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml create mode 100755 tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh create mode 100755 tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh create mode 100755 tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.sh create mode 100755 tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.sh create mode 100755 tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.sh create mode 100755 tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh diff --git a/examples/configs/cispo_math_8B.yaml b/examples/configs/cispo_math_8B.yaml deleted file mode 100644 index f5d315ca47..0000000000 --- a/examples/configs/cispo_math_8B.yaml +++ /dev/null @@ -1,25 +0,0 @@ -defaults: "grpo_math_1B.yaml" - - # ============================================================================ - # CISPO: Clipped Importance Sampling Policy Optimization - # CISPO clips the IS weight itself and applies stop-gradient, then multiplies by - # advantage and log-probability. - # ratio_clip_min / ratio_clip_max control the IS weight clipping bounds (ε_IS_low / ε_IS_high). - # The original paper sets ratio_clip_min to a large value (effectively no lower bound) and - # only tunes ratio_clip_max. Dual-clipping (ratio_clip_c) is ignored when use_cispo=True. - # See: https://arxiv.org/abs/2506.13585 - -loss_fn: - use_cispo: true - reference_policy_kl_penalty: 0.0 - ratio_clip_min: 1.0 # set to very high - ratio_clip_max: 0.8 #tune as per experiments - token_level_loss: true - force_on_policy_ratio: false - ratio_clip_c: null - -checkpointing: - checkpoint_dir: "results/cispo" - -logger: - log_dir: "logs/cispo" \ No newline at end of file diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 4e2b8241f2..375c7b8685 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -71,6 +71,12 @@ loss_fn: token_level_loss: true force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt) use_kl_in_reward: false # Reinforce++: add KL penalty to reward instead of loss + use_cispo: false # CISPO (https://arxiv.org/abs/2506.13585): clipped IS-weight policy optimization + # Optional CISPO-style diagnostics. Cheap; works on GRPO/DAPO/CISPO arms. + # See ClippedPGLossConfig in nemo_rl/algorithms/loss/loss_functions.py. + cispo_diagnostics: false + cispo_diag_grpo_eps: 0.2 # baseline GRPO eps for would_clip_frac + cispo_diag_low_prob_threshold: 0.05 # proxy threshold for rare reflective tokens checkpointing: enabled: true diff --git a/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml b/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml new file mode 100644 index 0000000000..75cf5b6c62 --- /dev/null +++ b/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml @@ -0,0 +1,87 @@ +# A/B treatment arm: CISPO (Clipped IS-weight Policy Optimization, +# arXiv:2506.13585) on Qwen2.5-Math-1.5B-Instruct. +# +# CISPO replaces GRPO's hard PPO clip + advantage product with a stop-gradient +# clipped importance weight applied to the log-probability: +# +# L_CISPO = -A_t * sg(clip(r_t, 1 - eps_low, 1 + eps_high)) * log pi(a_t) +# +# Pair with: +# examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml +# Everything except the loss-fn block is identical between the two arms. +# +# Off-policy regime (where CISPO and the hard PPO clip diverge most): +# * 32 prompts x 16 generations = 512 trajectories per step +# * train_global_batch_size = 128 -> 4 gradient updates per rollout +# (matches the GSPO Sec 5.1 reference setting, arXiv:2507.18071) +# * KL beta = 0 (CISPO paper Sec 5.1; kept identical in both arms so the +# KL regularizer is not a confounder) +# +# NOT in the CISPO PR - this is a local research-validation artifact. +# (The PR ships the on-policy machinery-smoke recipe at +# examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml.) +defaults: ../../grpo_math_1B.yaml + +grpo: + max_num_steps: 100 + val_period: 10 + val_at_start: true + val_at_end: true + max_val_samples: 256 + val_batch_size: 256 + seed: 42 # matched-pair: identical RNG to GRPO arm + +policy: + model_name: Qwen/Qwen2.5-Math-1.5B-Instruct + tokenizer: + name: Qwen/Qwen2.5-Math-1.5B-Instruct + train_global_batch_size: 128 # off-policy: 4 grad updates / rollout + train_micro_batch_size: 4 + logprob_batch_size: 4 + max_total_sequence_length: 1024 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 1024 + +data: + max_input_seq_length: 512 + +loss_fn: + # CISPO treatment arm. Paper-recommended clip: very loose lower (no + # effective lower clip), tighter upper. With nemo-rl's parameterisation + # (lower = 1 - ratio_clip_min, upper = 1 + ratio_clip_max): + # ratio_clip_min = 1.0 -> lower bound = 0.0 (ratios are positive, so this + # is effectively unclipped below) + # ratio_clip_max = 0.8 -> upper bound = 1.8 + use_cispo: true + reference_policy_kl_penalty: 0.0 # matched to the GRPO arm for fairness + reference_policy_kl_type: k3 + ratio_clip_min: 1.0 + ratio_clip_max: 0.8 + ratio_clip_c: null # dual clipping MUST be off for CISPO + token_level_loss: true + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + force_on_policy_ratio: false + +checkpointing: + enabled: false # research run; skip checkpoint I/O + +logger: + log_dir: logs/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: true + wandb: + project: nemo-rl + name: cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml b/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml new file mode 100644 index 0000000000..0f9f396114 --- /dev/null +++ b/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml @@ -0,0 +1,77 @@ +# A/B baseline arm: vanilla GRPO with the standard hard PPO clip. +# +# This recipe is the *control* arm in a back-to-back A/B comparison meant to +# isolate the effect of swapping the hard PPO clip for CISPO's clipped IS- +# weight surrogate (arXiv:2506.13585). Pair with: +# examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml +# Everything except the loss-fn block is identical between the two arms. +# +# Off-policy regime (where CISPO and the hard PPO clip diverge most): +# * 32 prompts x 16 generations = 512 trajectories per step +# * train_global_batch_size = 128 -> 4 gradient updates per rollout +# (matches the GSPO Sec 5.1 reference setting, arXiv:2507.18071) +# * KL beta = 0 (CISPO paper Sec 5.1; kept identical in both arms so the +# KL regularizer is not a confounder) +# * token-level loss, sampling temperature inherited from base +# +# NOT in the CISPO PR - this is a local research-validation artifact. +defaults: ../../grpo_math_1B.yaml + +grpo: + max_num_steps: 100 + val_period: 10 + val_at_start: true + val_at_end: true + max_val_samples: 256 + val_batch_size: 256 + seed: 42 + +policy: + model_name: Qwen/Qwen2.5-Math-1.5B-Instruct + tokenizer: + name: Qwen/Qwen2.5-Math-1.5B-Instruct + train_global_batch_size: 128 # off-policy: 4 grad updates / rollout + train_micro_batch_size: 4 + logprob_batch_size: 4 + max_total_sequence_length: 1024 + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 1024 + +data: + max_input_seq_length: 512 + +loss_fn: + # GRPO control arm: standard hard PPO clip at +/- 0.2. + use_cispo: false + reference_policy_kl_penalty: 0.0 # matched to the CISPO arm for fairness + reference_policy_kl_type: k3 + ratio_clip_min: 0.2 # PPO clip lower bound = 1 - 0.2 = 0.8 + ratio_clip_max: 0.2 # PPO clip upper bound = 1 + 0.2 = 1.2 + ratio_clip_c: null + token_level_loss: true + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + force_on_policy_ratio: false + +checkpointing: + enabled: false # research run; skip checkpoint I/O + +logger: + log_dir: logs/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: true + wandb: + project: nemo-rl + name: cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml b/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml new file mode 100644 index 0000000000..99387ad07e --- /dev/null +++ b/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml @@ -0,0 +1,95 @@ +# MiniMax-M1 replication study (https://arxiv.org/abs/2506.13585), CISPO arm. +# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the +# loss_fn block differs across arms. +# +# CISPO (MiniMax-M1 §3.1) clips the IS weight as a stop-gradient +# coefficient instead of clipping the policy ratio. Gradients flow +# through log pi for *every* token, including the rare reflective +# tokens ("However", "Wait", "Recheck") that GRPO/DAPO would zero out. +# +# L_CISPO = -A * sg(clip(r, 1 - eps_low, 1 + eps_high)) * log pi(a) +# +# Per ms-swift's CISPO recipe and ScaleRL (arXiv:2510.13786), we use a +# very loose lower clip and a much wider upper clip (eps_high = 5.0). +defaults: ../../grpo_math_qwen30ba3b_megatron.yaml + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_num_steps: 200 + val_period: 20 + val_at_start: true + val_at_end: true + max_val_samples: 128 + val_batch_size: 128 + +policy: + model_name: Qwen/Qwen3-30B-A3B + train_global_batch_size: 128 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + sequence_packing: + enabled: true + algorithm: modified_first_fit_decreasing + sequence_length_round: 64 + megatron_cfg: + enabled: true + converter_type: LlamaForCausalLM + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 8 + sequence_parallel: true + empty_unused_memory_level: 1 + freeze_moe_router: true + moe_router_dtype: fp64 + moe_router_load_balancing_type: none + moe_router_bias_update_rate: 0.0 + optimizer: + lr: 3.0e-7 + min_lr: 3.0e-8 + scheduler: + lr_decay_iters: 500 + lr_warmup_iters: 10 + lr_warmup_init: 3.0e-8 + env_vars: + PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False + generation: + vllm_cfg: + tensor_parallel_size: 4 + gpu_memory_utilization: 0.7 + enforce_eager: false + colocated: + enabled: true + +loss_fn: + # ---- CISPO arm ---- + reference_policy_kl_penalty: 0.0 + reference_policy_kl_type: k3 + ratio_clip_min: 1.0 # lower bound = 0; effectively unclipped + ratio_clip_max: 5.0 # eps_high = 5.0 (ms-swift / ScaleRL) + ratio_clip_c: null # dual clipping MUST be off for CISPO + token_level_loss: true + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + force_on_policy_ratio: false + use_cispo: true + cispo_diagnostics: true + cispo_diag_grpo_eps: 0.2 # measure GRPO-equivalent clip rate + cispo_diag_low_prob_threshold: 0.05 + +checkpointing: + enabled: false + +logger: + log_dir: logs/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: false + wandb: + project: nemo-rl + name: cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo + +cluster: + gpus_per_node: 8 + num_nodes: 2 diff --git a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml b/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml new file mode 100644 index 0000000000..1e0208b56c --- /dev/null +++ b/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml @@ -0,0 +1,88 @@ +# MiniMax-M1 replication study (https://arxiv.org/abs/2506.13585), DAPO arm. +# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the +# loss_fn block differs across arms. +# +# DAPO ("Clip-Higher", https://arxiv.org/abs/2503.14476): asymmetric clip +# with a tighter lower bound and a looser upper bound. +defaults: ../../grpo_math_qwen30ba3b_megatron.yaml + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_num_steps: 200 + val_period: 20 + val_at_start: true + val_at_end: true + max_val_samples: 128 + val_batch_size: 128 + +policy: + model_name: Qwen/Qwen3-30B-A3B + train_global_batch_size: 128 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + sequence_packing: + enabled: true + algorithm: modified_first_fit_decreasing + sequence_length_round: 64 + megatron_cfg: + enabled: true + converter_type: LlamaForCausalLM + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 8 + sequence_parallel: true + empty_unused_memory_level: 1 + freeze_moe_router: true + moe_router_dtype: fp64 + moe_router_load_balancing_type: none + moe_router_bias_update_rate: 0.0 + optimizer: + lr: 3.0e-7 + min_lr: 3.0e-8 + scheduler: + lr_decay_iters: 500 + lr_warmup_iters: 10 + lr_warmup_init: 3.0e-8 + env_vars: + PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False + generation: + vllm_cfg: + tensor_parallel_size: 4 + gpu_memory_utilization: 0.7 + enforce_eager: false + colocated: + enabled: true + +loss_fn: + # ---- DAPO ("Clip-Higher") arm ---- + reference_policy_kl_penalty: 0.0 + reference_policy_kl_type: k3 + ratio_clip_min: 0.2 # eps_low - identical to GRPO + ratio_clip_max: 0.28 # eps_high - DAPO "Clip-Higher" + ratio_clip_c: null + token_level_loss: true + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + force_on_policy_ratio: false + use_cispo: false + cispo_diagnostics: true + cispo_diag_grpo_eps: 0.2 + cispo_diag_low_prob_threshold: 0.05 + +checkpointing: + enabled: false + +logger: + log_dir: logs/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: false + wandb: + project: nemo-rl + name: cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo + +cluster: + gpus_per_node: 8 + num_nodes: 2 diff --git a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml b/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml new file mode 100644 index 0000000000..d4fe6df370 --- /dev/null +++ b/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml @@ -0,0 +1,98 @@ +# MiniMax-M1 replication study (https://arxiv.org/abs/2506.13585), GRPO arm. +# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe: +# grpo-qwen3-30ba3b-2n8g-megatron-sapo-asym.yaml +# Only the loss_fn block (and logger names) differs. +# +# Three-way A/B/C: this is the GRPO baseline; DAPO and CISPO are at +# cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml +# cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml +# Submit via cispo_mm1_replica.slurm with ARM=grpo|dapo|cispo. +# +# Off-policy regime: 32 x 16 = 512 trajectories, train_global_batch_size=128 +# -> 4 gradient updates per rollout (SAPO/GSPO Sec 5.1 setting). KL beta=0, +# token-level loss, sampling temperature 1.0. +# NOT in the PR; local research artifact. +defaults: ../../grpo_math_qwen30ba3b_megatron.yaml + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_num_steps: 200 + val_period: 20 + val_at_start: true + val_at_end: true + max_val_samples: 128 + val_batch_size: 128 + +policy: + model_name: Qwen/Qwen3-30B-A3B + train_global_batch_size: 128 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + sequence_packing: + enabled: true + algorithm: modified_first_fit_decreasing + sequence_length_round: 64 + megatron_cfg: + enabled: true + converter_type: LlamaForCausalLM + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 8 + sequence_parallel: true + empty_unused_memory_level: 1 + freeze_moe_router: true + moe_router_dtype: fp64 + moe_router_load_balancing_type: none + moe_router_bias_update_rate: 0.0 + optimizer: + lr: 3.0e-7 + min_lr: 3.0e-8 + scheduler: + lr_decay_iters: 500 + lr_warmup_iters: 10 + lr_warmup_init: 3.0e-8 + env_vars: + PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False + generation: + vllm_cfg: + tensor_parallel_size: 4 + gpu_memory_utilization: 0.7 + enforce_eager: false + colocated: + enabled: true + +loss_fn: + # ---- GRPO baseline arm ---- (only the loss_fn block differs across arms) + reference_policy_kl_penalty: 0.0 + reference_policy_kl_type: k3 + ratio_clip_min: 0.2 # standard PPO/GRPO epsilon + ratio_clip_max: 0.2 + ratio_clip_c: null + token_level_loss: true + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + force_on_policy_ratio: false + use_cispo: false + # Shared CISPO-style diagnostics across all 3 arms so the GRPO baseline + # also reports grpo_would_clip_frac (the gap CISPO claims to close). + cispo_diagnostics: true + cispo_diag_grpo_eps: 0.2 + cispo_diag_low_prob_threshold: 0.05 + +checkpointing: + enabled: false + +logger: + log_dir: logs/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: false + wandb: + project: nemo-rl + name: cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo + +cluster: + gpus_per_node: 8 + num_nodes: 2 diff --git a/examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml new file mode 100644 index 0000000000..04bbc7a16c --- /dev/null +++ b/examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml @@ -0,0 +1,59 @@ +# CISPO (Clipped IS-weight Policy Optimization) on Qwen2.5-Math-1.5B-Instruct. +# +# CISPO replaces GRPO's hard PPO clip + advantage product with a stop-gradient +# clipped importance weight applied to the log-probability: +# +# L_CISPO = -A_t * sg(clip(r_t, 1-eps_low, 1+eps_high)) * log pi_theta(a_t|s_t) +# +# - ratio_clip_min / ratio_clip_max are reused as eps_low / eps_high. +# - The paper (MiniMax-M1, arXiv:2506.13585) recommends a very loose lower +# bound (effectively no lower clip) and a tighter upper bound. We use +# ratio_clip_min=1.0 (lower bound = 1 - 1 = 0, i.e. no effective clipping +# below since ratios are positive) and ratio_clip_max=0.8 (upper bound = 1.8). +# - KL beta is set to 0 to match the paper. +# - Dual-clipping (ratio_clip_c) is incompatible with CISPO and asserted off. +# +# Mirrors the existing grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3 recipe +# for everything other than the loss formulation. +defaults: ../../grpo_math_1B.yaml + +grpo: + max_num_steps: 450 + +checkpointing: + checkpoint_dir: results/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 + +loss_fn: + use_cispo: true + reference_policy_kl_penalty: 0.0 + ratio_clip_min: 1.0 # eps_low: effectively no lower-bound clipping + ratio_clip_max: 0.8 # eps_high: upper bound = 1.8 + ratio_clip_c: null # dual clipping MUST be off for CISPO + +policy: + model_name: Qwen/Qwen2.5-Math-1.5B-Instruct + tokenizer: + name: Qwen/Qwen2.5-Math-1.5B-Instruct + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 + +data: + max_input_seq_length: 512 + +logger: + log_dir: logs/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 + +cluster: + gpus_per_node: 8 diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 3d280b7cb9..cdd3b11db2 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Any, NotRequired, Optional, TypedDict, TypeVar import torch @@ -128,10 +129,28 @@ class ClippedPGLossConfig(BaseModel, extra="allow"): # NOTE: This should only be used when doing exactly one update per rollout # (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size) force_on_policy_ratio: bool = False - # If True, add KL penalty to reward instead of loss (used by Reinforce++) - use_kl_in_reward: NotRequired[bool] # If True, use CISPO (Clipped IS-weight Policy Optimization) from MiniMax-M1. - use_cispo: NotRequired[bool] + use_cispo: bool = False + # If True, log per-step CISPO diagnostic metrics that quantify *why* + # CISPO is (or isn't) helping vs hard-clipped GRPO/DAPO. Off by default + # because the percentile / boolean-mask reductions add a small amount + # of per-step overhead. See ClippedPGLossFn.__call__ for the full list. + # Useful in any arm (GRPO / DAPO / CISPO) - the same metrics let the + # GRPO baseline tell you how much gradient signal it's losing to its + # own hard clip, which is precisely the gap CISPO claims to close. + cispo_diagnostics: bool = False + # The hard-clip epsilon to use when computing the "what fraction of + # tokens would standard GRPO have zeroed the gradient on?" diagnostic. + # Defaults to 0.2 (the original PPO/GRPO value). Independent from + # ratio_clip_min/ratio_clip_max so we can probe the GRPO-equivalent + # behaviour even on a CISPO run with epsilon_high=5.0. + cispo_diag_grpo_eps: float = 0.2 + # Probability threshold under which a token is counted as "low-prob" + # (a coarse, tokenizer-free proxy for rare reflective tokens like + # "However", "Wait", "Recheck" - see MiniMax-M1 paper §3.1). + # Defaults to 0.05; we log the fraction of generated tokens whose + # behaviour-policy probability is below this. + cispo_diag_low_prob_threshold: float = 0.05 class ClippedPGLossDataDict(TypedDict): @@ -227,18 +246,28 @@ def __init__(self, cfg: ClippedPGLossConfig): "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" ) - self.use_cispo = cfg.get("use_cispo", False) + self.use_cispo = cfg.use_cispo if self.use_cispo: assert not self.disable_ppo_ratio, ( "use_cispo is incompatible with disable_ppo_ratio; " - ) - assert not self.sequence_level_importance_ratios, ( - "use_cispo is incompatible with sequence_level_importance_ratios; " - "CISPO is a token-level loss function" + "CISPO needs the pi_theta/pi_theta_old ratio but disable_ppo_ratio removes it" ) assert self.ratio_clip_c is None, ( - "use_cispo is incompatible with ratio_clip_c; " - "ratio_clip_c is not supported when use_cispo=True" + "use_cispo is incompatible with dual clipping (ratio_clip_c); " + "the dual-clip block runs after the CISPO loss assembly and would " + "silently overwrite it. Set ratio_clip_c=null when use_cispo=True." + ) + # CISPO-style diagnostics. Off by default to avoid extra reductions. + self.cispo_diagnostics = cfg.cispo_diagnostics + self.cispo_diag_grpo_eps = cfg.cispo_diag_grpo_eps + self.cispo_diag_low_prob_threshold = cfg.cispo_diag_low_prob_threshold + assert self.cispo_diag_grpo_eps > 0, ( + f"cispo_diag_grpo_eps must be positive, got {self.cispo_diag_grpo_eps}" + ) + assert 0.0 < self.cispo_diag_low_prob_threshold < 1.0, ( + "cispo_diag_low_prob_threshold must be a probability in (0, 1), " + f"got {self.cispo_diag_low_prob_threshold}" + ) if self.truncated_importance_sampling_ratio is not None: assert self.use_importance_sampling_correction, ( "truncated_importance_sampling_ratio is only supported when use_importance_sampling_correction is True" @@ -637,6 +666,78 @@ def __call__( probs_ratio_clamped_min = float("inf") probs_ratio_clamped_max = float("-inf") + # CISPO-style diagnostics. Designed to be tokenizer-free and cheap: + # all reductions are over the same (mask) we already build above. We + # log even on GRPO/DAPO arms so the gap CISPO claims to close can be + # *measured directly* on the baseline (the would_clip_frac). + cispo_diag_metrics: dict[str, float] = {} + if self.cispo_diagnostics: + with torch.no_grad(): + eps = self.cispo_diag_grpo_eps + detached_ratios = ratios.detach() + # "Would standard GRPO have zeroed this token's gradient?" + # - positive-advantage tokens lose their gradient when r > 1+eps + # - negative-advantage tokens lose their gradient when r < 1-eps + adv_pos = (advantages > 0).float() + adv_neg = (advantages < 0).float() + would_clip_pos = adv_pos * (detached_ratios > 1.0 + eps).float() + would_clip_neg = adv_neg * (detached_ratios < 1.0 - eps).float() + grpo_would_clip_frac = masked_mean( + would_clip_pos + would_clip_neg, + mask, + global_normalization_factor=global_valid_toks, + ).item() + grpo_would_clip_pos_frac = masked_mean( + would_clip_pos, + mask, + global_normalization_factor=global_valid_toks, + ).item() + grpo_would_clip_neg_frac = masked_mean( + would_clip_neg, + mask, + global_normalization_factor=global_valid_toks, + ).item() + + if masked_ratios.numel() > 0: + r_t_p50 = torch.quantile(masked_ratios, 0.50).item() + r_t_p95 = torch.quantile(masked_ratios, 0.95).item() + r_t_p99 = torch.quantile(masked_ratios, 0.99).item() + else: + r_t_p50 = r_t_p95 = r_t_p99 = float("nan") + + # Coarse, tokenizer-free proxy for "rare reflective tokens": + # tokens whose behaviour-policy probability was below the + # threshold. CISPO's central claim is that these are exactly + # the tokens GRPO's hard clip throws away. + low_thr = math.log(self.cispo_diag_low_prob_threshold) + low_prob_token = (prev_logprobs < low_thr).float() + low_prob_token_frac = masked_mean( + low_prob_token, + mask, + global_normalization_factor=global_valid_toks, + ).item() + # Of the would-be-clipped tokens, what fraction are also + # "low-prob"? A high number here is the smoking-gun + # confirmation of the paper's diagnosis. + would_clip_and_low_prob = masked_mean( + (would_clip_pos + would_clip_neg) * low_prob_token, + mask, + global_normalization_factor=global_valid_toks, + ).item() + + cispo_diag_metrics = { + "cispo_diag/grpo_would_clip_frac": grpo_would_clip_frac, + "cispo_diag/grpo_would_clip_pos_frac": grpo_would_clip_pos_frac, + "cispo_diag/grpo_would_clip_neg_frac": grpo_would_clip_neg_frac, + "cispo_diag/r_t_p50": r_t_p50, + "cispo_diag/r_t_p95": r_t_p95, + "cispo_diag/r_t_p99": r_t_p99, + "cispo_diag/low_prob_token_frac": low_prob_token_frac, + "cispo_diag/would_clip_and_low_prob_frac": would_clip_and_low_prob, + "cispo_diag/grpo_eps": eps, + "cispo_diag/low_prob_threshold": self.cispo_diag_low_prob_threshold, + } + # If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized # by either sequence or token count, depending on particular metric. # To get the true metric, you'll need to sum over the microbatch. @@ -659,6 +760,7 @@ def __call__( "num_valid_samples": sample_mask.sum().item(), "approx_entropy": seq_entropy_approx.item(), **_is_filter_metrics, + **cispo_diag_metrics, }, ) diff --git a/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh b/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh new file mode 100755 index 0000000000..7210b9527b --- /dev/null +++ b/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh @@ -0,0 +1,37 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# CISPO-vs-GRPO A/B: treatment arm (CISPO loss with paper-default clip). +# See examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml. +# NOT in the CISPO PR - local research-validation only. + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=100 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) +NUM_MINUTES=90 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=False \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Dump TB to JSON for offline A/B comparison. +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# NOTE: no `tests/check_metrics.py` thresholds here (see the grpo arm's .sh +# for rationale). diff --git a/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh b/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh new file mode 100755 index 0000000000..7c3c0248a8 --- /dev/null +++ b/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh @@ -0,0 +1,42 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# CISPO-vs-GRPO A/B: control arm (vanilla GRPO with hard PPO clip). +# See examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml. +# NOT in the CISPO PR - local research-validation only. + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=100 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) +NUM_MINUTES=90 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=False \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Dump TB to JSON so the A/B runs can be compared offline (e.g. with +# `python tests/json_dump_tb_logs.py --diff`). +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# NOTE: no `tests/check_metrics.py` thresholds here. This is a research +# A/B - the goal is to *compare* the two arms, not gate either of them +# against absolute numbers. Inspect train/reward, validation/reward, +# train/token_mult_prob_error, and train/probs_ratio_clamped_frac side by +# side (wandb group=cispo-ab, or feed both metrics.json files to your +# preferred diff tool). diff --git a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.sh b/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.sh new file mode 100755 index 0000000000..c8afbedb97 --- /dev/null +++ b/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.sh @@ -0,0 +1,32 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# MiniMax-M1 replication study, CISPO arm (2n8g). +# See examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml. +# NOT in the CISPO PR - local research artifact. + +# ===== BEGIN CONFIG ===== +NUM_NODES=2 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) +NUM_MINUTES=$((24 * 60)) +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=False \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.sh b/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.sh new file mode 100755 index 0000000000..9d76f80ec3 --- /dev/null +++ b/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.sh @@ -0,0 +1,32 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# MiniMax-M1 replication study, DAPO ("Clip-Higher") arm (2n8g). +# See examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml. +# NOT in the CISPO PR - local research artifact. + +# ===== BEGIN CONFIG ===== +NUM_NODES=2 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) +NUM_MINUTES=$((24 * 60)) +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=False \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.sh b/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.sh new file mode 100755 index 0000000000..c6ef91cdce --- /dev/null +++ b/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.sh @@ -0,0 +1,33 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# MiniMax-M1 replication study, GRPO baseline arm (2n8g sized to match the +# proven SAPO sister recipe). +# See examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml. +# NOT in the CISPO PR - local research artifact. + +# ===== BEGIN CONFIG ===== +NUM_NODES=2 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) +NUM_MINUTES=$((24 * 60)) +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=False \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh b/tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh new file mode 100755 index 0000000000..b2bbfe1fb5 --- /dev/null +++ b/tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh @@ -0,0 +1,48 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# CISPO (Clipped IS-weight Policy Optimization, arXiv:2506.13585) on +# Qwen2.5-Math-1.5B-Instruct. Replaces GRPO's hard PPO clip with a +# stop-gradient clipped importance weight applied to the log-probability. +# See examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml. + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=450 +MAX_STEPS=450 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=120 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ + 'data["train/token_mult_prob_error"]["450"] < 1.1' \ + 'mean(data["timing/train/total_step_time"], 2) < 25' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 0ceb4b13fd..e816031b73 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -56,6 +56,9 @@ tests/test_suites/llm/grpo-deepscaler-1.5b-8K.sh # Deepscaler (GSPO) tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh +# CISPO (Clipped IS-weight Policy Optimization, arXiv:2506.13585) +tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh + # HelpSteer3 tests # Issue with details: https://github.com/NVIDIA-NeMo/RL/issues/1571 # tests/test_suites/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-8n8g-fsdp2tp8cp4.sh diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 4a001a5f0e..ea22552bb2 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -679,6 +679,30 @@ def test_clipped_pg_loss_force_on_policy_ratio_ignores_prev_logprobs(): assert metrics_1["probs_ratio"] == metrics_2["probs_ratio"] == 1.0 +@pytest.mark.parametrize( + "incompatible_flag,value", + [ + ("disable_ppo_ratio", True), + ("ratio_clip_c", 3.0), + ], +) +def test_clipped_pg_loss_cispo_incompatibility_asserts(incompatible_flag, value): + """CISPO must reject configs that conflict with its semantics. + + - disable_ppo_ratio removes the pi_theta / pi_theta_old ratio that CISPO + uses as the importance weight, so they are mutually exclusive. + - ratio_clip_c (dual clipping) runs after the CISPO loss assembly inside + ClippedPGLossFn and would silently overwrite it. + """ + cfg = ClippedPGLossConfig( + reference_policy_kl_penalty=0.0, + use_cispo=True, + **{incompatible_flag: value}, + ) + with pytest.raises(AssertionError): + ClippedPGLossFn(cfg) + + def test_clipped_pg_loss_cispo(): """Tests CISPO (Clipped IS-weight Policy Optimization) path in ClippedPGLossFn. @@ -696,8 +720,7 @@ def test_clipped_pg_loss_cispo(): device = "cuda" data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - cfg = deepcopy(basic_pg_loss_test_config) - cfg["use_cispo"] = True + cfg = ClippedPGLossConfig(reference_policy_kl_penalty=0.0, use_cispo=True) loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) @@ -713,7 +736,7 @@ def test_clipped_pg_loss_cispo(): # --- Hand calculation: CISPO loss = -A * clip(r, 1-ε, 1+ε) * curr_lp (ratio stop-grad) --- ratios = torch.exp(curr_lp_masked - prev_lp_masked) # [0.5, 1.0, 1.5] ratios_clamped = torch.clamp( - ratios, 1.0 - cfg["ratio_clip_min"], 1.0 + cfg["ratio_clip_max"] + ratios, 1.0 - cfg.ratio_clip_min, 1.0 + cfg.ratio_clip_max ) # [0.8, 1.0, 1.2] cispo_loss_per_token = -adv_masked * ratios_clamped * curr_lp_masked expected_loss = torch.mean(cispo_loss_per_token) From fa3b68a68657aaf567d462db49114a10c630be6a Mon Sep 17 00:00:00 2001 From: pengdurice Date: Wed, 20 May 2026 14:14:35 +0000 Subject: [PATCH 07/12] only include used tests and will do more clean up later Signed-off-by: pengdurice --- ...qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml | 87 ------------------ ...-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml | 77 ---------------- ...icy-qwen3-30ba3b-2n8g-megatron-cispo.yaml} | 8 +- ...licy-qwen3-30ba3b-2n8g-megatron-grpo.yaml} | 16 ++-- ...plica-qwen3-30ba3b-2n8g-megatron-dapo.yaml | 88 ------------------- ...n2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml | 59 ------------- ...b-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh | 37 -------- ...ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh | 42 --------- ...olicy-qwen3-30ba3b-2n8g-megatron-cispo.sh} | 4 +- ...policy-qwen3-30ba3b-2n8g-megatron-grpo.sh} | 4 +- ...replica-qwen3-30ba3b-2n8g-megatron-dapo.sh | 32 ------- ...wen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh | 48 ---------- 12 files changed, 16 insertions(+), 486 deletions(-) delete mode 100644 examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml delete mode 100644 examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml rename examples/configs/recipes/llm/{cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml => cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml} (91%) rename examples/configs/recipes/llm/{cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml => cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml} (84%) delete mode 100644 examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml delete mode 100644 examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml delete mode 100755 tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh delete mode 100755 tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh rename tests/test_suites/llm/{cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.sh => cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh} (84%) rename tests/test_suites/llm/{cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.sh => cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh} (82%) delete mode 100755 tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.sh delete mode 100755 tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh diff --git a/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml b/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml deleted file mode 100644 index 75cf5b6c62..0000000000 --- a/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# A/B treatment arm: CISPO (Clipped IS-weight Policy Optimization, -# arXiv:2506.13585) on Qwen2.5-Math-1.5B-Instruct. -# -# CISPO replaces GRPO's hard PPO clip + advantage product with a stop-gradient -# clipped importance weight applied to the log-probability: -# -# L_CISPO = -A_t * sg(clip(r_t, 1 - eps_low, 1 + eps_high)) * log pi(a_t) -# -# Pair with: -# examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml -# Everything except the loss-fn block is identical between the two arms. -# -# Off-policy regime (where CISPO and the hard PPO clip diverge most): -# * 32 prompts x 16 generations = 512 trajectories per step -# * train_global_batch_size = 128 -> 4 gradient updates per rollout -# (matches the GSPO Sec 5.1 reference setting, arXiv:2507.18071) -# * KL beta = 0 (CISPO paper Sec 5.1; kept identical in both arms so the -# KL regularizer is not a confounder) -# -# NOT in the CISPO PR - this is a local research-validation artifact. -# (The PR ships the on-policy machinery-smoke recipe at -# examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml.) -defaults: ../../grpo_math_1B.yaml - -grpo: - max_num_steps: 100 - val_period: 10 - val_at_start: true - val_at_end: true - max_val_samples: 256 - val_batch_size: 256 - seed: 42 # matched-pair: identical RNG to GRPO arm - -policy: - model_name: Qwen/Qwen2.5-Math-1.5B-Instruct - tokenizer: - name: Qwen/Qwen2.5-Math-1.5B-Instruct - train_global_batch_size: 128 # off-policy: 4 grad updates / rollout - train_micro_batch_size: 4 - logprob_batch_size: 4 - max_total_sequence_length: 1024 - dynamic_batching: - enabled: true - sequence_packing: - enabled: false - make_sequence_length_divisible_by: 1 - generation: - max_new_tokens: 512 - vllm_cfg: - max_model_len: 1024 - -data: - max_input_seq_length: 512 - -loss_fn: - # CISPO treatment arm. Paper-recommended clip: very loose lower (no - # effective lower clip), tighter upper. With nemo-rl's parameterisation - # (lower = 1 - ratio_clip_min, upper = 1 + ratio_clip_max): - # ratio_clip_min = 1.0 -> lower bound = 0.0 (ratios are positive, so this - # is effectively unclipped below) - # ratio_clip_max = 0.8 -> upper bound = 1.8 - use_cispo: true - reference_policy_kl_penalty: 0.0 # matched to the GRPO arm for fairness - reference_policy_kl_type: k3 - ratio_clip_min: 1.0 - ratio_clip_max: 0.8 - ratio_clip_c: null # dual clipping MUST be off for CISPO - token_level_loss: true - use_importance_sampling_correction: false - sequence_level_importance_ratios: false - force_on_policy_ratio: false - -checkpointing: - enabled: false # research run; skip checkpoint I/O - -logger: - log_dir: logs/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo - wandb_enabled: true - tensorboard_enabled: true - monitor_gpus: true - wandb: - project: nemo-rl - name: cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo - -cluster: - gpus_per_node: 8 - num_nodes: 1 diff --git a/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml b/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml deleted file mode 100644 index 0f9f396114..0000000000 --- a/examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml +++ /dev/null @@ -1,77 +0,0 @@ -# A/B baseline arm: vanilla GRPO with the standard hard PPO clip. -# -# This recipe is the *control* arm in a back-to-back A/B comparison meant to -# isolate the effect of swapping the hard PPO clip for CISPO's clipped IS- -# weight surrogate (arXiv:2506.13585). Pair with: -# examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml -# Everything except the loss-fn block is identical between the two arms. -# -# Off-policy regime (where CISPO and the hard PPO clip diverge most): -# * 32 prompts x 16 generations = 512 trajectories per step -# * train_global_batch_size = 128 -> 4 gradient updates per rollout -# (matches the GSPO Sec 5.1 reference setting, arXiv:2507.18071) -# * KL beta = 0 (CISPO paper Sec 5.1; kept identical in both arms so the -# KL regularizer is not a confounder) -# * token-level loss, sampling temperature inherited from base -# -# NOT in the CISPO PR - this is a local research-validation artifact. -defaults: ../../grpo_math_1B.yaml - -grpo: - max_num_steps: 100 - val_period: 10 - val_at_start: true - val_at_end: true - max_val_samples: 256 - val_batch_size: 256 - seed: 42 - -policy: - model_name: Qwen/Qwen2.5-Math-1.5B-Instruct - tokenizer: - name: Qwen/Qwen2.5-Math-1.5B-Instruct - train_global_batch_size: 128 # off-policy: 4 grad updates / rollout - train_micro_batch_size: 4 - logprob_batch_size: 4 - max_total_sequence_length: 1024 - dynamic_batching: - enabled: true - sequence_packing: - enabled: false - make_sequence_length_divisible_by: 1 - generation: - max_new_tokens: 512 - vllm_cfg: - max_model_len: 1024 - -data: - max_input_seq_length: 512 - -loss_fn: - # GRPO control arm: standard hard PPO clip at +/- 0.2. - use_cispo: false - reference_policy_kl_penalty: 0.0 # matched to the CISPO arm for fairness - reference_policy_kl_type: k3 - ratio_clip_min: 0.2 # PPO clip lower bound = 1 - 0.2 = 0.8 - ratio_clip_max: 0.2 # PPO clip upper bound = 1 + 0.2 = 1.2 - ratio_clip_c: null - token_level_loss: true - use_importance_sampling_correction: false - sequence_level_importance_ratios: false - force_on_policy_ratio: false - -checkpointing: - enabled: false # research run; skip checkpoint I/O - -logger: - log_dir: logs/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo - wandb_enabled: true - tensorboard_enabled: true - monitor_gpus: true - wandb: - project: nemo-rl - name: cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo - -cluster: - gpus_per_node: 8 - num_nodes: 1 diff --git a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml similarity index 91% rename from examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml rename to examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml index 99387ad07e..4ad484ed18 100644 --- a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml +++ b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml @@ -1,4 +1,4 @@ -# MiniMax-M1 replication study (https://arxiv.org/abs/2506.13585), CISPO arm. +# MiniMax-M1 high-off-policy study (https://arxiv.org/abs/2506.13585), CISPO arm. # Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the # loss_fn block differs across arms. # @@ -25,7 +25,7 @@ grpo: policy: model_name: Qwen/Qwen3-30B-A3B - train_global_batch_size: 128 + train_global_batch_size: 32 train_micro_batch_size: 1 logprob_batch_size: 1 max_total_sequence_length: 4096 @@ -82,13 +82,13 @@ checkpointing: enabled: false logger: - log_dir: logs/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo + log_dir: logs/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo wandb_enabled: true tensorboard_enabled: true monitor_gpus: false wandb: project: nemo-rl - name: cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo + name: cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo cluster: gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml similarity index 84% rename from examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml rename to examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml index d4fe6df370..6ba1491f58 100644 --- a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml +++ b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml @@ -1,15 +1,15 @@ -# MiniMax-M1 replication study (https://arxiv.org/abs/2506.13585), GRPO arm. +# MiniMax-M1 high-off-policy study (https://arxiv.org/abs/2506.13585), GRPO arm. # Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe: # grpo-qwen3-30ba3b-2n8g-megatron-sapo-asym.yaml # Only the loss_fn block (and logger names) differs. # # Three-way A/B/C: this is the GRPO baseline; DAPO and CISPO are at -# cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml -# cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml +# cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-dapo.yaml +# cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml # Submit via cispo_mm1_replica.slurm with ARM=grpo|dapo|cispo. # -# Off-policy regime: 32 x 16 = 512 trajectories, train_global_batch_size=128 -# -> 4 gradient updates per rollout (SAPO/GSPO Sec 5.1 setting). KL beta=0, +# Off-policy regime: 32 x 16 = 512 trajectories, train_global_batch_size=32 +# -> 16 gradient updates per rollout (SAPO/GSPO Sec 5.1 setting). KL beta=0, # token-level loss, sampling temperature 1.0. # NOT in the PR; local research artifact. defaults: ../../grpo_math_qwen30ba3b_megatron.yaml @@ -26,7 +26,7 @@ grpo: policy: model_name: Qwen/Qwen3-30B-A3B - train_global_batch_size: 128 + train_global_batch_size: 32 train_micro_batch_size: 1 logprob_batch_size: 1 max_total_sequence_length: 4096 @@ -85,13 +85,13 @@ checkpointing: enabled: false logger: - log_dir: logs/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo + log_dir: logs/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo wandb_enabled: true tensorboard_enabled: true monitor_gpus: false wandb: project: nemo-rl - name: cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo + name: cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo cluster: gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml b/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml deleted file mode 100644 index 1e0208b56c..0000000000 --- a/examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml +++ /dev/null @@ -1,88 +0,0 @@ -# MiniMax-M1 replication study (https://arxiv.org/abs/2506.13585), DAPO arm. -# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the -# loss_fn block differs across arms. -# -# DAPO ("Clip-Higher", https://arxiv.org/abs/2503.14476): asymmetric clip -# with a tighter lower bound and a looser upper bound. -defaults: ../../grpo_math_qwen30ba3b_megatron.yaml - -grpo: - num_prompts_per_step: 32 - num_generations_per_prompt: 16 - max_num_steps: 200 - val_period: 20 - val_at_start: true - val_at_end: true - max_val_samples: 128 - val_batch_size: 128 - -policy: - model_name: Qwen/Qwen3-30B-A3B - train_global_batch_size: 128 - train_micro_batch_size: 1 - logprob_batch_size: 1 - max_total_sequence_length: 4096 - sequence_packing: - enabled: true - algorithm: modified_first_fit_decreasing - sequence_length_round: 64 - megatron_cfg: - enabled: true - converter_type: LlamaForCausalLM - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - expert_model_parallel_size: 8 - sequence_parallel: true - empty_unused_memory_level: 1 - freeze_moe_router: true - moe_router_dtype: fp64 - moe_router_load_balancing_type: none - moe_router_bias_update_rate: 0.0 - optimizer: - lr: 3.0e-7 - min_lr: 3.0e-8 - scheduler: - lr_decay_iters: 500 - lr_warmup_iters: 10 - lr_warmup_init: 3.0e-8 - env_vars: - PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False - generation: - vllm_cfg: - tensor_parallel_size: 4 - gpu_memory_utilization: 0.7 - enforce_eager: false - colocated: - enabled: true - -loss_fn: - # ---- DAPO ("Clip-Higher") arm ---- - reference_policy_kl_penalty: 0.0 - reference_policy_kl_type: k3 - ratio_clip_min: 0.2 # eps_low - identical to GRPO - ratio_clip_max: 0.28 # eps_high - DAPO "Clip-Higher" - ratio_clip_c: null - token_level_loss: true - use_importance_sampling_correction: false - sequence_level_importance_ratios: false - force_on_policy_ratio: false - use_cispo: false - cispo_diagnostics: true - cispo_diag_grpo_eps: 0.2 - cispo_diag_low_prob_threshold: 0.05 - -checkpointing: - enabled: false - -logger: - log_dir: logs/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo - wandb_enabled: true - tensorboard_enabled: true - monitor_gpus: false - wandb: - project: nemo-rl - name: cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo - -cluster: - gpus_per_node: 8 - num_nodes: 2 diff --git a/examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml deleted file mode 100644 index 04bbc7a16c..0000000000 --- a/examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml +++ /dev/null @@ -1,59 +0,0 @@ -# CISPO (Clipped IS-weight Policy Optimization) on Qwen2.5-Math-1.5B-Instruct. -# -# CISPO replaces GRPO's hard PPO clip + advantage product with a stop-gradient -# clipped importance weight applied to the log-probability: -# -# L_CISPO = -A_t * sg(clip(r_t, 1-eps_low, 1+eps_high)) * log pi_theta(a_t|s_t) -# -# - ratio_clip_min / ratio_clip_max are reused as eps_low / eps_high. -# - The paper (MiniMax-M1, arXiv:2506.13585) recommends a very loose lower -# bound (effectively no lower clip) and a tighter upper bound. We use -# ratio_clip_min=1.0 (lower bound = 1 - 1 = 0, i.e. no effective clipping -# below since ratios are positive) and ratio_clip_max=0.8 (upper bound = 1.8). -# - KL beta is set to 0 to match the paper. -# - Dual-clipping (ratio_clip_c) is incompatible with CISPO and asserted off. -# -# Mirrors the existing grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3 recipe -# for everything other than the loss formulation. -defaults: ../../grpo_math_1B.yaml - -grpo: - max_num_steps: 450 - -checkpointing: - checkpoint_dir: results/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 - -loss_fn: - use_cispo: true - reference_policy_kl_penalty: 0.0 - ratio_clip_min: 1.0 # eps_low: effectively no lower-bound clipping - ratio_clip_max: 0.8 # eps_high: upper bound = 1.8 - ratio_clip_c: null # dual clipping MUST be off for CISPO - -policy: - model_name: Qwen/Qwen2.5-Math-1.5B-Instruct - tokenizer: - name: Qwen/Qwen2.5-Math-1.5B-Instruct - dynamic_batching: - enabled: true - sequence_packing: - enabled: false - make_sequence_length_divisible_by: 1 - generation: - max_new_tokens: 512 - vllm_cfg: - max_model_len: 512 - -data: - max_input_seq_length: 512 - -logger: - log_dir: logs/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 - wandb_enabled: true - tensorboard_enabled: true - wandb: - project: nemo-rl - name: cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 - -cluster: - gpus_per_node: 8 diff --git a/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh b/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh deleted file mode 100755 index 7210b9527b..0000000000 --- a/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -source $SCRIPT_DIR/common.env - -# CISPO-vs-GRPO A/B: treatment arm (CISPO loss with paper-default clip). -# See examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml. -# NOT in the CISPO PR - local research-validation only. - -# ===== BEGIN CONFIG ===== -NUM_NODES=1 -STEPS_PER_RUN=100 -MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) -NUM_MINUTES=90 -# ===== END CONFIG ===== - -exit_if_max_steps_reached - -cd $PROJECT_ROOT -uv run examples/run_grpo.py \ - --config $CONFIG_PATH \ - grpo.max_num_steps=$MAX_STEPS \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ - logger.wandb.name=$EXP_NAME \ - logger.monitor_gpus=True \ - logger.tensorboard_enabled=True \ - checkpointing.enabled=False \ - $@ \ - 2>&1 | tee $RUN_LOG - -# Dump TB to JSON for offline A/B comparison. -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS - -# NOTE: no `tests/check_metrics.py` thresholds here (see the grpo arm's .sh -# for rationale). diff --git a/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh b/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh deleted file mode 100755 index 7c3c0248a8..0000000000 --- a/tests/test_suites/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -source $SCRIPT_DIR/common.env - -# CISPO-vs-GRPO A/B: control arm (vanilla GRPO with hard PPO clip). -# See examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml. -# NOT in the CISPO PR - local research-validation only. - -# ===== BEGIN CONFIG ===== -NUM_NODES=1 -STEPS_PER_RUN=100 -MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) -NUM_MINUTES=90 -# ===== END CONFIG ===== - -exit_if_max_steps_reached - -cd $PROJECT_ROOT -uv run examples/run_grpo.py \ - --config $CONFIG_PATH \ - grpo.max_num_steps=$MAX_STEPS \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ - logger.wandb.name=$EXP_NAME \ - logger.monitor_gpus=True \ - logger.tensorboard_enabled=True \ - checkpointing.enabled=False \ - $@ \ - 2>&1 | tee $RUN_LOG - -# Dump TB to JSON so the A/B runs can be compared offline (e.g. with -# `python tests/json_dump_tb_logs.py --diff`). -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS - -# NOTE: no `tests/check_metrics.py` thresholds here. This is a research -# A/B - the goal is to *compare* the two arms, not gate either of them -# against absolute numbers. Inspect train/reward, validation/reward, -# train/token_mult_prob_error, and train/probs_ratio_clamped_frac side by -# side (wandb group=cispo-ab, or feed both metrics.json files to your -# preferred diff tool). diff --git a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.sh b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh similarity index 84% rename from tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.sh rename to tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh index c8afbedb97..a39ed495f6 100755 --- a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.sh +++ b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh @@ -2,8 +2,8 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) source $SCRIPT_DIR/common.env -# MiniMax-M1 replication study, CISPO arm (2n8g). -# See examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo.yaml. +# MiniMax-M1 high-off-policy study, CISPO arm (2n8g). +# See examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml. # NOT in the CISPO PR - local research artifact. # ===== BEGIN CONFIG ===== diff --git a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.sh b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh similarity index 82% rename from tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.sh rename to tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh index c6ef91cdce..aba2134cfa 100755 --- a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.sh +++ b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh @@ -2,9 +2,9 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) source $SCRIPT_DIR/common.env -# MiniMax-M1 replication study, GRPO baseline arm (2n8g sized to match the +# MiniMax-M1 high-off-policy study, GRPO baseline arm (2n8g sized to match the # proven SAPO sister recipe). -# See examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-grpo.yaml. +# See examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml. # NOT in the CISPO PR - local research artifact. # ===== BEGIN CONFIG ===== diff --git a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.sh b/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.sh deleted file mode 100755 index 9d76f80ec3..0000000000 --- a/tests/test_suites/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -source $SCRIPT_DIR/common.env - -# MiniMax-M1 replication study, DAPO ("Clip-Higher") arm (2n8g). -# See examples/configs/recipes/llm/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo.yaml. -# NOT in the CISPO PR - local research artifact. - -# ===== BEGIN CONFIG ===== -NUM_NODES=2 -STEPS_PER_RUN=500 -MAX_STEPS=500 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) -NUM_MINUTES=$((24 * 60)) -# ===== END CONFIG ===== - -exit_if_max_steps_reached - -cd $PROJECT_ROOT -uv run examples/run_grpo.py \ - --config $CONFIG_PATH \ - grpo.max_num_steps=$MAX_STEPS \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ - logger.wandb.name=$EXP_NAME \ - logger.tensorboard_enabled=True \ - checkpointing.enabled=False \ - $@ \ - 2>&1 | tee $RUN_LOG - -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh b/tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh deleted file mode 100755 index b2bbfe1fb5..0000000000 --- a/tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -source $SCRIPT_DIR/common.env - -# CISPO (Clipped IS-weight Policy Optimization, arXiv:2506.13585) on -# Qwen2.5-Math-1.5B-Instruct. Replaces GRPO's hard PPO clip with a -# stop-gradient clipped importance weight applied to the log-probability. -# See examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml. - -# ===== BEGIN CONFIG ===== -NUM_NODES=1 -STEPS_PER_RUN=450 -MAX_STEPS=450 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=120 -# ===== END CONFIG ===== - -exit_if_max_steps_reached - -# Run the experiment -cd $PROJECT_ROOT -uv run examples/run_grpo.py \ - --config $CONFIG_PATH \ - grpo.max_num_steps=$MAX_STEPS \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ - logger.wandb.name=$EXP_NAME \ - logger.monitor_gpus=True \ - logger.tensorboard_enabled=True \ - checkpointing.enabled=True \ - checkpointing.checkpoint_dir=$CKPT_DIR \ - $@ \ - 2>&1 | tee $RUN_LOG - -# Convert tensorboard logs to json -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS - -# Only run metrics if the target step is reached -if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then - uv run tests/check_metrics.py $JSON_METRICS \ - 'median(data["train/token_mult_prob_error"]) < 1.1' \ - 'data["train/token_mult_prob_error"]["450"] < 1.1' \ - 'mean(data["timing/train/total_step_time"], 2) < 25' - - # Clean up checkpoint directory after successful run to save space. - rm -rf "$CKPT_DIR" -fi From 333759ee29a10c6fa08beb829fde4122eb3b7c0a Mon Sep 17 00:00:00 2001 From: pengdurice Date: Wed, 20 May 2026 21:38:09 +0000 Subject: [PATCH 08/12] Fix CISPO rebase cleanup issues Signed-off-by: pengdurice --- ...licy-qwen3-30ba3b-2n8g-megatron-cispo.yaml | 2 +- ...olicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml | 6 ++---- nemo_rl/algorithms/loss/loss_functions.py | 20 ++++++++----------- ...policy-qwen3-30ba3b-2n8g-megatron-cispo.sh | 5 ++--- ...fpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh | 5 ++--- tests/test_suites/nightly.txt | 3 --- tests/unit/algorithms/test_loss_functions.py | 16 ++++++++++----- 7 files changed, 26 insertions(+), 31 deletions(-) diff --git a/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml index 4ad484ed18..bf3e896906 100644 --- a/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml +++ b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml @@ -16,7 +16,7 @@ defaults: ../../grpo_math_qwen30ba3b_megatron.yaml grpo: num_prompts_per_step: 32 num_generations_per_prompt: 16 - max_num_steps: 200 + max_num_steps: 100 val_period: 20 val_at_start: true val_at_end: true diff --git a/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml index 6ba1491f58..f1a19e9b69 100644 --- a/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml +++ b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml @@ -3,10 +3,8 @@ # grpo-qwen3-30ba3b-2n8g-megatron-sapo-asym.yaml # Only the loss_fn block (and logger names) differs. # -# Three-way A/B/C: this is the GRPO baseline; DAPO and CISPO are at -# cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-dapo.yaml +# Paired with: # cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml -# Submit via cispo_mm1_replica.slurm with ARM=grpo|dapo|cispo. # # Off-policy regime: 32 x 16 = 512 trajectories, train_global_batch_size=32 # -> 16 gradient updates per rollout (SAPO/GSPO Sec 5.1 setting). KL beta=0, @@ -17,7 +15,7 @@ defaults: ../../grpo_math_qwen30ba3b_megatron.yaml grpo: num_prompts_per_step: 32 num_generations_per_prompt: 16 - max_num_steps: 200 + max_num_steps: 100 val_period: 20 val_at_start: true val_at_end: true diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index cdd3b11db2..b504324a8a 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -252,6 +252,14 @@ def __init__(self, cfg: ClippedPGLossConfig): "use_cispo is incompatible with disable_ppo_ratio; " "CISPO needs the pi_theta/pi_theta_old ratio but disable_ppo_ratio removes it" ) + assert not self.force_on_policy_ratio, ( + "use_cispo is incompatible with force_on_policy_ratio; " + "forcing ratio=1 removes the clipped IS-weight that CISPO optimizes" + ) + assert not self.sequence_level_importance_ratios, ( + "use_cispo is incompatible with sequence_level_importance_ratios; " + "CISPO uses token-level importance weights" + ) assert self.ratio_clip_c is None, ( "use_cispo is incompatible with dual clipping (ratio_clip_c); " "the dual-clip block runs after the CISPO loss assembly and would " @@ -698,13 +706,6 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - if masked_ratios.numel() > 0: - r_t_p50 = torch.quantile(masked_ratios, 0.50).item() - r_t_p95 = torch.quantile(masked_ratios, 0.95).item() - r_t_p99 = torch.quantile(masked_ratios, 0.99).item() - else: - r_t_p50 = r_t_p95 = r_t_p99 = float("nan") - # Coarse, tokenizer-free proxy for "rare reflective tokens": # tokens whose behaviour-policy probability was below the # threshold. CISPO's central claim is that these are exactly @@ -729,13 +730,8 @@ def __call__( "cispo_diag/grpo_would_clip_frac": grpo_would_clip_frac, "cispo_diag/grpo_would_clip_pos_frac": grpo_would_clip_pos_frac, "cispo_diag/grpo_would_clip_neg_frac": grpo_would_clip_neg_frac, - "cispo_diag/r_t_p50": r_t_p50, - "cispo_diag/r_t_p95": r_t_p95, - "cispo_diag/r_t_p99": r_t_p99, "cispo_diag/low_prob_token_frac": low_prob_token_frac, "cispo_diag/would_clip_and_low_prob_frac": would_clip_and_low_prob, - "cispo_diag/grpo_eps": eps, - "cispo_diag/low_prob_threshold": self.cispo_diag_low_prob_threshold, } # If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized diff --git a/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh index a39ed495f6..c10cb9237f 100755 --- a/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh +++ b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh @@ -4,12 +4,11 @@ source $SCRIPT_DIR/common.env # MiniMax-M1 high-off-policy study, CISPO arm (2n8g). # See examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml. -# NOT in the CISPO PR - local research artifact. # ===== BEGIN CONFIG ===== NUM_NODES=2 -STEPS_PER_RUN=500 -MAX_STEPS=500 +STEPS_PER_RUN=100 +MAX_STEPS=100 NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) NUM_MINUTES=$((24 * 60)) # ===== END CONFIG ===== diff --git a/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh index aba2134cfa..c307c00189 100755 --- a/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh +++ b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh @@ -5,12 +5,11 @@ source $SCRIPT_DIR/common.env # MiniMax-M1 high-off-policy study, GRPO baseline arm (2n8g sized to match the # proven SAPO sister recipe). # See examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml. -# NOT in the CISPO PR - local research artifact. # ===== BEGIN CONFIG ===== NUM_NODES=2 -STEPS_PER_RUN=500 -MAX_STEPS=500 +STEPS_PER_RUN=100 +MAX_STEPS=100 NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) NUM_MINUTES=$((24 * 60)) # ===== END CONFIG ===== diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index e816031b73..0ceb4b13fd 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -56,9 +56,6 @@ tests/test_suites/llm/grpo-deepscaler-1.5b-8K.sh # Deepscaler (GSPO) tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh -# CISPO (Clipped IS-weight Policy Optimization, arXiv:2506.13585) -tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh - # HelpSteer3 tests # Issue with details: https://github.com/NVIDIA-NeMo/RL/issues/1571 # tests/test_suites/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-8n8g-fsdp2tp8cp4.sh diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index ea22552bb2..96475fca17 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -680,24 +680,30 @@ def test_clipped_pg_loss_force_on_policy_ratio_ignores_prev_logprobs(): @pytest.mark.parametrize( - "incompatible_flag,value", + "incompatible_config", [ - ("disable_ppo_ratio", True), - ("ratio_clip_c", 3.0), + {"disable_ppo_ratio": True}, + {"force_on_policy_ratio": True}, + {"ratio_clip_c": 3.0}, + {"sequence_level_importance_ratios": True, "token_level_loss": False}, ], ) -def test_clipped_pg_loss_cispo_incompatibility_asserts(incompatible_flag, value): +def test_clipped_pg_loss_cispo_incompatibility_asserts(incompatible_config): """CISPO must reject configs that conflict with its semantics. - disable_ppo_ratio removes the pi_theta / pi_theta_old ratio that CISPO uses as the importance weight, so they are mutually exclusive. + - force_on_policy_ratio makes every ratio 1.0, removing CISPO's clipped + importance-weight behavior. + - sequence_level_importance_ratios changes the token-level IS weights that + CISPO is defined over. - ratio_clip_c (dual clipping) runs after the CISPO loss assembly inside ClippedPGLossFn and would silently overwrite it. """ cfg = ClippedPGLossConfig( reference_policy_kl_penalty=0.0, use_cispo=True, - **{incompatible_flag: value}, + **incompatible_config, ) with pytest.raises(AssertionError): ClippedPGLossFn(cfg) From b08111256a7e99fdbb5d983f65ef76c3238e9717 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 21 May 2026 21:45:18 +0000 Subject: [PATCH 09/12] add async yaml and sh files Signed-off-by: pengdurice --- ...licy-qwen3-30ba3b-2n8g-megatron-cispo.yaml | 103 +++++++++++++++++ ...olicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml | 104 ++++++++++++++++++ ...policy-qwen3-30ba3b-2n8g-megatron-cispo.sh | 31 ++++++ ...fpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh | 32 ++++++ 4 files changed, 270 insertions(+) create mode 100644 examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml create mode 100644 examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml create mode 100755 tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh create mode 100755 tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh diff --git a/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml b/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml new file mode 100644 index 0000000000..a0d248ed4f --- /dev/null +++ b/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml @@ -0,0 +1,103 @@ +# MiniMax-M1 async lag-1 high-off-policy study (https://arxiv.org/abs/2506.13585), CISPO arm. +# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the +# loss_fn block differs across arms. +# +# CISPO (MiniMax-M1 §3.1) clips the IS weight as a stop-gradient +# coefficient instead of clipping the policy ratio. Gradients flow +# through log pi for *every* token, including the rare reflective +# tokens ("However", "Wait", "Recheck") that GRPO/DAPO would zero out. +# +# L_CISPO = -A * sg(clip(r, 1 - eps_low, 1 + eps_high)) * log pi(a) +# +# Per ms-swift's CISPO recipe and ScaleRL (arXiv:2510.13786), we use a +# very loose lower clip and a much wider upper clip (eps_high = 5.0). +defaults: ../../grpo_math_qwen30ba3b_megatron.yaml + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_num_steps: 100 + val_period: 20 + val_at_start: true + val_at_end: true + max_val_samples: 128 + val_batch_size: 128 + async_grpo: + enabled: true + max_trajectory_age_steps: 1 + in_flight_weight_updates: true + +policy: + model_name: Qwen/Qwen3-30B-A3B + train_global_batch_size: 32 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + sequence_packing: + enabled: true + algorithm: modified_first_fit_decreasing + sequence_length_round: 64 + megatron_cfg: + enabled: true + converter_type: LlamaForCausalLM + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 8 + sequence_parallel: true + empty_unused_memory_level: 1 + freeze_moe_router: true + moe_router_dtype: fp64 + moe_router_load_balancing_type: none + moe_router_bias_update_rate: 0.0 + optimizer: + lr: 3.0e-7 + min_lr: 3.0e-8 + scheduler: + lr_decay_iters: 500 + lr_warmup_iters: 10 + lr_warmup_init: 3.0e-8 + env_vars: + PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False + generation: + vllm_cfg: + tensor_parallel_size: 4 + gpu_memory_utilization: 0.7 + enforce_eager: false + async_engine: true + colocated: + enabled: false + resources: + num_nodes: 1 + gpus_per_node: 8 + +loss_fn: + # ---- CISPO arm ---- + reference_policy_kl_penalty: 0.0 + reference_policy_kl_type: k3 + ratio_clip_min: 1.0 # lower bound = 0; effectively unclipped + ratio_clip_max: 5.0 # eps_high = 5.0 (ms-swift / ScaleRL) + ratio_clip_c: null # dual clipping MUST be off for CISPO + token_level_loss: true + use_importance_sampling_correction: true + sequence_level_importance_ratios: false + force_on_policy_ratio: false + use_cispo: true + cispo_diagnostics: true + cispo_diag_grpo_eps: 0.2 # measure GRPO-equivalent clip rate + cispo_diag_low_prob_threshold: 0.05 + +checkpointing: + enabled: false + +logger: + log_dir: logs/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: false + wandb: + project: nemo-rl + name: cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo + +cluster: + gpus_per_node: 8 + num_nodes: 3 diff --git a/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml b/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml new file mode 100644 index 0000000000..d793048e2e --- /dev/null +++ b/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml @@ -0,0 +1,104 @@ +# MiniMax-M1 async lag-1 high-off-policy study (https://arxiv.org/abs/2506.13585), GRPO arm. +# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe: +# grpo-qwen3-30ba3b-2n8g-megatron-sapo-asym.yaml +# Only the loss_fn block (and logger names) differs. +# +# Paired with: +# cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml +# +# Off-policy regime: 32 x 16 = 512 trajectories, train_global_batch_size=32 +# -> 16 gradient updates per rollout (SAPO/GSPO Sec 5.1 setting). KL beta=0, +# token-level loss, sampling temperature 1.0. +# NOT in the PR; local research artifact. +defaults: ../../grpo_math_qwen30ba3b_megatron.yaml + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_num_steps: 100 + val_period: 20 + val_at_start: true + val_at_end: true + max_val_samples: 128 + val_batch_size: 128 + async_grpo: + enabled: true + max_trajectory_age_steps: 1 + in_flight_weight_updates: true + +policy: + model_name: Qwen/Qwen3-30B-A3B + train_global_batch_size: 32 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + sequence_packing: + enabled: true + algorithm: modified_first_fit_decreasing + sequence_length_round: 64 + megatron_cfg: + enabled: true + converter_type: LlamaForCausalLM + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 8 + sequence_parallel: true + empty_unused_memory_level: 1 + freeze_moe_router: true + moe_router_dtype: fp64 + moe_router_load_balancing_type: none + moe_router_bias_update_rate: 0.0 + optimizer: + lr: 3.0e-7 + min_lr: 3.0e-8 + scheduler: + lr_decay_iters: 500 + lr_warmup_iters: 10 + lr_warmup_init: 3.0e-8 + env_vars: + PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False + generation: + vllm_cfg: + tensor_parallel_size: 4 + gpu_memory_utilization: 0.7 + enforce_eager: false + async_engine: true + colocated: + enabled: false + resources: + num_nodes: 1 + gpus_per_node: 8 + +loss_fn: + # ---- GRPO baseline arm ---- (only the loss_fn block differs across arms) + reference_policy_kl_penalty: 0.0 + reference_policy_kl_type: k3 + ratio_clip_min: 0.2 # standard PPO/GRPO epsilon + ratio_clip_max: 0.2 + ratio_clip_c: null + token_level_loss: true + use_importance_sampling_correction: true + sequence_level_importance_ratios: false + force_on_policy_ratio: false + use_cispo: false + # Shared CISPO-style diagnostics across both arms so the GRPO baseline + # also reports grpo_would_clip_frac (the gap CISPO claims to close). + cispo_diagnostics: true + cispo_diag_grpo_eps: 0.2 + cispo_diag_low_prob_threshold: 0.05 + +checkpointing: + enabled: false + +logger: + log_dir: logs/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: false + wandb: + project: nemo-rl + name: cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo + +cluster: + gpus_per_node: 8 + num_nodes: 3 diff --git a/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh new file mode 100755 index 0000000000..c638f4d46c --- /dev/null +++ b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh @@ -0,0 +1,31 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# MiniMax-M1 async lag-1 high-off-policy study, CISPO arm (2n8g). +# See examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml. + +# ===== BEGIN CONFIG ===== +NUM_NODES=2 +STEPS_PER_RUN=100 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) +NUM_MINUTES=$((24 * 60)) +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=False \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh new file mode 100755 index 0000000000..b000ae00b0 --- /dev/null +++ b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh @@ -0,0 +1,32 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# MiniMax-M1 async lag-1 high-off-policy study, GRPO baseline arm (2n8g sized to match the +# proven SAPO sister recipe). +# See examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml. + +# ===== BEGIN CONFIG ===== +NUM_NODES=2 +STEPS_PER_RUN=100 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) +NUM_MINUTES=$((24 * 60)) +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=False \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS From 9d44d976e20a31bdfe0a1aa9be30dd0e53a6cfdb Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 21 May 2026 23:28:14 +0000 Subject: [PATCH 10/12] clean up some sh and yaml files, add one nightly and to doc Signed-off-by: pengdurice --- docs/about/algorithms/grpo.md | 2 + ...olicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml | 104 ------------------ ...licy-qwen3-30ba3b-2n8g-megatron-cispo.yaml | 95 ---------------- ...olicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml | 96 ---------------- ...policy-qwen3-30ba3b-2n8g-megatron-cispo.sh | 5 +- ...fpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh | 32 ------ ...policy-qwen3-30ba3b-2n8g-megatron-cispo.sh | 31 ------ ...fpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh | 32 ------ tests/test_suites/nightly.txt | 3 + 9 files changed, 8 insertions(+), 392 deletions(-) delete mode 100644 examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml delete mode 100644 examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml delete mode 100644 examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml delete mode 100755 tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh delete mode 100755 tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh delete mode 100755 tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh diff --git a/docs/about/algorithms/grpo.md b/docs/about/algorithms/grpo.md index 7cd2f65254..3e57af30f3 100644 --- a/docs/about/algorithms/grpo.md +++ b/docs/about/algorithms/grpo.md @@ -4,6 +4,8 @@ We provide a reference GRPO configuration for math benchmarks using the [OpenIns You can read about the details of the GRPO implementation [here](../../guides/grpo.md). +Related GRPO-family objectives are documented in [DAPO](dapo.md) and [CISPO](cispo.md). + ## GRPO Single Node To run GRPO on a single GPU for `Qwen/Qwen2.5-1.5B`: diff --git a/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml b/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml deleted file mode 100644 index d793048e2e..0000000000 --- a/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml +++ /dev/null @@ -1,104 +0,0 @@ -# MiniMax-M1 async lag-1 high-off-policy study (https://arxiv.org/abs/2506.13585), GRPO arm. -# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe: -# grpo-qwen3-30ba3b-2n8g-megatron-sapo-asym.yaml -# Only the loss_fn block (and logger names) differs. -# -# Paired with: -# cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml -# -# Off-policy regime: 32 x 16 = 512 trajectories, train_global_batch_size=32 -# -> 16 gradient updates per rollout (SAPO/GSPO Sec 5.1 setting). KL beta=0, -# token-level loss, sampling temperature 1.0. -# NOT in the PR; local research artifact. -defaults: ../../grpo_math_qwen30ba3b_megatron.yaml - -grpo: - num_prompts_per_step: 32 - num_generations_per_prompt: 16 - max_num_steps: 100 - val_period: 20 - val_at_start: true - val_at_end: true - max_val_samples: 128 - val_batch_size: 128 - async_grpo: - enabled: true - max_trajectory_age_steps: 1 - in_flight_weight_updates: true - -policy: - model_name: Qwen/Qwen3-30B-A3B - train_global_batch_size: 32 - train_micro_batch_size: 1 - logprob_batch_size: 1 - max_total_sequence_length: 4096 - sequence_packing: - enabled: true - algorithm: modified_first_fit_decreasing - sequence_length_round: 64 - megatron_cfg: - enabled: true - converter_type: LlamaForCausalLM - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - expert_model_parallel_size: 8 - sequence_parallel: true - empty_unused_memory_level: 1 - freeze_moe_router: true - moe_router_dtype: fp64 - moe_router_load_balancing_type: none - moe_router_bias_update_rate: 0.0 - optimizer: - lr: 3.0e-7 - min_lr: 3.0e-8 - scheduler: - lr_decay_iters: 500 - lr_warmup_iters: 10 - lr_warmup_init: 3.0e-8 - env_vars: - PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False - generation: - vllm_cfg: - tensor_parallel_size: 4 - gpu_memory_utilization: 0.7 - enforce_eager: false - async_engine: true - colocated: - enabled: false - resources: - num_nodes: 1 - gpus_per_node: 8 - -loss_fn: - # ---- GRPO baseline arm ---- (only the loss_fn block differs across arms) - reference_policy_kl_penalty: 0.0 - reference_policy_kl_type: k3 - ratio_clip_min: 0.2 # standard PPO/GRPO epsilon - ratio_clip_max: 0.2 - ratio_clip_c: null - token_level_loss: true - use_importance_sampling_correction: true - sequence_level_importance_ratios: false - force_on_policy_ratio: false - use_cispo: false - # Shared CISPO-style diagnostics across both arms so the GRPO baseline - # also reports grpo_would_clip_frac (the gap CISPO claims to close). - cispo_diagnostics: true - cispo_diag_grpo_eps: 0.2 - cispo_diag_low_prob_threshold: 0.05 - -checkpointing: - enabled: false - -logger: - log_dir: logs/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo - wandb_enabled: true - tensorboard_enabled: true - monitor_gpus: false - wandb: - project: nemo-rl - name: cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo - -cluster: - gpus_per_node: 8 - num_nodes: 3 diff --git a/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml deleted file mode 100644 index bf3e896906..0000000000 --- a/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml +++ /dev/null @@ -1,95 +0,0 @@ -# MiniMax-M1 high-off-policy study (https://arxiv.org/abs/2506.13585), CISPO arm. -# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the -# loss_fn block differs across arms. -# -# CISPO (MiniMax-M1 §3.1) clips the IS weight as a stop-gradient -# coefficient instead of clipping the policy ratio. Gradients flow -# through log pi for *every* token, including the rare reflective -# tokens ("However", "Wait", "Recheck") that GRPO/DAPO would zero out. -# -# L_CISPO = -A * sg(clip(r, 1 - eps_low, 1 + eps_high)) * log pi(a) -# -# Per ms-swift's CISPO recipe and ScaleRL (arXiv:2510.13786), we use a -# very loose lower clip and a much wider upper clip (eps_high = 5.0). -defaults: ../../grpo_math_qwen30ba3b_megatron.yaml - -grpo: - num_prompts_per_step: 32 - num_generations_per_prompt: 16 - max_num_steps: 100 - val_period: 20 - val_at_start: true - val_at_end: true - max_val_samples: 128 - val_batch_size: 128 - -policy: - model_name: Qwen/Qwen3-30B-A3B - train_global_batch_size: 32 - train_micro_batch_size: 1 - logprob_batch_size: 1 - max_total_sequence_length: 4096 - sequence_packing: - enabled: true - algorithm: modified_first_fit_decreasing - sequence_length_round: 64 - megatron_cfg: - enabled: true - converter_type: LlamaForCausalLM - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - expert_model_parallel_size: 8 - sequence_parallel: true - empty_unused_memory_level: 1 - freeze_moe_router: true - moe_router_dtype: fp64 - moe_router_load_balancing_type: none - moe_router_bias_update_rate: 0.0 - optimizer: - lr: 3.0e-7 - min_lr: 3.0e-8 - scheduler: - lr_decay_iters: 500 - lr_warmup_iters: 10 - lr_warmup_init: 3.0e-8 - env_vars: - PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False - generation: - vllm_cfg: - tensor_parallel_size: 4 - gpu_memory_utilization: 0.7 - enforce_eager: false - colocated: - enabled: true - -loss_fn: - # ---- CISPO arm ---- - reference_policy_kl_penalty: 0.0 - reference_policy_kl_type: k3 - ratio_clip_min: 1.0 # lower bound = 0; effectively unclipped - ratio_clip_max: 5.0 # eps_high = 5.0 (ms-swift / ScaleRL) - ratio_clip_c: null # dual clipping MUST be off for CISPO - token_level_loss: true - use_importance_sampling_correction: false - sequence_level_importance_ratios: false - force_on_policy_ratio: false - use_cispo: true - cispo_diagnostics: true - cispo_diag_grpo_eps: 0.2 # measure GRPO-equivalent clip rate - cispo_diag_low_prob_threshold: 0.05 - -checkpointing: - enabled: false - -logger: - log_dir: logs/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo - wandb_enabled: true - tensorboard_enabled: true - monitor_gpus: false - wandb: - project: nemo-rl - name: cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo - -cluster: - gpus_per_node: 8 - num_nodes: 2 diff --git a/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml b/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml deleted file mode 100644 index f1a19e9b69..0000000000 --- a/examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml +++ /dev/null @@ -1,96 +0,0 @@ -# MiniMax-M1 high-off-policy study (https://arxiv.org/abs/2506.13585), GRPO arm. -# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe: -# grpo-qwen3-30ba3b-2n8g-megatron-sapo-asym.yaml -# Only the loss_fn block (and logger names) differs. -# -# Paired with: -# cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml -# -# Off-policy regime: 32 x 16 = 512 trajectories, train_global_batch_size=32 -# -> 16 gradient updates per rollout (SAPO/GSPO Sec 5.1 setting). KL beta=0, -# token-level loss, sampling temperature 1.0. -# NOT in the PR; local research artifact. -defaults: ../../grpo_math_qwen30ba3b_megatron.yaml - -grpo: - num_prompts_per_step: 32 - num_generations_per_prompt: 16 - max_num_steps: 100 - val_period: 20 - val_at_start: true - val_at_end: true - max_val_samples: 128 - val_batch_size: 128 - -policy: - model_name: Qwen/Qwen3-30B-A3B - train_global_batch_size: 32 - train_micro_batch_size: 1 - logprob_batch_size: 1 - max_total_sequence_length: 4096 - sequence_packing: - enabled: true - algorithm: modified_first_fit_decreasing - sequence_length_round: 64 - megatron_cfg: - enabled: true - converter_type: LlamaForCausalLM - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - expert_model_parallel_size: 8 - sequence_parallel: true - empty_unused_memory_level: 1 - freeze_moe_router: true - moe_router_dtype: fp64 - moe_router_load_balancing_type: none - moe_router_bias_update_rate: 0.0 - optimizer: - lr: 3.0e-7 - min_lr: 3.0e-8 - scheduler: - lr_decay_iters: 500 - lr_warmup_iters: 10 - lr_warmup_init: 3.0e-8 - env_vars: - PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False - generation: - vllm_cfg: - tensor_parallel_size: 4 - gpu_memory_utilization: 0.7 - enforce_eager: false - colocated: - enabled: true - -loss_fn: - # ---- GRPO baseline arm ---- (only the loss_fn block differs across arms) - reference_policy_kl_penalty: 0.0 - reference_policy_kl_type: k3 - ratio_clip_min: 0.2 # standard PPO/GRPO epsilon - ratio_clip_max: 0.2 - ratio_clip_c: null - token_level_loss: true - use_importance_sampling_correction: false - sequence_level_importance_ratios: false - force_on_policy_ratio: false - use_cispo: false - # Shared CISPO-style diagnostics across all 3 arms so the GRPO baseline - # also reports grpo_would_clip_frac (the gap CISPO claims to close). - cispo_diagnostics: true - cispo_diag_grpo_eps: 0.2 - cispo_diag_low_prob_threshold: 0.05 - -checkpointing: - enabled: false - -logger: - log_dir: logs/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo - wandb_enabled: true - tensorboard_enabled: true - monitor_gpus: false - wandb: - project: nemo-rl - name: cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo - -cluster: - gpus_per_node: 8 - num_nodes: 2 diff --git a/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh index c638f4d46c..1c26ebe7e4 100755 --- a/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh +++ b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh @@ -2,11 +2,12 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) source $SCRIPT_DIR/common.env -# MiniMax-M1 async lag-1 high-off-policy study, CISPO arm (2n8g). +# MiniMax-M1 async lag-1 high-off-policy study, CISPO arm. +# Uses 2 nodes for Megatron policy training plus 1 non-colocated vLLM node. # See examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml. # ===== BEGIN CONFIG ===== -NUM_NODES=2 +NUM_NODES=3 STEPS_PER_RUN=100 MAX_STEPS=100 NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) diff --git a/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh deleted file mode 100755 index b000ae00b0..0000000000 --- a/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -source $SCRIPT_DIR/common.env - -# MiniMax-M1 async lag-1 high-off-policy study, GRPO baseline arm (2n8g sized to match the -# proven SAPO sister recipe). -# See examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml. - -# ===== BEGIN CONFIG ===== -NUM_NODES=2 -STEPS_PER_RUN=100 -MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) -NUM_MINUTES=$((24 * 60)) -# ===== END CONFIG ===== - -exit_if_max_steps_reached - -cd $PROJECT_ROOT -uv run examples/run_grpo.py \ - --config $CONFIG_PATH \ - grpo.max_num_steps=$MAX_STEPS \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ - logger.wandb.name=$EXP_NAME \ - logger.tensorboard_enabled=True \ - checkpointing.enabled=False \ - $@ \ - 2>&1 | tee $RUN_LOG - -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh deleted file mode 100755 index c10cb9237f..0000000000 --- a/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -source $SCRIPT_DIR/common.env - -# MiniMax-M1 high-off-policy study, CISPO arm (2n8g). -# See examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml. - -# ===== BEGIN CONFIG ===== -NUM_NODES=2 -STEPS_PER_RUN=100 -MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) -NUM_MINUTES=$((24 * 60)) -# ===== END CONFIG ===== - -exit_if_max_steps_reached - -cd $PROJECT_ROOT -uv run examples/run_grpo.py \ - --config $CONFIG_PATH \ - grpo.max_num_steps=$MAX_STEPS \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ - logger.wandb.name=$EXP_NAME \ - logger.tensorboard_enabled=True \ - checkpointing.enabled=False \ - $@ \ - 2>&1 | tee $RUN_LOG - -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh b/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh deleted file mode 100755 index c307c00189..0000000000 --- a/tests/test_suites/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -source $SCRIPT_DIR/common.env - -# MiniMax-M1 high-off-policy study, GRPO baseline arm (2n8g sized to match the -# proven SAPO sister recipe). -# See examples/configs/recipes/llm/cispo-mm1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-grpo.yaml. - -# ===== BEGIN CONFIG ===== -NUM_NODES=2 -STEPS_PER_RUN=100 -MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) -NUM_MINUTES=$((24 * 60)) -# ===== END CONFIG ===== - -exit_if_max_steps_reached - -cd $PROJECT_ROOT -uv run examples/run_grpo.py \ - --config $CONFIG_PATH \ - grpo.max_num_steps=$MAX_STEPS \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ - logger.wandb.name=$EXP_NAME \ - logger.tensorboard_enabled=True \ - checkpointing.enabled=False \ - $@ \ - 2>&1 | tee $RUN_LOG - -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 0ceb4b13fd..59cf1dffea 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -63,6 +63,9 @@ tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh # GRPO math test run (32K context mcore) tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh +# CISPO async lag-1 high-off-policy run (Qwen3-30B-A3B, Megatron + non-colocated vLLM) +tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh + # FP8 tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-megatron-fp8-e2e.sh From 49078728918f9108af3793b649602c6b27defc5d Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 21 May 2026 23:32:32 +0000 Subject: [PATCH 11/12] remove diagnositic Signed-off-by: pengdurice --- ...licy-qwen3-30ba3b-2n8g-megatron-cispo.yaml | 3 - nemo_rl/algorithms/loss/loss_functions.py | 93 ------------------- 2 files changed, 96 deletions(-) diff --git a/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml b/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml index a0d248ed4f..99bd7955fb 100644 --- a/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml +++ b/examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml @@ -82,9 +82,6 @@ loss_fn: sequence_level_importance_ratios: false force_on_policy_ratio: false use_cispo: true - cispo_diagnostics: true - cispo_diag_grpo_eps: 0.2 # measure GRPO-equivalent clip rate - cispo_diag_low_prob_threshold: 0.05 checkpointing: enabled: false diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index b504324a8a..0481ef7fbe 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Any, NotRequired, Optional, TypedDict, TypeVar import torch @@ -131,26 +130,6 @@ class ClippedPGLossConfig(BaseModel, extra="allow"): force_on_policy_ratio: bool = False # If True, use CISPO (Clipped IS-weight Policy Optimization) from MiniMax-M1. use_cispo: bool = False - # If True, log per-step CISPO diagnostic metrics that quantify *why* - # CISPO is (or isn't) helping vs hard-clipped GRPO/DAPO. Off by default - # because the percentile / boolean-mask reductions add a small amount - # of per-step overhead. See ClippedPGLossFn.__call__ for the full list. - # Useful in any arm (GRPO / DAPO / CISPO) - the same metrics let the - # GRPO baseline tell you how much gradient signal it's losing to its - # own hard clip, which is precisely the gap CISPO claims to close. - cispo_diagnostics: bool = False - # The hard-clip epsilon to use when computing the "what fraction of - # tokens would standard GRPO have zeroed the gradient on?" diagnostic. - # Defaults to 0.2 (the original PPO/GRPO value). Independent from - # ratio_clip_min/ratio_clip_max so we can probe the GRPO-equivalent - # behaviour even on a CISPO run with epsilon_high=5.0. - cispo_diag_grpo_eps: float = 0.2 - # Probability threshold under which a token is counted as "low-prob" - # (a coarse, tokenizer-free proxy for rare reflective tokens like - # "However", "Wait", "Recheck" - see MiniMax-M1 paper §3.1). - # Defaults to 0.05; we log the fraction of generated tokens whose - # behaviour-policy probability is below this. - cispo_diag_low_prob_threshold: float = 0.05 class ClippedPGLossDataDict(TypedDict): @@ -265,17 +244,6 @@ def __init__(self, cfg: ClippedPGLossConfig): "the dual-clip block runs after the CISPO loss assembly and would " "silently overwrite it. Set ratio_clip_c=null when use_cispo=True." ) - # CISPO-style diagnostics. Off by default to avoid extra reductions. - self.cispo_diagnostics = cfg.cispo_diagnostics - self.cispo_diag_grpo_eps = cfg.cispo_diag_grpo_eps - self.cispo_diag_low_prob_threshold = cfg.cispo_diag_low_prob_threshold - assert self.cispo_diag_grpo_eps > 0, ( - f"cispo_diag_grpo_eps must be positive, got {self.cispo_diag_grpo_eps}" - ) - assert 0.0 < self.cispo_diag_low_prob_threshold < 1.0, ( - "cispo_diag_low_prob_threshold must be a probability in (0, 1), " - f"got {self.cispo_diag_low_prob_threshold}" - ) if self.truncated_importance_sampling_ratio is not None: assert self.use_importance_sampling_correction, ( "truncated_importance_sampling_ratio is only supported when use_importance_sampling_correction is True" @@ -674,66 +642,6 @@ def __call__( probs_ratio_clamped_min = float("inf") probs_ratio_clamped_max = float("-inf") - # CISPO-style diagnostics. Designed to be tokenizer-free and cheap: - # all reductions are over the same (mask) we already build above. We - # log even on GRPO/DAPO arms so the gap CISPO claims to close can be - # *measured directly* on the baseline (the would_clip_frac). - cispo_diag_metrics: dict[str, float] = {} - if self.cispo_diagnostics: - with torch.no_grad(): - eps = self.cispo_diag_grpo_eps - detached_ratios = ratios.detach() - # "Would standard GRPO have zeroed this token's gradient?" - # - positive-advantage tokens lose their gradient when r > 1+eps - # - negative-advantage tokens lose their gradient when r < 1-eps - adv_pos = (advantages > 0).float() - adv_neg = (advantages < 0).float() - would_clip_pos = adv_pos * (detached_ratios > 1.0 + eps).float() - would_clip_neg = adv_neg * (detached_ratios < 1.0 - eps).float() - grpo_would_clip_frac = masked_mean( - would_clip_pos + would_clip_neg, - mask, - global_normalization_factor=global_valid_toks, - ).item() - grpo_would_clip_pos_frac = masked_mean( - would_clip_pos, - mask, - global_normalization_factor=global_valid_toks, - ).item() - grpo_would_clip_neg_frac = masked_mean( - would_clip_neg, - mask, - global_normalization_factor=global_valid_toks, - ).item() - - # Coarse, tokenizer-free proxy for "rare reflective tokens": - # tokens whose behaviour-policy probability was below the - # threshold. CISPO's central claim is that these are exactly - # the tokens GRPO's hard clip throws away. - low_thr = math.log(self.cispo_diag_low_prob_threshold) - low_prob_token = (prev_logprobs < low_thr).float() - low_prob_token_frac = masked_mean( - low_prob_token, - mask, - global_normalization_factor=global_valid_toks, - ).item() - # Of the would-be-clipped tokens, what fraction are also - # "low-prob"? A high number here is the smoking-gun - # confirmation of the paper's diagnosis. - would_clip_and_low_prob = masked_mean( - (would_clip_pos + would_clip_neg) * low_prob_token, - mask, - global_normalization_factor=global_valid_toks, - ).item() - - cispo_diag_metrics = { - "cispo_diag/grpo_would_clip_frac": grpo_would_clip_frac, - "cispo_diag/grpo_would_clip_pos_frac": grpo_would_clip_pos_frac, - "cispo_diag/grpo_would_clip_neg_frac": grpo_would_clip_neg_frac, - "cispo_diag/low_prob_token_frac": low_prob_token_frac, - "cispo_diag/would_clip_and_low_prob_frac": would_clip_and_low_prob, - } - # If you provided a global_valid_{seqs/toks}, all metrics here are globally normalized # by either sequence or token count, depending on particular metric. # To get the true metric, you'll need to sum over the microbatch. @@ -756,7 +664,6 @@ def __call__( "num_valid_samples": sample_mask.sum().item(), "approx_entropy": seq_entropy_approx.item(), **_is_filter_metrics, - **cispo_diag_metrics, }, ) From 06c4feed91034e8884fecb20649550ef30ffb393 Mon Sep 17 00:00:00 2001 From: pengdurice Date: Thu, 21 May 2026 23:33:56 +0000 Subject: [PATCH 12/12] add cispo.md Signed-off-by: pengdurice --- docs/about/algorithms/cispo.md | 62 ++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 docs/about/algorithms/cispo.md diff --git a/docs/about/algorithms/cispo.md b/docs/about/algorithms/cispo.md new file mode 100644 index 0000000000..d94ad06e37 --- /dev/null +++ b/docs/about/algorithms/cispo.md @@ -0,0 +1,62 @@ +# CISPO + +[Clipped Importance Sampling Policy Optimization (CISPO)](https://arxiv.org/abs/2506.13585) is a GRPO-family policy-gradient objective that clips the importance-sampling weight as a detached coefficient instead of using PPO-style hard ratio clipping. + +For each generated token, CISPO computes the policy ratio + +```text +r_t(theta) = pi_theta(o_t | q, o_