Skip to content

Commit 1b70bf3

Browse files
EazyRealclaude
andcommitted
feat(rl): composable off-policy importance-sampling correction
Expose the current grad-carrying log-probs to the policy-loss TIS hook as `cur_log_probs`, and add `off_policy_is_function` (in ppo_utils, next to compute_policy_loss/compute_cispo_loss) -- a truncated-IS correction between the *current* policy and the *actual rollout generator*: the (detached) weight is `clip(pi_theta / pi_rollout)` against the real rollout logprob, so one weight corrects both the train/inference mismatch and async (multi-version) staleness. The existing TIS hook only had pi_theta_old / pi_rollout, which equals this only in the single-update-per-rollout limit. On a plain REINFORCE base (`--advantage-estimator reinforce`) this reproduces the CISPO surrogate expressed as a correction rather than the dedicated `compute_cispo_loss` estimator. Existing corrections ignore the new kwarg via **kwargs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent e46ca0a commit 1b70bf3

3 files changed

Lines changed: 122 additions & 0 deletions

File tree

slime/backends/megatron_utils/loss.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,9 @@ def policy_loss_function(
846846
)
847847

848848
log_probs = log_probs_and_entropy["log_probs"]
849+
# Current pi_theta (grad-carrying), captured before the cat below; passed to TIS-hook
850+
# corrections so they can form pi_theta / pi_rollout (see off_policy_is_function).
851+
cur_log_probs_list = log_probs
849852
if not args.use_rollout_logprobs and not old_log_probs:
850853
old_log_probs = [log_prob.detach() for log_prob in log_probs]
851854
train_log_probs_for_tis = batch.get("log_probs")
@@ -919,9 +922,12 @@ def policy_loss_function(
919922
assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"
920923

921924
ois = (-ppo_kl).exp()
925+
# Pass cur_log_probs (current pi_theta, grad-carrying) so corrections can form
926+
# pi_theta/pi_rollout, not just the frozen pi_theta_old/pi_rollout of vanilla TIS.
922927
tis_kwargs = {
923928
"args": args,
924929
"pg_loss": pg_loss,
930+
"cur_log_probs": cur_log_probs_list,
925931
"train_log_probs": train_log_probs_for_tis,
926932
"rollout_log_probs": batch["rollout_log_probs"],
927933
"loss_masks": batch["loss_masks"],

slime/utils/ppo_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py
33

44
from argparse import Namespace
5+
from typing import Any
56

67
import torch
78
import torch.distributed as dist
@@ -171,6 +172,37 @@ def compute_cispo_loss(
171172
return pg_losses, clipfrac
172173

173174

175+
def off_policy_is_function(
176+
args: Namespace,
177+
*,
178+
pg_loss: torch.Tensor,
179+
cur_log_probs: list[torch.Tensor],
180+
rollout_log_probs: list[torch.Tensor],
181+
loss_masks: list[torch.Tensor],
182+
**kwargs: Any,
183+
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
184+
"""Off-policy truncated IS (TIS hook): like ``vanilla_tis_function`` but with the
185+
*current* policy in the numerator instead of the old recompute, so the (detached)
186+
weight is ``clip(pi_theta / pi_rollout)`` against the actual rollout logprob -- one
187+
weight that corrects both the train/inference mismatch and async (multi-version)
188+
staleness. Composed with ``--advantage-estimator reinforce`` it is the CISPO surrogate
189+
(https://arxiv.org/abs/2506.13585), expressed as a correction rather than the dedicated
190+
``compute_cispo_loss`` estimator; ``--eps-clip 1.0`` gives canonical single-sided clipping.
191+
Same ``(pg_loss, loss_masks, metrics)`` contract; ``loss_masks`` unchanged.
192+
"""
193+
cur = torch.cat([lp.detach() for lp in cur_log_probs], dim=0)
194+
rollout = torch.cat(rollout_log_probs, dim=0)
195+
ratio = torch.exp(cur - rollout)
196+
is_weights = torch.clamp(ratio, min=1.0 - args.eps_clip, max=1.0 + args.eps_clip_high)
197+
is_clipfrac = (is_weights != ratio).float()
198+
metrics = {
199+
"is_weight": ratio.clone().detach(),
200+
"is_clipfrac": is_clipfrac.clone().detach(),
201+
}
202+
pg_loss = pg_loss * is_weights
203+
return pg_loss, loss_masks, metrics
204+
205+
174206
def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None):
175207
# TODO: when megatron is not installed, fall back to naive implementation
176208
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy

tests/test_off_policy_is.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Unit tests for the ``off_policy_is_function`` importance-sampling correction
2+
(slime/utils/ppo_utils.py).
3+
4+
It is truncated IS between the *current* policy and the *actual rollout generator*:
5+
the (detached) weight is ``clip(pi_theta / pi_rollout)``. On a plain REINFORCE base
6+
``-A * log pi`` it reproduces the CISPO surrogate (https://arxiv.org/abs/2506.13585).
7+
8+
Pure-torch (no megatron), like tests/test_chunked_gae.py; runs on CPU. NUM_GPUS = 0
9+
selects the CPU runner in the changed-test CI matrix; the __main__ block lets CI run
10+
it as a script. (The hook wiring in loss.py that supplies cur_log_probs imports
11+
megatron and is exercised in the GPU CI suites.)
12+
"""
13+
14+
from argparse import Namespace
15+
16+
import torch
17+
18+
from slime.utils.ppo_utils import off_policy_is_function
19+
20+
# CPU-only test: selects the 0-GPU runner in the changed-test CI matrix.
21+
NUM_GPUS = 0
22+
23+
24+
def test_off_policy_is_function_clips_weight_and_passes_masks_through():
25+
# ratio = exp(cur - rollout): ln(2) -> 2 -> clamp 1.2; ln(0.5) -> 0.5 -> 0.8; 0 -> 1.0
26+
cur = torch.tensor([1.0, 1.0, 1.0])
27+
rollout = cur - torch.tensor([2.0, 0.5, 1.0]).log()
28+
pg_loss = torch.tensor([1.0, 1.0, 1.0])
29+
loss_masks = [torch.ones(3)]
30+
args = Namespace(eps_clip=0.2, eps_clip_high=0.2)
31+
32+
out_loss, out_masks, metrics = off_policy_is_function(
33+
args, pg_loss=pg_loss, cur_log_probs=[cur], rollout_log_probs=[rollout], loss_masks=loss_masks
34+
)
35+
36+
expected_w = torch.tensor([1.2, 0.8, 1.0])
37+
assert torch.allclose(out_loss, pg_loss * expected_w)
38+
assert torch.allclose(metrics["is_clipfrac"], torch.tensor([1.0, 1.0, 0.0]))
39+
assert out_masks is loss_masks # no rejection-sampling masking
40+
41+
42+
def test_off_policy_is_on_reinforce_base_equals_cispo_surrogate():
43+
# On a plain REINFORCE base (-A * log pi), off_policy_is_function reproduces the
44+
# CISPO surrogate exactly, with gradient flowing ONLY through log_probs.
45+
advantages = torch.tensor([2.0, -1.0, 0.5, 1.5])
46+
rollout = torch.tensor([-0.5, -0.2, -0.9, -0.3]) # behavior policy mu (frozen)
47+
log_probs = torch.tensor([-0.1, -0.4, -0.3, -0.8], requires_grad=True)
48+
args = Namespace(eps_clip=0.2, eps_clip_high=0.2)
49+
50+
pg_loss = -advantages * log_probs # plain REINFORCE base
51+
pg_loss, _, _ = off_policy_is_function(
52+
args, pg_loss=pg_loss, cur_log_probs=[log_probs], rollout_log_probs=[rollout], loss_masks=[torch.ones(4)]
53+
)
54+
55+
ratio = torch.exp(log_probs.detach() - rollout) # pi_theta / pi_rollout
56+
clipped = ratio.clamp(1 - args.eps_clip, 1 + args.eps_clip_high)
57+
assert torch.allclose(pg_loss, -clipped * advantages * log_probs.detach())
58+
59+
pg_loss.sum().backward()
60+
# d/d log_probs [ -clip(ratio).detach() * A * log_probs ] = -clip(ratio) * A
61+
assert torch.allclose(log_probs.grad, -clipped * advantages)
62+
63+
64+
def test_off_policy_is_single_sided_when_eps_clip_one():
65+
# Canonical CISPO: eps_clip=1.0 disables the lower bound (ratio >= 0 never clipped low).
66+
cur = torch.tensor([0.0, 0.0])
67+
rollout = cur - torch.tensor([10.0, 0.01]).log() # ratios 10.0 (high) and ~0.01 (very low)
68+
pg_loss = torch.tensor([1.0, 1.0])
69+
args = Namespace(eps_clip=1.0, eps_clip_high=4.0)
70+
71+
_, _, metrics = off_policy_is_function(
72+
args, pg_loss=pg_loss, cur_log_probs=[cur], rollout_log_probs=[rollout], loss_masks=[torch.ones(2)]
73+
)
74+
75+
# high ratio 10.0 > 1+eps_clip_high=5.0 clipped; low ratio ~0.01 >= 1-eps_clip=0.0 NOT clipped
76+
assert torch.allclose(metrics["is_clipfrac"], torch.tensor([1.0, 0.0]))
77+
78+
79+
if __name__ == "__main__":
80+
for name, fn in sorted(globals().items()):
81+
if name.startswith("test_") and callable(fn):
82+
fn()
83+
print(f"PASSED {name}")
84+
print("OK")

0 commit comments

Comments
 (0)