From 185f3f6428a68a3a7ea1c166676bf47f039dfa88 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 20 Apr 2026 16:31:42 -0600 Subject: [PATCH 01/17] Add Delightful Policy Gradient gate (use_delight) to GRPO loss and enable it in large_test_script. Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_utils.py | 7 +++++++ scripts/train/debug/large_test_script.sh | 1 + 2 files changed, 8 insertions(+) diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index 8a5811acc9..998c2cb9c5 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -118,6 +118,9 @@ class GRPOExperimentConfig( """Whether to load and use a reference policy for KL penalty calculation.""" loss_fn: GRPOLossType = GRPOLossType.dapo """Whether to use DAPO or CISPO loss function.""" + use_delight: bool = False + """Whether to gate per-token policy-gradient terms with the Delightful Policy Gradient sigmoid + of delight = advantage * surprisal (https://arxiv.org/abs/2603.14608).""" record_entropy: bool = False """whether to record the entropy of the policy during training. Uses extra memory.""" use_vllm_logprobs: bool = False @@ -342,6 +345,10 @@ def compute_grpo_loss( config: GRPOExperimentConfig, tis_weights: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if config.use_delight: + # Delightful Policy Gradient gate; temperature eta is fixed to 1 per the paper. + advantages = advantages * torch.sigmoid(-advantages * new_logprobs.detach()) + if config.loss_fn == GRPOLossType.dapo: pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - config.clip_lower, 1.0 + config.clip_higher) diff --git a/scripts/train/debug/large_test_script.sh b/scripts/train/debug/large_test_script.sh index 72e615f4ee..d5729effbd 100755 --- a/scripts/train/debug/large_test_script.sh +++ b/scripts/train/debug/large_test_script.sh @@ -67,4 +67,5 @@ uv run python mason.py \ --checkpoint_state_dir /tmp/checkpoint_test \ --active_sampling \ --async_steps 4 \ + --use_delight true \ --push_to_hub False From 853d736f8362804eb00e70e709c5813d298bb578 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 20 Apr 2026 16:48:33 -0600 Subject: [PATCH 02/17] Add CHANGELOG entry for --use_delight. Co-Authored-By: Claude Opus 4.7 --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17edc7af8b..bb9b5f09be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ### Changed +- Add `--use_delight` flag to GRPO loss that gates per-token policy-gradient terms with the Delightful Policy Gradient sigmoid (https://github.com/allenai/open-instruct/pull/1628). - Pass `attention_mask=None` in GRPO `forward_for_logprobs` calls — HF constructs the correct 3D intra-document mask from `position_ids` internally (https://github.com/allenai/open-instruct/pull/1617). - Migrate GRPO trainer→vLLM weight sync to vLLM 0.16.0's native weight transfer API (`NCCLWeightTransferEngine`), replacing custom NCCL process-group and broadcast code (https://github.com/allenai/open-instruct/pull/1515). - Extend pre-commit hook to also ban `nonlocal` keyword (https://github.com/allenai/open-instruct/pull/1613). From 77588c9b2d53f51ab6d67bffc2990a2186d5cf28 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 20 Apr 2026 16:51:12 -0600 Subject: [PATCH 03/17] Add Kondo gate (per-sample Bernoulli backward-skip on delight) to GRPO. Co-Authored-By: Claude Opus 4.7 --- CHANGELOG.md | 1 + open_instruct/grpo_fast.py | 48 ++++++++-- open_instruct/grpo_utils.py | 90 ++++++++++++++++++- open_instruct/olmo_core_train_modules.py | 41 ++++++++- open_instruct/test_olmo_core_train_modules.py | 66 ++++++++++++-- scripts/train/debug/large_test_script.sh | 2 + 6 files changed, 227 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb9b5f09be..c2f9c4a491 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ### Changed +- Add `--use_kondo_gate` flag to GRPO that skips backward passes on low-delight samples via the Kondo gate (https://arxiv.org/abs/2603.20526), with `--kondo_gate_rate`, `--kondo_gate_temperature`, `--kondo_gate_history_size`, and `--kondo_gate_warmup` controls. - Add `--use_delight` flag to GRPO loss that gates per-token policy-gradient terms with the Delightful Policy Gradient sigmoid (https://github.com/allenai/open-instruct/pull/1628). - Pass `attention_mask=None` in GRPO `forward_for_logprobs` calls — HF constructs the correct 3D intra-document mask from `position_ids` internally (https://github.com/allenai/open-instruct/pull/1617). - Migrate GRPO trainer→vLLM weight sync to vLLM 0.16.0's native weight transfer API (`NCCLWeightTransferEngine`), replacing custom NCCL process-group and broadcast code (https://github.com/allenai/open-instruct/pull/1515). diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 468da7c174..7452a1b588 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -471,6 +471,11 @@ def load(self, path: str, map_location=None): alpha=args.alpha, ) self.local_metrics = utils.MetricsTracker(max_metrics=512, device=self.device) + self._kondo_gate = ( + grpo_utils.KondoGateState(args, self.device, process_group=None, seed=args.seed) + if args.use_kondo_gate + else None + ) if self.mpu is not None: self.splitter = UlyssesSPSplitter( @@ -654,6 +659,8 @@ def step(self): token_counts_per_sample = torch.stack([mask[:, 1:].sum().float() for mask in data_BT.response_masks]) device = token_counts_per_sample.device grad_norms: list[float] = [] # May include nan/inf values reported by DeepSpeed. + group_had_backward = False + kondo_gate_stats: list[tuple[int, float, float]] = [] # (should_backward, gate_prob, lambda) per sample # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): loss_stats_B = grpo_utils.create_loss_stats(num_samples, device, record_entropy=self.args.record_entropy) @@ -729,7 +736,7 @@ def step(self): self.args.truncated_importance_sampling_ratio_cap, ) - pg_losses_BT, pg_losses2_BT, pg_loss_max_BT, kl_BT = grpo_utils.compute_grpo_loss( + pg_losses_BT, pg_losses2_BT, pg_loss_max_BT, kl_BT, delight_BT = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs_BT, ratio=ratio_BT, advantages=data_BT.advantages[i][:, 1:], @@ -746,15 +753,32 @@ def step(self): # up, adjusting for the sequence parallel size (adjust by dp world size). loss *= self.args.world_size // self.args.sequence_parallel_size + should_backward = True + gate_prob = 1.0 + gate_lambda = float("-inf") + if self._kondo_gate is not None: + sample_delight = masked_mean(delight_BT, response_mask_BT, None, None) + should_backward, gate_prob, gate_lambda = self._kondo_gate.decide(sample_delight) + kondo_gate_stats.append((int(should_backward), gate_prob, gate_lambda)) + # Clear CUDA cache before backward pass to free memory for reduce_scatter operations torch.cuda.empty_cache() is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0 - # Tell deepspeed whether this backward is the last in the accumulation group. - self.model.set_gradient_accumulation_boundary(is_accumulation_boundary) - self.model.backward(loss) + if should_backward: + # Tell deepspeed whether this backward is the last in the accumulation group. + self.model.set_gradient_accumulation_boundary(is_accumulation_boundary) + self.model.backward(loss) + group_had_backward = True + elif is_accumulation_boundary and group_had_backward: + # The last sample in the group was gated but earlier ungated samples left + # un-reduce-scattered grads. Trigger the reduce-scatter with a zero-contribution backward. + self.model.set_gradient_accumulation_boundary(True) + self.model.backward(loss * 0.0) if is_accumulation_boundary: - self.model.step() - grad_norms.append(float(self.model.get_global_grad_norm())) + if group_had_backward: + self.model.step() + grad_norms.append(float(self.model.get_global_grad_norm())) + group_had_backward = False local_step += 1 grpo_utils.populate_sample_loss_stats( loss_stats_B, @@ -776,7 +800,17 @@ def step(self): batch_metrics = batch_data["metrics"] with torch.no_grad(): self._compute_loss_metrics(loss_stats_B, token_counts_per_sample) - self.local_metrics["optim/grad_norm"] = sum(grad_norms) / len(grad_norms) + self.local_metrics["optim/grad_norm"] = sum(grad_norms) / len(grad_norms) if grad_norms else 0.0 + if self._kondo_gate is not None and kondo_gate_stats: + self.local_metrics["val/kondo_gate_backward_frac"] = sum(s[0] for s in kondo_gate_stats) / len( + kondo_gate_stats + ) + self.local_metrics["val/kondo_gate_prob_avg"] = sum(s[1] for s in kondo_gate_stats) / len( + kondo_gate_stats + ) + finite_lams = [s[2] for s in kondo_gate_stats if math.isfinite(s[2])] + if finite_lams: + self.local_metrics["val/kondo_lambda"] = sum(finite_lams) / len(finite_lams) array_metrics = {} for key, value in batch_metrics.items(): if value is None: diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index 998c2cb9c5..fe3617f9cb 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -121,6 +121,18 @@ class GRPOExperimentConfig( use_delight: bool = False """Whether to gate per-token policy-gradient terms with the Delightful Policy Gradient sigmoid of delight = advantage * surprisal (https://arxiv.org/abs/2603.14608).""" + use_kondo_gate: bool = False + """Whether to enable the Kondo gate (https://arxiv.org/abs/2603.20526): per-sample Bernoulli gate + on whether to run the backward pass, driven by sample-level delight against an adaptive threshold.""" + kondo_gate_rate: float = 1.0 + """Target fraction rho of samples that receive a backward pass. 1.0 is a no-op even when the gate + is enabled (always backward). Smaller values keep only the highest-delight samples.""" + kondo_gate_temperature: float = 1.0 + """Temperature eta in the Kondo gate Bernoulli probability sigma((chi - lambda) / eta).""" + kondo_gate_history_size: int = 1024 + """Size of the ring buffer of past sample delights used to compute lambda = quantile_{1-rho}.""" + kondo_gate_warmup: int = 128 + """Never gate until the history contains at least this many sample delights.""" record_entropy: bool = False """whether to record the entropy of the policy during training. Uses extra memory.""" use_vllm_logprobs: bool = False @@ -236,6 +248,18 @@ def __post_init__(self): raise ValueError(f"`gs_bucket_path` must start with 'gs://', got: {self.gs_bucket_path}") if self.sequence_parallel_size > 1 and self.deepspeed_stage != 3: raise ValueError("`sequence_parallel_size` > 1 requires `deepspeed_stage` to be 3!") + if self.use_kondo_gate: + if not (0.0 < self.kondo_gate_rate <= 1.0): + raise ValueError(f"`kondo_gate_rate` must be in (0, 1], got {self.kondo_gate_rate}") + if self.kondo_gate_temperature <= 0.0: + raise ValueError(f"`kondo_gate_temperature` must be > 0, got {self.kondo_gate_temperature}") + if self.kondo_gate_warmup <= 0: + raise ValueError(f"`kondo_gate_warmup` must be > 0, got {self.kondo_gate_warmup}") + if self.kondo_gate_history_size < self.kondo_gate_warmup: + raise ValueError( + f"`kondo_gate_history_size` ({self.kondo_gate_history_size}) must be >= " + f"`kondo_gate_warmup` ({self.kondo_gate_warmup})." + ) total_learner_gpus = sum(self.num_learners_per_node) if self.fsdp_shard_degree is not None and self.fsdp_num_replicas is not None: @@ -344,10 +368,11 @@ def compute_grpo_loss( ref_logprobs: torch.Tensor | None, config: GRPOExperimentConfig, tis_weights: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + delight = -advantages * new_logprobs.detach() if config.use_delight: # Delightful Policy Gradient gate; temperature eta is fixed to 1 per the paper. - advantages = advantages * torch.sigmoid(-advantages * new_logprobs.detach()) + advantages = advantages * torch.sigmoid(delight) if config.loss_fn == GRPOLossType.dapo: pg_losses = -advantages * ratio @@ -376,7 +401,66 @@ def compute_grpo_loss( else: kl = torch.zeros_like(pg_loss_max) - return pg_losses, pg_losses2, pg_loss_max, kl + return pg_losses, pg_losses2, pg_loss_max, kl, delight + + +class KondoGateState: + """Per-sample Kondo gate over delight (https://arxiv.org/abs/2603.20526). + + Maintains a ring buffer of past sample-level delight values, computes an adaptive + threshold lambda = quantile_{1-rho}(history), and draws a Bernoulli gate with + probability sigma((chi - lambda) / eta). All-reduces the sample delight across DP + ranks and uses an identically-seeded generator so every rank produces the same + gate decision -- required to keep DeepSpeed / FSDP collectives in sync. + """ + + def __init__( + self, + config: GRPOExperimentConfig, + device: torch.device, + process_group: dist.ProcessGroup | None = None, + seed: int = 0, + ) -> None: + self.config = config + self.device = device + self.process_group = process_group + self.history_size = config.kondo_gate_history_size + self.warmup = config.kondo_gate_warmup + self.rate = config.kondo_gate_rate + self.temperature = config.kondo_gate_temperature + self._buffer = torch.zeros(self.history_size, device=device) + self._count = 0 + self._write_idx = 0 + self._generator = torch.Generator(device=device) + self._generator.manual_seed(int(seed)) + + def _reduced_delight(self, sample_delight: torch.Tensor) -> torch.Tensor: + if not dist.is_available() or not dist.is_initialized(): + return sample_delight + value = sample_delight.detach().clone().to(self.device) + dist.all_reduce(value, op=dist.ReduceOp.SUM, group=self.process_group) + world_size = dist.get_world_size(group=self.process_group) + return value / world_size + + def _append(self, value: torch.Tensor) -> None: + self._buffer[self._write_idx] = value + self._write_idx = (self._write_idx + 1) % self.history_size + self._count = min(self._count + 1, self.history_size) + + def decide(self, sample_delight: torch.Tensor) -> tuple[bool, float, float]: + """All-reduces the scalar delight, appends to history, returns (should_backward, gate_prob, lambda). + + Identical return values on every rank in the process group. + """ + chi = self._reduced_delight(sample_delight).reshape(()) + self._append(chi) + if self._count < self.warmup or self.rate >= 1.0: + return True, 1.0, float("-inf") + history = self._buffer[: self._count] + lam = torch.quantile(history, 1.0 - self.rate) + prob = torch.sigmoid((chi - lam) / self.temperature) + gate = torch.bernoulli(prob, generator=self._generator) + return bool(gate.item()), float(prob.item()), float(lam.item()) def forward_for_logprobs( diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 409eb9bb76..f297816fa3 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -321,6 +321,8 @@ def __init__( if ref_policy is not None: self.ref_policy = ref_policy.to(device=self.device).eval().requires_grad_(False) + self._kondo_gate: grpo_utils.KondoGateState | None = None + def pre_train(self): # GRPO batches are prompt-grouped and do their own accumulation/token normalization # inside train_batch(), so the base TransformerTrainModule global-batch validation @@ -413,6 +415,13 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: num_steps = 0 local_step = 0 + group_had_backward = False + kondo_gate_stats: list[tuple[int, float, float]] = [] + + if self.grpo_config.use_kondo_gate and self._kondo_gate is None: + self._kondo_gate = grpo_utils.KondoGateState( + self.grpo_config, self.device, process_group=self.trainer.dp_process_group, seed=self.grpo_config.seed + ) for epoch_idx in range(self.grpo_config.num_epochs): for sample_idx in range(num_samples): @@ -450,7 +459,7 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: old_logprob, vllm_logprobs, response_mask, self.grpo_config.truncated_importance_sampling_ratio_cap ) - pg_losses, pg_losses2, pg_loss, kl = grpo_utils.compute_grpo_loss( + pg_losses, pg_losses2, pg_loss, kl, delight = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages[:, 1:], @@ -464,7 +473,18 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: loss = masked_mean(pg_loss + self.grpo_config.beta * kl, response_mask, None, loss_denominator) loss = loss * dp_world_size - loss.backward() + + should_backward = True + gate_prob = 1.0 + gate_lambda = float("-inf") + if self._kondo_gate is not None: + sample_delight = masked_mean(delight, response_mask, None, None) + should_backward, gate_prob, gate_lambda = self._kondo_gate.decide(sample_delight) + kondo_gate_stats.append((int(should_backward), gate_prob, gate_lambda)) + + if should_backward: + loss.backward() + group_had_backward = True grpo_utils.populate_sample_loss_stats( loss_stats_B, @@ -487,14 +507,16 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: local_step += 1 if local_step % accumulation_steps == 0: - if not dry_run: + if not dry_run and group_had_backward: self.optim_step() self.zero_grads() + group_had_backward = False if local_step % accumulation_steps != 0: - if not dry_run: + if not dry_run and group_had_backward: self.optim_step() self.zero_grads() + group_had_backward = False if not dry_run and num_steps > 0: local_metrics = grpo_utils.compute_metrics_from_loss_stats(loss_stats_B, token_counts) @@ -516,3 +538,14 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: ) self.record_metric("lr", float(lr), reduce_type=None) self.record_metric("_token_count", global_tokens, reduce_type=None) + if self._kondo_gate is not None and kondo_gate_stats: + n = len(kondo_gate_stats) + self.record_metric( + "val/kondo_gate_backward_frac", sum(s[0] for s in kondo_gate_stats) / n, reduce_type=None + ) + self.record_metric( + "val/kondo_gate_prob_avg", sum(s[1] for s in kondo_gate_stats) / n, reduce_type=None + ) + finite_lams = [s[2] for s in kondo_gate_stats if math.isfinite(s[2])] + if finite_lams: + self.record_metric("val/kondo_lambda", sum(finite_lams) / len(finite_lams), reduce_type=None) diff --git a/open_instruct/test_olmo_core_train_modules.py b/open_instruct/test_olmo_core_train_modules.py index eb0714af20..347d026545 100644 --- a/open_instruct/test_olmo_core_train_modules.py +++ b/open_instruct/test_olmo_core_train_modules.py @@ -160,6 +160,12 @@ def _make_grpo_config(**kwargs) -> grpo_utils.GRPOExperimentConfig: "kl_estimator": 2, "loss_fn": grpo_utils.GRPOLossType.dapo, "load_ref_policy": False, + "use_delight": False, + "use_kondo_gate": False, + "kondo_gate_rate": 1.0, + "kondo_gate_temperature": 1.0, + "kondo_gate_history_size": 32, + "kondo_gate_warmup": 8, } defaults.update(kwargs) config = MagicMock(spec=grpo_utils.GRPOExperimentConfig) @@ -177,7 +183,7 @@ def test_output_shapes(self, _name, loss_type): ratio = torch.exp(torch.randn(batch_size, seq_len)) advantages = torch.randn(batch_size, seq_len) - pg_losses, pg_losses2, pg_loss_max, kl = grpo_utils.compute_grpo_loss( + pg_losses, pg_losses2, pg_loss_max, kl, delight = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) @@ -185,6 +191,8 @@ def test_output_shapes(self, _name, loss_type): self.assertEqual(pg_losses2.shape, (batch_size, seq_len)) self.assertEqual(pg_loss_max.shape, (batch_size, seq_len)) self.assertEqual(kl.shape, (batch_size, seq_len)) + self.assertEqual(delight.shape, (batch_size, seq_len)) + torch.testing.assert_close(delight, -advantages * new_logprobs.detach()) def test_dapo_clipping(self): config = _make_grpo_config(clip_lower=0.2, clip_higher=0.2) @@ -192,7 +200,7 @@ def test_dapo_clipping(self): new_logprobs = torch.randn(1, 3) advantages = torch.ones(1, 3) - pg_losses, pg_losses2, pg_loss_max, _ = grpo_utils.compute_grpo_loss( + pg_losses, pg_losses2, pg_loss_max, _, _ = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) @@ -205,7 +213,7 @@ def test_cispo_uses_detached_ratio(self): new_logprobs = torch.randn(1, 3, requires_grad=True) advantages = torch.ones(1, 3) - pg_losses, pg_losses2, pg_loss_max, _ = grpo_utils.compute_grpo_loss( + pg_losses, pg_losses2, pg_loss_max, _, _ = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) @@ -221,7 +229,7 @@ def test_with_ref_logprobs(self): advantages = torch.randn(batch_size, seq_len) ref_logprobs = torch.randn(batch_size, seq_len) - _, _, _, kl = grpo_utils.compute_grpo_loss( + _, _, _, kl, _ = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=ref_logprobs, config=config ) @@ -233,7 +241,7 @@ def test_without_ref_logprobs(self): ratio = torch.exp(torch.randn(2, 4)) advantages = torch.randn(2, 4) - _, _, _, kl = grpo_utils.compute_grpo_loss( + _, _, _, kl, _ = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) @@ -246,7 +254,7 @@ def test_tis_weights(self): advantages = torch.randn(2, 4) tis_weights = torch.full((2, 4), 2.0) - pg_no_tis, pg2_no_tis, _, _ = grpo_utils.compute_grpo_loss( + pg_no_tis, pg2_no_tis, _, _, _ = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, @@ -255,7 +263,7 @@ def test_tis_weights(self): tis_weights=None, ) - pg_tis, pg2_tis, _, _ = grpo_utils.compute_grpo_loss( + pg_tis, pg2_tis, _, _, _ = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, @@ -279,5 +287,49 @@ def test_invalid_loss_fn(self): ) +class TestKondoGateState(unittest.TestCase): + def test_warmup_always_passes(self): + config = _make_grpo_config( + use_kondo_gate=True, kondo_gate_rate=0.1, kondo_gate_warmup=4, kondo_gate_history_size=8 + ) + gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=0) + for _ in range(3): + should_backward, prob, lam = gate.decide(torch.tensor(0.0)) + self.assertTrue(should_backward) + self.assertEqual(prob, 1.0) + self.assertEqual(lam, float("-inf")) + + def test_rate_one_always_passes(self): + config = _make_grpo_config( + use_kondo_gate=True, kondo_gate_rate=1.0, kondo_gate_warmup=2, kondo_gate_history_size=8 + ) + gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=0) + for _ in range(16): + should_backward, _, _ = gate.decide(torch.tensor(torch.randn(1).item())) + self.assertTrue(should_backward) + + def test_gate_rate_matches_target_in_expectation(self): + rate = 0.3 + config = _make_grpo_config( + use_kondo_gate=True, + kondo_gate_rate=rate, + kondo_gate_warmup=64, + kondo_gate_history_size=512, + kondo_gate_temperature=0.01, # hard threshold limit + ) + gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=42) + # Draw from a stable distribution so lambda stabilizes. + torch.manual_seed(0) + n = 2000 + values = torch.randn(n) + passes = 0 + for v in values: + should_backward, _, _ = gate.decide(v) + passes += int(should_backward) + # After warmup (64), remaining samples should pass at roughly `rate`. + observed = passes / n + self.assertAlmostEqual(observed, rate, delta=0.1) + + if __name__ == "__main__": unittest.main() diff --git a/scripts/train/debug/large_test_script.sh b/scripts/train/debug/large_test_script.sh index d5729effbd..aaef67783e 100755 --- a/scripts/train/debug/large_test_script.sh +++ b/scripts/train/debug/large_test_script.sh @@ -68,4 +68,6 @@ uv run python mason.py \ --active_sampling \ --async_steps 4 \ --use_delight true \ + --use_kondo_gate true \ + --kondo_gate_rate 0.5 \ --push_to_hub False From 4d5321cb6b1e0174386da8919bbde5c35be3b65f Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 20 Apr 2026 19:55:46 -0600 Subject: [PATCH 04/17] Simplify Kondo gate: NamedTuple decision, shared metrics helper, token-weighted chi, cached quantile. Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_fast.py | 38 ++++------ open_instruct/grpo_utils.py | 73 ++++++++++++++----- open_instruct/olmo_core_train_modules.py | 29 +++----- open_instruct/test_olmo_core_train_modules.py | 6 +- 4 files changed, 85 insertions(+), 61 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7452a1b588..da72c6e390 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -660,7 +660,7 @@ def step(self): device = token_counts_per_sample.device grad_norms: list[float] = [] # May include nan/inf values reported by DeepSpeed. group_had_backward = False - kondo_gate_stats: list[tuple[int, float, float]] = [] # (should_backward, gate_prob, lambda) per sample + kondo_gate_stats: list[grpo_utils.KondoGateDecision] = [] # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): loss_stats_B = grpo_utils.create_loss_stats(num_samples, device, record_entropy=self.args.record_entropy) @@ -753,25 +753,24 @@ def step(self): # up, adjusting for the sequence parallel size (adjust by dp world size). loss *= self.args.world_size // self.args.sequence_parallel_size - should_backward = True - gate_prob = 1.0 - gate_lambda = float("-inf") + decision = grpo_utils.KondoGateDecision(True, 1.0, float("-inf")) if self._kondo_gate is not None: - sample_delight = masked_mean(delight_BT, response_mask_BT, None, None) - should_backward, gate_prob, gate_lambda = self._kondo_gate.decide(sample_delight) - kondo_gate_stats.append((int(should_backward), gate_prob, gate_lambda)) + delight_sum = (delight_BT * response_mask_BT).sum() + token_count = response_mask_BT.sum().float() + decision = self._kondo_gate.decide(delight_sum, token_count) + kondo_gate_stats.append(decision) - # Clear CUDA cache before backward pass to free memory for reduce_scatter operations - torch.cuda.empty_cache() is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0 - if should_backward: - # Tell deepspeed whether this backward is the last in the accumulation group. + if decision.should_backward: + # Clear CUDA cache before backward pass to free memory for reduce_scatter operations + torch.cuda.empty_cache() self.model.set_gradient_accumulation_boundary(is_accumulation_boundary) self.model.backward(loss) group_had_backward = True elif is_accumulation_boundary and group_had_backward: # The last sample in the group was gated but earlier ungated samples left # un-reduce-scattered grads. Trigger the reduce-scatter with a zero-contribution backward. + torch.cuda.empty_cache() self.model.set_gradient_accumulation_boundary(True) self.model.backward(loss * 0.0) if is_accumulation_boundary: @@ -800,17 +799,12 @@ def step(self): batch_metrics = batch_data["metrics"] with torch.no_grad(): self._compute_loss_metrics(loss_stats_B, token_counts_per_sample) - self.local_metrics["optim/grad_norm"] = sum(grad_norms) / len(grad_norms) if grad_norms else 0.0 - if self._kondo_gate is not None and kondo_gate_stats: - self.local_metrics["val/kondo_gate_backward_frac"] = sum(s[0] for s in kondo_gate_stats) / len( - kondo_gate_stats - ) - self.local_metrics["val/kondo_gate_prob_avg"] = sum(s[1] for s in kondo_gate_stats) / len( - kondo_gate_stats - ) - finite_lams = [s[2] for s in kondo_gate_stats if math.isfinite(s[2])] - if finite_lams: - self.local_metrics["val/kondo_lambda"] = sum(finite_lams) / len(finite_lams) + self.local_metrics["optim/grad_norm"] = ( + sum(grad_norms) / len(grad_norms) if grad_norms else float("nan") + ) + if self._kondo_gate is not None: + for k, v in grpo_utils.summarize_kondo_gate_stats(kondo_gate_stats).items(): + self.local_metrics[k] = v array_metrics = {} for key, value in batch_metrics.items(): if value is None: diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index fe3617f9cb..3c60dc82b9 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -2,7 +2,7 @@ import math import os from dataclasses import dataclass, field -from typing import Literal +from typing import Literal, NamedTuple import numpy as np import torch @@ -404,6 +404,17 @@ def compute_grpo_loss( return pg_losses, pg_losses2, pg_loss_max, kl, delight +class KondoGateDecision(NamedTuple): + should_backward: bool + prob: float + lam: float + + +# Refresh the quantile lambda every N decide() calls; stale lambda is fine because the +# delight distribution shifts slowly relative to ring-buffer writes. +_KONDO_QUANTILE_REFRESH = 32 + + class KondoGateState: """Per-sample Kondo gate over delight (https://arxiv.org/abs/2603.20526). @@ -421,7 +432,6 @@ def __init__( process_group: dist.ProcessGroup | None = None, seed: int = 0, ) -> None: - self.config = config self.device = device self.process_group = process_group self.history_size = config.kondo_gate_history_size @@ -433,34 +443,63 @@ def __init__( self._write_idx = 0 self._generator = torch.Generator(device=device) self._generator.manual_seed(int(seed)) + self._cached_lam: torch.Tensor | None = None + self._calls_since_refresh = 0 + + def _reduced_chi(self, delight_sum: torch.Tensor, token_count: torch.Tensor) -> torch.Tensor: + """Reduce (sum_delight, sum_tokens) across the process group and return sum/count. - def _reduced_delight(self, sample_delight: torch.Tensor) -> torch.Tensor: - if not dist.is_available() or not dist.is_initialized(): - return sample_delight - value = sample_delight.detach().clone().to(self.device) - dist.all_reduce(value, op=dist.ReduceOp.SUM, group=self.process_group) - world_size = dist.get_world_size(group=self.process_group) - return value / world_size + With Ulysses SP, each rank holds a sequence-slice of its sample, so per-rank + slice-means differ across SP-mates. Reducing the numerator and denominator + separately gives the correct token-weighted mean regardless of slice lengths, + and the result is identical on every rank in the group (required to keep + DeepSpeed / FSDP collectives in sync). + """ + packed = torch.stack([delight_sum.detach(), token_count.detach()]).to(self.device) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(packed, op=dist.ReduceOp.SUM, group=self.process_group) + return packed[0] / packed[1].clamp_min(1.0) def _append(self, value: torch.Tensor) -> None: self._buffer[self._write_idx] = value self._write_idx = (self._write_idx + 1) % self.history_size self._count = min(self._count + 1, self.history_size) - def decide(self, sample_delight: torch.Tensor) -> tuple[bool, float, float]: - """All-reduces the scalar delight, appends to history, returns (should_backward, gate_prob, lambda). + def _quantile(self) -> torch.Tensor: + if self._cached_lam is None or self._calls_since_refresh >= _KONDO_QUANTILE_REFRESH: + self._cached_lam = torch.quantile(self._buffer[: self._count], 1.0 - self.rate) + self._calls_since_refresh = 0 + self._calls_since_refresh += 1 + return self._cached_lam + + def decide(self, delight_sum: torch.Tensor, token_count: torch.Tensor) -> KondoGateDecision: + """All-reduces (delight_sum, token_count), computes chi = sum/count, and gates. - Identical return values on every rank in the process group. + Returns identical values on every rank in the process group. """ - chi = self._reduced_delight(sample_delight).reshape(()) + chi = self._reduced_chi(delight_sum, token_count) self._append(chi) if self._count < self.warmup or self.rate >= 1.0: - return True, 1.0, float("-inf") - history = self._buffer[: self._count] - lam = torch.quantile(history, 1.0 - self.rate) + return KondoGateDecision(True, 1.0, float("-inf")) + lam = self._quantile() prob = torch.sigmoid((chi - lam) / self.temperature) gate = torch.bernoulli(prob, generator=self._generator) - return bool(gate.item()), float(prob.item()), float(lam.item()) + return KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item())) + + +def summarize_kondo_gate_stats(stats: list[KondoGateDecision]) -> dict[str, float]: + """Aggregate per-sample gate decisions into scalar metrics.""" + if not stats: + return {} + n = len(stats) + out = { + "val/kondo_gate_backward_frac": sum(int(s.should_backward) for s in stats) / n, + "val/kondo_gate_prob_avg": sum(s.prob for s in stats) / n, + } + finite_lams = [s.lam for s in stats if math.isfinite(s.lam)] + if finite_lams: + out["val/kondo_lambda"] = sum(finite_lams) / len(finite_lams) + return out def forward_for_logprobs( diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index f297816fa3..2ce5331700 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -416,7 +416,7 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: num_steps = 0 local_step = 0 group_had_backward = False - kondo_gate_stats: list[tuple[int, float, float]] = [] + kondo_gate_stats: list[grpo_utils.KondoGateDecision] = [] if self.grpo_config.use_kondo_gate and self._kondo_gate is None: self._kondo_gate = grpo_utils.KondoGateState( @@ -474,15 +474,14 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: loss = loss * dp_world_size - should_backward = True - gate_prob = 1.0 - gate_lambda = float("-inf") + decision = grpo_utils.KondoGateDecision(True, 1.0, float("-inf")) if self._kondo_gate is not None: - sample_delight = masked_mean(delight, response_mask, None, None) - should_backward, gate_prob, gate_lambda = self._kondo_gate.decide(sample_delight) - kondo_gate_stats.append((int(should_backward), gate_prob, gate_lambda)) + delight_sum = (delight * response_mask).sum() + token_count = response_mask.sum().float() + decision = self._kondo_gate.decide(delight_sum, token_count) + kondo_gate_stats.append(decision) - if should_backward: + if decision.should_backward: loss.backward() group_had_backward = True @@ -538,14 +537,6 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: ) self.record_metric("lr", float(lr), reduce_type=None) self.record_metric("_token_count", global_tokens, reduce_type=None) - if self._kondo_gate is not None and kondo_gate_stats: - n = len(kondo_gate_stats) - self.record_metric( - "val/kondo_gate_backward_frac", sum(s[0] for s in kondo_gate_stats) / n, reduce_type=None - ) - self.record_metric( - "val/kondo_gate_prob_avg", sum(s[1] for s in kondo_gate_stats) / n, reduce_type=None - ) - finite_lams = [s[2] for s in kondo_gate_stats if math.isfinite(s[2])] - if finite_lams: - self.record_metric("val/kondo_lambda", sum(finite_lams) / len(finite_lams), reduce_type=None) + if self._kondo_gate is not None: + for k, v in grpo_utils.summarize_kondo_gate_stats(kondo_gate_stats).items(): + self.record_metric(k, v, reduce_type=None) diff --git a/open_instruct/test_olmo_core_train_modules.py b/open_instruct/test_olmo_core_train_modules.py index 347d026545..c64dbbf327 100644 --- a/open_instruct/test_olmo_core_train_modules.py +++ b/open_instruct/test_olmo_core_train_modules.py @@ -294,7 +294,7 @@ def test_warmup_always_passes(self): ) gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=0) for _ in range(3): - should_backward, prob, lam = gate.decide(torch.tensor(0.0)) + should_backward, prob, lam = gate.decide(torch.tensor(0.0), torch.tensor(1.0)) self.assertTrue(should_backward) self.assertEqual(prob, 1.0) self.assertEqual(lam, float("-inf")) @@ -305,7 +305,7 @@ def test_rate_one_always_passes(self): ) gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=0) for _ in range(16): - should_backward, _, _ = gate.decide(torch.tensor(torch.randn(1).item())) + should_backward, _, _ = gate.decide(torch.tensor(torch.randn(1).item()), torch.tensor(1.0)) self.assertTrue(should_backward) def test_gate_rate_matches_target_in_expectation(self): @@ -324,7 +324,7 @@ def test_gate_rate_matches_target_in_expectation(self): values = torch.randn(n) passes = 0 for v in values: - should_backward, _, _ = gate.decide(v) + should_backward, _, _ = gate.decide(v, torch.tensor(1.0)) passes += int(should_backward) # After warmup (64), remaining samples should pass at roughly `rate`. observed = passes / n From 0440cc4fbbd2c4faba9e268fd7a5b4104b2c6655 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 20 Apr 2026 20:31:59 -0600 Subject: [PATCH 05/17] Debug Kondo gate: add tracing logs + simplify decide(delight, mask). Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_fast.py | 14 ++++-- open_instruct/grpo_utils.py | 57 +++++++++++++++--------- open_instruct/olmo_core_train_modules.py | 4 +- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index da72c6e390..4d65d01efd 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -755,9 +755,17 @@ def step(self): decision = grpo_utils.KondoGateDecision(True, 1.0, float("-inf")) if self._kondo_gate is not None: - delight_sum = (delight_BT * response_mask_BT).sum() - token_count = response_mask_BT.sum().float() - decision = self._kondo_gate.decide(delight_sum, token_count) + decision = self._kondo_gate.decide(delight_BT, response_mask_BT) + if self.rank == 0 and i < 3: + logger.info( + "[kondo] grpo_fast rank=0 i=%d delight_shape=%s mask_shape=%s " + "mask_sum=%.0f delight_sum=%.4g", + i, + tuple(delight_BT.shape), + tuple(response_mask_BT.shape), + float(response_mask_BT.sum()), + float((delight_BT * response_mask_BT).sum()), + ) kondo_gate_stats.append(decision) is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0 diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index 3c60dc82b9..fff4d4860e 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -410,11 +410,6 @@ class KondoGateDecision(NamedTuple): lam: float -# Refresh the quantile lambda every N decide() calls; stale lambda is fine because the -# delight distribution shifts slowly relative to ring-buffer writes. -_KONDO_QUANTILE_REFRESH = 32 - - class KondoGateState: """Per-sample Kondo gate over delight (https://arxiv.org/abs/2603.20526). @@ -443,10 +438,9 @@ def __init__( self._write_idx = 0 self._generator = torch.Generator(device=device) self._generator.manual_seed(int(seed)) - self._cached_lam: torch.Tensor | None = None - self._calls_since_refresh = 0 + self._log_calls = 0 - def _reduced_chi(self, delight_sum: torch.Tensor, token_count: torch.Tensor) -> torch.Tensor: + def _reduced_chi(self, delight: torch.Tensor, response_mask: torch.Tensor) -> torch.Tensor: """Reduce (sum_delight, sum_tokens) across the process group and return sum/count. With Ulysses SP, each rank holds a sequence-slice of its sample, so per-rank @@ -455,7 +449,7 @@ def _reduced_chi(self, delight_sum: torch.Tensor, token_count: torch.Tensor) -> and the result is identical on every rank in the group (required to keep DeepSpeed / FSDP collectives in sync). """ - packed = torch.stack([delight_sum.detach(), token_count.detach()]).to(self.device) + packed = torch.stack([(delight * response_mask).sum().detach(), response_mask.sum().float().detach()]) if dist.is_available() and dist.is_initialized(): dist.all_reduce(packed, op=dist.ReduceOp.SUM, group=self.process_group) return packed[0] / packed[1].clamp_min(1.0) @@ -465,26 +459,47 @@ def _append(self, value: torch.Tensor) -> None: self._write_idx = (self._write_idx + 1) % self.history_size self._count = min(self._count + 1, self.history_size) - def _quantile(self) -> torch.Tensor: - if self._cached_lam is None or self._calls_since_refresh >= _KONDO_QUANTILE_REFRESH: - self._cached_lam = torch.quantile(self._buffer[: self._count], 1.0 - self.rate) - self._calls_since_refresh = 0 - self._calls_since_refresh += 1 - return self._cached_lam - - def decide(self, delight_sum: torch.Tensor, token_count: torch.Tensor) -> KondoGateDecision: - """All-reduces (delight_sum, token_count), computes chi = sum/count, and gates. + def decide(self, delight: torch.Tensor, response_mask: torch.Tensor) -> KondoGateDecision: + """Computes token-weighted chi over the response, all-reduces across ranks, and gates. Returns identical values on every rank in the process group. """ - chi = self._reduced_chi(delight_sum, token_count) + chi = self._reduced_chi(delight, response_mask) self._append(chi) + self._log_calls += 1 + should_log = self._log_calls <= 5 or self._log_calls % 50 == 0 if self._count < self.warmup or self.rate >= 1.0: + if should_log: + logger.info( + "[kondo] call=%d count=%d warmup=%d rate=%.3f chi=%.6g -> WARMUP (pass-through)", + self._log_calls, + self._count, + self.warmup, + self.rate, + float(chi.item()), + ) return KondoGateDecision(True, 1.0, float("-inf")) - lam = self._quantile() + lam = torch.quantile(self._buffer[: self._count], 1.0 - self.rate) prob = torch.sigmoid((chi - lam) / self.temperature) gate = torch.bernoulli(prob, generator=self._generator) - return KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item())) + decision = KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item())) + if should_log: + buf = self._buffer[: self._count] + logger.info( + "[kondo] call=%d count=%d chi=%.6g lam=%.6g prob=%.4f gate=%d temp=%.3g " + "buf[min/med/max]=%.6g/%.6g/%.6g", + self._log_calls, + self._count, + float(chi.item()), + decision.lam, + decision.prob, + int(decision.should_backward), + self.temperature, + float(buf.min()), + float(buf.median()), + float(buf.max()), + ) + return decision def summarize_kondo_gate_stats(stats: list[KondoGateDecision]) -> dict[str, float]: diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 2ce5331700..d7f3b19f9f 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -476,9 +476,7 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: decision = grpo_utils.KondoGateDecision(True, 1.0, float("-inf")) if self._kondo_gate is not None: - delight_sum = (delight * response_mask).sum() - token_count = response_mask.sum().float() - decision = self._kondo_gate.decide(delight_sum, token_count) + decision = self._kondo_gate.decide(delight, response_mask) kondo_gate_stats.append(decision) if decision.should_backward: From 658827a824a87fe98f4dfadb295660f760592bd8 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 20 Apr 2026 20:34:03 -0600 Subject: [PATCH 06/17] Kondo gate: log every decide() call unconditionally for debug. Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_fast.py | 20 ++++++++-------- open_instruct/grpo_utils.py | 48 ++++++++++++++++--------------------- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 4d65d01efd..a8b5d37162 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -756,16 +756,16 @@ def step(self): decision = grpo_utils.KondoGateDecision(True, 1.0, float("-inf")) if self._kondo_gate is not None: decision = self._kondo_gate.decide(delight_BT, response_mask_BT) - if self.rank == 0 and i < 3: - logger.info( - "[kondo] grpo_fast rank=0 i=%d delight_shape=%s mask_shape=%s " - "mask_sum=%.0f delight_sum=%.4g", - i, - tuple(delight_BT.shape), - tuple(response_mask_BT.shape), - float(response_mask_BT.sum()), - float((delight_BT * response_mask_BT).sum()), - ) + logger.info( + "[kondo] grpo_fast rank=%d i=%d delight_shape=%s mask_shape=%s " + "mask_sum=%.0f delight_sum=%.4g", + self.rank, + i, + tuple(delight_BT.shape), + tuple(response_mask_BT.shape), + float(response_mask_BT.sum()), + float((delight_BT * response_mask_BT).sum()), + ) kondo_gate_stats.append(decision) is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0 diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index fff4d4860e..7fb13c9427 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -438,7 +438,6 @@ def __init__( self._write_idx = 0 self._generator = torch.Generator(device=device) self._generator.manual_seed(int(seed)) - self._log_calls = 0 def _reduced_chi(self, delight: torch.Tensor, response_mask: torch.Tensor) -> torch.Tensor: """Reduce (sum_delight, sum_tokens) across the process group and return sum/count. @@ -466,39 +465,32 @@ def decide(self, delight: torch.Tensor, response_mask: torch.Tensor) -> KondoGat """ chi = self._reduced_chi(delight, response_mask) self._append(chi) - self._log_calls += 1 - should_log = self._log_calls <= 5 or self._log_calls % 50 == 0 if self._count < self.warmup or self.rate >= 1.0: - if should_log: - logger.info( - "[kondo] call=%d count=%d warmup=%d rate=%.3f chi=%.6g -> WARMUP (pass-through)", - self._log_calls, - self._count, - self.warmup, - self.rate, - float(chi.item()), - ) + logger.info( + "[kondo] count=%d warmup=%d rate=%.3f chi=%.6g -> WARMUP (pass-through)", + self._count, + self.warmup, + self.rate, + float(chi.item()), + ) return KondoGateDecision(True, 1.0, float("-inf")) lam = torch.quantile(self._buffer[: self._count], 1.0 - self.rate) prob = torch.sigmoid((chi - lam) / self.temperature) gate = torch.bernoulli(prob, generator=self._generator) decision = KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item())) - if should_log: - buf = self._buffer[: self._count] - logger.info( - "[kondo] call=%d count=%d chi=%.6g lam=%.6g prob=%.4f gate=%d temp=%.3g " - "buf[min/med/max]=%.6g/%.6g/%.6g", - self._log_calls, - self._count, - float(chi.item()), - decision.lam, - decision.prob, - int(decision.should_backward), - self.temperature, - float(buf.min()), - float(buf.median()), - float(buf.max()), - ) + buf = self._buffer[: self._count] + logger.info( + "[kondo] count=%d chi=%.6g lam=%.6g prob=%.4f gate=%d temp=%.3g buf[min/med/max]=%.6g/%.6g/%.6g", + self._count, + float(chi.item()), + decision.lam, + decision.prob, + int(decision.should_backward), + self.temperature, + float(buf.min()), + float(buf.median()), + float(buf.max()), + ) return decision From b765f405d4f043c1231c0d76f8dcd1a8ad22d8b0 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 20 Apr 2026 20:35:52 -0600 Subject: [PATCH 07/17] Kondo gate: log quantile probes + frac_buf>lam to diagnose lambda. Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index 7fb13c9427..f90318c345 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -474,22 +474,27 @@ def decide(self, delight: torch.Tensor, response_mask: torch.Tensor) -> KondoGat float(chi.item()), ) return KondoGateDecision(True, 1.0, float("-inf")) - lam = torch.quantile(self._buffer[: self._count], 1.0 - self.rate) + buf = self._buffer[: self._count] + q = 1.0 - self.rate + lam = torch.quantile(buf, q) prob = torch.sigmoid((chi - lam) / self.temperature) gate = torch.bernoulli(prob, generator=self._generator) decision = KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item())) - buf = self._buffer[: self._count] + probe_qs = torch.tensor([0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0], device=buf.device, dtype=buf.dtype) + probes = torch.quantile(buf, probe_qs) + frac_above = float((buf > lam).float().mean()) logger.info( - "[kondo] count=%d chi=%.6g lam=%.6g prob=%.4f gate=%d temp=%.3g buf[min/med/max]=%.6g/%.6g/%.6g", + "[kondo] count=%d chi=%.6g q=%.3f lam=%.6g prob=%.4f gate=%d temp=%.3g " + "frac_buf>lam=%.3f buf_quantiles[0/.1/.25/.5/.75/.9/1]=%s", self._count, float(chi.item()), + q, decision.lam, decision.prob, int(decision.should_backward), self.temperature, - float(buf.min()), - float(buf.median()), - float(buf.max()), + frac_above, + [f"{float(v):.6g}" for v in probes], ) return decision From a8bd61f7697162171de8d79f1add1c9f109dc1f6 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 20 Apr 2026 20:51:47 -0600 Subject: [PATCH 08/17] Lower kondo_gate_warmup to 16 in large_test_script so debug runs exit warmup. Co-Authored-By: Claude Opus 4.7 --- scripts/train/debug/large_test_script.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train/debug/large_test_script.sh b/scripts/train/debug/large_test_script.sh index aaef67783e..e55762b08e 100755 --- a/scripts/train/debug/large_test_script.sh +++ b/scripts/train/debug/large_test_script.sh @@ -70,4 +70,5 @@ uv run python mason.py \ --use_delight true \ --use_kondo_gate true \ --kondo_gate_rate 0.5 \ + --kondo_gate_warmup 16 \ --push_to_hub False From 90d07ae182667b72ff0f87db05de8e3fcff10bd4 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 21 Apr 2026 07:58:57 -0600 Subject: [PATCH 09/17] Return dict from compute_grpo_loss instead of 5-tuple. Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_fast.py | 24 +++++------ open_instruct/grpo_utils.py | 38 ++++-------------- open_instruct/olmo_core_train_modules.py | 9 ++++- open_instruct/test_olmo_core_train_modules.py | 40 +++++++++---------- 4 files changed, 45 insertions(+), 66 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index a8b5d37162..3936648288 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -736,7 +736,7 @@ def step(self): self.args.truncated_importance_sampling_ratio_cap, ) - pg_losses_BT, pg_losses2_BT, pg_loss_max_BT, kl_BT, delight_BT = grpo_utils.compute_grpo_loss( + loss_terms = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs_BT, ratio=ratio_BT, advantages=data_BT.advantages[i][:, 1:], @@ -744,6 +744,11 @@ def step(self): config=self.args, tis_weights=tis_clamped_BT, ) + pg_losses_BT = loss_terms["pg_losses"] + pg_losses2_BT = loss_terms["pg_losses2"] + pg_loss_max_BT = loss_terms["pg_loss_max"] + kl_BT = loss_terms["kl"] + delight_BT = loss_terms["delight"] per_token_loss_BT = pg_loss_max_BT + self.args.beta * kl_BT loss = masked_mean(per_token_loss_BT, response_mask_BT, None, loss_denominator) @@ -753,19 +758,9 @@ def step(self): # up, adjusting for the sequence parallel size (adjust by dp world size). loss *= self.args.world_size // self.args.sequence_parallel_size - decision = grpo_utils.KondoGateDecision(True, 1.0, float("-inf")) + decision = grpo_utils.KONDO_GATE_PASSTHROUGH if self._kondo_gate is not None: decision = self._kondo_gate.decide(delight_BT, response_mask_BT) - logger.info( - "[kondo] grpo_fast rank=%d i=%d delight_shape=%s mask_shape=%s " - "mask_sum=%.0f delight_sum=%.4g", - self.rank, - i, - tuple(delight_BT.shape), - tuple(response_mask_BT.shape), - float(response_mask_BT.sum()), - float((delight_BT * response_mask_BT).sum()), - ) kondo_gate_stats.append(decision) is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0 @@ -776,8 +771,9 @@ def step(self): self.model.backward(loss) group_had_backward = True elif is_accumulation_boundary and group_had_backward: - # The last sample in the group was gated but earlier ungated samples left - # un-reduce-scattered grads. Trigger the reduce-scatter with a zero-contribution backward. + # DeepSpeed defers the accumulation-group reduce-scatter to the boundary + # backward; if the boundary sample is gated, we still need a backward here + # (zeroed so it contributes no gradient) to flush earlier micro-steps' grads. torch.cuda.empty_cache() self.model.set_gradient_accumulation_boundary(True) self.model.backward(loss * 0.0) diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index f90318c345..6992d4324c 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -368,7 +368,7 @@ def compute_grpo_loss( ref_logprobs: torch.Tensor | None, config: GRPOExperimentConfig, tis_weights: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> dict[str, torch.Tensor]: delight = -advantages * new_logprobs.detach() if config.use_delight: # Delightful Policy Gradient gate; temperature eta is fixed to 1 per the paper. @@ -401,7 +401,7 @@ def compute_grpo_loss( else: kl = torch.zeros_like(pg_loss_max) - return pg_losses, pg_losses2, pg_loss_max, kl, delight + return {"pg_losses": pg_losses, "pg_losses2": pg_losses2, "pg_loss_max": pg_loss_max, "kl": kl, "delight": delight} class KondoGateDecision(NamedTuple): @@ -410,6 +410,9 @@ class KondoGateDecision(NamedTuple): lam: float +KONDO_GATE_PASSTHROUGH = KondoGateDecision(True, 1.0, float("-inf")) + + class KondoGateState: """Per-sample Kondo gate over delight (https://arxiv.org/abs/2603.20526). @@ -466,37 +469,12 @@ def decide(self, delight: torch.Tensor, response_mask: torch.Tensor) -> KondoGat chi = self._reduced_chi(delight, response_mask) self._append(chi) if self._count < self.warmup or self.rate >= 1.0: - logger.info( - "[kondo] count=%d warmup=%d rate=%.3f chi=%.6g -> WARMUP (pass-through)", - self._count, - self.warmup, - self.rate, - float(chi.item()), - ) - return KondoGateDecision(True, 1.0, float("-inf")) + return KONDO_GATE_PASSTHROUGH buf = self._buffer[: self._count] - q = 1.0 - self.rate - lam = torch.quantile(buf, q) + lam = torch.quantile(buf, 1.0 - self.rate) prob = torch.sigmoid((chi - lam) / self.temperature) gate = torch.bernoulli(prob, generator=self._generator) - decision = KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item())) - probe_qs = torch.tensor([0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0], device=buf.device, dtype=buf.dtype) - probes = torch.quantile(buf, probe_qs) - frac_above = float((buf > lam).float().mean()) - logger.info( - "[kondo] count=%d chi=%.6g q=%.3f lam=%.6g prob=%.4f gate=%d temp=%.3g " - "frac_buf>lam=%.3f buf_quantiles[0/.1/.25/.5/.75/.9/1]=%s", - self._count, - float(chi.item()), - q, - decision.lam, - decision.prob, - int(decision.should_backward), - self.temperature, - frac_above, - [f"{float(v):.6g}" for v in probes], - ) - return decision + return KondoGateDecision(bool(gate.item()), float(prob.item()), float(lam.item())) def summarize_kondo_gate_stats(stats: list[KondoGateDecision]) -> dict[str, float]: diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index d7f3b19f9f..34da2b4f2a 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -459,7 +459,7 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: old_logprob, vllm_logprobs, response_mask, self.grpo_config.truncated_importance_sampling_ratio_cap ) - pg_losses, pg_losses2, pg_loss, kl, delight = grpo_utils.compute_grpo_loss( + loss_terms = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages[:, 1:], @@ -467,6 +467,11 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: config=self.grpo_config, tis_weights=tis_clamped, ) + pg_losses = loss_terms["pg_losses"] + pg_losses2 = loss_terms["pg_losses2"] + pg_loss = loss_terms["pg_loss_max"] + kl = loss_terms["kl"] + delight = loss_terms["delight"] batch_start = (sample_idx // accumulation_steps) * accumulation_steps loss_denominator = accumulation_token_counts[batch_start] @@ -474,7 +479,7 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: loss = loss * dp_world_size - decision = grpo_utils.KondoGateDecision(True, 1.0, float("-inf")) + decision = grpo_utils.KONDO_GATE_PASSTHROUGH if self._kondo_gate is not None: decision = self._kondo_gate.decide(delight, response_mask) kondo_gate_stats.append(decision) diff --git a/open_instruct/test_olmo_core_train_modules.py b/open_instruct/test_olmo_core_train_modules.py index c64dbbf327..702f50e125 100644 --- a/open_instruct/test_olmo_core_train_modules.py +++ b/open_instruct/test_olmo_core_train_modules.py @@ -183,16 +183,16 @@ def test_output_shapes(self, _name, loss_type): ratio = torch.exp(torch.randn(batch_size, seq_len)) advantages = torch.randn(batch_size, seq_len) - pg_losses, pg_losses2, pg_loss_max, kl, delight = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - self.assertEqual(pg_losses.shape, (batch_size, seq_len)) - self.assertEqual(pg_losses2.shape, (batch_size, seq_len)) - self.assertEqual(pg_loss_max.shape, (batch_size, seq_len)) - self.assertEqual(kl.shape, (batch_size, seq_len)) - self.assertEqual(delight.shape, (batch_size, seq_len)) - torch.testing.assert_close(delight, -advantages * new_logprobs.detach()) + self.assertEqual(result["pg_losses"].shape, (batch_size, seq_len)) + self.assertEqual(result["pg_losses2"].shape, (batch_size, seq_len)) + self.assertEqual(result["pg_loss_max"].shape, (batch_size, seq_len)) + self.assertEqual(result["kl"].shape, (batch_size, seq_len)) + self.assertEqual(result["delight"].shape, (batch_size, seq_len)) + torch.testing.assert_close(result["delight"], -advantages * new_logprobs.detach()) def test_dapo_clipping(self): config = _make_grpo_config(clip_lower=0.2, clip_higher=0.2) @@ -200,12 +200,12 @@ def test_dapo_clipping(self): new_logprobs = torch.randn(1, 3) advantages = torch.ones(1, 3) - pg_losses, pg_losses2, pg_loss_max, _, _ = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) expected_clamped = torch.clamp(ratio, 0.8, 1.2) - torch.testing.assert_close(pg_losses2, -advantages * expected_clamped) + torch.testing.assert_close(result["pg_losses2"], -advantages * expected_clamped) def test_cispo_uses_detached_ratio(self): config = _make_grpo_config(loss_fn=grpo_utils.GRPOLossType.cispo, clip_higher=0.2) @@ -213,11 +213,11 @@ def test_cispo_uses_detached_ratio(self): new_logprobs = torch.randn(1, 3, requires_grad=True) advantages = torch.ones(1, 3) - pg_losses, pg_losses2, pg_loss_max, _, _ = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - pg_loss_max.sum().backward() + result["pg_loss_max"].sum().backward() self.assertIsNone(ratio.grad) self.assertIsNotNone(new_logprobs.grad) @@ -229,11 +229,11 @@ def test_with_ref_logprobs(self): advantages = torch.randn(batch_size, seq_len) ref_logprobs = torch.randn(batch_size, seq_len) - _, _, _, kl, _ = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=ref_logprobs, config=config ) - self.assertFalse(torch.all(kl == 0)) + self.assertFalse(torch.all(result["kl"] == 0)) def test_without_ref_logprobs(self): config = _make_grpo_config() @@ -241,11 +241,11 @@ def test_without_ref_logprobs(self): ratio = torch.exp(torch.randn(2, 4)) advantages = torch.randn(2, 4) - _, _, _, kl, _ = grpo_utils.compute_grpo_loss( + result = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - torch.testing.assert_close(kl, torch.zeros_like(kl)) + torch.testing.assert_close(result["kl"], torch.zeros_like(result["kl"])) def test_tis_weights(self): config = _make_grpo_config() @@ -254,7 +254,7 @@ def test_tis_weights(self): advantages = torch.randn(2, 4) tis_weights = torch.full((2, 4), 2.0) - pg_no_tis, pg2_no_tis, _, _, _ = grpo_utils.compute_grpo_loss( + no_tis = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, @@ -263,7 +263,7 @@ def test_tis_weights(self): tis_weights=None, ) - pg_tis, pg2_tis, _, _, _ = grpo_utils.compute_grpo_loss( + tis = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, @@ -272,8 +272,8 @@ def test_tis_weights(self): tis_weights=tis_weights, ) - torch.testing.assert_close(pg_tis, pg_no_tis * 2.0) - torch.testing.assert_close(pg2_tis, pg2_no_tis * 2.0) + torch.testing.assert_close(tis["pg_losses"], no_tis["pg_losses"] * 2.0) + torch.testing.assert_close(tis["pg_losses2"], no_tis["pg_losses2"] * 2.0) def test_invalid_loss_fn(self): config = _make_grpo_config(loss_fn="invalid") @@ -305,7 +305,7 @@ def test_rate_one_always_passes(self): ) gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=0) for _ in range(16): - should_backward, _, _ = gate.decide(torch.tensor(torch.randn(1).item()), torch.tensor(1.0)) + should_backward, _, _ = gate.decide(torch.randn(()), torch.tensor(1.0)) self.assertTrue(should_backward) def test_gate_rate_matches_target_in_expectation(self): From 5d2abca69154a6354650edbff4187157bd40710d Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 21 Apr 2026 08:11:10 -0600 Subject: [PATCH 10/17] Use LossOutput dataclass for compute_grpo_loss return value. Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_fast.py | 17 +++++-------- open_instruct/grpo_utils.py | 13 ++++++++-- open_instruct/olmo_core_train_modules.py | 22 ++++++++--------- open_instruct/test_olmo_core_train_modules.py | 24 +++++++++---------- 4 files changed, 40 insertions(+), 36 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3936648288..a2b82de3dd 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -736,7 +736,7 @@ def step(self): self.args.truncated_importance_sampling_ratio_cap, ) - loss_terms = grpo_utils.compute_grpo_loss( + loss_output = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs_BT, ratio=ratio_BT, advantages=data_BT.advantages[i][:, 1:], @@ -744,13 +744,8 @@ def step(self): config=self.args, tis_weights=tis_clamped_BT, ) - pg_losses_BT = loss_terms["pg_losses"] - pg_losses2_BT = loss_terms["pg_losses2"] - pg_loss_max_BT = loss_terms["pg_loss_max"] - kl_BT = loss_terms["kl"] - delight_BT = loss_terms["delight"] - per_token_loss_BT = pg_loss_max_BT + self.args.beta * kl_BT + per_token_loss_BT = loss_output.pg_loss_max + self.args.beta * loss_output.kl loss = masked_mean(per_token_loss_BT, response_mask_BT, None, loss_denominator) # we already took world size into account via the tokens @@ -760,7 +755,7 @@ def step(self): decision = grpo_utils.KONDO_GATE_PASSTHROUGH if self._kondo_gate is not None: - decision = self._kondo_gate.decide(delight_BT, response_mask_BT) + decision = self._kondo_gate.decide(loss_output.delight, response_mask_BT) kondo_gate_stats.append(decision) is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0 @@ -786,9 +781,9 @@ def step(self): grpo_utils.populate_sample_loss_stats( loss_stats_B, i, - pg_losses_BT, - pg_losses2_BT, - pg_loss_max_BT, + loss_output.pg_losses, + loss_output.pg_losses2, + loss_output.pg_loss_max, ratio_BT, loss, response_mask_BT, diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index 6992d4324c..4047a10c53 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -361,6 +361,15 @@ def resolve_old_logprob( return result +@dataclass +class LossOutput: + pg_losses: torch.Tensor + pg_losses2: torch.Tensor + pg_loss_max: torch.Tensor + kl: torch.Tensor + delight: torch.Tensor + + def compute_grpo_loss( new_logprobs: torch.Tensor, ratio: torch.Tensor, @@ -368,7 +377,7 @@ def compute_grpo_loss( ref_logprobs: torch.Tensor | None, config: GRPOExperimentConfig, tis_weights: torch.Tensor | None = None, -) -> dict[str, torch.Tensor]: +) -> LossOutput: delight = -advantages * new_logprobs.detach() if config.use_delight: # Delightful Policy Gradient gate; temperature eta is fixed to 1 per the paper. @@ -401,7 +410,7 @@ def compute_grpo_loss( else: kl = torch.zeros_like(pg_loss_max) - return {"pg_losses": pg_losses, "pg_losses2": pg_losses2, "pg_loss_max": pg_loss_max, "kl": kl, "delight": delight} + return LossOutput(pg_losses=pg_losses, pg_losses2=pg_losses2, pg_loss_max=pg_loss_max, kl=kl, delight=delight) class KondoGateDecision(NamedTuple): diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 34da2b4f2a..57dbc15be0 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -459,7 +459,7 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: old_logprob, vllm_logprobs, response_mask, self.grpo_config.truncated_importance_sampling_ratio_cap ) - loss_terms = grpo_utils.compute_grpo_loss( + loss_output = grpo_utils.compute_grpo_loss( new_logprobs=new_logprobs, ratio=ratio, advantages=advantages[:, 1:], @@ -467,21 +467,21 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: config=self.grpo_config, tis_weights=tis_clamped, ) - pg_losses = loss_terms["pg_losses"] - pg_losses2 = loss_terms["pg_losses2"] - pg_loss = loss_terms["pg_loss_max"] - kl = loss_terms["kl"] - delight = loss_terms["delight"] batch_start = (sample_idx // accumulation_steps) * accumulation_steps loss_denominator = accumulation_token_counts[batch_start] - loss = masked_mean(pg_loss + self.grpo_config.beta * kl, response_mask, None, loss_denominator) + loss = masked_mean( + loss_output.pg_loss_max + self.grpo_config.beta * loss_output.kl, + response_mask, + None, + loss_denominator, + ) loss = loss * dp_world_size decision = grpo_utils.KONDO_GATE_PASSTHROUGH if self._kondo_gate is not None: - decision = self._kondo_gate.decide(delight, response_mask) + decision = self._kondo_gate.decide(loss_output.delight, response_mask) kondo_gate_stats.append(decision) if decision.should_backward: @@ -491,9 +491,9 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: grpo_utils.populate_sample_loss_stats( loss_stats_B, sample_idx, - pg_losses, - pg_losses2, - pg_loss, + loss_output.pg_losses, + loss_output.pg_losses2, + loss_output.pg_loss_max, ratio, loss, response_mask, diff --git a/open_instruct/test_olmo_core_train_modules.py b/open_instruct/test_olmo_core_train_modules.py index 702f50e125..92d0a104e7 100644 --- a/open_instruct/test_olmo_core_train_modules.py +++ b/open_instruct/test_olmo_core_train_modules.py @@ -187,12 +187,12 @@ def test_output_shapes(self, _name, loss_type): new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - self.assertEqual(result["pg_losses"].shape, (batch_size, seq_len)) - self.assertEqual(result["pg_losses2"].shape, (batch_size, seq_len)) - self.assertEqual(result["pg_loss_max"].shape, (batch_size, seq_len)) - self.assertEqual(result["kl"].shape, (batch_size, seq_len)) - self.assertEqual(result["delight"].shape, (batch_size, seq_len)) - torch.testing.assert_close(result["delight"], -advantages * new_logprobs.detach()) + self.assertEqual(result.pg_losses.shape, (batch_size, seq_len)) + self.assertEqual(result.pg_losses2.shape, (batch_size, seq_len)) + self.assertEqual(result.pg_loss_max.shape, (batch_size, seq_len)) + self.assertEqual(result.kl.shape, (batch_size, seq_len)) + self.assertEqual(result.delight.shape, (batch_size, seq_len)) + torch.testing.assert_close(result.delight, -advantages * new_logprobs.detach()) def test_dapo_clipping(self): config = _make_grpo_config(clip_lower=0.2, clip_higher=0.2) @@ -205,7 +205,7 @@ def test_dapo_clipping(self): ) expected_clamped = torch.clamp(ratio, 0.8, 1.2) - torch.testing.assert_close(result["pg_losses2"], -advantages * expected_clamped) + torch.testing.assert_close(result.pg_losses2, -advantages * expected_clamped) def test_cispo_uses_detached_ratio(self): config = _make_grpo_config(loss_fn=grpo_utils.GRPOLossType.cispo, clip_higher=0.2) @@ -217,7 +217,7 @@ def test_cispo_uses_detached_ratio(self): new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - result["pg_loss_max"].sum().backward() + result.pg_loss_max.sum().backward() self.assertIsNone(ratio.grad) self.assertIsNotNone(new_logprobs.grad) @@ -233,7 +233,7 @@ def test_with_ref_logprobs(self): new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=ref_logprobs, config=config ) - self.assertFalse(torch.all(result["kl"] == 0)) + self.assertFalse(torch.all(result.kl == 0)) def test_without_ref_logprobs(self): config = _make_grpo_config() @@ -245,7 +245,7 @@ def test_without_ref_logprobs(self): new_logprobs=new_logprobs, ratio=ratio, advantages=advantages, ref_logprobs=None, config=config ) - torch.testing.assert_close(result["kl"], torch.zeros_like(result["kl"])) + torch.testing.assert_close(result.kl, torch.zeros_like(result.kl)) def test_tis_weights(self): config = _make_grpo_config() @@ -272,8 +272,8 @@ def test_tis_weights(self): tis_weights=tis_weights, ) - torch.testing.assert_close(tis["pg_losses"], no_tis["pg_losses"] * 2.0) - torch.testing.assert_close(tis["pg_losses2"], no_tis["pg_losses2"] * 2.0) + torch.testing.assert_close(tis.pg_losses, no_tis.pg_losses * 2.0) + torch.testing.assert_close(tis.pg_losses2, no_tis.pg_losses2 * 2.0) def test_invalid_loss_fn(self): config = _make_grpo_config(loss_fn="invalid") From 88fe46fdfae0cdf4a6dae57616d80fd4dfc55b3b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 21 Apr 2026 08:49:14 -0600 Subject: [PATCH 11/17] Strip defensive guards from Kondo gate / delight code. Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_fast.py | 8 +++-- open_instruct/grpo_utils.py | 32 ++++--------------- open_instruct/olmo_core_train_modules.py | 18 +++++------ open_instruct/test_olmo_core_train_modules.py | 12 ++----- 4 files changed, 22 insertions(+), 48 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index a2b82de3dd..697bdc8c7a 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -753,13 +753,15 @@ def step(self): # up, adjusting for the sequence parallel size (adjust by dp world size). loss *= self.args.world_size // self.args.sequence_parallel_size - decision = grpo_utils.KONDO_GATE_PASSTHROUGH if self._kondo_gate is not None: decision = self._kondo_gate.decide(loss_output.delight, response_mask_BT) - kondo_gate_stats.append(decision) + kondo_gate_stats.append(decision) + should_backward = decision.should_backward + else: + should_backward = True is_accumulation_boundary = (local_step + 1) % accumulation_steps == 0 - if decision.should_backward: + if should_backward: # Clear CUDA cache before backward pass to free memory for reduce_scatter operations torch.cuda.empty_cache() self.model.set_gradient_accumulation_boundary(is_accumulation_boundary) diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index 4047a10c53..c62c53841e 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -248,18 +248,6 @@ def __post_init__(self): raise ValueError(f"`gs_bucket_path` must start with 'gs://', got: {self.gs_bucket_path}") if self.sequence_parallel_size > 1 and self.deepspeed_stage != 3: raise ValueError("`sequence_parallel_size` > 1 requires `deepspeed_stage` to be 3!") - if self.use_kondo_gate: - if not (0.0 < self.kondo_gate_rate <= 1.0): - raise ValueError(f"`kondo_gate_rate` must be in (0, 1], got {self.kondo_gate_rate}") - if self.kondo_gate_temperature <= 0.0: - raise ValueError(f"`kondo_gate_temperature` must be > 0, got {self.kondo_gate_temperature}") - if self.kondo_gate_warmup <= 0: - raise ValueError(f"`kondo_gate_warmup` must be > 0, got {self.kondo_gate_warmup}") - if self.kondo_gate_history_size < self.kondo_gate_warmup: - raise ValueError( - f"`kondo_gate_history_size` ({self.kondo_gate_history_size}) must be >= " - f"`kondo_gate_warmup` ({self.kondo_gate_warmup})." - ) total_learner_gpus = sum(self.num_learners_per_node) if self.fsdp_shard_degree is not None and self.fsdp_num_replicas is not None: @@ -419,9 +407,6 @@ class KondoGateDecision(NamedTuple): lam: float -KONDO_GATE_PASSTHROUGH = KondoGateDecision(True, 1.0, float("-inf")) - - class KondoGateState: """Per-sample Kondo gate over delight (https://arxiv.org/abs/2603.20526). @@ -461,9 +446,9 @@ def _reduced_chi(self, delight: torch.Tensor, response_mask: torch.Tensor) -> to DeepSpeed / FSDP collectives in sync). """ packed = torch.stack([(delight * response_mask).sum().detach(), response_mask.sum().float().detach()]) - if dist.is_available() and dist.is_initialized(): + if dist.is_initialized(): dist.all_reduce(packed, op=dist.ReduceOp.SUM, group=self.process_group) - return packed[0] / packed[1].clamp_min(1.0) + return packed[0] / packed[1] def _append(self, value: torch.Tensor) -> None: self._buffer[self._write_idx] = value @@ -477,8 +462,8 @@ def decide(self, delight: torch.Tensor, response_mask: torch.Tensor) -> KondoGat """ chi = self._reduced_chi(delight, response_mask) self._append(chi) - if self._count < self.warmup or self.rate >= 1.0: - return KONDO_GATE_PASSTHROUGH + if self._count < self.warmup: + return KondoGateDecision(True, 1.0, float("nan")) buf = self._buffer[: self._count] lam = torch.quantile(buf, 1.0 - self.rate) prob = torch.sigmoid((chi - lam) / self.temperature) @@ -488,17 +473,12 @@ def decide(self, delight: torch.Tensor, response_mask: torch.Tensor) -> KondoGat def summarize_kondo_gate_stats(stats: list[KondoGateDecision]) -> dict[str, float]: """Aggregate per-sample gate decisions into scalar metrics.""" - if not stats: - return {} n = len(stats) - out = { + return { "val/kondo_gate_backward_frac": sum(int(s.should_backward) for s in stats) / n, "val/kondo_gate_prob_avg": sum(s.prob for s in stats) / n, + "val/kondo_lambda": sum(s.lam for s in stats) / n, } - finite_lams = [s.lam for s in stats if math.isfinite(s.lam)] - if finite_lams: - out["val/kondo_lambda"] = sum(finite_lams) / len(finite_lams) - return out def forward_for_logprobs( diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 57dbc15be0..f52f705217 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -327,7 +327,10 @@ def pre_train(self): # GRPO batches are prompt-grouped and do their own accumulation/token normalization # inside train_batch(), so the base TransformerTrainModule global-batch validation # does not apply here. - pass + if self.grpo_config.use_kondo_gate: + self._kondo_gate = grpo_utils.KondoGateState( + self.grpo_config, self.device, process_group=self.trainer.dp_process_group, seed=self.grpo_config.seed + ) def state_dict(self, *, optim: bool | None = None) -> dict[str, Any]: state = super().state_dict(optim=optim) @@ -418,11 +421,6 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: group_had_backward = False kondo_gate_stats: list[grpo_utils.KondoGateDecision] = [] - if self.grpo_config.use_kondo_gate and self._kondo_gate is None: - self._kondo_gate = grpo_utils.KondoGateState( - self.grpo_config, self.device, process_group=self.trainer.dp_process_group, seed=self.grpo_config.seed - ) - for epoch_idx in range(self.grpo_config.num_epochs): for sample_idx in range(num_samples): new_logprobs, entropy = grpo_utils.forward_for_logprobs( @@ -479,12 +477,14 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: loss = loss * dp_world_size - decision = grpo_utils.KONDO_GATE_PASSTHROUGH if self._kondo_gate is not None: decision = self._kondo_gate.decide(loss_output.delight, response_mask) - kondo_gate_stats.append(decision) + kondo_gate_stats.append(decision) + should_backward = decision.should_backward + else: + should_backward = True - if decision.should_backward: + if should_backward: loss.backward() group_had_backward = True diff --git a/open_instruct/test_olmo_core_train_modules.py b/open_instruct/test_olmo_core_train_modules.py index 92d0a104e7..2fcbc78f59 100644 --- a/open_instruct/test_olmo_core_train_modules.py +++ b/open_instruct/test_olmo_core_train_modules.py @@ -1,3 +1,4 @@ +import math import unittest from unittest.mock import MagicMock @@ -297,16 +298,7 @@ def test_warmup_always_passes(self): should_backward, prob, lam = gate.decide(torch.tensor(0.0), torch.tensor(1.0)) self.assertTrue(should_backward) self.assertEqual(prob, 1.0) - self.assertEqual(lam, float("-inf")) - - def test_rate_one_always_passes(self): - config = _make_grpo_config( - use_kondo_gate=True, kondo_gate_rate=1.0, kondo_gate_warmup=2, kondo_gate_history_size=8 - ) - gate = grpo_utils.KondoGateState(config, device=torch.device("cpu"), process_group=None, seed=0) - for _ in range(16): - should_backward, _, _ = gate.decide(torch.randn(()), torch.tensor(1.0)) - self.assertTrue(should_backward) + self.assertTrue(math.isnan(lam)) def test_gate_rate_matches_target_in_expectation(self): rate = 0.3 From 06f94df1e1aacf0e31de822aa6a89bbb281abaa5 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 21 Apr 2026 18:09:57 -0600 Subject: [PATCH 12/17] use delight --- scripts/train/qwen/qwen3_4b_dapo_math.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train/qwen/qwen3_4b_dapo_math.sh b/scripts/train/qwen/qwen3_4b_dapo_math.sh index a82a474ca3..bfb88a06d5 100644 --- a/scripts/train/qwen/qwen3_4b_dapo_math.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math.sh @@ -71,4 +71,5 @@ uv run open_instruct/grpo_fast.py \ --chat_template qwen_instruct_user_boxed_math \ --load_ref_policy False \ --keep_last_n_checkpoints -1 \ + --use_delight true \ --push_to_hub False "$@" From 70acb5267326707bdc6ce37a1f6373577a33195c Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 21 Apr 2026 18:11:58 -0600 Subject: [PATCH 13/17] fixed script --- scripts/train/qwen/qwen3_4b_dapo_math.sh | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/train/qwen/qwen3_4b_dapo_math.sh b/scripts/train/qwen/qwen3_4b_dapo_math.sh index bfb88a06d5..e77e3112c2 100644 --- a/scripts/train/qwen/qwen3_4b_dapo_math.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math.sh @@ -4,16 +4,18 @@ EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo}" RUN_NAME="${RUN_NAME:-${EXP_NAME}_$(date +%Y%m%d_%H%M%S)}" NUM_GPUS="${NUM_GPUS:-8}" -BEAKER_IMAGE="${BEAKER_IMAGE:-nathanl/open_instruct_auto}" +BEAKER_IMAGE="${1:-${BEAKER_IMAGE:-nathanl/open_instruct_auto}}" +shift || true -CLUSTER="${CLUSTER:-ai2/jupiter ai2/ceres}" -PRIORITY="${PRIORITY:-high}" +CLUSTER="${CLUSTER:-ai2/jupiter}" +PRIORITY="${PRIORITY:-urgent}" +WORKSPACE="${WORKSPACE:-ai2/open-instruct-dev}" uv run mason.py \ --task_name ${EXP_NAME} \ --description "${RUN_NAME}" \ --cluster ${CLUSTER} \ - --workspace ai2/oe-adapt-code \ + --workspace ${WORKSPACE} \ --priority ${PRIORITY} \ --pure_docker_mode \ --no_auto_dataset_cache \ From 3e593ba695808074a516123cc6b31c4cc8b7e5ac Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 22 Apr 2026 07:01:48 -0600 Subject: [PATCH 14/17] Apply delight gate at sample level to preserve blunder learning signal. Co-Authored-By: Claude Opus 4.7 --- open_instruct/grpo_fast.py | 1 + open_instruct/grpo_utils.py | 13 +++++++-- open_instruct/olmo_core_train_modules.py | 1 + open_instruct/test_olmo_core_train_modules.py | 28 +++++++++++++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index b735fada1b..4cf4bb3820 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -757,6 +757,7 @@ def step(self): ref_logprobs=ref_logprobs_BT[i] if self.args.load_ref_policy else None, config=self.args, tis_weights=tis_clamped_BT, + response_mask=response_mask_BT, ) per_token_loss_BT = loss_output.pg_loss_max + self.args.beta * loss_output.kl diff --git a/open_instruct/grpo_utils.py b/open_instruct/grpo_utils.py index c864ccf897..9e569f152d 100644 --- a/open_instruct/grpo_utils.py +++ b/open_instruct/grpo_utils.py @@ -369,11 +369,20 @@ def compute_grpo_loss( ref_logprobs: torch.Tensor | None, config: GRPOExperimentConfig, tis_weights: torch.Tensor | None = None, + response_mask: torch.Tensor | None = None, ) -> LossOutput: delight = -advantages * new_logprobs.detach() if config.use_delight: - # Delightful Policy Gradient gate; temperature eta is fixed to 1 per the paper. - advantages = advantages * torch.sigmoid(delight) + # Delightful Policy Gradient gate applied at sample level: one sigmoid per rollout, + # broadcast across tokens. GRPO's advantage is constant across a response, so a + # per-token gate would zero out the exact "blunder" tokens whose negative signal + # we need to learn from; a sample-level chi = mean_t(-A * surprisal_t) preserves + # that signal while keeping the paper's breakthrough/blunder interpretation. + mask = response_mask.to(delight.dtype) + denom = mask.sum(dim=-1).clamp(min=1.0) + sample_chi = (delight * mask).sum(dim=-1) / denom + sample_gate = torch.sigmoid(sample_chi).unsqueeze(-1) + advantages = advantages * sample_gate if config.loss_fn == GRPOLossType.dapo: pg_losses = -advantages * ratio diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index fd82f71f01..775d5935e8 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -464,6 +464,7 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: ref_logprobs=ref_logprobs_BT[sample_idx] if ref_logprobs_BT is not None else None, config=self.grpo_config, tis_weights=tis_clamped, + response_mask=response_mask, ) batch_start = (sample_idx // accumulation_steps) * accumulation_steps diff --git a/open_instruct/test_olmo_core_train_modules.py b/open_instruct/test_olmo_core_train_modules.py index 2fcbc78f59..0c433ea2b7 100644 --- a/open_instruct/test_olmo_core_train_modules.py +++ b/open_instruct/test_olmo_core_train_modules.py @@ -195,6 +195,34 @@ def test_output_shapes(self, _name, loss_type): self.assertEqual(result.delight.shape, (batch_size, seq_len)) torch.testing.assert_close(result.delight, -advantages * new_logprobs.detach()) + def test_use_delight_applies_sample_level_gate(self): + config = _make_grpo_config(use_delight=True, loss_fn=grpo_utils.GRPOLossType.dapo) + batch_size, seq_len = 3, 5 + new_logprobs = torch.randn(batch_size, seq_len) + ratio = torch.ones(batch_size, seq_len) + advantages = torch.randn(batch_size, seq_len) + response_mask = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1], [1, 1, 0, 0, 0]], dtype=torch.bool) + + result = grpo_utils.compute_grpo_loss( + new_logprobs=new_logprobs, + ratio=ratio, + advantages=advantages, + ref_logprobs=None, + config=config, + response_mask=response_mask, + ) + + delight = -advantages * new_logprobs.detach() + mask_f = response_mask.float() + sample_chi = (delight * mask_f).sum(-1) / mask_f.sum(-1).clamp(min=1.0) + sample_gate = torch.sigmoid(sample_chi).unsqueeze(-1) + expected_pg = -(advantages * sample_gate) * ratio + torch.testing.assert_close(result.pg_losses, expected_pg) + # Gate is constant across the sequence for each sample. + gate_col0 = result.pg_losses[:, 0] / (-advantages[:, 0] * ratio[:, 0]) + gate_col1 = result.pg_losses[:, 1] / (-advantages[:, 1] * ratio[:, 1]) + torch.testing.assert_close(gate_col0, gate_col1) + def test_dapo_clipping(self): config = _make_grpo_config(clip_lower=0.2, clip_higher=0.2) ratio = torch.tensor([[1.5, 0.5, 1.0]]) From 2f1e65b093fc69a1644b8c4ec041704c95c1030b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 24 Apr 2026 14:38:37 -0600 Subject: [PATCH 15/17] Added eval scripts --- .gitignore | 1 + run_aime_eval.sh | 26 ++++ scripts/eval/oe-eval.sh | 4 +- scripts/submit_eval_jobs_new.py | 230 ++++++++++++++++++++++++++++++++ 4 files changed, 259 insertions(+), 2 deletions(-) create mode 100755 run_aime_eval.sh create mode 100644 scripts/submit_eval_jobs_new.py diff --git a/.gitignore b/.gitignore index c5cf71db88..b12d763d96 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ token_length.png birr/ oe-eval-internal/ +olmo-eval-internal/ results models diff --git a/run_aime_eval.sh b/run_aime_eval.sh new file mode 100755 index 0000000000..93599fa634 --- /dev/null +++ b/run_aime_eval.sh @@ -0,0 +1,26 @@ +#! /bin/bash +# Usage: ./run_aime_eval.sh +# check we have at least 2 arguments +if [ "$#" -lt 2 ]; then + echo "Usage: $0 " + exit 1 +fi +MODEL_NAME=$1 +BEAKER_DATASET_ID=$2 +uv run python scripts/submit_eval_jobs.py \ + --model_name "$MODEL_NAME" \ + --location "$BEAKER_DATASET_ID" \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale ai2/ceres-cirrascale ai2/neptune-cirrascale \ + --beaker_image oe-eval-beaker/oe_eval_auto_finbarr \ + --is_tuned \ + --preemptible \ + --priority urgent \ + --workspace ai2/open-instruct-dev \ + --use_hf_tokenizer_template \ + --run_oe_eval_experiments \ + --oe_eval_tasks aime:zs_cot_r1::maj_at_32_2025 \ + --evaluate_on_weka \ + --run_id placeholder \ + --oe_eval_max_length 8192 \ + --process_output r1_style \ + --skip_oi_evals diff --git a/scripts/eval/oe-eval.sh b/scripts/eval/oe-eval.sh index d4f2fd33a5..acb137dade 100755 --- a/scripts/eval/oe-eval.sh +++ b/scripts/eval/oe-eval.sh @@ -288,7 +288,7 @@ GPU_COUNT_OTHER=$((NUM_GPUS * 2)) MODEL_TYPE_OTHER="" # Build model args JSON with optional process_output -MODEL_ARGS="{\"model_path\":\"${MODEL_LOCATION}\", \"max_length\": ${MAX_LENGTH}, \"trust_remote_code\": \"true\"" +MODEL_ARGS="{\"model_path\":\"${MODEL_LOCATION}\", \"max_length\": ${MAX_LENGTH}, \"trust_remote_code\": true" if [[ -n "$PROCESS_OUTPUT" ]]; then MODEL_ARGS+=", \"process_output\": \"${PROCESS_OUTPUT}\"" fi @@ -319,7 +319,7 @@ for TASK in "${TASKS[@]}"; do # HF model: check if it's supported if [[ " ${SUPPORTED_MODELS[*]} " =~ " ${MODEL_LOCATION} " ]]; then # Supported HF model: remove model_path, no metadata needed - BASE_ARGS="{\"max_length\": ${MAX_LENGTH}, \"trust_remote_code\": \"true\"" + BASE_ARGS="{\"max_length\": ${MAX_LENGTH}, \"trust_remote_code\": true" if [[ -n "$PROCESS_OUTPUT" ]]; then BASE_ARGS+=", \"process_output\": \"${PROCESS_OUTPUT}\"" fi diff --git a/scripts/submit_eval_jobs_new.py b/scripts/submit_eval_jobs_new.py new file mode 100644 index 0000000000..a5d02c118b --- /dev/null +++ b/scripts/submit_eval_jobs_new.py @@ -0,0 +1,230 @@ +"""Submit evaluation jobs using allenai/olmo-eval-internal. + +Replicates the Beaker-dataset flow of `scripts/submit_eval_jobs.py` but uses +`olmo-eval run` (from olmo-eval-internal, baked into the Beaker image) as the +in-container command. The model is mounted at `/model` when `--location` refers +to a Beaker dataset. + +The in-container command is: + + olmo-eval run -m --harness default \ + -o provider.kind=vllm_server \ + -o provider.max_model_len= \ + -o provider.trust_remote_code=true \ + -t [-t ...] \ + --num-gpus \ + --output-dir /results + +This submits directly via `beaker experiment create`, avoiding gantry (and +therefore any git/ref requirements). The wrapper writes a v2 experiment spec +YAML to `configs/beaker_configs/auto_created/` and invokes the Beaker CLI. + +Example: + uv run python scripts/submit_eval_jobs_new.py \ + --model_name qwen3_4b_base_dapo_20260422_083224 \ + --location 01KPTSPMHGEZVYCDNR0XBVJCGZ \ + --tasks aime_2025:pass_at_32 \ + --max_length 8192 \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --priority urgent \ + --preemptible \ + --workspace ai2/open-instruct-dev +""" + +import argparse +import os +import re +import shlex +import subprocess +from datetime import date + +import yaml + + +BEAKER_ID_RE = re.compile(r"^[0-9A-Z]{26}$") +DEFAULT_CLUSTERS = ( + "ai2/jupiter-cirrascale-2", + "ai2/saturn-cirrascale", + "ai2/ceres-cirrascale", + "ai2/neptune-cirrascale", +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_name", type=str, required=True, help="Human-readable run name.") + parser.add_argument( + "--location", + type=str, + required=True, + help=( + "Model location. Accepts: a bare Beaker dataset id (26 uppercase alphanumerics), " + "'beaker://', an HF repo id (e.g. allenai/OLMo-2-1124-7B-Instruct), " + "an absolute Weka/NFS path, or a gs:// URL." + ), + ) + parser.add_argument( + "--tasks", + type=str, + default="aime_2025:pass_at_32", + help="Comma-separated olmo-eval task specs. See `olmo-eval tasks`/`olmo-eval suites`.", + ) + parser.add_argument("--num_gpus", type=int, default=1) + parser.add_argument("--cluster", nargs="+", default=list(DEFAULT_CLUSTERS)) + parser.add_argument("--priority", type=str, default="normal") + parser.add_argument("--preemptible", action="store_true") + parser.add_argument("--workspace", type=str, default="ai2/tulu-3-results") + parser.add_argument("--budget", type=str, default="ai2/oe-adapt") + parser.add_argument( + "--beaker_image", + type=str, + default="ai2-tylerm/olmo-eval-cu1281-trc290-amd64", + help="Beaker image with olmo-eval installed.", + ) + parser.add_argument("--revision", type=str, default=None, help="HF revision (git sha/tag).") + parser.add_argument( + "--max_length", + type=int, + default=32768, + help="Provider max_model_len. Sampling max_tokens comes from the task definition.", + ) + parser.add_argument( + "--sampling_max_tokens", + type=int, + default=None, + help="Override per-task sampling max_tokens (applied via -o max_tokens=N after each -t).", + ) + parser.add_argument("--experiment_name", type=str, default=None) + parser.add_argument( + "--dry_run", + action="store_true", + help="Write the spec YAML and print the beaker command, but do not submit.", + ) + return parser.parse_args() + + +def resolve_model_mount(location: str) -> tuple[str, str | None]: + """Resolve --location into (model_path_in_container, beaker_dataset_id_or_None).""" + if location.startswith("beaker://"): + return "/model", location[len("beaker://") :] + if BEAKER_ID_RE.match(location): + return "/model", location + return location, None + + +def build_inner_cmd(args: argparse.Namespace, model_path: str) -> list[str]: + cmd = [ + "olmo-eval", + "run", + "-m", + model_path, + "--harness", + "default", + "-o", + "provider.kind=vllm_server", + "-o", + f"provider.max_model_len={args.max_length}", + "-o", + "provider.trust_remote_code=true", + ] + if args.revision: + cmd += ["-o", f"provider.revision={args.revision}"] + for task in args.tasks.split(","): + task = task.strip() + if not task: + continue + cmd += ["-t", task] + if args.sampling_max_tokens is not None: + cmd += ["-o", f"max_tokens={args.sampling_max_tokens}"] + cmd += ["--num-gpus", str(args.num_gpus)] + cmd += ["--output-dir", "/results"] + return cmd + + +INSTALL_SCRIPT = ( + "set -euo pipefail && " + "git clone --depth=1 --branch finbarr/cli-sampling-override " + "https://x-access-token:${GITHUB_TOKEN}@github.com/allenai/olmo-eval-internal.git " + "/opt/olmo-eval-internal && " + "cd /opt/olmo-eval-internal && " + "uv pip install --cache-dir /weka/oe-eval-default/olmo-eval-pypi-cache -e '.[vllm]' && " + "uv pip install --cache-dir /weka/oe-eval-default/olmo-eval-pypi-cache " + "--upgrade 'vllm[runai]>=0.19.0' 'transformers>=5.4.0' && " + "cd /workspace" +) + + +def build_spec( + args: argparse.Namespace, + inner_cmd: list[str], + dataset_id: str | None, + experiment_name: str, +) -> dict: + datasets: list[dict] = [ + {"mountPath": "/weka/oe-adapt-default", "source": {"weka": "oe-adapt-default"}}, + {"mountPath": "/weka/oe-training-default", "source": {"weka": "oe-training-default"}}, + {"mountPath": "/weka/oe-eval-default", "source": {"weka": "oe-eval-default"}}, + ] + if dataset_id: + datasets.append({"mountPath": "/model", "source": {"beaker": dataset_id}}) + + full_command = f"{INSTALL_SCRIPT} && {shlex.join(inner_cmd)}" + + return { + "version": "v2", + "description": experiment_name, + "budget": args.budget, + "retry": {"allowedTaskRetries": 2}, + "tasks": [ + { + "name": experiment_name, + "image": {"beaker": args.beaker_image}, + "command": ["/bin/bash", "-c"], + "arguments": [full_command], + "envVars": [ + {"name": "HF_TOKEN", "secret": "HF_TOKEN"}, + {"name": "OPENAI_API_KEY", "secret": "openai_api_key"}, + {"name": "GITHUB_TOKEN", "secret": "GITHUB_TOKEN"}, + {"name": "VLLM_ALLOW_LONG_MAX_MODEL_LEN", "value": "1"}, + ], + "datasets": datasets, + "result": {"path": "/results"}, + "resources": {"gpuCount": args.num_gpus}, + "constraints": {"cluster": list(args.cluster)}, + "context": {"priority": args.priority, "preemptible": args.preemptible}, + } + ], + } + + +def main() -> None: + args = parse_args() + + if len(args.workspace.split("/")) != 2 or not all(args.workspace.split("/")): + raise ValueError(f"--workspace must be '/'. Received: '{args.workspace}'") + + model_path, dataset_id = resolve_model_mount(args.location) + inner_cmd = build_inner_cmd(args, model_path) + + today = date.today().strftime("%m%d%Y") + experiment_name = (args.experiment_name or f"olmo_eval_{args.model_name}_{today}")[:128] + spec = build_spec(args, inner_cmd, dataset_id, experiment_name) + + out_dir = "configs/beaker_configs/auto_created" + os.makedirs(out_dir, exist_ok=True) + spec_path = os.path.join(out_dir, f"{experiment_name}.yaml") + with open(spec_path, "w") as f: + yaml.dump(spec, f, default_flow_style=False, sort_keys=False) + + print("Inner command:", shlex.join(inner_cmd)) + print("Spec written to:", spec_path) + + beaker_cmd = ["beaker", "experiment", "create", spec_path, "--workspace", args.workspace] + print("Running:", shlex.join(beaker_cmd)) + if args.dry_run: + return + subprocess.run(beaker_cmd, check=True) + + +if __name__ == "__main__": + main() From 9fa3062309ba4ee22f0a104da813f46b51708033 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 24 Apr 2026 15:18:26 -0600 Subject: [PATCH 16/17] updated code --- scripts/submit_eval_jobs_new.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/scripts/submit_eval_jobs_new.py b/scripts/submit_eval_jobs_new.py index a5d02c118b..2412f2c4ac 100644 --- a/scripts/submit_eval_jobs_new.py +++ b/scripts/submit_eval_jobs_new.py @@ -42,12 +42,7 @@ BEAKER_ID_RE = re.compile(r"^[0-9A-Z]{26}$") -DEFAULT_CLUSTERS = ( - "ai2/jupiter-cirrascale-2", - "ai2/saturn-cirrascale", - "ai2/ceres-cirrascale", - "ai2/neptune-cirrascale", -) +DEFAULT_CLUSTERS = ("ai2/jupiter",) def parse_args() -> argparse.Namespace: @@ -96,9 +91,7 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--experiment_name", type=str, default=None) parser.add_argument( - "--dry_run", - action="store_true", - help="Write the spec YAML and print the beaker command, but do not submit.", + "--dry_run", action="store_true", help="Write the spec YAML and print the beaker command, but do not submit." ) return parser.parse_args() @@ -143,7 +136,7 @@ def build_inner_cmd(args: argparse.Namespace, model_path: str) -> list[str]: INSTALL_SCRIPT = ( "set -euo pipefail && " - "git clone --depth=1 --branch finbarr/cli-sampling-override " + "git clone --depth=1" "https://x-access-token:${GITHUB_TOKEN}@github.com/allenai/olmo-eval-internal.git " "/opt/olmo-eval-internal && " "cd /opt/olmo-eval-internal && " @@ -154,12 +147,7 @@ def build_inner_cmd(args: argparse.Namespace, model_path: str) -> list[str]: ) -def build_spec( - args: argparse.Namespace, - inner_cmd: list[str], - dataset_id: str | None, - experiment_name: str, -) -> dict: +def build_spec(args: argparse.Namespace, inner_cmd: list[str], dataset_id: str | None, experiment_name: str) -> dict: datasets: list[dict] = [ {"mountPath": "/weka/oe-adapt-default", "source": {"weka": "oe-adapt-default"}}, {"mountPath": "/weka/oe-training-default", "source": {"weka": "oe-training-default"}}, From 0b02490447c597cf036c69ce97530c756d81bb59 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Sat, 25 Apr 2026 08:15:44 -0600 Subject: [PATCH 17/17] cleaned up script --- scripts/submit_eval_jobs_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/submit_eval_jobs_new.py b/scripts/submit_eval_jobs_new.py index 2412f2c4ac..3553cb8ac8 100644 --- a/scripts/submit_eval_jobs_new.py +++ b/scripts/submit_eval_jobs_new.py @@ -136,7 +136,7 @@ def build_inner_cmd(args: argparse.Namespace, model_path: str) -> list[str]: INSTALL_SCRIPT = ( "set -euo pipefail && " - "git clone --depth=1" + "git clone --depth=1 " "https://x-access-token:${GITHUB_TOKEN}@github.com/allenai/olmo-eval-internal.git " "/opt/olmo-eval-internal && " "cd /opt/olmo-eval-internal && "