diff --git a/docs/about/algorithms/cispo.md b/docs/about/algorithms/cispo.md new file mode 100644 index 0000000000..d94ad06e37 --- /dev/null +++ b/docs/about/algorithms/cispo.md @@ -0,0 +1,62 @@ +# CISPO + +[Clipped Importance Sampling Policy Optimization (CISPO)](https://arxiv.org/abs/2506.13585) is a GRPO-family policy-gradient objective that clips the importance-sampling weight as a detached coefficient instead of using PPO-style hard ratio clipping. + +For each generated token, CISPO computes the policy ratio + +```text +r_t(theta) = pi_theta(o_t | q, o_ $cA$(clipped) @@ -218,6 +225,29 @@ def __init__(self, cfg: ClippedPGLossConfig): "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" ) + self.use_cispo = cfg.use_cispo + if self.use_cispo: + assert not self.disable_ppo_ratio, ( + "use_cispo is incompatible with disable_ppo_ratio; " + "CISPO needs the pi_theta/pi_theta_old ratio but disable_ppo_ratio removes it" + ) + assert not self.force_on_policy_ratio, ( + "use_cispo is incompatible with force_on_policy_ratio; " + "forcing ratio=1 removes the clipped IS-weight that CISPO optimizes" + ) + assert not self.sequence_level_importance_ratios, ( + "use_cispo is incompatible with sequence_level_importance_ratios; " + "CISPO uses token-level importance weights" + ) + assert self.ratio_clip_c is None, ( + "use_cispo is incompatible with dual clipping (ratio_clip_c); " + "the dual-clip block runs after the CISPO loss assembly and would " + "silently overwrite it. Set ratio_clip_c=null when use_cispo=True." + ) + if self.truncated_importance_sampling_ratio is not None: + assert self.use_importance_sampling_correction, ( + "truncated_importance_sampling_ratio is only supported when use_importance_sampling_correction is True" + ) if self.truncated_importance_sampling_type is not None: assert self.use_importance_sampling_correction, ( "truncated importance sampling is only supported when use_importance_sampling_correction is True" @@ -417,7 +447,11 @@ def __call__( # Determine which value to use for clipping (max for pessimistic estimate) clip_loss = torch.max(loss1, loss2) - + if self.use_cispo: + ratios_clamped = ratios.clamp( + 1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max + ) + clip_loss = -advantages * ratios_clamped.detach() * curr_logprobs # Dual-clipping see https://arxiv.org/pdf/1912.09729 if self.ratio_clip_c is not None: assert self.ratio_clip_c > 1, ( diff --git a/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh new file mode 100755 index 0000000000..1c26ebe7e4 --- /dev/null +++ b/tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh @@ -0,0 +1,32 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# MiniMax-M1 async lag-1 high-off-policy study, CISPO arm. +# Uses 2 nodes for Megatron policy training plus 1 non-colocated vLLM node. +# See examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml. + +# ===== BEGIN CONFIG ===== +NUM_NODES=3 +STEPS_PER_RUN=100 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) +NUM_MINUTES=$((24 * 60)) +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=False \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 0ceb4b13fd..59cf1dffea 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -63,6 +63,9 @@ tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh # GRPO math test run (32K context mcore) tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh +# CISPO async lag-1 high-off-policy run (Qwen3-30B-A3B, Megatron + non-colocated vLLM) +tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh + # FP8 tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-megatron-fp8-e2e.sh diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index e4ac4fab66..96475fca17 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -623,11 +623,12 @@ def test_clipped_pg_loss_force_on_policy_ratio(): def test_clipped_pg_loss_force_on_policy_ratio_ignores_prev_logprobs(): - """Tests that force_on_policy_ratio ignores prev_logprobs from data and uses curr_logprobs instead. + """Tests that force_on_policy_ratio ignores prev_logprobs from data. - When force_on_policy_ratio=True, the loss function should use curr_logprobs.detach() - as prev_logprobs, so the actual prev_logprobs in data are irrelevant. This allows - skipping the expensive prev_logprobs computation upstream. + When force_on_policy_ratio=True, the loss function should use + curr_logprobs.detach() as prev_logprobs, so the actual prev_logprobs in + data are irrelevant. This allows skipping the expensive prev_logprobs + computation upstream. """ if not torch.cuda.is_available(): pytest.skip("No GPU available") @@ -678,6 +679,91 @@ def test_clipped_pg_loss_force_on_policy_ratio_ignores_prev_logprobs(): assert metrics_1["probs_ratio"] == metrics_2["probs_ratio"] == 1.0 +@pytest.mark.parametrize( + "incompatible_config", + [ + {"disable_ppo_ratio": True}, + {"force_on_policy_ratio": True}, + {"ratio_clip_c": 3.0}, + {"sequence_level_importance_ratios": True, "token_level_loss": False}, + ], +) +def test_clipped_pg_loss_cispo_incompatibility_asserts(incompatible_config): + """CISPO must reject configs that conflict with its semantics. + + - disable_ppo_ratio removes the pi_theta / pi_theta_old ratio that CISPO + uses as the importance weight, so they are mutually exclusive. + - force_on_policy_ratio makes every ratio 1.0, removing CISPO's clipped + importance-weight behavior. + - sequence_level_importance_ratios changes the token-level IS weights that + CISPO is defined over. + - ratio_clip_c (dual clipping) runs after the CISPO loss assembly inside + ClippedPGLossFn and would silently overwrite it. + """ + cfg = ClippedPGLossConfig( + reference_policy_kl_penalty=0.0, + use_cispo=True, + **incompatible_config, + ) + with pytest.raises(AssertionError): + ClippedPGLossFn(cfg) + + +def test_clipped_pg_loss_cispo(): + """Tests CISPO (Clipped IS-weight Policy Optimization) path in ClippedPGLossFn. + + Uses the same data pattern as test_clipped_pg_loss_ppo_clipping: ratios are + [0.5, 1.0, 1.5] and clamp to [0.8, 1.0, 1.2]. CISPO formula: + + L = -advantages * clip(ratio, 1-eps, 1+eps).detach() * curr_logprobs + + The IS weight is clipped and stop-gradiented; gradients flow only through + curr_logprobs. + """ + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = ClippedPGLossConfig(reference_policy_kl_penalty=0.0, use_cispo=True) + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Target ratios: 0.5, 1.0, 1.5 -> after clip(0.2, 0.2): 0.8, 1.0, 1.2 + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + + # --- Hand calculation: CISPO loss = -A * clip(r, 1-ε, 1+ε) * curr_lp (ratio stop-grad) --- + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # [0.5, 1.0, 1.5] + ratios_clamped = torch.clamp( + ratios, 1.0 - cfg.ratio_clip_min, 1.0 + cfg.ratio_clip_max + ) # [0.8, 1.0, 1.2] + cispo_loss_per_token = -adv_masked * ratios_clamped * curr_lp_masked + expected_loss = torch.mean(cispo_loss_per_token) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device + ) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) + + actual_loss, _ = loss_fn( + data=data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum( + data["sample_mask"].unsqueeze(-1) * data["token_mask"] + ), + **loss_input, + ) + torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize("kl_type", ["k1", "k2", "k3"]) def test_calculate_kl(kl_type): """Tests KL calculations."""