Skip to content

Commit bfa408f

Browse files
committed
initial fix of the previous PR, add many test cases now, and will remove / check later
Signed-off-by: pengdurice <pengduhit@gmail.com>
1 parent 96f9b2e commit bfa408f

17 files changed

Lines changed: 875 additions & 38 deletions

examples/configs/cispo_math_8B.yaml

Lines changed: 0 additions & 25 deletions
This file was deleted.

examples/configs/grpo_math_1B.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ loss_fn:
7171
token_level_loss: true
7272
force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt)
7373
use_kl_in_reward: false # Reinforce++: add KL penalty to reward instead of loss
74+
use_cispo: false # CISPO (https://arxiv.org/abs/2506.13585): clipped IS-weight policy optimization
75+
# Optional CISPO-style diagnostics. Cheap; works on GRPO/DAPO/CISPO arms.
76+
# See ClippedPGLossConfig in nemo_rl/algorithms/loss/loss_functions.py.
77+
cispo_diagnostics: false
78+
cispo_diag_grpo_eps: 0.2 # baseline GRPO eps for would_clip_frac
79+
cispo_diag_low_prob_threshold: 0.05 # proxy threshold for rare reflective tokens
7480

7581
checkpointing:
7682
enabled: true
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# A/B treatment arm: CISPO (Clipped IS-weight Policy Optimization,
2+
# arXiv:2506.13585) on Qwen2.5-Math-1.5B-Instruct.
3+
#
4+
# CISPO replaces GRPO's hard PPO clip + advantage product with a stop-gradient
5+
# clipped importance weight applied to the log-probability:
6+
#
7+
# L_CISPO = -A_t * sg(clip(r_t, 1 - eps_low, 1 + eps_high)) * log pi(a_t)
8+
#
9+
# Pair with:
10+
# examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo.yaml
11+
# Everything except the loss-fn block is identical between the two arms.
12+
#
13+
# Off-policy regime (where CISPO and the hard PPO clip diverge most):
14+
# * 32 prompts x 16 generations = 512 trajectories per step
15+
# * train_global_batch_size = 128 -> 4 gradient updates per rollout
16+
# (matches the GSPO Sec 5.1 reference setting, arXiv:2507.18071)
17+
# * KL beta = 0 (CISPO paper Sec 5.1; kept identical in both arms so the
18+
# KL regularizer is not a confounder)
19+
#
20+
# NOT in the CISPO PR - this is a local research-validation artifact.
21+
# (The PR ships the on-policy machinery-smoke recipe at
22+
# examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml.)
23+
defaults: ../../grpo_math_1B.yaml
24+
25+
grpo:
26+
max_num_steps: 100
27+
val_period: 10
28+
val_at_start: true
29+
val_at_end: true
30+
max_val_samples: 256
31+
val_batch_size: 256
32+
seed: 42 # matched-pair: identical RNG to GRPO arm
33+
34+
policy:
35+
model_name: Qwen/Qwen2.5-Math-1.5B-Instruct
36+
tokenizer:
37+
name: Qwen/Qwen2.5-Math-1.5B-Instruct
38+
train_global_batch_size: 128 # off-policy: 4 grad updates / rollout
39+
train_micro_batch_size: 4
40+
logprob_batch_size: 4
41+
max_total_sequence_length: 1024
42+
dynamic_batching:
43+
enabled: true
44+
sequence_packing:
45+
enabled: false
46+
make_sequence_length_divisible_by: 1
47+
generation:
48+
max_new_tokens: 512
49+
vllm_cfg:
50+
max_model_len: 1024
51+
52+
data:
53+
max_input_seq_length: 512
54+
55+
loss_fn:
56+
# CISPO treatment arm. Paper-recommended clip: very loose lower (no
57+
# effective lower clip), tighter upper. With nemo-rl's parameterisation
58+
# (lower = 1 - ratio_clip_min, upper = 1 + ratio_clip_max):
59+
# ratio_clip_min = 1.0 -> lower bound = 0.0 (ratios are positive, so this
60+
# is effectively unclipped below)
61+
# ratio_clip_max = 0.8 -> upper bound = 1.8
62+
use_cispo: true
63+
reference_policy_kl_penalty: 0.0 # matched to the GRPO arm for fairness
64+
reference_policy_kl_type: k3
65+
ratio_clip_min: 1.0
66+
ratio_clip_max: 0.8
67+
ratio_clip_c: null # dual clipping MUST be off for CISPO
68+
token_level_loss: true
69+
use_importance_sampling_correction: false
70+
sequence_level_importance_ratios: false
71+
force_on_policy_ratio: false
72+
73+
checkpointing:
74+
enabled: false # research run; skip checkpoint I/O
75+
76+
logger:
77+
log_dir: logs/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo
78+
wandb_enabled: true
79+
tensorboard_enabled: true
80+
monitor_gpus: true
81+
wandb:
82+
project: nemo-rl
83+
name: cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo
84+
85+
cluster:
86+
gpus_per_node: 8
87+
num_nodes: 1
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# A/B baseline arm: vanilla GRPO with the standard hard PPO clip.
2+
#
3+
# This recipe is the *control* arm in a back-to-back A/B comparison meant to
4+
# isolate the effect of swapping the hard PPO clip for CISPO's clipped IS-
5+
# weight surrogate (arXiv:2506.13585). Pair with:
6+
# examples/configs/recipes/llm/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-cispo.yaml
7+
# Everything except the loss-fn block is identical between the two arms.
8+
#
9+
# Off-policy regime (where CISPO and the hard PPO clip diverge most):
10+
# * 32 prompts x 16 generations = 512 trajectories per step
11+
# * train_global_batch_size = 128 -> 4 gradient updates per rollout
12+
# (matches the GSPO Sec 5.1 reference setting, arXiv:2507.18071)
13+
# * KL beta = 0 (CISPO paper Sec 5.1; kept identical in both arms so the
14+
# KL regularizer is not a confounder)
15+
# * token-level loss, sampling temperature inherited from base
16+
#
17+
# NOT in the CISPO PR - this is a local research-validation artifact.
18+
defaults: ../../grpo_math_1B.yaml
19+
20+
grpo:
21+
max_num_steps: 100
22+
val_period: 10
23+
val_at_start: true
24+
val_at_end: true
25+
max_val_samples: 256
26+
val_batch_size: 256
27+
seed: 42
28+
29+
policy:
30+
model_name: Qwen/Qwen2.5-Math-1.5B-Instruct
31+
tokenizer:
32+
name: Qwen/Qwen2.5-Math-1.5B-Instruct
33+
train_global_batch_size: 128 # off-policy: 4 grad updates / rollout
34+
train_micro_batch_size: 4
35+
logprob_batch_size: 4
36+
max_total_sequence_length: 1024
37+
dynamic_batching:
38+
enabled: true
39+
sequence_packing:
40+
enabled: false
41+
make_sequence_length_divisible_by: 1
42+
generation:
43+
max_new_tokens: 512
44+
vllm_cfg:
45+
max_model_len: 1024
46+
47+
data:
48+
max_input_seq_length: 512
49+
50+
loss_fn:
51+
# GRPO control arm: standard hard PPO clip at +/- 0.2.
52+
use_cispo: false
53+
reference_policy_kl_penalty: 0.0 # matched to the CISPO arm for fairness
54+
reference_policy_kl_type: k3
55+
ratio_clip_min: 0.2 # PPO clip lower bound = 1 - 0.2 = 0.8
56+
ratio_clip_max: 0.2 # PPO clip upper bound = 1 + 0.2 = 1.2
57+
ratio_clip_c: null
58+
token_level_loss: true
59+
use_importance_sampling_correction: false
60+
sequence_level_importance_ratios: false
61+
force_on_policy_ratio: false
62+
63+
checkpointing:
64+
enabled: false # research run; skip checkpoint I/O
65+
66+
logger:
67+
log_dir: logs/cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo
68+
wandb_enabled: true
69+
tensorboard_enabled: true
70+
monitor_gpus: true
71+
wandb:
72+
project: nemo-rl
73+
name: cispo-ab-qwen2.5-math-1.5b-instruct-1n8g-grpo
74+
75+
cluster:
76+
gpus_per_node: 8
77+
num_nodes: 1
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# MiniMax-M1 replication study (https://arxiv.org/abs/2506.13585), CISPO arm.
2+
# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the
3+
# loss_fn block differs across arms.
4+
#
5+
# CISPO (MiniMax-M1 §3.1) clips the IS weight as a stop-gradient
6+
# coefficient instead of clipping the policy ratio. Gradients flow
7+
# through log pi for *every* token, including the rare reflective
8+
# tokens ("However", "Wait", "Recheck") that GRPO/DAPO would zero out.
9+
#
10+
# L_CISPO = -A * sg(clip(r, 1 - eps_low, 1 + eps_high)) * log pi(a)
11+
#
12+
# Per ms-swift's CISPO recipe and ScaleRL (arXiv:2510.13786), we use a
13+
# very loose lower clip and a much wider upper clip (eps_high = 5.0).
14+
defaults: ../../grpo_math_qwen30ba3b_megatron.yaml
15+
16+
grpo:
17+
num_prompts_per_step: 32
18+
num_generations_per_prompt: 16
19+
max_num_steps: 200
20+
val_period: 20
21+
val_at_start: true
22+
val_at_end: true
23+
max_val_samples: 128
24+
val_batch_size: 128
25+
26+
policy:
27+
model_name: Qwen/Qwen3-30B-A3B
28+
train_global_batch_size: 128
29+
train_micro_batch_size: 1
30+
logprob_batch_size: 1
31+
max_total_sequence_length: 4096
32+
sequence_packing:
33+
enabled: true
34+
algorithm: modified_first_fit_decreasing
35+
sequence_length_round: 64
36+
megatron_cfg:
37+
enabled: true
38+
converter_type: LlamaForCausalLM
39+
tensor_model_parallel_size: 2
40+
pipeline_model_parallel_size: 1
41+
expert_model_parallel_size: 8
42+
sequence_parallel: true
43+
empty_unused_memory_level: 1
44+
freeze_moe_router: true
45+
moe_router_dtype: fp64
46+
moe_router_load_balancing_type: none
47+
moe_router_bias_update_rate: 0.0
48+
optimizer:
49+
lr: 3.0e-7
50+
min_lr: 3.0e-8
51+
scheduler:
52+
lr_decay_iters: 500
53+
lr_warmup_iters: 10
54+
lr_warmup_init: 3.0e-8
55+
env_vars:
56+
PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False
57+
generation:
58+
vllm_cfg:
59+
tensor_parallel_size: 4
60+
gpu_memory_utilization: 0.7
61+
enforce_eager: false
62+
colocated:
63+
enabled: true
64+
65+
loss_fn:
66+
# ---- CISPO arm ----
67+
reference_policy_kl_penalty: 0.0
68+
reference_policy_kl_type: k3
69+
ratio_clip_min: 1.0 # lower bound = 0; effectively unclipped
70+
ratio_clip_max: 5.0 # eps_high = 5.0 (ms-swift / ScaleRL)
71+
ratio_clip_c: null # dual clipping MUST be off for CISPO
72+
token_level_loss: true
73+
use_importance_sampling_correction: false
74+
sequence_level_importance_ratios: false
75+
force_on_policy_ratio: false
76+
use_cispo: true
77+
cispo_diagnostics: true
78+
cispo_diag_grpo_eps: 0.2 # measure GRPO-equivalent clip rate
79+
cispo_diag_low_prob_threshold: 0.05
80+
81+
checkpointing:
82+
enabled: false
83+
84+
logger:
85+
log_dir: logs/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo
86+
wandb_enabled: true
87+
tensorboard_enabled: true
88+
monitor_gpus: false
89+
wandb:
90+
project: nemo-rl
91+
name: cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-cispo
92+
93+
cluster:
94+
gpus_per_node: 8
95+
num_nodes: 2
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# MiniMax-M1 replication study (https://arxiv.org/abs/2506.13585), DAPO arm.
2+
# Minimal-diff copy of workspace-4's proven 2n8g SAPO recipe. Only the
3+
# loss_fn block differs across arms.
4+
#
5+
# DAPO ("Clip-Higher", https://arxiv.org/abs/2503.14476): asymmetric clip
6+
# with a tighter lower bound and a looser upper bound.
7+
defaults: ../../grpo_math_qwen30ba3b_megatron.yaml
8+
9+
grpo:
10+
num_prompts_per_step: 32
11+
num_generations_per_prompt: 16
12+
max_num_steps: 200
13+
val_period: 20
14+
val_at_start: true
15+
val_at_end: true
16+
max_val_samples: 128
17+
val_batch_size: 128
18+
19+
policy:
20+
model_name: Qwen/Qwen3-30B-A3B
21+
train_global_batch_size: 128
22+
train_micro_batch_size: 1
23+
logprob_batch_size: 1
24+
max_total_sequence_length: 4096
25+
sequence_packing:
26+
enabled: true
27+
algorithm: modified_first_fit_decreasing
28+
sequence_length_round: 64
29+
megatron_cfg:
30+
enabled: true
31+
converter_type: LlamaForCausalLM
32+
tensor_model_parallel_size: 2
33+
pipeline_model_parallel_size: 1
34+
expert_model_parallel_size: 8
35+
sequence_parallel: true
36+
empty_unused_memory_level: 1
37+
freeze_moe_router: true
38+
moe_router_dtype: fp64
39+
moe_router_load_balancing_type: none
40+
moe_router_bias_update_rate: 0.0
41+
optimizer:
42+
lr: 3.0e-7
43+
min_lr: 3.0e-8
44+
scheduler:
45+
lr_decay_iters: 500
46+
lr_warmup_iters: 10
47+
lr_warmup_init: 3.0e-8
48+
env_vars:
49+
PYTORCH_CUDA_ALLOC_CONF: expandable_segments:False
50+
generation:
51+
vllm_cfg:
52+
tensor_parallel_size: 4
53+
gpu_memory_utilization: 0.7
54+
enforce_eager: false
55+
colocated:
56+
enabled: true
57+
58+
loss_fn:
59+
# ---- DAPO ("Clip-Higher") arm ----
60+
reference_policy_kl_penalty: 0.0
61+
reference_policy_kl_type: k3
62+
ratio_clip_min: 0.2 # eps_low - identical to GRPO
63+
ratio_clip_max: 0.28 # eps_high - DAPO "Clip-Higher"
64+
ratio_clip_c: null
65+
token_level_loss: true
66+
use_importance_sampling_correction: false
67+
sequence_level_importance_ratios: false
68+
force_on_policy_ratio: false
69+
use_cispo: false
70+
cispo_diagnostics: true
71+
cispo_diag_grpo_eps: 0.2
72+
cispo_diag_low_prob_threshold: 0.05
73+
74+
checkpointing:
75+
enabled: false
76+
77+
logger:
78+
log_dir: logs/cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo
79+
wandb_enabled: true
80+
tensorboard_enabled: true
81+
monitor_gpus: false
82+
wandb:
83+
project: nemo-rl
84+
name: cispo-mm1-replica-qwen3-30ba3b-2n8g-megatron-dapo
85+
86+
cluster:
87+
gpus_per_node: 8
88+
num_nodes: 2

0 commit comments

Comments
 (0)