Skip to content

Commit ea4859d

Browse files
committed
feat(rl): add REINFORCE advantage estimator
1 parent a897e1f commit ea4859d

7 files changed

Lines changed: 65 additions & 6 deletions

File tree

.github/workflows/pr-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ jobs:
372372
strategy:
373373
fail-fast: false
374374
matrix:
375-
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
375+
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_reinforce.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
376376
defaults:
377377
run:
378378
working-directory: ${{ github.workspace }}

.github/workflows/pr-test.yml.j2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
{'test_file': 'test_logprob_response_spans.py', 'num_gpus': 0},
7373
{'test_file': 'test_value_temperature.py', 'num_gpus': 0},
7474
{'test_file': 'test_cispo_loss.py', 'num_gpus': 0},
75+
{'test_file': 'test_reinforce.py', 'num_gpus': 0},
7576
{'test_file': 'test_rm_f1.py', 'num_gpus': 0},
7677
{'test_file': 'test_rm_gpqa.py', 'num_gpus': 0},
7778
{'test_file': 'test_rm_math.py', 'num_gpus': 0},

slime/backends/megatron_utils/loss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
compute_gspo_kl,
1818
compute_opsm_mask,
1919
compute_policy_loss,
20+
compute_reinforce_loss,
2021
get_advantages_and_returns_batch,
2122
get_grpo_returns,
2223
get_reinforce_plus_plus_baseline_advantages,
@@ -713,7 +714,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
713714
custom_adv_fn(args, rollout_data)
714715
advantages, returns = rollout_data["advantages"], rollout_data["returns"]
715716

716-
elif args.advantage_estimator in ["grpo", "gspo", "cispo"]:
717+
elif args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce"]:
717718
rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device)
718719
returns = get_grpo_returns(rewards, kl)
719720
# TODO: is the copy necessary?
@@ -973,6 +974,8 @@ def policy_loss_function(
973974

974975
if args.advantage_estimator == "cispo":
975976
pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip, args.eps_clip_high)
977+
elif args.advantage_estimator == "reinforce":
978+
pg_loss, pg_clipfrac = compute_reinforce_loss(advantages, log_probs)
976979
else:
977980
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)
978981

slime/ray/rollout.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
689689

690690
raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
691691
if (
692-
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce_plus_plus_baseline"]
692+
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce", "reinforce_plus_plus_baseline"]
693693
and self.args.rewards_normalization
694694
):
695695
# group norm
@@ -702,7 +702,10 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
702702
mean = rewards.mean(dim=-1, keepdim=True)
703703
rewards = rewards - mean
704704

705-
if self.args.advantage_estimator in ["grpo", "gspo", "cispo"] and self.args.grpo_std_normalization:
705+
if (
706+
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce"]
707+
and self.args.grpo_std_normalization
708+
):
706709
std = rewards.std(dim=-1, keepdim=True)
707710
rewards = rewards / (std + 1e-6)
708711

slime/utils/arguments.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,14 +911,17 @@ def add_algo_arguments(parser):
911911
"grpo",
912912
"gspo",
913913
"cispo",
914+
"reinforce",
914915
"reinforce_plus_plus",
915916
"reinforce_plus_plus_baseline",
916917
"ppo",
917918
],
918919
default="grpo",
919920
help=(
920-
"Advantage estimator to use. Note: on-policy distillation (OPD) is now orthogonal "
921-
"to the advantage estimator. Use --opd-kl-coef > 0 to enable OPD on top of any estimator."
921+
"Advantage estimator to use. 'reinforce' uses GRPO-style group-normalized "
922+
"advantages with the plain additive surrogate (no PPO/IS ratio, no clipping). "
923+
"Note: on-policy distillation (OPD) is now orthogonal to the advantage estimator. "
924+
"Use --opd-kl-coef > 0 to enable OPD on top of any estimator."
922925
),
923926
)
924927
parser.add_argument(

slime/utils/ppo_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ def compute_cispo_loss(
171171
return pg_losses, clipfrac
172172

173173

174+
@torch.compile(dynamic=True)
175+
def compute_reinforce_loss(
176+
advantages: torch.Tensor,
177+
log_probs: torch.Tensor,
178+
):
179+
"""REINFORCE surrogate ``-A * log pi_theta`` (no IS ratio, no clipping); gradient
180+
flows only through ``log_probs``. Same ``(per_token_loss, clipfrac)`` contract as
181+
:func:`compute_policy_loss`, with ``clipfrac`` identically zero (nothing is clipped).
182+
"""
183+
pg_losses = -advantages * log_probs
184+
clipfrac = torch.zeros_like(pg_losses)
185+
return pg_losses, clipfrac
186+
187+
174188
def compute_log_probs(
175189
logits: torch.Tensor,
176190
tokens: torch.Tensor,

tests/test_reinforce.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""CPU tests for compute_reinforce_loss (plain ``-A * log pi_theta`` surrogate)."""
2+
3+
import pytest
4+
import torch
5+
6+
from slime.utils.ppo_utils import compute_reinforce_loss
7+
8+
NUM_GPUS = 0
9+
10+
11+
@pytest.mark.unit
12+
def test_reinforce_loss_matches_closed_form():
13+
advantages = torch.tensor([2.0, -1.0, 0.5])
14+
log_probs = torch.tensor([-0.1, -0.2, -0.3])
15+
16+
pg_loss, clipfrac = compute_reinforce_loss(advantages, log_probs)
17+
18+
assert torch.allclose(pg_loss, -advantages * log_probs)
19+
assert torch.allclose(clipfrac, torch.zeros(3))
20+
21+
22+
@pytest.mark.unit
23+
def test_reinforce_gradient_flows_only_through_log_probs():
24+
advantages = torch.tensor([2.0, -1.0, 0.5])
25+
log_probs = torch.tensor([-0.1, -0.2, -0.3], requires_grad=True)
26+
27+
pg_loss, _ = compute_reinforce_loss(advantages, log_probs)
28+
pg_loss.sum().backward()
29+
30+
# d/d log_probs [ -A * log_probs ] = -A
31+
assert torch.allclose(log_probs.grad, -advantages)
32+
33+
34+
if __name__ == "__main__":
35+
raise SystemExit(pytest.main([__file__]))

0 commit comments

Comments
 (0)