Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions docs/about/algorithms/cispo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# CISPO

[Clipped Importance Sampling Policy Optimization (CISPO)](https://arxiv.org/abs/2506.13585) is a GRPO-family policy-gradient objective that clips the importance-sampling weight as a detached coefficient instead of using PPO-style hard ratio clipping.

For each generated token, CISPO computes the policy ratio

```text
r_t(theta) = pi_theta(o_t | q, o_<t) / pi_old(o_t | q, o_<t)
```

and uses a clipped, stop-gradient importance weight in the policy loss:

```text
L_CISPO = -A_t * sg(clip(r_t(theta), 1 - eps_low, 1 + eps_high)) * log pi_theta(o_t | q, o_<t)
```

This keeps gradients flowing through `log pi_theta` for every token while bounding the scalar importance weight. In contrast, standard GRPO/PPO-style clipping can zero out the gradient contribution for tokens whose ratios leave the clip range.

## Configuration

CISPO uses the same GRPO training path and `ClippedPGLossFn` as GRPO. Enable it in the `loss_fn` block:

```yaml
loss_fn:
use_cispo: true
token_level_loss: true
sequence_level_importance_ratios: false
force_on_policy_ratio: false
ratio_clip_min: 1.0
ratio_clip_max: 5.0
ratio_clip_c: null
```

`ratio_clip_min` and `ratio_clip_max` follow the paper's additive epsilon convention. The effective clamp range is:

```text
[1 - ratio_clip_min, 1 + ratio_clip_max]
```

For example, `ratio_clip_min: 1.0` and `ratio_clip_max: 5.0` clamp ratios to `[0, 6]`. Since policy ratios are non-negative, this is effectively an upper-only clamp at `6`.

## Async Lag-1 Recipe

The nightly CISPO recipe validates the objective in a high-off-policy setting with repeated updates per rollout and non-colocated async vLLM generation:

```bash
bash tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh
```

The corresponding config is:

```text
examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml
```

The recipe uses `Qwen/Qwen3-30B-A3B`, Megatron policy training, async GRPO with `max_trajectory_age_steps: 1`, and a separate non-colocated vLLM generation node.

## Additional Resources

- [MiniMax-M1 paper](https://arxiv.org/abs/2506.13585)
- [GRPO documentation](grpo.md)
- [DAPO documentation](dapo.md)
2 changes: 2 additions & 0 deletions docs/about/algorithms/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ We provide a reference GRPO configuration for math benchmarks using the [OpenIns

You can read about the details of the GRPO implementation [here](../../guides/grpo.md).

Related GRPO-family objectives are documented in [DAPO](dapo.md) and [CISPO](cispo.md).

## GRPO Single Node

To run GRPO on a single GPU for `Qwen/Qwen2.5-1.5B`:
Expand Down
6 changes: 6 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ loss_fn:
token_level_loss: true
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)
use_kl_in_reward: false # Reinforce++: add KL penalty to reward instead of loss
use_cispo: false # CISPO (https://arxiv.org/abs/2506.13585): clipped IS-weight policy optimization
# Optional CISPO-style diagnostics. Cheap; works on GRPO/DAPO/CISPO arms.
# See ClippedPGLossConfig in nemo_rl/algorithms/loss/loss_functions.py.
cispo_diagnostics: false
cispo_diag_grpo_eps: 0.2 # baseline GRPO eps for would_clip_frac
cispo_diag_low_prob_threshold: 0.05 # proxy threshold for rare reflective tokens

checkpointing:
enabled: true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# MiniMax-M1 async lag-1 high-off-policy study (https://arxiv.org/abs/2506.13585), CISPO arm.
# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the
# loss_fn block differs across arms.
#
# CISPO (MiniMax-M1 §3.1) clips the IS weight as a stop-gradient
# coefficient instead of clipping the policy ratio. Gradients flow
# through log pi for *every* token, including the rare reflective
# tokens ("However", "Wait", "Recheck") that GRPO/DAPO would zero out.
#
# L_CISPO = -A * sg(clip(r, 1 - eps_low, 1 + eps_high)) * log pi(a)
#
# Per ms-swift's CISPO recipe and ScaleRL (arXiv:2510.13786), we use a
# very loose lower clip and a much wider upper clip (eps_high = 5.0).
defaults: ../../grpo_math_qwen30ba3b_megatron.yaml

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_num_steps: 100
val_period: 20
val_at_start: true
val_at_end: true
max_val_samples: 128
val_batch_size: 128
async_grpo:
enabled: true
max_trajectory_age_steps: 1
in_flight_weight_updates: true

policy:
model_name: Qwen/Qwen3-30B-A3B
train_global_batch_size: 32
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
sequence_packing:
enabled: true
algorithm: modified_first_fit_decreasing
sequence_length_round: 64
megatron_cfg:
enabled: true
converter_type: LlamaForCausalLM
tensor_model_parallel_size: 2
pipeline_model_parallel_size: 1
expert_model_parallel_size: 8
sequence_parallel: true
empty_unused_memory_level: 1
freeze_moe_router: true
moe_router_dtype: fp64
moe_router_load_balancing_type: none
moe_router_bias_update_rate: 0.0
optimizer:
lr: 3.0e-7
min_lr: 3.0e-8
scheduler:
lr_decay_iters: 500
lr_warmup_iters: 10
lr_warmup_init: 3.0e-8
env_vars:
PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False
generation:
vllm_cfg:
tensor_parallel_size: 4
gpu_memory_utilization: 0.7
enforce_eager: false
async_engine: true
colocated:
enabled: false
resources:
num_nodes: 1
gpus_per_node: 8

loss_fn:
# ---- CISPO arm ----
reference_policy_kl_penalty: 0.0
reference_policy_kl_type: k3
ratio_clip_min: 1.0 # lower bound = 0; effectively unclipped
ratio_clip_max: 5.0 # eps_high = 5.0 (ms-swift / ScaleRL)
ratio_clip_c: null # dual clipping MUST be off for CISPO
token_level_loss: true
use_importance_sampling_correction: true
sequence_level_importance_ratios: false
force_on_policy_ratio: false
use_cispo: true

checkpointing:
enabled: false

logger:
log_dir: logs/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo
wandb_enabled: true
tensorboard_enabled: true
monitor_gpus: false
wandb:
project: nemo-rl
name: cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo

cluster:
gpus_per_node: 8
num_nodes: 3
36 changes: 35 additions & 1 deletion nemo_rl/algorithms/loss/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class ClippedPGLossConfig(BaseModel, extra="allow"):
# NOTE: This should only be used when doing exactly one update per rollout
# (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size)
force_on_policy_ratio: bool = False
# If True, use CISPO (Clipped IS-weight Policy Optimization) from MiniMax-M1.
use_cispo: bool = False


class ClippedPGLossDataDict(TypedDict):
Expand All @@ -152,6 +154,7 @@ class ClippedPGLossFn(LossFunction):
- GRPO - https://arxiv.org/abs/2402.03300
- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740
- GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071
- CISPO (set use_cispo = True) - https://arxiv.org/abs/2506.13585
- Truly on-policy (set force_on_policy_ratio = True to force ratio = 1.0, requires one update per rollout)

Formula:
Expand All @@ -171,6 +174,10 @@ class ClippedPGLossFn(LossFunction):
For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to:
L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref)

Formula (CISPO):
L(θ) = E_t [ sg(clip(r_t(θ), 1-ε_low, 1+ε_high)) * A_t * log π_θ(a_t|s_t) ]


Also supports "Dual-Clipping" from https://arxiv.org/pdf/1912.09729, which
imposes an additional upper bound on the probability ratio when advantages are negative.
This prevents excessive policy updates. $rA << 0$ -> $cA$(clipped)
Expand Down Expand Up @@ -218,6 +225,29 @@ def __init__(self, cfg: ClippedPGLossConfig):
"sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss"
)

self.use_cispo = cfg.use_cispo
if self.use_cispo:
assert not self.disable_ppo_ratio, (
"use_cispo is incompatible with disable_ppo_ratio; "
"CISPO needs the pi_theta/pi_theta_old ratio but disable_ppo_ratio removes it"
)
assert not self.force_on_policy_ratio, (
"use_cispo is incompatible with force_on_policy_ratio; "
"forcing ratio=1 removes the clipped IS-weight that CISPO optimizes"
)
assert not self.sequence_level_importance_ratios, (
"use_cispo is incompatible with sequence_level_importance_ratios; "
"CISPO uses token-level importance weights"
)
assert self.ratio_clip_c is None, (
"use_cispo is incompatible with dual clipping (ratio_clip_c); "
"the dual-clip block runs after the CISPO loss assembly and would "
"silently overwrite it. Set ratio_clip_c=null when use_cispo=True."
)
if self.truncated_importance_sampling_ratio is not None:
assert self.use_importance_sampling_correction, (
"truncated_importance_sampling_ratio is only supported when use_importance_sampling_correction is True"
)
if self.truncated_importance_sampling_type is not None:
assert self.use_importance_sampling_correction, (
"truncated importance sampling is only supported when use_importance_sampling_correction is True"
Expand Down Expand Up @@ -417,7 +447,11 @@ def __call__(

# Determine which value to use for clipping (max for pessimistic estimate)
clip_loss = torch.max(loss1, loss2)

if self.use_cispo:
ratios_clamped = ratios.clamp(
1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max
)
clip_loss = -advantages * ratios_clamped.detach() * curr_logprobs
# Dual-clipping see https://arxiv.org/pdf/1912.09729
if self.ratio_clip_c is not None:
assert self.ratio_clip_c > 1, (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# MiniMax-M1 async lag-1 high-off-policy study, CISPO arm.
# Uses 2 nodes for Megatron policy training plus 1 non-colocated vLLM node.
# See examples/configs/recipes/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.yaml.

# ===== BEGIN CONFIG =====
NUM_NODES=3
STEPS_PER_RUN=100
MAX_STEPS=100
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))
NUM_MINUTES=$((24 * 60))
# ===== END CONFIG =====

exit_if_max_steps_reached

cd $PROJECT_ROOT
uv run examples/run_grpo.py \
--config $CONFIG_PATH \
grpo.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.tensorboard_enabled=True \
checkpointing.enabled=False \
$@ \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
3 changes: 3 additions & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh
# GRPO math test run (32K context mcore)
tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh

# CISPO async lag-1 high-off-policy run (Qwen3-30B-A3B, Megatron + non-colocated vLLM)
tests/test_suites/llm/cispo-mm1-async-lag1-highoffpolicy-qwen3-30ba3b-2n8g-megatron-cispo.sh

# FP8
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh
tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-megatron-fp8-e2e.sh
Expand Down
Loading
Loading