Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
185f3f6
Add Delightful Policy Gradient gate (use_delight) to GRPO loss and en…
finbarrtimbers Apr 20, 2026
853d736
Add CHANGELOG entry for --use_delight. Co-Authored-By: Claude Opus 4.…
finbarrtimbers Apr 20, 2026
77588c9
Add Kondo gate (per-sample Bernoulli backward-skip on delight) to GRP…
finbarrtimbers Apr 20, 2026
4d5321c
Simplify Kondo gate: NamedTuple decision, shared metrics helper, toke…
finbarrtimbers Apr 21, 2026
0440cc4
Debug Kondo gate: add tracing logs + simplify decide(delight, mask). …
finbarrtimbers Apr 21, 2026
658827a
Kondo gate: log every decide() call unconditionally for debug. Co-Aut…
finbarrtimbers Apr 21, 2026
b765f40
Kondo gate: log quantile probes + frac_buf>lam to diagnose lambda. Co…
finbarrtimbers Apr 21, 2026
a8bd61f
Lower kondo_gate_warmup to 16 in large_test_script so debug runs exit…
finbarrtimbers Apr 21, 2026
90d07ae
Return dict from compute_grpo_loss instead of 5-tuple. Co-Authored-By…
finbarrtimbers Apr 21, 2026
5d2abca
Use LossOutput dataclass for compute_grpo_loss return value. Co-Autho…
finbarrtimbers Apr 21, 2026
88fe46f
Strip defensive guards from Kondo gate / delight code. Co-Authored-By…
finbarrtimbers Apr 21, 2026
04c52df
Merge remote-tracking branch 'origin/main' into finbarr/delight
finbarrtimbers Apr 22, 2026
06f94df
use delight
finbarrtimbers Apr 22, 2026
70acb52
fixed script
finbarrtimbers Apr 22, 2026
3e593ba
Apply delight gate at sample level to preserve blunder learning signa…
finbarrtimbers Apr 22, 2026
2f1e65b
Added eval scripts
finbarrtimbers Apr 24, 2026
9fa3062
updated code
finbarrtimbers Apr 24, 2026
0b02490
cleaned up script
finbarrtimbers Apr 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ token_length.png
birr/

oe-eval-internal/
olmo-eval-internal/

results
models
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file.


### Changed
- Add `--use_kondo_gate` flag to GRPO that skips backward passes on low-delight samples via the Kondo gate (https://arxiv.org/abs/2603.20526), with `--kondo_gate_rate`, `--kondo_gate_temperature`, `--kondo_gate_history_size`, and `--kondo_gate_warmup` controls.
- Add `--use_delight` flag to GRPO loss that gates per-token policy-gradient terms with the Delightful Policy Gradient sigmoid (https://github.com/allenai/open-instruct/pull/1628).
- Simplified model step tracking logic (https://github.com/allenai/open-instruct/pull/1616).
- Pass `attention_mask=None` in GRPO `forward_for_logprobs` calls — HF constructs the correct 3D intra-document mask from `position_ids` internally (https://github.com/allenai/open-instruct/pull/1617).
- Migrate GRPO trainer→vLLM weight sync to vLLM 0.16.0's native weight transfer API (`NCCLWeightTransferEngine`), replacing custom NCCL process-group and broadcast code (https://github.com/allenai/open-instruct/pull/1515).
Expand Down
56 changes: 43 additions & 13 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,11 @@ def load(self, path: str, map_location=None):
alpha=args.alpha,
)
self.local_metrics = utils.MetricsTracker(max_metrics=512, device=self.device)
self._kondo_gate = (
grpo_utils.KondoGateState(args, self.device, process_group=None, seed=args.seed)
if args.use_kondo_gate
else None
)

if self.mpu is not None:
self.splitter = UlyssesSPSplitter(
Expand Down Expand Up @@ -668,6 +673,8 @@ def step(self):
token_counts_per_sample = torch.stack([mask[:, 1:].sum().float() for mask in data_BT.response_masks])
device = token_counts_per_sample.device
grad_norms: list[float] = [] # May include nan/inf values reported by DeepSpeed.
group_had_backward = False
kondo_gate_stats: list[grpo_utils.KondoGateDecision] = []
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
with Timer("[Training Processes] Loss calculation", noop=self.rank != 0):
loss_stats_B = grpo_utils.create_loss_stats(num_samples, device, record_entropy=self.args.record_entropy)
Expand Down Expand Up @@ -743,39 +750,57 @@ def step(self):
self.args.truncated_importance_sampling_ratio_cap,
)

pg_losses_BT, pg_losses2_BT, pg_loss_max_BT, kl_BT = grpo_utils.compute_grpo_loss(
loss_output = grpo_utils.compute_grpo_loss(
new_logprobs=new_logprobs_BT,
ratio=ratio_BT,
advantages=data_BT.advantages[i][:, 1:],
ref_logprobs=ref_logprobs_BT[i] if self.args.load_ref_policy else None,
config=self.args,
tis_weights=tis_clamped_BT,
response_mask=response_mask_BT,
)

per_token_loss_BT = pg_loss_max_BT + self.args.beta * kl_BT
per_token_loss_BT = loss_output.pg_loss_max + self.args.beta * loss_output.kl
loss = masked_mean(per_token_loss_BT, response_mask_BT, None, loss_denominator)

# we already took world size into account via the tokens
# but deepspeed will try to average over ranks, so multiply back
# up, adjusting for the sequence parallel size (adjust by dp world size).
loss *= self.args.world_size // self.args.sequence_parallel_size

# Clear CUDA cache before backward pass to free memory for reduce_scatter operations
torch.cuda.empty_cache()
if self._kondo_gate is not None:
decision = self._kondo_gate.decide(loss_output.delight, response_mask_BT)
kondo_gate_stats.append(decision)
should_backward = decision.should_backward
else:
should_backward = True

is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0
# Tell deepspeed whether this backward is the last in the accumulation group.
self.model.set_gradient_accumulation_boundary(is_accumulation_boundary)
self.model.backward(loss)
if should_backward:
# Clear CUDA cache before backward pass to free memory for reduce_scatter operations
torch.cuda.empty_cache()
self.model.set_gradient_accumulation_boundary(is_accumulation_boundary)
self.model.backward(loss)
group_had_backward = True
elif is_accumulation_boundary and group_had_backward:
# DeepSpeed defers the accumulation-group reduce-scatter to the boundary
# backward; if the boundary sample is gated, we still need a backward here
# (zeroed so it contributes no gradient) to flush earlier micro-steps' grads.
torch.cuda.empty_cache()
self.model.set_gradient_accumulation_boundary(True)
self.model.backward(loss * 0.0)
if is_accumulation_boundary:
self.model.step()
grad_norms.append(float(self.model.get_global_grad_norm()))
if group_had_backward:
self.model.step()
grad_norms.append(float(self.model.get_global_grad_norm()))
group_had_backward = False
local_step += 1
grpo_utils.populate_sample_loss_stats(
loss_stats_B,
i,
pg_losses_BT,
pg_losses2_BT,
pg_loss_max_BT,
loss_output.pg_losses,
loss_output.pg_losses2,
loss_output.pg_loss_max,
ratio_BT,
loss,
response_mask_BT,
Expand All @@ -790,7 +815,12 @@ def step(self):
batch_metrics = batch_data["metrics"]
with torch.no_grad():
self._compute_loss_metrics(loss_stats_B, token_counts_per_sample)
self.local_metrics["optim/grad_norm"] = sum(grad_norms) / len(grad_norms)
self.local_metrics["optim/grad_norm"] = (
sum(grad_norms) / len(grad_norms) if grad_norms else float("nan")
)
if self._kondo_gate is not None:
for k, v in grpo_utils.summarize_kondo_gate_stats(kondo_gate_stats).items():
self.local_metrics[k] = v
array_metrics = {}
for key, value in batch_metrics.items():
if value is None:
Expand Down
124 changes: 121 additions & 3 deletions open_instruct/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import os
from dataclasses import dataclass, field
from typing import Literal
from typing import Literal, NamedTuple

import numpy as np
import torch
Expand Down Expand Up @@ -118,6 +118,21 @@ class GRPOExperimentConfig(
"""Whether to load and use a reference policy for KL penalty calculation."""
loss_fn: GRPOLossType = GRPOLossType.dapo
"""Whether to use DAPO or CISPO loss function."""
use_delight: bool = False
"""Whether to gate per-token policy-gradient terms with the Delightful Policy Gradient sigmoid
of delight = advantage * surprisal (https://arxiv.org/abs/2603.14608)."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The arXiv link provided in the docstring contains a typo. The correct arXiv ID for the 'Delightful Policy Gradient' paper by Osband et al. (2024) is 2403.14608, not 2603.14608.

    of delight = advantage * surprisal (https://arxiv.org/abs/2403.14608)."

use_kondo_gate: bool = False
"""Whether to enable the Kondo gate (https://arxiv.org/abs/2603.20526): per-sample Bernoulli gate
on whether to run the backward pass, driven by sample-level delight against an adaptive threshold."""
kondo_gate_rate: float = 1.0
"""Target fraction rho of samples that receive a backward pass. 1.0 is a no-op even when the gate
is enabled (always backward). Smaller values keep only the highest-delight samples."""
kondo_gate_temperature: float = 1.0
"""Temperature eta in the Kondo gate Bernoulli probability sigma((chi - lambda) / eta)."""
kondo_gate_history_size: int = 1024
"""Size of the ring buffer of past sample delights used to compute lambda = quantile_{1-rho}."""
kondo_gate_warmup: int = 128
"""Never gate until the history contains at least this many sample delights."""
record_entropy: bool = False
"""whether to record the entropy of the policy during training. Uses extra memory."""
use_vllm_logprobs: bool = False
Expand Down Expand Up @@ -338,14 +353,37 @@ def resolve_old_logprob(
return result


@dataclass
class LossOutput:
pg_losses: torch.Tensor
pg_losses2: torch.Tensor
pg_loss_max: torch.Tensor
kl: torch.Tensor
delight: torch.Tensor


def compute_grpo_loss(
new_logprobs: torch.Tensor,
ratio: torch.Tensor,
advantages: torch.Tensor,
ref_logprobs: torch.Tensor | None,
config: GRPOExperimentConfig,
tis_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
response_mask: torch.Tensor | None = None,
) -> LossOutput:
delight = -advantages * new_logprobs.detach()
if config.use_delight:
# Delightful Policy Gradient gate applied at sample level: one sigmoid per rollout,
# broadcast across tokens. GRPO's advantage is constant across a response, so a
# per-token gate would zero out the exact "blunder" tokens whose negative signal
# we need to learn from; a sample-level chi = mean_t(-A * surprisal_t) preserves
# that signal while keeping the paper's breakthrough/blunder interpretation.
mask = response_mask.to(delight.dtype)
denom = mask.sum(dim=-1).clamp(min=1.0)
sample_chi = (delight * mask).sum(dim=-1) / denom
sample_gate = torch.sigmoid(sample_chi).unsqueeze(-1)
advantages = advantages * sample_gate

if config.loss_fn == GRPOLossType.dapo:
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - config.clip_lower, 1.0 + config.clip_higher)
Expand Down Expand Up @@ -373,7 +411,87 @@ def compute_grpo_loss(
else:
kl = torch.zeros_like(pg_loss_max)

return pg_losses, pg_losses2, pg_loss_max, kl
return LossOutput(pg_losses=pg_losses, pg_losses2=pg_losses2, pg_loss_max=pg_loss_max, kl=kl, delight=delight)


class KondoGateDecision(NamedTuple):
should_backward: bool
prob: float
lam: float


class KondoGateState:
"""Per-sample Kondo gate over delight (https://arxiv.org/abs/2603.20526).

Maintains a ring buffer of past sample-level delight values, computes an adaptive
threshold lambda = quantile_{1-rho}(history), and draws a Bernoulli gate with
probability sigma((chi - lambda) / eta). All-reduces the sample delight across DP
ranks and uses an identically-seeded generator so every rank produces the same
gate decision -- required to keep DeepSpeed / FSDP collectives in sync.
"""

def __init__(
self,
config: GRPOExperimentConfig,
device: torch.device,
process_group: dist.ProcessGroup | None = None,
seed: int = 0,
) -> None:
self.device = device
self.process_group = process_group
self.history_size = config.kondo_gate_history_size
self.warmup = config.kondo_gate_warmup
self.rate = config.kondo_gate_rate
self.temperature = config.kondo_gate_temperature
self._buffer = torch.zeros(self.history_size, device=device)
self._count = 0
self._write_idx = 0
self._generator = torch.Generator(device=device)
self._generator.manual_seed(int(seed))

def _reduced_chi(self, delight: torch.Tensor, response_mask: torch.Tensor) -> torch.Tensor:
"""Reduce (sum_delight, sum_tokens) across the process group and return sum/count.

With Ulysses SP, each rank holds a sequence-slice of its sample, so per-rank
slice-means differ across SP-mates. Reducing the numerator and denominator
separately gives the correct token-weighted mean regardless of slice lengths,
and the result is identical on every rank in the group (required to keep
DeepSpeed / FSDP collectives in sync).
"""
packed = torch.stack([(delight * response_mask).sum().detach(), response_mask.sum().float().detach()])
if dist.is_initialized():
dist.all_reduce(packed, op=dist.ReduceOp.SUM, group=self.process_group)
return packed[0] / packed[1]

def _append(self, value: torch.Tensor) -> None:
self._buffer[self._write_idx] = value
self._write_idx = (self._write_idx + 1) % self.history_size
self._count = min(self._count + 1, self.history_size)

def decide(self, delight: torch.Tensor, response_mask: torch.Tensor) -> KondoGateDecision:
"""Computes token-weighted chi over the response, all-reduces across ranks, and gates.

Returns identical values on every rank in the process group.
"""
chi = self._reduced_chi(delight, response_mask)
self._append(chi)
if self._count < self.warmup:
return KondoGateDecision(True, 1.0, float("nan"))
buf = self._buffer[: self._count]
lam = torch.quantile(buf, 1.0 - self.rate)
prob = torch.sigmoid((chi - lam) / self.temperature)
gate = torch.bernoulli(prob, generator=self._generator)
return KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item()))


def summarize_kondo_gate_stats(stats: list[KondoGateDecision]) -> dict[str, float]:
"""Aggregate per-sample gate decisions into scalar metrics."""
n = len(stats)
return {
"val/kondo_gate_backward_frac": sum(int(s.should_backward) for s in stats) / n,
"val/kondo_gate_prob_avg": sum(s.prob for s in stats) / n,
"val/kondo_lambda": sum(s.lam for s in stats) / n,
}


def forward_for_logprobs(
Expand Down
46 changes: 37 additions & 9 deletions open_instruct/olmo_core_train_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,16 @@ def __init__(
if ref_policy is not None:
self.ref_policy = ref_policy.to(device=self.device).eval().requires_grad_(False)

self._kondo_gate: grpo_utils.KondoGateState | None = None

def pre_train(self):
# GRPO batches are prompt-grouped and do their own accumulation/token normalization
# inside train_batch(), so the base TransformerTrainModule global-batch validation
# does not apply here.
pass
if self.grpo_config.use_kondo_gate:
self._kondo_gate = grpo_utils.KondoGateState(
self.grpo_config, self.device, process_group=self.trainer.dp_process_group, seed=self.grpo_config.seed
)

def state_dict(self, *, optim: bool | None = None) -> dict[str, Any]:
state = super().state_dict(optim=optim)
Expand Down Expand Up @@ -413,6 +418,8 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:

num_steps = 0
local_step = 0
group_had_backward = False
kondo_gate_stats: list[grpo_utils.KondoGateDecision] = []

for epoch_idx in range(self.grpo_config.num_epochs):
for sample_idx in range(num_samples):
Expand Down Expand Up @@ -450,28 +457,44 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
old_logprob, vllm_logprobs, response_mask, self.grpo_config.truncated_importance_sampling_ratio_cap
)

pg_losses, pg_losses2, pg_loss, kl = grpo_utils.compute_grpo_loss(
loss_output = grpo_utils.compute_grpo_loss(
new_logprobs=new_logprobs,
ratio=ratio,
advantages=advantages[:, 1:],
ref_logprobs=ref_logprobs_BT[sample_idx] if ref_logprobs_BT is not None else None,
config=self.grpo_config,
tis_weights=tis_clamped,
response_mask=response_mask,
)

batch_start = (sample_idx // accumulation_steps) * accumulation_steps
loss_denominator = accumulation_token_counts[batch_start]
loss = masked_mean(pg_loss + self.grpo_config.beta * kl, response_mask, None, loss_denominator)
loss = masked_mean(
loss_output.pg_loss_max + self.grpo_config.beta * loss_output.kl,
response_mask,
None,
loss_denominator,
)

loss = loss * dp_world_size
loss.backward()

if self._kondo_gate is not None:
decision = self._kondo_gate.decide(loss_output.delight, response_mask)
kondo_gate_stats.append(decision)
should_backward = decision.should_backward
else:
should_backward = True

if should_backward:
loss.backward()
group_had_backward = True

grpo_utils.populate_sample_loss_stats(
loss_stats_B,
sample_idx,
pg_losses,
pg_losses2,
pg_loss,
loss_output.pg_losses,
loss_output.pg_losses2,
loss_output.pg_loss_max,
ratio,
loss,
response_mask,
Expand All @@ -487,14 +510,16 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
local_step += 1

if local_step % accumulation_steps == 0:
if not dry_run:
if not dry_run and group_had_backward:
self.optim_step()
self.zero_grads()
group_had_backward = False

if local_step % accumulation_steps != 0:
if not dry_run:
if not dry_run and group_had_backward:
self.optim_step()
self.zero_grads()
group_had_backward = False

if not dry_run and num_steps > 0:
local_metrics = grpo_utils.compute_metrics_from_loss_stats(loss_stats_B, token_counts)
Expand All @@ -516,6 +541,9 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None:
)
self.record_metric("lr", float(lr), reduce_type=None)
self.record_metric("_token_count", global_tokens, reduce_type=None)
if self._kondo_gate is not None:
for k, v in grpo_utils.summarize_kondo_gate_stats(kondo_gate_stats).items():
self.record_metric(k, v, reduce_type=None)

data_prep_metrics = batch.get("metrics") or {}
for metric_key, metric_value in data_prep_metrics.items():
Expand Down
Loading