diff --git a/.gitignore b/.gitignore index c5cf71db88..b12d763d96 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ token_length.png birr/ oe-eval-internal/ +olmo-eval-internal/ results models diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b811ccd37..3d981a2230 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. ### Changed +- Add `--use_kondo_gate` flag to GRPO that skips backward passes on low-delight samples via the Kondo gate (https://arxiv.org/abs/2603.20526), with `--kondo_gate_rate`, `--kondo_gate_temperature`, `--kondo_gate_history_size`, and `--kondo_gate_warmup` controls. +- Add `--use_delight` flag to GRPO loss that gates per-token policy-gradient terms with the Delightful Policy Gradient sigmoid (https://github.com/allenai/open-instruct/pull/1628). - Simplified model step tracking logic (https://github.com/allenai/open-instruct/pull/1616). - Pass `attention_mask=None` in GRPO `forward_for_logprobs` calls — HF constructs the correct 3D intra-document mask from `position_ids` internally (https://github.com/allenai/open-instruct/pull/1617). - Migrate GRPO trainer→vLLM weight sync to vLLM 0.16.0's native weight transfer API (`NCCLWeightTransferEngine`), replacing custom NCCL process-group and broadcast code (https://github.com/allenai/open-instruct/pull/1515). diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 886162529b..4cf4bb3820 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -471,6 +471,11 @@ def load(self, path: str, map_location=None): alpha=args.alpha, ) self.local_metrics = utils.MetricsTracker(max_metrics=512, device=self.device) + self._kondo_gate = ( + grpo_utils.KondoGateState(args, self.device, process_group=None, seed=args.seed) + if args.use_kondo_gate + else None + ) if self.mpu is not None: self.splitter = UlyssesSPSplitter( @@ -668,6 +673,8 @@ def step(self): token_counts_per_sample = torch.stack([mask[:, 1:].sum().float() for mask in data_BT.response_masks]) device = token_counts_per_sample.device grad_norms: list[float] = [] # May include nan/inf values reported by DeepSpeed. + group_had_backward = False + kondo_gate_stats: list[grpo_utils.KondoGateDecision] = [] # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): loss_stats_B = grpo_utils.create_loss_stats(num_samples, device, record_entropy=self.args.record_entropy) @@ -743,16 +750,17 @@ def step(self): self.args.truncated_importance_sampling_ratio_cap, ) - pg_losses_BT, pg_losses2_BT, pg_loss_max_BT, kl_BT = grpo_utils.compute_grpo_loss( + loss_output = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs_BT, ratio=ratio_BT, advantages=data_BT.advantages[i][:, 1:], ref_logprobs=ref_logprobs_BT[i] if self.args.load_ref_policy else None, config=self.args, tis_weights=tis_clamped_BT, + response_mask=response_mask_BT, ) - per_token_loss_BT = pg_loss_max_BT + self.args.beta * kl_BT + per_token_loss_BT = loss_output.pg_loss_max + self.args.beta * loss_output.kl loss = masked_mean(per_token_loss_BT, response_mask_BT, None, loss_denominator) # we already took world size into account via the tokens @@ -760,22 +768,39 @@ def step(self): # up, adjusting for the sequence parallel size (adjust by dp world size). loss *= self.args.world_size // self.args.sequence_parallel_size - # Clear CUDA cache before backward pass to free memory for reduce_scatter operations - torch.cuda.empty_cache() + if self._kondo_gate is not None: + decision = self._kondo_gate.decide(loss_output.delight, response_mask_BT) + kondo_gate_stats.append(decision) + should_backward = decision.should_backward + else: + should_backward = True + is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0 - # Tell deepspeed whether this backward is the last in the accumulation group. - self.model.set_gradient_accumulation_boundary(is_accumulation_boundary) - self.model.backward(loss) + if should_backward: + # Clear CUDA cache before backward pass to free memory for reduce_scatter operations + torch.cuda.empty_cache() + self.model.set_gradient_accumulation_boundary(is_accumulation_boundary) + self.model.backward(loss) + group_had_backward = True + elif is_accumulation_boundary and group_had_backward: + # DeepSpeed defers the accumulation-group reduce-scatter to the boundary + # backward; if the boundary sample is gated, we still need a backward here + # (zeroed so it contributes no gradient) to flush earlier micro-steps' grads. + torch.cuda.empty_cache() + self.model.set_gradient_accumulation_boundary(True) + self.model.backward(loss * 0.0) if is_accumulation_boundary: - self.model.step() - grad_norms.append(float(self.model.get_global_grad_norm())) + if group_had_backward: + self.model.step() + grad_norms.append(float(self.model.get_global_grad_norm())) + group_had_backward = False local_step += 1 grpo_utils.populate_sample_loss_stats( loss_stats_B, i, - pg_losses_BT, - pg_losses2_BT, - pg_loss_max_BT, + loss_output.pg_losses, + loss_output.pg_losses2, + loss_output.pg_loss_max, ratio_BT, loss, response_mask_BT, @@ -790,7 +815,12 @@ def step(self): batch_metrics = batch_data["metrics"] with torch.no_grad(): self._compute_loss_metrics(loss_stats_B, token_counts_per_sample) - self.local_metrics["optim/grad_norm"] = sum(grad_norms) / len(grad_norms) + self.local_metrics["optim/grad_norm"] = ( + sum(grad_norms) / len(grad_norms) if grad_norms else float("nan") + ) + if self._kondo_gate is not None: + for k, v in grpo_utils.summarize_kondo_gate_stats(kondo_gate_stats).items(): + self.local_metrics[k] = v array_metrics = {} for key, value in batch_metrics.items(): if value is None: diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index 830c0e560d..9e569f152d 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -2,7 +2,7 @@ import math import os from dataclasses import dataclass, field -from typing import Literal +from typing import Literal, NamedTuple import numpy as np import torch @@ -118,6 +118,21 @@ class GRPOExperimentConfig( """Whether to load and use a reference policy for KL penalty calculation.""" loss_fn: GRPOLossType = GRPOLossType.dapo """Whether to use DAPO or CISPO loss function.""" + use_delight: bool = False + """Whether to gate per-token policy-gradient terms with the Delightful Policy Gradient sigmoid + of delight = advantage * surprisal (https://arxiv.org/abs/2603.14608).""" + use_kondo_gate: bool = False + """Whether to enable the Kondo gate (https://arxiv.org/abs/2603.20526): per-sample Bernoulli gate + on whether to run the backward pass, driven by sample-level delight against an adaptive threshold.""" + kondo_gate_rate: float = 1.0 + """Target fraction rho of samples that receive a backward pass. 1.0 is a no-op even when the gate + is enabled (always backward). Smaller values keep only the highest-delight samples.""" + kondo_gate_temperature: float = 1.0 + """Temperature eta in the Kondo gate Bernoulli probability sigma((chi - lambda) / eta).""" + kondo_gate_history_size: int = 1024 + """Size of the ring buffer of past sample delights used to compute lambda = quantile_{1-rho}.""" + kondo_gate_warmup: int = 128 + """Never gate until the history contains at least this many sample delights.""" record_entropy: bool = False """whether to record the entropy of the policy during training. Uses extra memory.""" use_vllm_logprobs: bool = False @@ -338,6 +353,15 @@ def resolve_old_logprob( return result +@dataclass +class LossOutput: + pg_losses: torch.Tensor + pg_losses2: torch.Tensor + pg_loss_max: torch.Tensor + kl: torch.Tensor + delight: torch.Tensor + + def compute_grpo_loss( new_logprobs: torch.Tensor, ratio: torch.Tensor, @@ -345,7 +369,21 @@ def compute_grpo_loss( ref_logprobs: torch.Tensor | None, config: GRPOExperimentConfig, tis_weights: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + response_mask: torch.Tensor | None = None, +) -> LossOutput: + delight = -advantages * new_logprobs.detach() + if config.use_delight: + # Delightful Policy Gradient gate applied at sample level: one sigmoid per rollout, + # broadcast across tokens. GRPO's advantage is constant across a response, so a + # per-token gate would zero out the exact "blunder" tokens whose negative signal + # we need to learn from; a sample-level chi = mean_t(-A * surprisal_t) preserves + # that signal while keeping the paper's breakthrough/blunder interpretation. + mask = response_mask.to(delight.dtype) + denom = mask.sum(dim=-1).clamp(min=1.0) + sample_chi = (delight * mask).sum(dim=-1) / denom + sample_gate = torch.sigmoid(sample_chi).unsqueeze(-1) + advantages = advantages * sample_gate + if config.loss_fn == GRPOLossType.dapo: pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - config.clip_lower, 1.0 + config.clip_higher) @@ -373,7 +411,87 @@ def compute_grpo_loss( else: kl = torch.zeros_like(pg_loss_max) - return pg_losses, pg_losses2, pg_loss_max, kl + return LossOutput(pg_losses=pg_losses, pg_losses2=pg_losses2, pg_loss_max=pg_loss_max, kl=kl, delight=delight) + + +class KondoGateDecision(NamedTuple): + should_backward: bool + prob: float + lam: float + + +class KondoGateState: + """Per-sample Kondo gate over delight (https://arxiv.org/abs/2603.20526). + + Maintains a ring buffer of past sample-level delight values, computes an adaptive + threshold lambda = quantile_{1-rho}(history), and draws a Bernoulli gate with + probability sigma((chi - lambda) / eta). All-reduces the sample delight across DP + ranks and uses an identically-seeded generator so every rank produces the same + gate decision -- required to keep DeepSpeed / FSDP collectives in sync. + """ + + def __init__( + self, + config: GRPOExperimentConfig, + device: torch.device, + process_group: dist.ProcessGroup | None = None, + seed: int = 0, + ) -> None: + self.device = device + self.process_group = process_group + self.history_size = config.kondo_gate_history_size + self.warmup = config.kondo_gate_warmup + self.rate = config.kondo_gate_rate + self.temperature = config.kondo_gate_temperature + self._buffer = torch.zeros(self.history_size, device=device) + self._count = 0 + self._write_idx = 0 + self._generator = torch.Generator(device=device) + self._generator.manual_seed(int(seed)) + + def _reduced_chi(self, delight: torch.Tensor, response_mask: torch.Tensor) -> torch.Tensor: + """Reduce (sum_delight, sum_tokens) across the process group and return sum/count. + + With Ulysses SP, each rank holds a sequence-slice of its sample, so per-rank + slice-means differ across SP-mates. Reducing the numerator and denominator + separately gives the correct token-weighted mean regardless of slice lengths, + and the result is identical on every rank in the group (required to keep + DeepSpeed / FSDP collectives in sync). + """ + packed = torch.stack([(delight * response_mask).sum().detach(), response_mask.sum().float().detach()]) + if dist.is_initialized(): + dist.all_reduce(packed, op=dist.ReduceOp.SUM, group=self.process_group) + return packed[0] / packed[1] + + def _append(self, value: torch.Tensor) -> None: + self._buffer[self._write_idx] = value + self._write_idx = (self._write_idx + 1) % self.history_size + self._count = min(self._count + 1, self.history_size) + + def decide(self, delight: torch.Tensor, response_mask: torch.Tensor) -> KondoGateDecision: + """Computes token-weighted chi over the response, all-reduces across ranks, and gates. + + Returns identical values on every rank in the process group. + """ + chi = self._reduced_chi(delight, response_mask) + self._append(chi) + if self._count < self.warmup: + return KondoGateDecision(True, 1.0, float("nan")) + buf = self._buffer[: self._count] + lam = torch.quantile(buf, 1.0 - self.rate) + prob = torch.sigmoid((chi - lam) / self.temperature) + gate = torch.bernoulli(prob, generator=self._generator) + return KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item())) + + +def summarize_kondo_gate_stats(stats: list[KondoGateDecision]) -> dict[str, float]: + """Aggregate per-sample gate decisions into scalar metrics.""" + n = len(stats) + return { + "val/kondo_gate_backward_frac": sum(int(s.should_backward) for s in stats) / n, + "val/kondo_gate_prob_avg": sum(s.prob for s in stats) / n, + "val/kondo_lambda": sum(s.lam for s in stats) / n, + } def forward_for_logprobs( diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index e4b0fe98f5..775d5935e8 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -321,11 +321,16 @@ def __init__( if ref_policy is not None: self.ref_policy = ref_policy.to(device=self.device).eval().requires_grad_(False) + self._kondo_gate: grpo_utils.KondoGateState | None = None + def pre_train(self): # GRPO batches are prompt-grouped and do their own accumulation/token normalization # inside train_batch(), so the base TransformerTrainModule global-batch validation # does not apply here. - pass + if self.grpo_config.use_kondo_gate: + self._kondo_gate = grpo_utils.KondoGateState( + self.grpo_config, self.device, process_group=self.trainer.dp_process_group, seed=self.grpo_config.seed + ) def state_dict(self, *, optim: bool | None = None) -> dict[str, Any]: state = super().state_dict(optim=optim) @@ -413,6 +418,8 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: num_steps = 0 local_step = 0 + group_had_backward = False + kondo_gate_stats: list[grpo_utils.KondoGateDecision] = [] for epoch_idx in range(self.grpo_config.num_epochs): for sample_idx in range(num_samples): @@ -450,28 +457,44 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: old_logprob, vllm_logprobs, response_mask, self.grpo_config.truncated_importance_sampling_ratio_cap ) - pg_losses, pg_losses2, pg_loss, kl = grpo_utils.compute_grpo_loss( + loss_output = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages[:, 1:], ref_logprobs=ref_logprobs_BT[sample_idx] if ref_logprobs_BT is not None else None, config=self.grpo_config, tis_weights=tis_clamped, + response_mask=response_mask, ) batch_start = (sample_idx // accumulation_steps) * accumulation_steps loss_denominator = accumulation_token_counts[batch_start] - loss = masked_mean(pg_loss + self.grpo_config.beta * kl, response_mask, None, loss_denominator) + loss = masked_mean( + loss_output.pg_loss_max + self.grpo_config.beta * loss_output.kl, + response_mask, + None, + loss_denominator, + ) loss = loss * dp_world_size - loss.backward() + + if self._kondo_gate is not None: + decision = self._kondo_gate.decide(loss_output.delight, response_mask) + kondo_gate_stats.append(decision) + should_backward = decision.should_backward + else: + should_backward = True + + if should_backward: + loss.backward() + group_had_backward = True grpo_utils.populate_sample_loss_stats( loss_stats_B, sample_idx, - pg_losses, - pg_losses2, - pg_loss, + loss_output.pg_losses, + loss_output.pg_losses2, + loss_output.pg_loss_max, ratio, loss, response_mask, @@ -487,14 +510,16 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: local_step += 1 if local_step % accumulation_steps == 0: - if not dry_run: + if not dry_run and group_had_backward: self.optim_step() self.zero_grads() + group_had_backward = False if local_step % accumulation_steps != 0: - if not dry_run: + if not dry_run and group_had_backward: self.optim_step() self.zero_grads() + group_had_backward = False if not dry_run and num_steps > 0: local_metrics = grpo_utils.compute_metrics_from_loss_stats(loss_stats_B, token_counts) @@ -516,6 +541,9 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: ) self.record_metric("lr", float(lr), reduce_type=None) self.record_metric("_token_count", global_tokens, reduce_type=None) + if self._kondo_gate is not None: + for k, v in grpo_utils.summarize_kondo_gate_stats(kondo_gate_stats).items(): + self.record_metric(k, v, reduce_type=None) data_prep_metrics = batch.get("metrics") or {} for metric_key, metric_value in data_prep_metrics.items(): diff --git a/open_instruct/test_olmo_core_train_modules.py b/open_instruct/test_olmo_core_train_modules.py index eb0714af20..0c433ea2b7 100644 --- a/open_instruct/test_olmo_core_train_modules.py +++ b/open_instruct/test_olmo_core_train_modules.py @@ -1,3 +1,4 @@ +import math import unittest from unittest.mock import MagicMock @@ -160,6 +161,12 @@ def _make_grpo_config(**kwargs) -> grpo_utils.GRPOExperimentConfig: "kl_estimator": 2, "loss_fn": grpo_utils.GRPOLossType.dapo, "load_ref_policy": False, + "use_delight": False, + "use_kondo_gate": False, + "kondo_gate_rate": 1.0, + "kondo_gate_temperature": 1.0, + "kondo_gate_history_size": 32, + "kondo_gate_warmup": 8, } defaults.update(kwargs) config = MagicMock(spec=grpo_utils.GRPOExperimentConfig) @@ -177,14 +184,44 @@ def test_output_shapes(self, _name, loss_type): ratio = torch.exp(torch.randn(batch_size, seq_len)) advantages = torch.randn(batch_size, seq_len) - pg_losses, pg_losses2, pg_loss_max, kl = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - self.assertEqual(pg_losses.shape, (batch_size, seq_len)) - self.assertEqual(pg_losses2.shape, (batch_size, seq_len)) - self.assertEqual(pg_loss_max.shape, (batch_size, seq_len)) - self.assertEqual(kl.shape, (batch_size, seq_len)) + self.assertEqual(result.pg_losses.shape, (batch_size, seq_len)) + self.assertEqual(result.pg_losses2.shape, (batch_size, seq_len)) + self.assertEqual(result.pg_loss_max.shape, (batch_size, seq_len)) + self.assertEqual(result.kl.shape, (batch_size, seq_len)) + self.assertEqual(result.delight.shape, (batch_size, seq_len)) + torch.testing.assert_close(result.delight, -advantages * new_logprobs.detach()) + + def test_use_delight_applies_sample_level_gate(self): + config = _make_grpo_config(use_delight=True, loss_fn=grpo_utils.GRPOLossType.dapo) + batch_size, seq_len = 3, 5 + new_logprobs = torch.randn(batch_size, seq_len) + ratio = torch.ones(batch_size, seq_len) + advantages = torch.randn(batch_size, seq_len) + response_mask = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1], [1, 1, 0, 0, 0]], dtype=torch.bool) + + result = grpo_utils.compute_grpo_loss( + new_logprobs=new_logprobs, + ratio=ratio, + advantages=advantages, + ref_logprobs=None, + config=config, + response_mask=response_mask, + ) + + delight = -advantages * new_logprobs.detach() + mask_f = response_mask.float() + sample_chi = (delight * mask_f).sum(-1) / mask_f.sum(-1).clamp(min=1.0) + sample_gate = torch.sigmoid(sample_chi).unsqueeze(-1) + expected_pg = -(advantages * sample_gate) * ratio + torch.testing.assert_close(result.pg_losses, expected_pg) + # Gate is constant across the sequence for each sample. + gate_col0 = result.pg_losses[:, 0] / (-advantages[:, 0] * ratio[:, 0]) + gate_col1 = result.pg_losses[:, 1] / (-advantages[:, 1] * ratio[:, 1]) + torch.testing.assert_close(gate_col0, gate_col1) def test_dapo_clipping(self): config = _make_grpo_config(clip_lower=0.2, clip_higher=0.2) @@ -192,12 +229,12 @@ def test_dapo_clipping(self): new_logprobs = torch.randn(1, 3) advantages = torch.ones(1, 3) - pg_losses, pg_losses2, pg_loss_max, _ = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) expected_clamped = torch.clamp(ratio, 0.8, 1.2) - torch.testing.assert_close(pg_losses2, -advantages * expected_clamped) + torch.testing.assert_close(result.pg_losses2, -advantages * expected_clamped) def test_cispo_uses_detached_ratio(self): config = _make_grpo_config(loss_fn=grpo_utils.GRPOLossType.cispo, clip_higher=0.2) @@ -205,11 +242,11 @@ def test_cispo_uses_detached_ratio(self): new_logprobs = torch.randn(1, 3, requires_grad=True) advantages = torch.ones(1, 3) - pg_losses, pg_losses2, pg_loss_max, _ = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - pg_loss_max.sum().backward() + result.pg_loss_max.sum().backward() self.assertIsNone(ratio.grad) self.assertIsNotNone(new_logprobs.grad) @@ -221,11 +258,11 @@ def test_with_ref_logprobs(self): advantages = torch.randn(batch_size, seq_len) ref_logprobs = torch.randn(batch_size, seq_len) - _, _, _, kl = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=ref_logprobs, config=config ) - self.assertFalse(torch.all(kl == 0)) + self.assertFalse(torch.all(result.kl == 0)) def test_without_ref_logprobs(self): config = _make_grpo_config() @@ -233,11 +270,11 @@ def test_without_ref_logprobs(self): ratio = torch.exp(torch.randn(2, 4)) advantages = torch.randn(2, 4) - _, _, _, kl = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - torch.testing.assert_close(kl, torch.zeros_like(kl)) + torch.testing.assert_close(result.kl, torch.zeros_like(result.kl)) def test_tis_weights(self): config = _make_grpo_config() @@ -246,7 +283,7 @@ def test_tis_weights(self): advantages = torch.randn(2, 4) tis_weights = torch.full((2, 4), 2.0) - pg_no_tis, pg2_no_tis, _, _ = grpo_utils.compute_grpo_loss( + no_tis = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, @@ -255,7 +292,7 @@ def test_tis_weights(self): tis_weights=None, ) - pg_tis, pg2_tis, _, _ = grpo_utils.compute_grpo_loss( + tis = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, @@ -264,8 +301,8 @@ def test_tis_weights(self): tis_weights=tis_weights, ) - torch.testing.assert_close(pg_tis, pg_no_tis * 2.0) - torch.testing.assert_close(pg2_tis, pg2_no_tis * 2.0) + torch.testing.assert_close(tis.pg_losses, no_tis.pg_losses * 2.0) + torch.testing.assert_close(tis.pg_losses2, no_tis.pg_losses2 * 2.0) def test_invalid_loss_fn(self): config = _make_grpo_config(loss_fn="invalid") @@ -279,5 +316,40 @@ def test_invalid_loss_fn(self): ) +class TestKondoGateState(unittest.TestCase): + def test_warmup_always_passes(self): + config = _make_grpo_config( + use_kondo_gate=True, kondo_gate_rate=0.1, kondo_gate_warmup=4, kondo_gate_history_size=8 + ) + gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=0) + for _ in range(3): + should_backward, prob, lam = gate.decide(torch.tensor(0.0), torch.tensor(1.0)) + self.assertTrue(should_backward) + self.assertEqual(prob, 1.0) + self.assertTrue(math.isnan(lam)) + + def test_gate_rate_matches_target_in_expectation(self): + rate = 0.3 + config = _make_grpo_config( + use_kondo_gate=True, + kondo_gate_rate=rate, + kondo_gate_warmup=64, + kondo_gate_history_size=512, + kondo_gate_temperature=0.01, # hard threshold limit + ) + gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=42) + # Draw from a stable distribution so lambda stabilizes. + torch.manual_seed(0) + n = 2000 + values = torch.randn(n) + passes = 0 + for v in values: + should_backward, _, _ = gate.decide(v, torch.tensor(1.0)) + passes += int(should_backward) + # After warmup (64), remaining samples should pass at roughly `rate`. + observed = passes / n + self.assertAlmostEqual(observed, rate, delta=0.1) + + if __name__ == "__main__": unittest.main() diff --git a/run_aime_eval.sh b/run_aime_eval.sh new file mode 100755 index 0000000000..93599fa634 --- /dev/null +++ b/run_aime_eval.sh @@ -0,0 +1,26 @@ +#! /bin/bash +# Usage: ./run_aime_eval.sh +# check we have at least 2 arguments +if [ "$#" -lt 2 ]; then + echo "Usage: $0 " + exit 1 +fi +MODEL_NAME=$1 +BEAKER_DATASET_ID=$2 +uv run python scripts/submit_eval_jobs.py \ + --model_name "$MODEL_NAME" \ + --location "$BEAKER_DATASET_ID" \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale ai2/ceres-cirrascale ai2/neptune-cirrascale \ + --beaker_image oe-eval-beaker/oe_eval_auto_finbarr \ + --is_tuned \ + --preemptible \ + --priority urgent \ + --workspace ai2/open-instruct-dev \ + --use_hf_tokenizer_template \ + --run_oe_eval_experiments \ + --oe_eval_tasks aime:zs_cot_r1::maj_at_32_2025 \ + --evaluate_on_weka \ + --run_id placeholder \ + --oe_eval_max_length 8192 \ + --process_output r1_style \ + --skip_oi_evals diff --git a/scripts/eval/oe-eval.sh b/scripts/eval/oe-eval.sh index d4f2fd33a5..acb137dade 100755 --- a/scripts/eval/oe-eval.sh +++ b/scripts/eval/oe-eval.sh @@ -288,7 +288,7 @@ GPU_COUNT_OTHER=$((NUM_GPUS * 2)) MODEL_TYPE_OTHER="" # Build model args JSON with optional process_output -MODEL_ARGS="{\"model_path\":\"${MODEL_LOCATION}\", \"max_length\": ${MAX_LENGTH}, \"trust_remote_code\": \"true\"" +MODEL_ARGS="{\"model_path\":\"${MODEL_LOCATION}\", \"max_length\": ${MAX_LENGTH}, \"trust_remote_code\": true" if [[ -n "$PROCESS_OUTPUT" ]]; then MODEL_ARGS+=", \"process_output\": \"${PROCESS_OUTPUT}\"" fi @@ -319,7 +319,7 @@ for TASK in "${TASKS[@]}"; do # HF model: check if it's supported if [[ " ${SUPPORTED_MODELS[*]} " =~ " ${MODEL_LOCATION} " ]]; then # Supported HF model: remove model_path, no metadata needed - BASE_ARGS="{\"max_length\": ${MAX_LENGTH}, \"trust_remote_code\": \"true\"" + BASE_ARGS="{\"max_length\": ${MAX_LENGTH}, \"trust_remote_code\": true" if [[ -n "$PROCESS_OUTPUT" ]]; then BASE_ARGS+=", \"process_output\": \"${PROCESS_OUTPUT}\"" fi diff --git a/scripts/submit_eval_jobs_new.py b/scripts/submit_eval_jobs_new.py new file mode 100644 index 0000000000..3553cb8ac8 --- /dev/null +++ b/scripts/submit_eval_jobs_new.py @@ -0,0 +1,218 @@ +"""Submit evaluation jobs using allenai/olmo-eval-internal. + +Replicates the Beaker-dataset flow of `scripts/submit_eval_jobs.py` but uses +`olmo-eval run` (from olmo-eval-internal, baked into the Beaker image) as the +in-container command. The model is mounted at `/model` when `--location` refers +to a Beaker dataset. + +The in-container command is: + + olmo-eval run -m --harness default \ + -o provider.kind=vllm_server \ + -o provider.max_model_len= \ + -o provider.trust_remote_code=true \ + -t [-t ...] \ + --num-gpus \ + --output-dir /results + +This submits directly via `beaker experiment create`, avoiding gantry (and +therefore any git/ref requirements). The wrapper writes a v2 experiment spec +YAML to `configs/beaker_configs/auto_created/` and invokes the Beaker CLI. + +Example: + uv run python scripts/submit_eval_jobs_new.py \ + --model_name qwen3_4b_base_dapo_20260422_083224 \ + --location 01KPTSPMHGEZVYCDNR0XBVJCGZ \ + --tasks aime_2025:pass_at_32 \ + --max_length 8192 \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --priority urgent \ + --preemptible \ + --workspace ai2/open-instruct-dev +""" + +import argparse +import os +import re +import shlex +import subprocess +from datetime import date + +import yaml + + +BEAKER_ID_RE = re.compile(r"^[0-9A-Z]{26}$") +DEFAULT_CLUSTERS = ("ai2/jupiter",) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_name", type=str, required=True, help="Human-readable run name.") + parser.add_argument( + "--location", + type=str, + required=True, + help=( + "Model location. Accepts: a bare Beaker dataset id (26 uppercase alphanumerics), " + "'beaker://', an HF repo id (e.g. allenai/OLMo-2-1124-7B-Instruct), " + "an absolute Weka/NFS path, or a gs:// URL." + ), + ) + parser.add_argument( + "--tasks", + type=str, + default="aime_2025:pass_at_32", + help="Comma-separated olmo-eval task specs. See `olmo-eval tasks`/`olmo-eval suites`.", + ) + parser.add_argument("--num_gpus", type=int, default=1) + parser.add_argument("--cluster", nargs="+", default=list(DEFAULT_CLUSTERS)) + parser.add_argument("--priority", type=str, default="normal") + parser.add_argument("--preemptible", action="store_true") + parser.add_argument("--workspace", type=str, default="ai2/tulu-3-results") + parser.add_argument("--budget", type=str, default="ai2/oe-adapt") + parser.add_argument( + "--beaker_image", + type=str, + default="ai2-tylerm/olmo-eval-cu1281-trc290-amd64", + help="Beaker image with olmo-eval installed.", + ) + parser.add_argument("--revision", type=str, default=None, help="HF revision (git sha/tag).") + parser.add_argument( + "--max_length", + type=int, + default=32768, + help="Provider max_model_len. Sampling max_tokens comes from the task definition.", + ) + parser.add_argument( + "--sampling_max_tokens", + type=int, + default=None, + help="Override per-task sampling max_tokens (applied via -o max_tokens=N after each -t).", + ) + parser.add_argument("--experiment_name", type=str, default=None) + parser.add_argument( + "--dry_run", action="store_true", help="Write the spec YAML and print the beaker command, but do not submit." + ) + return parser.parse_args() + + +def resolve_model_mount(location: str) -> tuple[str, str | None]: + """Resolve --location into (model_path_in_container, beaker_dataset_id_or_None).""" + if location.startswith("beaker://"): + return "/model", location[len("beaker://") :] + if BEAKER_ID_RE.match(location): + return "/model", location + return location, None + + +def build_inner_cmd(args: argparse.Namespace, model_path: str) -> list[str]: + cmd = [ + "olmo-eval", + "run", + "-m", + model_path, + "--harness", + "default", + "-o", + "provider.kind=vllm_server", + "-o", + f"provider.max_model_len={args.max_length}", + "-o", + "provider.trust_remote_code=true", + ] + if args.revision: + cmd += ["-o", f"provider.revision={args.revision}"] + for task in args.tasks.split(","): + task = task.strip() + if not task: + continue + cmd += ["-t", task] + if args.sampling_max_tokens is not None: + cmd += ["-o", f"max_tokens={args.sampling_max_tokens}"] + cmd += ["--num-gpus", str(args.num_gpus)] + cmd += ["--output-dir", "/results"] + return cmd + + +INSTALL_SCRIPT = ( + "set -euo pipefail && " + "git clone --depth=1 " + "https://x-access-token:${GITHUB_TOKEN}@github.com/allenai/olmo-eval-internal.git " + "/opt/olmo-eval-internal && " + "cd /opt/olmo-eval-internal && " + "uv pip install --cache-dir /weka/oe-eval-default/olmo-eval-pypi-cache -e '.[vllm]' && " + "uv pip install --cache-dir /weka/oe-eval-default/olmo-eval-pypi-cache " + "--upgrade 'vllm[runai]>=0.19.0' 'transformers>=5.4.0' && " + "cd /workspace" +) + + +def build_spec(args: argparse.Namespace, inner_cmd: list[str], dataset_id: str | None, experiment_name: str) -> dict: + datasets: list[dict] = [ + {"mountPath": "/weka/oe-adapt-default", "source": {"weka": "oe-adapt-default"}}, + {"mountPath": "/weka/oe-training-default", "source": {"weka": "oe-training-default"}}, + {"mountPath": "/weka/oe-eval-default", "source": {"weka": "oe-eval-default"}}, + ] + if dataset_id: + datasets.append({"mountPath": "/model", "source": {"beaker": dataset_id}}) + + full_command = f"{INSTALL_SCRIPT} && {shlex.join(inner_cmd)}" + + return { + "version": "v2", + "description": experiment_name, + "budget": args.budget, + "retry": {"allowedTaskRetries": 2}, + "tasks": [ + { + "name": experiment_name, + "image": {"beaker": args.beaker_image}, + "command": ["/bin/bash", "-c"], + "arguments": [full_command], + "envVars": [ + {"name": "HF_TOKEN", "secret": "HF_TOKEN"}, + {"name": "OPENAI_API_KEY", "secret": "openai_api_key"}, + {"name": "GITHUB_TOKEN", "secret": "GITHUB_TOKEN"}, + {"name": "VLLM_ALLOW_LONG_MAX_MODEL_LEN", "value": "1"}, + ], + "datasets": datasets, + "result": {"path": "/results"}, + "resources": {"gpuCount": args.num_gpus}, + "constraints": {"cluster": list(args.cluster)}, + "context": {"priority": args.priority, "preemptible": args.preemptible}, + } + ], + } + + +def main() -> None: + args = parse_args() + + if len(args.workspace.split("/")) != 2 or not all(args.workspace.split("/")): + raise ValueError(f"--workspace must be '/'. Received: '{args.workspace}'") + + model_path, dataset_id = resolve_model_mount(args.location) + inner_cmd = build_inner_cmd(args, model_path) + + today = date.today().strftime("%m%d%Y") + experiment_name = (args.experiment_name or f"olmo_eval_{args.model_name}_{today}")[:128] + spec = build_spec(args, inner_cmd, dataset_id, experiment_name) + + out_dir = "configs/beaker_configs/auto_created" + os.makedirs(out_dir, exist_ok=True) + spec_path = os.path.join(out_dir, f"{experiment_name}.yaml") + with open(spec_path, "w") as f: + yaml.dump(spec, f, default_flow_style=False, sort_keys=False) + + print("Inner command:", shlex.join(inner_cmd)) + print("Spec written to:", spec_path) + + beaker_cmd = ["beaker", "experiment", "create", spec_path, "--workspace", args.workspace] + print("Running:", shlex.join(beaker_cmd)) + if args.dry_run: + return + subprocess.run(beaker_cmd, check=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/train/debug/large_test_script.sh b/scripts/train/debug/large_test_script.sh index 72e615f4ee..e55762b08e 100755 --- a/scripts/train/debug/large_test_script.sh +++ b/scripts/train/debug/large_test_script.sh @@ -67,4 +67,8 @@ uv run python mason.py \ --checkpoint_state_dir /tmp/checkpoint_test \ --active_sampling \ --async_steps 4 \ + --use_delight true \ + --use_kondo_gate true \ + --kondo_gate_rate 0.5 \ + --kondo_gate_warmup 16 \ --push_to_hub False diff --git a/scripts/train/qwen/qwen3_4b_dapo_math.sh b/scripts/train/qwen/qwen3_4b_dapo_math.sh index a82a474ca3..e77e3112c2 100644 --- a/scripts/train/qwen/qwen3_4b_dapo_math.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math.sh @@ -4,16 +4,18 @@ EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo}" RUN_NAME="${RUN_NAME:-${EXP_NAME}_$(date +%Y%m%d_%H%M%S)}" NUM_GPUS="${NUM_GPUS:-8}" -BEAKER_IMAGE="${BEAKER_IMAGE:-nathanl/open_instruct_auto}" +BEAKER_IMAGE="${1:-${BEAKER_IMAGE:-nathanl/open_instruct_auto}}" +shift || true -CLUSTER="${CLUSTER:-ai2/jupiter ai2/ceres}" -PRIORITY="${PRIORITY:-high}" +CLUSTER="${CLUSTER:-ai2/jupiter}" +PRIORITY="${PRIORITY:-urgent}" +WORKSPACE="${WORKSPACE:-ai2/open-instruct-dev}" uv run mason.py \ --task_name ${EXP_NAME} \ --description "${RUN_NAME}" \ --cluster ${CLUSTER} \ - --workspace ai2/oe-adapt-code \ + --workspace ${WORKSPACE} \ --priority ${PRIORITY} \ --pure_docker_mode \ --no_auto_dataset_cache \ @@ -71,4 +73,5 @@ uv run open_instruct/grpo_fast.py \ --chat_template qwen_instruct_user_boxed_math \ --load_ref_policy False \ --keep_last_n_checkpoints -1 \ + --use_delight true \ --push_to_hub False "$@"