From 674058096d845e6bbe07e244e6b5caecccd7f4c4 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 1 May 2026 09:13:47 -0700 Subject: [PATCH 01/40] WIP: difficulty mapper + persist sample level rewards --- open_instruct/benchmark_generators.py | 72 +- open_instruct/rl_utils.py | 46 +- open_instruct/test_rl_utils.py | 59 + .../create_bucketed_difficulty.py | 1080 +++++++++++++++++ .../qwen3_4b_dapo_math_gen.sh | 41 + tests/test_create_bucketed_difficulty.py | 389 ++++++ 6 files changed, 1685 insertions(+), 2 deletions(-) create mode 100644 scripts/data/difficulty_sampling/create_bucketed_difficulty.py create mode 100644 scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh create mode 100644 tests/test_create_bucketed_difficulty.py diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index 9c95f372ec..73f6622cee 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -29,8 +29,9 @@ from open_instruct import data_loader, dataset_transformation, grpo_utils, logger_utils, model_utils, utils, vllm_utils from open_instruct.actor_manager import ActorManager -from open_instruct.data_types import PromptRequest +from open_instruct.data_types import GenerationResult, PromptRequest from open_instruct.ground_truth_utils import RewardConfig, build_all_verifiers +from open_instruct.rl_utils import build_rollout_batch_and_advantages, save_rollout_metadata, save_rollouts_to_disk logger = logger_utils.setup_logger(__name__) @@ -148,6 +149,52 @@ def save_benchmark_results_to_csv( logger.info(f"Saved benchmark results to {csv_path}") +def resolve_run_name(args: grpo_utils.GRPOExperimentConfig, timestamp: int) -> str: + """Resolve a stable run name for optional rollout trace persistence.""" + return args.run_name or f"{args.exp_name}__{timestamp}" + + +def maybe_save_scored_rollout_traces( + batch_results: list[GenerationResult], + dataset: datasets.Dataset, + streaming_config: data_loader.StreamingDataLoaderConfig, + *, + run_name: str, + step: int, + total_samples_written: int, +) -> int: + """Persist raw per-sample reward traces when explicitly requested.""" + if not streaming_config.save_traces: + return total_samples_written + + for result in batch_results: + if result.index is None: + raise ValueError("Cannot save scored rollout traces because the result is missing its dataset index.") + + example = dataset[result.index] + batch, advantages = build_rollout_batch_and_advantages( + result, + prompt_tokens=list(example[dataset_transformation.INPUT_IDS_PROMPT_KEY]), + ground_truth=example[dataset_transformation.GROUND_TRUTHS_KEY], + dataset_name=example[dataset_transformation.VERIFIER_SOURCE_KEY], + raw_query=example[dataset_transformation.RAW_PROMPT_KEY], + advantage_normalization_type=streaming_config.advantage_normalization_type, + ) + save_rollouts_to_disk( + streaming_config.rollouts_save_path, + run_name, + step, + batch, + result, + advantages, + len(result.responses), + total_samples_written, + ) + total_samples_written += len(result.responses) + + return total_samples_written + + def free_all_gpu_memory(device: int | str = 0) -> None: """ Aggressively free GPU memory used by PyTorch. @@ -360,6 +407,7 @@ def run_benchmark( streaming_config: data_loader.StreamingDataLoaderConfig, vllm_config: data_loader.VLLMConfig, model_config: model_utils.ModelConfig, + run_name: str, timestamp: int, num_batches: int = 5, ) -> list[dict[str, Any]]: @@ -384,6 +432,7 @@ def run_benchmark( executor = futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="benchmark") results = [] + total_samples_written = 0 # Get the model dimensions from one of the engines without loading weights model_dims = ray.get(vllm_engines[0].get_model_dims.remote()) @@ -433,6 +482,14 @@ def run_benchmark( logger.info( f"Warmup batch completed with {total_warmup_responses} total responses from {len(warmup_results)} prompts" ) + total_samples_written = maybe_save_scored_rollout_traces( + warmup_results, + dataset, + streaming_config, + run_name=run_name, + step=0, + total_samples_written=total_samples_written, + ) logger.info(f"Submitting {num_batches - 1} batches for main benchmark...") submission_future = executor.submit( submission_thread, @@ -462,6 +519,15 @@ def run_benchmark( if time.time() > batch_deadline: raise TimeoutError(f"Batch timed out, got {len(batch_results)}/{num_prompts}") from None + total_samples_written = maybe_save_scored_rollout_traces( + batch_results, + dataset, + streaming_config, + run_name=run_name, + step=batch_idx, + total_samples_written=total_samples_written, + ) + # Simulate weight sync between batches weight_sync_time = simulate_weight_sync(actor_manager, vllm_engines, args) completion_time = time.perf_counter() @@ -720,7 +786,10 @@ def main() -> None: # Create the timestamp here so we use it for both filenames. timestamp = int(time.time()) + args.run_name = resolve_run_name(args, timestamp) save_config(args, tokenizer_config, model_config, streaming_config, timestamp) + if streaming_config.save_traces: + save_rollout_metadata(streaming_config.rollouts_save_path, args.run_name, model_config.model_name_or_path) run_benchmark( dataset, vllm_engines, @@ -731,6 +800,7 @@ def main() -> None: streaming_config, vllm_config, model_config, + args.run_name, timestamp, ) diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index c7161b3320..9c00e13da9 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -5,7 +5,7 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass, field -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar import numpy as np import torch @@ -159,6 +159,50 @@ def save_rollouts_to_disk( ) +def build_rollout_batch_and_advantages( + result: data_types.GenerationResult, + *, + prompt_tokens: list[int], + ground_truth: Any, + dataset_name: str, + raw_query: str, + advantage_normalization_type: str, +) -> tuple[model_utils.Batch, np.ndarray]: + """Convert a scored inference result into the rollout format used by difficulty bucketing.""" + if result.reward_scores is None: + raise ValueError("Cannot save scored rollout traces because reward_scores is missing from GenerationResult.") + + num_samples = len(result.responses) + if len(result.reward_scores) != num_samples: + raise ValueError( + "Cannot save scored rollout traces because reward_scores length does not match the number of responses." + ) + + scores = [float(score) for score in result.reward_scores] + indices = [result.index] * num_samples if result.index is not None else None + batch = model_utils.Batch( + queries=[list(prompt_tokens)] * num_samples, + ground_truths=[ground_truth] * num_samples, + datasets=[dataset_name] * num_samples, + raw_queries=[raw_query] * num_samples, + decoded_responses=None, + indices=indices, + scores=scores, + model_steps=[result.model_step] * num_samples, + ) + + score_array = np.asarray(scores, dtype=float) + mean_score = score_array.mean() + if advantage_normalization_type == "standard": + advantages = (score_array - mean_score) / (score_array.std() + 1e-8) + elif advantage_normalization_type == "centered": + advantages = score_array - mean_score + else: + raise ValueError(f"Invalid advantage normalization type: {advantage_normalization_type}") + + return batch, advantages + + @dataclass class Timer(contextlib.ContextDecorator): """A context manager and decorator for timing code blocks""" diff --git a/open_instruct/test_rl_utils.py b/open_instruct/test_rl_utils.py index 2303a98b45..9e3fe92799 100644 --- a/open_instruct/test_rl_utils.py +++ b/open_instruct/test_rl_utils.py @@ -9,6 +9,7 @@ from parameterized import parameterized from open_instruct import rl_utils +from open_instruct.data_types import GenerationResult, RequestInfo PACK_LENGTH = 40 PROMPT_MAX_LEN = 20 @@ -337,6 +338,64 @@ def test_pack_sequences_min_num_batches(self): self.assertGreater(len(seq), 0) +class TestRolloutTraceSaving(unittest.TestCase): + def _make_generation_result(self, reward_scores: list[float]) -> GenerationResult: + num_samples = len(reward_scores) + return GenerationResult( + responses=[[10, sample_idx] for sample_idx in range(num_samples)], + finish_reasons=["stop"] * num_samples, + masks=[[1, 1]] * num_samples, + request_info=RequestInfo( + num_calls=[0] * num_samples, + timeouts=[0] * num_samples, + tool_errors=[""] * num_samples, + tool_outputs=[""] * num_samples, + tool_runtimes=[0.0] * num_samples, + tool_calleds=[False] * num_samples, + tool_call_stats=[[] for _ in range(num_samples)], + rollout_states=[{} for _ in range(num_samples)], + ), + index=3, + prompt_id="prompt_3", + reward_scores=reward_scores, + logprobs=[[0.1, 0.2]] * num_samples, + model_step=7, + ) + + def test_build_rollout_batch_and_advantages_preserves_scores(self): + result = self._make_generation_result([10.0, 0.0]) + + batch, advantages = rl_utils.build_rollout_batch_and_advantages( + result, + prompt_tokens=[1, 2, 3], + ground_truth="4", + dataset_name="math", + raw_query="user: solve 2+2", + advantage_normalization_type="centered", + ) + + self.assertEqual(batch.queries, [[1, 2, 3], [1, 2, 3]]) + self.assertEqual(batch.ground_truths, ["4", "4"]) + self.assertEqual(batch.datasets, ["math", "math"]) + self.assertEqual(batch.indices, [3, 3]) + self.assertEqual(batch.scores, [10.0, 0.0]) + np.testing.assert_allclose(advantages, np.array([5.0, -5.0])) + + def test_build_rollout_batch_and_advantages_raises_without_scores(self): + result = self._make_generation_result([1.0, 0.0]) + result.reward_scores = None + + with self.assertRaises(ValueError): + rl_utils.build_rollout_batch_and_advantages( + result, + prompt_tokens=[1, 2, 3], + ground_truth="4", + dataset_name="math", + raw_query="user: solve 2+2", + advantage_normalization_type="centered", + ) + + class TestMaskedMean(unittest.TestCase): def test_original_axis_int(self): values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) diff --git a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py new file mode 100644 index 0000000000..3e12db0499 --- /dev/null +++ b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py @@ -0,0 +1,1080 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = "==3.12.*" +# dependencies = [ +# "datasets>=4.0.0", +# "numpy>=2", +# "scipy>=1.14.0", +# ] +# /// + +""" +Build a per-instance difficulty map from open-instruct rollout traces. + +The script accepts one or more local rollout directories, metadata ``.jsonl`` +files, or rollout shard ``.jsonl`` files written by ``open_instruct.rl_utils``. +For each traced prompt instance it: + +1. loads rollout shards written by ``save_rollouts_to_disk()``, +2. groups attempts by a deterministic fingerprint over task name, prompt tokens, + and ground truth, +3. normalizes binary verifiable rewards from ``{0, C}`` back to ``{0, 1}`` + when possible, +4. fits a Beta prior across binary outcomes and estimates per-item success + rates, and +5. writes a JSONL difficulty file and schema/metadata sidecars. + +Examples: + uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + --source /tmp/qwen_math_rollouts \ + --task math \ + --output /tmp/qwen_math_difficulty + + uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + --source /tmp/qwen_math_rollouts/qwen_math_metadata.jsonl \ + --output /tmp/difficulty_map +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import math +import sys +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from datasets import Dataset +from scipy.optimize import minimize +from scipy.special import betaln +from scipy.stats import beta as beta_distribution + +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from open_instruct import logger_utils # noqa: E402 + +logger = logger_utils.setup_logger(__name__) + + +EPS = 1e-8 +EXPERIMENT_METADATA_KEYS = ("source_root", "model_name", "experiment_id", "experiment_name") +JEFFREYS_PRIOR_ALPHA = 0.5 +JEFFREYS_PRIOR_BETA = 0.5 +DEFAULT_DIFFICULTY_BUCKETS = 5 +POSTERIOR_QUANTILE_GRID_SIZE = 512 +POSTERIOR_QUANTILE_BATCH_SIZE = 256 +DIFFICULTY_GENERATION_METHOD = "beta_binomial_posterior_quantiles" +DIFFICULTY_METHOD_FILENAME_ALIASES = {DIFFICULTY_GENERATION_METHOD: "bbq"} +PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} +SOURCE_FORMAT_KIND = "open_instruct_rollout_traces" +INSTANCE_ID_DEFINITION = "sha1(task_name,prompt_tokens,ground_truth)" + + +@dataclass(frozen=True) +class BetaPrior: + alpha: float + beta: float + source: str + + +@dataclass(frozen=True) +class RolloutSource: + input_arg: str + root_path: Path + metadata_path: Path + rollout_paths: tuple[Path, ...] + run_name: str + + +@dataclass(frozen=True) +class DifficultyPosteriorRow: + row: dict[str, Any] + difficulty_alpha: float + difficulty_beta: float + + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Build a per-instance difficulty map from open-instruct rollout traces.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source", + nargs="+", + required=True, + help="One or more local rollout dirs, *_metadata.jsonl files, or *_rollouts_*.jsonl shards.", + ) + parser.add_argument( + "--task", + action="append", + default=[], + help="Optional task filter. Matches the rollout trace dataset/verifier source.", + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help=( + "Output directory or path-like root. The script writes one file per task/model inside it as " + "____.jsonl plus matching .schema.json and .metadata.json sidecars." + ), + ) + parser.add_argument( + "--push-to-hub", type=str, default=None, help="Optional dataset repo id to push the validated rows to." + ) + parser.add_argument("--split", type=str, default="train", help="Split to use with --push-to-hub.") + parser.add_argument( + "--strict", action="store_true", help="Fail if a rollout record is malformed or required files are missing." + ) + parser.add_argument( + "--allow-nonunit-scores", + action="store_true", + help="Keep rows whose rewards cannot be normalized to binary correctness. Difficulty will be null for them.", + ) + parser.add_argument( + "--max-instances", + type=int, + default=None, + help="Optional cap for the number of resolved instances written (useful for smoke tests).", + ) + parser.add_argument( + "--beta-prior", + choices=["empirical-bayes", "jeffreys"], + default="empirical-bayes", + help="Global Beta prior to use for smoothing binary solve rates.", + ) + parser.add_argument( + "--posterior-lower-quantile", + type=float, + default=0.1, + help="Lower posterior quantile used to define difficulty as 1 - quantile.", + ) + parser.add_argument( + "--difficulty-buckets", + type=int, + default=DEFAULT_DIFFICULTY_BUCKETS, + help=( + "Number of posterior-aware quantile buckets to assign for stratification. " + "Set to 0 to skip discrete bucket assignment." + ), + ) + return parser + + +def main(argv: list[str] | None = None) -> None: + args = make_parser().parse_args(argv) + validate_args(args) + task_filters = set(args.task) + output_root = resolve_output_root(args.output) + + source_runs = discover_rollout_sources(args.source) + if not source_runs: + raise ValueError("No rollout trace sources were found.") + + contributions: list[dict[str, Any]] = [] + malformed_records = 0 + + for source_run in source_runs: + logger.info( + "Loading %s (run=%s, metadata=%s, shards=%s)", + source_run.input_arg, + source_run.run_name, + source_run.metadata_path, + len(source_run.rollout_paths), + ) + run_contributions, run_malformed = build_contributions_for_source( + source_run=source_run, task_filters=task_filters, strict=args.strict + ) + contributions.extend(run_contributions) + malformed_records += run_malformed + + if not contributions: + raise ValueError("No resolved per-instance rows were produced.") + + rows = aggregate_contributions(contributions) + rows = sorted( + rows, + key=lambda row: ( + stable_string(row.get("task_name")), + stable_string((row.get("experiment_metadata") or {}).get("model_name")), + stable_string(row.get("instance_id")), + ), + ) + if args.max_instances is not None: + rows = rows[: args.max_instances] + + rows_by_group = group_rows_by_task_and_model(rows) + if args.push_to_hub is not None and len(rows_by_group) != 1: + raise ValueError( + "--push-to-hub requires a single task/model output. Filter with --task or use a source with one task." + ) + + skipped_nonunit = 0 + written_outputs: list[tuple[str, str | None, int, Path, Path, Path]] = [] + + for (task_name, model_name), group_rows in sorted( + rows_by_group.items(), key=lambda item: (item[0][0], stable_string(item[0][1])) + ): + group_rows, score_processing, group_skipped_nonunit = normalize_attempt_scores_for_group( + group_rows, allow_nonunit_scores=args.allow_nonunit_scores + ) + skipped_nonunit += group_skipped_nonunit + + if not group_rows: + logger.warning( + "Skipping task=%s model=%s because no rows remained after reward normalization.", task_name, model_name + ) + continue + + prior, binary_row_count = estimate_beta_prior(group_rows, prior_mode=args.beta_prior) + group_rows = apply_beta_binomial_difficulty( + group_rows, prior=prior, lower_quantile=args.posterior_lower_quantile, num_buckets=args.difficulty_buckets + ) + group_rows = sorted(group_rows, key=lambda row: row["instance_id"]) + + dataset_metadata = build_dataset_metadata( + rows=group_rows, + task_name=task_name, + model_name=model_name, + requested_prior_mode=args.beta_prior, + requested_bucket_count=args.difficulty_buckets, + lower_quantile=args.posterior_lower_quantile, + prior=prior, + binary_row_count=binary_row_count, + score_processing=score_processing, + ) + + if prior is not None: + logger.info( + "Using %s Beta prior alpha=%.4f beta=%.4f across %s binary instances for task=%s model=%s.", + prior.source, + prior.alpha, + prior.beta, + binary_row_count, + task_name, + model_name, + ) + else: + logger.warning( + "No binary instances were available for Beta-Binomial difficulty estimation for task=%s model=%s.", + task_name, + model_name, + ) + + dataset = Dataset.from_list(group_rows) + annotate_dataset_metadata(dataset, dataset_metadata) + output_jsonl, schema_json, metadata_json = build_output_paths( + output_root, task_name=task_name, model_name=model_name, dataset_metadata=dataset_metadata + ) + write_output_files( + output_jsonl=output_jsonl, + schema_json=schema_json, + metadata_json=metadata_json, + rows=group_rows, + dataset=dataset, + dataset_metadata=dataset_metadata, + ) + + if args.push_to_hub is not None: + dataset.push_to_hub(args.push_to_hub, split=args.split, private=True) + + written_outputs.append((task_name, model_name, len(group_rows), output_jsonl, schema_json, metadata_json)) + logger.info( + "Wrote %s rows for task=%s model=%s to %s, %s, and %s.", + len(group_rows), + task_name, + model_name, + output_jsonl, + schema_json, + metadata_json, + ) + + logger.info( + "Finished writing %s output file groups (%s malformed rollout records, %s skipped due to unsupported scores).", + len(written_outputs), + malformed_records, + skipped_nonunit, + ) + + +def discover_rollout_sources(sources: list[str]) -> list[RolloutSource]: + discovered: dict[Path, RolloutSource] = {} + + for source in sources: + source_path = Path(source) + if not source_path.exists(): + raise FileNotFoundError(f"Could not find source path {source}") + + if source_path.is_dir(): + metadata_paths = sorted(source_path.rglob("*_metadata.jsonl")) + if not metadata_paths: + raise FileNotFoundError(f"Could not find *_metadata.jsonl under {source}") + for metadata_path in metadata_paths: + rollout_source = build_rollout_source_from_metadata(metadata_path, input_arg=source) + discovered[rollout_source.metadata_path] = rollout_source + continue + + if source_path.name.endswith("_metadata.jsonl"): + rollout_source = build_rollout_source_from_metadata(source_path, input_arg=source) + discovered[rollout_source.metadata_path] = rollout_source + continue + + if source_path.suffix == ".jsonl" and "_rollouts_" in source_path.name: + rollout_source = build_rollout_source_from_rollout(source_path, input_arg=source) + discovered[rollout_source.metadata_path] = rollout_source + continue + + raise ValueError( + f"Unsupported source path {source}. Expected a directory, *_metadata.jsonl, or *_rollouts_*.jsonl." + ) + + return sorted(discovered.values(), key=lambda source_run: (str(source_run.root_path), source_run.run_name)) + + +def build_rollout_source_from_metadata(metadata_path: Path, *, input_arg: str) -> RolloutSource: + run_name = parse_run_name_from_metadata_path(metadata_path) + rollout_paths = tuple(sorted(metadata_path.parent.glob(f"{run_name}_rollouts_*.jsonl"))) + if not rollout_paths: + raise FileNotFoundError(f"Could not find rollout shards for run {run_name} next to {metadata_path}") + return RolloutSource( + input_arg=input_arg, + root_path=metadata_path.parent.resolve(), + metadata_path=metadata_path.resolve(), + rollout_paths=rollout_paths, + run_name=run_name, + ) + + +def build_rollout_source_from_rollout(rollout_path: Path, *, input_arg: str) -> RolloutSource: + run_name = parse_run_name_from_rollout_path(rollout_path) + metadata_path = rollout_path.parent / f"{run_name}_metadata.jsonl" + if not metadata_path.exists(): + raise FileNotFoundError(f"Could not find metadata file {metadata_path} for rollout shard {rollout_path}") + return build_rollout_source_from_metadata(metadata_path, input_arg=input_arg) + + +def parse_run_name_from_metadata_path(metadata_path: Path) -> str: + suffix = "_metadata.jsonl" + if not metadata_path.name.endswith(suffix): + raise ValueError(f"Metadata path must end with {suffix}: {metadata_path}") + return metadata_path.name[: -len(suffix)] + + +def parse_run_name_from_rollout_path(rollout_path: Path) -> str: + marker = "_rollouts_" + if marker not in rollout_path.name: + raise ValueError(f"Rollout shard filename must contain {marker}: {rollout_path}") + return rollout_path.name.split(marker, 1)[0] + + +def build_contributions_for_source( + *, source_run: RolloutSource, task_filters: set[str], strict: bool +) -> tuple[list[dict[str, Any]], int]: + run_metadata = read_rollout_metadata(source_run.metadata_path, fallback_run_name=source_run.run_name) + contributions: list[dict[str, Any]] = [] + malformed_records = 0 + + for rollout_path in source_run.rollout_paths: + for line_number, record in enumerate(read_jsonl(rollout_path), start=1): + try: + contribution = build_rollout_contribution( + record=record, source_run=source_run, run_metadata=run_metadata + ) + except Exception as exc: + malformed_records += 1 + message = f"Malformed rollout record in {rollout_path}:{line_number}: {exc}" + if strict: + raise ValueError(message) from exc + logger.warning(message) + continue + + task_name = stable_string(contribution.get("task_name")) + if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: + continue + contributions.append(contribution) + + return contributions, malformed_records + + +def read_rollout_metadata(metadata_path: Path, *, fallback_run_name: str) -> dict[str, Any]: + rows = read_jsonl(metadata_path) + if not rows: + raise ValueError(f"Metadata file is empty: {metadata_path}") + if len(rows) > 1: + logger.warning("Expected one metadata row in %s but found %s. Using the first row.", metadata_path, len(rows)) + + metadata = rows[0] + return { + "run_name": optional_string(metadata.get("run_name")) or fallback_run_name, + "model_name": optional_string(metadata.get("model_name")), + "git_commit": optional_string(metadata.get("git_commit")), + "timestamp": optional_string(metadata.get("timestamp")), + } + + +def build_rollout_contribution( + *, record: dict[str, Any], source_run: RolloutSource, run_metadata: dict[str, Any] +) -> dict[str, Any]: + task_name = normalize_task_name(record.get("dataset")) + if task_name is None: + raise ValueError("missing dataset/verifier source") + + prompt_tokens = normalize_token_list(record.get("prompt_tokens")) + if prompt_tokens is None: + raise ValueError("missing or invalid prompt_tokens") + + reward = extract_numeric_reward(record.get("reward")) + if reward is None: + raise ValueError("missing or invalid reward") + + ground_truth = make_jsonable(record.get("ground_truth")) + finish_reason = optional_string(record.get("finish_reason")) + + return { + "instance_id": make_rollout_instance_id( + task_name=task_name, prompt_tokens=prompt_tokens, ground_truth=ground_truth + ), + "task_name": task_name, + "base_task_name": get_base_task_name(task_name), + "prompt_tokens": prompt_tokens, + "ground_truth": ground_truth, + "score_source": task_name, + "attempt_scores": [reward], + "finish_reasons": [finish_reason] if finish_reason else [], + "experiment_metadata": { + "source_root": str(source_run.root_path), + "model_name": run_metadata["model_name"], + "experiment_id": None, + "experiment_name": run_metadata["run_name"], + }, + "warnings": extract_rollout_warnings(record.get("request_info")), + } + + +def normalize_task_name(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + serialized = serialize_value(value) + return serialized or None + + +def normalize_token_list(value: Any) -> list[int] | None: + if not isinstance(value, list): + return None + + tokens: list[int] = [] + for item in value: + if isinstance(item, bool) or not isinstance(item, (int, float)): + return None + tokens.append(int(item)) + return tokens + + +def extract_numeric_reward(value: Any) -> float | None: + if not is_number(value): + return None + return float(value) + + +def extract_rollout_warnings(request_info: Any) -> list[str]: + if not isinstance(request_info, dict): + return [] + + warnings: list[str] = [] + if request_info.get("timeouts"): + warnings.append("timeout") + if optional_string(request_info.get("tool_errors")): + warnings.append("tool_error") + return warnings + + +def aggregate_contributions(contributions: list[dict[str, Any]]) -> list[dict[str, Any]]: + grouped: dict[str, dict[str, Any]] = {} + + for contribution in contributions: + instance_id = contribution["instance_id"] + if instance_id not in grouped: + grouped[instance_id] = { + key: value + for key, value in contribution.items() + if key not in {"attempt_scores", "finish_reasons", "experiment_metadata", "warnings", "score_source"} + } + grouped[instance_id]["attempt_scores"] = [] + grouped[instance_id]["finish_reasons"] = [] + grouped[instance_id]["experiment_metadata"] = None + grouped[instance_id]["score_sources"] = set() + grouped[instance_id]["warnings"] = set() + + row = grouped[instance_id] + row["attempt_scores"].extend(float(score) for score in contribution["attempt_scores"]) + row["finish_reasons"].extend(contribution["finish_reasons"]) + row["experiment_metadata"] = merge_experiment_metadata( + existing=row["experiment_metadata"], incoming=contribution["experiment_metadata"], instance_id=instance_id + ) + row["score_sources"].add(stable_string(contribution["score_source"])) + row["warnings"].update(contribution["warnings"]) + + rows: list[dict[str, Any]] = [] + for row in grouped.values(): + row["attempt_scores"] = [float(score) for score in row["attempt_scores"]] + row["finish_reasons"] = [stable_string(reason) for reason in row["finish_reasons"] if stable_string(reason)] + row["experiment_metadata"] = normalize_experiment_metadata(row["experiment_metadata"]) + row["score_sources"] = sorted(value for value in row["score_sources"] if value) + row["warnings"] = sorted(value for value in row["warnings"] if value) + rows.append(row) + + return rows + + +def normalize_attempt_scores_for_group( + rows: list[dict[str, Any]], *, allow_nonunit_scores: bool +) -> tuple[list[dict[str, Any]], dict[str, Any], int]: + score_processing = infer_score_processing(rows) + normalized_rows: list[dict[str, Any]] = [] + skipped_nonunit = 0 + + for row in rows: + normalized_scores = normalize_attempt_scores(row["attempt_scores"], score_processing) + if normalized_scores is None: + if allow_nonunit_scores: + kept_row = dict(row) + kept_row["attempt_scores"] = [float(score) for score in row["attempt_scores"]] + kept_row["warnings"] = sorted({*kept_row["warnings"], "nonbinary_reward_scores"}) + normalized_rows.append(kept_row) + else: + skipped_nonunit += 1 + continue + + normalized_row = dict(row) + normalized_row["attempt_scores"] = normalized_scores + normalized_rows.append(normalized_row) + + return normalized_rows, score_processing, skipped_nonunit + + +def infer_score_processing(rows: list[dict[str, Any]]) -> dict[str, Any]: + scores = [float(score) for row in rows for score in row.get("attempt_scores", [])] + score_processing = { + "source_field": "reward", + "output_field": "attempt_scores", + "normalization": "unsupported", + "positive_reward_value": None, + "supports_binary_difficulty": False, + } + + if not scores: + return score_processing + + if all(is_close(score, 0.0) or is_close(score, 1.0) for score in scores): + score_processing["normalization"] = "identity_binary" + score_processing["positive_reward_value"] = 1.0 + score_processing["supports_binary_difficulty"] = True + return score_processing + + if any(score < -EPS for score in scores): + return score_processing + + positive_scores = [score for score in scores if score > EPS] + if not positive_scores: + score_processing["normalization"] = "all_zero_binary" + score_processing["supports_binary_difficulty"] = True + return score_processing + + positive_reward_value = max(positive_scores) + if all(is_close(score, 0.0) or is_close(score, positive_reward_value) for score in scores): + score_processing["normalization"] = "binary_zero_or_constant" + score_processing["positive_reward_value"] = positive_reward_value + score_processing["supports_binary_difficulty"] = True + + return score_processing + + +def normalize_attempt_scores(attempt_scores: list[float], score_processing: dict[str, Any]) -> list[float] | None: + if not score_processing.get("supports_binary_difficulty"): + return None + + normalization = stable_string(score_processing.get("normalization")) + positive_reward_value = score_processing.get("positive_reward_value") + normalized_scores: list[float] = [] + + for score in attempt_scores: + if is_close(score, 0.0): + normalized_scores.append(0.0) + continue + + if normalization == "identity_binary" and is_close(score, 1.0): + normalized_scores.append(1.0) + continue + + if ( + normalization == "binary_zero_or_constant" + and positive_reward_value is not None + and is_close(score, float(positive_reward_value)) + ): + normalized_scores.append(1.0) + continue + + if normalization == "all_zero_binary": + return None + + return None + + return normalized_scores + + +def estimate_beta_prior(rows: list[dict[str, Any]], *, prior_mode: str) -> tuple[BetaPrior | None, int]: + binary_counts = [counts for row in rows if (counts := extract_binary_counts(row["attempt_scores"])) is not None] + if not binary_counts: + return None, 0 + + if prior_mode == "jeffreys": + return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys"), len(binary_counts) + + prior = fit_empirical_beta_prior(binary_counts) + if prior is not None: + return prior, len(binary_counts) + + logger.warning("Falling back to Jeffreys prior after empirical-Bayes fitting failed.") + return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys_fallback"), len(binary_counts) + + +def apply_beta_binomial_difficulty( + rows: list[dict[str, Any]], *, prior: BetaPrior | None, lower_quantile: float, num_buckets: int +) -> list[dict[str, Any]]: + posterior_rows: list[DifficultyPosteriorRow] = [] + + for row in rows: + row["difficulty"] = make_empty_difficulty_payload() + + if prior is None: + continue + + binary_counts = extract_binary_counts(row["attempt_scores"]) + if binary_counts is None: + continue + + success_count, attempt_count = binary_counts + posterior_alpha = success_count + prior.alpha + posterior_beta = attempt_count - success_count + prior.beta + posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta) + posterior_lower_bound = float(beta_distribution.ppf(lower_quantile, posterior_alpha, posterior_beta)) + + row["difficulty"] = { + "value": max(0.0, min(1.0, 1.0 - posterior_lower_bound)), + "posterior_mean": posterior_mean, + "posterior_lower_bound": posterior_lower_bound, + "expected_quantile": None, + "bucket_index": None, + "bucket_count": None, + } + posterior_rows.append( + DifficultyPosteriorRow(row=row, difficulty_alpha=posterior_beta, difficulty_beta=posterior_alpha) + ) + + assign_posterior_difficulty_buckets(posterior_rows, num_buckets=num_buckets) + return rows + + +def make_empty_difficulty_payload() -> dict[str, Any]: + return { + "value": None, + "posterior_mean": None, + "posterior_lower_bound": None, + "expected_quantile": None, + "bucket_index": None, + "bucket_count": None, + } + + +def assign_posterior_difficulty_buckets(posterior_rows: list[DifficultyPosteriorRow], *, num_buckets: int) -> None: + if not posterior_rows: + return + + expected_quantiles = estimate_expected_difficulty_quantiles(posterior_rows) + for posterior_row, expected_quantile in zip(posterior_rows, expected_quantiles, strict=True): + posterior_row.row["difficulty"]["expected_quantile"] = expected_quantile + + if num_buckets <= 0: + return + + effective_bucket_count = min(num_buckets, len(posterior_rows)) + ordered_rows = sorted( + zip(posterior_rows, expected_quantiles, strict=True), + key=lambda item: (item[1], item[0].row["difficulty"]["value"], stable_string(item[0].row["instance_id"])), + ) + base_bucket_size, remainder = divmod(len(ordered_rows), effective_bucket_count) + + cursor = 0 + for bucket_index in range(effective_bucket_count): + bucket_size = base_bucket_size + (1 if bucket_index < remainder else 0) + for posterior_row, _expected_quantile in ordered_rows[cursor : cursor + bucket_size]: + posterior_row.row["difficulty"]["bucket_index"] = bucket_index + posterior_row.row["difficulty"]["bucket_count"] = effective_bucket_count + cursor += bucket_size + + +def estimate_expected_difficulty_quantiles( + posterior_rows: list[DifficultyPosteriorRow], + *, + grid_size: int = POSTERIOR_QUANTILE_GRID_SIZE, + batch_size: int = POSTERIOR_QUANTILE_BATCH_SIZE, +) -> list[float]: + if not posterior_rows: + return [] + if len(posterior_rows) == 1: + return [0.5] + + grid = (np.arange(grid_size, dtype=np.float64) + 0.5) / grid_size + difficulty_alphas = np.asarray([row.difficulty_alpha for row in posterior_rows], dtype=np.float64) + difficulty_betas = np.asarray([row.difficulty_beta for row in posterior_rows], dtype=np.float64) + + mixture_cdf = np.zeros(grid_size, dtype=np.float64) + for start in range(0, len(posterior_rows), batch_size): + stop = start + batch_size + batch_cdf = beta_distribution.cdf( + grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] + ) + mixture_cdf += np.nan_to_num(batch_cdf, nan=0.0, posinf=1.0, neginf=0.0).sum(axis=0) + mixture_cdf /= len(posterior_rows) + + quantiles = np.zeros(len(posterior_rows), dtype=np.float64) + dx = 1.0 / grid_size + for start in range(0, len(posterior_rows), batch_size): + stop = start + batch_size + batch_pdf = beta_distribution.pdf( + grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] + ) + quantiles[start:stop] = np.clip( + np.nan_to_num(batch_pdf, nan=0.0, posinf=0.0, neginf=0.0).dot(mixture_cdf) * dx, 0.0, 1.0 + ) + + return quantiles.tolist() + + +def fit_empirical_beta_prior(binary_counts: list[tuple[int, int]]) -> BetaPrior | None: + total_successes = sum(success_count for success_count, _ in binary_counts) + total_attempts = sum(attempt_count for _, attempt_count in binary_counts) + if total_attempts == 0 or total_successes in {0, total_attempts}: + return None + + mean_rate = total_successes / total_attempts + init_alpha = max(mean_rate * 2.0, 1e-3) + init_beta = max((1.0 - mean_rate) * 2.0, 1e-3) + + def objective(log_params: tuple[float, float]) -> float: + alpha = math.exp(log_params[0]) + beta = math.exp(log_params[1]) + return -sum( + betaln(success_count + alpha, attempt_count - success_count + beta) - betaln(alpha, beta) + for success_count, attempt_count in binary_counts + ) + + result = minimize( + objective, + x0=(math.log(init_alpha), math.log(init_beta)), + method="L-BFGS-B", + bounds=[(-10.0, 10.0), (-10.0, 10.0)], + ) + if not result.success: + logger.warning("Empirical-Bayes fit failed: %s", result.message) + return None + + return BetaPrior(alpha=math.exp(result.x[0]), beta=math.exp(result.x[1]), source="empirical_bayes") + + +def merge_experiment_metadata( + existing: dict[str, Any] | None, incoming: dict[str, Any], *, instance_id: str +) -> dict[str, Any]: + normalized_incoming = normalize_experiment_metadata(incoming) + if existing is None: + return normalized_incoming + + merged = dict(existing) + for key in EXPERIMENT_METADATA_KEYS: + existing_value = merged.get(key) + incoming_value = normalized_incoming.get(key) + if existing_value in {None, ""}: + merged[key] = incoming_value + elif incoming_value in {None, ""} or incoming_value == existing_value: + continue + else: + raise ValueError( + f"Conflicting experiment metadata for instance {instance_id}: " + f"{key}={existing_value!r} vs {incoming_value!r}" + ) + return merged + + +def normalize_experiment_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]: + if metadata is None: + return {key: None for key in EXPERIMENT_METADATA_KEYS} + return {key: metadata.get(key) for key in EXPERIMENT_METADATA_KEYS} + + +def resolve_output_root(output: Path) -> Path: + output_str = str(output) + if output_str.endswith(".schema.json"): + return Path(output_str[: -len(".schema.json")]) + if output_str.endswith(".jsonl"): + return Path(output_str[: -len(".jsonl")]) + if output_str.endswith(".json"): + return Path(output_str[: -len(".json")]) + return output + + +def build_output_paths( + output_root: Path, *, task_name: str, model_name: str | None, dataset_metadata: dict[str, Any] +) -> tuple[Path, Path, Path]: + task_suffix = sanitize_name(task_name) or "unknown-task" + model_suffix = sanitize_name(model_name or "") or "unknown-model" + difficulty_suffix = build_difficulty_filename_suffix(dataset_metadata) + stem = output_root / f"{task_suffix}__{model_suffix}{difficulty_suffix}" + return Path(f"{stem}.jsonl"), Path(f"{stem}.schema.json"), Path(f"{stem}.metadata.json") + + +def write_output_files( + *, + output_jsonl: Path, + schema_json: Path, + metadata_json: Path, + rows: list[dict[str, Any]], + dataset: Dataset, + dataset_metadata: dict[str, Any], +) -> None: + output_jsonl.parent.mkdir(parents=True, exist_ok=True) + with output_jsonl.open("w") as output_file: + for row in rows: + output_file.write(json.dumps(row, ensure_ascii=False) + "\n") + + schema_json.parent.mkdir(parents=True, exist_ok=True) + try: + schema_payload: Any = dataset.features.to_dict() + except AttributeError: + schema_payload = str(dataset.features) + with schema_json.open("w") as output_file: + json.dump(schema_payload, output_file, indent=2, sort_keys=True) + + metadata_json.parent.mkdir(parents=True, exist_ok=True) + with metadata_json.open("w") as output_file: + json.dump(dataset_metadata, output_file, indent=2, sort_keys=True) + + +def build_dataset_metadata( + *, + rows: list[dict[str, Any]], + task_name: str, + model_name: str | None, + requested_prior_mode: str, + requested_bucket_count: int, + lower_quantile: float, + prior: BetaPrior | None, + binary_row_count: int, + score_processing: dict[str, Any], +) -> dict[str, Any]: + effective_bucket_count = extract_effective_bucket_count(rows) + difficulty_generation = { + "method": DIFFICULTY_GENERATION_METHOD, + "difficulty_value_field": "difficulty.value", + "difficulty_value_definition": "1 - difficulty.posterior_lower_bound", + "bucket_field": "difficulty.bucket_index", + "bucket_count_field": "difficulty.bucket_count", + "bucket_ranking_field": "difficulty.expected_quantile", + "posterior_lower_quantile": lower_quantile, + "bucket_count_requested": requested_bucket_count, + "bucket_count_effective": effective_bucket_count, + "beta_prior_requested": requested_prior_mode, + "beta_prior_used": { + "source": prior.source if prior is not None else None, + "alpha": prior.alpha if prior is not None else None, + "beta": prior.beta if prior is not None else None, + }, + "binary_instance_count": binary_row_count, + "nonbinary_instance_count": max(0, len(rows) - binary_row_count), + } + difficulty_generation["tag"] = build_difficulty_config_tag(difficulty_generation) + return { + "task_name": task_name, + "model_name": model_name, + "row_count": len(rows), + "source_format": { + "kind": SOURCE_FORMAT_KIND, + "task_field": "dataset", + "score_field": "reward", + "instance_id_definition": INSTANCE_ID_DEFINITION, + }, + "score_processing": dict(score_processing), + "difficulty_generation": difficulty_generation, + } + + +def extract_effective_bucket_count(rows: list[dict[str, Any]]) -> int: + effective_bucket_counts = { + difficulty.get("bucket_count") + for row in rows + if isinstance((difficulty := row.get("difficulty")), dict) and difficulty.get("bucket_count") is not None + } + if not effective_bucket_counts: + return 0 + if len(effective_bucket_counts) != 1: + raise ValueError(f"Expected a single effective bucket count, found {sorted(effective_bucket_counts)}") + return next(iter(effective_bucket_counts)) + + +def build_difficulty_filename_suffix(dataset_metadata: dict[str, Any]) -> str: + return f"__{dataset_metadata['difficulty_generation']['tag']}" + + +def build_difficulty_config_tag(difficulty_generation: dict[str, Any]) -> str: + method_token = abbreviate_filename_token( + optional_string(difficulty_generation.get("method")), + aliases=DIFFICULTY_METHOD_FILENAME_ALIASES, + default="diff", + ) + prior_source = optional_string((difficulty_generation.get("beta_prior_used") or {}).get("source")) + prior_token = abbreviate_filename_token(prior_source, aliases=PRIOR_SOURCE_FILENAME_ALIASES, default="none") + quantile_token = format_quantile_token(difficulty_generation["posterior_lower_quantile"]) + bucket_token = format_bucket_token( + requested_count=difficulty_generation["bucket_count_requested"], + effective_count=difficulty_generation["bucket_count_effective"], + ) + return "-".join([method_token, prior_token, quantile_token, bucket_token]) + + +def abbreviate_filename_token(value: str | None, *, aliases: dict[str, str], default: str) -> str: + if not value: + return default + return aliases.get(value, sanitize_name(value)) + + +def format_quantile_token(value: float) -> str: + return f"q{format_filename_number(value * 100.0)}" + + +def format_bucket_token(*, requested_count: int, effective_count: int) -> str: + if requested_count == effective_count: + return f"k{requested_count}" + return f"k{requested_count}e{effective_count}" + + +def annotate_dataset_metadata(dataset: Dataset, dataset_metadata: dict[str, Any]) -> None: + if not hasattr(dataset, "info") or dataset.info is None: + return + dataset.info.description = json.dumps(dataset_metadata, indent=2, sort_keys=True) + + +def validate_args(args: argparse.Namespace) -> None: + if not 0.0 < args.posterior_lower_quantile < 1.0: + raise ValueError("--posterior-lower-quantile must be between 0 and 1.") + if args.difficulty_buckets < 0: + raise ValueError("--difficulty-buckets must be non-negative.") + if args.max_instances is not None and args.max_instances <= 0: + raise ValueError("--max-instances must be positive when provided.") + + +def group_rows_by_task_and_model(rows: list[dict[str, Any]]) -> dict[tuple[str, str | None], list[dict[str, Any]]]: + rows_by_group: dict[tuple[str, str | None], list[dict[str, Any]]] = defaultdict(list) + for row in rows: + experiment_metadata = row.get("experiment_metadata") or {} + task_name = stable_string(row.get("task_name")) + model_name = optional_string(experiment_metadata.get("model_name")) + rows_by_group[(task_name, model_name)].append(row) + return dict(rows_by_group) + + +def read_jsonl(path: Path) -> list[dict[str, Any]]: + with path.open() as input_file: + return [json.loads(line) for line in input_file if line.strip()] + + +def get_base_task_name(task_name: str) -> str: + return task_name.split("@", 1)[0].split(":", 1)[0] + + +def extract_binary_counts(attempt_scores: list[float]) -> tuple[int, int] | None: + if not attempt_scores: + return None + + success_count = 0 + for score in attempt_scores: + if is_close(score, 0.0): + continue + if is_close(score, 1.0): + success_count += 1 + continue + return None + + return success_count, len(attempt_scores) + + +def make_rollout_instance_id(*, task_name: str, prompt_tokens: list[int], ground_truth: Any) -> str: + fingerprint = {"task_name": task_name, "prompt_tokens": prompt_tokens, "ground_truth": make_jsonable(ground_truth)} + digest = hashlib.sha1(canonical_json(fingerprint).encode("utf-8")).hexdigest()[:20] + task_prefix = sanitize_name(task_name) or "unknown" + return f"{task_prefix}::{digest}" + + +def canonical_json(value: Any) -> str: + return json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True, separators=(",", ":")) + + +def make_jsonable(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, list): + return [make_jsonable(item) for item in value] + if isinstance(value, tuple): + return [make_jsonable(item) for item in value] + if isinstance(value, dict): + return {stable_string(key): make_jsonable(item) for key, item in value.items()} + return stable_string(value) + + +def stable_string(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return str(value) + + +def optional_string(value: Any) -> str | None: + text = stable_string(value) + return text or None + + +def serialize_value(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True) + + +def format_filename_number(value: float) -> str: + text = f"{value:.8g}" + return text.replace("-", "m").replace(".", "p") + + +def sanitize_name(value: str) -> str: + return value.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") + + +def is_number(value: Any) -> bool: + return isinstance(value, (int, float)) and not isinstance(value, bool) and not math.isnan(float(value)) + + +def is_close(lhs: float, rhs: float) -> bool: + tolerance = EPS * max(1.0, abs(lhs), abs(rhs)) + return abs(lhs - rhs) <= tolerance + + +if __name__ == "__main__": + main() diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh new file mode 100644 index 0000000000..8935b9bc02 --- /dev/null +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -0,0 +1,41 @@ +EXP_NAME=qwen3_4b_base_dapo_rollout_probe +RUN_NAME=${EXP_NAME}_$(date +%Y%m%d_%H%M%S) +NUM_GPUS=1 +CLUSTER=ai2/jupiter +WORKSPACE=ai2/olmo-instruct +PRIORITY=urgent +BEAKER_IMAGE=nathanl/open_instruct_auto + +uv run mason.py \ + --task_name "${EXP_NAME}" \ + --description "${RUN_NAME}" \ + --cluster "${CLUSTER}" \ + --workspace "${WORKSPACE}" \ + --priority "${PRIORITY}" \ + --pure_docker_mode \ + --no_auto_dataset_cache \ + --image "${BEAKER_IMAGE}" \ + --preemptible \ + --num_nodes 1 \ + --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + --gpus "${NUM_GPUS}" \ + --budget ai2/oe-adapt \ + -- \ +uv run open_instruct/benchmark_generators.py \ + --model_name_or_path "Qwen/Qwen3-4B-Base" \ + --chat_template_name qwen_instruct_user_boxed_math \ + --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 64 \ + --dataset_mixer_list_splits train \ + --num_unique_prompts_rollout 8 \ + --num_samples_per_prompt_rollout 16 \ + --max_prompt_token_length 2048 \ + --response_length 8192 \ + --pack_length 10240 \ + --vllm_num_engines 1 \ + --vllm_tensor_parallel_size 1 \ + --vllm_top_p 1.0 \ + --temperature 1.0 \ + --apply_verifiable_reward true \ + --verification_reward 10.0 \ + --save_traces \ + --seed 1 diff --git a/tests/test_create_bucketed_difficulty.py b/tests/test_create_bucketed_difficulty.py new file mode 100644 index 0000000000..1bc6817b11 --- /dev/null +++ b/tests/test_create_bucketed_difficulty.py @@ -0,0 +1,389 @@ +"""Unit tests for posterior-aware bucketing in create_bucketed_difficulty.py.""" + +import importlib.util +import json +import math +import sys +import tempfile +import types +import unittest +from collections import Counter +from pathlib import Path +from statistics import NormalDist +from unittest.mock import patch + +import numpy as np + +SCRIPT_PATH = Path(__file__).resolve().parents[1] / "scripts/data/difficulty_sampling/create_bucketed_difficulty.py" + + +def _load_create_bucketed_difficulty_module(): + module_name = "test_create_bucketed_difficulty_script" + spec = importlib.util.spec_from_file_location(module_name, SCRIPT_PATH) + module = importlib.util.module_from_spec(spec) + + fake_datasets = types.ModuleType("datasets") + fake_datasets.Dataset = type("Dataset", (), {}) + + fake_scipy = types.ModuleType("scipy") + fake_scipy_optimize = types.ModuleType("scipy.optimize") + fake_scipy_optimize.minimize = lambda *_args, **_kwargs: types.SimpleNamespace( + success=False, message="not implemented in test stub", x=(0.0, 0.0) + ) + fake_scipy_special = types.ModuleType("scipy.special") + fake_scipy_special.betaln = lambda alpha, beta: math.lgamma(alpha) + math.lgamma(beta) - math.lgamma(alpha + beta) + fake_scipy_stats = types.ModuleType("scipy.stats") + + class _ApproximateBetaDistribution: + @staticmethod + def _mean(alpha, beta): + return alpha / (alpha + beta) + + @staticmethod + def _sigma(alpha, beta): + total = alpha + beta + variance = np.maximum(alpha * beta / (total * total * (total + 1.0)), 1e-6) + return np.sqrt(variance) + + @classmethod + def cdf(cls, x, alpha, beta): + x_array, alpha_array, beta_array = np.broadcast_arrays( + np.asarray(x, dtype=float), np.asarray(alpha, dtype=float), np.asarray(beta, dtype=float) + ) + mean = cls._mean(alpha_array, beta_array) + sigma = cls._sigma(alpha_array, beta_array) + z_scores = (x_array - mean) / (sigma * math.sqrt(2.0)) + return 0.5 * (1.0 + np.vectorize(math.erf)(z_scores)) + + @classmethod + def pdf(cls, x, alpha, beta): + x_array, alpha_array, beta_array = np.broadcast_arrays( + np.asarray(x, dtype=float), np.asarray(alpha, dtype=float), np.asarray(beta, dtype=float) + ) + mean = cls._mean(alpha_array, beta_array) + sigma = cls._sigma(alpha_array, beta_array) + z_scores = (x_array - mean) / sigma + normalizer = sigma * math.sqrt(2.0 * math.pi) + return np.exp(-0.5 * z_scores * z_scores) / normalizer + + @classmethod + def ppf(cls, q, alpha, beta): + q_array, alpha_array, beta_array = np.broadcast_arrays( + np.asarray(q, dtype=float), np.asarray(alpha, dtype=float), np.asarray(beta, dtype=float) + ) + quantiles = np.empty_like(q_array, dtype=float) + for index in np.ndindex(q_array.shape): + mean = float(cls._mean(alpha_array[index], beta_array[index])) + sigma = float(cls._sigma(alpha_array[index], beta_array[index])) + quantiles[index] = np.clip(NormalDist(mu=mean, sigma=sigma).inv_cdf(float(q_array[index])), 0.0, 1.0) + return quantiles + + fake_scipy_stats.beta = _ApproximateBetaDistribution + + modules = { + "datasets": fake_datasets, + "scipy": fake_scipy, + "scipy.optimize": fake_scipy_optimize, + "scipy.special": fake_scipy_special, + "scipy.stats": fake_scipy_stats, + } + + with patch.dict(sys.modules, modules): + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + sys.modules.pop(module_name, None) + + return module + + +MODULE = _load_create_bucketed_difficulty_module() + + +class TestCreateBucketedDifficulty(unittest.TestCase): + def test_discover_rollout_sources_resolves_directory_runs(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "demo_run_metadata.jsonl").write_text( + json.dumps({"run_name": "demo_run", "model_name": "demo-model"}) + "\n" + ) + (root / "demo_run_rollouts_000000.jsonl").write_text( + json.dumps( + { + "prompt_tokens": [1, 2, 3], + "reward": 1.0, + "finish_reason": "stop", + "dataset": "math", + "ground_truth": "4", + } + ) + + "\n" + ) + + sources = MODULE.discover_rollout_sources([str(root)]) + + self.assertEqual(len(sources), 1) + self.assertEqual(sources[0].run_name, "demo_run") + self.assertEqual(sources[0].metadata_path.name, "demo_run_metadata.jsonl") + self.assertEqual([path.name for path in sources[0].rollout_paths], ["demo_run_rollouts_000000.jsonl"]) + + def test_rollout_contributions_aggregate_and_normalize_constant_rewards(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "demo_run_metadata.jsonl").write_text( + json.dumps({"run_name": "demo_run", "model_name": "Qwen/Qwen3-4B-Base"}) + "\n" + ) + shard = root / "demo_run_rollouts_000000.jsonl" + shard.write_text( + "\n".join( + [ + json.dumps( + { + "prompt_tokens": [11, 12, 13], + "reward": 10.0, + "finish_reason": "stop", + "dataset": "math", + "ground_truth": {"answer": "4"}, + "request_info": {"timeouts": 0, "tool_errors": ""}, + } + ), + json.dumps( + { + "prompt_tokens": [11, 12, 13], + "reward": 0.0, + "finish_reason": "length", + "dataset": "math", + "ground_truth": {"answer": "4"}, + "request_info": {"timeouts": 1, "tool_errors": ""}, + } + ), + json.dumps( + { + "prompt_tokens": [21, 22, 23], + "reward": 10.0, + "finish_reason": "stop", + "dataset": "math", + "ground_truth": {"answer": "9"}, + "request_info": {"timeouts": 0, "tool_errors": ""}, + } + ), + ] + ) + + "\n" + ) + + source = MODULE.discover_rollout_sources([str(root)])[0] + contributions, malformed_records = MODULE.build_contributions_for_source( + source_run=source, task_filters=set(), strict=True + ) + + self.assertEqual(malformed_records, 0) + + rows = MODULE.aggregate_contributions(contributions) + self.assertEqual(len(rows), 2) + + rows_by_group = MODULE.group_rows_by_task_and_model(rows) + group_rows, score_processing, skipped_nonunit = MODULE.normalize_attempt_scores_for_group( + rows_by_group[("math", "Qwen/Qwen3-4B-Base")], allow_nonunit_scores=False + ) + + self.assertEqual(skipped_nonunit, 0) + self.assertEqual(score_processing["normalization"], "binary_zero_or_constant") + self.assertEqual(score_processing["positive_reward_value"], 10.0) + + easy_row = next(row for row in group_rows if row["ground_truth"] == {"answer": "4"}) + self.assertEqual(easy_row["attempt_scores"], [1.0, 0.0]) + self.assertEqual(easy_row["prompt_tokens"], [11, 12, 13]) + self.assertEqual(easy_row["finish_reasons"], ["stop", "length"]) + self.assertEqual(easy_row["score_sources"], ["math"]) + self.assertEqual(easy_row["experiment_metadata"]["model_name"], "Qwen/Qwen3-4B-Base") + self.assertIn("timeout", easy_row["warnings"]) + + def test_normalize_attempt_scores_for_group_marks_unsupported_rewards(self): + rows = [ + { + "instance_id": "example", + "task_name": "math", + "base_task_name": "math", + "prompt_tokens": [1, 2, 3], + "ground_truth": "4", + "attempt_scores": [10.0, 5.0], + "finish_reasons": ["stop", "stop"], + "experiment_metadata": { + "source_root": "/tmp/example-rollouts", + "model_name": "demo-model", + "experiment_id": None, + "experiment_name": "demo-run", + }, + "score_sources": ["math"], + "warnings": [], + } + ] + + kept_rows, score_processing, skipped_nonunit = MODULE.normalize_attempt_scores_for_group( + rows, allow_nonunit_scores=True + ) + + self.assertEqual(skipped_nonunit, 0) + self.assertFalse(score_processing["supports_binary_difficulty"]) + self.assertEqual(kept_rows[0]["attempt_scores"], [10.0, 5.0]) + self.assertIn("nonbinary_reward_scores", kept_rows[0]["warnings"]) + + dropped_rows, _, dropped_count = MODULE.normalize_attempt_scores_for_group(rows, allow_nonunit_scores=False) + + self.assertEqual(dropped_rows, []) + self.assertEqual(dropped_count, 1) + + def test_build_dataset_metadata_captures_difficulty_generation_details(self): + rows = [ + { + "instance_id": "easy", + "difficulty": { + "value": 0.1, + "posterior_mean": 0.2, + "posterior_lower_bound": 0.9, + "expected_quantile": 0.2, + "bucket_index": 0, + "bucket_count": 3, + }, + }, + { + "instance_id": "hard", + "difficulty": { + "value": 0.8, + "posterior_mean": 0.7, + "posterior_lower_bound": 0.2, + "expected_quantile": 0.9, + "bucket_index": 2, + "bucket_count": 3, + }, + }, + {"instance_id": "nonbinary", "difficulty": MODULE.make_empty_difficulty_payload()}, + ] + + metadata = MODULE.build_dataset_metadata( + rows=rows, + task_name="math", + model_name="demo-model", + requested_prior_mode="empirical-bayes", + requested_bucket_count=5, + lower_quantile=0.1, + prior=MODULE.BetaPrior(alpha=0.75, beta=1.25, source="empirical_bayes"), + binary_row_count=2, + score_processing={ + "source_field": "reward", + "output_field": "attempt_scores", + "normalization": "binary_zero_or_constant", + "positive_reward_value": 10.0, + "supports_binary_difficulty": True, + }, + ) + + self.assertEqual(metadata["task_name"], "math") + self.assertEqual(metadata["model_name"], "demo-model") + self.assertEqual(metadata["row_count"], 3) + self.assertEqual(metadata["source_format"]["kind"], "open_instruct_rollout_traces") + self.assertEqual(metadata["score_processing"]["normalization"], "binary_zero_or_constant") + self.assertEqual(metadata["score_processing"]["positive_reward_value"], 10.0) + self.assertEqual(metadata["difficulty_generation"]["method"], "beta_binomial_posterior_quantiles") + self.assertEqual(metadata["difficulty_generation"]["posterior_lower_quantile"], 0.1) + self.assertEqual(metadata["difficulty_generation"]["bucket_count_requested"], 5) + self.assertEqual(metadata["difficulty_generation"]["bucket_count_effective"], 3) + self.assertEqual(metadata["difficulty_generation"]["beta_prior_used"]["source"], "empirical_bayes") + self.assertEqual(metadata["difficulty_generation"]["beta_prior_used"]["alpha"], 0.75) + self.assertEqual(metadata["difficulty_generation"]["beta_prior_used"]["beta"], 1.25) + self.assertEqual(metadata["difficulty_generation"]["binary_instance_count"], 2) + self.assertEqual(metadata["difficulty_generation"]["nonbinary_instance_count"], 1) + + def test_annotate_dataset_metadata_stores_json_description(self): + class FakeInfo: + description = "" + + class FakeDataset: + def __init__(self): + self.info = FakeInfo() + + dataset = FakeDataset() + dataset_metadata = {"task_name": "math", "difficulty_generation": {"bucket_count_requested": 5}} + + MODULE.annotate_dataset_metadata(dataset, dataset_metadata) + + self.assertEqual(json.loads(dataset.info.description), dataset_metadata) + + def test_normalize_experiment_metadata_uses_canonical_source_root_only(self): + normalized = MODULE.normalize_experiment_metadata( + { + "source_root": "/tmp/example-rollouts", + "source_input": "/tmp/example-rollouts/demo_run_metadata.jsonl", + "model_name": "demo-model", + "experiment_id": "exp-123", + "experiment_name": "demo-run", + } + ) + + self.assertEqual( + normalized, + { + "source_root": "/tmp/example-rollouts", + "model_name": "demo-model", + "experiment_id": "exp-123", + "experiment_name": "demo-run", + }, + ) + + def test_apply_beta_binomial_difficulty_orders_rows_by_expected_quantile(self): + rows = [ + {"instance_id": "easy", "attempt_scores": [1.0, 1.0, 1.0, 1.0]}, + {"instance_id": "medium", "attempt_scores": [1.0, 1.0, 0.0, 0.0]}, + {"instance_id": "hard", "attempt_scores": [0.0, 0.0, 0.0, 0.0]}, + ] + + result = MODULE.apply_beta_binomial_difficulty( + rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=3 + ) + difficulties = {row["instance_id"]: row["difficulty"] for row in result} + + self.assertLess(difficulties["easy"]["expected_quantile"], difficulties["medium"]["expected_quantile"]) + self.assertLess(difficulties["medium"]["expected_quantile"], difficulties["hard"]["expected_quantile"]) + self.assertEqual(difficulties["easy"]["bucket_index"], 0) + self.assertEqual(difficulties["medium"]["bucket_index"], 1) + self.assertEqual(difficulties["hard"]["bucket_index"], 2) + self.assertTrue(all(difficulty["bucket_count"] == 3 for difficulty in difficulties.values())) + + def test_apply_beta_binomial_difficulty_balances_bucket_sizes(self): + rows = [ + {"instance_id": "easiest", "attempt_scores": [1.0, 1.0, 1.0, 1.0]}, + {"instance_id": "easy", "attempt_scores": [1.0, 1.0, 1.0, 0.0]}, + {"instance_id": "mid", "attempt_scores": [1.0, 1.0, 0.0, 0.0]}, + {"instance_id": "hard", "attempt_scores": [1.0, 0.0, 0.0, 0.0]}, + {"instance_id": "hardest", "attempt_scores": [0.0, 0.0, 0.0, 0.0]}, + ] + + result = MODULE.apply_beta_binomial_difficulty( + rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=2 + ) + bucket_counts = Counter(row["difficulty"]["bucket_index"] for row in result) + + self.assertEqual(bucket_counts[0], 3) + self.assertEqual(bucket_counts[1], 2) + + def test_apply_beta_binomial_difficulty_leaves_nonbinary_rows_unbucketed(self): + rows = [ + {"instance_id": "easy", "attempt_scores": [1.0, 1.0]}, + {"instance_id": "nonbinary", "attempt_scores": [0.5, 1.0]}, + {"instance_id": "hard", "attempt_scores": [0.0, 0.0]}, + ] + + result = MODULE.apply_beta_binomial_difficulty( + rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=2 + ) + difficulties = {row["instance_id"]: row["difficulty"] for row in result} + + self.assertIsNone(difficulties["nonbinary"]["value"]) + self.assertIsNone(difficulties["nonbinary"]["expected_quantile"]) + self.assertIsNone(difficulties["nonbinary"]["bucket_index"]) + self.assertIsNone(difficulties["nonbinary"]["bucket_count"]) + + +if __name__ == "__main__": + unittest.main() From 0b99fa1ade54197fa318bd5b122334813550adbd Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 1 May 2026 09:15:37 -0700 Subject: [PATCH 02/40] File perms --- scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh old mode 100644 new mode 100755 From cb30daacdb2249b579a722782322e97ab91b9883 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 1 May 2026 09:21:40 -0700 Subject: [PATCH 03/40] Script fixes --- scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 8935b9bc02..0d11d45e27 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -1,3 +1,5 @@ +#!/bin/bash + EXP_NAME=qwen3_4b_base_dapo_rollout_probe RUN_NAME=${EXP_NAME}_$(date +%Y%m%d_%H%M%S) NUM_GPUS=1 @@ -6,7 +8,7 @@ WORKSPACE=ai2/olmo-instruct PRIORITY=urgent BEAKER_IMAGE=nathanl/open_instruct_auto -uv run mason.py \ +uv run python mason.py \ --task_name "${EXP_NAME}" \ --description "${RUN_NAME}" \ --cluster "${CLUSTER}" \ From 35c69642be50397c12b83bda034e7c1b01811ade Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 1 May 2026 09:27:13 -0700 Subject: [PATCH 04/40] Explicit tokenizer --- scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 0d11d45e27..f0d797a4aa 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -25,6 +25,7 @@ uv run python mason.py \ -- \ uv run open_instruct/benchmark_generators.py \ --model_name_or_path "Qwen/Qwen3-4B-Base" \ + --tokenizer_name_or_path "Qwen/Qwen3-4B-Base" \ --chat_template_name qwen_instruct_user_boxed_math \ --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 64 \ --dataset_mixer_list_splits train \ From 6735515e7ed414023aa8d4b32fb3636dbe4b40de Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 1 May 2026 10:55:51 -0700 Subject: [PATCH 05/40] Fix save destination --- scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index f0d797a4aa..3c80a55ac3 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -7,6 +7,7 @@ CLUSTER=ai2/jupiter WORKSPACE=ai2/olmo-instruct PRIORITY=urgent BEAKER_IMAGE=nathanl/open_instruct_auto +TRACE_DIR=/weka/oe-adapt-default/tylerm/deletable_rollouts/${EXP_NAME}/${RUN_NAME} uv run python mason.py \ --task_name "${EXP_NAME}" \ @@ -24,6 +25,8 @@ uv run python mason.py \ --budget ai2/oe-adapt \ -- \ uv run open_instruct/benchmark_generators.py \ + --run_name "${RUN_NAME}" \ + --exp_name "${EXP_NAME}" \ --model_name_or_path "Qwen/Qwen3-4B-Base" \ --tokenizer_name_or_path "Qwen/Qwen3-4B-Base" \ --chat_template_name qwen_instruct_user_boxed_math \ @@ -41,4 +44,5 @@ uv run open_instruct/benchmark_generators.py \ --apply_verifiable_reward true \ --verification_reward 10.0 \ --save_traces \ + --rollouts_save_path "${TRACE_DIR}" \ --seed 1 From 2bbd3f2b9da5c497476854ac36de4ac2a6ccf792 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 1 May 2026 11:02:07 -0700 Subject: [PATCH 06/40] Tweaks for results output location --- open_instruct/benchmark_generators.py | 33 +++++++++++++++---- .../qwen3_4b_dapo_math_gen.sh | 1 + 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index 73f6622cee..92a0342278 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -36,13 +36,27 @@ logger = logger_utils.setup_logger(__name__) -# Determine data directory -if pathlib.Path("/weka").exists(): - DATA_DIR = pathlib.Path("/weka") / "finbarrt" / "open_instruct_generators_benchmark" -elif pathlib.Path("/root").exists(): - DATA_DIR = pathlib.Path("/root") / "finbarrt" / "open_instruct_generators_benchmark" -else: - DATA_DIR = pathlib.Path("/tmp") / "open_instruct_generators_benchmark" +def get_default_data_dir() -> pathlib.Path: + """Return the legacy default directory for benchmark artifacts.""" + if pathlib.Path("/weka").exists(): + return pathlib.Path("/weka") / "finbarrt" / "open_instruct_generators_benchmark" + if pathlib.Path("/root").exists(): + return pathlib.Path("/root") / "finbarrt" / "open_instruct_generators_benchmark" + return pathlib.Path("/tmp") / "open_instruct_generators_benchmark" + + +DATA_DIR = get_default_data_dir() + + +def resolve_data_dir( + args: grpo_utils.GRPOExperimentConfig, streaming_config: data_loader.StreamingDataLoaderConfig +) -> pathlib.Path: + """Resolve where benchmark artifacts should be written for this run.""" + if args.output_dir.rstrip("/") != "output": + return pathlib.Path(args.output_dir) + if streaming_config.save_traces and streaming_config.rollouts_save_path: + return pathlib.Path(streaming_config.rollouts_save_path) + return get_default_data_dir() def save_completion_lengths(batch_results: list[dict], timestamp: int, batch_idx: int): @@ -746,6 +760,8 @@ def cleanup(vllm_engines: list[ray.actor.ActorHandle], actor_manager: ray.actor. def main() -> None: """Main benchmark function.""" + global DATA_DIR + # Parse arguments using ArgumentParserPlus parser = utils.ArgumentParserPlus( ( @@ -768,8 +784,11 @@ def main() -> None: parser.parse_args_into_dataclasses(), ) + DATA_DIR = resolve_data_dir(args, streaming_config) + # Ensure data directory exists DATA_DIR.mkdir(parents=True, exist_ok=True) + logger.info(f"Writing benchmark artifacts to {DATA_DIR}") # Calculate flops per token before starting vLLM logger.info("Calculating model FLOPs per token...") diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 3c80a55ac3..7c79972ec6 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -27,6 +27,7 @@ uv run python mason.py \ uv run open_instruct/benchmark_generators.py \ --run_name "${RUN_NAME}" \ --exp_name "${EXP_NAME}" \ + --output_dir "${TRACE_DIR}" \ --model_name_or_path "Qwen/Qwen3-4B-Base" \ --tokenizer_name_or_path "Qwen/Qwen3-4B-Base" \ --chat_template_name qwen_instruct_user_boxed_math \ From 2ab2d637697f9fe0dac803da537ae2f3eb4f3983 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 07:57:04 -0700 Subject: [PATCH 07/40] Adjust to support image builder script --- .../qwen3_4b_dapo_math_gen.sh | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 7c79972ec6..68256ed30e 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -1,13 +1,19 @@ #!/bin/bash -EXP_NAME=qwen3_4b_base_dapo_rollout_probe -RUN_NAME=${EXP_NAME}_$(date +%Y%m%d_%H%M%S) -NUM_GPUS=1 -CLUSTER=ai2/jupiter -WORKSPACE=ai2/olmo-instruct -PRIORITY=urgent -BEAKER_IMAGE=nathanl/open_instruct_auto -TRACE_DIR=/weka/oe-adapt-default/tylerm/deletable_rollouts/${EXP_NAME}/${RUN_NAME} +EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_rollout_probe}" +RUN_NAME="${RUN_NAME:-${EXP_NAME}_$(date +%Y%m%d_%H%M%S)}" + +NUM_GPUS="${NUM_GPUS:-1}" +BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" + +CLUSTER="${CLUSTER:-ai2/jupiter}" +PRIORITY="${PRIORITY:-urgent}" +WORKSPACE="${WORKSPACE:-ai2/olmo-instruct}" +TRACE_DIR="${TRACE_DIR:-/weka/oe-adapt-default/tylerm/deletable_rollouts/${EXP_NAME}/${RUN_NAME}}" + +if [[ $# -gt 0 ]]; then + shift +fi uv run python mason.py \ --task_name "${EXP_NAME}" \ @@ -46,4 +52,4 @@ uv run open_instruct/benchmark_generators.py \ --verification_reward 10.0 \ --save_traces \ --rollouts_save_path "${TRACE_DIR}" \ - --seed 1 + --seed 1 "$@" From ae30b59a2da1847a4f5d0490b1f70e497344b644 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 10:16:11 -0700 Subject: [PATCH 08/40] Add source row id explicitly --- open_instruct/benchmark_generators.py | 3 +++ open_instruct/data_loader.py | 12 ++++++++++++ open_instruct/dataset_transformation.py | 10 ++++++++-- open_instruct/grpo_fast.py | 2 ++ open_instruct/model_utils.py | 8 ++++++++ open_instruct/rl_utils.py | 8 ++++++++ .../difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 2 +- 7 files changed, 42 insertions(+), 3 deletions(-) diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index 92a0342278..dd11b3b730 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -193,6 +193,8 @@ def maybe_save_scored_rollout_traces( dataset_name=example[dataset_transformation.VERIFIER_SOURCE_KEY], raw_query=example[dataset_transformation.RAW_PROMPT_KEY], advantage_normalization_type=streaming_config.advantage_normalization_type, + source_row_id=example.get(dataset_transformation.SOURCE_ROW_ID_KEY), + source_dataset=example.get(dataset_transformation.DATASET_ORIGIN_KEY), ) save_rollouts_to_disk( streaming_config.rollouts_save_path, @@ -272,6 +274,7 @@ def setup_dataset( dataset_cache_mode=streaming_config.dataset_cache_mode, dataset_local_cache_dir=streaming_config.dataset_local_cache_dir, dataset_skip_cache=streaming_config.dataset_skip_cache, + drop_dataset_source=not streaming_config.save_traces, ) # Shuffle dataset diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 01b62ed40b..a4ee28e83a 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -36,10 +36,12 @@ from open_instruct import data_types, padding_free_collator, utils from open_instruct.data_types import EnvConfig, EnvConfigEntry from open_instruct.dataset_transformation import ( + DATASET_ORIGIN_KEY, ENV_CONFIG_KEY, GROUND_TRUTHS_KEY, INPUT_IDS_PROMPT_KEY, RAW_PROMPT_KEY, + SOURCE_ROW_ID_KEY, TOOLS_COLUMN_KEY, VERIFIER_SOURCE_KEY, ) @@ -780,6 +782,8 @@ def accumulate_inference_batches( all_active_tools = [] all_scores = [] all_indices = [] + all_source_row_ids = [] + all_source_datasets = [] all_percent_solved = [] all_model_steps = [] total_filtered_prompts = 0 @@ -831,6 +835,8 @@ def accumulate_inference_batches( ground_truth = example[GROUND_TRUTHS_KEY] dataset_name = example[VERIFIER_SOURCE_KEY] raw_query = example[RAW_PROMPT_KEY] + source_row_id = example.get(SOURCE_ROW_ID_KEY) + source_dataset = example.get(DATASET_ORIGIN_KEY) sample_active_tools = example.get(TOOLS_COLUMN_KEY) if replenish_prompts: @@ -861,6 +867,8 @@ def accumulate_inference_batches( k_raw_queries = repeat_each([raw_query], generation_config.n) k_active_tools = repeat_each([sample_active_tools], generation_config.n) k_indices = repeat_each([result.index], generation_config.n) + k_source_row_ids = repeat_each([source_row_id], generation_config.n) + k_source_datasets = repeat_each([source_dataset], generation_config.n) percent_solved = np.mean(result.reward_scores).item() / max_possible_score if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: @@ -898,6 +906,8 @@ def accumulate_inference_batches( all_raw_queries.extend(k_raw_queries) all_active_tools.extend(k_active_tools) all_indices.extend(k_indices) + all_source_row_ids.extend(k_source_row_ids) + all_source_datasets.extend(k_source_datasets) all_decoded_responses.extend(decoded_responses) all_scores.extend(result.reward_scores) all_reward_metrics.append(result.reward_metrics) @@ -1000,6 +1010,8 @@ def accumulate_inference_batches( indices=all_indices, scores=all_scores, active_tools=all_active_tools if all_active_tools else None, + source_row_ids=all_source_row_ids, + source_datasets=all_source_datasets, model_steps=all_model_steps, ) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 33b297f492..193116e1bf 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -869,6 +869,7 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"): GROUND_TRUTHS_KEY = "ground_truth" VERIFIER_SOURCE_KEY = "dataset" RAW_PROMPT_KEY = "prompt" +SOURCE_ROW_ID_KEY = "source_row_id" @dataclass @@ -944,7 +945,8 @@ def remove_dataset_source_field(dataset: Dataset) -> Dataset: # Cache version: increment this when transformation logic changes significantly # to invalidate old caches. v6: Added return_dict=False to apply_chat_template calls for transformers 5.x. -DATASET_CACHE_VERSION = "v6" +# v7: Preserve original source row ids in transformed datasets for rollout trace joins. +DATASET_CACHE_VERSION = "v7" def _normalize_env_config_column(row: dict[str, Any]) -> None: @@ -1630,6 +1632,8 @@ def __post_init__(self): num_proc=max_num_processes(), ) assert isinstance(dataset, Dataset), f"Expected Dataset, got {type(dataset)}" + if SOURCE_ROW_ID_KEY not in dataset.column_names: + dataset = dataset.add_column(SOURCE_ROW_ID_KEY, range(len(dataset))) self.dataset = dataset if self.dataset_range is None: dataset_range = len(self.dataset) @@ -1731,6 +1735,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): target_columns = dataset.column_names if dc.target_columns is None else dc.target_columns # Always preserve dataset_source if it exists target_columns = _preserve_column(DATASET_ORIGIN_KEY, dataset, target_columns) + target_columns = _preserve_column(SOURCE_ROW_ID_KEY, dataset, target_columns) target_columns = _preserve_column(TOOLS_COLUMN_KEY, dataset, target_columns) target_columns = _preserve_column(ENV_CONFIG_KEY, dataset, target_columns) @@ -2162,6 +2167,7 @@ def get_cached_dataset_tulu( dataset_skip_cache: bool = False, dataset_config_seed: int = 42, system_prompt_override: str | None = None, + drop_dataset_source: bool = True, ) -> Dataset: return get_cached_dataset_tulu_with_statistics( dataset_mixer_list=dataset_mixer_list, @@ -2175,7 +2181,7 @@ def get_cached_dataset_tulu( hf_entity=hf_entity, dataset_local_cache_dir=dataset_local_cache_dir, dataset_skip_cache=dataset_skip_cache, - drop_dataset_source=True, + drop_dataset_source=drop_dataset_source, dataset_config_seed=dataset_config_seed, system_prompt_override=system_prompt_override, )[0] diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 886162529b..d56812326b 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1180,6 +1180,7 @@ def setup_datasets( dataset_local_cache_dir=streaming_config.dataset_local_cache_dir, dataset_skip_cache=streaming_config.dataset_skip_cache, system_prompt_override=system_prompt_override, + drop_dataset_source=not streaming_config.save_traces, ) _validate_and_log_dataset_tools(train_dataset, configured_tool_call_names, "train_dataset") @@ -1197,6 +1198,7 @@ def setup_datasets( dataset_local_cache_dir=streaming_config.dataset_local_cache_dir, dataset_skip_cache=streaming_config.dataset_skip_cache, system_prompt_override=system_prompt_override, + drop_dataset_source=not streaming_config.save_traces, ) _validate_and_log_dataset_tools(eval_dataset, configured_tool_call_names, "eval_dataset") diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 6a1ac44eb1..eaed9395a9 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -132,6 +132,8 @@ class Batch: indices: list[int] | None scores: list[float] | None active_tools: list[list[str] | None] | None = None + source_row_ids: list[int | None] | None = None + source_datasets: list[str | None] | None = None model_steps: list[int] = field(default_factory=list) def __getitem__(self, key: slice | int | list[int]) -> "Batch": @@ -147,6 +149,8 @@ def __getitem__(self, key: slice | int | list[int]) -> "Batch": indices=self.indices[key] if self.indices is not None else None, scores=self.scores[key] if self.scores is not None else None, active_tools=self.active_tools[key] if self.active_tools is not None else None, + source_row_ids=self.source_row_ids[key] if self.source_row_ids is not None else None, + source_datasets=self.source_datasets[key] if self.source_datasets is not None else None, model_steps=self.model_steps[key], ) elif isinstance(key, int): @@ -160,6 +164,8 @@ def __getitem__(self, key: slice | int | list[int]) -> "Batch": indices=[self.indices[key]] if self.indices is not None else None, scores=[self.scores[key]] if self.scores is not None else None, active_tools=[self.active_tools[key]] if self.active_tools is not None else None, + source_row_ids=[self.source_row_ids[key]] if self.source_row_ids is not None else None, + source_datasets=[self.source_datasets[key]] if self.source_datasets is not None else None, model_steps=[self.model_steps[key]], ) else: @@ -175,6 +181,8 @@ def __getitem__(self, key: slice | int | list[int]) -> "Batch": indices=[self.indices[i] for i in key] if self.indices is not None else None, scores=[self.scores[i] for i in key] if self.scores is not None else None, active_tools=[self.active_tools[i] for i in key] if self.active_tools is not None else None, + source_row_ids=[self.source_row_ids[i] for i in key] if self.source_row_ids is not None else None, + source_datasets=[self.source_datasets[i] for i in key] if self.source_datasets is not None else None, model_steps=[self.model_steps[i] for i in key], ) diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 9c00e13da9..ceb18496d5 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -39,6 +39,8 @@ class RolloutRecord: finish_reason: str dataset: str ground_truth: list[int] | None = None + source_row_id: int | None = None + source_dataset: str | None = None request_info: dict | None = None logprobs: list[float] | None = None @@ -116,6 +118,8 @@ def _save_rollouts( finish_reason=result.finish_reasons[i], dataset=batch.datasets[i], ground_truth=batch.ground_truths[i], + source_row_id=batch.source_row_ids[i] if batch.source_row_ids is not None else None, + source_dataset=batch.source_datasets[i] if batch.source_datasets is not None else None, request_info=_get_request_info_for_sample(result.request_info, i), logprobs=result.logprobs[i] if result.logprobs else None, ) @@ -167,6 +171,8 @@ def build_rollout_batch_and_advantages( dataset_name: str, raw_query: str, advantage_normalization_type: str, + source_row_id: int | None = None, + source_dataset: str | None = None, ) -> tuple[model_utils.Batch, np.ndarray]: """Convert a scored inference result into the rollout format used by difficulty bucketing.""" if result.reward_scores is None: @@ -188,6 +194,8 @@ def build_rollout_batch_and_advantages( decoded_responses=None, indices=indices, scores=scores, + source_row_ids=[source_row_id] * num_samples, + source_datasets=[source_dataset] * num_samples, model_steps=[result.model_step] * num_samples, ) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 68256ed30e..14f2e200a7 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -37,7 +37,7 @@ uv run open_instruct/benchmark_generators.py \ --model_name_or_path "Qwen/Qwen3-4B-Base" \ --tokenizer_name_or_path "Qwen/Qwen3-4B-Base" \ --chat_template_name qwen_instruct_user_boxed_math \ - --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 64 \ + --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 8 \ --dataset_mixer_list_splits train \ --num_unique_prompts_rollout 8 \ --num_samples_per_prompt_rollout 16 \ From 866319f417bc5ae515dcee340f144dfe92855bc9 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 10:37:05 -0700 Subject: [PATCH 09/40] Drop sample limit --- .../create_bucketed_difficulty.py | 74 +++++++++++++++++-- .../qwen3_4b_dapo_math_gen.sh | 2 +- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py index 3e12db0499..c0850ad2c4 100644 --- a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py +++ b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py @@ -16,8 +16,8 @@ For each traced prompt instance it: 1. loads rollout shards written by ``save_rollouts_to_disk()``, -2. groups attempts by a deterministic fingerprint over task name, prompt tokens, - and ground truth, +2. groups attempts by source dataset row identity when available, otherwise by a + deterministic fingerprint over task name, prompt tokens, and ground truth, 3. normalizes binary verifiable rewards from ``{0, C}`` back to ``{0, 1}`` when possible, 4. fits a Beta prior across binary outcomes and estimates per-item success @@ -73,7 +73,9 @@ DIFFICULTY_METHOD_FILENAME_ALIASES = {DIFFICULTY_GENERATION_METHOD: "bbq"} PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} SOURCE_FORMAT_KIND = "open_instruct_rollout_traces" -INSTANCE_ID_DEFINITION = "sha1(task_name,prompt_tokens,ground_truth)" +INSTANCE_ID_DEFINITION = ( + "sha1(source_dataset,source_row_id) when available; otherwise sha1(task_name,prompt_tokens,ground_truth)" +) @dataclass(frozen=True) @@ -425,8 +427,11 @@ def build_rollout_contribution( if task_name is None: raise ValueError("missing dataset/verifier source") + source_dataset = normalize_source_dataset(record.get("source_dataset")) + source_row_id = normalize_source_row_id(record.get("source_row_id")) + prompt_tokens = normalize_token_list(record.get("prompt_tokens")) - if prompt_tokens is None: + if prompt_tokens is None and (source_dataset is None or source_row_id is None): raise ValueError("missing or invalid prompt_tokens") reward = extract_numeric_reward(record.get("reward")) @@ -438,12 +443,18 @@ def build_rollout_contribution( return { "instance_id": make_rollout_instance_id( - task_name=task_name, prompt_tokens=prompt_tokens, ground_truth=ground_truth + task_name=task_name, + prompt_tokens=prompt_tokens, + ground_truth=ground_truth, + source_dataset=source_dataset, + source_row_id=source_row_id, ), "task_name": task_name, "base_task_name": get_base_task_name(task_name), "prompt_tokens": prompt_tokens, "ground_truth": ground_truth, + "source_dataset": source_dataset, + "source_row_id": source_row_id, "score_source": task_name, "attempt_scores": [reward], "finish_reasons": [finish_reason] if finish_reason else [], @@ -462,10 +473,43 @@ def normalize_task_name(value: Any) -> str | None: return None if isinstance(value, str): return value + if isinstance(value, (list, tuple)) and len(value) == 1: + return normalize_task_name(value[0]) + serialized = serialize_value(value) + return serialized or None + + +def normalize_source_dataset(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)) and len(value) == 1: + return normalize_source_dataset(value[0]) serialized = serialize_value(value) return serialized or None +def normalize_source_row_id(value: Any) -> int | None: + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + if not math.isfinite(value) or not value.is_integer(): + return None + return int(value) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + return int(stripped) + except ValueError: + return None + return None + + def normalize_token_list(value: Any) -> list[int] | None: if not isinstance(value, list): return None @@ -908,6 +952,8 @@ def build_dataset_metadata( "kind": SOURCE_FORMAT_KIND, "task_field": "dataset", "score_field": "reward", + "source_dataset_field": "source_dataset", + "source_row_id_field": "source_row_id", "instance_id_definition": INSTANCE_ID_DEFINITION, }, "score_processing": dict(score_processing), @@ -1014,7 +1060,23 @@ def extract_binary_counts(attempt_scores: list[float]) -> tuple[int, int] | None return success_count, len(attempt_scores) -def make_rollout_instance_id(*, task_name: str, prompt_tokens: list[int], ground_truth: Any) -> str: +def make_rollout_instance_id( + *, + task_name: str, + prompt_tokens: list[int] | None, + ground_truth: Any, + source_dataset: str | None = None, + source_row_id: int | None = None, +) -> str: + if source_dataset is not None and source_row_id is not None: + fingerprint = {"source_dataset": source_dataset, "source_row_id": source_row_id} + digest = hashlib.sha1(canonical_json(fingerprint).encode("utf-8")).hexdigest()[:20] + task_prefix = sanitize_name(source_dataset) or "unknown" + return f"{task_prefix}::{digest}" + + if prompt_tokens is None: + raise ValueError("prompt_tokens are required when source row identity is unavailable") + fingerprint = {"task_name": task_name, "prompt_tokens": prompt_tokens, "ground_truth": make_jsonable(ground_truth)} digest = hashlib.sha1(canonical_json(fingerprint).encode("utf-8")).hexdigest()[:20] task_prefix = sanitize_name(task_name) or "unknown" diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 14f2e200a7..374b54e714 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -37,7 +37,7 @@ uv run open_instruct/benchmark_generators.py \ --model_name_or_path "Qwen/Qwen3-4B-Base" \ --tokenizer_name_or_path "Qwen/Qwen3-4B-Base" \ --chat_template_name qwen_instruct_user_boxed_math \ - --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 8 \ + --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered \ --dataset_mixer_list_splits train \ --num_unique_prompts_rollout 8 \ --num_samples_per_prompt_rollout 16 \ From 13c307b5d1d3c17917b10dde5affaa44a89071e1 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 10:56:58 -0700 Subject: [PATCH 10/40] Persist source / id and experiment id in the map --- open_instruct/benchmark_generators.py | 1 + open_instruct/rl_utils.py | 4 +- .../create_bucketed_difficulty.py | 57 ++++++++++++------- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index dd11b3b730..bfb980e426 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -103,6 +103,7 @@ def save_config( "model_config": dataclasses.asdict(model_config), "streaming_config": dataclasses.asdict(streaming_config), "timestamp": timestamp, + "experiment_id": os.environ.get("BEAKER_WORKLOAD_ID") or None, } with open(config_path, "w") as f: diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index ceb18496d5..3a5a9e7fba 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -25,6 +25,7 @@ class RolloutMetadata: git_commit: str model_name: str timestamp: str + experiment_id: str | None = None @dataclass @@ -49,7 +50,7 @@ def save_rollout_metadata(save_path: str, run_name: str, model_name: str | None) """Save metadata about the rollout collection to disk. Creates a JSONL file containing run information including git commit, - model name, and timestamp for traceability. + model name, runtime experiment id, and timestamp for traceability. Args: save_path: Directory to save metadata file. @@ -60,6 +61,7 @@ def save_rollout_metadata(save_path: str, run_name: str, model_name: str | None) run_name=run_name, git_commit=utils.get_git_commit(), model_name=model_name or "unknown", + experiment_id=os.environ.get("BEAKER_WORKLOAD_ID") or None, timestamp=datetime.datetime.now(datetime.timezone.utc).isoformat(), ) metadata_path = os.path.join(save_path, f"{run_name}_metadata.jsonl") diff --git a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py index c0850ad2c4..39f2908681 100644 --- a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py +++ b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py @@ -16,7 +16,7 @@ For each traced prompt instance it: 1. loads rollout shards written by ``save_rollouts_to_disk()``, -2. groups attempts by source dataset row identity when available, otherwise by a +2. groups attempts by source dataset identity when available, otherwise by a deterministic fingerprint over task name, prompt tokens, and ground truth, 3. normalizes binary verifiable rewards from ``{0, C}`` back to ``{0, 1}`` when possible, @@ -74,7 +74,7 @@ PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} SOURCE_FORMAT_KIND = "open_instruct_rollout_traces" INSTANCE_ID_DEFINITION = ( - "sha1(source_dataset,source_row_id) when available; otherwise sha1(task_name,prompt_tokens,ground_truth)" + "source_dataset::source_dataset_id when available; otherwise sha1(task_name,prompt_tokens,ground_truth)" ) @@ -239,9 +239,10 @@ def main(argv: list[str] | None = None) -> None: group_rows, prior=prior, lower_quantile=args.posterior_lower_quantile, num_buckets=args.difficulty_buckets ) group_rows = sorted(group_rows, key=lambda row: row["instance_id"]) + output_rows = strip_output_only_rollout_fields(group_rows) dataset_metadata = build_dataset_metadata( - rows=group_rows, + rows=output_rows, task_name=task_name, model_name=model_name, requested_prior_mode=args.beta_prior, @@ -269,7 +270,7 @@ def main(argv: list[str] | None = None) -> None: model_name, ) - dataset = Dataset.from_list(group_rows) + dataset = Dataset.from_list(output_rows) annotate_dataset_metadata(dataset, dataset_metadata) output_jsonl, schema_json, metadata_json = build_output_paths( output_root, task_name=task_name, model_name=model_name, dataset_metadata=dataset_metadata @@ -278,7 +279,7 @@ def main(argv: list[str] | None = None) -> None: output_jsonl=output_jsonl, schema_json=schema_json, metadata_json=metadata_json, - rows=group_rows, + rows=output_rows, dataset=dataset, dataset_metadata=dataset_metadata, ) @@ -286,10 +287,10 @@ def main(argv: list[str] | None = None) -> None: if args.push_to_hub is not None: dataset.push_to_hub(args.push_to_hub, split=args.split, private=True) - written_outputs.append((task_name, model_name, len(group_rows), output_jsonl, schema_json, metadata_json)) + written_outputs.append((task_name, model_name, len(output_rows), output_jsonl, schema_json, metadata_json)) logger.info( "Wrote %s rows for task=%s model=%s to %s, %s, and %s.", - len(group_rows), + len(output_rows), task_name, model_name, output_jsonl, @@ -346,8 +347,8 @@ def build_rollout_source_from_metadata(metadata_path: Path, *, input_arg: str) - raise FileNotFoundError(f"Could not find rollout shards for run {run_name} next to {metadata_path}") return RolloutSource( input_arg=input_arg, - root_path=metadata_path.parent.resolve(), - metadata_path=metadata_path.resolve(), + root_path=metadata_path.parent.absolute(), + metadata_path=metadata_path.absolute(), rollout_paths=rollout_paths, run_name=run_name, ) @@ -415,6 +416,7 @@ def read_rollout_metadata(metadata_path: Path, *, fallback_run_name: str) -> dic return { "run_name": optional_string(metadata.get("run_name")) or fallback_run_name, "model_name": optional_string(metadata.get("model_name")), + "experiment_id": optional_string(metadata.get("experiment_id")), "git_commit": optional_string(metadata.get("git_commit")), "timestamp": optional_string(metadata.get("timestamp")), } @@ -428,10 +430,10 @@ def build_rollout_contribution( raise ValueError("missing dataset/verifier source") source_dataset = normalize_source_dataset(record.get("source_dataset")) - source_row_id = normalize_source_row_id(record.get("source_row_id")) + source_dataset_id = extract_source_dataset_id(record) prompt_tokens = normalize_token_list(record.get("prompt_tokens")) - if prompt_tokens is None and (source_dataset is None or source_row_id is None): + if prompt_tokens is None and (source_dataset is None or source_dataset_id is None): raise ValueError("missing or invalid prompt_tokens") reward = extract_numeric_reward(record.get("reward")) @@ -447,21 +449,21 @@ def build_rollout_contribution( prompt_tokens=prompt_tokens, ground_truth=ground_truth, source_dataset=source_dataset, - source_row_id=source_row_id, + source_dataset_id=source_dataset_id, ), "task_name": task_name, "base_task_name": get_base_task_name(task_name), "prompt_tokens": prompt_tokens, "ground_truth": ground_truth, "source_dataset": source_dataset, - "source_row_id": source_row_id, + "source_dataset_id": source_dataset_id, "score_source": task_name, "attempt_scores": [reward], "finish_reasons": [finish_reason] if finish_reason else [], "experiment_metadata": { "source_root": str(source_run.root_path), "model_name": run_metadata["model_name"], - "experiment_id": None, + "experiment_id": run_metadata["experiment_id"], "experiment_name": run_metadata["run_name"], }, "warnings": extract_rollout_warnings(record.get("request_info")), @@ -490,7 +492,15 @@ def normalize_source_dataset(value: Any) -> str | None: return serialized or None -def normalize_source_row_id(value: Any) -> int | None: +def extract_source_dataset_id(record: dict[str, Any]) -> int | None: + for field_name in ("source_dataset_id", "source_row_id"): + source_dataset_id = normalize_source_dataset_id(record.get(field_name)) + if source_dataset_id is not None: + return source_dataset_id + return None + + +def normalize_source_dataset_id(value: Any) -> int | None: if value is None or isinstance(value, bool): return None if isinstance(value, int): @@ -578,6 +588,13 @@ def aggregate_contributions(contributions: list[dict[str, Any]]) -> list[dict[st return rows +def strip_output_only_rollout_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [ + {key: value for key, value in row.items() if key not in {"prompt_tokens", "ground_truth"}} + for row in rows + ] + + def normalize_attempt_scores_for_group( rows: list[dict[str, Any]], *, allow_nonunit_scores: bool ) -> tuple[list[dict[str, Any]], dict[str, Any], int]: @@ -953,6 +970,7 @@ def build_dataset_metadata( "task_field": "dataset", "score_field": "reward", "source_dataset_field": "source_dataset", + "source_dataset_id_field": "source_dataset_id", "source_row_id_field": "source_row_id", "instance_id_definition": INSTANCE_ID_DEFINITION, }, @@ -1066,13 +1084,10 @@ def make_rollout_instance_id( prompt_tokens: list[int] | None, ground_truth: Any, source_dataset: str | None = None, - source_row_id: int | None = None, + source_dataset_id: int | None = None, ) -> str: - if source_dataset is not None and source_row_id is not None: - fingerprint = {"source_dataset": source_dataset, "source_row_id": source_row_id} - digest = hashlib.sha1(canonical_json(fingerprint).encode("utf-8")).hexdigest()[:20] - task_prefix = sanitize_name(source_dataset) or "unknown" - return f"{task_prefix}::{digest}" + if source_dataset is not None and source_dataset_id is not None: + return f"{source_dataset}::{source_dataset_id}" if prompt_tokens is None: raise ValueError("prompt_tokens are required when source row identity is unavailable") From 6af2b45380d7fca3a7893556399d0b35da69b7a1 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 11:03:35 -0700 Subject: [PATCH 11/40] Small changes for sample size samller than batch size --- open_instruct/benchmark_generators.py | 121 +++++++++++++++++++------- 1 file changed, 88 insertions(+), 33 deletions(-) diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index bfb980e426..061984fb0a 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -121,6 +121,7 @@ def save_benchmark_results_to_csv( """Save benchmark results to CSV file.""" git_commit = utils.get_git_commit() agg_results = aggregate_results(results) + total_samples = len(agg_results["response_lengths"]) csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv" row_data = { @@ -134,7 +135,9 @@ def save_benchmark_results_to_csv( "total_time": total_time, "total_generation_time": agg_results["total_generation_time"], "total_weight_sync_time": agg_results["total_weight_sync_time"], - "generation_time_percentage": (agg_results["total_generation_time"] / total_time) * 100, + "generation_time_percentage": (agg_results["total_generation_time"] / total_time) * 100 + if total_time > 0 + else 0, "weight_sync_time_percentage": (agg_results["total_weight_sync_time"] / total_time) * 100 if total_time > 0 else 0, @@ -144,12 +147,7 @@ def save_benchmark_results_to_csv( "avg_mbu": agg_results["avg_mbu"], "avg_generation_time_per_batch": agg_results["avg_generation_time"], "avg_weight_sync_time_per_batch": agg_results["avg_weight_sync_time"], - "avg_new_tokens_per_sample": agg_results["total_num_new_tokens"] - / ( - len(results) - * streaming_config.num_unique_prompts_rollout - * streaming_config.num_samples_per_prompt_rollout - ), + "avg_new_tokens_per_sample": agg_results["total_num_new_tokens"] / total_samples if total_samples > 0 else 0, } csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv" @@ -385,20 +383,16 @@ def submission_thread( dataset: datasets.Dataset, generation_config: vllm_utils.SamplingConfig, stop_event: threading.Event, - batch_size: int, - start_batch_idx: int, - num_batches: int, + batch_specs: list[tuple[int, int, int]], ) -> None: """Thread that submits prompts to the queue.""" logger.info("[Submission Thread] Starting prompt submission") - for batch_idx in range(start_batch_idx, start_batch_idx + num_batches): + for batch_idx, start_idx, end_idx in batch_specs: if stop_event.is_set(): logger.info("[Submission Thread] Stopped due to stop event") break # Get batch data from dataset - start_idx = batch_idx * batch_size - end_idx = min(start_idx + batch_size, len(dataset)) batch_data = dataset[start_idx:end_idx] prompts = batch_data[dataset_transformation.INPUT_IDS_PROMPT_KEY] @@ -412,7 +406,21 @@ def submission_thread( generation_config=generation_config, ) ) - logger.info(f"[Submission Thread] All {num_batches} batches submitted") + logger.info(f"[Submission Thread] All {len(batch_specs)} batches submitted") + + +def build_batch_specs( + *, dataset_len: int, batch_size: int, start_batch_idx: int, num_batches: int +) -> list[tuple[int, int, int]]: + """Return (batch_idx, start_idx, end_idx) triples for non-empty batches only.""" + batch_specs = [] + for batch_idx in range(start_batch_idx, start_batch_idx + num_batches): + start_idx = batch_idx * batch_size + if start_idx >= dataset_len: + break + end_idx = min(start_idx + batch_size, dataset_len) + batch_specs.append((batch_idx, start_idx, end_idx)) + return batch_specs def run_benchmark( @@ -430,6 +438,9 @@ def run_benchmark( num_batches: int = 5, ) -> list[dict[str, Any]]: """Run the full benchmark.""" + if len(dataset) == 0: + raise ValueError("Benchmark dataset is empty after loading and filtering.") + logger.info( f"Starting benchmark with 1 warmup batch + {num_batches - 1} main batches of size {streaming_config.num_unique_prompts_rollout}" ) @@ -508,25 +519,42 @@ def run_benchmark( step=0, total_samples_written=total_samples_written, ) - logger.info(f"Submitting {num_batches - 1} batches for main benchmark...") - submission_future = executor.submit( - submission_thread, - param_prompt_Q, - dataset, - generation_config, - stop_event, - streaming_config.num_unique_prompts_rollout, - 1, - num_batches - 1, + main_batch_specs = build_batch_specs( + dataset_len=len(dataset), + batch_size=streaming_config.num_unique_prompts_rollout, + start_batch_idx=1, + num_batches=num_batches - 1, ) + if not main_batch_specs: + logger.warning( + "No main benchmark batches remain after warmup because the dataset only has %s prompt(s), " + "which fit entirely in the warmup batch size of %s.", + len(dataset), + streaming_config.num_unique_prompts_rollout, + ) + submission_future = None + else: + if len(main_batch_specs) < num_batches - 1: + logger.info( + "Submitting %s main benchmark batch(es) instead of %s because the dataset only has %s prompt(s).", + len(main_batch_specs), + num_batches - 1, + len(dataset), + ) + else: + logger.info(f"Submitting {len(main_batch_specs)} batches for main benchmark...") + + submission_future = executor.submit( + submission_thread, param_prompt_Q, dataset, generation_config, stop_event, main_batch_specs + ) # Process remaining batches with timing - for batch_idx in range(1, num_batches): + for batch_position, (batch_idx, batch_start_idx, batch_end_idx) in enumerate(main_batch_specs, start=1): # Quick health check! - if submission_future.done(): + if submission_future is not None and submission_future.done(): submission_future.result() # Collect all results for this batch (one per prompt) using non-blocking polling - num_prompts = streaming_config.num_unique_prompts_rollout + num_prompts = batch_end_idx - batch_start_idx batch_results = [] batch_deadline = time.time() + 1200 while len(batch_results) < num_prompts: @@ -535,7 +563,9 @@ def run_benchmark( batch_results.append(result) except Empty: if time.time() > batch_deadline: - raise TimeoutError(f"Batch timed out, got {len(batch_results)}/{num_prompts}") from None + raise TimeoutError( + f"Batch {batch_idx} timed out, got {len(batch_results)}/{num_prompts}" + ) from None total_samples_written = maybe_save_scored_rollout_traces( batch_results, @@ -609,7 +639,7 @@ def run_benchmark( save_completion_lengths([result_dict], timestamp, batch_idx) results.append(result_dict) logger.info( - f"Batch {batch_idx}/{num_batches - 1}: " + f"Batch {batch_position}/{len(main_batch_specs)}: " f"{result_dict['tokens_per_second']:.2f} new tokens/sec, " f"MFU: {result_dict['mfu']:.2f}%, " f"MBU: {result_dict['mbu']:.2f}%, " @@ -618,6 +648,9 @@ def run_benchmark( f"total new tokens: {total_new_tokens}" ) + if submission_future is not None: + submission_future.result() + # Calculate total time for main benchmark only main_benchmark_time = sum(r["generation_time"] for r in results) @@ -657,6 +690,24 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]: prompt_lengths.extend(result["prompt_lengths"]) num_results = len(results) + if num_results == 0: + return { + "total_mfu": 0.0, + "total_mbu": 0.0, + "total_tokens_per_second": 0.0, + "total_generation_time": 0.0, + "total_weight_sync_time": 0.0, + "total_num_new_tokens": 0, + "finish_reasons": finish_reasons, + "response_lengths": response_lengths, + "prompt_lengths": prompt_lengths, + "avg_tokens_per_second": 0.0, + "avg_mfu": 0.0, + "avg_mbu": 0.0, + "avg_generation_time": 0.0, + "avg_weight_sync_time": 0.0, + } + avg_tokens_per_second = total_num_new_tokens / total_generation_time if total_generation_time > 0 else 0 avg_mfu = total_mfu / num_results avg_mbu = total_mbu / num_results @@ -691,10 +742,8 @@ def print_summary( """Print benchmark summary statistics.""" agg_results = aggregate_results(results) - total_samples = ( - len(results) * streaming_config.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout - ) - avg_new_tokens_per_sample = agg_results["total_num_new_tokens"] / total_samples + total_samples = len(agg_results["response_lengths"]) + avg_new_tokens_per_sample = agg_results["total_num_new_tokens"] / total_samples if total_samples > 0 else 0 print("\n" + "=" * 60) print("BENCHMARK SUMMARY") @@ -707,6 +756,12 @@ def print_summary( print(f"Unique prompts per batch: {streaming_config.num_unique_prompts_rollout}") print(f"Num rollouts: {streaming_config.num_samples_per_prompt_rollout}") print(f"Max tokens: {streaming_config.response_length}") + if not results: + print("-" * 60) + print("No main benchmark batches were executed after warmup.") + print("=" * 60) + return + print("-" * 60) print(f"Total time (main benchmark): {agg_results['total_generation_time']:.2f}s") print(f"Total weight sync time: {agg_results['total_weight_sync_time']:.2f}s") From 28b1bb006063d23226e6284a183fd5156da8f283 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 11:10:34 -0700 Subject: [PATCH 12/40] Run everything and data parallel setup --- open_instruct/benchmark_generators.py | 56 ++++++++++++++++++- .../qwen3_4b_dapo_math_gen.sh | 6 +- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index 061984fb0a..a10f463a7a 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -48,6 +48,27 @@ def get_default_data_dir() -> pathlib.Path: DATA_DIR = get_default_data_dir() +@dataclasses.dataclass +class BenchmarkConfig: + """Benchmark-only controls for benchmark_generators.py.""" + + num_batches: int = 5 + """Total number of benchmark batches to run, including the initial warmup batch.""" + run_all_instances: bool = False + """If True, ignore num_batches and run enough batches to cover the entire dataset.""" + + def __post_init__(self) -> None: + if self.num_batches < 1: + raise ValueError(f"num_batches must be >= 1, got {self.num_batches}") + + +def resolve_num_batches(*, dataset_len: int, prompts_per_batch: int, benchmark_config: BenchmarkConfig) -> int: + """Resolve the total number of benchmark batches to run, including warmup.""" + if benchmark_config.run_all_instances: + return max(1, -(-dataset_len // prompts_per_batch)) + return benchmark_config.num_batches + + def resolve_data_dir( args: grpo_utils.GRPOExperimentConfig, streaming_config: data_loader.StreamingDataLoaderConfig ) -> pathlib.Path: @@ -82,7 +103,13 @@ def save_completion_lengths(batch_results: list[dict], timestamp: int, batch_idx def save_config( - args, tokenizer_config, model_config, streaming_config: data_loader.StreamingDataLoaderConfig, timestamp: int + args, + tokenizer_config, + model_config, + streaming_config: data_loader.StreamingDataLoaderConfig, + benchmark_config: BenchmarkConfig, + resolved_num_batches: int, + timestamp: int, ): """ Save configuration to JSON file. @@ -92,6 +119,8 @@ def save_config( tokenizer_config: TokenizerConfig dataclass model_config: ModelConfig dataclass streaming_config: StreamingDataLoaderConfig dataclass + benchmark_config: Benchmark-specific config dataclass + resolved_num_batches: Effective total number of benchmark batches that will run timestamp: Unix timestamp """ config_path = DATA_DIR / f"config_{timestamp}.json" @@ -102,6 +131,8 @@ def save_config( "tokenizer_config": dataclasses.asdict(tokenizer_config), "model_config": dataclasses.asdict(model_config), "streaming_config": dataclasses.asdict(streaming_config), + "benchmark_config": dataclasses.asdict(benchmark_config), + "resolved_num_batches": resolved_num_batches, "timestamp": timestamp, "experiment_id": os.environ.get("BEAKER_WORKLOAD_ID") or None, } @@ -829,16 +860,18 @@ def main() -> None: model_utils.ModelConfig, data_loader.StreamingDataLoaderConfig, data_loader.VLLMConfig, + BenchmarkConfig, ) # type: ignore[arg-type] ) - args, tokenizer_config, model_config, streaming_config, vllm_config = cast( + args, tokenizer_config, model_config, streaming_config, vllm_config, benchmark_config = cast( tuple[ grpo_utils.GRPOExperimentConfig, dataset_transformation.TokenizerConfig, model_utils.ModelConfig, data_loader.StreamingDataLoaderConfig, data_loader.VLLMConfig, + BenchmarkConfig, ], parser.parse_args_into_dataclasses(), ) @@ -857,6 +890,20 @@ def main() -> None: free_all_gpu_memory() dataset = setup_dataset(args, streaming_config, tokenizer_config) + resolved_num_batches = resolve_num_batches( + dataset_len=len(dataset), + prompts_per_batch=streaming_config.num_unique_prompts_rollout, + benchmark_config=benchmark_config, + ) + if benchmark_config.run_all_instances: + logger.info( + "Resolved run_all_instances=True to %s total batch(es) for %s dataset prompt(s) at %s prompt(s) per batch.", + resolved_num_batches, + len(dataset), + streaming_config.num_unique_prompts_rollout, + ) + else: + logger.info("Using configured num_batches=%s.", resolved_num_batches) max_model_len = streaming_config.max_prompt_token_length + streaming_config.response_length vllm_engines, param_prompt_Q, inference_results_Q, actor_manager = setup_vllm_engines( args, streaming_config, vllm_config, tokenizer_config, model_config, max_model_len, dataset @@ -865,7 +912,9 @@ def main() -> None: # Create the timestamp here so we use it for both filenames. timestamp = int(time.time()) args.run_name = resolve_run_name(args, timestamp) - save_config(args, tokenizer_config, model_config, streaming_config, timestamp) + save_config( + args, tokenizer_config, model_config, streaming_config, benchmark_config, resolved_num_batches, timestamp + ) if streaming_config.save_traces: save_rollout_metadata(streaming_config.rollouts_save_path, args.run_name, model_config.model_name_or_path) run_benchmark( @@ -880,6 +929,7 @@ def main() -> None: model_config, args.run_name, timestamp, + num_batches=resolved_num_batches, ) cleanup(vllm_engines, actor_manager) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 374b54e714..adee6fa57d 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -37,10 +37,11 @@ uv run open_instruct/benchmark_generators.py \ --model_name_or_path "Qwen/Qwen3-4B-Base" \ --tokenizer_name_or_path "Qwen/Qwen3-4B-Base" \ --chat_template_name qwen_instruct_user_boxed_math \ - --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered \ + --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 1.0 \ --dataset_mixer_list_splits train \ --num_unique_prompts_rollout 8 \ - --num_samples_per_prompt_rollout 16 \ + --num_unique_prompts_rollout 64 \ + --vllm_num_engines 8 \ --max_prompt_token_length 2048 \ --response_length 8192 \ --pack_length 10240 \ @@ -52,4 +53,5 @@ uv run open_instruct/benchmark_generators.py \ --verification_reward 10.0 \ --save_traces \ --rollouts_save_path "${TRACE_DIR}" \ + --run_all_instances \ --seed 1 "$@" From df7c586f491df3995cb44cbb5f894c21cda1bbce Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 11:23:19 -0700 Subject: [PATCH 13/40] Duplicate argument --- scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index adee6fa57d..d7a4737db6 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -39,7 +39,6 @@ uv run open_instruct/benchmark_generators.py \ --chat_template_name qwen_instruct_user_boxed_math \ --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 1.0 \ --dataset_mixer_list_splits train \ - --num_unique_prompts_rollout 8 \ --num_unique_prompts_rollout 64 \ --vllm_num_engines 8 \ --max_prompt_token_length 2048 \ From 1d3c168ca348410d8e79249cb8e993d38340dbf4 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 11:28:38 -0700 Subject: [PATCH 14/40] Ugh --- scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index d7a4737db6..abc99a1dbf 100755 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -41,11 +41,10 @@ uv run open_instruct/benchmark_generators.py \ --dataset_mixer_list_splits train \ --num_unique_prompts_rollout 64 \ --vllm_num_engines 8 \ + --vllm_tensor_parallel_size 1 \ --max_prompt_token_length 2048 \ --response_length 8192 \ --pack_length 10240 \ - --vllm_num_engines 1 \ - --vllm_tensor_parallel_size 1 \ --vllm_top_p 1.0 \ --temperature 1.0 \ --apply_verifiable_reward true \ From a9b1ffe913ad794e143d47f3748f71d41a195749 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 11:49:05 -0700 Subject: [PATCH 15/40] Need 16 samples --- scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 1 + 1 file changed, 1 insertion(+) mode change 100755 => 100644 scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh old mode 100755 new mode 100644 index abc99a1dbf..4f7e86417f --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -40,6 +40,7 @@ uv run open_instruct/benchmark_generators.py \ --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 1.0 \ --dataset_mixer_list_splits train \ --num_unique_prompts_rollout 64 \ + --num_samples_per_prompt_rollout 16 \ --vllm_num_engines 8 \ --vllm_tensor_parallel_size 1 \ --max_prompt_token_length 2048 \ From 2caeccacf75445b43726926735c65bb092061b96 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 11:57:37 -0700 Subject: [PATCH 16/40] Configurable output format --- open_instruct/benchmark_generators.py | 15 ++++++- open_instruct/data_loader.py | 3 ++ open_instruct/rl_utils.py | 39 ++++++++++++++++++- .../create_bucketed_difficulty.py | 4 +- 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index a10f463a7a..83506ac9b3 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -216,10 +216,20 @@ def maybe_save_scored_rollout_traces( raise ValueError("Cannot save scored rollout traces because the result is missing its dataset index.") example = dataset[result.index] + prompt_tokens = ( + list(example[dataset_transformation.INPUT_IDS_PROMPT_KEY]) + if streaming_config.rollout_save_format == "full" + else [] + ) + ground_truth = ( + example[dataset_transformation.GROUND_TRUTHS_KEY] + if streaming_config.rollout_save_format == "full" + else None + ) batch, advantages = build_rollout_batch_and_advantages( result, - prompt_tokens=list(example[dataset_transformation.INPUT_IDS_PROMPT_KEY]), - ground_truth=example[dataset_transformation.GROUND_TRUTHS_KEY], + prompt_tokens=prompt_tokens, + ground_truth=ground_truth, dataset_name=example[dataset_transformation.VERIFIER_SOURCE_KEY], raw_query=example[dataset_transformation.RAW_PROMPT_KEY], advantage_normalization_type=streaming_config.advantage_normalization_type, @@ -235,6 +245,7 @@ def maybe_save_scored_rollout_traces( advantages, len(result.responses), total_samples_written, + record_format=streaming_config.rollout_save_format, ) total_samples_written += len(result.responses) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index a4ee28e83a..dd3cf85045 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -492,6 +492,8 @@ class StreamingDataLoaderConfig: # Rollout saving save_traces: bool = False rollouts_save_path: str = "/weka/oe-adapt-default/allennlp/deletable_rollouts/" + rollout_save_format: Literal["full", "scores_only"] = "full" + """Trace record shape to persist when save_traces is enabled.""" # Computed at post_init max_possible_score: float = 1.0 @@ -1330,6 +1332,7 @@ def _data_preparation_loop(self): advantages, self.config.num_samples_per_prompt_rollout, self.total_samples_written, + record_format=self.config.rollout_save_format, ) self.total_samples_written += len(batch.queries) diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 3a5a9e7fba..1c4a7fc06b 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -5,7 +5,7 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass, field -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Literal, TypeVar import numpy as np import torch @@ -17,6 +17,7 @@ _rollout_executor = ThreadPoolExecutor(max_workers=2) ROLLOUT_SHARD_SIZE = 10000 +RolloutSaveFormat = Literal["full", "scores_only"] @dataclass @@ -46,6 +47,14 @@ class RolloutRecord: logprobs: list[float] | None = None +@dataclass +class RolloutScoreRecord: + dataset: str + reward: float + source_row_id: int | None = None + source_dataset: str | None = None + + def save_rollout_metadata(save_path: str, run_name: str, model_name: str | None) -> None: """Save metadata about the rollout collection to disk. @@ -98,6 +107,7 @@ def _save_rollouts( advantages: np.ndarray, num_samples_per_prompt: int, shard_idx: int, + record_format: RolloutSaveFormat, ) -> None: shard_filename = f"{run_name}_rollouts_{shard_idx:06d}.jsonl" filepath = os.path.join(save_path, shard_filename) @@ -107,6 +117,19 @@ def _save_rollouts( records = [] for i in range(len(batch.queries)): + if record_format == "scores_only": + records.append( + asdict( + RolloutScoreRecord( + dataset=batch.datasets[i], + reward=float(batch.scores[i]), + source_row_id=batch.source_row_ids[i] if batch.source_row_ids is not None else None, + source_dataset=batch.source_datasets[i] if batch.source_datasets is not None else None, + ) + ) + ) + continue + records.append( asdict( RolloutRecord( @@ -143,6 +166,7 @@ def save_rollouts_to_disk( advantages: np.ndarray, num_samples_per_prompt: int, total_samples_written: int, + record_format: RolloutSaveFormat = "full", ) -> None: """Asynchronously save rollout records to disk. @@ -158,10 +182,21 @@ def save_rollouts_to_disk( advantages: Calculated advantage values per sample. num_samples_per_prompt: Number of samples generated per prompt. total_samples_written: Total samples written so far, used for sharding. + record_format: Output schema to persist. Use "scores_only" for the + minimum fields needed by the difficulty-map builder. """ shard_idx = total_samples_written // ROLLOUT_SHARD_SIZE _rollout_executor.submit( - _save_rollouts, save_path, run_name, step, batch, result, advantages, num_samples_per_prompt, shard_idx + _save_rollouts, + save_path, + run_name, + step, + batch, + result, + advantages, + num_samples_per_prompt, + shard_idx, + record_format, ) diff --git a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py index 39f2908681..8194f2cab1 100644 --- a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py +++ b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py @@ -15,7 +15,7 @@ files, or rollout shard ``.jsonl`` files written by ``open_instruct.rl_utils``. For each traced prompt instance it: -1. loads rollout shards written by ``save_rollouts_to_disk()``, +1. loads rollout shards written by ``save_rollouts_to_disk()``, including compact score-only shards, 2. groups attempts by source dataset identity when available, otherwise by a deterministic fingerprint over task name, prompt tokens, and ground truth, 3. normalizes binary verifiable rewards from ``{0, C}`` back to ``{0, 1}`` @@ -434,7 +434,7 @@ def build_rollout_contribution( prompt_tokens = normalize_token_list(record.get("prompt_tokens")) if prompt_tokens is None and (source_dataset is None or source_dataset_id is None): - raise ValueError("missing or invalid prompt_tokens") + raise ValueError("missing prompt_tokens and source dataset identity (source_dataset/source_row_id)") reward = extract_numeric_reward(record.get("reward")) if reward is None: From d45c6614a210bc4e330fb60bf674e5ec01fdf4db Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Sun, 3 May 2026 12:02:00 -0700 Subject: [PATCH 17/40] Only save scores --- scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 4f7e86417f..0287d057a5 100644 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -51,6 +51,7 @@ uv run open_instruct/benchmark_generators.py \ --apply_verifiable_reward true \ --verification_reward 10.0 \ --save_traces \ + --rollout_save_format scores_only \ --rollouts_save_path "${TRACE_DIR}" \ --run_all_instances \ --seed 1 "$@" From 2f415412a7101eb043872f7619cf1f09a24dea10 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 4 May 2026 10:58:04 -0700 Subject: [PATCH 18/40] Support HF datasets as well --- open_instruct/environments/tools/parsers.py | 2 +- .../create_bucketed_difficulty.py | 531 +++++++++++++++--- .../qwen3_4b_dapo_math_gen.sh | 1 + tests/test_create_bucketed_difficulty.py | 176 ++++++ 4 files changed, 633 insertions(+), 77 deletions(-) diff --git a/open_instruct/environments/tools/parsers.py b/open_instruct/environments/tools/parsers.py index a8f04486b4..e9f81fbb2b 100644 --- a/open_instruct/environments/tools/parsers.py +++ b/open_instruct/environments/tools/parsers.py @@ -154,7 +154,7 @@ def _make_request(self) -> Any: Usually these only need the list of tools. """ - return ChatCompletionRequest(model="dummy", messages=[], tools=self._tool_definitions) # ty: ignore[invalid-argument-type] + return ChatCompletionRequest(model="dummy", messages=[], tools=self._tool_definitions) def get_tool_calls(self, text: str) -> list[EnvCall]: """Extract tool calls from model output. diff --git a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py index 8194f2cab1..f2c05d15d9 100644 --- a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py +++ b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py @@ -9,13 +9,16 @@ # /// """ -Build a per-instance difficulty map from open-instruct rollout traces. +Build a per-instance difficulty map from open-instruct rollout traces or +Hugging Face datasets with pass-rate aggregates. The script accepts one or more local rollout directories, metadata ``.jsonl`` -files, or rollout shard ``.jsonl`` files written by ``open_instruct.rl_utils``. -For each traced prompt instance it: +files, rollout shard ``.jsonl`` files written by ``open_instruct.rl_utils``, +or a Hugging Face dataset that already contains per-row pass counts. For each +prompt instance it: 1. loads rollout shards written by ``save_rollouts_to_disk()``, including compact score-only shards, + or loads per-row pass counts from a Hub dataset, 2. groups attempts by source dataset identity when available, otherwise by a deterministic fingerprint over task name, prompt tokens, and ground truth, 3. normalizes binary verifiable rewards from ``{0, C}`` back to ``{0, 1}`` @@ -33,6 +36,11 @@ uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ --source /tmp/qwen_math_rollouts/qwen_math_metadata.jsonl \ --output /tmp/difficulty_map + + uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --output /tmp/dapo_math_qwen3_difficulty """ from __future__ import annotations @@ -48,7 +56,7 @@ from typing import Any import numpy as np -from datasets import Dataset +from datasets import Dataset, load_dataset from scipy.optimize import minimize from scipy.special import betaln from scipy.stats import beta as beta_distribution @@ -72,10 +80,16 @@ DIFFICULTY_GENERATION_METHOD = "beta_binomial_posterior_quantiles" DIFFICULTY_METHOD_FILENAME_ALIASES = {DIFFICULTY_GENERATION_METHOD: "bbq"} PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} -SOURCE_FORMAT_KIND = "open_instruct_rollout_traces" -INSTANCE_ID_DEFINITION = ( +ROLLOUT_SOURCE_FORMAT_KIND = "open_instruct_rollout_traces" +HF_SOURCE_FORMAT_KIND = "hugging_face_dataset_passrate_rows" +ROLLOUT_INSTANCE_ID_DEFINITION = ( "source_dataset::source_dataset_id when available; otherwise sha1(task_name,prompt_tokens,ground_truth)" ) +HF_INSTANCE_ID_DEFINITION = ( + "dataset_repo_id::row_id_field when a stable row id is available; otherwise dataset_repo_id::row_index" +) +HF_SOURCE_ROW_INDEX_FIELD = "_source_row_index" +HF_OUTPUT_COLUMNS = ("difficulty",) @dataclass(frozen=True) @@ -101,17 +115,66 @@ class DifficultyPosteriorRow: difficulty_beta: float +@dataclass(frozen=True) +class InputRowsBundle: + rows: list[dict[str, Any]] + malformed_records: int + source_format: dict[str, Any] + source_dataset: Dataset | None = None + + def make_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( - description="Build a per-instance difficulty map from open-instruct rollout traces.", + description="Build a per-instance difficulty map from open-instruct rollout traces or HF pass-rate datasets.", formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument( + source_group = parser.add_mutually_exclusive_group(required=True) + source_group.add_argument( "--source", nargs="+", - required=True, help="One or more local rollout dirs, *_metadata.jsonl files, or *_rollouts_*.jsonl shards.", ) + source_group.add_argument( + "--hf-dataset", + type=str, + default=None, + help="Hugging Face dataset repo id containing per-row pass-rate aggregates.", + ) + parser.add_argument("--hf-config", type=str, default=None, help="Optional dataset config for --hf-dataset.") + parser.add_argument("--hf-split", type=str, default="train", help="Input split to load from --hf-dataset.") + parser.add_argument( + "--hf-row-id-field", + type=str, + default="extra_info.index", + help="Dot-path to the stable per-row id field inside --hf-dataset.", + ) + parser.add_argument( + "--hf-task-field", type=str, default="dataset", help="Dot-path to the task/verifier field in --hf-dataset." + ) + parser.add_argument( + "--hf-model-field", + type=str, + default="generator_model", + help="Dot-path to the generator model field in --hf-dataset.", + ) + parser.add_argument( + "--hf-pass-count-field", + type=str, + default="pass_count", + help="Dot-path to the integer pass-count field in --hf-dataset.", + ) + parser.add_argument( + "--hf-attempt-count-field", + type=str, + default="num_samples", + help="Dot-path to the total-attempt-count field in --hf-dataset.", + ) + parser.add_argument( + "--hf-pass-rate-field", + type=str, + default="pass_rate", + help="Optional dot-path to a pass-rate or fraction field used for validation/fallback in --hf-dataset.", + ) parser.add_argument( "--task", action="append", @@ -175,33 +238,13 @@ def main(argv: list[str] | None = None) -> None: task_filters = set(args.task) output_root = resolve_output_root(args.output) - source_runs = discover_rollout_sources(args.source) - if not source_runs: - raise ValueError("No rollout trace sources were found.") - - contributions: list[dict[str, Any]] = [] - malformed_records = 0 + input_rows = load_input_rows(args, task_filters=task_filters) - for source_run in source_runs: - logger.info( - "Loading %s (run=%s, metadata=%s, shards=%s)", - source_run.input_arg, - source_run.run_name, - source_run.metadata_path, - len(source_run.rollout_paths), - ) - run_contributions, run_malformed = build_contributions_for_source( - source_run=source_run, task_filters=task_filters, strict=args.strict - ) - contributions.extend(run_contributions) - malformed_records += run_malformed - - if not contributions: + if not input_rows.rows: raise ValueError("No resolved per-instance rows were produced.") - rows = aggregate_contributions(contributions) rows = sorted( - rows, + input_rows.rows, key=lambda row: ( stable_string(row.get("task_name")), stable_string((row.get("experiment_metadata") or {}).get("model_name")), @@ -226,6 +269,16 @@ def main(argv: list[str] | None = None) -> None: group_rows, score_processing, group_skipped_nonunit = normalize_attempt_scores_for_group( group_rows, allow_nonunit_scores=args.allow_nonunit_scores ) + if input_rows.source_format["kind"] == HF_SOURCE_FORMAT_KIND: + score_processing["source_field"] = ",".join( + field_name + for field_name in ( + input_rows.source_format.get("pass_count_field"), + input_rows.source_format.get("attempt_count_field"), + input_rows.source_format.get("pass_rate_field"), + ) + if field_name + ) skipped_nonunit += group_skipped_nonunit if not group_rows: @@ -238,8 +291,14 @@ def main(argv: list[str] | None = None) -> None: group_rows = apply_beta_binomial_difficulty( group_rows, prior=prior, lower_quantile=args.posterior_lower_quantile, num_buckets=args.difficulty_buckets ) - group_rows = sorted(group_rows, key=lambda row: row["instance_id"]) - output_rows = strip_output_only_rollout_fields(group_rows) + if input_rows.source_dataset is None: + ordered_group_rows = sorted(group_rows, key=lambda row: row["instance_id"]) + output_rows = strip_output_only_rollout_fields(ordered_group_rows) + dataset = Dataset.from_list(output_rows) + else: + ordered_group_rows = sort_hf_group_rows(group_rows) + output_rows = strip_internal_fields(ordered_group_rows) + dataset = build_hf_output_dataset(input_rows.source_dataset, ordered_group_rows) dataset_metadata = build_dataset_metadata( rows=output_rows, @@ -251,6 +310,7 @@ def main(argv: list[str] | None = None) -> None: prior=prior, binary_row_count=binary_row_count, score_processing=score_processing, + source_format=input_rows.source_format, ) if prior is not None: @@ -270,7 +330,6 @@ def main(argv: list[str] | None = None) -> None: model_name, ) - dataset = Dataset.from_list(output_rows) annotate_dataset_metadata(dataset, dataset_metadata) output_jsonl, schema_json, metadata_json = build_output_paths( output_root, task_name=task_name, model_name=model_name, dataset_metadata=dataset_metadata @@ -279,7 +338,6 @@ def main(argv: list[str] | None = None) -> None: output_jsonl=output_jsonl, schema_json=schema_json, metadata_json=metadata_json, - rows=output_rows, dataset=dataset, dataset_metadata=dataset_metadata, ) @@ -301,11 +359,364 @@ def main(argv: list[str] | None = None) -> None: logger.info( "Finished writing %s output file groups (%s malformed rollout records, %s skipped due to unsupported scores).", len(written_outputs), - malformed_records, + input_rows.malformed_records, skipped_nonunit, ) +def load_input_rows(args: argparse.Namespace, *, task_filters: set[str]) -> InputRowsBundle: + if args.hf_dataset is not None: + return load_hf_dataset_rows( + dataset_name=args.hf_dataset, + config_name=args.hf_config, + split=args.hf_split, + task_filters=task_filters, + strict=args.strict, + row_id_field=args.hf_row_id_field, + task_field=args.hf_task_field, + model_field=args.hf_model_field, + pass_count_field=args.hf_pass_count_field, + attempt_count_field=args.hf_attempt_count_field, + pass_rate_field=args.hf_pass_rate_field, + ) + + if not args.source: + raise ValueError("Expected --source when --hf-dataset is not provided.") + + source_runs = discover_rollout_sources(args.source) + if not source_runs: + raise ValueError("No rollout trace sources were found.") + + contributions: list[dict[str, Any]] = [] + malformed_records = 0 + + for source_run in source_runs: + logger.info( + "Loading %s (run=%s, metadata=%s, shards=%s)", + source_run.input_arg, + source_run.run_name, + source_run.metadata_path, + len(source_run.rollout_paths), + ) + run_contributions, run_malformed = build_contributions_for_source( + source_run=source_run, task_filters=task_filters, strict=args.strict + ) + contributions.extend(run_contributions) + malformed_records += run_malformed + + return InputRowsBundle( + rows=aggregate_contributions(contributions), + malformed_records=malformed_records, + source_format=build_rollout_source_format_metadata(), + ) + + +def load_hf_dataset_rows( + *, + dataset_name: str, + config_name: str | None, + split: str, + task_filters: set[str], + strict: bool, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> InputRowsBundle: + logger.info( + "Loading Hugging Face dataset %s (config=%s, split=%s).", dataset_name, config_name or "default", split + ) + + if config_name: + source_dataset = load_dataset(dataset_name, config_name, split=split) + else: + source_dataset = load_dataset(dataset_name, split=split) + + rows: list[dict[str, Any]] = [] + malformed_records = 0 + + for row_index, source_row in enumerate(source_dataset): + try: + row = build_hf_dataset_row( + source_row=source_row, + source_row_index=row_index, + dataset_name=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ) + except Exception as exc: + malformed_records += 1 + message = f"Malformed HF dataset row {dataset_name}[{split}][{row_index}]: {exc}" + if strict: + raise ValueError(message) from exc + logger.warning(message) + continue + + task_name = stable_string(row.get("task_name")) + if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: + continue + rows.append(row) + + return InputRowsBundle( + rows=rows, + malformed_records=malformed_records, + source_format=build_hf_source_format_metadata( + dataset_name=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ), + source_dataset=source_dataset, + ) + + +def build_hf_dataset_row( + *, + source_row: dict[str, Any], + source_row_index: int, + dataset_name: str, + config_name: str | None, + split: str, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> dict[str, Any]: + task_name = normalize_task_name(get_nested_field(source_row, task_field)) + if task_name is None: + raise ValueError(f"missing task field {task_field!r}") + + source_row_id = normalize_identifier(get_nested_field(source_row, row_id_field)) or str(source_row_index) + pass_count, attempt_count = extract_hf_attempt_summary( + row=source_row, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ) + model_name = optional_string(get_nested_field(source_row, model_field)) + + return { + HF_SOURCE_ROW_INDEX_FIELD: source_row_index, + "instance_id": make_hf_instance_id(dataset_name=dataset_name, source_row_id=source_row_id), + "task_name": task_name, + "base_task_name": get_base_task_name(task_name), + "source_dataset": dataset_name, + "source_row_id": source_row_id, + "attempt_scores": expand_binary_attempt_scores(pass_count=pass_count, attempt_count=attempt_count), + "finish_reasons": [], + "experiment_metadata": { + "source_root": format_hf_source_locator(dataset_name=dataset_name, config_name=config_name, split=split), + "model_name": model_name, + "experiment_id": None, + "experiment_name": dataset_name, + }, + "score_sources": [task_name], + "warnings": [], + } + + +def build_rollout_source_format_metadata() -> dict[str, Any]: + return { + "kind": ROLLOUT_SOURCE_FORMAT_KIND, + "task_field": "dataset", + "score_field": "reward", + "source_dataset_field": "source_dataset", + "source_dataset_id_field": "source_dataset_id", + "source_row_id_field": "source_row_id", + "instance_id_definition": ROLLOUT_INSTANCE_ID_DEFINITION, + } + + +def build_hf_source_format_metadata( + *, + dataset_name: str, + config_name: str | None, + split: str, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> dict[str, Any]: + return { + "kind": HF_SOURCE_FORMAT_KIND, + "dataset_repo_id": dataset_name, + "config_name": config_name, + "split": split, + "row_id_field": row_id_field, + "task_field": task_field, + "model_field": model_field, + "pass_count_field": pass_count_field, + "attempt_count_field": attempt_count_field, + "pass_rate_field": pass_rate_field, + "instance_id_definition": HF_INSTANCE_ID_DEFINITION, + } + + +def format_hf_source_locator(*, dataset_name: str, config_name: str | None, split: str) -> str: + config_token = config_name or "default" + return f"hf://{dataset_name}/{config_token}/{split}" + + +def make_hf_instance_id(*, dataset_name: str, source_row_id: str) -> str: + return f"{dataset_name}::{source_row_id}" + + +def sort_hf_group_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + return sorted(rows, key=lambda row: row[HF_SOURCE_ROW_INDEX_FIELD]) + + +def build_hf_output_dataset(source_dataset: Dataset, rows: list[dict[str, Any]]) -> Dataset: + ordered_rows = sort_hf_group_rows(rows) + dataset = source_dataset.select([row[HF_SOURCE_ROW_INDEX_FIELD] for row in ordered_rows]) + + for column_name in HF_OUTPUT_COLUMNS: + values = [make_jsonable(row.get(column_name)) for row in ordered_rows] + if column_name in dataset.column_names: + dataset = dataset.remove_columns(column_name) + dataset = dataset.add_column(column_name, values) + + return dataset + + +def strip_internal_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [{key: value for key, value in row.items() if key != HF_SOURCE_ROW_INDEX_FIELD} for row in rows] + + +def get_nested_field(value: Any, field_path: str) -> Any: + if not field_path: + return value + + current = value + for field_name in field_path.split("."): + if not isinstance(current, dict) or field_name not in current: + return None + current = current[field_name] + return current + + +def normalize_identifier(value: Any) -> str | None: + if value is None or isinstance(value, bool): + return None + text = stable_string(value).strip() + return text or None + + +def normalize_nonnegative_int(value: Any) -> int | None: + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value if value >= 0 else None + if isinstance(value, float): + if not math.isfinite(value) or not value.is_integer() or value < 0: + return None + return int(value) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + parsed = int(stripped) + except ValueError: + return None + return parsed if parsed >= 0 else None + return None + + +def parse_pass_rate_value(value: Any) -> tuple[int | None, int | None, float | None]: + if value is None: + return None, None, None + if is_number(value): + rate = float(value) + if 0.0 <= rate <= 1.0: + return None, None, rate + raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") + if not isinstance(value, str): + raise ValueError(f"unsupported pass-rate value {value!r}") + + stripped = value.strip() + if not stripped: + return None, None, None + + if "/" in stripped: + numerator_text, denominator_text = stripped.split("/", 1) + numerator = normalize_nonnegative_int(numerator_text) + denominator = normalize_nonnegative_int(denominator_text) + if numerator is None or denominator is None or numerator > denominator: + raise ValueError(f"invalid pass-rate fraction {value!r}") + rate = 0.0 if denominator == 0 else numerator / denominator + return numerator, denominator, rate + + try: + rate = float(stripped) + except ValueError as exc: + raise ValueError(f"invalid pass-rate value {value!r}") from exc + if not math.isfinite(rate) or rate < 0.0 or rate > 1.0: + raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") + return None, None, rate + + +def extract_hf_attempt_summary( + *, row: dict[str, Any], pass_count_field: str, attempt_count_field: str, pass_rate_field: str | None +) -> tuple[int, int]: + pass_count = normalize_nonnegative_int(get_nested_field(row, pass_count_field)) + attempt_count = normalize_nonnegative_int(get_nested_field(row, attempt_count_field)) + + parsed_pass_count = None + parsed_attempt_count = None + parsed_pass_rate = None + if pass_rate_field: + parsed_pass_count, parsed_attempt_count, parsed_pass_rate = parse_pass_rate_value( + get_nested_field(row, pass_rate_field) + ) + + if pass_count is None and parsed_pass_count is not None: + pass_count = parsed_pass_count + if attempt_count is None and parsed_attempt_count is not None: + attempt_count = parsed_attempt_count + + if pass_count is None or attempt_count is None: + raise ValueError( + f"missing pass-count summary fields {pass_count_field!r}/{attempt_count_field!r}" + f"{f' or parseable {pass_rate_field!r}' if pass_rate_field else ''}" + ) + if attempt_count <= 0: + raise ValueError(f"attempt count must be positive, received {attempt_count}") + if pass_count > attempt_count: + raise ValueError(f"pass count {pass_count} exceeds attempt count {attempt_count}") + + if parsed_pass_count is not None and parsed_pass_count != pass_count: + raise ValueError(f"pass-count field {pass_count_field!r} disagrees with {pass_rate_field!r}") + if parsed_attempt_count is not None and parsed_attempt_count != attempt_count: + raise ValueError(f"attempt-count field {attempt_count_field!r} disagrees with {pass_rate_field!r}") + if parsed_pass_rate is not None and not is_close(pass_count / attempt_count, parsed_pass_rate): + raise ValueError( + f"pass-count fields {pass_count_field!r}/{attempt_count_field!r} disagree with {pass_rate_field!r}" + ) + + return pass_count, attempt_count + + +def expand_binary_attempt_scores(*, pass_count: int, attempt_count: int) -> list[float]: + return [1.0] * pass_count + [0.0] * (attempt_count - pass_count) + + def discover_rollout_sources(sources: list[str]) -> list[RolloutSource]: discovered: dict[Path, RolloutSource] = {} @@ -501,23 +912,7 @@ def extract_source_dataset_id(record: dict[str, Any]) -> int | None: def normalize_source_dataset_id(value: Any) -> int | None: - if value is None or isinstance(value, bool): - return None - if isinstance(value, int): - return value - if isinstance(value, float): - if not math.isfinite(value) or not value.is_integer(): - return None - return int(value) - if isinstance(value, str): - stripped = value.strip() - if not stripped: - return None - try: - return int(stripped) - except ValueError: - return None - return None + return normalize_nonnegative_int(value) def normalize_token_list(value: Any) -> list[int] | None: @@ -589,10 +984,7 @@ def aggregate_contributions(contributions: list[dict[str, Any]]) -> list[dict[st def strip_output_only_rollout_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [ - {key: value for key, value in row.items() if key not in {"prompt_tokens", "ground_truth"}} - for row in rows - ] + return [{key: value for key, value in row.items() if key not in {"prompt_tokens", "ground_truth"}} for row in rows] def normalize_attempt_scores_for_group( @@ -902,18 +1294,12 @@ def build_output_paths( def write_output_files( - *, - output_jsonl: Path, - schema_json: Path, - metadata_json: Path, - rows: list[dict[str, Any]], - dataset: Dataset, - dataset_metadata: dict[str, Any], + *, output_jsonl: Path, schema_json: Path, metadata_json: Path, dataset: Dataset, dataset_metadata: dict[str, Any] ) -> None: output_jsonl.parent.mkdir(parents=True, exist_ok=True) with output_jsonl.open("w") as output_file: - for row in rows: - output_file.write(json.dumps(row, ensure_ascii=False) + "\n") + for row in dataset: + output_file.write(json.dumps(make_jsonable(row), ensure_ascii=False) + "\n") schema_json.parent.mkdir(parents=True, exist_ok=True) try: @@ -939,6 +1325,7 @@ def build_dataset_metadata( prior: BetaPrior | None, binary_row_count: int, score_processing: dict[str, Any], + source_format: dict[str, Any], ) -> dict[str, Any]: effective_bucket_count = extract_effective_bucket_count(rows) difficulty_generation = { @@ -965,15 +1352,7 @@ def build_dataset_metadata( "task_name": task_name, "model_name": model_name, "row_count": len(rows), - "source_format": { - "kind": SOURCE_FORMAT_KIND, - "task_field": "dataset", - "score_field": "reward", - "source_dataset_field": "source_dataset", - "source_dataset_id_field": "source_dataset_id", - "source_row_id_field": "source_row_id", - "instance_id_definition": INSTANCE_ID_DEFINITION, - }, + "source_format": dict(source_format), "score_processing": dict(score_processing), "difficulty_generation": difficulty_generation, } diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh index 0287d057a5..2662fd816c 100644 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh @@ -51,6 +51,7 @@ uv run open_instruct/benchmark_generators.py \ --apply_verifiable_reward true \ --verification_reward 10.0 \ --save_traces \ + --vllm_enable_prefix_caching \ --rollout_save_format scores_only \ --rollouts_save_path "${TRACE_DIR}" \ --run_all_instances \ diff --git a/tests/test_create_bucketed_difficulty.py b/tests/test_create_bucketed_difficulty.py index 1bc6817b11..733b6b4e92 100644 --- a/tests/test_create_bucketed_difficulty.py +++ b/tests/test_create_bucketed_difficulty.py @@ -24,6 +24,7 @@ def _load_create_bucketed_difficulty_module(): fake_datasets = types.ModuleType("datasets") fake_datasets.Dataset = type("Dataset", (), {}) + fake_datasets.load_dataset = lambda *_args, **_kwargs: None fake_scipy = types.ModuleType("scipy") fake_scipy_optimize = types.ModuleType("scipy.optimize") @@ -101,6 +102,37 @@ def ppf(cls, q, alpha, beta): class TestCreateBucketedDifficulty(unittest.TestCase): + class FakeHFDataset: + def __init__(self, rows): + self._rows = [dict(row) for row in rows] + + def __getitem__(self, index): + return self._rows[index] + + def __iter__(self): + return iter(self._rows) + + def __len__(self): + return len(self._rows) + + @property + def column_names(self): + return list(self._rows[0].keys()) if self._rows else [] + + def select(self, indices): + return TestCreateBucketedDifficulty.FakeHFDataset([self._rows[index] for index in indices]) + + def remove_columns(self, column_names): + names = {column_names} if isinstance(column_names, str) else set(column_names) + return TestCreateBucketedDifficulty.FakeHFDataset( + [{key: value for key, value in row.items() if key not in names} for row in self._rows] + ) + + def add_column(self, name, values): + return TestCreateBucketedDifficulty.FakeHFDataset( + [{**row, name: value} for row, value in zip(self._rows, values, strict=True)] + ) + def test_discover_rollout_sources_resolves_directory_runs(self): with tempfile.TemporaryDirectory() as tmpdir: root = Path(tmpdir) @@ -277,6 +309,7 @@ def test_build_dataset_metadata_captures_difficulty_generation_details(self): "positive_reward_value": 10.0, "supports_binary_difficulty": True, }, + source_format=MODULE.build_rollout_source_format_metadata(), ) self.assertEqual(metadata["task_name"], "math") @@ -295,6 +328,149 @@ def test_build_dataset_metadata_captures_difficulty_generation_details(self): self.assertEqual(metadata["difficulty_generation"]["binary_instance_count"], 2) self.assertEqual(metadata["difficulty_generation"]["nonbinary_instance_count"], 1) + def test_build_hf_dataset_row_parses_pass_rate_counts(self): + row = MODULE.build_hf_dataset_row( + source_row={ + "dataset": "math", + "extra_info": {"index": "row-7"}, + "pass_count": 3, + "num_samples": 5, + "pass_rate": "3/5", + "generator_model": "Qwen/Qwen3-4B-Base", + }, + source_row_index=7, + dataset_name="mnoukhov/demo", + config_name=None, + split="train", + row_id_field="extra_info.index", + task_field="dataset", + model_field="generator_model", + pass_count_field="pass_count", + attempt_count_field="num_samples", + pass_rate_field="pass_rate", + ) + + self.assertEqual(row["instance_id"], "mnoukhov/demo::row-7") + self.assertEqual(row["source_row_id"], "row-7") + self.assertEqual(row["attempt_scores"], [1.0, 1.0, 1.0, 0.0, 0.0]) + self.assertEqual(row["experiment_metadata"]["model_name"], "Qwen/Qwen3-4B-Base") + self.assertEqual(row["experiment_metadata"]["source_root"], "hf://mnoukhov/demo/default/train") + + def test_load_hf_dataset_rows_builds_bundle_and_filters_tasks(self): + fake_dataset = self.FakeHFDataset( + [ + { + "dataset": "math", + "extra_info": {"index": "math-1"}, + "pass_count": 2, + "num_samples": 4, + "pass_rate": "2/4", + "generator_model": "Qwen/Qwen3-4B-Base", + }, + { + "dataset": "gsm8k", + "extra_info": {"index": "gsm-1"}, + "pass_count": 1, + "num_samples": 4, + "pass_rate": "1/4", + "generator_model": "Qwen/Qwen3-4B-Base", + }, + ] + ) + + with patch.object(MODULE, "load_dataset", return_value=fake_dataset): + bundle = MODULE.load_hf_dataset_rows( + dataset_name="mnoukhov/demo", + config_name=None, + split="train", + task_filters={"math"}, + strict=True, + row_id_field="extra_info.index", + task_field="dataset", + model_field="generator_model", + pass_count_field="pass_count", + attempt_count_field="num_samples", + pass_rate_field="pass_rate", + ) + + self.assertEqual(bundle.malformed_records, 0) + self.assertEqual(bundle.source_format["kind"], MODULE.HF_SOURCE_FORMAT_KIND) + self.assertEqual(bundle.source_format["dataset_repo_id"], "mnoukhov/demo") + self.assertEqual(len(bundle.rows), 1) + self.assertEqual(bundle.rows[0]["instance_id"], "mnoukhov/demo::math-1") + self.assertEqual(bundle.rows[0]["attempt_scores"], [1.0, 1.0, 0.0, 0.0]) + + def test_build_hf_output_dataset_preserves_source_rows_and_order(self): + source_dataset = self.FakeHFDataset( + [ + {"prompt": "first", "extra_info": {"index": "row-0"}}, + {"prompt": "second", "extra_info": {"index": "row-1"}}, + ] + ) + rows = [ + { + MODULE.HF_SOURCE_ROW_INDEX_FIELD: 1, + "instance_id": "mnoukhov/demo::row-1", + "task_name": "math", + "base_task_name": "math", + "source_dataset": "mnoukhov/demo", + "source_row_id": "row-1", + "attempt_scores": [0.0, 0.0], + "finish_reasons": [], + "experiment_metadata": { + "source_root": "hf://mnoukhov/demo/default/train", + "model_name": "Qwen/Qwen3-4B-Base", + "experiment_id": None, + "experiment_name": "mnoukhov/demo", + }, + "score_sources": ["math"], + "warnings": [], + "difficulty": { + "value": 0.9, + "posterior_mean": 0.1, + "posterior_lower_bound": 0.1, + "expected_quantile": 0.9, + "bucket_index": 1, + "bucket_count": 2, + }, + }, + { + MODULE.HF_SOURCE_ROW_INDEX_FIELD: 0, + "instance_id": "mnoukhov/demo::row-0", + "task_name": "math", + "base_task_name": "math", + "source_dataset": "mnoukhov/demo", + "source_row_id": "row-0", + "attempt_scores": [1.0, 1.0], + "finish_reasons": [], + "experiment_metadata": { + "source_root": "hf://mnoukhov/demo/default/train", + "model_name": "Qwen/Qwen3-4B-Base", + "experiment_id": None, + "experiment_name": "mnoukhov/demo", + }, + "score_sources": ["math"], + "warnings": [], + "difficulty": { + "value": 0.1, + "posterior_mean": 0.9, + "posterior_lower_bound": 0.9, + "expected_quantile": 0.1, + "bucket_index": 0, + "bucket_count": 2, + }, + }, + ] + + dataset = MODULE.build_hf_output_dataset(source_dataset, rows) + + self.assertEqual(len(dataset), 2) + self.assertEqual(dataset.column_names, ["prompt", "extra_info", "difficulty"]) + self.assertEqual(dataset[0]["prompt"], "first") + self.assertEqual(dataset[0]["difficulty"]["bucket_index"], 0) + self.assertEqual(dataset[1]["prompt"], "second") + self.assertEqual(dataset[1]["difficulty"]["bucket_index"], 1) + def test_annotate_dataset_metadata_stores_json_description(self): class FakeInfo: description = "" From fb8599ef5891249aebcf11a73666f6f39fa1c320 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 4 May 2026 11:48:34 -0700 Subject: [PATCH 19/40] WIP: First pass at difficulty sampling loader --- docs/algorithms/grpo.md | 68 ++ open_instruct/data_loader.py | 227 +++++- open_instruct/rlvr_curriculum.py | 656 ++++++++++++++++++ open_instruct/test_rlvr_curriculum.py | 185 +++++ ...wen3_4b_dapo_math_difficulty_curriculum.sh | 113 +++ 5 files changed, 1237 insertions(+), 12 deletions(-) create mode 100644 open_instruct/rlvr_curriculum.py create mode 100644 open_instruct/test_rlvr_curriculum.py create mode 100644 scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh diff --git a/docs/algorithms/grpo.md b/docs/algorithms/grpo.md index 1206d9a490..9b991a79cf 100644 --- a/docs/algorithms/grpo.md +++ b/docs/algorithms/grpo.md @@ -77,6 +77,74 @@ Both `grpo.py` and `grpo_fast.py` share the same config classes and accept the s | | `--save_freq` | Save every N train steps | `200` | | | `--with_tracking` | Track experiment with Weights and Biases | `False` | +### Difficulty-Aware RLVR Curriculum + +Open-Instruct can optionally replace uniform prompt reshuffling with a bucket-aware RLVR curriculum driven by per-instance beta-binomial metadata: + +```json +{ + "difficulty": { + "value": 0.9999999997624719, + "posterior_mean": 0.003437858035078528, + "posterior_lower_bound": 2.3752813430506325e-10, + "expected_quantile": 0.10139684528348392, + "bucket_index": 4, + "bucket_count": 5 + } +} +``` + +- `posterior_mean` is the estimated solve probability for that prompt. Lower means harder. +- `bucket_index = 0` is the easiest bucket and `bucket_index = bucket_count - 1` is the hardest. +- The sampler uses a smooth distribution with a configurable easy-heavy bootstrap phase, then gradually shifts mass toward harder buckets instead of hard-switching between discrete phases. +- Within each bucket, examples are weighted by a blend of uncertainty (`4 * p * (1 - p)`) and hardness (`1 - p`), so borderline prompts stay attractive while already-solved prompts are naturally down-weighted. +- If `--difficulty_curriculum_adaptive_enabled true` is set, bucket probabilities are additionally blended with live reward / advantage statistics so buckets with useful learning signal can get more mass during training. + +Recommended starting settings for `bucket_count=5`: + +- Bootstrap (first ~100 steps by default): buckets 0 and 1 dominate so the model sees easier prompts while it settles into the chat template and task format. +- Early after bootstrap: bucket 2 highest, buckets 1 and 3 nonzero, bucket 4 low. +- Mid: buckets 2 and 3 dominate, with bucket 4 increasing. +- Late: buckets 3 and 4 dominate, while buckets 0-2 remain nonzero. + +Useful flags: + +```bash +--difficulty_curriculum_enabled true \ +--difficulty_curriculum_field difficulty \ +--difficulty_curriculum_easy_focus_steps 100 \ +--difficulty_curriculum_bootstrap_target_bucket_ratio 0.125 \ +--difficulty_curriculum_warmup_target_bucket_ratio 0.5 \ +--difficulty_curriculum_final_target_bucket_ratio 1.0 \ +--difficulty_curriculum_warmup_steps 500 \ +--difficulty_curriculum_total_steps 10000 \ +--difficulty_curriculum_min_hard_frac 0.05 \ +--difficulty_curriculum_max_hard_frac 0.50 \ +--difficulty_curriculum_bucket_sigma 0.0 \ +--difficulty_curriculum_easy_focus_sigma 0.0 \ +--difficulty_curriculum_uncertainty_weight 0.5 \ +--difficulty_curriculum_adaptive_enabled true +``` + +Tuning tips: + +- Increase `difficulty_curriculum_easy_focus_steps` to keep the easy bootstrap around longer. +- Lower `difficulty_curriculum_bootstrap_target_bucket_ratio` to bias more strongly toward the easiest buckets early. +- Lower `difficulty_curriculum_bucket_sigma` or `difficulty_curriculum_easy_focus_sigma` to concentrate probability on fewer neighboring buckets. +- Lower `difficulty_curriculum_warmup_target_bucket_ratio` if you want the post-bootstrap warmup to stay easier for longer. + +Metrics are logged through the standard GRPO tracking path. The most useful ones are: + +- `curriculum/progress` +- `curriculum/static_bucket_prob_*` +- `curriculum/adaptive_bucket_prob_*` +- `curriculum/bucket_prob_*` +- `curriculum/sampled_bucket_count_*` +- `curriculum/bucket_reward_mean_*` +- `curriculum/bucket_abs_advantage_mean_*` + +See `scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh` for a concrete launch example. The dataset metadata can be produced with `scripts/data/difficulty_sampling/create_bucketed_difficulty.py`. + For details on how GRPO's HSDP sharding works, see [OLMo-core Sharding and Parallelism](olmo_core_sharding.md). --- diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index dd3cf85045..26f5cac02c 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -33,7 +33,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer -from open_instruct import data_types, padding_free_collator, utils +from open_instruct import data_types, padding_free_collator, rlvr_curriculum, utils from open_instruct.data_types import EnvConfig, EnvConfigEntry from open_instruct.dataset_transformation import ( DATASET_ORIGIN_KEY, @@ -424,6 +424,30 @@ class StreamingDataLoaderConfig: mask_truncated_completions: bool = False mask_tool_use: bool = True + # Difficulty-aware prompt curriculum + difficulty_curriculum_enabled: bool = False + difficulty_curriculum_field: str = "difficulty" + difficulty_curriculum_posterior_mean_field: str = "posterior_mean" + difficulty_curriculum_bucket_index_field: str = "bucket_index" + difficulty_curriculum_bucket_count_field: str = "bucket_count" + difficulty_curriculum_easy_focus_steps: int = 100 + difficulty_curriculum_bootstrap_target_bucket_ratio: float = 0.125 + difficulty_curriculum_warmup_target_bucket_ratio: float = 0.5 + difficulty_curriculum_final_target_bucket_ratio: float = 1.0 + difficulty_curriculum_warmup_steps: int = 500 + difficulty_curriculum_total_steps: int = 10000 + difficulty_curriculum_min_hard_frac: float = 0.05 + difficulty_curriculum_max_hard_frac: float = 0.50 + difficulty_curriculum_bucket_sigma: float = 0.0 + difficulty_curriculum_easy_focus_sigma: float = 0.0 + difficulty_curriculum_uncertainty_weight: float = 0.5 + difficulty_curriculum_adaptive_enabled: bool = False + difficulty_curriculum_adaptive_update_every: int = 50 + difficulty_curriculum_adaptive_learning_signal_weight: float = 0.7 + difficulty_curriculum_adaptive_exploration_weight: float = 0.3 + difficulty_curriculum_adaptive_blend_weight: float = 0.5 + difficulty_curriculum_strict_metadata: bool = True + # Dataset dataset_mixer_list: list[str] = field(default_factory=lambda: ["ai2-adapt-dev/rlvr_gsm8k_zs", "1.0"]) dataset_mixer_eval_list: list[str] = field(default_factory=list) @@ -566,6 +590,33 @@ def build_dataloader( fs_local_rank=fs_local_rank, ) + def build_difficulty_curriculum_config(self, seed: int) -> rlvr_curriculum.DifficultyCurriculumConfig: + return rlvr_curriculum.DifficultyCurriculumConfig( + enabled=self.difficulty_curriculum_enabled, + difficulty_field=self.difficulty_curriculum_field, + posterior_mean_field=self.difficulty_curriculum_posterior_mean_field, + bucket_index_field=self.difficulty_curriculum_bucket_index_field, + bucket_count_field=self.difficulty_curriculum_bucket_count_field, + easy_focus_steps=self.difficulty_curriculum_easy_focus_steps, + bootstrap_target_bucket_ratio=self.difficulty_curriculum_bootstrap_target_bucket_ratio, + warmup_target_bucket_ratio=self.difficulty_curriculum_warmup_target_bucket_ratio, + final_target_bucket_ratio=self.difficulty_curriculum_final_target_bucket_ratio, + warmup_steps=self.difficulty_curriculum_warmup_steps, + total_curriculum_steps=self.difficulty_curriculum_total_steps, + min_hard_frac=self.difficulty_curriculum_min_hard_frac, + max_hard_frac=self.difficulty_curriculum_max_hard_frac, + bucket_sigma=self.difficulty_curriculum_bucket_sigma, + easy_focus_sigma=self.difficulty_curriculum_easy_focus_sigma, + uncertainty_weight=self.difficulty_curriculum_uncertainty_weight, + adaptive_enabled=self.difficulty_curriculum_adaptive_enabled, + adaptive_update_every=self.difficulty_curriculum_adaptive_update_every, + adaptive_learning_signal_weight=self.difficulty_curriculum_adaptive_learning_signal_weight, + adaptive_exploration_weight=self.difficulty_curriculum_adaptive_exploration_weight, + adaptive_blend_weight=self.difficulty_curriculum_adaptive_blend_weight, + seed=seed, + strict_metadata=self.difficulty_curriculum_strict_metadata, + ) + class StreamingDataLoader(data_loader.DataLoaderBase): """Thin wrapper dataloader that pulls pre-prepared data from the DataPreparationActor singleton.""" @@ -665,6 +716,143 @@ def single_example_collator(examples: list[dict[str, Any]]) -> dict[str, Any]: return example | {"index": torch.tensor([example["index"]])} +class DifficultyCurriculumHFDataLoader(HFDataLoader): + """Prompt loader that samples with a difficulty-aware curriculum.""" + + def __init__( + self, + dataset: Dataset, + batch_size: int, + seed: int, + dp_rank: int, + dp_world_size: int, + work_dir: str, + curriculum_config: rlvr_curriculum.DifficultyCurriculumConfig, + automatic_reshuffle: bool = True, + collator: Callable[[list[dict[str, Any]]], dict[str, Any]] | None = None, + device: torch.device | None = None, + drop_last: bool = True, + fs_local_rank: int | None = None, + max_seq_length: int = 1, + ) -> None: + if batch_size != 1: + raise ValueError("DifficultyCurriculumHFDataLoader currently supports batch_size=1 only") + if dp_world_size != 1 or dp_rank != 0: + raise ValueError("DifficultyCurriculumHFDataLoader currently supports dp_world_size=1 only") + + self._sampling_step = 0 + self._curriculum_sampler = rlvr_curriculum.BetaBinomialDifficultySampler( + dataset=dataset, + num_samples=max(len(dataset), 1), + config=curriculum_config, + global_step_getter=lambda: self._sampling_step, + ) + self._curriculum_iter = None + + super().__init__( + dataset=dataset, + batch_size=batch_size, + seed=seed, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + work_dir=work_dir, + automatic_reshuffle=automatic_reshuffle, + collator=collator, + device=device, + drop_last=drop_last, + fs_local_rank=fs_local_rank, + max_seq_length=max_seq_length, + ) + + @property + def curriculum_sampler(self) -> rlvr_curriculum.BetaBinomialDifficultySampler: + return self._curriculum_sampler + + def set_sampling_step(self, step: int) -> None: + self._sampling_step = int(step) + + def record_curriculum_observations( + self, + dataset_indices: list[int] | np.ndarray, + rewards: list[float] | np.ndarray, + advantages: list[float] | np.ndarray | None = None, + ) -> None: + self._curriculum_sampler.record_observations(dataset_indices, rewards, advantages) + + def build_curriculum_metrics(self, prompt_dataset_indices: list[int], step: int) -> dict[str, float]: + return self._curriculum_sampler.build_metrics(prompt_dataset_indices, step) + + def state_dict(self) -> dict[str, Any]: + return { + "epoch": self._epoch, + "batches_processed": self.batches_processed, + "sampling_step": self._sampling_step, + "curriculum_sampler_state": self._curriculum_sampler.state_dict(), + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._epoch = state_dict["epoch"] + self.batches_processed = state_dict["batches_processed"] + self._sampling_step = state_dict.get("sampling_step", 0) + self._curriculum_sampler.load_state_dict(state_dict["curriculum_sampler_state"]) + self.effective_size = max(len(self._full_dataset), 1) + self.dataset = self._full_dataset + self._curriculum_iter = iter(self._curriculum_sampler) + self._current_iter = None + + def exclude_index(self, index: int) -> None: + self._curriculum_sampler.exclude_index(index) + + def _reshard(self, epoch: int) -> None: + del epoch + self._precomputed_batch_sizes = None + self._num_padding_batches = 0 + self._overflow = [] + self.effective_size = max(len(self._full_dataset), 1) + self.dataset = self._full_dataset + self._curriculum_iter = iter(self._curriculum_sampler) + + def _iter_batches(self) -> Iterable[dict[str, Any]]: + start_example = self.batches_processed + if self._curriculum_iter is None: + self._curriculum_iter = iter(self._curriculum_sampler) + + for batch_offset in range(start_example, self.effective_size): + dataset_index = next(self._curriculum_iter) + example = self._full_dataset[dataset_index] + prompt_id = f"{self._epoch}_{batch_offset}_{dataset_index}" + batch = to_device(self._collator([example | {"prompt_id": prompt_id}]), self._device) + yield batch + + +def build_data_preparation_prompt_dataloader( + dataset: Dataset, seed: int, work_dir: str, config: StreamingDataLoaderConfig +) -> HFDataLoader: + if config.difficulty_curriculum_enabled: + return DifficultyCurriculumHFDataLoader( + dataset=dataset, + batch_size=1, + seed=seed, + dp_rank=0, + dp_world_size=1, + work_dir=work_dir, + automatic_reshuffle=True, + collator=single_example_collator, + curriculum_config=config.build_difficulty_curriculum_config(seed=seed), + ) + + return HFDataLoader( + dataset=dataset, + batch_size=1, + seed=seed, + dp_rank=0, + dp_world_size=1, + work_dir=work_dir, + automatic_reshuffle=True, + collator=single_example_collator, + ) + + def _merge_env_config(base_env_config: EnvConfig, sample_env_config: dict[str, Any] | None) -> EnvConfig: """Merge base and sample env config into canonical payload. Sample env_config overrides any base env_configs with the same name. @@ -721,6 +909,7 @@ def add_prompt_to_generator( ground_truth_overrides: dict[int, Any] | None = None, ) -> None: index = int(example["index"]) + prompt_id = example.get("prompt_id", f"{epoch_number}_{index}") sample_env_config = example.get(ENV_CONFIG_KEY) env_config = _merge_env_config(base_env_config, sample_env_config) @@ -732,7 +921,7 @@ def add_prompt_to_generator( prompt=example[INPUT_IDS_PROMPT_KEY], generation_config=generation_config, index=index, - prompt_id=f"{epoch_number}_{index}", + prompt_id=prompt_id, is_eval=is_eval, active_tools=example.get(TOOLS_COLUMN_KEY), env_config=env_config, @@ -1179,15 +1368,8 @@ def __init__( self.model_name = model_name self.base_env_config = base_env_config - self.iter_dataloader = HFDataLoader( - dataset=dataset, - batch_size=1, - seed=seed, - dp_rank=0, - dp_world_size=1, - work_dir=work_dir, - automatic_reshuffle=True, - collator=single_example_collator, + self.iter_dataloader = build_data_preparation_prompt_dataloader( + dataset=dataset, seed=seed, work_dir=work_dir, config=config ) self.prepared_data: dict[int, list[data_types.CollatedBatchData]] = {} @@ -1227,6 +1409,8 @@ def _data_preparation_loop(self): num_initial_prompts = self.config.async_steps * self.global_batch_size logger.info(f"[DataPreparationActor] Pushing {num_initial_prompts} initial prompts to param_prompt_Q") + if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader): + self.iter_dataloader.set_sampling_step(self.training_step) for _ in range(num_initial_prompts): add_prompt_to_generator( next(self.iter_dataloader), @@ -1246,6 +1430,8 @@ def _data_preparation_loop(self): ) time.sleep(0.1) generation_idle_wait_time = time.perf_counter() - generation_idle_wait_start_time + if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader): + self.iter_dataloader.set_sampling_step(step) logger.info( f"[DataPreparationActor] Step {step}: calling accumulate_inference_batches for {self.global_batch_size} prompts" @@ -1289,14 +1475,23 @@ def _data_preparation_loop(self): ) for _ in range(self.dp_world_size) ] + empty_metrics = {"time/generation_idle_waiting_for_trainer": generation_idle_wait_time} + if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader): + empty_metrics.update(self.iter_dataloader.build_curriculum_metrics([], step)) with self.lock: self.prepared_data[step] = empty_data - self.metrics[step] = {"time/generation_idle_waiting_for_trainer": generation_idle_wait_time} + self.metrics[step] = empty_metrics self.current_prepared_step = step + self.training_step = step + 1 continue assert batch is not None assert batch_stats is not None + prompt_dataset_indices = ( + [int(index) for index in batch.indices[:: self.config.num_samples_per_prompt_rollout]] + if batch.indices is not None + else [] + ) if self.rubric_manager and batch.decoded_responses: rubric_metrics, new_overrides = self.rubric_manager.run_step( @@ -1355,6 +1550,10 @@ def _data_preparation_loop(self): assert result.logprobs is not None result.logprobs = [result.logprobs[i] for i in stop_idxes] + if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader) and batch.indices is not None: + normalized_scores = np.clip(scores / max(self.config.max_possible_score, 1e-8), 0.0, 1.0) + self.iter_dataloader.record_curriculum_observations(batch.indices, normalized_scores, advantages) + assert result.logprobs is not None packed_sequences = pack_sequences( queries=batch.queries, @@ -1440,10 +1639,14 @@ def _data_preparation_loop(self): step_metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time step_metrics["time/getting_response"] = result.token_statistics.generation_time + if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader): + step_metrics.update(self.iter_dataloader.build_curriculum_metrics(prompt_dataset_indices, step)) + with self.lock: self.prepared_data[step] = collated_data self.metrics[step] = step_metrics self.current_prepared_step = step + self.training_step = step + 1 def get_data(self, rank: int, step: int) -> dict: """Called by each rank's StreamingDataLoader. Blocks until data ready.""" diff --git a/open_instruct/rlvr_curriculum.py b/open_instruct/rlvr_curriculum.py new file mode 100644 index 0000000000..832e36db9b --- /dev/null +++ b/open_instruct/rlvr_curriculum.py @@ -0,0 +1,656 @@ +"""Difficulty-aware curriculum sampling for RLVR / GRPO prompt selection.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +from torch.utils.data import Sampler + +from open_instruct import logger_utils + +logger = logger_utils.setup_logger(__name__) + +_DEFAULT_POSTERIOR_MEAN = 0.5 + + +def _resolve_path(value: Any, path: str) -> Any: + current = value + for part in path.split("."): + if not isinstance(current, dict) or part not in current: + return None + current = current[part] + return current + + +def _coerce_int(value: Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + return None + + +def _coerce_float(value: Any) -> float | None: + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + value = float(value) + if math.isnan(value) or math.isinf(value): + return None + return value + return None + + +def _normalize_probs(values: np.ndarray) -> np.ndarray: + total = float(values.sum()) + if total <= 0: + return np.zeros_like(values, dtype=np.float64) + return values.astype(np.float64) / total + + +@dataclass +class DifficultyCurriculumConfig: + enabled: bool = False + difficulty_field: str = "difficulty" + posterior_mean_field: str = "posterior_mean" + bucket_index_field: str = "bucket_index" + bucket_count_field: str = "bucket_count" + easy_focus_steps: int = 100 + bootstrap_target_bucket_ratio: float = 0.125 + warmup_target_bucket_ratio: float = 0.5 + final_target_bucket_ratio: float = 1.0 + warmup_steps: int = 500 + total_curriculum_steps: int = 10000 + min_hard_frac: float = 0.05 + max_hard_frac: float = 0.50 + bucket_sigma: float = 0.0 + easy_focus_sigma: float = 0.0 + uncertainty_weight: float = 0.5 + adaptive_enabled: bool = False + adaptive_update_every: int = 50 + adaptive_learning_signal_weight: float = 0.7 + adaptive_exploration_weight: float = 0.3 + adaptive_blend_weight: float = 0.5 + seed: int = 0 + strict_metadata: bool = True + epsilon: float = 1e-8 + + def __post_init__(self) -> None: + if self.easy_focus_steps < 0: + raise ValueError("easy_focus_steps must be >= 0") + if not 0.0 <= self.bootstrap_target_bucket_ratio <= 1.0: + raise ValueError("bootstrap_target_bucket_ratio must be in [0, 1]") + if not 0.0 <= self.warmup_target_bucket_ratio <= 1.0: + raise ValueError("warmup_target_bucket_ratio must be in [0, 1]") + if not 0.0 <= self.final_target_bucket_ratio <= 1.0: + raise ValueError("final_target_bucket_ratio must be in [0, 1]") + if self.warmup_steps < 0: + raise ValueError("warmup_steps must be >= 0") + if self.total_curriculum_steps <= 0: + raise ValueError("total_curriculum_steps must be > 0") + if not 0.0 <= self.min_hard_frac <= 1.0: + raise ValueError("min_hard_frac must be in [0, 1]") + if not 0.0 <= self.max_hard_frac <= 1.0: + raise ValueError("max_hard_frac must be in [0, 1]") + if self.min_hard_frac > self.max_hard_frac: + raise ValueError("min_hard_frac must be <= max_hard_frac") + if self.bucket_sigma < 0: + raise ValueError("bucket_sigma must be >= 0") + if self.easy_focus_sigma < 0: + raise ValueError("easy_focus_sigma must be >= 0") + if not 0.0 <= self.uncertainty_weight <= 1.0: + raise ValueError("uncertainty_weight must be in [0, 1]") + if self.adaptive_update_every <= 0: + raise ValueError("adaptive_update_every must be > 0") + if not 0.0 <= self.adaptive_learning_signal_weight <= 1.0: + raise ValueError("adaptive_learning_signal_weight must be in [0, 1]") + if not 0.0 <= self.adaptive_exploration_weight <= 1.0: + raise ValueError("adaptive_exploration_weight must be in [0, 1]") + if not 0.0 <= self.adaptive_blend_weight <= 1.0: + raise ValueError("adaptive_blend_weight must be in [0, 1]") + if self.epsilon <= 0: + raise ValueError("epsilon must be > 0") + + +class AdaptiveBucketStats: + """Tracks per-bucket learning signal statistics for adaptive sampling.""" + + def __init__(self, learning_signal_weight: float, exploration_weight: float, epsilon: float) -> None: + self.learning_signal_weight = learning_signal_weight + self.exploration_weight = exploration_weight + self.epsilon = epsilon + + self.total_count = 0 + self._count_by_bucket: dict[int, int] = {} + self._reward_sum_by_bucket: dict[int, float] = {} + self._reward_sq_sum_by_bucket: dict[int, float] = {} + self._abs_advantage_sum_by_bucket: dict[int, float] = {} + self._advantage_count_by_bucket: dict[int, int] = {} + + def update( + self, + bucket_indices: list[int] | np.ndarray, + rewards: list[float] | np.ndarray, + advantages: list[float] | np.ndarray | None = None, + ) -> None: + if len(bucket_indices) != len(rewards): + raise ValueError("bucket_indices and rewards must have the same length") + if advantages is not None and len(advantages) != len(rewards): + raise ValueError("advantages and rewards must have the same length") + + reward_values = [float(np.clip(value, 0.0, 1.0)) for value in rewards] + advantage_values = None if advantages is None else [abs(float(value)) for value in advantages] + + for position, bucket_index in enumerate(bucket_indices): + bucket = int(bucket_index) + reward = reward_values[position] + + self.total_count += 1 + self._count_by_bucket[bucket] = self._count_by_bucket.get(bucket, 0) + 1 + self._reward_sum_by_bucket[bucket] = self._reward_sum_by_bucket.get(bucket, 0.0) + reward + self._reward_sq_sum_by_bucket[bucket] = self._reward_sq_sum_by_bucket.get(bucket, 0.0) + reward * reward + + if advantage_values is not None: + advantage = advantage_values[position] + self._abs_advantage_sum_by_bucket[bucket] = ( + self._abs_advantage_sum_by_bucket.get(bucket, 0.0) + advantage + ) + self._advantage_count_by_bucket[bucket] = self._advantage_count_by_bucket.get(bucket, 0) + 1 + + def state_dict(self) -> dict[str, Any]: + return { + "total_count": self.total_count, + "count_by_bucket": self._count_by_bucket, + "reward_sum_by_bucket": self._reward_sum_by_bucket, + "reward_sq_sum_by_bucket": self._reward_sq_sum_by_bucket, + "abs_advantage_sum_by_bucket": self._abs_advantage_sum_by_bucket, + "advantage_count_by_bucket": self._advantage_count_by_bucket, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.total_count = int(state_dict.get("total_count", 0)) + self._count_by_bucket = {int(k): int(v) for k, v in state_dict.get("count_by_bucket", {}).items()} + self._reward_sum_by_bucket = {int(k): float(v) for k, v in state_dict.get("reward_sum_by_bucket", {}).items()} + self._reward_sq_sum_by_bucket = { + int(k): float(v) for k, v in state_dict.get("reward_sq_sum_by_bucket", {}).items() + } + self._abs_advantage_sum_by_bucket = { + int(k): float(v) for k, v in state_dict.get("abs_advantage_sum_by_bucket", {}).items() + } + self._advantage_count_by_bucket = { + int(k): int(v) for k, v in state_dict.get("advantage_count_by_bucket", {}).items() + } + + def get_count(self, bucket_index: int) -> int: + return self._count_by_bucket.get(bucket_index, 0) + + def get_mean_reward(self, bucket_index: int) -> float: + count = self.get_count(bucket_index) + if count == 0: + return 0.0 + return self._reward_sum_by_bucket.get(bucket_index, 0.0) / count + + def get_mean_abs_advantage(self, bucket_index: int) -> float: + count = self._advantage_count_by_bucket.get(bucket_index, 0) + if count == 0: + return 0.0 + return self._abs_advantage_sum_by_bucket.get(bucket_index, 0.0) / count + + def _get_reward_variance(self, bucket_index: int) -> float: + count = self.get_count(bucket_index) + if count == 0: + return 0.0 + mean_reward = self.get_mean_reward(bucket_index) + mean_reward_sq = self._reward_sq_sum_by_bucket.get(bucket_index, 0.0) / count + return max(0.0, mean_reward_sq - mean_reward * mean_reward) + + def get_bucket_scores(self, bucket_count: int) -> np.ndarray: + scores = np.zeros(bucket_count, dtype=np.float64) + total_count = max(self.total_count, 0) + + for bucket_index in range(bucket_count): + count = self.get_count(bucket_index) + mean_reward = self.get_mean_reward(bucket_index) + mean_abs_advantage = self.get_mean_abs_advantage(bucket_index) + + if self._advantage_count_by_bucket.get(bucket_index, 0) > 0: + learning_signal = mean_abs_advantage * max(0.0, 1.0 - mean_reward) + else: + reward_variance = self._get_reward_variance(bucket_index) + non_saturation = max(0.0, 1.0 - 2.0 * abs(mean_reward - 0.5)) + learning_signal = 0.5 * reward_variance + 0.5 * non_saturation + + exploration_bonus = math.sqrt(math.log(total_count + 1.0) / (count + 1.0)) + scores[bucket_index] = ( + self.learning_signal_weight * learning_signal + + self.exploration_weight * exploration_bonus + + self.epsilon + ) + + return scores + + +@dataclass(frozen=True) +class _ParsedDifficultyMetadata: + bucket_index: int | None + bucket_count: int | None + posterior_mean: float | None + error: str | None + + +class BetaBinomialDifficultySampler(Sampler[int]): + """Bucket-aware curriculum sampler that uses beta-binomial difficulty metadata.""" + + def __init__(self, dataset, num_samples: int, config: DifficultyCurriculumConfig, global_step_getter) -> None: + if num_samples <= 0: + raise ValueError("num_samples must be > 0") + + self.dataset = dataset + self.num_samples = num_samples + self.config = config + self.global_step_getter = global_step_getter + + self._generator = torch.Generator() + self._generator.manual_seed(self.config.seed) + + self._excluded_indices: set[int] = set() + self._index_to_bucket: dict[int, int] = {} + self.bucket_count = 1 + self.metadata_fallback_count = 0 + + self._base_bucket_indices: list[list[int]] = [] + self._base_bucket_weights: list[torch.Tensor] = [] + self._active_bucket_indices: list[list[int]] = [] + self._active_bucket_weights: list[torch.Tensor] = [] + + self.adaptive_stats = None + if self.config.adaptive_enabled: + self.adaptive_stats = AdaptiveBucketStats( + learning_signal_weight=self.config.adaptive_learning_signal_weight, + exploration_weight=self.config.adaptive_exploration_weight, + epsilon=self.config.epsilon, + ) + + self._cached_adaptive_probs: np.ndarray | None = None + self._last_adaptive_refresh_step = -1 + + self._build_bucket_index() + + def _parse_metadata(self, example: dict[str, Any], index: int) -> _ParsedDifficultyMetadata: + difficulty_blob = _resolve_path(example, self.config.difficulty_field) + if not isinstance(difficulty_blob, dict): + return _ParsedDifficultyMetadata( + bucket_index=None, + bucket_count=None, + posterior_mean=None, + error=f"missing '{self.config.difficulty_field}' metadata for dataset index {index}", + ) + + bucket_index = _coerce_int(_resolve_path(difficulty_blob, self.config.bucket_index_field)) + bucket_count = _coerce_int(_resolve_path(difficulty_blob, self.config.bucket_count_field)) + posterior_mean = _coerce_float(_resolve_path(difficulty_blob, self.config.posterior_mean_field)) + + if bucket_index is None or bucket_index < 0: + return _ParsedDifficultyMetadata( + bucket_index=None, + bucket_count=bucket_count, + posterior_mean=posterior_mean, + error=f"invalid bucket_index for dataset index {index}", + ) + if bucket_count is None or bucket_count <= 0: + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=None, + posterior_mean=posterior_mean, + error=f"invalid bucket_count for dataset index {index}", + ) + if posterior_mean is None: + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=bucket_count, + posterior_mean=None, + error=f"invalid posterior_mean for dataset index {index}", + ) + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, bucket_count=bucket_count, posterior_mean=posterior_mean, error=None + ) + + def _build_bucket_index(self) -> None: + parsed_rows: list[_ParsedDifficultyMetadata] = [] + observed_bucket_counts: set[int] = set() + max_bucket_index = -1 + + for dataset_index in range(len(self.dataset)): + parsed = self._parse_metadata(self.dataset[dataset_index], dataset_index) + if parsed.error is not None and self.config.strict_metadata: + raise ValueError(parsed.error) + if parsed.bucket_count is not None: + observed_bucket_counts.add(parsed.bucket_count) + if parsed.bucket_index is not None: + max_bucket_index = max(max_bucket_index, parsed.bucket_index) + parsed_rows.append(parsed) + + if observed_bucket_counts: + if self.config.strict_metadata and len(observed_bucket_counts) > 1: + raise ValueError( + f"inconsistent difficulty bucket_count values found: {sorted(observed_bucket_counts)}" + ) + self.bucket_count = max(observed_bucket_counts) + elif max_bucket_index >= 0: + self.bucket_count = max_bucket_index + 1 + else: + self.bucket_count = 1 + + self._base_bucket_indices = [[] for _ in range(self.bucket_count)] + bucket_weight_lists: list[list[float]] = [[] for _ in range(self.bucket_count)] + fallback_bucket = min(self.bucket_count - 1, self.bucket_count // 2) + + for dataset_index, parsed in enumerate(parsed_rows): + if parsed.error is not None: + bucket_index = fallback_bucket + posterior_mean = _DEFAULT_POSTERIOR_MEAN + self.metadata_fallback_count += 1 + else: + assert parsed.bucket_index is not None + bucket_index = int(np.clip(parsed.bucket_index, 0, self.bucket_count - 1)) + posterior_mean = parsed.posterior_mean + + if posterior_mean is None: + posterior_mean = _DEFAULT_POSTERIOR_MEAN + posterior_mean = float(np.clip(posterior_mean, 0.0, 1.0)) + example_weight = self._compute_example_weight(posterior_mean) + + self._index_to_bucket[dataset_index] = bucket_index + self._base_bucket_indices[bucket_index].append(dataset_index) + bucket_weight_lists[bucket_index].append(example_weight) + + self._base_bucket_weights = [ + torch.tensor(weight_list, dtype=torch.float64) for weight_list in bucket_weight_lists + ] + self._active_bucket_indices = [list(indices) for indices in self._base_bucket_indices] + self._active_bucket_weights = [weights.clone() for weights in self._base_bucket_weights] + + if self.metadata_fallback_count > 0 and not self.config.strict_metadata: + logger.warning( + "Difficulty curriculum fell back to conservative defaults for %s/%s rows because metadata was missing " + "or invalid.", + self.metadata_fallback_count, + len(self.dataset), + ) + + def _compute_example_weight(self, posterior_mean: float) -> float: + probability = float(np.clip(posterior_mean, 0.0, 1.0)) + uncertainty = 4.0 * probability * (1.0 - probability) + hardness = 1.0 - probability + return ( + self.config.uncertainty_weight * uncertainty + + (1.0 - self.config.uncertainty_weight) * hardness + + self.config.epsilon + ) + + def __len__(self) -> int: + return self.num_samples + + @property + def bucket_to_indices(self) -> tuple[tuple[int, ...], ...]: + return tuple(tuple(indices) for indices in self._base_bucket_indices) + + def _get_current_step(self) -> int: + step = self.global_step_getter() if self.global_step_getter is not None else 0 + return max(int(step), 0) + + def get_progress(self, step: int | None = None) -> float: + if step is None: + step = self._get_current_step() + if step < self.config.warmup_steps: + return 0.0 + return min(1.0, (step - self.config.warmup_steps) / self.config.total_curriculum_steps) + + def _smooth_progress(self, step: int | None = None) -> float: + progress = self.get_progress(step) + return progress * progress * (3.0 - 2.0 * progress) + + def _get_default_bucket_sigma(self) -> float: + return max(0.85, 0.25 * max(self.bucket_count - 1, 1)) + + def _get_bucket_sigma(self, step: int | None = None) -> float: + sigma = self.config.bucket_sigma if self.config.bucket_sigma > 0 else self._get_default_bucket_sigma() + if step is not None and step < self.config.easy_focus_steps and self.config.easy_focus_sigma > 0: + return self.config.easy_focus_sigma + return sigma + + def _bucket_ratio_to_bucket_index(self, bucket_ratio: float) -> float: + return float(self.bucket_count - 1) * bucket_ratio + + def _get_target_bucket(self, step: int | None = None) -> float: + if step is None: + step = self._get_current_step() + + warmup_target_bucket = self._bucket_ratio_to_bucket_index(self.config.warmup_target_bucket_ratio) + final_target_bucket = self._bucket_ratio_to_bucket_index(self.config.final_target_bucket_ratio) + + if self.config.easy_focus_steps > 0 and step < self.config.easy_focus_steps: + easy_progress = min(1.0, step / self.config.easy_focus_steps) + bootstrap_target_bucket = self._bucket_ratio_to_bucket_index(self.config.bootstrap_target_bucket_ratio) + return bootstrap_target_bucket + (warmup_target_bucket - bootstrap_target_bucket) * easy_progress + + smooth_progress = self._smooth_progress(step) + return warmup_target_bucket + (final_target_bucket - warmup_target_bucket) * smooth_progress + + def _available_bucket_mask(self) -> np.ndarray: + return np.array([1.0 if indices else 0.0 for indices in self._active_bucket_indices], dtype=np.float64) + + def get_static_bucket_probs(self, step: int | None = None) -> np.ndarray: + if self.bucket_count == 1: + return np.ones(1, dtype=np.float64) + + if step is None: + step = self._get_current_step() + + smooth_progress = self._smooth_progress(step) + target_bucket = self._get_target_bucket(step) + hard_bucket_frac = ( + self.config.min_hard_frac + (self.config.max_hard_frac - self.config.min_hard_frac) * smooth_progress + ) + + bucket_ids = np.arange(self.bucket_count - 1, dtype=np.float64) + sigma = self._get_bucket_sigma(step) + gaussian_logits = np.exp(-0.5 * ((bucket_ids - target_bucket) / sigma) ** 2) + non_hard_probs = _normalize_probs(gaussian_logits) + + static_probs = np.zeros(self.bucket_count, dtype=np.float64) + static_probs[:-1] = (1.0 - hard_bucket_frac) * non_hard_probs + static_probs[-1] = hard_bucket_frac + + mask = self._available_bucket_mask() + static_probs *= mask + if mask.sum() == 0: + return np.ones(self.bucket_count, dtype=np.float64) / self.bucket_count + if static_probs.sum() <= 0: + return _normalize_probs(mask) + return _normalize_probs(static_probs) + + def get_adaptive_bucket_probs(self, step: int | None = None) -> np.ndarray | None: + if not self.config.adaptive_enabled or self.adaptive_stats is None or self.adaptive_stats.total_count == 0: + return None + + refresh_step = self._get_current_step() if step is None else step + if ( + self._cached_adaptive_probs is not None + and refresh_step - self._last_adaptive_refresh_step < self.config.adaptive_update_every + ): + return self._cached_adaptive_probs.copy() + + adaptive_scores = self.adaptive_stats.get_bucket_scores(self.bucket_count) + adaptive_scores *= self._available_bucket_mask() + if adaptive_scores.sum() <= 0: + return None + + self._cached_adaptive_probs = _normalize_probs(adaptive_scores) + self._last_adaptive_refresh_step = refresh_step + return self._cached_adaptive_probs.copy() + + def get_bucket_probs(self, step: int | None = None) -> np.ndarray: + static_probs = self.get_static_bucket_probs(step) + adaptive_probs = self.get_adaptive_bucket_probs(step) + if adaptive_probs is None: + return static_probs + + final_probs = ( + 1.0 - self.config.adaptive_blend_weight + ) * static_probs + self.config.adaptive_blend_weight * adaptive_probs + final_probs *= self._available_bucket_mask() + if final_probs.sum() <= 0: + return static_probs + return _normalize_probs(final_probs) + + def bucket_for_dataset_index(self, dataset_index: int) -> int: + return self._index_to_bucket[int(dataset_index)] + + def get_example_probability(self, dataset_index: int, step: int | None = None) -> float: + if int(dataset_index) in self._excluded_indices: + return 0.0 + bucket_index = self.bucket_for_dataset_index(dataset_index) + active_indices = self._active_bucket_indices[bucket_index] + if not active_indices: + return 0.0 + try: + local_index = active_indices.index(int(dataset_index)) + except ValueError: + return 0.0 + bucket_weight = self._active_bucket_weights[bucket_index] + weight_total = float(bucket_weight.sum().item()) + if weight_total <= 0: + return 0.0 + bucket_probs = self.get_bucket_probs(step) + return float(bucket_probs[bucket_index] * bucket_weight[local_index].item() / weight_total) + + def sample_index(self, step: int | None = None) -> int: + if self._available_bucket_mask().sum() == 0: + raise RuntimeError("All dataset examples have been excluded. Cannot continue iteration.") + bucket_probs = torch.tensor(self.get_bucket_probs(step), dtype=torch.float64) + bucket_index = int(torch.multinomial(bucket_probs, 1, generator=self._generator).item()) + + example_weights = self._active_bucket_weights[bucket_index] + if example_weights.numel() == 0: + raise RuntimeError("attempted to sample from an empty curriculum bucket") + + sampled_offset = int(torch.multinomial(example_weights, 1, generator=self._generator).item()) + return self._active_bucket_indices[bucket_index][sampled_offset] + + def __iter__(self): + for _ in range(self.num_samples): + yield self.sample_index() + + def exclude_index(self, dataset_index: int) -> None: + dataset_index = int(dataset_index) + if dataset_index in self._excluded_indices: + return + + bucket_index = self._index_to_bucket.get(dataset_index) + if bucket_index is None: + return + + active_indices = self._active_bucket_indices[bucket_index] + try: + position = active_indices.index(dataset_index) + except ValueError: + self._excluded_indices.add(dataset_index) + return + + active_indices.pop(position) + weights = self._active_bucket_weights[bucket_index] + if weights.numel() <= 1: + self._active_bucket_weights[bucket_index] = weights[:0].clone() + else: + self._active_bucket_weights[bucket_index] = torch.cat((weights[:position], weights[position + 1 :])) + self._excluded_indices.add(dataset_index) + + def record_observations( + self, + dataset_indices: list[int] | np.ndarray, + rewards: list[float] | np.ndarray, + advantages: list[float] | np.ndarray | None = None, + ) -> None: + if not self.config.adaptive_enabled or self.adaptive_stats is None: + return + if len(dataset_indices) == 0: + return + + bucket_indices = [self.bucket_for_dataset_index(int(dataset_index)) for dataset_index in dataset_indices] + self.adaptive_stats.update(bucket_indices, rewards, advantages) + self._cached_adaptive_probs = None + + def build_metrics(self, prompt_dataset_indices: list[int], step: int | None = None) -> dict[str, float]: + metrics: dict[str, float] = {"curriculum/progress": self.get_progress(step)} + + static_probs = self.get_static_bucket_probs(step) + for bucket_index, probability in enumerate(static_probs): + metrics[f"curriculum/static_bucket_prob_{bucket_index}"] = float(probability) + + adaptive_probs = self.get_adaptive_bucket_probs(step) + if self.config.adaptive_enabled: + if adaptive_probs is None: + adaptive_probs = np.zeros(self.bucket_count, dtype=np.float64) + for bucket_index, probability in enumerate(adaptive_probs): + metrics[f"curriculum/adaptive_bucket_prob_{bucket_index}"] = float(probability) + + final_probs = self.get_bucket_probs(step) + for bucket_index, probability in enumerate(final_probs): + metrics[f"curriculum/bucket_prob_{bucket_index}"] = float(probability) + + sampled_counts = np.zeros(self.bucket_count, dtype=np.float64) + for dataset_index in prompt_dataset_indices: + sampled_counts[self.bucket_for_dataset_index(int(dataset_index))] += 1.0 + for bucket_index, count in enumerate(sampled_counts): + metrics[f"curriculum/sampled_bucket_count_{bucket_index}"] = float(count) + + if self.config.adaptive_enabled and self.adaptive_stats is not None: + for bucket_index in range(self.bucket_count): + metrics[f"curriculum/bucket_reward_mean_{bucket_index}"] = float( + self.adaptive_stats.get_mean_reward(bucket_index) + ) + metrics[f"curriculum/bucket_abs_advantage_mean_{bucket_index}"] = float( + self.adaptive_stats.get_mean_abs_advantage(bucket_index) + ) + + return metrics + + def state_dict(self) -> dict[str, Any]: + return { + "generator_state": self._generator.get_state(), + "excluded_indices": sorted(self._excluded_indices), + "adaptive_stats": None if self.adaptive_stats is None else self.adaptive_stats.state_dict(), + "last_adaptive_refresh_step": self._last_adaptive_refresh_step, + "cached_adaptive_probs": None + if self._cached_adaptive_probs is None + else self._cached_adaptive_probs.tolist(), + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + generator_state = state_dict.get("generator_state") + if generator_state is not None: + self._generator.set_state(generator_state) + + self._excluded_indices = {int(index) for index in state_dict.get("excluded_indices", [])} + self._active_bucket_indices = [] + self._active_bucket_weights = [] + for base_indices, base_weights in zip(self._base_bucket_indices, self._base_bucket_weights, strict=True): + keep_positions = [ + position for position, index in enumerate(base_indices) if index not in self._excluded_indices + ] + self._active_bucket_indices.append([base_indices[position] for position in keep_positions]) + self._active_bucket_weights.append(base_weights[keep_positions].clone()) + + if self.adaptive_stats is not None and state_dict.get("adaptive_stats") is not None: + self.adaptive_stats.load_state_dict(state_dict["adaptive_stats"]) + + self._last_adaptive_refresh_step = int(state_dict.get("last_adaptive_refresh_step", -1)) + cached_adaptive_probs = state_dict.get("cached_adaptive_probs") + self._cached_adaptive_probs = None if cached_adaptive_probs is None else np.array(cached_adaptive_probs) diff --git a/open_instruct/test_rlvr_curriculum.py b/open_instruct/test_rlvr_curriculum.py new file mode 100644 index 0000000000..486dfc9684 --- /dev/null +++ b/open_instruct/test_rlvr_curriculum.py @@ -0,0 +1,185 @@ +import sys +import tempfile +import types +import unittest + +from datasets import Dataset + +if "vllm" not in sys.modules: + vllm_stub = types.ModuleType("vllm") + vllm_stub.SamplingParams = object + sys.modules["vllm"] = vllm_stub + +from open_instruct import data_loader, rlvr_curriculum + + +class ListDataset: + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, index): + return self.rows[index] + + +def make_difficulty_row(index: int, bucket_index: int, posterior_mean: float, bucket_count: int = 5) -> dict: + return { + "index": index, + "difficulty": { + "value": 1.0 - posterior_mean, + "posterior_mean": posterior_mean, + "posterior_lower_bound": 0.0, + "expected_quantile": bucket_index / max(bucket_count - 1, 1), + "bucket_index": bucket_index, + "bucket_count": bucket_count, + }, + } + + +def make_bucket_dataset(bucket_count: int = 5) -> ListDataset: + rows = [ + make_difficulty_row(index=0, bucket_index=0, posterior_mean=0.95, bucket_count=bucket_count), + make_difficulty_row(index=1, bucket_index=1, posterior_mean=0.80, bucket_count=bucket_count), + make_difficulty_row(index=2, bucket_index=2, posterior_mean=0.50, bucket_count=bucket_count), + make_difficulty_row(index=3, bucket_index=3, posterior_mean=0.20, bucket_count=bucket_count), + make_difficulty_row(index=4, bucket_index=4, posterior_mean=0.003, bucket_count=bucket_count), + ] + return ListDataset(rows) + + +def make_plain_hf_dataset(num_examples: int) -> Dataset: + return Dataset.from_dict( + {"text": [f"example_{index}" for index in range(num_examples)], "index": list(range(num_examples))} + ) + + +class TestDifficultyCurriculumSampler(unittest.TestCase): + def _make_config(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumConfig: + return rlvr_curriculum.DifficultyCurriculumConfig( + enabled=True, easy_focus_steps=100, warmup_steps=120, total_curriculum_steps=200, seed=13, **overrides + ) + + def _make_sampler(self, dataset, **config_overrides) -> rlvr_curriculum.BetaBinomialDifficultySampler: + config = self._make_config(**config_overrides) + return rlvr_curriculum.BetaBinomialDifficultySampler( + dataset=dataset, num_samples=max(len(dataset), 1), config=config, global_step_getter=lambda: 0 + ) + + def test_missing_metadata_raises_when_strict_metadata(self): + dataset = ListDataset([{"index": 0}]) + with self.assertRaises(ValueError): + self._make_sampler(dataset, strict_metadata=True) + + def test_missing_metadata_falls_back_when_not_strict(self): + dataset = ListDataset( + [make_difficulty_row(index=0, bucket_index=0, posterior_mean=0.9, bucket_count=5), {"index": 1}] + ) + sampler = self._make_sampler(dataset, strict_metadata=False) + self.assertEqual(sampler.metadata_fallback_count, 1) + self.assertIn(1, sampler.bucket_to_indices[2]) + + def test_bucket_grouping_works(self): + sampler = self._make_sampler(make_bucket_dataset()) + self.assertEqual(sampler.bucket_to_indices, ((0,), (1,), (2,), (3,), (4,))) + + def test_bootstrap_curriculum_heavily_samples_easy_buckets(self): + sampler = self._make_sampler(make_bucket_dataset()) + early_probs = sampler.get_static_bucket_probs(step=0) + self.assertGreater(early_probs[0] + early_probs[1], 0.75) + self.assertGreater(early_probs[0], early_probs[2]) + self.assertGreater(early_probs[1], early_probs[2]) + self.assertLessEqual(early_probs[4], sampler.config.min_hard_frac + 1e-6) + + def test_post_bootstrap_curriculum_returns_to_medium_buckets(self): + sampler = self._make_sampler(make_bucket_dataset()) + post_bootstrap_probs = sampler.get_static_bucket_probs(step=sampler.config.easy_focus_steps) + self.assertEqual(int(post_bootstrap_probs.argmax()), 2) + self.assertGreater(post_bootstrap_probs[2], post_bootstrap_probs[1]) + self.assertGreater(post_bootstrap_probs[2], post_bootstrap_probs[3]) + + def test_late_curriculum_increases_hard_bucket_probability(self): + sampler = self._make_sampler(make_bucket_dataset()) + early_probs = sampler.get_bucket_probs(step=0) + late_step = sampler.config.warmup_steps + sampler.config.total_curriculum_steps + late_probs = sampler.get_bucket_probs(step=late_step) + self.assertGreater(late_probs[4], early_probs[4]) + self.assertGreater(late_probs[4], late_probs[2]) + self.assertGreater(late_probs[3], late_probs[2]) + + def test_extremely_hard_example_is_rare_early_but_more_likely_late(self): + sampler = self._make_sampler(make_bucket_dataset()) + early_probability = sampler.get_example_probability(4, step=0) + late_step = sampler.config.warmup_steps + sampler.config.total_curriculum_steps + late_probability = sampler.get_example_probability(4, step=late_step) + self.assertLess(early_probability, 0.1) + self.assertGreater(late_probability, early_probability) + + def test_probabilities_always_sum_to_one(self): + sampler = self._make_sampler(make_bucket_dataset()) + for step in ( + 0, + sampler.config.warmup_steps + 5, + sampler.config.warmup_steps + sampler.config.total_curriculum_steps, + ): + self.assertAlmostEqual(float(sampler.get_static_bucket_probs(step=step).sum()), 1.0, places=6) + self.assertAlmostEqual(float(sampler.get_bucket_probs(step=step).sum()), 1.0, places=6) + + def test_adaptive_stats_increase_sampling_probability_for_high_signal_bucket(self): + sampler = self._make_sampler( + make_bucket_dataset(), adaptive_enabled=True, adaptive_update_every=1, adaptive_blend_weight=0.5 + ) + static_probs = sampler.get_bucket_probs(step=0) + sampler.record_observations( + dataset_indices=[4, 4, 4, 4, 4, 2, 2], + rewards=[0.3, 0.35, 0.25, 0.4, 0.3, 0.95, 0.9], + advantages=[1.2, 1.1, 1.0, 0.9, 1.3, 0.05, 0.02], + ) + adaptive_probs = sampler.get_bucket_probs(step=1) + self.assertGreater(adaptive_probs[4], static_probs[4]) + + def test_bootstrap_distribution_is_tunable(self): + default_sampler = self._make_sampler(make_bucket_dataset()) + tuned_sampler = self._make_sampler( + make_bucket_dataset(), + bootstrap_target_bucket_ratio=0.0, + warmup_target_bucket_ratio=0.4, + easy_focus_sigma=0.5, + ) + + default_probs = default_sampler.get_static_bucket_probs(step=0) + tuned_probs = tuned_sampler.get_static_bucket_probs(step=0) + + self.assertGreater(tuned_probs[0], default_probs[0]) + self.assertLess(tuned_probs[2], default_probs[2]) + + +class TestDifficultyCurriculumLoaderIntegration(unittest.TestCase): + def test_existing_behavior_is_unchanged_when_curriculum_disabled(self): + dataset = make_plain_hf_dataset(20) + config = data_loader.StreamingDataLoaderConfig(difficulty_curriculum_enabled=False) + + built_loader = data_loader.build_data_preparation_prompt_dataloader( + dataset=dataset, seed=7, work_dir=tempfile.gettempdir(), config=config + ) + baseline_loader = data_loader.HFDataLoader( + dataset=dataset, + batch_size=1, + seed=7, + dp_rank=0, + dp_world_size=1, + work_dir=tempfile.gettempdir(), + automatic_reshuffle=True, + collator=data_loader.single_example_collator, + ) + + self.assertIs(type(built_loader), data_loader.HFDataLoader) + + built_indices = [batch["index"].item() for batch in built_loader] + baseline_indices = [batch["index"].item() for batch in baseline_loader] + self.assertEqual(built_indices, baseline_indices) + + +if __name__ == "__main__": + unittest.main() diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh b/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh new file mode 100644 index 0000000000..59fe9508ee --- /dev/null +++ b/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh @@ -0,0 +1,113 @@ +#!/bin/bash +set -euo pipefail + +EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum}" +RUN_NAME="${RUN_NAME:-${EXP_NAME}_$(date +%Y%m%d_%H%M%S)}" + +NUM_GPUS="${NUM_GPUS:-8}" +BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" + +CLUSTER="${CLUSTER:-ai2/jupiter}" +PRIORITY="${PRIORITY:-urgent}" +WORKSPACE="${WORKSPACE:-ai2/olmo-instruct}" + +# Difficulty-annotated variant of hamishivi/DAPO-Math-17k-Processed_filtered +DATASET_WITH_DIFFICULTY="undfined/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples-ds" + +TOTAL_EPISODES="${TOTAL_EPISODES:-128000}" +NUM_SAMPLES_PER_PROMPT_ROLLOUT="${NUM_SAMPLES_PER_PROMPT_ROLLOUT:-16}" +NUM_UNIQUE_PROMPTS_ROLLOUT="${NUM_UNIQUE_PROMPTS_ROLLOUT:-8}" +LOCAL_EVAL_EVERY="${LOCAL_EVAL_EVERY:-100}" +SAVE_FREQ="${SAVE_FREQ:-100}" +CHECKPOINT_STATE_FREQ="${CHECKPOINT_STATE_FREQ:-100}" + +NUM_TRAINING_STEPS=$(( TOTAL_EPISODES / (NUM_UNIQUE_PROMPTS_ROLLOUT * NUM_SAMPLES_PER_PROMPT_ROLLOUT) )) + +# Keep the easy bootstrap aligned with the first logging/eval window by default. +DIFFICULTY_CURRICULUM_EASY_FOCUS_STEPS="${DIFFICULTY_CURRICULUM_EASY_FOCUS_STEPS:-${LOCAL_EVAL_EVERY}}" +DIFFICULTY_CURRICULUM_WARMUP_STEPS="${DIFFICULTY_CURRICULUM_WARMUP_STEPS:-${DIFFICULTY_CURRICULUM_EASY_FOCUS_STEPS}}" +if (( NUM_TRAINING_STEPS <= DIFFICULTY_CURRICULUM_WARMUP_STEPS )); then + DEFAULT_DIFFICULTY_CURRICULUM_TOTAL_STEPS=1 +else + DEFAULT_DIFFICULTY_CURRICULUM_TOTAL_STEPS=$(( NUM_TRAINING_STEPS - DIFFICULTY_CURRICULUM_WARMUP_STEPS )) +fi +DIFFICULTY_CURRICULUM_TOTAL_STEPS="${DIFFICULTY_CURRICULUM_TOTAL_STEPS:-${DEFAULT_DIFFICULTY_CURRICULUM_TOTAL_STEPS}}" + +uv run mason.py \ + --task_name ${EXP_NAME} \ + --description "${RUN_NAME}" \ + --cluster ${CLUSTER} \ + --workspace ${WORKSPACE} \ + --priority ${PRIORITY} \ + --pure_docker_mode \ + --no_auto_dataset_cache \ + --image ${BEAKER_IMAGE} \ + --preemptible \ + --num_nodes 1 \ + --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + --gpus $NUM_GPUS \ + --budget ai2/oe-adapt \ + -- \ +uv run open_instruct/grpo_fast.py \ + --run_name "${RUN_NAME}" \ + --exp_name "${EXP_NAME}" \ + --eval_pass_at_k 32 \ + --eval_top_p 0.95 \ + --vllm_top_p 1.0 \ + --beta 0.0 \ + --async_steps 4 \ + --active_sampling \ + --inflight_updates \ + --truncated_importance_sampling_ratio_cap 2.0 \ + --advantage_normalization_type centered \ + --num_samples_per_prompt_rollout ${NUM_SAMPLES_PER_PROMPT_ROLLOUT} \ + --num_unique_prompts_rollout ${NUM_UNIQUE_PROMPTS_ROLLOUT} \ + --num_mini_batches 1 \ + --learning_rate 1e-6 \ + --per_device_train_batch_size 1 \ + --dataset_mixer_list "${DATASET_WITH_DIFFICULTY}" 1.0 \ + --dataset_mixer_list_splits "train" \ + --dataset_mixer_eval_list allenai/aime_2025_openinstruct 1.0 allenai/brumo_2025_openinstruct 1.0 \ + --dataset_mixer_eval_list_splits "train" \ + --max_prompt_token_length 2048 \ + --response_length 8192 \ + --pack_length 10240 \ + --model_name_or_path "Qwen/Qwen3-4B-Base" \ + --non_stop_penalty False \ + --temperature 1.0 \ + --total_episodes ${TOTAL_EPISODES} \ + --deepspeed_stage 2 \ + --num_learners_per_node 4 \ + --vllm_num_engines 4 \ + --vllm_tensor_parallel_size 1 \ + --lr_scheduler_type constant \ + --apply_verifiable_reward true \ + --seed 1 \ + --local_eval_every ${LOCAL_EVAL_EVERY} \ + --save_freq ${SAVE_FREQ} \ + --checkpoint_state_freq ${CHECKPOINT_STATE_FREQ} \ + --gradient_checkpointing \ + --with_tracking \ + --send_slack_alerts \ + --vllm_enable_prefix_caching \ + --clip_higher 0.272 \ + --mask_truncated_completions False \ + --chat_template qwen_instruct_user_boxed_math \ + --load_ref_policy False \ + --keep_last_n_checkpoints -1 \ + --push_to_hub False \ + --difficulty_curriculum_enabled true \ + --difficulty_curriculum_field difficulty \ + --difficulty_curriculum_easy_focus_steps ${DIFFICULTY_CURRICULUM_EASY_FOCUS_STEPS} \ + --difficulty_curriculum_bootstrap_target_bucket_ratio 0.125 \ + --difficulty_curriculum_warmup_target_bucket_ratio 0.5 \ + --difficulty_curriculum_final_target_bucket_ratio 1.0 \ + --difficulty_curriculum_warmup_steps ${DIFFICULTY_CURRICULUM_WARMUP_STEPS} \ + --difficulty_curriculum_total_steps ${DIFFICULTY_CURRICULUM_TOTAL_STEPS} \ + --difficulty_curriculum_min_hard_frac 0.05 \ + --difficulty_curriculum_max_hard_frac 0.50 \ + --difficulty_curriculum_bucket_sigma 0.0 \ + --difficulty_curriculum_easy_focus_sigma 0.0 \ + --difficulty_curriculum_uncertainty_weight 0.5 \ + --difficulty_curriculum_adaptive_enabled False \ + --difficulty_curriculum_strict_metadata true "$@" From 3a7b3a45e4759c6eea7625dea54545228f94ee78 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 4 May 2026 11:57:01 -0700 Subject: [PATCH 20/40] Launcher tweaks --- .../train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh b/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh index 59fe9508ee..21e6bd427a 100644 --- a/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh @@ -6,6 +6,9 @@ RUN_NAME="${RUN_NAME:-${EXP_NAME}_$(date +%Y%m%d_%H%M%S)}" NUM_GPUS="${NUM_GPUS:-8}" BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" +if [[ $# -gt 0 ]]; then + shift +fi CLUSTER="${CLUSTER:-ai2/jupiter}" PRIORITY="${PRIORITY:-urgent}" @@ -33,7 +36,7 @@ else fi DIFFICULTY_CURRICULUM_TOTAL_STEPS="${DIFFICULTY_CURRICULUM_TOTAL_STEPS:-${DEFAULT_DIFFICULTY_CURRICULUM_TOTAL_STEPS}}" -uv run mason.py \ +uv run python mason.py \ --task_name ${EXP_NAME} \ --description "${RUN_NAME}" \ --cluster ${CLUSTER} \ From c96506e992005ebda485646092b84922c0c9551f Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 14:39:47 -0700 Subject: [PATCH 21/40] Some cleanup, renames --- ...Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json | 45 + ...n_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json | 137 ++ docs/algorithms/grpo.md | 40 +- open_instruct/benchmark_generators.py | 295 +--- open_instruct/data_loader.py | 96 +- open_instruct/grpo_fast.py | 15 +- open_instruct/rl_utils.py | 91 +- open_instruct/rlvr_curriculum.py | 573 ++++--- open_instruct/rlvr_difficulty.py | 1521 +++++++++++++++++ open_instruct/test_rl_utils.py | 59 - open_instruct/test_rlvr_curriculum.py | 78 +- open_instruct/test_rollout_traces.py | 84 + .../create_bucketed_difficulty.py | 1516 +--------------- .../qwen3_4b_dapo_math_gen.sh | 58 - ...wen3_4b_dapo_math_difficulty_curriculum.sh | 48 +- tests/test_create_bucketed_difficulty.py | 18 +- 16 files changed, 2360 insertions(+), 2314 deletions(-) create mode 100644 configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json create mode 100644 configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json create mode 100644 open_instruct/rlvr_difficulty.py create mode 100644 open_instruct/test_rollout_traces.py delete mode 100644 scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh diff --git a/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json b/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json new file mode 100644 index 0000000000..7ec86a1beb --- /dev/null +++ b/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json @@ -0,0 +1,45 @@ +{ + "difficulty_generation": { + "beta_prior_requested": "empirical-bayes", + "beta_prior_used": { + "alpha": 0.4383101221875231, + "beta": 2.1002075643639166, + "source": "empirical_bayes" + }, + "binary_instance_count": 12643, + "bucket_count_effective": 5, + "bucket_count_field": "difficulty.bucket_count", + "bucket_count_requested": 5, + "bucket_field": "difficulty.bucket_index", + "bucket_ranking_field": "difficulty.expected_quantile", + "difficulty_value_definition": "1 - difficulty.posterior_lower_bound", + "difficulty_value_field": "difficulty.value", + "method": "beta_binomial_posterior_quantiles", + "nonbinary_instance_count": 0, + "posterior_lower_quantile": 0.1, + "tag": "bbq-eb-q10-k5" + }, + "model_name": "Qwen/Qwen3-4B-Base", + "row_count": 12643, + "score_processing": { + "normalization": "identity_binary", + "output_field": "attempt_scores", + "positive_reward_value": 1.0, + "source_field": "pass_count,num_samples,pass_rate", + "supports_binary_difficulty": true + }, + "source_format": { + "attempt_count_field": "num_samples", + "config_name": null, + "dataset_repo_id": "mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples", + "instance_id_definition": "dataset_repo_id::row_id_field when a stable row id is available; otherwise dataset_repo_id::row_index", + "kind": "hugging_face_dataset_passrate_rows", + "model_field": "generator_model", + "pass_count_field": "pass_count", + "pass_rate_field": "pass_rate", + "row_id_field": "extra_info.index", + "split": "train", + "task_field": "dataset" + }, + "task_name": "math" +} \ No newline at end of file diff --git a/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json b/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json new file mode 100644 index 0000000000..1068399271 --- /dev/null +++ b/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json @@ -0,0 +1,137 @@ +{ + "ability": { + "_type": "Value", + "dtype": "string" + }, + "completions": { + "_type": "List", + "feature": { + "_type": "Value", + "dtype": "string" + } + }, + "data_source": { + "_type": "Value", + "dtype": "string" + }, + "dataset": { + "_type": "Value", + "dtype": "string" + }, + "difficulty": { + "bucket_count": { + "_type": "Value", + "dtype": "int64" + }, + "bucket_index": { + "_type": "Value", + "dtype": "int64" + }, + "expected_quantile": { + "_type": "Value", + "dtype": "float64" + }, + "posterior_lower_bound": { + "_type": "Value", + "dtype": "float64" + }, + "posterior_mean": { + "_type": "Value", + "dtype": "float64" + }, + "value": { + "_type": "Value", + "dtype": "float64" + } + }, + "extra_info": { + "index": { + "_type": "Value", + "dtype": "string" + } + }, + "generator_chat_template": { + "_type": "Value", + "dtype": "string" + }, + "generator_max_tokens": { + "_type": "Value", + "dtype": "int64" + }, + "generator_model": { + "_type": "Value", + "dtype": "string" + }, + "generator_temperature": { + "_type": "Value", + "dtype": "float64" + }, + "generator_top_p": { + "_type": "Value", + "dtype": "float64" + }, + "ground_truth": { + "_type": "Value", + "dtype": "string" + }, + "messages": { + "_type": "List", + "feature": { + "content": { + "_type": "Value", + "dtype": "string" + }, + "role": { + "_type": "Value", + "dtype": "string" + } + } + }, + "num_samples": { + "_type": "Value", + "dtype": "int64" + }, + "pass_count": { + "_type": "Value", + "dtype": "int64" + }, + "pass_rate": { + "_type": "Value", + "dtype": "string" + }, + "prompt": { + "_type": "Value", + "dtype": "string" + }, + "reward_model": { + "ground_truth": { + "_type": "Value", + "dtype": "string" + }, + "style": { + "_type": "Value", + "dtype": "string" + } + }, + "solution": { + "_type": "Value", + "dtype": "string" + }, + "source_prompt": { + "_type": "List", + "feature": { + "content": { + "_type": "Value", + "dtype": "string" + }, + "role": { + "_type": "Value", + "dtype": "string" + } + } + }, + "source_split": { + "_type": "Value", + "dtype": "string" + } +} \ No newline at end of file diff --git a/docs/algorithms/grpo.md b/docs/algorithms/grpo.md index 9b991a79cf..fb47d3318f 100644 --- a/docs/algorithms/grpo.md +++ b/docs/algorithms/grpo.md @@ -79,7 +79,7 @@ Both `grpo.py` and `grpo_fast.py` share the same config classes and accept the s ### Difficulty-Aware RLVR Curriculum -Open-Instruct can optionally replace uniform prompt reshuffling with a bucket-aware RLVR curriculum driven by per-instance beta-binomial metadata: +`grpo_fast.py` can optionally replace uniform prompt reshuffling with `DifficultyCurriculumSampler`, a bucket-aware RLVR curriculum driven by per-instance difficulty metadata. The current recommended metadata format comes from the beta-binomial estimator in `scripts/data/difficulty_sampling/create_bucketed_difficulty.py`: ```json { @@ -98,7 +98,7 @@ Open-Instruct can optionally replace uniform prompt reshuffling with a bucket-aw - `bucket_index = 0` is the easiest bucket and `bucket_index = bucket_count - 1` is the hardest. - The sampler uses a smooth distribution with a configurable easy-heavy bootstrap phase, then gradually shifts mass toward harder buckets instead of hard-switching between discrete phases. - Within each bucket, examples are weighted by a blend of uncertainty (`4 * p * (1 - p)`) and hardness (`1 - p`), so borderline prompts stay attractive while already-solved prompts are naturally down-weighted. -- If `--difficulty_curriculum_adaptive_enabled true` is set, bucket probabilities are additionally blended with live reward / advantage statistics so buckets with useful learning signal can get more mass during training. +- If `--curriculum_adaptive true` is set, bucket probabilities are additionally blended with live reward / advantage statistics so buckets with useful learning signal can get more mass during training. Recommended starting settings for `bucket_count=5`: @@ -110,28 +110,28 @@ Recommended starting settings for `bucket_count=5`: Useful flags: ```bash ---difficulty_curriculum_enabled true \ ---difficulty_curriculum_field difficulty \ ---difficulty_curriculum_easy_focus_steps 100 \ ---difficulty_curriculum_bootstrap_target_bucket_ratio 0.125 \ ---difficulty_curriculum_warmup_target_bucket_ratio 0.5 \ ---difficulty_curriculum_final_target_bucket_ratio 1.0 \ ---difficulty_curriculum_warmup_steps 500 \ ---difficulty_curriculum_total_steps 10000 \ ---difficulty_curriculum_min_hard_frac 0.05 \ ---difficulty_curriculum_max_hard_frac 0.50 \ ---difficulty_curriculum_bucket_sigma 0.0 \ ---difficulty_curriculum_easy_focus_sigma 0.0 \ ---difficulty_curriculum_uncertainty_weight 0.5 \ ---difficulty_curriculum_adaptive_enabled true +--curriculum difficulty \ +--curriculum_metadata_field difficulty \ +--curriculum_bootstrap_steps 100 \ +--curriculum_bootstrap_target 0.125 \ +--curriculum_warmup_target 0.5 \ +--curriculum_final_target 1.0 \ +--curriculum_warmup_steps 500 \ +--curriculum_total_steps 10000 \ +--curriculum_min_hard_frac 0.05 \ +--curriculum_max_hard_frac 0.50 \ +--curriculum_bucket_sigma 0.0 \ +--curriculum_bootstrap_sigma 0.0 \ +--curriculum_uncertainty_weight 0.5 \ +--curriculum_adaptive true ``` Tuning tips: -- Increase `difficulty_curriculum_easy_focus_steps` to keep the easy bootstrap around longer. -- Lower `difficulty_curriculum_bootstrap_target_bucket_ratio` to bias more strongly toward the easiest buckets early. -- Lower `difficulty_curriculum_bucket_sigma` or `difficulty_curriculum_easy_focus_sigma` to concentrate probability on fewer neighboring buckets. -- Lower `difficulty_curriculum_warmup_target_bucket_ratio` if you want the post-bootstrap warmup to stay easier for longer. +- Increase `curriculum_bootstrap_steps` to keep the easy bootstrap around longer. +- Lower `curriculum_bootstrap_target` to bias more strongly toward the easiest buckets early. +- Lower `curriculum_bucket_sigma` or `curriculum_bootstrap_sigma` to concentrate probability on fewer neighboring buckets. +- Lower `curriculum_warmup_target` if you want the post-bootstrap warmup to stay easier for longer. Metrics are logged through the standard GRPO tracking path. The most useful ones are: diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index 83506ac9b3..9c95f372ec 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -29,55 +29,19 @@ from open_instruct import data_loader, dataset_transformation, grpo_utils, logger_utils, model_utils, utils, vllm_utils from open_instruct.actor_manager import ActorManager -from open_instruct.data_types import GenerationResult, PromptRequest +from open_instruct.data_types import PromptRequest from open_instruct.ground_truth_utils import RewardConfig, build_all_verifiers -from open_instruct.rl_utils import build_rollout_batch_and_advantages, save_rollout_metadata, save_rollouts_to_disk logger = logger_utils.setup_logger(__name__) -def get_default_data_dir() -> pathlib.Path: - """Return the legacy default directory for benchmark artifacts.""" - if pathlib.Path("/weka").exists(): - return pathlib.Path("/weka") / "finbarrt" / "open_instruct_generators_benchmark" - if pathlib.Path("/root").exists(): - return pathlib.Path("/root") / "finbarrt" / "open_instruct_generators_benchmark" - return pathlib.Path("/tmp") / "open_instruct_generators_benchmark" - - -DATA_DIR = get_default_data_dir() - - -@dataclasses.dataclass -class BenchmarkConfig: - """Benchmark-only controls for benchmark_generators.py.""" - - num_batches: int = 5 - """Total number of benchmark batches to run, including the initial warmup batch.""" - run_all_instances: bool = False - """If True, ignore num_batches and run enough batches to cover the entire dataset.""" - - def __post_init__(self) -> None: - if self.num_batches < 1: - raise ValueError(f"num_batches must be >= 1, got {self.num_batches}") - - -def resolve_num_batches(*, dataset_len: int, prompts_per_batch: int, benchmark_config: BenchmarkConfig) -> int: - """Resolve the total number of benchmark batches to run, including warmup.""" - if benchmark_config.run_all_instances: - return max(1, -(-dataset_len // prompts_per_batch)) - return benchmark_config.num_batches - - -def resolve_data_dir( - args: grpo_utils.GRPOExperimentConfig, streaming_config: data_loader.StreamingDataLoaderConfig -) -> pathlib.Path: - """Resolve where benchmark artifacts should be written for this run.""" - if args.output_dir.rstrip("/") != "output": - return pathlib.Path(args.output_dir) - if streaming_config.save_traces and streaming_config.rollouts_save_path: - return pathlib.Path(streaming_config.rollouts_save_path) - return get_default_data_dir() +# Determine data directory +if pathlib.Path("/weka").exists(): + DATA_DIR = pathlib.Path("/weka") / "finbarrt" / "open_instruct_generators_benchmark" +elif pathlib.Path("/root").exists(): + DATA_DIR = pathlib.Path("/root") / "finbarrt" / "open_instruct_generators_benchmark" +else: + DATA_DIR = pathlib.Path("/tmp") / "open_instruct_generators_benchmark" def save_completion_lengths(batch_results: list[dict], timestamp: int, batch_idx: int): @@ -103,13 +67,7 @@ def save_completion_lengths(batch_results: list[dict], timestamp: int, batch_idx def save_config( - args, - tokenizer_config, - model_config, - streaming_config: data_loader.StreamingDataLoaderConfig, - benchmark_config: BenchmarkConfig, - resolved_num_batches: int, - timestamp: int, + args, tokenizer_config, model_config, streaming_config: data_loader.StreamingDataLoaderConfig, timestamp: int ): """ Save configuration to JSON file. @@ -119,8 +77,6 @@ def save_config( tokenizer_config: TokenizerConfig dataclass model_config: ModelConfig dataclass streaming_config: StreamingDataLoaderConfig dataclass - benchmark_config: Benchmark-specific config dataclass - resolved_num_batches: Effective total number of benchmark batches that will run timestamp: Unix timestamp """ config_path = DATA_DIR / f"config_{timestamp}.json" @@ -131,10 +87,7 @@ def save_config( "tokenizer_config": dataclasses.asdict(tokenizer_config), "model_config": dataclasses.asdict(model_config), "streaming_config": dataclasses.asdict(streaming_config), - "benchmark_config": dataclasses.asdict(benchmark_config), - "resolved_num_batches": resolved_num_batches, "timestamp": timestamp, - "experiment_id": os.environ.get("BEAKER_WORKLOAD_ID") or None, } with open(config_path, "w") as f: @@ -152,7 +105,6 @@ def save_benchmark_results_to_csv( """Save benchmark results to CSV file.""" git_commit = utils.get_git_commit() agg_results = aggregate_results(results) - total_samples = len(agg_results["response_lengths"]) csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv" row_data = { @@ -166,9 +118,7 @@ def save_benchmark_results_to_csv( "total_time": total_time, "total_generation_time": agg_results["total_generation_time"], "total_weight_sync_time": agg_results["total_weight_sync_time"], - "generation_time_percentage": (agg_results["total_generation_time"] / total_time) * 100 - if total_time > 0 - else 0, + "generation_time_percentage": (agg_results["total_generation_time"] / total_time) * 100, "weight_sync_time_percentage": (agg_results["total_weight_sync_time"] / total_time) * 100 if total_time > 0 else 0, @@ -178,7 +128,12 @@ def save_benchmark_results_to_csv( "avg_mbu": agg_results["avg_mbu"], "avg_generation_time_per_batch": agg_results["avg_generation_time"], "avg_weight_sync_time_per_batch": agg_results["avg_weight_sync_time"], - "avg_new_tokens_per_sample": agg_results["total_num_new_tokens"] / total_samples if total_samples > 0 else 0, + "avg_new_tokens_per_sample": agg_results["total_num_new_tokens"] + / ( + len(results) + * streaming_config.num_unique_prompts_rollout + * streaming_config.num_samples_per_prompt_rollout + ), } csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv" @@ -193,65 +148,6 @@ def save_benchmark_results_to_csv( logger.info(f"Saved benchmark results to {csv_path}") -def resolve_run_name(args: grpo_utils.GRPOExperimentConfig, timestamp: int) -> str: - """Resolve a stable run name for optional rollout trace persistence.""" - return args.run_name or f"{args.exp_name}__{timestamp}" - - -def maybe_save_scored_rollout_traces( - batch_results: list[GenerationResult], - dataset: datasets.Dataset, - streaming_config: data_loader.StreamingDataLoaderConfig, - *, - run_name: str, - step: int, - total_samples_written: int, -) -> int: - """Persist raw per-sample reward traces when explicitly requested.""" - if not streaming_config.save_traces: - return total_samples_written - - for result in batch_results: - if result.index is None: - raise ValueError("Cannot save scored rollout traces because the result is missing its dataset index.") - - example = dataset[result.index] - prompt_tokens = ( - list(example[dataset_transformation.INPUT_IDS_PROMPT_KEY]) - if streaming_config.rollout_save_format == "full" - else [] - ) - ground_truth = ( - example[dataset_transformation.GROUND_TRUTHS_KEY] - if streaming_config.rollout_save_format == "full" - else None - ) - batch, advantages = build_rollout_batch_and_advantages( - result, - prompt_tokens=prompt_tokens, - ground_truth=ground_truth, - dataset_name=example[dataset_transformation.VERIFIER_SOURCE_KEY], - raw_query=example[dataset_transformation.RAW_PROMPT_KEY], - advantage_normalization_type=streaming_config.advantage_normalization_type, - source_row_id=example.get(dataset_transformation.SOURCE_ROW_ID_KEY), - source_dataset=example.get(dataset_transformation.DATASET_ORIGIN_KEY), - ) - save_rollouts_to_disk( - streaming_config.rollouts_save_path, - run_name, - step, - batch, - result, - advantages, - len(result.responses), - total_samples_written, - record_format=streaming_config.rollout_save_format, - ) - total_samples_written += len(result.responses) - - return total_samples_written - - def free_all_gpu_memory(device: int | str = 0) -> None: """ Aggressively free GPU memory used by PyTorch. @@ -315,7 +211,6 @@ def setup_dataset( dataset_cache_mode=streaming_config.dataset_cache_mode, dataset_local_cache_dir=streaming_config.dataset_local_cache_dir, dataset_skip_cache=streaming_config.dataset_skip_cache, - drop_dataset_source=not streaming_config.save_traces, ) # Shuffle dataset @@ -425,16 +320,20 @@ def submission_thread( dataset: datasets.Dataset, generation_config: vllm_utils.SamplingConfig, stop_event: threading.Event, - batch_specs: list[tuple[int, int, int]], + batch_size: int, + start_batch_idx: int, + num_batches: int, ) -> None: """Thread that submits prompts to the queue.""" logger.info("[Submission Thread] Starting prompt submission") - for batch_idx, start_idx, end_idx in batch_specs: + for batch_idx in range(start_batch_idx, start_batch_idx + num_batches): if stop_event.is_set(): logger.info("[Submission Thread] Stopped due to stop event") break # Get batch data from dataset + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, len(dataset)) batch_data = dataset[start_idx:end_idx] prompts = batch_data[dataset_transformation.INPUT_IDS_PROMPT_KEY] @@ -448,21 +347,7 @@ def submission_thread( generation_config=generation_config, ) ) - logger.info(f"[Submission Thread] All {len(batch_specs)} batches submitted") - - -def build_batch_specs( - *, dataset_len: int, batch_size: int, start_batch_idx: int, num_batches: int -) -> list[tuple[int, int, int]]: - """Return (batch_idx, start_idx, end_idx) triples for non-empty batches only.""" - batch_specs = [] - for batch_idx in range(start_batch_idx, start_batch_idx + num_batches): - start_idx = batch_idx * batch_size - if start_idx >= dataset_len: - break - end_idx = min(start_idx + batch_size, dataset_len) - batch_specs.append((batch_idx, start_idx, end_idx)) - return batch_specs + logger.info(f"[Submission Thread] All {num_batches} batches submitted") def run_benchmark( @@ -475,14 +360,10 @@ def run_benchmark( streaming_config: data_loader.StreamingDataLoaderConfig, vllm_config: data_loader.VLLMConfig, model_config: model_utils.ModelConfig, - run_name: str, timestamp: int, num_batches: int = 5, ) -> list[dict[str, Any]]: """Run the full benchmark.""" - if len(dataset) == 0: - raise ValueError("Benchmark dataset is empty after loading and filtering.") - logger.info( f"Starting benchmark with 1 warmup batch + {num_batches - 1} main batches of size {streaming_config.num_unique_prompts_rollout}" ) @@ -503,7 +384,6 @@ def run_benchmark( executor = futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="benchmark") results = [] - total_samples_written = 0 # Get the model dimensions from one of the engines without loading weights model_dims = ray.get(vllm_engines[0].get_model_dims.remote()) @@ -553,50 +433,25 @@ def run_benchmark( logger.info( f"Warmup batch completed with {total_warmup_responses} total responses from {len(warmup_results)} prompts" ) - total_samples_written = maybe_save_scored_rollout_traces( - warmup_results, + logger.info(f"Submitting {num_batches - 1} batches for main benchmark...") + submission_future = executor.submit( + submission_thread, + param_prompt_Q, dataset, - streaming_config, - run_name=run_name, - step=0, - total_samples_written=total_samples_written, - ) - main_batch_specs = build_batch_specs( - dataset_len=len(dataset), - batch_size=streaming_config.num_unique_prompts_rollout, - start_batch_idx=1, - num_batches=num_batches - 1, + generation_config, + stop_event, + streaming_config.num_unique_prompts_rollout, + 1, + num_batches - 1, ) - if not main_batch_specs: - logger.warning( - "No main benchmark batches remain after warmup because the dataset only has %s prompt(s), " - "which fit entirely in the warmup batch size of %s.", - len(dataset), - streaming_config.num_unique_prompts_rollout, - ) - submission_future = None - else: - if len(main_batch_specs) < num_batches - 1: - logger.info( - "Submitting %s main benchmark batch(es) instead of %s because the dataset only has %s prompt(s).", - len(main_batch_specs), - num_batches - 1, - len(dataset), - ) - else: - logger.info(f"Submitting {len(main_batch_specs)} batches for main benchmark...") - - submission_future = executor.submit( - submission_thread, param_prompt_Q, dataset, generation_config, stop_event, main_batch_specs - ) # Process remaining batches with timing - for batch_position, (batch_idx, batch_start_idx, batch_end_idx) in enumerate(main_batch_specs, start=1): + for batch_idx in range(1, num_batches): # Quick health check! - if submission_future is not None and submission_future.done(): + if submission_future.done(): submission_future.result() # Collect all results for this batch (one per prompt) using non-blocking polling - num_prompts = batch_end_idx - batch_start_idx + num_prompts = streaming_config.num_unique_prompts_rollout batch_results = [] batch_deadline = time.time() + 1200 while len(batch_results) < num_prompts: @@ -605,18 +460,7 @@ def run_benchmark( batch_results.append(result) except Empty: if time.time() > batch_deadline: - raise TimeoutError( - f"Batch {batch_idx} timed out, got {len(batch_results)}/{num_prompts}" - ) from None - - total_samples_written = maybe_save_scored_rollout_traces( - batch_results, - dataset, - streaming_config, - run_name=run_name, - step=batch_idx, - total_samples_written=total_samples_written, - ) + raise TimeoutError(f"Batch timed out, got {len(batch_results)}/{num_prompts}") from None # Simulate weight sync between batches weight_sync_time = simulate_weight_sync(actor_manager, vllm_engines, args) @@ -681,7 +525,7 @@ def run_benchmark( save_completion_lengths([result_dict], timestamp, batch_idx) results.append(result_dict) logger.info( - f"Batch {batch_position}/{len(main_batch_specs)}: " + f"Batch {batch_idx}/{num_batches - 1}: " f"{result_dict['tokens_per_second']:.2f} new tokens/sec, " f"MFU: {result_dict['mfu']:.2f}%, " f"MBU: {result_dict['mbu']:.2f}%, " @@ -690,9 +534,6 @@ def run_benchmark( f"total new tokens: {total_new_tokens}" ) - if submission_future is not None: - submission_future.result() - # Calculate total time for main benchmark only main_benchmark_time = sum(r["generation_time"] for r in results) @@ -732,24 +573,6 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]: prompt_lengths.extend(result["prompt_lengths"]) num_results = len(results) - if num_results == 0: - return { - "total_mfu": 0.0, - "total_mbu": 0.0, - "total_tokens_per_second": 0.0, - "total_generation_time": 0.0, - "total_weight_sync_time": 0.0, - "total_num_new_tokens": 0, - "finish_reasons": finish_reasons, - "response_lengths": response_lengths, - "prompt_lengths": prompt_lengths, - "avg_tokens_per_second": 0.0, - "avg_mfu": 0.0, - "avg_mbu": 0.0, - "avg_generation_time": 0.0, - "avg_weight_sync_time": 0.0, - } - avg_tokens_per_second = total_num_new_tokens / total_generation_time if total_generation_time > 0 else 0 avg_mfu = total_mfu / num_results avg_mbu = total_mbu / num_results @@ -784,8 +607,10 @@ def print_summary( """Print benchmark summary statistics.""" agg_results = aggregate_results(results) - total_samples = len(agg_results["response_lengths"]) - avg_new_tokens_per_sample = agg_results["total_num_new_tokens"] / total_samples if total_samples > 0 else 0 + total_samples = ( + len(results) * streaming_config.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout + ) + avg_new_tokens_per_sample = agg_results["total_num_new_tokens"] / total_samples print("\n" + "=" * 60) print("BENCHMARK SUMMARY") @@ -798,12 +623,6 @@ def print_summary( print(f"Unique prompts per batch: {streaming_config.num_unique_prompts_rollout}") print(f"Num rollouts: {streaming_config.num_samples_per_prompt_rollout}") print(f"Max tokens: {streaming_config.response_length}") - if not results: - print("-" * 60) - print("No main benchmark batches were executed after warmup.") - print("=" * 60) - return - print("-" * 60) print(f"Total time (main benchmark): {agg_results['total_generation_time']:.2f}s") print(f"Total weight sync time: {agg_results['total_weight_sync_time']:.2f}s") @@ -861,8 +680,6 @@ def cleanup(vllm_engines: list[ray.actor.ActorHandle], actor_manager: ray.actor. def main() -> None: """Main benchmark function.""" - global DATA_DIR - # Parse arguments using ArgumentParserPlus parser = utils.ArgumentParserPlus( ( @@ -871,27 +688,22 @@ def main() -> None: model_utils.ModelConfig, data_loader.StreamingDataLoaderConfig, data_loader.VLLMConfig, - BenchmarkConfig, ) # type: ignore[arg-type] ) - args, tokenizer_config, model_config, streaming_config, vllm_config, benchmark_config = cast( + args, tokenizer_config, model_config, streaming_config, vllm_config = cast( tuple[ grpo_utils.GRPOExperimentConfig, dataset_transformation.TokenizerConfig, model_utils.ModelConfig, data_loader.StreamingDataLoaderConfig, data_loader.VLLMConfig, - BenchmarkConfig, ], parser.parse_args_into_dataclasses(), ) - DATA_DIR = resolve_data_dir(args, streaming_config) - # Ensure data directory exists DATA_DIR.mkdir(parents=True, exist_ok=True) - logger.info(f"Writing benchmark artifacts to {DATA_DIR}") # Calculate flops per token before starting vLLM logger.info("Calculating model FLOPs per token...") @@ -901,20 +713,6 @@ def main() -> None: free_all_gpu_memory() dataset = setup_dataset(args, streaming_config, tokenizer_config) - resolved_num_batches = resolve_num_batches( - dataset_len=len(dataset), - prompts_per_batch=streaming_config.num_unique_prompts_rollout, - benchmark_config=benchmark_config, - ) - if benchmark_config.run_all_instances: - logger.info( - "Resolved run_all_instances=True to %s total batch(es) for %s dataset prompt(s) at %s prompt(s) per batch.", - resolved_num_batches, - len(dataset), - streaming_config.num_unique_prompts_rollout, - ) - else: - logger.info("Using configured num_batches=%s.", resolved_num_batches) max_model_len = streaming_config.max_prompt_token_length + streaming_config.response_length vllm_engines, param_prompt_Q, inference_results_Q, actor_manager = setup_vllm_engines( args, streaming_config, vllm_config, tokenizer_config, model_config, max_model_len, dataset @@ -922,12 +720,7 @@ def main() -> None: # Create the timestamp here so we use it for both filenames. timestamp = int(time.time()) - args.run_name = resolve_run_name(args, timestamp) - save_config( - args, tokenizer_config, model_config, streaming_config, benchmark_config, resolved_num_batches, timestamp - ) - if streaming_config.save_traces: - save_rollout_metadata(streaming_config.rollouts_save_path, args.run_name, model_config.model_name_or_path) + save_config(args, tokenizer_config, model_config, streaming_config, timestamp) run_benchmark( dataset, vllm_engines, @@ -938,9 +731,7 @@ def main() -> None: streaming_config, vllm_config, model_config, - args.run_name, timestamp, - num_batches=resolved_num_batches, ) cleanup(vllm_engines, actor_manager) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 26f5cac02c..4e15c81324 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -250,6 +250,21 @@ def exclude_index(self, index: int) -> None: """ self._excluded_indices.add(index) + def set_sampling_step(self, step: int) -> None: + del step + + def record_curriculum_observations( + self, + dataset_indices: list[int] | np.ndarray, + rewards: list[float] | np.ndarray, + advantages: list[float] | np.ndarray | None = None, + ) -> None: + del dataset_indices, rewards, advantages + + def build_curriculum_metrics(self, prompt_dataset_indices: list[int], step: int) -> dict[str, float]: + del prompt_dataset_indices, step + return {} + def reshuffle(self, epoch: int | None = None, **kwargs: Any) -> None: """Reshuffle and reshard the dataset for a new epoch. @@ -424,30 +439,6 @@ class StreamingDataLoaderConfig: mask_truncated_completions: bool = False mask_tool_use: bool = True - # Difficulty-aware prompt curriculum - difficulty_curriculum_enabled: bool = False - difficulty_curriculum_field: str = "difficulty" - difficulty_curriculum_posterior_mean_field: str = "posterior_mean" - difficulty_curriculum_bucket_index_field: str = "bucket_index" - difficulty_curriculum_bucket_count_field: str = "bucket_count" - difficulty_curriculum_easy_focus_steps: int = 100 - difficulty_curriculum_bootstrap_target_bucket_ratio: float = 0.125 - difficulty_curriculum_warmup_target_bucket_ratio: float = 0.5 - difficulty_curriculum_final_target_bucket_ratio: float = 1.0 - difficulty_curriculum_warmup_steps: int = 500 - difficulty_curriculum_total_steps: int = 10000 - difficulty_curriculum_min_hard_frac: float = 0.05 - difficulty_curriculum_max_hard_frac: float = 0.50 - difficulty_curriculum_bucket_sigma: float = 0.0 - difficulty_curriculum_easy_focus_sigma: float = 0.0 - difficulty_curriculum_uncertainty_weight: float = 0.5 - difficulty_curriculum_adaptive_enabled: bool = False - difficulty_curriculum_adaptive_update_every: int = 50 - difficulty_curriculum_adaptive_learning_signal_weight: float = 0.7 - difficulty_curriculum_adaptive_exploration_weight: float = 0.3 - difficulty_curriculum_adaptive_blend_weight: float = 0.5 - difficulty_curriculum_strict_metadata: bool = True - # Dataset dataset_mixer_list: list[str] = field(default_factory=lambda: ["ai2-adapt-dev/rlvr_gsm8k_zs", "1.0"]) dataset_mixer_eval_list: list[str] = field(default_factory=list) @@ -590,33 +581,6 @@ def build_dataloader( fs_local_rank=fs_local_rank, ) - def build_difficulty_curriculum_config(self, seed: int) -> rlvr_curriculum.DifficultyCurriculumConfig: - return rlvr_curriculum.DifficultyCurriculumConfig( - enabled=self.difficulty_curriculum_enabled, - difficulty_field=self.difficulty_curriculum_field, - posterior_mean_field=self.difficulty_curriculum_posterior_mean_field, - bucket_index_field=self.difficulty_curriculum_bucket_index_field, - bucket_count_field=self.difficulty_curriculum_bucket_count_field, - easy_focus_steps=self.difficulty_curriculum_easy_focus_steps, - bootstrap_target_bucket_ratio=self.difficulty_curriculum_bootstrap_target_bucket_ratio, - warmup_target_bucket_ratio=self.difficulty_curriculum_warmup_target_bucket_ratio, - final_target_bucket_ratio=self.difficulty_curriculum_final_target_bucket_ratio, - warmup_steps=self.difficulty_curriculum_warmup_steps, - total_curriculum_steps=self.difficulty_curriculum_total_steps, - min_hard_frac=self.difficulty_curriculum_min_hard_frac, - max_hard_frac=self.difficulty_curriculum_max_hard_frac, - bucket_sigma=self.difficulty_curriculum_bucket_sigma, - easy_focus_sigma=self.difficulty_curriculum_easy_focus_sigma, - uncertainty_weight=self.difficulty_curriculum_uncertainty_weight, - adaptive_enabled=self.difficulty_curriculum_adaptive_enabled, - adaptive_update_every=self.difficulty_curriculum_adaptive_update_every, - adaptive_learning_signal_weight=self.difficulty_curriculum_adaptive_learning_signal_weight, - adaptive_exploration_weight=self.difficulty_curriculum_adaptive_exploration_weight, - adaptive_blend_weight=self.difficulty_curriculum_adaptive_blend_weight, - seed=seed, - strict_metadata=self.difficulty_curriculum_strict_metadata, - ) - class StreamingDataLoader(data_loader.DataLoaderBase): """Thin wrapper dataloader that pulls pre-prepared data from the DataPreparationActor singleton.""" @@ -741,7 +705,7 @@ def __init__( raise ValueError("DifficultyCurriculumHFDataLoader currently supports dp_world_size=1 only") self._sampling_step = 0 - self._curriculum_sampler = rlvr_curriculum.BetaBinomialDifficultySampler( + self._curriculum_sampler = rlvr_curriculum.DifficultyCurriculumSampler( dataset=dataset, num_samples=max(len(dataset), 1), config=curriculum_config, @@ -765,7 +729,7 @@ def __init__( ) @property - def curriculum_sampler(self) -> rlvr_curriculum.BetaBinomialDifficultySampler: + def curriculum_sampler(self) -> rlvr_curriculum.DifficultyCurriculumSampler: return self._curriculum_sampler def set_sampling_step(self, step: int) -> None: @@ -826,9 +790,12 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]: def build_data_preparation_prompt_dataloader( - dataset: Dataset, seed: int, work_dir: str, config: StreamingDataLoaderConfig + dataset: Dataset, + seed: int, + work_dir: str, + curriculum_config: rlvr_curriculum.DifficultyCurriculumConfig | None = None, ) -> HFDataLoader: - if config.difficulty_curriculum_enabled: + if curriculum_config is not None: return DifficultyCurriculumHFDataLoader( dataset=dataset, batch_size=1, @@ -838,7 +805,7 @@ def build_data_preparation_prompt_dataloader( work_dir=work_dir, automatic_reshuffle=True, collator=single_example_collator, - curriculum_config=config.build_difficulty_curriculum_config(seed=seed), + curriculum_config=curriculum_config, ) return HFDataLoader( @@ -1348,6 +1315,7 @@ def __init__( model_name: str | None, base_env_config: EnvConfig, initial_state: dict | None = None, + curriculum_config: rlvr_curriculum.DifficultyCurriculumConfig | None = None, ): self.inference_results_Q = inference_results_Q self.param_prompt_Q = param_prompt_Q @@ -1369,7 +1337,7 @@ def __init__( self.base_env_config = base_env_config self.iter_dataloader = build_data_preparation_prompt_dataloader( - dataset=dataset, seed=seed, work_dir=work_dir, config=config + dataset=dataset, seed=seed, work_dir=work_dir, curriculum_config=curriculum_config ) self.prepared_data: dict[int, list[data_types.CollatedBatchData]] = {} @@ -1409,8 +1377,7 @@ def _data_preparation_loop(self): num_initial_prompts = self.config.async_steps * self.global_batch_size logger.info(f"[DataPreparationActor] Pushing {num_initial_prompts} initial prompts to param_prompt_Q") - if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader): - self.iter_dataloader.set_sampling_step(self.training_step) + self.iter_dataloader.set_sampling_step(self.training_step) for _ in range(num_initial_prompts): add_prompt_to_generator( next(self.iter_dataloader), @@ -1430,8 +1397,7 @@ def _data_preparation_loop(self): ) time.sleep(0.1) generation_idle_wait_time = time.perf_counter() - generation_idle_wait_start_time - if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader): - self.iter_dataloader.set_sampling_step(step) + self.iter_dataloader.set_sampling_step(step) logger.info( f"[DataPreparationActor] Step {step}: calling accumulate_inference_batches for {self.global_batch_size} prompts" @@ -1476,8 +1442,7 @@ def _data_preparation_loop(self): for _ in range(self.dp_world_size) ] empty_metrics = {"time/generation_idle_waiting_for_trainer": generation_idle_wait_time} - if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader): - empty_metrics.update(self.iter_dataloader.build_curriculum_metrics([], step)) + empty_metrics.update(self.iter_dataloader.build_curriculum_metrics([], step)) with self.lock: self.prepared_data[step] = empty_data self.metrics[step] = empty_metrics @@ -1550,7 +1515,7 @@ def _data_preparation_loop(self): assert result.logprobs is not None result.logprobs = [result.logprobs[i] for i in stop_idxes] - if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader) and batch.indices is not None: + if batch.indices is not None: normalized_scores = np.clip(scores / max(self.config.max_possible_score, 1e-8), 0.0, 1.0) self.iter_dataloader.record_curriculum_observations(batch.indices, normalized_scores, advantages) @@ -1639,8 +1604,7 @@ def _data_preparation_loop(self): step_metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time step_metrics["time/getting_response"] = result.token_statistics.generation_time - if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader): - step_metrics.update(self.iter_dataloader.build_curriculum_metrics(prompt_dataset_indices, step)) + step_metrics.update(self.iter_dataloader.build_curriculum_metrics(prompt_dataset_indices, step)) with self.lock: self.prepared_data[step] = collated_data diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index d56812326b..894ced3845 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -40,7 +40,7 @@ from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF from deepspeed.utils import groups -from open_instruct import data_loader as data_loader_lib +from open_instruct import data_loader as data_loader_lib, rlvr_curriculum from open_instruct import data_types, grpo_utils, utils from open_instruct.data_loader import accumulate_inference_batches, add_prompt_to_generator from open_instruct.data_types import EnvConfig, EnvConfigEntry @@ -1300,6 +1300,7 @@ def create_model_and_optimizer( reward_config: RewardConfig, generation_config, base_env_config: EnvConfig, + curriculum_config: rlvr_curriculum.DifficultyCurriculumConfig | None, tool_definitions: list[dict[str, Any]] | None = None, tools_config: EnvsConfig | None = None, pools: dict[str, ray.actor.ActorHandle] | None = None, @@ -1361,6 +1362,7 @@ def create_model_and_optimizer( model_name=model_config.model_name_or_path, initial_state=None, base_env_config=base_env_config, + curriculum_config=curriculum_config, ) # Create policy group and start model loading BEFORE vLLM engines (matches main branch order). @@ -2395,10 +2397,12 @@ def main( streaming_config: data_loader_lib.StreamingDataLoaderConfig, vllm_config: data_loader_lib.VLLMConfig, tools_config: EnvsConfig, + curriculum_args: rlvr_curriculum.DifficultyCurriculumArgs, ): tokenizer = make_tokenizer(tc, model_config) args = setup_runtime_variables(args, streaming_config, tools_config) validate_configs(streaming_config, vllm_config, tuple(args.num_learners_per_node), args.sequence_parallel_size) + curriculum_config = curriculum_args.build_curriculum_config(seed=args.seed) if args.verbose: logging.getLogger().setLevel(logging.DEBUG) @@ -2455,7 +2459,7 @@ def main( if tc.tokenizer_name_or_path and tc.tokenizer_name_or_path != model_config.model_name_or_path: utils.ensure_hf_repo_cached(tc.tokenizer_name_or_path, revision=tc.tokenizer_revision) - pprint([args, model_config, streaming_config, vllm_config, tools_config]) + pprint([args, model_config, streaming_config, vllm_config, tools_config, curriculum_args]) # Create Ray queues. # Since we now send/receive individual prompts, queue size should accommodate @@ -2511,6 +2515,7 @@ def main( reward_config, generation_configs["train"], base_env_config, + curriculum_config, tool_definitions, tools_config, pools, @@ -2593,10 +2598,11 @@ def main( data_loader_lib.StreamingDataLoaderConfig, data_loader_lib.VLLMConfig, EnvsConfig, + rlvr_curriculum.DifficultyCurriculumArgs, ) ) parser.set_defaults(exp_name="grpo", warmup_ratio=0.0, max_grad_norm=1.0, per_device_train_batch_size=1) - args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config = ( + args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum_args = ( parser.parse_args_into_dataclasses() ) assert isinstance(args, grpo_utils.GRPOExperimentConfig) @@ -2605,5 +2611,6 @@ def main( assert isinstance(streaming_config, data_loader_lib.StreamingDataLoaderConfig) assert isinstance(vllm_config, data_loader_lib.VLLMConfig) assert isinstance(tools_config, EnvsConfig) + assert isinstance(curriculum_args, rlvr_curriculum.DifficultyCurriculumArgs) - main(args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config) + main(args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum_args) diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 1c4a7fc06b..3c79b728db 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -98,21 +98,16 @@ def _get_request_info_for_sample(request_info: data_types.RequestInfo | None, i: } -def _save_rollouts( - save_path: str, - run_name: str, - step: int, +def build_rollout_records( batch: model_utils.Batch, result: data_types.GenerationResult, advantages: np.ndarray, + *, + step: int, num_samples_per_prompt: int, - shard_idx: int, - record_format: RolloutSaveFormat, -) -> None: - shard_filename = f"{run_name}_rollouts_{shard_idx:06d}.jsonl" - filepath = os.path.join(save_path, shard_filename) - os.makedirs(save_path, exist_ok=True) - + record_format: RolloutSaveFormat = "full", +) -> list[dict[str, Any]]: + """Build JSON-serializable rollout records for persistence.""" assert batch.scores is not None, "batch.scores must not be None when saving rollouts" records = [] @@ -151,6 +146,32 @@ def _save_rollouts( ) ) + return records + + +def _save_rollouts( + save_path: str, + run_name: str, + step: int, + batch: model_utils.Batch, + result: data_types.GenerationResult, + advantages: np.ndarray, + num_samples_per_prompt: int, + shard_idx: int, + record_format: RolloutSaveFormat, +) -> None: + shard_filename = f"{run_name}_rollouts_{shard_idx:06d}.jsonl" + filepath = os.path.join(save_path, shard_filename) + os.makedirs(save_path, exist_ok=True) + records = build_rollout_records( + batch, + result, + advantages, + step=step, + num_samples_per_prompt=num_samples_per_prompt, + record_format=record_format, + ) + with open(filepath, "a") as f: for record in records: f.write(json.dumps(record) + "\n") @@ -200,54 +221,6 @@ def save_rollouts_to_disk( ) -def build_rollout_batch_and_advantages( - result: data_types.GenerationResult, - *, - prompt_tokens: list[int], - ground_truth: Any, - dataset_name: str, - raw_query: str, - advantage_normalization_type: str, - source_row_id: int | None = None, - source_dataset: str | None = None, -) -> tuple[model_utils.Batch, np.ndarray]: - """Convert a scored inference result into the rollout format used by difficulty bucketing.""" - if result.reward_scores is None: - raise ValueError("Cannot save scored rollout traces because reward_scores is missing from GenerationResult.") - - num_samples = len(result.responses) - if len(result.reward_scores) != num_samples: - raise ValueError( - "Cannot save scored rollout traces because reward_scores length does not match the number of responses." - ) - - scores = [float(score) for score in result.reward_scores] - indices = [result.index] * num_samples if result.index is not None else None - batch = model_utils.Batch( - queries=[list(prompt_tokens)] * num_samples, - ground_truths=[ground_truth] * num_samples, - datasets=[dataset_name] * num_samples, - raw_queries=[raw_query] * num_samples, - decoded_responses=None, - indices=indices, - scores=scores, - source_row_ids=[source_row_id] * num_samples, - source_datasets=[source_dataset] * num_samples, - model_steps=[result.model_step] * num_samples, - ) - - score_array = np.asarray(scores, dtype=float) - mean_score = score_array.mean() - if advantage_normalization_type == "standard": - advantages = (score_array - mean_score) / (score_array.std() + 1e-8) - elif advantage_normalization_type == "centered": - advantages = score_array - mean_score - else: - raise ValueError(f"Invalid advantage normalization type: {advantage_normalization_type}") - - return batch, advantages - - @dataclass class Timer(contextlib.ContextDecorator): """A context manager and decorator for timing code blocks""" diff --git a/open_instruct/rlvr_curriculum.py b/open_instruct/rlvr_curriculum.py index 832e36db9b..60817457d6 100644 --- a/open_instruct/rlvr_curriculum.py +++ b/open_instruct/rlvr_curriculum.py @@ -3,8 +3,9 @@ from __future__ import annotations import math -from dataclasses import dataclass -from typing import Any +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Literal import numpy as np import torch @@ -55,45 +56,40 @@ def _normalize_probs(values: np.ndarray) -> np.ndarray: @dataclass -class DifficultyCurriculumConfig: - enabled: bool = False - difficulty_field: str = "difficulty" +class DifficultyCurriculumMetadataConfig: + field: str = "difficulty" posterior_mean_field: str = "posterior_mean" bucket_index_field: str = "bucket_index" bucket_count_field: str = "bucket_count" - easy_focus_steps: int = 100 - bootstrap_target_bucket_ratio: float = 0.125 - warmup_target_bucket_ratio: float = 0.5 - final_target_bucket_ratio: float = 1.0 + strict: bool = True + + +@dataclass +class DifficultyCurriculumScheduleConfig: + bootstrap_steps: int = 100 warmup_steps: int = 500 - total_curriculum_steps: int = 10000 + total_steps: int = 10000 + bootstrap_target: float = 0.125 + warmup_target: float = 0.5 + final_target: float = 1.0 min_hard_frac: float = 0.05 max_hard_frac: float = 0.50 bucket_sigma: float = 0.0 - easy_focus_sigma: float = 0.0 - uncertainty_weight: float = 0.5 - adaptive_enabled: bool = False - adaptive_update_every: int = 50 - adaptive_learning_signal_weight: float = 0.7 - adaptive_exploration_weight: float = 0.3 - adaptive_blend_weight: float = 0.5 - seed: int = 0 - strict_metadata: bool = True - epsilon: float = 1e-8 + bootstrap_sigma: float = 0.0 def __post_init__(self) -> None: - if self.easy_focus_steps < 0: - raise ValueError("easy_focus_steps must be >= 0") - if not 0.0 <= self.bootstrap_target_bucket_ratio <= 1.0: - raise ValueError("bootstrap_target_bucket_ratio must be in [0, 1]") - if not 0.0 <= self.warmup_target_bucket_ratio <= 1.0: - raise ValueError("warmup_target_bucket_ratio must be in [0, 1]") - if not 0.0 <= self.final_target_bucket_ratio <= 1.0: - raise ValueError("final_target_bucket_ratio must be in [0, 1]") + if self.bootstrap_steps < 0: + raise ValueError("bootstrap_steps must be >= 0") if self.warmup_steps < 0: raise ValueError("warmup_steps must be >= 0") - if self.total_curriculum_steps <= 0: - raise ValueError("total_curriculum_steps must be > 0") + if self.total_steps <= 0: + raise ValueError("total_steps must be > 0") + if not 0.0 <= self.bootstrap_target <= 1.0: + raise ValueError("bootstrap_target must be in [0, 1]") + if not 0.0 <= self.warmup_target <= 1.0: + raise ValueError("warmup_target must be in [0, 1]") + if not 0.0 <= self.final_target <= 1.0: + raise ValueError("final_target must be in [0, 1]") if not 0.0 <= self.min_hard_frac <= 1.0: raise ValueError("min_hard_frac must be in [0, 1]") if not 0.0 <= self.max_hard_frac <= 1.0: @@ -102,22 +98,108 @@ def __post_init__(self) -> None: raise ValueError("min_hard_frac must be <= max_hard_frac") if self.bucket_sigma < 0: raise ValueError("bucket_sigma must be >= 0") - if self.easy_focus_sigma < 0: - raise ValueError("easy_focus_sigma must be >= 0") + if self.bootstrap_sigma < 0: + raise ValueError("bootstrap_sigma must be >= 0") + + +@dataclass +class DifficultyCurriculumAdaptiveConfig: + enabled: bool = False + update_every: int = 50 + learning_weight: float = 0.7 + exploration_weight: float = 0.3 + blend: float = 0.5 + + def __post_init__(self) -> None: + if self.update_every <= 0: + raise ValueError("update_every must be > 0") + if not 0.0 <= self.learning_weight <= 1.0: + raise ValueError("learning_weight must be in [0, 1]") + if not 0.0 <= self.exploration_weight <= 1.0: + raise ValueError("exploration_weight must be in [0, 1]") + if not 0.0 <= self.blend <= 1.0: + raise ValueError("blend must be in [0, 1]") + + +@dataclass +class DifficultyCurriculumConfig: + metadata: DifficultyCurriculumMetadataConfig = field(default_factory=DifficultyCurriculumMetadataConfig) + schedule: DifficultyCurriculumScheduleConfig = field(default_factory=DifficultyCurriculumScheduleConfig) + adaptive: DifficultyCurriculumAdaptiveConfig = field(default_factory=DifficultyCurriculumAdaptiveConfig) + uncertainty_weight: float = 0.5 + seed: int = 0 + epsilon: float = 1e-8 + + def __post_init__(self) -> None: if not 0.0 <= self.uncertainty_weight <= 1.0: raise ValueError("uncertainty_weight must be in [0, 1]") - if self.adaptive_update_every <= 0: - raise ValueError("adaptive_update_every must be > 0") - if not 0.0 <= self.adaptive_learning_signal_weight <= 1.0: - raise ValueError("adaptive_learning_signal_weight must be in [0, 1]") - if not 0.0 <= self.adaptive_exploration_weight <= 1.0: - raise ValueError("adaptive_exploration_weight must be in [0, 1]") - if not 0.0 <= self.adaptive_blend_weight <= 1.0: - raise ValueError("adaptive_blend_weight must be in [0, 1]") if self.epsilon <= 0: raise ValueError("epsilon must be > 0") +@dataclass +class DifficultyCurriculumArgs: + curriculum: Literal["none", "difficulty"] = "none" + curriculum_metadata_field: str = "difficulty" + curriculum_posterior_mean_field: str = "posterior_mean" + curriculum_bucket_index_field: str = "bucket_index" + curriculum_bucket_count_field: str = "bucket_count" + curriculum_strict_metadata: bool = True + curriculum_bootstrap_steps: int = 100 + curriculum_warmup_steps: int = 500 + curriculum_total_steps: int = 10000 + curriculum_bootstrap_target: float = 0.125 + curriculum_warmup_target: float = 0.5 + curriculum_final_target: float = 1.0 + curriculum_min_hard_frac: float = 0.05 + curriculum_max_hard_frac: float = 0.50 + curriculum_bucket_sigma: float = 0.0 + curriculum_bootstrap_sigma: float = 0.0 + curriculum_uncertainty_weight: float = 0.5 + curriculum_adaptive: bool = False + curriculum_adaptive_update_every: int = 50 + curriculum_adaptive_learning_weight: float = 0.7 + curriculum_adaptive_exploration_weight: float = 0.3 + curriculum_adaptive_blend: float = 0.5 + + def build_curriculum_config(self, *, seed: int) -> DifficultyCurriculumConfig | None: + if self.curriculum == "none": + return None + if self.curriculum != "difficulty": + raise ValueError(f"Unsupported curriculum type: {self.curriculum}") + + return DifficultyCurriculumConfig( + metadata=DifficultyCurriculumMetadataConfig( + field=self.curriculum_metadata_field, + posterior_mean_field=self.curriculum_posterior_mean_field, + bucket_index_field=self.curriculum_bucket_index_field, + bucket_count_field=self.curriculum_bucket_count_field, + strict=self.curriculum_strict_metadata, + ), + schedule=DifficultyCurriculumScheduleConfig( + bootstrap_steps=self.curriculum_bootstrap_steps, + warmup_steps=self.curriculum_warmup_steps, + total_steps=self.curriculum_total_steps, + bootstrap_target=self.curriculum_bootstrap_target, + warmup_target=self.curriculum_warmup_target, + final_target=self.curriculum_final_target, + min_hard_frac=self.curriculum_min_hard_frac, + max_hard_frac=self.curriculum_max_hard_frac, + bucket_sigma=self.curriculum_bucket_sigma, + bootstrap_sigma=self.curriculum_bootstrap_sigma, + ), + adaptive=DifficultyCurriculumAdaptiveConfig( + enabled=self.curriculum_adaptive, + update_every=self.curriculum_adaptive_update_every, + learning_weight=self.curriculum_adaptive_learning_weight, + exploration_weight=self.curriculum_adaptive_exploration_weight, + blend=self.curriculum_adaptive_blend, + ), + uncertainty_weight=self.curriculum_uncertainty_weight, + seed=seed, + ) + + class AdaptiveBucketStats: """Tracks per-bucket learning signal statistics for adaptive sampling.""" @@ -244,10 +326,202 @@ class _ParsedDifficultyMetadata: error: str | None -class BetaBinomialDifficultySampler(Sampler[int]): - """Bucket-aware curriculum sampler that uses beta-binomial difficulty metadata.""" +@dataclass(frozen=True) +class _DifficultyBucketIndex: + index_to_bucket: dict[int, int] + bucket_to_indices: tuple[tuple[int, ...], ...] + bucket_weights: tuple[torch.Tensor, ...] + bucket_count: int + metadata_fallback_count: int + + +class _DifficultyCurriculumSchedule: + def __init__(self, config: DifficultyCurriculumScheduleConfig, bucket_count: int) -> None: + self.config = config + self.bucket_count = bucket_count + + def get_progress(self, step: int) -> float: + if step < self.config.warmup_steps: + return 0.0 + return min(1.0, (step - self.config.warmup_steps) / self.config.total_steps) + + def _smooth_progress(self, step: int) -> float: + progress = self.get_progress(step) + return progress * progress * (3.0 - 2.0 * progress) + + def _get_default_bucket_sigma(self) -> float: + return max(0.85, 0.25 * max(self.bucket_count - 1, 1)) + + def _get_bucket_sigma(self, step: int) -> float: + sigma = self.config.bucket_sigma if self.config.bucket_sigma > 0 else self._get_default_bucket_sigma() + if step < self.config.bootstrap_steps and self.config.bootstrap_sigma > 0: + return self.config.bootstrap_sigma + return sigma + + def _bucket_ratio_to_bucket_index(self, bucket_ratio: float) -> float: + return float(self.bucket_count - 1) * bucket_ratio + + def _get_target_bucket(self, step: int) -> float: + warmup_target_bucket = self._bucket_ratio_to_bucket_index(self.config.warmup_target) + final_target_bucket = self._bucket_ratio_to_bucket_index(self.config.final_target) + + if self.config.bootstrap_steps > 0 and step < self.config.bootstrap_steps: + bootstrap_progress = min(1.0, step / self.config.bootstrap_steps) + bootstrap_target_bucket = self._bucket_ratio_to_bucket_index(self.config.bootstrap_target) + return bootstrap_target_bucket + (warmup_target_bucket - bootstrap_target_bucket) * bootstrap_progress + + smooth_progress = self._smooth_progress(step) + return warmup_target_bucket + (final_target_bucket - warmup_target_bucket) * smooth_progress + + def build_probs(self, step: int, available_mask: np.ndarray) -> np.ndarray: + if self.bucket_count == 1: + return np.ones(1, dtype=np.float64) + + smooth_progress = self._smooth_progress(step) + target_bucket = self._get_target_bucket(step) + hard_bucket_frac = ( + self.config.min_hard_frac + (self.config.max_hard_frac - self.config.min_hard_frac) * smooth_progress + ) + + bucket_ids = np.arange(self.bucket_count - 1, dtype=np.float64) + sigma = self._get_bucket_sigma(step) + gaussian_logits = np.exp(-0.5 * ((bucket_ids - target_bucket) / sigma) ** 2) + non_hard_probs = _normalize_probs(gaussian_logits) + + static_probs = np.zeros(self.bucket_count, dtype=np.float64) + static_probs[:-1] = (1.0 - hard_bucket_frac) * non_hard_probs + static_probs[-1] = hard_bucket_frac + + static_probs *= available_mask + if available_mask.sum() == 0: + return np.ones(self.bucket_count, dtype=np.float64) / self.bucket_count + if static_probs.sum() <= 0: + return _normalize_probs(available_mask) + return _normalize_probs(static_probs) + + +def _parse_difficulty_metadata( + example: dict[str, Any], index: int, metadata_config: DifficultyCurriculumMetadataConfig +) -> _ParsedDifficultyMetadata: + difficulty_blob = _resolve_path(example, metadata_config.field) + if not isinstance(difficulty_blob, dict): + return _ParsedDifficultyMetadata( + bucket_index=None, + bucket_count=None, + posterior_mean=None, + error=f"missing '{metadata_config.field}' metadata for dataset index {index}", + ) + + bucket_index = _coerce_int(_resolve_path(difficulty_blob, metadata_config.bucket_index_field)) + bucket_count = _coerce_int(_resolve_path(difficulty_blob, metadata_config.bucket_count_field)) + posterior_mean = _coerce_float(_resolve_path(difficulty_blob, metadata_config.posterior_mean_field)) + + if bucket_index is None or bucket_index < 0: + return _ParsedDifficultyMetadata( + bucket_index=None, + bucket_count=bucket_count, + posterior_mean=posterior_mean, + error=f"invalid bucket_index for dataset index {index}", + ) + if bucket_count is None or bucket_count <= 0: + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=None, + posterior_mean=posterior_mean, + error=f"invalid bucket_count for dataset index {index}", + ) + if posterior_mean is None: + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=bucket_count, + posterior_mean=None, + error=f"invalid posterior_mean for dataset index {index}", + ) + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, bucket_count=bucket_count, posterior_mean=posterior_mean, error=None + ) + + +def _compute_example_weight(posterior_mean: float, uncertainty_weight: float, epsilon: float) -> float: + probability = float(np.clip(posterior_mean, 0.0, 1.0)) + uncertainty = 4.0 * probability * (1.0 - probability) + hardness = 1.0 - probability + return uncertainty_weight * uncertainty + (1.0 - uncertainty_weight) * hardness + epsilon + + +def _build_difficulty_bucket_index( + dataset, metadata_config: DifficultyCurriculumMetadataConfig, uncertainty_weight: float, epsilon: float +) -> _DifficultyBucketIndex: + parsed_rows: list[_ParsedDifficultyMetadata] = [] + observed_bucket_counts: set[int] = set() + max_bucket_index = -1 + + for dataset_index in range(len(dataset)): + parsed = _parse_difficulty_metadata(dataset[dataset_index], dataset_index, metadata_config) + if parsed.error is not None and metadata_config.strict: + raise ValueError(parsed.error) + if parsed.bucket_count is not None: + observed_bucket_counts.add(parsed.bucket_count) + if parsed.bucket_index is not None: + max_bucket_index = max(max_bucket_index, parsed.bucket_index) + parsed_rows.append(parsed) + + if observed_bucket_counts: + if metadata_config.strict and len(observed_bucket_counts) > 1: + raise ValueError(f"inconsistent difficulty bucket_count values found: {sorted(observed_bucket_counts)}") + bucket_count = max(observed_bucket_counts) + elif max_bucket_index >= 0: + bucket_count = max_bucket_index + 1 + else: + bucket_count = 1 + + bucket_to_indices: list[list[int]] = [[] for _ in range(bucket_count)] + bucket_weight_lists: list[list[float]] = [[] for _ in range(bucket_count)] + index_to_bucket: dict[int, int] = {} + metadata_fallback_count = 0 + fallback_bucket = min(bucket_count - 1, bucket_count // 2) + + for dataset_index, parsed in enumerate(parsed_rows): + if parsed.error is not None: + bucket_index = fallback_bucket + posterior_mean = _DEFAULT_POSTERIOR_MEAN + metadata_fallback_count += 1 + else: + assert parsed.bucket_index is not None + bucket_index = int(np.clip(parsed.bucket_index, 0, bucket_count - 1)) + posterior_mean = parsed.posterior_mean + + if posterior_mean is None: + posterior_mean = _DEFAULT_POSTERIOR_MEAN + example_weight = _compute_example_weight( + posterior_mean=float(np.clip(posterior_mean, 0.0, 1.0)), + uncertainty_weight=uncertainty_weight, + epsilon=epsilon, + ) + + index_to_bucket[dataset_index] = bucket_index + bucket_to_indices[bucket_index].append(dataset_index) + bucket_weight_lists[bucket_index].append(example_weight) + + return _DifficultyBucketIndex( + index_to_bucket=index_to_bucket, + bucket_to_indices=tuple(tuple(indices) for indices in bucket_to_indices), + bucket_weights=tuple(torch.tensor(weight_list, dtype=torch.float64) for weight_list in bucket_weight_lists), + bucket_count=bucket_count, + metadata_fallback_count=metadata_fallback_count, + ) + + +class DifficultyCurriculumSampler(Sampler[int]): + """Bucket-aware curriculum sampler for difficulty-annotated prompt datasets.""" - def __init__(self, dataset, num_samples: int, config: DifficultyCurriculumConfig, global_step_getter) -> None: + def __init__( + self, + dataset, + num_samples: int, + config: DifficultyCurriculumConfig, + global_step_getter: Callable[[], int] | None, + ) -> None: if num_samples <= 0: raise ValueError("num_samples must be > 0") @@ -259,124 +533,35 @@ def __init__(self, dataset, num_samples: int, config: DifficultyCurriculumConfig self._generator = torch.Generator() self._generator.manual_seed(self.config.seed) - self._excluded_indices: set[int] = set() - self._index_to_bucket: dict[int, int] = {} - self.bucket_count = 1 - self.metadata_fallback_count = 0 + bucket_index = _build_difficulty_bucket_index( + dataset=dataset, + metadata_config=self.config.metadata, + uncertainty_weight=self.config.uncertainty_weight, + epsilon=self.config.epsilon, + ) + self._index_to_bucket = dict(bucket_index.index_to_bucket) + self.bucket_count = bucket_index.bucket_count + self.metadata_fallback_count = bucket_index.metadata_fallback_count + self._schedule = _DifficultyCurriculumSchedule(self.config.schedule, self.bucket_count) - self._base_bucket_indices: list[list[int]] = [] - self._base_bucket_weights: list[torch.Tensor] = [] - self._active_bucket_indices: list[list[int]] = [] - self._active_bucket_weights: list[torch.Tensor] = [] + self._excluded_indices: set[int] = set() + self._base_bucket_indices = [list(indices) for indices in bucket_index.bucket_to_indices] + self._base_bucket_weights = [weights.clone() for weights in bucket_index.bucket_weights] + self._active_bucket_indices = [list(indices) for indices in self._base_bucket_indices] + self._active_bucket_weights = [weights.clone() for weights in self._base_bucket_weights] self.adaptive_stats = None - if self.config.adaptive_enabled: + if self.config.adaptive.enabled: self.adaptive_stats = AdaptiveBucketStats( - learning_signal_weight=self.config.adaptive_learning_signal_weight, - exploration_weight=self.config.adaptive_exploration_weight, + learning_signal_weight=self.config.adaptive.learning_weight, + exploration_weight=self.config.adaptive.exploration_weight, epsilon=self.config.epsilon, ) self._cached_adaptive_probs: np.ndarray | None = None self._last_adaptive_refresh_step = -1 - self._build_bucket_index() - - def _parse_metadata(self, example: dict[str, Any], index: int) -> _ParsedDifficultyMetadata: - difficulty_blob = _resolve_path(example, self.config.difficulty_field) - if not isinstance(difficulty_blob, dict): - return _ParsedDifficultyMetadata( - bucket_index=None, - bucket_count=None, - posterior_mean=None, - error=f"missing '{self.config.difficulty_field}' metadata for dataset index {index}", - ) - - bucket_index = _coerce_int(_resolve_path(difficulty_blob, self.config.bucket_index_field)) - bucket_count = _coerce_int(_resolve_path(difficulty_blob, self.config.bucket_count_field)) - posterior_mean = _coerce_float(_resolve_path(difficulty_blob, self.config.posterior_mean_field)) - - if bucket_index is None or bucket_index < 0: - return _ParsedDifficultyMetadata( - bucket_index=None, - bucket_count=bucket_count, - posterior_mean=posterior_mean, - error=f"invalid bucket_index for dataset index {index}", - ) - if bucket_count is None or bucket_count <= 0: - return _ParsedDifficultyMetadata( - bucket_index=bucket_index, - bucket_count=None, - posterior_mean=posterior_mean, - error=f"invalid bucket_count for dataset index {index}", - ) - if posterior_mean is None: - return _ParsedDifficultyMetadata( - bucket_index=bucket_index, - bucket_count=bucket_count, - posterior_mean=None, - error=f"invalid posterior_mean for dataset index {index}", - ) - return _ParsedDifficultyMetadata( - bucket_index=bucket_index, bucket_count=bucket_count, posterior_mean=posterior_mean, error=None - ) - - def _build_bucket_index(self) -> None: - parsed_rows: list[_ParsedDifficultyMetadata] = [] - observed_bucket_counts: set[int] = set() - max_bucket_index = -1 - - for dataset_index in range(len(self.dataset)): - parsed = self._parse_metadata(self.dataset[dataset_index], dataset_index) - if parsed.error is not None and self.config.strict_metadata: - raise ValueError(parsed.error) - if parsed.bucket_count is not None: - observed_bucket_counts.add(parsed.bucket_count) - if parsed.bucket_index is not None: - max_bucket_index = max(max_bucket_index, parsed.bucket_index) - parsed_rows.append(parsed) - - if observed_bucket_counts: - if self.config.strict_metadata and len(observed_bucket_counts) > 1: - raise ValueError( - f"inconsistent difficulty bucket_count values found: {sorted(observed_bucket_counts)}" - ) - self.bucket_count = max(observed_bucket_counts) - elif max_bucket_index >= 0: - self.bucket_count = max_bucket_index + 1 - else: - self.bucket_count = 1 - - self._base_bucket_indices = [[] for _ in range(self.bucket_count)] - bucket_weight_lists: list[list[float]] = [[] for _ in range(self.bucket_count)] - fallback_bucket = min(self.bucket_count - 1, self.bucket_count // 2) - - for dataset_index, parsed in enumerate(parsed_rows): - if parsed.error is not None: - bucket_index = fallback_bucket - posterior_mean = _DEFAULT_POSTERIOR_MEAN - self.metadata_fallback_count += 1 - else: - assert parsed.bucket_index is not None - bucket_index = int(np.clip(parsed.bucket_index, 0, self.bucket_count - 1)) - posterior_mean = parsed.posterior_mean - - if posterior_mean is None: - posterior_mean = _DEFAULT_POSTERIOR_MEAN - posterior_mean = float(np.clip(posterior_mean, 0.0, 1.0)) - example_weight = self._compute_example_weight(posterior_mean) - - self._index_to_bucket[dataset_index] = bucket_index - self._base_bucket_indices[bucket_index].append(dataset_index) - bucket_weight_lists[bucket_index].append(example_weight) - - self._base_bucket_weights = [ - torch.tensor(weight_list, dtype=torch.float64) for weight_list in bucket_weight_lists - ] - self._active_bucket_indices = [list(indices) for indices in self._base_bucket_indices] - self._active_bucket_weights = [weights.clone() for weights in self._base_bucket_weights] - - if self.metadata_fallback_count > 0 and not self.config.strict_metadata: + if self.metadata_fallback_count > 0 and not self.config.metadata.strict: logger.warning( "Difficulty curriculum fell back to conservative defaults for %s/%s rows because metadata was missing " "or invalid.", @@ -384,16 +569,6 @@ def _build_bucket_index(self) -> None: len(self.dataset), ) - def _compute_example_weight(self, posterior_mean: float) -> float: - probability = float(np.clip(posterior_mean, 0.0, 1.0)) - uncertainty = 4.0 * probability * (1.0 - probability) - hardness = 1.0 - probability - return ( - self.config.uncertainty_weight * uncertainty - + (1.0 - self.config.uncertainty_weight) * hardness - + self.config.epsilon - ) - def __len__(self) -> int: return self.num_samples @@ -408,82 +583,24 @@ def _get_current_step(self) -> int: def get_progress(self, step: int | None = None) -> float: if step is None: step = self._get_current_step() - if step < self.config.warmup_steps: - return 0.0 - return min(1.0, (step - self.config.warmup_steps) / self.config.total_curriculum_steps) - - def _smooth_progress(self, step: int | None = None) -> float: - progress = self.get_progress(step) - return progress * progress * (3.0 - 2.0 * progress) - - def _get_default_bucket_sigma(self) -> float: - return max(0.85, 0.25 * max(self.bucket_count - 1, 1)) - - def _get_bucket_sigma(self, step: int | None = None) -> float: - sigma = self.config.bucket_sigma if self.config.bucket_sigma > 0 else self._get_default_bucket_sigma() - if step is not None and step < self.config.easy_focus_steps and self.config.easy_focus_sigma > 0: - return self.config.easy_focus_sigma - return sigma - - def _bucket_ratio_to_bucket_index(self, bucket_ratio: float) -> float: - return float(self.bucket_count - 1) * bucket_ratio - - def _get_target_bucket(self, step: int | None = None) -> float: - if step is None: - step = self._get_current_step() - - warmup_target_bucket = self._bucket_ratio_to_bucket_index(self.config.warmup_target_bucket_ratio) - final_target_bucket = self._bucket_ratio_to_bucket_index(self.config.final_target_bucket_ratio) - - if self.config.easy_focus_steps > 0 and step < self.config.easy_focus_steps: - easy_progress = min(1.0, step / self.config.easy_focus_steps) - bootstrap_target_bucket = self._bucket_ratio_to_bucket_index(self.config.bootstrap_target_bucket_ratio) - return bootstrap_target_bucket + (warmup_target_bucket - bootstrap_target_bucket) * easy_progress - - smooth_progress = self._smooth_progress(step) - return warmup_target_bucket + (final_target_bucket - warmup_target_bucket) * smooth_progress + return self._schedule.get_progress(step) def _available_bucket_mask(self) -> np.ndarray: return np.array([1.0 if indices else 0.0 for indices in self._active_bucket_indices], dtype=np.float64) def get_static_bucket_probs(self, step: int | None = None) -> np.ndarray: - if self.bucket_count == 1: - return np.ones(1, dtype=np.float64) - if step is None: step = self._get_current_step() - - smooth_progress = self._smooth_progress(step) - target_bucket = self._get_target_bucket(step) - hard_bucket_frac = ( - self.config.min_hard_frac + (self.config.max_hard_frac - self.config.min_hard_frac) * smooth_progress - ) - - bucket_ids = np.arange(self.bucket_count - 1, dtype=np.float64) - sigma = self._get_bucket_sigma(step) - gaussian_logits = np.exp(-0.5 * ((bucket_ids - target_bucket) / sigma) ** 2) - non_hard_probs = _normalize_probs(gaussian_logits) - - static_probs = np.zeros(self.bucket_count, dtype=np.float64) - static_probs[:-1] = (1.0 - hard_bucket_frac) * non_hard_probs - static_probs[-1] = hard_bucket_frac - - mask = self._available_bucket_mask() - static_probs *= mask - if mask.sum() == 0: - return np.ones(self.bucket_count, dtype=np.float64) / self.bucket_count - if static_probs.sum() <= 0: - return _normalize_probs(mask) - return _normalize_probs(static_probs) + return self._schedule.build_probs(step, self._available_bucket_mask()) def get_adaptive_bucket_probs(self, step: int | None = None) -> np.ndarray | None: - if not self.config.adaptive_enabled or self.adaptive_stats is None or self.adaptive_stats.total_count == 0: + if not self.config.adaptive.enabled or self.adaptive_stats is None or self.adaptive_stats.total_count == 0: return None refresh_step = self._get_current_step() if step is None else step if ( self._cached_adaptive_probs is not None - and refresh_step - self._last_adaptive_refresh_step < self.config.adaptive_update_every + and refresh_step - self._last_adaptive_refresh_step < self.config.adaptive.update_every ): return self._cached_adaptive_probs.copy() @@ -502,9 +619,7 @@ def get_bucket_probs(self, step: int | None = None) -> np.ndarray: if adaptive_probs is None: return static_probs - final_probs = ( - 1.0 - self.config.adaptive_blend_weight - ) * static_probs + self.config.adaptive_blend_weight * adaptive_probs + final_probs = (1.0 - self.config.adaptive.blend) * static_probs + self.config.adaptive.blend * adaptive_probs final_probs *= self._available_bucket_mask() if final_probs.sum() <= 0: return static_probs @@ -578,7 +693,7 @@ def record_observations( rewards: list[float] | np.ndarray, advantages: list[float] | np.ndarray | None = None, ) -> None: - if not self.config.adaptive_enabled or self.adaptive_stats is None: + if not self.config.adaptive.enabled or self.adaptive_stats is None: return if len(dataset_indices) == 0: return @@ -595,7 +710,7 @@ def build_metrics(self, prompt_dataset_indices: list[int], step: int | None = No metrics[f"curriculum/static_bucket_prob_{bucket_index}"] = float(probability) adaptive_probs = self.get_adaptive_bucket_probs(step) - if self.config.adaptive_enabled: + if self.config.adaptive.enabled: if adaptive_probs is None: adaptive_probs = np.zeros(self.bucket_count, dtype=np.float64) for bucket_index, probability in enumerate(adaptive_probs): @@ -611,7 +726,7 @@ def build_metrics(self, prompt_dataset_indices: list[int], step: int | None = No for bucket_index, count in enumerate(sampled_counts): metrics[f"curriculum/sampled_bucket_count_{bucket_index}"] = float(count) - if self.config.adaptive_enabled and self.adaptive_stats is not None: + if self.config.adaptive.enabled and self.adaptive_stats is not None: for bucket_index in range(self.bucket_count): metrics[f"curriculum/bucket_reward_mean_{bucket_index}"] = float( self.adaptive_stats.get_mean_reward(bucket_index) diff --git a/open_instruct/rlvr_difficulty.py b/open_instruct/rlvr_difficulty.py new file mode 100644 index 0000000000..3d193283e0 --- /dev/null +++ b/open_instruct/rlvr_difficulty.py @@ -0,0 +1,1521 @@ +""" +Build a per-instance difficulty map from open-instruct rollout traces or +Hugging Face datasets with pass-rate aggregates. + +The script accepts one or more local rollout directories, metadata ``.jsonl`` +files, rollout shard ``.jsonl`` files written by ``open_instruct.rl_utils``, +or a Hugging Face dataset that already contains per-row pass counts. For each +prompt instance it: + +1. loads rollout shards written by ``save_rollouts_to_disk()``, including compact score-only shards, + or loads per-row pass counts from a Hub dataset, +2. groups attempts by source dataset identity when available, otherwise by a + deterministic fingerprint over task name, prompt tokens, and ground truth, +3. normalizes binary verifiable rewards from ``{0, C}`` back to ``{0, 1}`` + when possible, +4. fits a Beta prior across binary outcomes and estimates per-item success + rates, and +5. writes a JSONL difficulty file and schema/metadata sidecars. + +Examples: + uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + --source /tmp/qwen_math_rollouts \ + --task math \ + --output /tmp/qwen_math_difficulty + + uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + --source /tmp/qwen_math_rollouts/qwen_math_metadata.jsonl \ + --output /tmp/difficulty_map + + uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --output /tmp/dapo_math_qwen3_difficulty +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import math +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from datasets import Dataset, load_dataset +from scipy.optimize import minimize +from scipy.special import betaln +from scipy.stats import beta as beta_distribution + +from open_instruct import logger_utils + +logger = logger_utils.setup_logger(__name__) + + +EPS = 1e-8 +EXPERIMENT_METADATA_KEYS = ("source_root", "model_name", "experiment_id", "experiment_name") +JEFFREYS_PRIOR_ALPHA = 0.5 +JEFFREYS_PRIOR_BETA = 0.5 +DEFAULT_DIFFICULTY_BUCKETS = 5 +POSTERIOR_QUANTILE_GRID_SIZE = 512 +POSTERIOR_QUANTILE_BATCH_SIZE = 256 +DIFFICULTY_GENERATION_METHOD = "beta_binomial_posterior_quantiles" +DIFFICULTY_METHOD_FILENAME_ALIASES = {DIFFICULTY_GENERATION_METHOD: "bbq"} +PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} +ROLLOUT_SOURCE_FORMAT_KIND = "open_instruct_rollout_traces" +HF_SOURCE_FORMAT_KIND = "hugging_face_dataset_passrate_rows" +ROLLOUT_INSTANCE_ID_DEFINITION = ( + "source_dataset::source_dataset_id when available; otherwise sha1(task_name,prompt_tokens,ground_truth)" +) +HF_INSTANCE_ID_DEFINITION = ( + "dataset_repo_id::row_id_field when a stable row id is available; otherwise dataset_repo_id::row_index" +) +HF_SOURCE_ROW_INDEX_FIELD = "_source_row_index" +HF_OUTPUT_COLUMNS = ("difficulty",) + + +@dataclass(frozen=True) +class BetaPrior: + alpha: float + beta: float + source: str + + +@dataclass(frozen=True) +class RolloutSource: + input_arg: str + root_path: Path + metadata_path: Path + rollout_paths: tuple[Path, ...] + run_name: str + + +@dataclass(frozen=True) +class DifficultyPosteriorRow: + row: dict[str, Any] + difficulty_alpha: float + difficulty_beta: float + + +@dataclass(frozen=True) +class InputRowsBundle: + rows: list[dict[str, Any]] + malformed_records: int + source_format: dict[str, Any] + source_dataset: Dataset | None = None + + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Build a per-instance difficulty map from open-instruct rollout traces or HF pass-rate datasets.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + source_group = parser.add_mutually_exclusive_group(required=True) + source_group.add_argument( + "--source", + nargs="+", + help="One or more local rollout dirs, *_metadata.jsonl files, or *_rollouts_*.jsonl shards.", + ) + source_group.add_argument( + "--hf-dataset", + type=str, + default=None, + help="Hugging Face dataset repo id containing per-row pass-rate aggregates.", + ) + parser.add_argument("--hf-config", type=str, default=None, help="Optional dataset config for --hf-dataset.") + parser.add_argument("--hf-split", type=str, default="train", help="Input split to load from --hf-dataset.") + parser.add_argument( + "--hf-row-id-field", + type=str, + default="extra_info.index", + help="Dot-path to the stable per-row id field inside --hf-dataset.", + ) + parser.add_argument( + "--hf-task-field", type=str, default="dataset", help="Dot-path to the task/verifier field in --hf-dataset." + ) + parser.add_argument( + "--hf-model-field", + type=str, + default="generator_model", + help="Dot-path to the generator model field in --hf-dataset.", + ) + parser.add_argument( + "--hf-pass-count-field", + type=str, + default="pass_count", + help="Dot-path to the integer pass-count field in --hf-dataset.", + ) + parser.add_argument( + "--hf-attempt-count-field", + type=str, + default="num_samples", + help="Dot-path to the total-attempt-count field in --hf-dataset.", + ) + parser.add_argument( + "--hf-pass-rate-field", + type=str, + default="pass_rate", + help="Optional dot-path to a pass-rate or fraction field used for validation/fallback in --hf-dataset.", + ) + parser.add_argument( + "--task", + action="append", + default=[], + help="Optional task filter. Matches the rollout trace dataset/verifier source.", + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help=( + "Output directory or path-like root. The script writes one file per task/model inside it as " + "____.jsonl plus matching .schema.json and .metadata.json sidecars." + ), + ) + parser.add_argument( + "--push-to-hub", type=str, default=None, help="Optional dataset repo id to push the validated rows to." + ) + parser.add_argument("--split", type=str, default="train", help="Split to use with --push-to-hub.") + parser.add_argument( + "--strict", action="store_true", help="Fail if a rollout record is malformed or required files are missing." + ) + parser.add_argument( + "--allow-nonunit-scores", + action="store_true", + help="Keep rows whose rewards cannot be normalized to binary correctness. Difficulty will be null for them.", + ) + parser.add_argument( + "--max-instances", + type=int, + default=None, + help="Optional cap for the number of resolved instances written (useful for smoke tests).", + ) + parser.add_argument( + "--beta-prior", + choices=["empirical-bayes", "jeffreys"], + default="empirical-bayes", + help="Global Beta prior to use for smoothing binary solve rates.", + ) + parser.add_argument( + "--posterior-lower-quantile", + type=float, + default=0.1, + help="Lower posterior quantile used to define difficulty as 1 - quantile.", + ) + parser.add_argument( + "--difficulty-buckets", + type=int, + default=DEFAULT_DIFFICULTY_BUCKETS, + help=( + "Number of posterior-aware quantile buckets to assign for stratification. " + "Set to 0 to skip discrete bucket assignment." + ), + ) + return parser + + +def main(argv: list[str] | None = None) -> None: + args = make_parser().parse_args(argv) + validate_args(args) + task_filters = set(args.task) + output_root = resolve_output_root(args.output) + + input_rows = load_input_rows(args, task_filters=task_filters) + + if not input_rows.rows: + raise ValueError("No resolved per-instance rows were produced.") + + rows = sorted( + input_rows.rows, + key=lambda row: ( + stable_string(row.get("task_name")), + stable_string((row.get("experiment_metadata") or {}).get("model_name")), + stable_string(row.get("instance_id")), + ), + ) + if args.max_instances is not None: + rows = rows[: args.max_instances] + + rows_by_group = group_rows_by_task_and_model(rows) + if args.push_to_hub is not None and len(rows_by_group) != 1: + raise ValueError( + "--push-to-hub requires a single task/model output. Filter with --task or use a source with one task." + ) + + skipped_nonunit = 0 + written_outputs: list[tuple[str, str | None, int, Path, Path, Path]] = [] + + for (task_name, model_name), group_rows in sorted( + rows_by_group.items(), key=lambda item: (item[0][0], stable_string(item[0][1])) + ): + group_rows, score_processing, group_skipped_nonunit = normalize_attempt_scores_for_group( + group_rows, allow_nonunit_scores=args.allow_nonunit_scores + ) + if input_rows.source_format["kind"] == HF_SOURCE_FORMAT_KIND: + score_processing["source_field"] = ",".join( + field_name + for field_name in ( + input_rows.source_format.get("pass_count_field"), + input_rows.source_format.get("attempt_count_field"), + input_rows.source_format.get("pass_rate_field"), + ) + if field_name + ) + skipped_nonunit += group_skipped_nonunit + + if not group_rows: + logger.warning( + "Skipping task=%s model=%s because no rows remained after reward normalization.", task_name, model_name + ) + continue + + prior, binary_row_count = estimate_beta_prior(group_rows, prior_mode=args.beta_prior) + group_rows = apply_beta_binomial_difficulty( + group_rows, prior=prior, lower_quantile=args.posterior_lower_quantile, num_buckets=args.difficulty_buckets + ) + if input_rows.source_dataset is None: + ordered_group_rows = sorted(group_rows, key=lambda row: row["instance_id"]) + output_rows = strip_output_only_rollout_fields(ordered_group_rows) + dataset = Dataset.from_list(output_rows) + else: + ordered_group_rows = sort_hf_group_rows(group_rows) + output_rows = strip_internal_fields(ordered_group_rows) + dataset = build_hf_output_dataset(input_rows.source_dataset, ordered_group_rows) + + dataset_metadata = build_dataset_metadata( + rows=output_rows, + task_name=task_name, + model_name=model_name, + requested_prior_mode=args.beta_prior, + requested_bucket_count=args.difficulty_buckets, + lower_quantile=args.posterior_lower_quantile, + prior=prior, + binary_row_count=binary_row_count, + score_processing=score_processing, + source_format=input_rows.source_format, + ) + + if prior is not None: + logger.info( + "Using %s Beta prior alpha=%.4f beta=%.4f across %s binary instances for task=%s model=%s.", + prior.source, + prior.alpha, + prior.beta, + binary_row_count, + task_name, + model_name, + ) + else: + logger.warning( + "No binary instances were available for Beta-Binomial difficulty estimation for task=%s model=%s.", + task_name, + model_name, + ) + + annotate_dataset_metadata(dataset, dataset_metadata) + output_jsonl, schema_json, metadata_json = build_output_paths( + output_root, task_name=task_name, model_name=model_name, dataset_metadata=dataset_metadata + ) + write_output_files( + output_jsonl=output_jsonl, + schema_json=schema_json, + metadata_json=metadata_json, + dataset=dataset, + dataset_metadata=dataset_metadata, + ) + + if args.push_to_hub is not None: + dataset.push_to_hub(args.push_to_hub, split=args.split, private=True) + + written_outputs.append((task_name, model_name, len(output_rows), output_jsonl, schema_json, metadata_json)) + logger.info( + "Wrote %s rows for task=%s model=%s to %s, %s, and %s.", + len(output_rows), + task_name, + model_name, + output_jsonl, + schema_json, + metadata_json, + ) + + logger.info( + "Finished writing %s output file groups (%s malformed rollout records, %s skipped due to unsupported scores).", + len(written_outputs), + input_rows.malformed_records, + skipped_nonunit, + ) + + +def load_input_rows(args: argparse.Namespace, *, task_filters: set[str]) -> InputRowsBundle: + if args.hf_dataset is not None: + return load_hf_dataset_rows( + dataset_name=args.hf_dataset, + config_name=args.hf_config, + split=args.hf_split, + task_filters=task_filters, + strict=args.strict, + row_id_field=args.hf_row_id_field, + task_field=args.hf_task_field, + model_field=args.hf_model_field, + pass_count_field=args.hf_pass_count_field, + attempt_count_field=args.hf_attempt_count_field, + pass_rate_field=args.hf_pass_rate_field, + ) + + if not args.source: + raise ValueError("Expected --source when --hf-dataset is not provided.") + + source_runs = discover_rollout_sources(args.source) + if not source_runs: + raise ValueError("No rollout trace sources were found.") + + contributions: list[dict[str, Any]] = [] + malformed_records = 0 + + for source_run in source_runs: + logger.info( + "Loading %s (run=%s, metadata=%s, shards=%s)", + source_run.input_arg, + source_run.run_name, + source_run.metadata_path, + len(source_run.rollout_paths), + ) + run_contributions, run_malformed = build_contributions_for_source( + source_run=source_run, task_filters=task_filters, strict=args.strict + ) + contributions.extend(run_contributions) + malformed_records += run_malformed + + return InputRowsBundle( + rows=aggregate_contributions(contributions), + malformed_records=malformed_records, + source_format=build_rollout_source_format_metadata(), + ) + + +def load_hf_dataset_rows( + *, + dataset_name: str, + config_name: str | None, + split: str, + task_filters: set[str], + strict: bool, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> InputRowsBundle: + logger.info( + "Loading Hugging Face dataset %s (config=%s, split=%s).", dataset_name, config_name or "default", split + ) + + if config_name: + source_dataset = load_dataset(dataset_name, config_name, split=split) + else: + source_dataset = load_dataset(dataset_name, split=split) + + rows: list[dict[str, Any]] = [] + malformed_records = 0 + + for row_index, source_row in enumerate(source_dataset): + try: + row = build_hf_dataset_row( + source_row=source_row, + source_row_index=row_index, + dataset_name=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ) + except Exception as exc: + malformed_records += 1 + message = f"Malformed HF dataset row {dataset_name}[{split}][{row_index}]: {exc}" + if strict: + raise ValueError(message) from exc + logger.warning(message) + continue + + task_name = stable_string(row.get("task_name")) + if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: + continue + rows.append(row) + + return InputRowsBundle( + rows=rows, + malformed_records=malformed_records, + source_format=build_hf_source_format_metadata( + dataset_name=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ), + source_dataset=source_dataset, + ) + + +def build_hf_dataset_row( + *, + source_row: dict[str, Any], + source_row_index: int, + dataset_name: str, + config_name: str | None, + split: str, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> dict[str, Any]: + task_name = normalize_task_name(get_nested_field(source_row, task_field)) + if task_name is None: + raise ValueError(f"missing task field {task_field!r}") + + source_row_id = normalize_identifier(get_nested_field(source_row, row_id_field)) or str(source_row_index) + pass_count, attempt_count = extract_hf_attempt_summary( + row=source_row, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ) + model_name = optional_string(get_nested_field(source_row, model_field)) + + return { + HF_SOURCE_ROW_INDEX_FIELD: source_row_index, + "instance_id": make_hf_instance_id(dataset_name=dataset_name, source_row_id=source_row_id), + "task_name": task_name, + "base_task_name": get_base_task_name(task_name), + "source_dataset": dataset_name, + "source_row_id": source_row_id, + "attempt_scores": expand_binary_attempt_scores(pass_count=pass_count, attempt_count=attempt_count), + "finish_reasons": [], + "experiment_metadata": { + "source_root": format_hf_source_locator(dataset_name=dataset_name, config_name=config_name, split=split), + "model_name": model_name, + "experiment_id": None, + "experiment_name": dataset_name, + }, + "score_sources": [task_name], + "warnings": [], + } + + +def build_rollout_source_format_metadata() -> dict[str, Any]: + return { + "kind": ROLLOUT_SOURCE_FORMAT_KIND, + "task_field": "dataset", + "score_field": "reward", + "source_dataset_field": "source_dataset", + "source_dataset_id_field": "source_dataset_id", + "source_row_id_field": "source_row_id", + "instance_id_definition": ROLLOUT_INSTANCE_ID_DEFINITION, + } + + +def build_hf_source_format_metadata( + *, + dataset_name: str, + config_name: str | None, + split: str, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> dict[str, Any]: + return { + "kind": HF_SOURCE_FORMAT_KIND, + "dataset_repo_id": dataset_name, + "config_name": config_name, + "split": split, + "row_id_field": row_id_field, + "task_field": task_field, + "model_field": model_field, + "pass_count_field": pass_count_field, + "attempt_count_field": attempt_count_field, + "pass_rate_field": pass_rate_field, + "instance_id_definition": HF_INSTANCE_ID_DEFINITION, + } + + +def format_hf_source_locator(*, dataset_name: str, config_name: str | None, split: str) -> str: + config_token = config_name or "default" + return f"hf://{dataset_name}/{config_token}/{split}" + + +def make_hf_instance_id(*, dataset_name: str, source_row_id: str) -> str: + return f"{dataset_name}::{source_row_id}" + + +def sort_hf_group_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + return sorted(rows, key=lambda row: row[HF_SOURCE_ROW_INDEX_FIELD]) + + +def build_hf_output_dataset(source_dataset: Dataset, rows: list[dict[str, Any]]) -> Dataset: + ordered_rows = sort_hf_group_rows(rows) + dataset = source_dataset.select([row[HF_SOURCE_ROW_INDEX_FIELD] for row in ordered_rows]) + + for column_name in HF_OUTPUT_COLUMNS: + values = [make_jsonable(row.get(column_name)) for row in ordered_rows] + if column_name in dataset.column_names: + dataset = dataset.remove_columns(column_name) + dataset = dataset.add_column(column_name, values) + + return dataset + + +def strip_internal_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [{key: value for key, value in row.items() if key != HF_SOURCE_ROW_INDEX_FIELD} for row in rows] + + +def get_nested_field(value: Any, field_path: str) -> Any: + if not field_path: + return value + + current = value + for field_name in field_path.split("."): + if not isinstance(current, dict) or field_name not in current: + return None + current = current[field_name] + return current + + +def normalize_identifier(value: Any) -> str | None: + if value is None or isinstance(value, bool): + return None + text = stable_string(value).strip() + return text or None + + +def normalize_nonnegative_int(value: Any) -> int | None: + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value if value >= 0 else None + if isinstance(value, float): + if not math.isfinite(value) or not value.is_integer() or value < 0: + return None + return int(value) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + parsed = int(stripped) + except ValueError: + return None + return parsed if parsed >= 0 else None + return None + + +def parse_pass_rate_value(value: Any) -> tuple[int | None, int | None, float | None]: + if value is None: + return None, None, None + if is_number(value): + rate = float(value) + if 0.0 <= rate <= 1.0: + return None, None, rate + raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") + if not isinstance(value, str): + raise ValueError(f"unsupported pass-rate value {value!r}") + + stripped = value.strip() + if not stripped: + return None, None, None + + if "/" in stripped: + numerator_text, denominator_text = stripped.split("/", 1) + numerator = normalize_nonnegative_int(numerator_text) + denominator = normalize_nonnegative_int(denominator_text) + if numerator is None or denominator is None or numerator > denominator: + raise ValueError(f"invalid pass-rate fraction {value!r}") + rate = 0.0 if denominator == 0 else numerator / denominator + return numerator, denominator, rate + + try: + rate = float(stripped) + except ValueError as exc: + raise ValueError(f"invalid pass-rate value {value!r}") from exc + if not math.isfinite(rate) or rate < 0.0 or rate > 1.0: + raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") + return None, None, rate + + +def extract_hf_attempt_summary( + *, row: dict[str, Any], pass_count_field: str, attempt_count_field: str, pass_rate_field: str | None +) -> tuple[int, int]: + pass_count = normalize_nonnegative_int(get_nested_field(row, pass_count_field)) + attempt_count = normalize_nonnegative_int(get_nested_field(row, attempt_count_field)) + + parsed_pass_count = None + parsed_attempt_count = None + parsed_pass_rate = None + if pass_rate_field: + parsed_pass_count, parsed_attempt_count, parsed_pass_rate = parse_pass_rate_value( + get_nested_field(row, pass_rate_field) + ) + + if pass_count is None and parsed_pass_count is not None: + pass_count = parsed_pass_count + if attempt_count is None and parsed_attempt_count is not None: + attempt_count = parsed_attempt_count + + if pass_count is None or attempt_count is None: + raise ValueError( + f"missing pass-count summary fields {pass_count_field!r}/{attempt_count_field!r}" + f"{f' or parseable {pass_rate_field!r}' if pass_rate_field else ''}" + ) + if attempt_count <= 0: + raise ValueError(f"attempt count must be positive, received {attempt_count}") + if pass_count > attempt_count: + raise ValueError(f"pass count {pass_count} exceeds attempt count {attempt_count}") + + if parsed_pass_count is not None and parsed_pass_count != pass_count: + raise ValueError(f"pass-count field {pass_count_field!r} disagrees with {pass_rate_field!r}") + if parsed_attempt_count is not None and parsed_attempt_count != attempt_count: + raise ValueError(f"attempt-count field {attempt_count_field!r} disagrees with {pass_rate_field!r}") + if parsed_pass_rate is not None and not is_close(pass_count / attempt_count, parsed_pass_rate): + raise ValueError( + f"pass-count fields {pass_count_field!r}/{attempt_count_field!r} disagree with {pass_rate_field!r}" + ) + + return pass_count, attempt_count + + +def expand_binary_attempt_scores(*, pass_count: int, attempt_count: int) -> list[float]: + return [1.0] * pass_count + [0.0] * (attempt_count - pass_count) + + +def discover_rollout_sources(sources: list[str]) -> list[RolloutSource]: + discovered: dict[Path, RolloutSource] = {} + + for source in sources: + source_path = Path(source) + if not source_path.exists(): + raise FileNotFoundError(f"Could not find source path {source}") + + if source_path.is_dir(): + metadata_paths = sorted(source_path.rglob("*_metadata.jsonl")) + if not metadata_paths: + raise FileNotFoundError(f"Could not find *_metadata.jsonl under {source}") + for metadata_path in metadata_paths: + rollout_source = build_rollout_source_from_metadata(metadata_path, input_arg=source) + discovered[rollout_source.metadata_path] = rollout_source + continue + + if source_path.name.endswith("_metadata.jsonl"): + rollout_source = build_rollout_source_from_metadata(source_path, input_arg=source) + discovered[rollout_source.metadata_path] = rollout_source + continue + + if source_path.suffix == ".jsonl" and "_rollouts_" in source_path.name: + rollout_source = build_rollout_source_from_rollout(source_path, input_arg=source) + discovered[rollout_source.metadata_path] = rollout_source + continue + + raise ValueError( + f"Unsupported source path {source}. Expected a directory, *_metadata.jsonl, or *_rollouts_*.jsonl." + ) + + return sorted(discovered.values(), key=lambda source_run: (str(source_run.root_path), source_run.run_name)) + + +def build_rollout_source_from_metadata(metadata_path: Path, *, input_arg: str) -> RolloutSource: + run_name = parse_run_name_from_metadata_path(metadata_path) + rollout_paths = tuple(sorted(metadata_path.parent.glob(f"{run_name}_rollouts_*.jsonl"))) + if not rollout_paths: + raise FileNotFoundError(f"Could not find rollout shards for run {run_name} next to {metadata_path}") + return RolloutSource( + input_arg=input_arg, + root_path=metadata_path.parent.absolute(), + metadata_path=metadata_path.absolute(), + rollout_paths=rollout_paths, + run_name=run_name, + ) + + +def build_rollout_source_from_rollout(rollout_path: Path, *, input_arg: str) -> RolloutSource: + run_name = parse_run_name_from_rollout_path(rollout_path) + metadata_path = rollout_path.parent / f"{run_name}_metadata.jsonl" + if not metadata_path.exists(): + raise FileNotFoundError(f"Could not find metadata file {metadata_path} for rollout shard {rollout_path}") + return build_rollout_source_from_metadata(metadata_path, input_arg=input_arg) + + +def parse_run_name_from_metadata_path(metadata_path: Path) -> str: + suffix = "_metadata.jsonl" + if not metadata_path.name.endswith(suffix): + raise ValueError(f"Metadata path must end with {suffix}: {metadata_path}") + return metadata_path.name[: -len(suffix)] + + +def parse_run_name_from_rollout_path(rollout_path: Path) -> str: + marker = "_rollouts_" + if marker not in rollout_path.name: + raise ValueError(f"Rollout shard filename must contain {marker}: {rollout_path}") + return rollout_path.name.split(marker, 1)[0] + + +def build_contributions_for_source( + *, source_run: RolloutSource, task_filters: set[str], strict: bool +) -> tuple[list[dict[str, Any]], int]: + run_metadata = read_rollout_metadata(source_run.metadata_path, fallback_run_name=source_run.run_name) + contributions: list[dict[str, Any]] = [] + malformed_records = 0 + + for rollout_path in source_run.rollout_paths: + for line_number, record in enumerate(read_jsonl(rollout_path), start=1): + try: + contribution = build_rollout_contribution( + record=record, source_run=source_run, run_metadata=run_metadata + ) + except Exception as exc: + malformed_records += 1 + message = f"Malformed rollout record in {rollout_path}:{line_number}: {exc}" + if strict: + raise ValueError(message) from exc + logger.warning(message) + continue + + task_name = stable_string(contribution.get("task_name")) + if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: + continue + contributions.append(contribution) + + return contributions, malformed_records + + +def read_rollout_metadata(metadata_path: Path, *, fallback_run_name: str) -> dict[str, Any]: + rows = read_jsonl(metadata_path) + if not rows: + raise ValueError(f"Metadata file is empty: {metadata_path}") + if len(rows) > 1: + logger.warning("Expected one metadata row in %s but found %s. Using the first row.", metadata_path, len(rows)) + + metadata = rows[0] + return { + "run_name": optional_string(metadata.get("run_name")) or fallback_run_name, + "model_name": optional_string(metadata.get("model_name")), + "experiment_id": optional_string(metadata.get("experiment_id")), + "git_commit": optional_string(metadata.get("git_commit")), + "timestamp": optional_string(metadata.get("timestamp")), + } + + +def build_rollout_contribution( + *, record: dict[str, Any], source_run: RolloutSource, run_metadata: dict[str, Any] +) -> dict[str, Any]: + task_name = normalize_task_name(record.get("dataset")) + if task_name is None: + raise ValueError("missing dataset/verifier source") + + source_dataset = normalize_source_dataset(record.get("source_dataset")) + source_dataset_id = extract_source_dataset_id(record) + + prompt_tokens = normalize_token_list(record.get("prompt_tokens")) + if prompt_tokens is None and (source_dataset is None or source_dataset_id is None): + raise ValueError("missing prompt_tokens and source dataset identity (source_dataset/source_row_id)") + + reward = extract_numeric_reward(record.get("reward")) + if reward is None: + raise ValueError("missing or invalid reward") + + ground_truth = make_jsonable(record.get("ground_truth")) + finish_reason = optional_string(record.get("finish_reason")) + + return { + "instance_id": make_rollout_instance_id( + task_name=task_name, + prompt_tokens=prompt_tokens, + ground_truth=ground_truth, + source_dataset=source_dataset, + source_dataset_id=source_dataset_id, + ), + "task_name": task_name, + "base_task_name": get_base_task_name(task_name), + "prompt_tokens": prompt_tokens, + "ground_truth": ground_truth, + "source_dataset": source_dataset, + "source_dataset_id": source_dataset_id, + "score_source": task_name, + "attempt_scores": [reward], + "finish_reasons": [finish_reason] if finish_reason else [], + "experiment_metadata": { + "source_root": str(source_run.root_path), + "model_name": run_metadata["model_name"], + "experiment_id": run_metadata["experiment_id"], + "experiment_name": run_metadata["run_name"], + }, + "warnings": extract_rollout_warnings(record.get("request_info")), + } + + +def normalize_task_name(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)) and len(value) == 1: + return normalize_task_name(value[0]) + serialized = serialize_value(value) + return serialized or None + + +def normalize_source_dataset(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)) and len(value) == 1: + return normalize_source_dataset(value[0]) + serialized = serialize_value(value) + return serialized or None + + +def extract_source_dataset_id(record: dict[str, Any]) -> int | None: + for field_name in ("source_dataset_id", "source_row_id"): + source_dataset_id = normalize_source_dataset_id(record.get(field_name)) + if source_dataset_id is not None: + return source_dataset_id + return None + + +def normalize_source_dataset_id(value: Any) -> int | None: + return normalize_nonnegative_int(value) + + +def normalize_token_list(value: Any) -> list[int] | None: + if not isinstance(value, list): + return None + + tokens: list[int] = [] + for item in value: + if isinstance(item, bool) or not isinstance(item, (int, float)): + return None + tokens.append(int(item)) + return tokens + + +def extract_numeric_reward(value: Any) -> float | None: + if not is_number(value): + return None + return float(value) + + +def extract_rollout_warnings(request_info: Any) -> list[str]: + if not isinstance(request_info, dict): + return [] + + warnings: list[str] = [] + if request_info.get("timeouts"): + warnings.append("timeout") + if optional_string(request_info.get("tool_errors")): + warnings.append("tool_error") + return warnings + + +def aggregate_contributions(contributions: list[dict[str, Any]]) -> list[dict[str, Any]]: + grouped: dict[str, dict[str, Any]] = {} + + for contribution in contributions: + instance_id = contribution["instance_id"] + if instance_id not in grouped: + grouped[instance_id] = { + key: value + for key, value in contribution.items() + if key not in {"attempt_scores", "finish_reasons", "experiment_metadata", "warnings", "score_source"} + } + grouped[instance_id]["attempt_scores"] = [] + grouped[instance_id]["finish_reasons"] = [] + grouped[instance_id]["experiment_metadata"] = None + grouped[instance_id]["score_sources"] = set() + grouped[instance_id]["warnings"] = set() + + row = grouped[instance_id] + row["attempt_scores"].extend(float(score) for score in contribution["attempt_scores"]) + row["finish_reasons"].extend(contribution["finish_reasons"]) + row["experiment_metadata"] = merge_experiment_metadata( + existing=row["experiment_metadata"], incoming=contribution["experiment_metadata"], instance_id=instance_id + ) + row["score_sources"].add(stable_string(contribution["score_source"])) + row["warnings"].update(contribution["warnings"]) + + rows: list[dict[str, Any]] = [] + for row in grouped.values(): + row["attempt_scores"] = [float(score) for score in row["attempt_scores"]] + row["finish_reasons"] = [stable_string(reason) for reason in row["finish_reasons"] if stable_string(reason)] + row["experiment_metadata"] = normalize_experiment_metadata(row["experiment_metadata"]) + row["score_sources"] = sorted(value for value in row["score_sources"] if value) + row["warnings"] = sorted(value for value in row["warnings"] if value) + rows.append(row) + + return rows + + +def strip_output_only_rollout_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [{key: value for key, value in row.items() if key not in {"prompt_tokens", "ground_truth"}} for row in rows] + + +def normalize_attempt_scores_for_group( + rows: list[dict[str, Any]], *, allow_nonunit_scores: bool +) -> tuple[list[dict[str, Any]], dict[str, Any], int]: + score_processing = infer_score_processing(rows) + normalized_rows: list[dict[str, Any]] = [] + skipped_nonunit = 0 + + for row in rows: + normalized_scores = normalize_attempt_scores(row["attempt_scores"], score_processing) + if normalized_scores is None: + if allow_nonunit_scores: + kept_row = dict(row) + kept_row["attempt_scores"] = [float(score) for score in row["attempt_scores"]] + kept_row["warnings"] = sorted({*kept_row["warnings"], "nonbinary_reward_scores"}) + normalized_rows.append(kept_row) + else: + skipped_nonunit += 1 + continue + + normalized_row = dict(row) + normalized_row["attempt_scores"] = normalized_scores + normalized_rows.append(normalized_row) + + return normalized_rows, score_processing, skipped_nonunit + + +def infer_score_processing(rows: list[dict[str, Any]]) -> dict[str, Any]: + scores = [float(score) for row in rows for score in row.get("attempt_scores", [])] + score_processing = { + "source_field": "reward", + "output_field": "attempt_scores", + "normalization": "unsupported", + "positive_reward_value": None, + "supports_binary_difficulty": False, + } + + if not scores: + return score_processing + + if all(is_close(score, 0.0) or is_close(score, 1.0) for score in scores): + score_processing["normalization"] = "identity_binary" + score_processing["positive_reward_value"] = 1.0 + score_processing["supports_binary_difficulty"] = True + return score_processing + + if any(score < -EPS for score in scores): + return score_processing + + positive_scores = [score for score in scores if score > EPS] + if not positive_scores: + score_processing["normalization"] = "all_zero_binary" + score_processing["supports_binary_difficulty"] = True + return score_processing + + positive_reward_value = max(positive_scores) + if all(is_close(score, 0.0) or is_close(score, positive_reward_value) for score in scores): + score_processing["normalization"] = "binary_zero_or_constant" + score_processing["positive_reward_value"] = positive_reward_value + score_processing["supports_binary_difficulty"] = True + + return score_processing + + +def normalize_attempt_scores(attempt_scores: list[float], score_processing: dict[str, Any]) -> list[float] | None: + if not score_processing.get("supports_binary_difficulty"): + return None + + normalization = stable_string(score_processing.get("normalization")) + positive_reward_value = score_processing.get("positive_reward_value") + normalized_scores: list[float] = [] + + for score in attempt_scores: + if is_close(score, 0.0): + normalized_scores.append(0.0) + continue + + if normalization == "identity_binary" and is_close(score, 1.0): + normalized_scores.append(1.0) + continue + + if ( + normalization == "binary_zero_or_constant" + and positive_reward_value is not None + and is_close(score, float(positive_reward_value)) + ): + normalized_scores.append(1.0) + continue + + if normalization == "all_zero_binary": + return None + + return None + + return normalized_scores + + +def estimate_beta_prior(rows: list[dict[str, Any]], *, prior_mode: str) -> tuple[BetaPrior | None, int]: + binary_counts = [counts for row in rows if (counts := extract_binary_counts(row["attempt_scores"])) is not None] + if not binary_counts: + return None, 0 + + if prior_mode == "jeffreys": + return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys"), len(binary_counts) + + prior = fit_empirical_beta_prior(binary_counts) + if prior is not None: + return prior, len(binary_counts) + + logger.warning("Falling back to Jeffreys prior after empirical-Bayes fitting failed.") + return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys_fallback"), len(binary_counts) + + +def apply_beta_binomial_difficulty( + rows: list[dict[str, Any]], *, prior: BetaPrior | None, lower_quantile: float, num_buckets: int +) -> list[dict[str, Any]]: + posterior_rows: list[DifficultyPosteriorRow] = [] + + for row in rows: + row["difficulty"] = make_empty_difficulty_payload() + + if prior is None: + continue + + binary_counts = extract_binary_counts(row["attempt_scores"]) + if binary_counts is None: + continue + + success_count, attempt_count = binary_counts + posterior_alpha = success_count + prior.alpha + posterior_beta = attempt_count - success_count + prior.beta + posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta) + posterior_lower_bound = float(beta_distribution.ppf(lower_quantile, posterior_alpha, posterior_beta)) + + row["difficulty"] = { + "value": max(0.0, min(1.0, 1.0 - posterior_lower_bound)), + "posterior_mean": posterior_mean, + "posterior_lower_bound": posterior_lower_bound, + "expected_quantile": None, + "bucket_index": None, + "bucket_count": None, + } + posterior_rows.append( + DifficultyPosteriorRow(row=row, difficulty_alpha=posterior_beta, difficulty_beta=posterior_alpha) + ) + + assign_posterior_difficulty_buckets(posterior_rows, num_buckets=num_buckets) + return rows + + +def make_empty_difficulty_payload() -> dict[str, Any]: + return { + "value": None, + "posterior_mean": None, + "posterior_lower_bound": None, + "expected_quantile": None, + "bucket_index": None, + "bucket_count": None, + } + + +def assign_posterior_difficulty_buckets(posterior_rows: list[DifficultyPosteriorRow], *, num_buckets: int) -> None: + if not posterior_rows: + return + + expected_quantiles = estimate_expected_difficulty_quantiles(posterior_rows) + for posterior_row, expected_quantile in zip(posterior_rows, expected_quantiles, strict=True): + posterior_row.row["difficulty"]["expected_quantile"] = expected_quantile + + if num_buckets <= 0: + return + + effective_bucket_count = min(num_buckets, len(posterior_rows)) + ordered_rows = sorted( + zip(posterior_rows, expected_quantiles, strict=True), + key=lambda item: (item[1], item[0].row["difficulty"]["value"], stable_string(item[0].row["instance_id"])), + ) + base_bucket_size, remainder = divmod(len(ordered_rows), effective_bucket_count) + + cursor = 0 + for bucket_index in range(effective_bucket_count): + bucket_size = base_bucket_size + (1 if bucket_index < remainder else 0) + for posterior_row, _expected_quantile in ordered_rows[cursor : cursor + bucket_size]: + posterior_row.row["difficulty"]["bucket_index"] = bucket_index + posterior_row.row["difficulty"]["bucket_count"] = effective_bucket_count + cursor += bucket_size + + +def estimate_expected_difficulty_quantiles( + posterior_rows: list[DifficultyPosteriorRow], + *, + grid_size: int = POSTERIOR_QUANTILE_GRID_SIZE, + batch_size: int = POSTERIOR_QUANTILE_BATCH_SIZE, +) -> list[float]: + if not posterior_rows: + return [] + if len(posterior_rows) == 1: + return [0.5] + + grid = (np.arange(grid_size, dtype=np.float64) + 0.5) / grid_size + difficulty_alphas = np.asarray([row.difficulty_alpha for row in posterior_rows], dtype=np.float64) + difficulty_betas = np.asarray([row.difficulty_beta for row in posterior_rows], dtype=np.float64) + + mixture_cdf = np.zeros(grid_size, dtype=np.float64) + for start in range(0, len(posterior_rows), batch_size): + stop = start + batch_size + batch_cdf = beta_distribution.cdf( + grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] + ) + mixture_cdf += np.nan_to_num(batch_cdf, nan=0.0, posinf=1.0, neginf=0.0).sum(axis=0) + mixture_cdf /= len(posterior_rows) + + quantiles = np.zeros(len(posterior_rows), dtype=np.float64) + dx = 1.0 / grid_size + for start in range(0, len(posterior_rows), batch_size): + stop = start + batch_size + batch_pdf = beta_distribution.pdf( + grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] + ) + quantiles[start:stop] = np.clip( + np.nan_to_num(batch_pdf, nan=0.0, posinf=0.0, neginf=0.0).dot(mixture_cdf) * dx, 0.0, 1.0 + ) + + return quantiles.tolist() + + +def fit_empirical_beta_prior(binary_counts: list[tuple[int, int]]) -> BetaPrior | None: + total_successes = sum(success_count for success_count, _ in binary_counts) + total_attempts = sum(attempt_count for _, attempt_count in binary_counts) + if total_attempts == 0 or total_successes in {0, total_attempts}: + return None + + mean_rate = total_successes / total_attempts + init_alpha = max(mean_rate * 2.0, 1e-3) + init_beta = max((1.0 - mean_rate) * 2.0, 1e-3) + + def objective(log_params: tuple[float, float]) -> float: + alpha = math.exp(log_params[0]) + beta = math.exp(log_params[1]) + return -sum( + betaln(success_count + alpha, attempt_count - success_count + beta) - betaln(alpha, beta) + for success_count, attempt_count in binary_counts + ) + + result = minimize( + objective, + x0=(math.log(init_alpha), math.log(init_beta)), + method="L-BFGS-B", + bounds=[(-10.0, 10.0), (-10.0, 10.0)], + ) + if not result.success: + logger.warning("Empirical-Bayes fit failed: %s", result.message) + return None + + return BetaPrior(alpha=math.exp(result.x[0]), beta=math.exp(result.x[1]), source="empirical_bayes") + + +def merge_experiment_metadata( + existing: dict[str, Any] | None, incoming: dict[str, Any], *, instance_id: str +) -> dict[str, Any]: + normalized_incoming = normalize_experiment_metadata(incoming) + if existing is None: + return normalized_incoming + + merged = dict(existing) + for key in EXPERIMENT_METADATA_KEYS: + existing_value = merged.get(key) + incoming_value = normalized_incoming.get(key) + if existing_value in {None, ""}: + merged[key] = incoming_value + elif incoming_value in {None, ""} or incoming_value == existing_value: + continue + else: + raise ValueError( + f"Conflicting experiment metadata for instance {instance_id}: " + f"{key}={existing_value!r} vs {incoming_value!r}" + ) + return merged + + +def normalize_experiment_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]: + if metadata is None: + return {key: None for key in EXPERIMENT_METADATA_KEYS} + return {key: metadata.get(key) for key in EXPERIMENT_METADATA_KEYS} + + +def resolve_output_root(output: Path) -> Path: + output_str = str(output) + if output_str.endswith(".schema.json"): + return Path(output_str[: -len(".schema.json")]) + if output_str.endswith(".jsonl"): + return Path(output_str[: -len(".jsonl")]) + if output_str.endswith(".json"): + return Path(output_str[: -len(".json")]) + return output + + +def build_output_paths( + output_root: Path, *, task_name: str, model_name: str | None, dataset_metadata: dict[str, Any] +) -> tuple[Path, Path, Path]: + task_suffix = sanitize_name(task_name) or "unknown-task" + model_suffix = sanitize_name(model_name or "") or "unknown-model" + difficulty_suffix = build_difficulty_filename_suffix(dataset_metadata) + stem = output_root / f"{task_suffix}__{model_suffix}{difficulty_suffix}" + return Path(f"{stem}.jsonl"), Path(f"{stem}.schema.json"), Path(f"{stem}.metadata.json") + + +def write_output_files( + *, output_jsonl: Path, schema_json: Path, metadata_json: Path, dataset: Dataset, dataset_metadata: dict[str, Any] +) -> None: + output_jsonl.parent.mkdir(parents=True, exist_ok=True) + with output_jsonl.open("w") as output_file: + for row in dataset: + output_file.write(json.dumps(make_jsonable(row), ensure_ascii=False) + "\n") + + schema_json.parent.mkdir(parents=True, exist_ok=True) + try: + schema_payload: Any = dataset.features.to_dict() + except AttributeError: + schema_payload = str(dataset.features) + with schema_json.open("w") as output_file: + json.dump(schema_payload, output_file, indent=2, sort_keys=True) + + metadata_json.parent.mkdir(parents=True, exist_ok=True) + with metadata_json.open("w") as output_file: + json.dump(dataset_metadata, output_file, indent=2, sort_keys=True) + + +def build_dataset_metadata( + *, + rows: list[dict[str, Any]], + task_name: str, + model_name: str | None, + requested_prior_mode: str, + requested_bucket_count: int, + lower_quantile: float, + prior: BetaPrior | None, + binary_row_count: int, + score_processing: dict[str, Any], + source_format: dict[str, Any], +) -> dict[str, Any]: + effective_bucket_count = extract_effective_bucket_count(rows) + difficulty_generation = { + "method": DIFFICULTY_GENERATION_METHOD, + "difficulty_value_field": "difficulty.value", + "difficulty_value_definition": "1 - difficulty.posterior_lower_bound", + "bucket_field": "difficulty.bucket_index", + "bucket_count_field": "difficulty.bucket_count", + "bucket_ranking_field": "difficulty.expected_quantile", + "posterior_lower_quantile": lower_quantile, + "bucket_count_requested": requested_bucket_count, + "bucket_count_effective": effective_bucket_count, + "beta_prior_requested": requested_prior_mode, + "beta_prior_used": { + "source": prior.source if prior is not None else None, + "alpha": prior.alpha if prior is not None else None, + "beta": prior.beta if prior is not None else None, + }, + "binary_instance_count": binary_row_count, + "nonbinary_instance_count": max(0, len(rows) - binary_row_count), + } + difficulty_generation["tag"] = build_difficulty_config_tag(difficulty_generation) + return { + "task_name": task_name, + "model_name": model_name, + "row_count": len(rows), + "source_format": dict(source_format), + "score_processing": dict(score_processing), + "difficulty_generation": difficulty_generation, + } + + +def extract_effective_bucket_count(rows: list[dict[str, Any]]) -> int: + effective_bucket_counts = { + difficulty.get("bucket_count") + for row in rows + if isinstance((difficulty := row.get("difficulty")), dict) and difficulty.get("bucket_count") is not None + } + if not effective_bucket_counts: + return 0 + if len(effective_bucket_counts) != 1: + raise ValueError(f"Expected a single effective bucket count, found {sorted(effective_bucket_counts)}") + return next(iter(effective_bucket_counts)) + + +def build_difficulty_filename_suffix(dataset_metadata: dict[str, Any]) -> str: + return f"__{dataset_metadata['difficulty_generation']['tag']}" + + +def build_difficulty_config_tag(difficulty_generation: dict[str, Any]) -> str: + method_token = abbreviate_filename_token( + optional_string(difficulty_generation.get("method")), + aliases=DIFFICULTY_METHOD_FILENAME_ALIASES, + default="diff", + ) + prior_source = optional_string((difficulty_generation.get("beta_prior_used") or {}).get("source")) + prior_token = abbreviate_filename_token(prior_source, aliases=PRIOR_SOURCE_FILENAME_ALIASES, default="none") + quantile_token = format_quantile_token(difficulty_generation["posterior_lower_quantile"]) + bucket_token = format_bucket_token( + requested_count=difficulty_generation["bucket_count_requested"], + effective_count=difficulty_generation["bucket_count_effective"], + ) + return "-".join([method_token, prior_token, quantile_token, bucket_token]) + + +def abbreviate_filename_token(value: str | None, *, aliases: dict[str, str], default: str) -> str: + if not value: + return default + return aliases.get(value, sanitize_name(value)) + + +def format_quantile_token(value: float) -> str: + return f"q{format_filename_number(value * 100.0)}" + + +def format_bucket_token(*, requested_count: int, effective_count: int) -> str: + if requested_count == effective_count: + return f"k{requested_count}" + return f"k{requested_count}e{effective_count}" + + +def annotate_dataset_metadata(dataset: Dataset, dataset_metadata: dict[str, Any]) -> None: + if not hasattr(dataset, "info") or dataset.info is None: + return + dataset.info.description = json.dumps(dataset_metadata, indent=2, sort_keys=True) + + +def validate_args(args: argparse.Namespace) -> None: + if not 0.0 < args.posterior_lower_quantile < 1.0: + raise ValueError("--posterior-lower-quantile must be between 0 and 1.") + if args.difficulty_buckets < 0: + raise ValueError("--difficulty-buckets must be non-negative.") + if args.max_instances is not None and args.max_instances <= 0: + raise ValueError("--max-instances must be positive when provided.") + + +def group_rows_by_task_and_model(rows: list[dict[str, Any]]) -> dict[tuple[str, str | None], list[dict[str, Any]]]: + rows_by_group: dict[tuple[str, str | None], list[dict[str, Any]]] = defaultdict(list) + for row in rows: + experiment_metadata = row.get("experiment_metadata") or {} + task_name = stable_string(row.get("task_name")) + model_name = optional_string(experiment_metadata.get("model_name")) + rows_by_group[(task_name, model_name)].append(row) + return dict(rows_by_group) + + +def read_jsonl(path: Path) -> list[dict[str, Any]]: + with path.open() as input_file: + return [json.loads(line) for line in input_file if line.strip()] + + +def get_base_task_name(task_name: str) -> str: + return task_name.split("@", 1)[0].split(":", 1)[0] + + +def extract_binary_counts(attempt_scores: list[float]) -> tuple[int, int] | None: + if not attempt_scores: + return None + + success_count = 0 + for score in attempt_scores: + if is_close(score, 0.0): + continue + if is_close(score, 1.0): + success_count += 1 + continue + return None + + return success_count, len(attempt_scores) + + +def make_rollout_instance_id( + *, + task_name: str, + prompt_tokens: list[int] | None, + ground_truth: Any, + source_dataset: str | None = None, + source_dataset_id: int | None = None, +) -> str: + if source_dataset is not None and source_dataset_id is not None: + return f"{source_dataset}::{source_dataset_id}" + + if prompt_tokens is None: + raise ValueError("prompt_tokens are required when source row identity is unavailable") + + fingerprint = {"task_name": task_name, "prompt_tokens": prompt_tokens, "ground_truth": make_jsonable(ground_truth)} + digest = hashlib.sha1(canonical_json(fingerprint).encode("utf-8")).hexdigest()[:20] + task_prefix = sanitize_name(task_name) or "unknown" + return f"{task_prefix}::{digest}" + + +def canonical_json(value: Any) -> str: + return json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True, separators=(",", ":")) + + +def make_jsonable(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, list): + return [make_jsonable(item) for item in value] + if isinstance(value, tuple): + return [make_jsonable(item) for item in value] + if isinstance(value, dict): + return {stable_string(key): make_jsonable(item) for key, item in value.items()} + return stable_string(value) + + +def stable_string(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return str(value) + + +def optional_string(value: Any) -> str | None: + text = stable_string(value) + return text or None + + +def serialize_value(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True) + + +def format_filename_number(value: float) -> str: + text = f"{value:.8g}" + return text.replace("-", "m").replace(".", "p") + + +def sanitize_name(value: str) -> str: + return value.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") + + +def is_number(value: Any) -> bool: + return isinstance(value, (int, float)) and not isinstance(value, bool) and not math.isnan(float(value)) + + +def is_close(lhs: float, rhs: float) -> bool: + tolerance = EPS * max(1.0, abs(lhs), abs(rhs)) + return abs(lhs - rhs) <= tolerance + + +if __name__ == "__main__": + main() diff --git a/open_instruct/test_rl_utils.py b/open_instruct/test_rl_utils.py index 9e3fe92799..2303a98b45 100644 --- a/open_instruct/test_rl_utils.py +++ b/open_instruct/test_rl_utils.py @@ -9,7 +9,6 @@ from parameterized import parameterized from open_instruct import rl_utils -from open_instruct.data_types import GenerationResult, RequestInfo PACK_LENGTH = 40 PROMPT_MAX_LEN = 20 @@ -338,64 +337,6 @@ def test_pack_sequences_min_num_batches(self): self.assertGreater(len(seq), 0) -class TestRolloutTraceSaving(unittest.TestCase): - def _make_generation_result(self, reward_scores: list[float]) -> GenerationResult: - num_samples = len(reward_scores) - return GenerationResult( - responses=[[10, sample_idx] for sample_idx in range(num_samples)], - finish_reasons=["stop"] * num_samples, - masks=[[1, 1]] * num_samples, - request_info=RequestInfo( - num_calls=[0] * num_samples, - timeouts=[0] * num_samples, - tool_errors=[""] * num_samples, - tool_outputs=[""] * num_samples, - tool_runtimes=[0.0] * num_samples, - tool_calleds=[False] * num_samples, - tool_call_stats=[[] for _ in range(num_samples)], - rollout_states=[{} for _ in range(num_samples)], - ), - index=3, - prompt_id="prompt_3", - reward_scores=reward_scores, - logprobs=[[0.1, 0.2]] * num_samples, - model_step=7, - ) - - def test_build_rollout_batch_and_advantages_preserves_scores(self): - result = self._make_generation_result([10.0, 0.0]) - - batch, advantages = rl_utils.build_rollout_batch_and_advantages( - result, - prompt_tokens=[1, 2, 3], - ground_truth="4", - dataset_name="math", - raw_query="user: solve 2+2", - advantage_normalization_type="centered", - ) - - self.assertEqual(batch.queries, [[1, 2, 3], [1, 2, 3]]) - self.assertEqual(batch.ground_truths, ["4", "4"]) - self.assertEqual(batch.datasets, ["math", "math"]) - self.assertEqual(batch.indices, [3, 3]) - self.assertEqual(batch.scores, [10.0, 0.0]) - np.testing.assert_allclose(advantages, np.array([5.0, -5.0])) - - def test_build_rollout_batch_and_advantages_raises_without_scores(self): - result = self._make_generation_result([1.0, 0.0]) - result.reward_scores = None - - with self.assertRaises(ValueError): - rl_utils.build_rollout_batch_and_advantages( - result, - prompt_tokens=[1, 2, 3], - ground_truth="4", - dataset_name="math", - raw_query="user: solve 2+2", - advantage_normalization_type="centered", - ) - - class TestMaskedMean(unittest.TestCase): def test_original_axis_int(self): values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) diff --git a/open_instruct/test_rlvr_curriculum.py b/open_instruct/test_rlvr_curriculum.py index 486dfc9684..cd5e8aa767 100644 --- a/open_instruct/test_rlvr_curriculum.py +++ b/open_instruct/test_rlvr_curriculum.py @@ -4,6 +4,7 @@ import unittest from datasets import Dataset +from transformers import HfArgumentParser if "vllm" not in sys.modules: vllm_stub = types.ModuleType("vllm") @@ -56,27 +57,42 @@ def make_plain_hf_dataset(num_examples: int) -> Dataset: class TestDifficultyCurriculumSampler(unittest.TestCase): + def _make_metadata(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumMetadataConfig: + return rlvr_curriculum.DifficultyCurriculumMetadataConfig(**overrides) + + def _make_schedule(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumScheduleConfig: + return rlvr_curriculum.DifficultyCurriculumScheduleConfig( + bootstrap_steps=100, warmup_steps=120, total_steps=200, **overrides + ) + + def _make_adaptive(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumAdaptiveConfig: + return rlvr_curriculum.DifficultyCurriculumAdaptiveConfig(**overrides) + def _make_config(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumConfig: return rlvr_curriculum.DifficultyCurriculumConfig( - enabled=True, easy_focus_steps=100, warmup_steps=120, total_curriculum_steps=200, seed=13, **overrides + metadata=overrides.pop("metadata", self._make_metadata()), + schedule=overrides.pop("schedule", self._make_schedule()), + adaptive=overrides.pop("adaptive", self._make_adaptive()), + seed=13, + **overrides, ) - def _make_sampler(self, dataset, **config_overrides) -> rlvr_curriculum.BetaBinomialDifficultySampler: + def _make_sampler(self, dataset, **config_overrides) -> rlvr_curriculum.DifficultyCurriculumSampler: config = self._make_config(**config_overrides) - return rlvr_curriculum.BetaBinomialDifficultySampler( + return rlvr_curriculum.DifficultyCurriculumSampler( dataset=dataset, num_samples=max(len(dataset), 1), config=config, global_step_getter=lambda: 0 ) def test_missing_metadata_raises_when_strict_metadata(self): dataset = ListDataset([{"index": 0}]) with self.assertRaises(ValueError): - self._make_sampler(dataset, strict_metadata=True) + self._make_sampler(dataset, metadata=self._make_metadata(strict=True)) def test_missing_metadata_falls_back_when_not_strict(self): dataset = ListDataset( [make_difficulty_row(index=0, bucket_index=0, posterior_mean=0.9, bucket_count=5), {"index": 1}] ) - sampler = self._make_sampler(dataset, strict_metadata=False) + sampler = self._make_sampler(dataset, metadata=self._make_metadata(strict=False)) self.assertEqual(sampler.metadata_fallback_count, 1) self.assertIn(1, sampler.bucket_to_indices[2]) @@ -90,11 +106,11 @@ def test_bootstrap_curriculum_heavily_samples_easy_buckets(self): self.assertGreater(early_probs[0] + early_probs[1], 0.75) self.assertGreater(early_probs[0], early_probs[2]) self.assertGreater(early_probs[1], early_probs[2]) - self.assertLessEqual(early_probs[4], sampler.config.min_hard_frac + 1e-6) + self.assertLessEqual(early_probs[4], sampler.config.schedule.min_hard_frac + 1e-6) def test_post_bootstrap_curriculum_returns_to_medium_buckets(self): sampler = self._make_sampler(make_bucket_dataset()) - post_bootstrap_probs = sampler.get_static_bucket_probs(step=sampler.config.easy_focus_steps) + post_bootstrap_probs = sampler.get_static_bucket_probs(step=sampler.config.schedule.bootstrap_steps) self.assertEqual(int(post_bootstrap_probs.argmax()), 2) self.assertGreater(post_bootstrap_probs[2], post_bootstrap_probs[1]) self.assertGreater(post_bootstrap_probs[2], post_bootstrap_probs[3]) @@ -102,7 +118,7 @@ def test_post_bootstrap_curriculum_returns_to_medium_buckets(self): def test_late_curriculum_increases_hard_bucket_probability(self): sampler = self._make_sampler(make_bucket_dataset()) early_probs = sampler.get_bucket_probs(step=0) - late_step = sampler.config.warmup_steps + sampler.config.total_curriculum_steps + late_step = sampler.config.schedule.warmup_steps + sampler.config.schedule.total_steps late_probs = sampler.get_bucket_probs(step=late_step) self.assertGreater(late_probs[4], early_probs[4]) self.assertGreater(late_probs[4], late_probs[2]) @@ -111,7 +127,7 @@ def test_late_curriculum_increases_hard_bucket_probability(self): def test_extremely_hard_example_is_rare_early_but_more_likely_late(self): sampler = self._make_sampler(make_bucket_dataset()) early_probability = sampler.get_example_probability(4, step=0) - late_step = sampler.config.warmup_steps + sampler.config.total_curriculum_steps + late_step = sampler.config.schedule.warmup_steps + sampler.config.schedule.total_steps late_probability = sampler.get_example_probability(4, step=late_step) self.assertLess(early_probability, 0.1) self.assertGreater(late_probability, early_probability) @@ -120,15 +136,15 @@ def test_probabilities_always_sum_to_one(self): sampler = self._make_sampler(make_bucket_dataset()) for step in ( 0, - sampler.config.warmup_steps + 5, - sampler.config.warmup_steps + sampler.config.total_curriculum_steps, + sampler.config.schedule.warmup_steps + 5, + sampler.config.schedule.warmup_steps + sampler.config.schedule.total_steps, ): self.assertAlmostEqual(float(sampler.get_static_bucket_probs(step=step).sum()), 1.0, places=6) self.assertAlmostEqual(float(sampler.get_bucket_probs(step=step).sum()), 1.0, places=6) def test_adaptive_stats_increase_sampling_probability_for_high_signal_bucket(self): sampler = self._make_sampler( - make_bucket_dataset(), adaptive_enabled=True, adaptive_update_every=1, adaptive_blend_weight=0.5 + make_bucket_dataset(), adaptive=self._make_adaptive(enabled=True, update_every=1, blend=0.5) ) static_probs = sampler.get_bucket_probs(step=0) sampler.record_observations( @@ -143,9 +159,7 @@ def test_bootstrap_distribution_is_tunable(self): default_sampler = self._make_sampler(make_bucket_dataset()) tuned_sampler = self._make_sampler( make_bucket_dataset(), - bootstrap_target_bucket_ratio=0.0, - warmup_target_bucket_ratio=0.4, - easy_focus_sigma=0.5, + schedule=self._make_schedule(bootstrap_target=0.0, warmup_target=0.4, bootstrap_sigma=0.5), ) default_probs = default_sampler.get_static_bucket_probs(step=0) @@ -154,14 +168,42 @@ def test_bootstrap_distribution_is_tunable(self): self.assertGreater(tuned_probs[0], default_probs[0]) self.assertLess(tuned_probs[2], default_probs[2]) + def test_curriculum_args_parser_builds_grouped_config(self): + parser = HfArgumentParser((rlvr_curriculum.DifficultyCurriculumArgs,)) + (curriculum_args,) = parser.parse_args_into_dataclasses( + [ + "--curriculum", + "difficulty", + "--curriculum_bootstrap_steps", + "12", + "--curriculum_warmup_steps", + "34", + "--curriculum_total_steps", + "56", + "--curriculum_adaptive", + "true", + "--curriculum_adaptive_blend", + "0.25", + ] + ) + + curriculum_config = curriculum_args.build_curriculum_config(seed=17) + + self.assertIsNotNone(curriculum_config) + assert curriculum_config is not None + self.assertEqual(curriculum_config.schedule.bootstrap_steps, 12) + self.assertEqual(curriculum_config.schedule.warmup_steps, 34) + self.assertEqual(curriculum_config.schedule.total_steps, 56) + self.assertTrue(curriculum_config.adaptive.enabled) + self.assertEqual(curriculum_config.adaptive.blend, 0.25) + self.assertEqual(curriculum_config.seed, 17) + class TestDifficultyCurriculumLoaderIntegration(unittest.TestCase): def test_existing_behavior_is_unchanged_when_curriculum_disabled(self): dataset = make_plain_hf_dataset(20) - config = data_loader.StreamingDataLoaderConfig(difficulty_curriculum_enabled=False) - built_loader = data_loader.build_data_preparation_prompt_dataloader( - dataset=dataset, seed=7, work_dir=tempfile.gettempdir(), config=config + dataset=dataset, seed=7, work_dir=tempfile.gettempdir(), curriculum_config=None ) baseline_loader = data_loader.HFDataLoader( dataset=dataset, diff --git a/open_instruct/test_rollout_traces.py b/open_instruct/test_rollout_traces.py new file mode 100644 index 0000000000..ea0947b634 --- /dev/null +++ b/open_instruct/test_rollout_traces.py @@ -0,0 +1,84 @@ +import unittest + +import numpy as np + +from open_instruct import model_utils, rl_utils +from open_instruct.data_types import GenerationResult, RequestInfo + + +class TestRolloutRecords(unittest.TestCase): + def _make_result(self) -> GenerationResult: + return GenerationResult( + responses=[[10, 11], [12, 13]], + finish_reasons=["stop", "length"], + masks=[[1, 1], [1, 1]], + request_info=RequestInfo( + num_calls=[0, 1], + timeouts=[0, 0], + tool_errors=["", ""], + tool_outputs=["", "ok"], + tool_runtimes=[0.0, 0.1], + tool_calleds=[False, True], + tool_call_stats=[[], []], + rollout_states=[{}, {"done": False}], + ), + index=3, + prompt_id="prompt_3", + logprobs=[[0.1, 0.2], [0.3, 0.4]], + model_step=7, + ) + + def _make_batch(self) -> model_utils.Batch: + return model_utils.Batch( + queries=[[1, 2, 3], [1, 2, 3]], + ground_truths=[[4], [4]], + datasets=["math", "math"], + raw_queries=["user: solve 2+2", "user: solve 2+2"], + decoded_responses=None, + indices=[3, 3], + scores=[10.0, 0.0], + source_row_ids=[11, 11], + source_datasets=["demo", "demo"], + model_steps=[7, 7], + ) + + def test_build_rollout_records_full_format(self): + records = rl_utils.build_rollout_records( + self._make_batch(), + self._make_result(), + np.array([5.0, -5.0]), + step=9, + num_samples_per_prompt=2, + record_format="full", + ) + + self.assertEqual(len(records), 2) + self.assertEqual(records[0]["step"], 9) + self.assertEqual(records[0]["prompt_idx"], 0) + self.assertEqual(records[0]["source_row_id"], 11) + self.assertEqual(records[0]["source_dataset"], "demo") + self.assertEqual(records[0]["response_tokens"], [10, 11]) + self.assertEqual(records[1]["finish_reason"], "length") + self.assertEqual(records[1]["request_info"]["tool_outputs"], "ok") + + def test_build_rollout_records_scores_only_format(self): + records = rl_utils.build_rollout_records( + self._make_batch(), + self._make_result(), + np.array([5.0, -5.0]), + step=9, + num_samples_per_prompt=2, + record_format="scores_only", + ) + + self.assertEqual( + records, + [ + {"dataset": "math", "reward": 10.0, "source_row_id": 11, "source_dataset": "demo"}, + {"dataset": "math", "reward": 0.0, "source_row_id": 11, "source_dataset": "demo"}, + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py index f2c05d15d9..12ebe6a906 100644 --- a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py +++ b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py @@ -8,1529 +8,19 @@ # ] # /// -""" -Build a per-instance difficulty map from open-instruct rollout traces or -Hugging Face datasets with pass-rate aggregates. - -The script accepts one or more local rollout directories, metadata ``.jsonl`` -files, rollout shard ``.jsonl`` files written by ``open_instruct.rl_utils``, -or a Hugging Face dataset that already contains per-row pass counts. For each -prompt instance it: - -1. loads rollout shards written by ``save_rollouts_to_disk()``, including compact score-only shards, - or loads per-row pass counts from a Hub dataset, -2. groups attempts by source dataset identity when available, otherwise by a - deterministic fingerprint over task name, prompt tokens, and ground truth, -3. normalizes binary verifiable rewards from ``{0, C}`` back to ``{0, 1}`` - when possible, -4. fits a Beta prior across binary outcomes and estimates per-item success - rates, and -5. writes a JSONL difficulty file and schema/metadata sidecars. - -Examples: - uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ - --source /tmp/qwen_math_rollouts \ - --task math \ - --output /tmp/qwen_math_difficulty - - uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ - --source /tmp/qwen_math_rollouts/qwen_math_metadata.jsonl \ - --output /tmp/difficulty_map - - uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ - --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ - --hf-split train \ - --output /tmp/dapo_math_qwen3_difficulty -""" +"""Thin CLI wrapper for ``open_instruct.rlvr_difficulty``.""" from __future__ import annotations -import argparse -import hashlib -import json -import math import sys -from collections import defaultdict -from dataclasses import dataclass from pathlib import Path -from typing import Any - -import numpy as np -from datasets import Dataset, load_dataset -from scipy.optimize import minimize -from scipy.special import betaln -from scipy.stats import beta as beta_distribution REPO_ROOT = Path(__file__).resolve().parents[3] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from open_instruct import logger_utils # noqa: E402 - -logger = logger_utils.setup_logger(__name__) - - -EPS = 1e-8 -EXPERIMENT_METADATA_KEYS = ("source_root", "model_name", "experiment_id", "experiment_name") -JEFFREYS_PRIOR_ALPHA = 0.5 -JEFFREYS_PRIOR_BETA = 0.5 -DEFAULT_DIFFICULTY_BUCKETS = 5 -POSTERIOR_QUANTILE_GRID_SIZE = 512 -POSTERIOR_QUANTILE_BATCH_SIZE = 256 -DIFFICULTY_GENERATION_METHOD = "beta_binomial_posterior_quantiles" -DIFFICULTY_METHOD_FILENAME_ALIASES = {DIFFICULTY_GENERATION_METHOD: "bbq"} -PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} -ROLLOUT_SOURCE_FORMAT_KIND = "open_instruct_rollout_traces" -HF_SOURCE_FORMAT_KIND = "hugging_face_dataset_passrate_rows" -ROLLOUT_INSTANCE_ID_DEFINITION = ( - "source_dataset::source_dataset_id when available; otherwise sha1(task_name,prompt_tokens,ground_truth)" -) -HF_INSTANCE_ID_DEFINITION = ( - "dataset_repo_id::row_id_field when a stable row id is available; otherwise dataset_repo_id::row_index" -) -HF_SOURCE_ROW_INDEX_FIELD = "_source_row_index" -HF_OUTPUT_COLUMNS = ("difficulty",) - - -@dataclass(frozen=True) -class BetaPrior: - alpha: float - beta: float - source: str - - -@dataclass(frozen=True) -class RolloutSource: - input_arg: str - root_path: Path - metadata_path: Path - rollout_paths: tuple[Path, ...] - run_name: str - - -@dataclass(frozen=True) -class DifficultyPosteriorRow: - row: dict[str, Any] - difficulty_alpha: float - difficulty_beta: float - - -@dataclass(frozen=True) -class InputRowsBundle: - rows: list[dict[str, Any]] - malformed_records: int - source_format: dict[str, Any] - source_dataset: Dataset | None = None - - -def make_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description="Build a per-instance difficulty map from open-instruct rollout traces or HF pass-rate datasets.", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - source_group = parser.add_mutually_exclusive_group(required=True) - source_group.add_argument( - "--source", - nargs="+", - help="One or more local rollout dirs, *_metadata.jsonl files, or *_rollouts_*.jsonl shards.", - ) - source_group.add_argument( - "--hf-dataset", - type=str, - default=None, - help="Hugging Face dataset repo id containing per-row pass-rate aggregates.", - ) - parser.add_argument("--hf-config", type=str, default=None, help="Optional dataset config for --hf-dataset.") - parser.add_argument("--hf-split", type=str, default="train", help="Input split to load from --hf-dataset.") - parser.add_argument( - "--hf-row-id-field", - type=str, - default="extra_info.index", - help="Dot-path to the stable per-row id field inside --hf-dataset.", - ) - parser.add_argument( - "--hf-task-field", type=str, default="dataset", help="Dot-path to the task/verifier field in --hf-dataset." - ) - parser.add_argument( - "--hf-model-field", - type=str, - default="generator_model", - help="Dot-path to the generator model field in --hf-dataset.", - ) - parser.add_argument( - "--hf-pass-count-field", - type=str, - default="pass_count", - help="Dot-path to the integer pass-count field in --hf-dataset.", - ) - parser.add_argument( - "--hf-attempt-count-field", - type=str, - default="num_samples", - help="Dot-path to the total-attempt-count field in --hf-dataset.", - ) - parser.add_argument( - "--hf-pass-rate-field", - type=str, - default="pass_rate", - help="Optional dot-path to a pass-rate or fraction field used for validation/fallback in --hf-dataset.", - ) - parser.add_argument( - "--task", - action="append", - default=[], - help="Optional task filter. Matches the rollout trace dataset/verifier source.", - ) - parser.add_argument( - "--output", - type=Path, - required=True, - help=( - "Output directory or path-like root. The script writes one file per task/model inside it as " - "____.jsonl plus matching .schema.json and .metadata.json sidecars." - ), - ) - parser.add_argument( - "--push-to-hub", type=str, default=None, help="Optional dataset repo id to push the validated rows to." - ) - parser.add_argument("--split", type=str, default="train", help="Split to use with --push-to-hub.") - parser.add_argument( - "--strict", action="store_true", help="Fail if a rollout record is malformed or required files are missing." - ) - parser.add_argument( - "--allow-nonunit-scores", - action="store_true", - help="Keep rows whose rewards cannot be normalized to binary correctness. Difficulty will be null for them.", - ) - parser.add_argument( - "--max-instances", - type=int, - default=None, - help="Optional cap for the number of resolved instances written (useful for smoke tests).", - ) - parser.add_argument( - "--beta-prior", - choices=["empirical-bayes", "jeffreys"], - default="empirical-bayes", - help="Global Beta prior to use for smoothing binary solve rates.", - ) - parser.add_argument( - "--posterior-lower-quantile", - type=float, - default=0.1, - help="Lower posterior quantile used to define difficulty as 1 - quantile.", - ) - parser.add_argument( - "--difficulty-buckets", - type=int, - default=DEFAULT_DIFFICULTY_BUCKETS, - help=( - "Number of posterior-aware quantile buckets to assign for stratification. " - "Set to 0 to skip discrete bucket assignment." - ), - ) - return parser - - -def main(argv: list[str] | None = None) -> None: - args = make_parser().parse_args(argv) - validate_args(args) - task_filters = set(args.task) - output_root = resolve_output_root(args.output) - - input_rows = load_input_rows(args, task_filters=task_filters) - - if not input_rows.rows: - raise ValueError("No resolved per-instance rows were produced.") - - rows = sorted( - input_rows.rows, - key=lambda row: ( - stable_string(row.get("task_name")), - stable_string((row.get("experiment_metadata") or {}).get("model_name")), - stable_string(row.get("instance_id")), - ), - ) - if args.max_instances is not None: - rows = rows[: args.max_instances] - - rows_by_group = group_rows_by_task_and_model(rows) - if args.push_to_hub is not None and len(rows_by_group) != 1: - raise ValueError( - "--push-to-hub requires a single task/model output. Filter with --task or use a source with one task." - ) - - skipped_nonunit = 0 - written_outputs: list[tuple[str, str | None, int, Path, Path, Path]] = [] - - for (task_name, model_name), group_rows in sorted( - rows_by_group.items(), key=lambda item: (item[0][0], stable_string(item[0][1])) - ): - group_rows, score_processing, group_skipped_nonunit = normalize_attempt_scores_for_group( - group_rows, allow_nonunit_scores=args.allow_nonunit_scores - ) - if input_rows.source_format["kind"] == HF_SOURCE_FORMAT_KIND: - score_processing["source_field"] = ",".join( - field_name - for field_name in ( - input_rows.source_format.get("pass_count_field"), - input_rows.source_format.get("attempt_count_field"), - input_rows.source_format.get("pass_rate_field"), - ) - if field_name - ) - skipped_nonunit += group_skipped_nonunit - - if not group_rows: - logger.warning( - "Skipping task=%s model=%s because no rows remained after reward normalization.", task_name, model_name - ) - continue - - prior, binary_row_count = estimate_beta_prior(group_rows, prior_mode=args.beta_prior) - group_rows = apply_beta_binomial_difficulty( - group_rows, prior=prior, lower_quantile=args.posterior_lower_quantile, num_buckets=args.difficulty_buckets - ) - if input_rows.source_dataset is None: - ordered_group_rows = sorted(group_rows, key=lambda row: row["instance_id"]) - output_rows = strip_output_only_rollout_fields(ordered_group_rows) - dataset = Dataset.from_list(output_rows) - else: - ordered_group_rows = sort_hf_group_rows(group_rows) - output_rows = strip_internal_fields(ordered_group_rows) - dataset = build_hf_output_dataset(input_rows.source_dataset, ordered_group_rows) - - dataset_metadata = build_dataset_metadata( - rows=output_rows, - task_name=task_name, - model_name=model_name, - requested_prior_mode=args.beta_prior, - requested_bucket_count=args.difficulty_buckets, - lower_quantile=args.posterior_lower_quantile, - prior=prior, - binary_row_count=binary_row_count, - score_processing=score_processing, - source_format=input_rows.source_format, - ) - - if prior is not None: - logger.info( - "Using %s Beta prior alpha=%.4f beta=%.4f across %s binary instances for task=%s model=%s.", - prior.source, - prior.alpha, - prior.beta, - binary_row_count, - task_name, - model_name, - ) - else: - logger.warning( - "No binary instances were available for Beta-Binomial difficulty estimation for task=%s model=%s.", - task_name, - model_name, - ) - - annotate_dataset_metadata(dataset, dataset_metadata) - output_jsonl, schema_json, metadata_json = build_output_paths( - output_root, task_name=task_name, model_name=model_name, dataset_metadata=dataset_metadata - ) - write_output_files( - output_jsonl=output_jsonl, - schema_json=schema_json, - metadata_json=metadata_json, - dataset=dataset, - dataset_metadata=dataset_metadata, - ) - - if args.push_to_hub is not None: - dataset.push_to_hub(args.push_to_hub, split=args.split, private=True) - - written_outputs.append((task_name, model_name, len(output_rows), output_jsonl, schema_json, metadata_json)) - logger.info( - "Wrote %s rows for task=%s model=%s to %s, %s, and %s.", - len(output_rows), - task_name, - model_name, - output_jsonl, - schema_json, - metadata_json, - ) - - logger.info( - "Finished writing %s output file groups (%s malformed rollout records, %s skipped due to unsupported scores).", - len(written_outputs), - input_rows.malformed_records, - skipped_nonunit, - ) - - -def load_input_rows(args: argparse.Namespace, *, task_filters: set[str]) -> InputRowsBundle: - if args.hf_dataset is not None: - return load_hf_dataset_rows( - dataset_name=args.hf_dataset, - config_name=args.hf_config, - split=args.hf_split, - task_filters=task_filters, - strict=args.strict, - row_id_field=args.hf_row_id_field, - task_field=args.hf_task_field, - model_field=args.hf_model_field, - pass_count_field=args.hf_pass_count_field, - attempt_count_field=args.hf_attempt_count_field, - pass_rate_field=args.hf_pass_rate_field, - ) - - if not args.source: - raise ValueError("Expected --source when --hf-dataset is not provided.") - - source_runs = discover_rollout_sources(args.source) - if not source_runs: - raise ValueError("No rollout trace sources were found.") - - contributions: list[dict[str, Any]] = [] - malformed_records = 0 - - for source_run in source_runs: - logger.info( - "Loading %s (run=%s, metadata=%s, shards=%s)", - source_run.input_arg, - source_run.run_name, - source_run.metadata_path, - len(source_run.rollout_paths), - ) - run_contributions, run_malformed = build_contributions_for_source( - source_run=source_run, task_filters=task_filters, strict=args.strict - ) - contributions.extend(run_contributions) - malformed_records += run_malformed - - return InputRowsBundle( - rows=aggregate_contributions(contributions), - malformed_records=malformed_records, - source_format=build_rollout_source_format_metadata(), - ) - - -def load_hf_dataset_rows( - *, - dataset_name: str, - config_name: str | None, - split: str, - task_filters: set[str], - strict: bool, - row_id_field: str, - task_field: str, - model_field: str, - pass_count_field: str, - attempt_count_field: str, - pass_rate_field: str | None, -) -> InputRowsBundle: - logger.info( - "Loading Hugging Face dataset %s (config=%s, split=%s).", dataset_name, config_name or "default", split - ) - - if config_name: - source_dataset = load_dataset(dataset_name, config_name, split=split) - else: - source_dataset = load_dataset(dataset_name, split=split) - - rows: list[dict[str, Any]] = [] - malformed_records = 0 - - for row_index, source_row in enumerate(source_dataset): - try: - row = build_hf_dataset_row( - source_row=source_row, - source_row_index=row_index, - dataset_name=dataset_name, - config_name=config_name, - split=split, - row_id_field=row_id_field, - task_field=task_field, - model_field=model_field, - pass_count_field=pass_count_field, - attempt_count_field=attempt_count_field, - pass_rate_field=pass_rate_field, - ) - except Exception as exc: - malformed_records += 1 - message = f"Malformed HF dataset row {dataset_name}[{split}][{row_index}]: {exc}" - if strict: - raise ValueError(message) from exc - logger.warning(message) - continue - - task_name = stable_string(row.get("task_name")) - if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: - continue - rows.append(row) - - return InputRowsBundle( - rows=rows, - malformed_records=malformed_records, - source_format=build_hf_source_format_metadata( - dataset_name=dataset_name, - config_name=config_name, - split=split, - row_id_field=row_id_field, - task_field=task_field, - model_field=model_field, - pass_count_field=pass_count_field, - attempt_count_field=attempt_count_field, - pass_rate_field=pass_rate_field, - ), - source_dataset=source_dataset, - ) - - -def build_hf_dataset_row( - *, - source_row: dict[str, Any], - source_row_index: int, - dataset_name: str, - config_name: str | None, - split: str, - row_id_field: str, - task_field: str, - model_field: str, - pass_count_field: str, - attempt_count_field: str, - pass_rate_field: str | None, -) -> dict[str, Any]: - task_name = normalize_task_name(get_nested_field(source_row, task_field)) - if task_name is None: - raise ValueError(f"missing task field {task_field!r}") - - source_row_id = normalize_identifier(get_nested_field(source_row, row_id_field)) or str(source_row_index) - pass_count, attempt_count = extract_hf_attempt_summary( - row=source_row, - pass_count_field=pass_count_field, - attempt_count_field=attempt_count_field, - pass_rate_field=pass_rate_field, - ) - model_name = optional_string(get_nested_field(source_row, model_field)) - - return { - HF_SOURCE_ROW_INDEX_FIELD: source_row_index, - "instance_id": make_hf_instance_id(dataset_name=dataset_name, source_row_id=source_row_id), - "task_name": task_name, - "base_task_name": get_base_task_name(task_name), - "source_dataset": dataset_name, - "source_row_id": source_row_id, - "attempt_scores": expand_binary_attempt_scores(pass_count=pass_count, attempt_count=attempt_count), - "finish_reasons": [], - "experiment_metadata": { - "source_root": format_hf_source_locator(dataset_name=dataset_name, config_name=config_name, split=split), - "model_name": model_name, - "experiment_id": None, - "experiment_name": dataset_name, - }, - "score_sources": [task_name], - "warnings": [], - } - - -def build_rollout_source_format_metadata() -> dict[str, Any]: - return { - "kind": ROLLOUT_SOURCE_FORMAT_KIND, - "task_field": "dataset", - "score_field": "reward", - "source_dataset_field": "source_dataset", - "source_dataset_id_field": "source_dataset_id", - "source_row_id_field": "source_row_id", - "instance_id_definition": ROLLOUT_INSTANCE_ID_DEFINITION, - } - - -def build_hf_source_format_metadata( - *, - dataset_name: str, - config_name: str | None, - split: str, - row_id_field: str, - task_field: str, - model_field: str, - pass_count_field: str, - attempt_count_field: str, - pass_rate_field: str | None, -) -> dict[str, Any]: - return { - "kind": HF_SOURCE_FORMAT_KIND, - "dataset_repo_id": dataset_name, - "config_name": config_name, - "split": split, - "row_id_field": row_id_field, - "task_field": task_field, - "model_field": model_field, - "pass_count_field": pass_count_field, - "attempt_count_field": attempt_count_field, - "pass_rate_field": pass_rate_field, - "instance_id_definition": HF_INSTANCE_ID_DEFINITION, - } - - -def format_hf_source_locator(*, dataset_name: str, config_name: str | None, split: str) -> str: - config_token = config_name or "default" - return f"hf://{dataset_name}/{config_token}/{split}" - - -def make_hf_instance_id(*, dataset_name: str, source_row_id: str) -> str: - return f"{dataset_name}::{source_row_id}" - - -def sort_hf_group_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - return sorted(rows, key=lambda row: row[HF_SOURCE_ROW_INDEX_FIELD]) - - -def build_hf_output_dataset(source_dataset: Dataset, rows: list[dict[str, Any]]) -> Dataset: - ordered_rows = sort_hf_group_rows(rows) - dataset = source_dataset.select([row[HF_SOURCE_ROW_INDEX_FIELD] for row in ordered_rows]) - - for column_name in HF_OUTPUT_COLUMNS: - values = [make_jsonable(row.get(column_name)) for row in ordered_rows] - if column_name in dataset.column_names: - dataset = dataset.remove_columns(column_name) - dataset = dataset.add_column(column_name, values) - - return dataset - - -def strip_internal_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [{key: value for key, value in row.items() if key != HF_SOURCE_ROW_INDEX_FIELD} for row in rows] - - -def get_nested_field(value: Any, field_path: str) -> Any: - if not field_path: - return value - - current = value - for field_name in field_path.split("."): - if not isinstance(current, dict) or field_name not in current: - return None - current = current[field_name] - return current - - -def normalize_identifier(value: Any) -> str | None: - if value is None or isinstance(value, bool): - return None - text = stable_string(value).strip() - return text or None - - -def normalize_nonnegative_int(value: Any) -> int | None: - if value is None or isinstance(value, bool): - return None - if isinstance(value, int): - return value if value >= 0 else None - if isinstance(value, float): - if not math.isfinite(value) or not value.is_integer() or value < 0: - return None - return int(value) - if isinstance(value, str): - stripped = value.strip() - if not stripped: - return None - try: - parsed = int(stripped) - except ValueError: - return None - return parsed if parsed >= 0 else None - return None - - -def parse_pass_rate_value(value: Any) -> tuple[int | None, int | None, float | None]: - if value is None: - return None, None, None - if is_number(value): - rate = float(value) - if 0.0 <= rate <= 1.0: - return None, None, rate - raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") - if not isinstance(value, str): - raise ValueError(f"unsupported pass-rate value {value!r}") - - stripped = value.strip() - if not stripped: - return None, None, None - - if "/" in stripped: - numerator_text, denominator_text = stripped.split("/", 1) - numerator = normalize_nonnegative_int(numerator_text) - denominator = normalize_nonnegative_int(denominator_text) - if numerator is None or denominator is None or numerator > denominator: - raise ValueError(f"invalid pass-rate fraction {value!r}") - rate = 0.0 if denominator == 0 else numerator / denominator - return numerator, denominator, rate - - try: - rate = float(stripped) - except ValueError as exc: - raise ValueError(f"invalid pass-rate value {value!r}") from exc - if not math.isfinite(rate) or rate < 0.0 or rate > 1.0: - raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") - return None, None, rate - - -def extract_hf_attempt_summary( - *, row: dict[str, Any], pass_count_field: str, attempt_count_field: str, pass_rate_field: str | None -) -> tuple[int, int]: - pass_count = normalize_nonnegative_int(get_nested_field(row, pass_count_field)) - attempt_count = normalize_nonnegative_int(get_nested_field(row, attempt_count_field)) - - parsed_pass_count = None - parsed_attempt_count = None - parsed_pass_rate = None - if pass_rate_field: - parsed_pass_count, parsed_attempt_count, parsed_pass_rate = parse_pass_rate_value( - get_nested_field(row, pass_rate_field) - ) - - if pass_count is None and parsed_pass_count is not None: - pass_count = parsed_pass_count - if attempt_count is None and parsed_attempt_count is not None: - attempt_count = parsed_attempt_count - - if pass_count is None or attempt_count is None: - raise ValueError( - f"missing pass-count summary fields {pass_count_field!r}/{attempt_count_field!r}" - f"{f' or parseable {pass_rate_field!r}' if pass_rate_field else ''}" - ) - if attempt_count <= 0: - raise ValueError(f"attempt count must be positive, received {attempt_count}") - if pass_count > attempt_count: - raise ValueError(f"pass count {pass_count} exceeds attempt count {attempt_count}") - - if parsed_pass_count is not None and parsed_pass_count != pass_count: - raise ValueError(f"pass-count field {pass_count_field!r} disagrees with {pass_rate_field!r}") - if parsed_attempt_count is not None and parsed_attempt_count != attempt_count: - raise ValueError(f"attempt-count field {attempt_count_field!r} disagrees with {pass_rate_field!r}") - if parsed_pass_rate is not None and not is_close(pass_count / attempt_count, parsed_pass_rate): - raise ValueError( - f"pass-count fields {pass_count_field!r}/{attempt_count_field!r} disagree with {pass_rate_field!r}" - ) - - return pass_count, attempt_count - - -def expand_binary_attempt_scores(*, pass_count: int, attempt_count: int) -> list[float]: - return [1.0] * pass_count + [0.0] * (attempt_count - pass_count) - - -def discover_rollout_sources(sources: list[str]) -> list[RolloutSource]: - discovered: dict[Path, RolloutSource] = {} - - for source in sources: - source_path = Path(source) - if not source_path.exists(): - raise FileNotFoundError(f"Could not find source path {source}") - - if source_path.is_dir(): - metadata_paths = sorted(source_path.rglob("*_metadata.jsonl")) - if not metadata_paths: - raise FileNotFoundError(f"Could not find *_metadata.jsonl under {source}") - for metadata_path in metadata_paths: - rollout_source = build_rollout_source_from_metadata(metadata_path, input_arg=source) - discovered[rollout_source.metadata_path] = rollout_source - continue - - if source_path.name.endswith("_metadata.jsonl"): - rollout_source = build_rollout_source_from_metadata(source_path, input_arg=source) - discovered[rollout_source.metadata_path] = rollout_source - continue - - if source_path.suffix == ".jsonl" and "_rollouts_" in source_path.name: - rollout_source = build_rollout_source_from_rollout(source_path, input_arg=source) - discovered[rollout_source.metadata_path] = rollout_source - continue - - raise ValueError( - f"Unsupported source path {source}. Expected a directory, *_metadata.jsonl, or *_rollouts_*.jsonl." - ) - - return sorted(discovered.values(), key=lambda source_run: (str(source_run.root_path), source_run.run_name)) - - -def build_rollout_source_from_metadata(metadata_path: Path, *, input_arg: str) -> RolloutSource: - run_name = parse_run_name_from_metadata_path(metadata_path) - rollout_paths = tuple(sorted(metadata_path.parent.glob(f"{run_name}_rollouts_*.jsonl"))) - if not rollout_paths: - raise FileNotFoundError(f"Could not find rollout shards for run {run_name} next to {metadata_path}") - return RolloutSource( - input_arg=input_arg, - root_path=metadata_path.parent.absolute(), - metadata_path=metadata_path.absolute(), - rollout_paths=rollout_paths, - run_name=run_name, - ) - - -def build_rollout_source_from_rollout(rollout_path: Path, *, input_arg: str) -> RolloutSource: - run_name = parse_run_name_from_rollout_path(rollout_path) - metadata_path = rollout_path.parent / f"{run_name}_metadata.jsonl" - if not metadata_path.exists(): - raise FileNotFoundError(f"Could not find metadata file {metadata_path} for rollout shard {rollout_path}") - return build_rollout_source_from_metadata(metadata_path, input_arg=input_arg) - - -def parse_run_name_from_metadata_path(metadata_path: Path) -> str: - suffix = "_metadata.jsonl" - if not metadata_path.name.endswith(suffix): - raise ValueError(f"Metadata path must end with {suffix}: {metadata_path}") - return metadata_path.name[: -len(suffix)] - - -def parse_run_name_from_rollout_path(rollout_path: Path) -> str: - marker = "_rollouts_" - if marker not in rollout_path.name: - raise ValueError(f"Rollout shard filename must contain {marker}: {rollout_path}") - return rollout_path.name.split(marker, 1)[0] - - -def build_contributions_for_source( - *, source_run: RolloutSource, task_filters: set[str], strict: bool -) -> tuple[list[dict[str, Any]], int]: - run_metadata = read_rollout_metadata(source_run.metadata_path, fallback_run_name=source_run.run_name) - contributions: list[dict[str, Any]] = [] - malformed_records = 0 - - for rollout_path in source_run.rollout_paths: - for line_number, record in enumerate(read_jsonl(rollout_path), start=1): - try: - contribution = build_rollout_contribution( - record=record, source_run=source_run, run_metadata=run_metadata - ) - except Exception as exc: - malformed_records += 1 - message = f"Malformed rollout record in {rollout_path}:{line_number}: {exc}" - if strict: - raise ValueError(message) from exc - logger.warning(message) - continue - - task_name = stable_string(contribution.get("task_name")) - if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: - continue - contributions.append(contribution) - - return contributions, malformed_records - - -def read_rollout_metadata(metadata_path: Path, *, fallback_run_name: str) -> dict[str, Any]: - rows = read_jsonl(metadata_path) - if not rows: - raise ValueError(f"Metadata file is empty: {metadata_path}") - if len(rows) > 1: - logger.warning("Expected one metadata row in %s but found %s. Using the first row.", metadata_path, len(rows)) - - metadata = rows[0] - return { - "run_name": optional_string(metadata.get("run_name")) or fallback_run_name, - "model_name": optional_string(metadata.get("model_name")), - "experiment_id": optional_string(metadata.get("experiment_id")), - "git_commit": optional_string(metadata.get("git_commit")), - "timestamp": optional_string(metadata.get("timestamp")), - } - - -def build_rollout_contribution( - *, record: dict[str, Any], source_run: RolloutSource, run_metadata: dict[str, Any] -) -> dict[str, Any]: - task_name = normalize_task_name(record.get("dataset")) - if task_name is None: - raise ValueError("missing dataset/verifier source") - - source_dataset = normalize_source_dataset(record.get("source_dataset")) - source_dataset_id = extract_source_dataset_id(record) - - prompt_tokens = normalize_token_list(record.get("prompt_tokens")) - if prompt_tokens is None and (source_dataset is None or source_dataset_id is None): - raise ValueError("missing prompt_tokens and source dataset identity (source_dataset/source_row_id)") - - reward = extract_numeric_reward(record.get("reward")) - if reward is None: - raise ValueError("missing or invalid reward") - - ground_truth = make_jsonable(record.get("ground_truth")) - finish_reason = optional_string(record.get("finish_reason")) - - return { - "instance_id": make_rollout_instance_id( - task_name=task_name, - prompt_tokens=prompt_tokens, - ground_truth=ground_truth, - source_dataset=source_dataset, - source_dataset_id=source_dataset_id, - ), - "task_name": task_name, - "base_task_name": get_base_task_name(task_name), - "prompt_tokens": prompt_tokens, - "ground_truth": ground_truth, - "source_dataset": source_dataset, - "source_dataset_id": source_dataset_id, - "score_source": task_name, - "attempt_scores": [reward], - "finish_reasons": [finish_reason] if finish_reason else [], - "experiment_metadata": { - "source_root": str(source_run.root_path), - "model_name": run_metadata["model_name"], - "experiment_id": run_metadata["experiment_id"], - "experiment_name": run_metadata["run_name"], - }, - "warnings": extract_rollout_warnings(record.get("request_info")), - } - - -def normalize_task_name(value: Any) -> str | None: - if value is None: - return None - if isinstance(value, str): - return value - if isinstance(value, (list, tuple)) and len(value) == 1: - return normalize_task_name(value[0]) - serialized = serialize_value(value) - return serialized or None - - -def normalize_source_dataset(value: Any) -> str | None: - if value is None: - return None - if isinstance(value, str): - return value - if isinstance(value, (list, tuple)) and len(value) == 1: - return normalize_source_dataset(value[0]) - serialized = serialize_value(value) - return serialized or None - - -def extract_source_dataset_id(record: dict[str, Any]) -> int | None: - for field_name in ("source_dataset_id", "source_row_id"): - source_dataset_id = normalize_source_dataset_id(record.get(field_name)) - if source_dataset_id is not None: - return source_dataset_id - return None - - -def normalize_source_dataset_id(value: Any) -> int | None: - return normalize_nonnegative_int(value) - - -def normalize_token_list(value: Any) -> list[int] | None: - if not isinstance(value, list): - return None - - tokens: list[int] = [] - for item in value: - if isinstance(item, bool) or not isinstance(item, (int, float)): - return None - tokens.append(int(item)) - return tokens - - -def extract_numeric_reward(value: Any) -> float | None: - if not is_number(value): - return None - return float(value) - - -def extract_rollout_warnings(request_info: Any) -> list[str]: - if not isinstance(request_info, dict): - return [] - - warnings: list[str] = [] - if request_info.get("timeouts"): - warnings.append("timeout") - if optional_string(request_info.get("tool_errors")): - warnings.append("tool_error") - return warnings - - -def aggregate_contributions(contributions: list[dict[str, Any]]) -> list[dict[str, Any]]: - grouped: dict[str, dict[str, Any]] = {} - - for contribution in contributions: - instance_id = contribution["instance_id"] - if instance_id not in grouped: - grouped[instance_id] = { - key: value - for key, value in contribution.items() - if key not in {"attempt_scores", "finish_reasons", "experiment_metadata", "warnings", "score_source"} - } - grouped[instance_id]["attempt_scores"] = [] - grouped[instance_id]["finish_reasons"] = [] - grouped[instance_id]["experiment_metadata"] = None - grouped[instance_id]["score_sources"] = set() - grouped[instance_id]["warnings"] = set() - - row = grouped[instance_id] - row["attempt_scores"].extend(float(score) for score in contribution["attempt_scores"]) - row["finish_reasons"].extend(contribution["finish_reasons"]) - row["experiment_metadata"] = merge_experiment_metadata( - existing=row["experiment_metadata"], incoming=contribution["experiment_metadata"], instance_id=instance_id - ) - row["score_sources"].add(stable_string(contribution["score_source"])) - row["warnings"].update(contribution["warnings"]) - - rows: list[dict[str, Any]] = [] - for row in grouped.values(): - row["attempt_scores"] = [float(score) for score in row["attempt_scores"]] - row["finish_reasons"] = [stable_string(reason) for reason in row["finish_reasons"] if stable_string(reason)] - row["experiment_metadata"] = normalize_experiment_metadata(row["experiment_metadata"]) - row["score_sources"] = sorted(value for value in row["score_sources"] if value) - row["warnings"] = sorted(value for value in row["warnings"] if value) - rows.append(row) - - return rows - - -def strip_output_only_rollout_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [{key: value for key, value in row.items() if key not in {"prompt_tokens", "ground_truth"}} for row in rows] - - -def normalize_attempt_scores_for_group( - rows: list[dict[str, Any]], *, allow_nonunit_scores: bool -) -> tuple[list[dict[str, Any]], dict[str, Any], int]: - score_processing = infer_score_processing(rows) - normalized_rows: list[dict[str, Any]] = [] - skipped_nonunit = 0 - - for row in rows: - normalized_scores = normalize_attempt_scores(row["attempt_scores"], score_processing) - if normalized_scores is None: - if allow_nonunit_scores: - kept_row = dict(row) - kept_row["attempt_scores"] = [float(score) for score in row["attempt_scores"]] - kept_row["warnings"] = sorted({*kept_row["warnings"], "nonbinary_reward_scores"}) - normalized_rows.append(kept_row) - else: - skipped_nonunit += 1 - continue - - normalized_row = dict(row) - normalized_row["attempt_scores"] = normalized_scores - normalized_rows.append(normalized_row) - - return normalized_rows, score_processing, skipped_nonunit - - -def infer_score_processing(rows: list[dict[str, Any]]) -> dict[str, Any]: - scores = [float(score) for row in rows for score in row.get("attempt_scores", [])] - score_processing = { - "source_field": "reward", - "output_field": "attempt_scores", - "normalization": "unsupported", - "positive_reward_value": None, - "supports_binary_difficulty": False, - } - - if not scores: - return score_processing - - if all(is_close(score, 0.0) or is_close(score, 1.0) for score in scores): - score_processing["normalization"] = "identity_binary" - score_processing["positive_reward_value"] = 1.0 - score_processing["supports_binary_difficulty"] = True - return score_processing - - if any(score < -EPS for score in scores): - return score_processing - - positive_scores = [score for score in scores if score > EPS] - if not positive_scores: - score_processing["normalization"] = "all_zero_binary" - score_processing["supports_binary_difficulty"] = True - return score_processing - - positive_reward_value = max(positive_scores) - if all(is_close(score, 0.0) or is_close(score, positive_reward_value) for score in scores): - score_processing["normalization"] = "binary_zero_or_constant" - score_processing["positive_reward_value"] = positive_reward_value - score_processing["supports_binary_difficulty"] = True - - return score_processing - - -def normalize_attempt_scores(attempt_scores: list[float], score_processing: dict[str, Any]) -> list[float] | None: - if not score_processing.get("supports_binary_difficulty"): - return None - - normalization = stable_string(score_processing.get("normalization")) - positive_reward_value = score_processing.get("positive_reward_value") - normalized_scores: list[float] = [] - - for score in attempt_scores: - if is_close(score, 0.0): - normalized_scores.append(0.0) - continue - - if normalization == "identity_binary" and is_close(score, 1.0): - normalized_scores.append(1.0) - continue - - if ( - normalization == "binary_zero_or_constant" - and positive_reward_value is not None - and is_close(score, float(positive_reward_value)) - ): - normalized_scores.append(1.0) - continue - - if normalization == "all_zero_binary": - return None - - return None - - return normalized_scores - - -def estimate_beta_prior(rows: list[dict[str, Any]], *, prior_mode: str) -> tuple[BetaPrior | None, int]: - binary_counts = [counts for row in rows if (counts := extract_binary_counts(row["attempt_scores"])) is not None] - if not binary_counts: - return None, 0 - - if prior_mode == "jeffreys": - return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys"), len(binary_counts) - - prior = fit_empirical_beta_prior(binary_counts) - if prior is not None: - return prior, len(binary_counts) - - logger.warning("Falling back to Jeffreys prior after empirical-Bayes fitting failed.") - return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys_fallback"), len(binary_counts) - - -def apply_beta_binomial_difficulty( - rows: list[dict[str, Any]], *, prior: BetaPrior | None, lower_quantile: float, num_buckets: int -) -> list[dict[str, Any]]: - posterior_rows: list[DifficultyPosteriorRow] = [] - - for row in rows: - row["difficulty"] = make_empty_difficulty_payload() - - if prior is None: - continue - - binary_counts = extract_binary_counts(row["attempt_scores"]) - if binary_counts is None: - continue - - success_count, attempt_count = binary_counts - posterior_alpha = success_count + prior.alpha - posterior_beta = attempt_count - success_count + prior.beta - posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta) - posterior_lower_bound = float(beta_distribution.ppf(lower_quantile, posterior_alpha, posterior_beta)) - - row["difficulty"] = { - "value": max(0.0, min(1.0, 1.0 - posterior_lower_bound)), - "posterior_mean": posterior_mean, - "posterior_lower_bound": posterior_lower_bound, - "expected_quantile": None, - "bucket_index": None, - "bucket_count": None, - } - posterior_rows.append( - DifficultyPosteriorRow(row=row, difficulty_alpha=posterior_beta, difficulty_beta=posterior_alpha) - ) - - assign_posterior_difficulty_buckets(posterior_rows, num_buckets=num_buckets) - return rows - - -def make_empty_difficulty_payload() -> dict[str, Any]: - return { - "value": None, - "posterior_mean": None, - "posterior_lower_bound": None, - "expected_quantile": None, - "bucket_index": None, - "bucket_count": None, - } - - -def assign_posterior_difficulty_buckets(posterior_rows: list[DifficultyPosteriorRow], *, num_buckets: int) -> None: - if not posterior_rows: - return - - expected_quantiles = estimate_expected_difficulty_quantiles(posterior_rows) - for posterior_row, expected_quantile in zip(posterior_rows, expected_quantiles, strict=True): - posterior_row.row["difficulty"]["expected_quantile"] = expected_quantile - - if num_buckets <= 0: - return - - effective_bucket_count = min(num_buckets, len(posterior_rows)) - ordered_rows = sorted( - zip(posterior_rows, expected_quantiles, strict=True), - key=lambda item: (item[1], item[0].row["difficulty"]["value"], stable_string(item[0].row["instance_id"])), - ) - base_bucket_size, remainder = divmod(len(ordered_rows), effective_bucket_count) - - cursor = 0 - for bucket_index in range(effective_bucket_count): - bucket_size = base_bucket_size + (1 if bucket_index < remainder else 0) - for posterior_row, _expected_quantile in ordered_rows[cursor : cursor + bucket_size]: - posterior_row.row["difficulty"]["bucket_index"] = bucket_index - posterior_row.row["difficulty"]["bucket_count"] = effective_bucket_count - cursor += bucket_size - - -def estimate_expected_difficulty_quantiles( - posterior_rows: list[DifficultyPosteriorRow], - *, - grid_size: int = POSTERIOR_QUANTILE_GRID_SIZE, - batch_size: int = POSTERIOR_QUANTILE_BATCH_SIZE, -) -> list[float]: - if not posterior_rows: - return [] - if len(posterior_rows) == 1: - return [0.5] - - grid = (np.arange(grid_size, dtype=np.float64) + 0.5) / grid_size - difficulty_alphas = np.asarray([row.difficulty_alpha for row in posterior_rows], dtype=np.float64) - difficulty_betas = np.asarray([row.difficulty_beta for row in posterior_rows], dtype=np.float64) - - mixture_cdf = np.zeros(grid_size, dtype=np.float64) - for start in range(0, len(posterior_rows), batch_size): - stop = start + batch_size - batch_cdf = beta_distribution.cdf( - grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] - ) - mixture_cdf += np.nan_to_num(batch_cdf, nan=0.0, posinf=1.0, neginf=0.0).sum(axis=0) - mixture_cdf /= len(posterior_rows) - - quantiles = np.zeros(len(posterior_rows), dtype=np.float64) - dx = 1.0 / grid_size - for start in range(0, len(posterior_rows), batch_size): - stop = start + batch_size - batch_pdf = beta_distribution.pdf( - grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] - ) - quantiles[start:stop] = np.clip( - np.nan_to_num(batch_pdf, nan=0.0, posinf=0.0, neginf=0.0).dot(mixture_cdf) * dx, 0.0, 1.0 - ) - - return quantiles.tolist() - - -def fit_empirical_beta_prior(binary_counts: list[tuple[int, int]]) -> BetaPrior | None: - total_successes = sum(success_count for success_count, _ in binary_counts) - total_attempts = sum(attempt_count for _, attempt_count in binary_counts) - if total_attempts == 0 or total_successes in {0, total_attempts}: - return None - - mean_rate = total_successes / total_attempts - init_alpha = max(mean_rate * 2.0, 1e-3) - init_beta = max((1.0 - mean_rate) * 2.0, 1e-3) - - def objective(log_params: tuple[float, float]) -> float: - alpha = math.exp(log_params[0]) - beta = math.exp(log_params[1]) - return -sum( - betaln(success_count + alpha, attempt_count - success_count + beta) - betaln(alpha, beta) - for success_count, attempt_count in binary_counts - ) - - result = minimize( - objective, - x0=(math.log(init_alpha), math.log(init_beta)), - method="L-BFGS-B", - bounds=[(-10.0, 10.0), (-10.0, 10.0)], - ) - if not result.success: - logger.warning("Empirical-Bayes fit failed: %s", result.message) - return None - - return BetaPrior(alpha=math.exp(result.x[0]), beta=math.exp(result.x[1]), source="empirical_bayes") - - -def merge_experiment_metadata( - existing: dict[str, Any] | None, incoming: dict[str, Any], *, instance_id: str -) -> dict[str, Any]: - normalized_incoming = normalize_experiment_metadata(incoming) - if existing is None: - return normalized_incoming - - merged = dict(existing) - for key in EXPERIMENT_METADATA_KEYS: - existing_value = merged.get(key) - incoming_value = normalized_incoming.get(key) - if existing_value in {None, ""}: - merged[key] = incoming_value - elif incoming_value in {None, ""} or incoming_value == existing_value: - continue - else: - raise ValueError( - f"Conflicting experiment metadata for instance {instance_id}: " - f"{key}={existing_value!r} vs {incoming_value!r}" - ) - return merged - - -def normalize_experiment_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]: - if metadata is None: - return {key: None for key in EXPERIMENT_METADATA_KEYS} - return {key: metadata.get(key) for key in EXPERIMENT_METADATA_KEYS} - - -def resolve_output_root(output: Path) -> Path: - output_str = str(output) - if output_str.endswith(".schema.json"): - return Path(output_str[: -len(".schema.json")]) - if output_str.endswith(".jsonl"): - return Path(output_str[: -len(".jsonl")]) - if output_str.endswith(".json"): - return Path(output_str[: -len(".json")]) - return output - - -def build_output_paths( - output_root: Path, *, task_name: str, model_name: str | None, dataset_metadata: dict[str, Any] -) -> tuple[Path, Path, Path]: - task_suffix = sanitize_name(task_name) or "unknown-task" - model_suffix = sanitize_name(model_name or "") or "unknown-model" - difficulty_suffix = build_difficulty_filename_suffix(dataset_metadata) - stem = output_root / f"{task_suffix}__{model_suffix}{difficulty_suffix}" - return Path(f"{stem}.jsonl"), Path(f"{stem}.schema.json"), Path(f"{stem}.metadata.json") - - -def write_output_files( - *, output_jsonl: Path, schema_json: Path, metadata_json: Path, dataset: Dataset, dataset_metadata: dict[str, Any] -) -> None: - output_jsonl.parent.mkdir(parents=True, exist_ok=True) - with output_jsonl.open("w") as output_file: - for row in dataset: - output_file.write(json.dumps(make_jsonable(row), ensure_ascii=False) + "\n") - - schema_json.parent.mkdir(parents=True, exist_ok=True) - try: - schema_payload: Any = dataset.features.to_dict() - except AttributeError: - schema_payload = str(dataset.features) - with schema_json.open("w") as output_file: - json.dump(schema_payload, output_file, indent=2, sort_keys=True) - - metadata_json.parent.mkdir(parents=True, exist_ok=True) - with metadata_json.open("w") as output_file: - json.dump(dataset_metadata, output_file, indent=2, sort_keys=True) - - -def build_dataset_metadata( - *, - rows: list[dict[str, Any]], - task_name: str, - model_name: str | None, - requested_prior_mode: str, - requested_bucket_count: int, - lower_quantile: float, - prior: BetaPrior | None, - binary_row_count: int, - score_processing: dict[str, Any], - source_format: dict[str, Any], -) -> dict[str, Any]: - effective_bucket_count = extract_effective_bucket_count(rows) - difficulty_generation = { - "method": DIFFICULTY_GENERATION_METHOD, - "difficulty_value_field": "difficulty.value", - "difficulty_value_definition": "1 - difficulty.posterior_lower_bound", - "bucket_field": "difficulty.bucket_index", - "bucket_count_field": "difficulty.bucket_count", - "bucket_ranking_field": "difficulty.expected_quantile", - "posterior_lower_quantile": lower_quantile, - "bucket_count_requested": requested_bucket_count, - "bucket_count_effective": effective_bucket_count, - "beta_prior_requested": requested_prior_mode, - "beta_prior_used": { - "source": prior.source if prior is not None else None, - "alpha": prior.alpha if prior is not None else None, - "beta": prior.beta if prior is not None else None, - }, - "binary_instance_count": binary_row_count, - "nonbinary_instance_count": max(0, len(rows) - binary_row_count), - } - difficulty_generation["tag"] = build_difficulty_config_tag(difficulty_generation) - return { - "task_name": task_name, - "model_name": model_name, - "row_count": len(rows), - "source_format": dict(source_format), - "score_processing": dict(score_processing), - "difficulty_generation": difficulty_generation, - } - - -def extract_effective_bucket_count(rows: list[dict[str, Any]]) -> int: - effective_bucket_counts = { - difficulty.get("bucket_count") - for row in rows - if isinstance((difficulty := row.get("difficulty")), dict) and difficulty.get("bucket_count") is not None - } - if not effective_bucket_counts: - return 0 - if len(effective_bucket_counts) != 1: - raise ValueError(f"Expected a single effective bucket count, found {sorted(effective_bucket_counts)}") - return next(iter(effective_bucket_counts)) - - -def build_difficulty_filename_suffix(dataset_metadata: dict[str, Any]) -> str: - return f"__{dataset_metadata['difficulty_generation']['tag']}" - - -def build_difficulty_config_tag(difficulty_generation: dict[str, Any]) -> str: - method_token = abbreviate_filename_token( - optional_string(difficulty_generation.get("method")), - aliases=DIFFICULTY_METHOD_FILENAME_ALIASES, - default="diff", - ) - prior_source = optional_string((difficulty_generation.get("beta_prior_used") or {}).get("source")) - prior_token = abbreviate_filename_token(prior_source, aliases=PRIOR_SOURCE_FILENAME_ALIASES, default="none") - quantile_token = format_quantile_token(difficulty_generation["posterior_lower_quantile"]) - bucket_token = format_bucket_token( - requested_count=difficulty_generation["bucket_count_requested"], - effective_count=difficulty_generation["bucket_count_effective"], - ) - return "-".join([method_token, prior_token, quantile_token, bucket_token]) - - -def abbreviate_filename_token(value: str | None, *, aliases: dict[str, str], default: str) -> str: - if not value: - return default - return aliases.get(value, sanitize_name(value)) - - -def format_quantile_token(value: float) -> str: - return f"q{format_filename_number(value * 100.0)}" - - -def format_bucket_token(*, requested_count: int, effective_count: int) -> str: - if requested_count == effective_count: - return f"k{requested_count}" - return f"k{requested_count}e{effective_count}" - - -def annotate_dataset_metadata(dataset: Dataset, dataset_metadata: dict[str, Any]) -> None: - if not hasattr(dataset, "info") or dataset.info is None: - return - dataset.info.description = json.dumps(dataset_metadata, indent=2, sort_keys=True) - - -def validate_args(args: argparse.Namespace) -> None: - if not 0.0 < args.posterior_lower_quantile < 1.0: - raise ValueError("--posterior-lower-quantile must be between 0 and 1.") - if args.difficulty_buckets < 0: - raise ValueError("--difficulty-buckets must be non-negative.") - if args.max_instances is not None and args.max_instances <= 0: - raise ValueError("--max-instances must be positive when provided.") - - -def group_rows_by_task_and_model(rows: list[dict[str, Any]]) -> dict[tuple[str, str | None], list[dict[str, Any]]]: - rows_by_group: dict[tuple[str, str | None], list[dict[str, Any]]] = defaultdict(list) - for row in rows: - experiment_metadata = row.get("experiment_metadata") or {} - task_name = stable_string(row.get("task_name")) - model_name = optional_string(experiment_metadata.get("model_name")) - rows_by_group[(task_name, model_name)].append(row) - return dict(rows_by_group) - - -def read_jsonl(path: Path) -> list[dict[str, Any]]: - with path.open() as input_file: - return [json.loads(line) for line in input_file if line.strip()] - - -def get_base_task_name(task_name: str) -> str: - return task_name.split("@", 1)[0].split(":", 1)[0] - - -def extract_binary_counts(attempt_scores: list[float]) -> tuple[int, int] | None: - if not attempt_scores: - return None - - success_count = 0 - for score in attempt_scores: - if is_close(score, 0.0): - continue - if is_close(score, 1.0): - success_count += 1 - continue - return None - - return success_count, len(attempt_scores) - - -def make_rollout_instance_id( - *, - task_name: str, - prompt_tokens: list[int] | None, - ground_truth: Any, - source_dataset: str | None = None, - source_dataset_id: int | None = None, -) -> str: - if source_dataset is not None and source_dataset_id is not None: - return f"{source_dataset}::{source_dataset_id}" - - if prompt_tokens is None: - raise ValueError("prompt_tokens are required when source row identity is unavailable") - - fingerprint = {"task_name": task_name, "prompt_tokens": prompt_tokens, "ground_truth": make_jsonable(ground_truth)} - digest = hashlib.sha1(canonical_json(fingerprint).encode("utf-8")).hexdigest()[:20] - task_prefix = sanitize_name(task_name) or "unknown" - return f"{task_prefix}::{digest}" - - -def canonical_json(value: Any) -> str: - return json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True, separators=(",", ":")) - - -def make_jsonable(value: Any) -> Any: - if value is None or isinstance(value, (str, int, float, bool)): - return value - if isinstance(value, list): - return [make_jsonable(item) for item in value] - if isinstance(value, tuple): - return [make_jsonable(item) for item in value] - if isinstance(value, dict): - return {stable_string(key): make_jsonable(item) for key, item in value.items()} - return stable_string(value) - - -def stable_string(value: Any) -> str: - if value is None: - return "" - if isinstance(value, str): - return value - return str(value) - - -def optional_string(value: Any) -> str | None: - text = stable_string(value) - return text or None - - -def serialize_value(value: Any) -> str | None: - if value is None: - return None - if isinstance(value, str): - return value - return json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True) - - -def format_filename_number(value: float) -> str: - text = f"{value:.8g}" - return text.replace("-", "m").replace(".", "p") - - -def sanitize_name(value: str) -> str: - return value.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") - - -def is_number(value: Any) -> bool: - return isinstance(value, (int, float)) and not isinstance(value, bool) and not math.isnan(float(value)) - - -def is_close(lhs: float, rhs: float) -> bool: - tolerance = EPS * max(1.0, abs(lhs), abs(rhs)) - return abs(lhs - rhs) <= tolerance +from open_instruct import rlvr_difficulty if __name__ == "__main__": - main() + rlvr_difficulty.main() diff --git a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh b/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh deleted file mode 100644 index 2662fd816c..0000000000 --- a/scripts/data/difficulty_sampling/qwen3_4b_dapo_math_gen.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash - -EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_rollout_probe}" -RUN_NAME="${RUN_NAME:-${EXP_NAME}_$(date +%Y%m%d_%H%M%S)}" - -NUM_GPUS="${NUM_GPUS:-1}" -BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" - -CLUSTER="${CLUSTER:-ai2/jupiter}" -PRIORITY="${PRIORITY:-urgent}" -WORKSPACE="${WORKSPACE:-ai2/olmo-instruct}" -TRACE_DIR="${TRACE_DIR:-/weka/oe-adapt-default/tylerm/deletable_rollouts/${EXP_NAME}/${RUN_NAME}}" - -if [[ $# -gt 0 ]]; then - shift -fi - -uv run python mason.py \ - --task_name "${EXP_NAME}" \ - --description "${RUN_NAME}" \ - --cluster "${CLUSTER}" \ - --workspace "${WORKSPACE}" \ - --priority "${PRIORITY}" \ - --pure_docker_mode \ - --no_auto_dataset_cache \ - --image "${BEAKER_IMAGE}" \ - --preemptible \ - --num_nodes 1 \ - --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ - --gpus "${NUM_GPUS}" \ - --budget ai2/oe-adapt \ - -- \ -uv run open_instruct/benchmark_generators.py \ - --run_name "${RUN_NAME}" \ - --exp_name "${EXP_NAME}" \ - --output_dir "${TRACE_DIR}" \ - --model_name_or_path "Qwen/Qwen3-4B-Base" \ - --tokenizer_name_or_path "Qwen/Qwen3-4B-Base" \ - --chat_template_name qwen_instruct_user_boxed_math \ - --dataset_mixer_list hamishivi/DAPO-Math-17k-Processed_filtered 1.0 \ - --dataset_mixer_list_splits train \ - --num_unique_prompts_rollout 64 \ - --num_samples_per_prompt_rollout 16 \ - --vllm_num_engines 8 \ - --vllm_tensor_parallel_size 1 \ - --max_prompt_token_length 2048 \ - --response_length 8192 \ - --pack_length 10240 \ - --vllm_top_p 1.0 \ - --temperature 1.0 \ - --apply_verifiable_reward true \ - --verification_reward 10.0 \ - --save_traces \ - --vllm_enable_prefix_caching \ - --rollout_save_format scores_only \ - --rollouts_save_path "${TRACE_DIR}" \ - --run_all_instances \ - --seed 1 "$@" diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh b/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh index 21e6bd427a..24f238fc8b 100644 --- a/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh @@ -26,15 +26,33 @@ CHECKPOINT_STATE_FREQ="${CHECKPOINT_STATE_FREQ:-100}" NUM_TRAINING_STEPS=$(( TOTAL_EPISODES / (NUM_UNIQUE_PROMPTS_ROLLOUT * NUM_SAMPLES_PER_PROMPT_ROLLOUT) )) -# Keep the easy bootstrap aligned with the first logging/eval window by default. -DIFFICULTY_CURRICULUM_EASY_FOCUS_STEPS="${DIFFICULTY_CURRICULUM_EASY_FOCUS_STEPS:-${LOCAL_EVAL_EVERY}}" -DIFFICULTY_CURRICULUM_WARMUP_STEPS="${DIFFICULTY_CURRICULUM_WARMUP_STEPS:-${DIFFICULTY_CURRICULUM_EASY_FOCUS_STEPS}}" -if (( NUM_TRAINING_STEPS <= DIFFICULTY_CURRICULUM_WARMUP_STEPS )); then - DEFAULT_DIFFICULTY_CURRICULUM_TOTAL_STEPS=1 +# Keep the bootstrap aligned with the first logging/eval window by default. +CURRICULUM_BOOTSTRAP_STEPS="${CURRICULUM_BOOTSTRAP_STEPS:-${LOCAL_EVAL_EVERY}}" +CURRICULUM_WARMUP_STEPS="${CURRICULUM_WARMUP_STEPS:-${CURRICULUM_BOOTSTRAP_STEPS}}" +if (( NUM_TRAINING_STEPS <= CURRICULUM_WARMUP_STEPS )); then + DEFAULT_CURRICULUM_TOTAL_STEPS=1 else - DEFAULT_DIFFICULTY_CURRICULUM_TOTAL_STEPS=$(( NUM_TRAINING_STEPS - DIFFICULTY_CURRICULUM_WARMUP_STEPS )) + DEFAULT_CURRICULUM_TOTAL_STEPS=$(( NUM_TRAINING_STEPS - CURRICULUM_WARMUP_STEPS )) fi -DIFFICULTY_CURRICULUM_TOTAL_STEPS="${DIFFICULTY_CURRICULUM_TOTAL_STEPS:-${DEFAULT_DIFFICULTY_CURRICULUM_TOTAL_STEPS}}" +CURRICULUM_TOTAL_STEPS="${CURRICULUM_TOTAL_STEPS:-${DEFAULT_CURRICULUM_TOTAL_STEPS}}" + +CURRICULUM_ARGS=( + --curriculum difficulty + --curriculum_metadata_field difficulty + --curriculum_bootstrap_steps "${CURRICULUM_BOOTSTRAP_STEPS}" + --curriculum_bootstrap_target 0.125 + --curriculum_warmup_target 0.5 + --curriculum_final_target 1.0 + --curriculum_warmup_steps "${CURRICULUM_WARMUP_STEPS}" + --curriculum_total_steps "${CURRICULUM_TOTAL_STEPS}" + --curriculum_min_hard_frac 0.05 + --curriculum_max_hard_frac 0.50 + --curriculum_bucket_sigma 0.0 + --curriculum_bootstrap_sigma 0.0 + --curriculum_uncertainty_weight 0.5 + --curriculum_adaptive false + --curriculum_strict_metadata true +) uv run python mason.py \ --task_name ${EXP_NAME} \ @@ -99,18 +117,4 @@ uv run open_instruct/grpo_fast.py \ --load_ref_policy False \ --keep_last_n_checkpoints -1 \ --push_to_hub False \ - --difficulty_curriculum_enabled true \ - --difficulty_curriculum_field difficulty \ - --difficulty_curriculum_easy_focus_steps ${DIFFICULTY_CURRICULUM_EASY_FOCUS_STEPS} \ - --difficulty_curriculum_bootstrap_target_bucket_ratio 0.125 \ - --difficulty_curriculum_warmup_target_bucket_ratio 0.5 \ - --difficulty_curriculum_final_target_bucket_ratio 1.0 \ - --difficulty_curriculum_warmup_steps ${DIFFICULTY_CURRICULUM_WARMUP_STEPS} \ - --difficulty_curriculum_total_steps ${DIFFICULTY_CURRICULUM_TOTAL_STEPS} \ - --difficulty_curriculum_min_hard_frac 0.05 \ - --difficulty_curriculum_max_hard_frac 0.50 \ - --difficulty_curriculum_bucket_sigma 0.0 \ - --difficulty_curriculum_easy_focus_sigma 0.0 \ - --difficulty_curriculum_uncertainty_weight 0.5 \ - --difficulty_curriculum_adaptive_enabled False \ - --difficulty_curriculum_strict_metadata true "$@" + "${CURRICULUM_ARGS[@]}" "$@" diff --git a/tests/test_create_bucketed_difficulty.py b/tests/test_create_bucketed_difficulty.py index 733b6b4e92..21c0c550af 100644 --- a/tests/test_create_bucketed_difficulty.py +++ b/tests/test_create_bucketed_difficulty.py @@ -1,6 +1,6 @@ -"""Unit tests for posterior-aware bucketing in create_bucketed_difficulty.py.""" +"""Unit tests for posterior-aware bucketing in open_instruct.rlvr_difficulty.""" -import importlib.util +import importlib import json import math import sys @@ -14,14 +14,8 @@ import numpy as np -SCRIPT_PATH = Path(__file__).resolve().parents[1] / "scripts/data/difficulty_sampling/create_bucketed_difficulty.py" - def _load_create_bucketed_difficulty_module(): - module_name = "test_create_bucketed_difficulty_script" - spec = importlib.util.spec_from_file_location(module_name, SCRIPT_PATH) - module = importlib.util.module_from_spec(spec) - fake_datasets = types.ModuleType("datasets") fake_datasets.Dataset = type("Dataset", (), {}) fake_datasets.load_dataset = lambda *_args, **_kwargs: None @@ -90,12 +84,8 @@ def ppf(cls, q, alpha, beta): } with patch.dict(sys.modules, modules): - assert spec.loader is not None - sys.modules[module_name] = module - spec.loader.exec_module(module) - sys.modules.pop(module_name, None) - - return module + sys.modules.pop("open_instruct.rlvr_difficulty", None) + return importlib.import_module("open_instruct.rlvr_difficulty") MODULE = _load_create_bucketed_difficulty_module() From bc4b9cc3143b056d82947f696d0a72428ed26c64 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 15:55:09 -0700 Subject: [PATCH 22/40] More cleanup --- open_instruct/rlvr_difficulty.py | 1521 ----------------- .../create_bucketed_difficulty.py | 1135 +++++++++++- tests/test_create_bucketed_difficulty.py | 496 +++--- 3 files changed, 1348 insertions(+), 1804 deletions(-) delete mode 100644 open_instruct/rlvr_difficulty.py diff --git a/open_instruct/rlvr_difficulty.py b/open_instruct/rlvr_difficulty.py deleted file mode 100644 index 3d193283e0..0000000000 --- a/open_instruct/rlvr_difficulty.py +++ /dev/null @@ -1,1521 +0,0 @@ -""" -Build a per-instance difficulty map from open-instruct rollout traces or -Hugging Face datasets with pass-rate aggregates. - -The script accepts one or more local rollout directories, metadata ``.jsonl`` -files, rollout shard ``.jsonl`` files written by ``open_instruct.rl_utils``, -or a Hugging Face dataset that already contains per-row pass counts. For each -prompt instance it: - -1. loads rollout shards written by ``save_rollouts_to_disk()``, including compact score-only shards, - or loads per-row pass counts from a Hub dataset, -2. groups attempts by source dataset identity when available, otherwise by a - deterministic fingerprint over task name, prompt tokens, and ground truth, -3. normalizes binary verifiable rewards from ``{0, C}`` back to ``{0, 1}`` - when possible, -4. fits a Beta prior across binary outcomes and estimates per-item success - rates, and -5. writes a JSONL difficulty file and schema/metadata sidecars. - -Examples: - uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ - --source /tmp/qwen_math_rollouts \ - --task math \ - --output /tmp/qwen_math_difficulty - - uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ - --source /tmp/qwen_math_rollouts/qwen_math_metadata.jsonl \ - --output /tmp/difficulty_map - - uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ - --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ - --hf-split train \ - --output /tmp/dapo_math_qwen3_difficulty -""" - -from __future__ import annotations - -import argparse -import hashlib -import json -import math -from collections import defaultdict -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -import numpy as np -from datasets import Dataset, load_dataset -from scipy.optimize import minimize -from scipy.special import betaln -from scipy.stats import beta as beta_distribution - -from open_instruct import logger_utils - -logger = logger_utils.setup_logger(__name__) - - -EPS = 1e-8 -EXPERIMENT_METADATA_KEYS = ("source_root", "model_name", "experiment_id", "experiment_name") -JEFFREYS_PRIOR_ALPHA = 0.5 -JEFFREYS_PRIOR_BETA = 0.5 -DEFAULT_DIFFICULTY_BUCKETS = 5 -POSTERIOR_QUANTILE_GRID_SIZE = 512 -POSTERIOR_QUANTILE_BATCH_SIZE = 256 -DIFFICULTY_GENERATION_METHOD = "beta_binomial_posterior_quantiles" -DIFFICULTY_METHOD_FILENAME_ALIASES = {DIFFICULTY_GENERATION_METHOD: "bbq"} -PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} -ROLLOUT_SOURCE_FORMAT_KIND = "open_instruct_rollout_traces" -HF_SOURCE_FORMAT_KIND = "hugging_face_dataset_passrate_rows" -ROLLOUT_INSTANCE_ID_DEFINITION = ( - "source_dataset::source_dataset_id when available; otherwise sha1(task_name,prompt_tokens,ground_truth)" -) -HF_INSTANCE_ID_DEFINITION = ( - "dataset_repo_id::row_id_field when a stable row id is available; otherwise dataset_repo_id::row_index" -) -HF_SOURCE_ROW_INDEX_FIELD = "_source_row_index" -HF_OUTPUT_COLUMNS = ("difficulty",) - - -@dataclass(frozen=True) -class BetaPrior: - alpha: float - beta: float - source: str - - -@dataclass(frozen=True) -class RolloutSource: - input_arg: str - root_path: Path - metadata_path: Path - rollout_paths: tuple[Path, ...] - run_name: str - - -@dataclass(frozen=True) -class DifficultyPosteriorRow: - row: dict[str, Any] - difficulty_alpha: float - difficulty_beta: float - - -@dataclass(frozen=True) -class InputRowsBundle: - rows: list[dict[str, Any]] - malformed_records: int - source_format: dict[str, Any] - source_dataset: Dataset | None = None - - -def make_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description="Build a per-instance difficulty map from open-instruct rollout traces or HF pass-rate datasets.", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - source_group = parser.add_mutually_exclusive_group(required=True) - source_group.add_argument( - "--source", - nargs="+", - help="One or more local rollout dirs, *_metadata.jsonl files, or *_rollouts_*.jsonl shards.", - ) - source_group.add_argument( - "--hf-dataset", - type=str, - default=None, - help="Hugging Face dataset repo id containing per-row pass-rate aggregates.", - ) - parser.add_argument("--hf-config", type=str, default=None, help="Optional dataset config for --hf-dataset.") - parser.add_argument("--hf-split", type=str, default="train", help="Input split to load from --hf-dataset.") - parser.add_argument( - "--hf-row-id-field", - type=str, - default="extra_info.index", - help="Dot-path to the stable per-row id field inside --hf-dataset.", - ) - parser.add_argument( - "--hf-task-field", type=str, default="dataset", help="Dot-path to the task/verifier field in --hf-dataset." - ) - parser.add_argument( - "--hf-model-field", - type=str, - default="generator_model", - help="Dot-path to the generator model field in --hf-dataset.", - ) - parser.add_argument( - "--hf-pass-count-field", - type=str, - default="pass_count", - help="Dot-path to the integer pass-count field in --hf-dataset.", - ) - parser.add_argument( - "--hf-attempt-count-field", - type=str, - default="num_samples", - help="Dot-path to the total-attempt-count field in --hf-dataset.", - ) - parser.add_argument( - "--hf-pass-rate-field", - type=str, - default="pass_rate", - help="Optional dot-path to a pass-rate or fraction field used for validation/fallback in --hf-dataset.", - ) - parser.add_argument( - "--task", - action="append", - default=[], - help="Optional task filter. Matches the rollout trace dataset/verifier source.", - ) - parser.add_argument( - "--output", - type=Path, - required=True, - help=( - "Output directory or path-like root. The script writes one file per task/model inside it as " - "____.jsonl plus matching .schema.json and .metadata.json sidecars." - ), - ) - parser.add_argument( - "--push-to-hub", type=str, default=None, help="Optional dataset repo id to push the validated rows to." - ) - parser.add_argument("--split", type=str, default="train", help="Split to use with --push-to-hub.") - parser.add_argument( - "--strict", action="store_true", help="Fail if a rollout record is malformed or required files are missing." - ) - parser.add_argument( - "--allow-nonunit-scores", - action="store_true", - help="Keep rows whose rewards cannot be normalized to binary correctness. Difficulty will be null for them.", - ) - parser.add_argument( - "--max-instances", - type=int, - default=None, - help="Optional cap for the number of resolved instances written (useful for smoke tests).", - ) - parser.add_argument( - "--beta-prior", - choices=["empirical-bayes", "jeffreys"], - default="empirical-bayes", - help="Global Beta prior to use for smoothing binary solve rates.", - ) - parser.add_argument( - "--posterior-lower-quantile", - type=float, - default=0.1, - help="Lower posterior quantile used to define difficulty as 1 - quantile.", - ) - parser.add_argument( - "--difficulty-buckets", - type=int, - default=DEFAULT_DIFFICULTY_BUCKETS, - help=( - "Number of posterior-aware quantile buckets to assign for stratification. " - "Set to 0 to skip discrete bucket assignment." - ), - ) - return parser - - -def main(argv: list[str] | None = None) -> None: - args = make_parser().parse_args(argv) - validate_args(args) - task_filters = set(args.task) - output_root = resolve_output_root(args.output) - - input_rows = load_input_rows(args, task_filters=task_filters) - - if not input_rows.rows: - raise ValueError("No resolved per-instance rows were produced.") - - rows = sorted( - input_rows.rows, - key=lambda row: ( - stable_string(row.get("task_name")), - stable_string((row.get("experiment_metadata") or {}).get("model_name")), - stable_string(row.get("instance_id")), - ), - ) - if args.max_instances is not None: - rows = rows[: args.max_instances] - - rows_by_group = group_rows_by_task_and_model(rows) - if args.push_to_hub is not None and len(rows_by_group) != 1: - raise ValueError( - "--push-to-hub requires a single task/model output. Filter with --task or use a source with one task." - ) - - skipped_nonunit = 0 - written_outputs: list[tuple[str, str | None, int, Path, Path, Path]] = [] - - for (task_name, model_name), group_rows in sorted( - rows_by_group.items(), key=lambda item: (item[0][0], stable_string(item[0][1])) - ): - group_rows, score_processing, group_skipped_nonunit = normalize_attempt_scores_for_group( - group_rows, allow_nonunit_scores=args.allow_nonunit_scores - ) - if input_rows.source_format["kind"] == HF_SOURCE_FORMAT_KIND: - score_processing["source_field"] = ",".join( - field_name - for field_name in ( - input_rows.source_format.get("pass_count_field"), - input_rows.source_format.get("attempt_count_field"), - input_rows.source_format.get("pass_rate_field"), - ) - if field_name - ) - skipped_nonunit += group_skipped_nonunit - - if not group_rows: - logger.warning( - "Skipping task=%s model=%s because no rows remained after reward normalization.", task_name, model_name - ) - continue - - prior, binary_row_count = estimate_beta_prior(group_rows, prior_mode=args.beta_prior) - group_rows = apply_beta_binomial_difficulty( - group_rows, prior=prior, lower_quantile=args.posterior_lower_quantile, num_buckets=args.difficulty_buckets - ) - if input_rows.source_dataset is None: - ordered_group_rows = sorted(group_rows, key=lambda row: row["instance_id"]) - output_rows = strip_output_only_rollout_fields(ordered_group_rows) - dataset = Dataset.from_list(output_rows) - else: - ordered_group_rows = sort_hf_group_rows(group_rows) - output_rows = strip_internal_fields(ordered_group_rows) - dataset = build_hf_output_dataset(input_rows.source_dataset, ordered_group_rows) - - dataset_metadata = build_dataset_metadata( - rows=output_rows, - task_name=task_name, - model_name=model_name, - requested_prior_mode=args.beta_prior, - requested_bucket_count=args.difficulty_buckets, - lower_quantile=args.posterior_lower_quantile, - prior=prior, - binary_row_count=binary_row_count, - score_processing=score_processing, - source_format=input_rows.source_format, - ) - - if prior is not None: - logger.info( - "Using %s Beta prior alpha=%.4f beta=%.4f across %s binary instances for task=%s model=%s.", - prior.source, - prior.alpha, - prior.beta, - binary_row_count, - task_name, - model_name, - ) - else: - logger.warning( - "No binary instances were available for Beta-Binomial difficulty estimation for task=%s model=%s.", - task_name, - model_name, - ) - - annotate_dataset_metadata(dataset, dataset_metadata) - output_jsonl, schema_json, metadata_json = build_output_paths( - output_root, task_name=task_name, model_name=model_name, dataset_metadata=dataset_metadata - ) - write_output_files( - output_jsonl=output_jsonl, - schema_json=schema_json, - metadata_json=metadata_json, - dataset=dataset, - dataset_metadata=dataset_metadata, - ) - - if args.push_to_hub is not None: - dataset.push_to_hub(args.push_to_hub, split=args.split, private=True) - - written_outputs.append((task_name, model_name, len(output_rows), output_jsonl, schema_json, metadata_json)) - logger.info( - "Wrote %s rows for task=%s model=%s to %s, %s, and %s.", - len(output_rows), - task_name, - model_name, - output_jsonl, - schema_json, - metadata_json, - ) - - logger.info( - "Finished writing %s output file groups (%s malformed rollout records, %s skipped due to unsupported scores).", - len(written_outputs), - input_rows.malformed_records, - skipped_nonunit, - ) - - -def load_input_rows(args: argparse.Namespace, *, task_filters: set[str]) -> InputRowsBundle: - if args.hf_dataset is not None: - return load_hf_dataset_rows( - dataset_name=args.hf_dataset, - config_name=args.hf_config, - split=args.hf_split, - task_filters=task_filters, - strict=args.strict, - row_id_field=args.hf_row_id_field, - task_field=args.hf_task_field, - model_field=args.hf_model_field, - pass_count_field=args.hf_pass_count_field, - attempt_count_field=args.hf_attempt_count_field, - pass_rate_field=args.hf_pass_rate_field, - ) - - if not args.source: - raise ValueError("Expected --source when --hf-dataset is not provided.") - - source_runs = discover_rollout_sources(args.source) - if not source_runs: - raise ValueError("No rollout trace sources were found.") - - contributions: list[dict[str, Any]] = [] - malformed_records = 0 - - for source_run in source_runs: - logger.info( - "Loading %s (run=%s, metadata=%s, shards=%s)", - source_run.input_arg, - source_run.run_name, - source_run.metadata_path, - len(source_run.rollout_paths), - ) - run_contributions, run_malformed = build_contributions_for_source( - source_run=source_run, task_filters=task_filters, strict=args.strict - ) - contributions.extend(run_contributions) - malformed_records += run_malformed - - return InputRowsBundle( - rows=aggregate_contributions(contributions), - malformed_records=malformed_records, - source_format=build_rollout_source_format_metadata(), - ) - - -def load_hf_dataset_rows( - *, - dataset_name: str, - config_name: str | None, - split: str, - task_filters: set[str], - strict: bool, - row_id_field: str, - task_field: str, - model_field: str, - pass_count_field: str, - attempt_count_field: str, - pass_rate_field: str | None, -) -> InputRowsBundle: - logger.info( - "Loading Hugging Face dataset %s (config=%s, split=%s).", dataset_name, config_name or "default", split - ) - - if config_name: - source_dataset = load_dataset(dataset_name, config_name, split=split) - else: - source_dataset = load_dataset(dataset_name, split=split) - - rows: list[dict[str, Any]] = [] - malformed_records = 0 - - for row_index, source_row in enumerate(source_dataset): - try: - row = build_hf_dataset_row( - source_row=source_row, - source_row_index=row_index, - dataset_name=dataset_name, - config_name=config_name, - split=split, - row_id_field=row_id_field, - task_field=task_field, - model_field=model_field, - pass_count_field=pass_count_field, - attempt_count_field=attempt_count_field, - pass_rate_field=pass_rate_field, - ) - except Exception as exc: - malformed_records += 1 - message = f"Malformed HF dataset row {dataset_name}[{split}][{row_index}]: {exc}" - if strict: - raise ValueError(message) from exc - logger.warning(message) - continue - - task_name = stable_string(row.get("task_name")) - if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: - continue - rows.append(row) - - return InputRowsBundle( - rows=rows, - malformed_records=malformed_records, - source_format=build_hf_source_format_metadata( - dataset_name=dataset_name, - config_name=config_name, - split=split, - row_id_field=row_id_field, - task_field=task_field, - model_field=model_field, - pass_count_field=pass_count_field, - attempt_count_field=attempt_count_field, - pass_rate_field=pass_rate_field, - ), - source_dataset=source_dataset, - ) - - -def build_hf_dataset_row( - *, - source_row: dict[str, Any], - source_row_index: int, - dataset_name: str, - config_name: str | None, - split: str, - row_id_field: str, - task_field: str, - model_field: str, - pass_count_field: str, - attempt_count_field: str, - pass_rate_field: str | None, -) -> dict[str, Any]: - task_name = normalize_task_name(get_nested_field(source_row, task_field)) - if task_name is None: - raise ValueError(f"missing task field {task_field!r}") - - source_row_id = normalize_identifier(get_nested_field(source_row, row_id_field)) or str(source_row_index) - pass_count, attempt_count = extract_hf_attempt_summary( - row=source_row, - pass_count_field=pass_count_field, - attempt_count_field=attempt_count_field, - pass_rate_field=pass_rate_field, - ) - model_name = optional_string(get_nested_field(source_row, model_field)) - - return { - HF_SOURCE_ROW_INDEX_FIELD: source_row_index, - "instance_id": make_hf_instance_id(dataset_name=dataset_name, source_row_id=source_row_id), - "task_name": task_name, - "base_task_name": get_base_task_name(task_name), - "source_dataset": dataset_name, - "source_row_id": source_row_id, - "attempt_scores": expand_binary_attempt_scores(pass_count=pass_count, attempt_count=attempt_count), - "finish_reasons": [], - "experiment_metadata": { - "source_root": format_hf_source_locator(dataset_name=dataset_name, config_name=config_name, split=split), - "model_name": model_name, - "experiment_id": None, - "experiment_name": dataset_name, - }, - "score_sources": [task_name], - "warnings": [], - } - - -def build_rollout_source_format_metadata() -> dict[str, Any]: - return { - "kind": ROLLOUT_SOURCE_FORMAT_KIND, - "task_field": "dataset", - "score_field": "reward", - "source_dataset_field": "source_dataset", - "source_dataset_id_field": "source_dataset_id", - "source_row_id_field": "source_row_id", - "instance_id_definition": ROLLOUT_INSTANCE_ID_DEFINITION, - } - - -def build_hf_source_format_metadata( - *, - dataset_name: str, - config_name: str | None, - split: str, - row_id_field: str, - task_field: str, - model_field: str, - pass_count_field: str, - attempt_count_field: str, - pass_rate_field: str | None, -) -> dict[str, Any]: - return { - "kind": HF_SOURCE_FORMAT_KIND, - "dataset_repo_id": dataset_name, - "config_name": config_name, - "split": split, - "row_id_field": row_id_field, - "task_field": task_field, - "model_field": model_field, - "pass_count_field": pass_count_field, - "attempt_count_field": attempt_count_field, - "pass_rate_field": pass_rate_field, - "instance_id_definition": HF_INSTANCE_ID_DEFINITION, - } - - -def format_hf_source_locator(*, dataset_name: str, config_name: str | None, split: str) -> str: - config_token = config_name or "default" - return f"hf://{dataset_name}/{config_token}/{split}" - - -def make_hf_instance_id(*, dataset_name: str, source_row_id: str) -> str: - return f"{dataset_name}::{source_row_id}" - - -def sort_hf_group_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - return sorted(rows, key=lambda row: row[HF_SOURCE_ROW_INDEX_FIELD]) - - -def build_hf_output_dataset(source_dataset: Dataset, rows: list[dict[str, Any]]) -> Dataset: - ordered_rows = sort_hf_group_rows(rows) - dataset = source_dataset.select([row[HF_SOURCE_ROW_INDEX_FIELD] for row in ordered_rows]) - - for column_name in HF_OUTPUT_COLUMNS: - values = [make_jsonable(row.get(column_name)) for row in ordered_rows] - if column_name in dataset.column_names: - dataset = dataset.remove_columns(column_name) - dataset = dataset.add_column(column_name, values) - - return dataset - - -def strip_internal_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [{key: value for key, value in row.items() if key != HF_SOURCE_ROW_INDEX_FIELD} for row in rows] - - -def get_nested_field(value: Any, field_path: str) -> Any: - if not field_path: - return value - - current = value - for field_name in field_path.split("."): - if not isinstance(current, dict) or field_name not in current: - return None - current = current[field_name] - return current - - -def normalize_identifier(value: Any) -> str | None: - if value is None or isinstance(value, bool): - return None - text = stable_string(value).strip() - return text or None - - -def normalize_nonnegative_int(value: Any) -> int | None: - if value is None or isinstance(value, bool): - return None - if isinstance(value, int): - return value if value >= 0 else None - if isinstance(value, float): - if not math.isfinite(value) or not value.is_integer() or value < 0: - return None - return int(value) - if isinstance(value, str): - stripped = value.strip() - if not stripped: - return None - try: - parsed = int(stripped) - except ValueError: - return None - return parsed if parsed >= 0 else None - return None - - -def parse_pass_rate_value(value: Any) -> tuple[int | None, int | None, float | None]: - if value is None: - return None, None, None - if is_number(value): - rate = float(value) - if 0.0 <= rate <= 1.0: - return None, None, rate - raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") - if not isinstance(value, str): - raise ValueError(f"unsupported pass-rate value {value!r}") - - stripped = value.strip() - if not stripped: - return None, None, None - - if "/" in stripped: - numerator_text, denominator_text = stripped.split("/", 1) - numerator = normalize_nonnegative_int(numerator_text) - denominator = normalize_nonnegative_int(denominator_text) - if numerator is None or denominator is None or numerator > denominator: - raise ValueError(f"invalid pass-rate fraction {value!r}") - rate = 0.0 if denominator == 0 else numerator / denominator - return numerator, denominator, rate - - try: - rate = float(stripped) - except ValueError as exc: - raise ValueError(f"invalid pass-rate value {value!r}") from exc - if not math.isfinite(rate) or rate < 0.0 or rate > 1.0: - raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") - return None, None, rate - - -def extract_hf_attempt_summary( - *, row: dict[str, Any], pass_count_field: str, attempt_count_field: str, pass_rate_field: str | None -) -> tuple[int, int]: - pass_count = normalize_nonnegative_int(get_nested_field(row, pass_count_field)) - attempt_count = normalize_nonnegative_int(get_nested_field(row, attempt_count_field)) - - parsed_pass_count = None - parsed_attempt_count = None - parsed_pass_rate = None - if pass_rate_field: - parsed_pass_count, parsed_attempt_count, parsed_pass_rate = parse_pass_rate_value( - get_nested_field(row, pass_rate_field) - ) - - if pass_count is None and parsed_pass_count is not None: - pass_count = parsed_pass_count - if attempt_count is None and parsed_attempt_count is not None: - attempt_count = parsed_attempt_count - - if pass_count is None or attempt_count is None: - raise ValueError( - f"missing pass-count summary fields {pass_count_field!r}/{attempt_count_field!r}" - f"{f' or parseable {pass_rate_field!r}' if pass_rate_field else ''}" - ) - if attempt_count <= 0: - raise ValueError(f"attempt count must be positive, received {attempt_count}") - if pass_count > attempt_count: - raise ValueError(f"pass count {pass_count} exceeds attempt count {attempt_count}") - - if parsed_pass_count is not None and parsed_pass_count != pass_count: - raise ValueError(f"pass-count field {pass_count_field!r} disagrees with {pass_rate_field!r}") - if parsed_attempt_count is not None and parsed_attempt_count != attempt_count: - raise ValueError(f"attempt-count field {attempt_count_field!r} disagrees with {pass_rate_field!r}") - if parsed_pass_rate is not None and not is_close(pass_count / attempt_count, parsed_pass_rate): - raise ValueError( - f"pass-count fields {pass_count_field!r}/{attempt_count_field!r} disagree with {pass_rate_field!r}" - ) - - return pass_count, attempt_count - - -def expand_binary_attempt_scores(*, pass_count: int, attempt_count: int) -> list[float]: - return [1.0] * pass_count + [0.0] * (attempt_count - pass_count) - - -def discover_rollout_sources(sources: list[str]) -> list[RolloutSource]: - discovered: dict[Path, RolloutSource] = {} - - for source in sources: - source_path = Path(source) - if not source_path.exists(): - raise FileNotFoundError(f"Could not find source path {source}") - - if source_path.is_dir(): - metadata_paths = sorted(source_path.rglob("*_metadata.jsonl")) - if not metadata_paths: - raise FileNotFoundError(f"Could not find *_metadata.jsonl under {source}") - for metadata_path in metadata_paths: - rollout_source = build_rollout_source_from_metadata(metadata_path, input_arg=source) - discovered[rollout_source.metadata_path] = rollout_source - continue - - if source_path.name.endswith("_metadata.jsonl"): - rollout_source = build_rollout_source_from_metadata(source_path, input_arg=source) - discovered[rollout_source.metadata_path] = rollout_source - continue - - if source_path.suffix == ".jsonl" and "_rollouts_" in source_path.name: - rollout_source = build_rollout_source_from_rollout(source_path, input_arg=source) - discovered[rollout_source.metadata_path] = rollout_source - continue - - raise ValueError( - f"Unsupported source path {source}. Expected a directory, *_metadata.jsonl, or *_rollouts_*.jsonl." - ) - - return sorted(discovered.values(), key=lambda source_run: (str(source_run.root_path), source_run.run_name)) - - -def build_rollout_source_from_metadata(metadata_path: Path, *, input_arg: str) -> RolloutSource: - run_name = parse_run_name_from_metadata_path(metadata_path) - rollout_paths = tuple(sorted(metadata_path.parent.glob(f"{run_name}_rollouts_*.jsonl"))) - if not rollout_paths: - raise FileNotFoundError(f"Could not find rollout shards for run {run_name} next to {metadata_path}") - return RolloutSource( - input_arg=input_arg, - root_path=metadata_path.parent.absolute(), - metadata_path=metadata_path.absolute(), - rollout_paths=rollout_paths, - run_name=run_name, - ) - - -def build_rollout_source_from_rollout(rollout_path: Path, *, input_arg: str) -> RolloutSource: - run_name = parse_run_name_from_rollout_path(rollout_path) - metadata_path = rollout_path.parent / f"{run_name}_metadata.jsonl" - if not metadata_path.exists(): - raise FileNotFoundError(f"Could not find metadata file {metadata_path} for rollout shard {rollout_path}") - return build_rollout_source_from_metadata(metadata_path, input_arg=input_arg) - - -def parse_run_name_from_metadata_path(metadata_path: Path) -> str: - suffix = "_metadata.jsonl" - if not metadata_path.name.endswith(suffix): - raise ValueError(f"Metadata path must end with {suffix}: {metadata_path}") - return metadata_path.name[: -len(suffix)] - - -def parse_run_name_from_rollout_path(rollout_path: Path) -> str: - marker = "_rollouts_" - if marker not in rollout_path.name: - raise ValueError(f"Rollout shard filename must contain {marker}: {rollout_path}") - return rollout_path.name.split(marker, 1)[0] - - -def build_contributions_for_source( - *, source_run: RolloutSource, task_filters: set[str], strict: bool -) -> tuple[list[dict[str, Any]], int]: - run_metadata = read_rollout_metadata(source_run.metadata_path, fallback_run_name=source_run.run_name) - contributions: list[dict[str, Any]] = [] - malformed_records = 0 - - for rollout_path in source_run.rollout_paths: - for line_number, record in enumerate(read_jsonl(rollout_path), start=1): - try: - contribution = build_rollout_contribution( - record=record, source_run=source_run, run_metadata=run_metadata - ) - except Exception as exc: - malformed_records += 1 - message = f"Malformed rollout record in {rollout_path}:{line_number}: {exc}" - if strict: - raise ValueError(message) from exc - logger.warning(message) - continue - - task_name = stable_string(contribution.get("task_name")) - if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: - continue - contributions.append(contribution) - - return contributions, malformed_records - - -def read_rollout_metadata(metadata_path: Path, *, fallback_run_name: str) -> dict[str, Any]: - rows = read_jsonl(metadata_path) - if not rows: - raise ValueError(f"Metadata file is empty: {metadata_path}") - if len(rows) > 1: - logger.warning("Expected one metadata row in %s but found %s. Using the first row.", metadata_path, len(rows)) - - metadata = rows[0] - return { - "run_name": optional_string(metadata.get("run_name")) or fallback_run_name, - "model_name": optional_string(metadata.get("model_name")), - "experiment_id": optional_string(metadata.get("experiment_id")), - "git_commit": optional_string(metadata.get("git_commit")), - "timestamp": optional_string(metadata.get("timestamp")), - } - - -def build_rollout_contribution( - *, record: dict[str, Any], source_run: RolloutSource, run_metadata: dict[str, Any] -) -> dict[str, Any]: - task_name = normalize_task_name(record.get("dataset")) - if task_name is None: - raise ValueError("missing dataset/verifier source") - - source_dataset = normalize_source_dataset(record.get("source_dataset")) - source_dataset_id = extract_source_dataset_id(record) - - prompt_tokens = normalize_token_list(record.get("prompt_tokens")) - if prompt_tokens is None and (source_dataset is None or source_dataset_id is None): - raise ValueError("missing prompt_tokens and source dataset identity (source_dataset/source_row_id)") - - reward = extract_numeric_reward(record.get("reward")) - if reward is None: - raise ValueError("missing or invalid reward") - - ground_truth = make_jsonable(record.get("ground_truth")) - finish_reason = optional_string(record.get("finish_reason")) - - return { - "instance_id": make_rollout_instance_id( - task_name=task_name, - prompt_tokens=prompt_tokens, - ground_truth=ground_truth, - source_dataset=source_dataset, - source_dataset_id=source_dataset_id, - ), - "task_name": task_name, - "base_task_name": get_base_task_name(task_name), - "prompt_tokens": prompt_tokens, - "ground_truth": ground_truth, - "source_dataset": source_dataset, - "source_dataset_id": source_dataset_id, - "score_source": task_name, - "attempt_scores": [reward], - "finish_reasons": [finish_reason] if finish_reason else [], - "experiment_metadata": { - "source_root": str(source_run.root_path), - "model_name": run_metadata["model_name"], - "experiment_id": run_metadata["experiment_id"], - "experiment_name": run_metadata["run_name"], - }, - "warnings": extract_rollout_warnings(record.get("request_info")), - } - - -def normalize_task_name(value: Any) -> str | None: - if value is None: - return None - if isinstance(value, str): - return value - if isinstance(value, (list, tuple)) and len(value) == 1: - return normalize_task_name(value[0]) - serialized = serialize_value(value) - return serialized or None - - -def normalize_source_dataset(value: Any) -> str | None: - if value is None: - return None - if isinstance(value, str): - return value - if isinstance(value, (list, tuple)) and len(value) == 1: - return normalize_source_dataset(value[0]) - serialized = serialize_value(value) - return serialized or None - - -def extract_source_dataset_id(record: dict[str, Any]) -> int | None: - for field_name in ("source_dataset_id", "source_row_id"): - source_dataset_id = normalize_source_dataset_id(record.get(field_name)) - if source_dataset_id is not None: - return source_dataset_id - return None - - -def normalize_source_dataset_id(value: Any) -> int | None: - return normalize_nonnegative_int(value) - - -def normalize_token_list(value: Any) -> list[int] | None: - if not isinstance(value, list): - return None - - tokens: list[int] = [] - for item in value: - if isinstance(item, bool) or not isinstance(item, (int, float)): - return None - tokens.append(int(item)) - return tokens - - -def extract_numeric_reward(value: Any) -> float | None: - if not is_number(value): - return None - return float(value) - - -def extract_rollout_warnings(request_info: Any) -> list[str]: - if not isinstance(request_info, dict): - return [] - - warnings: list[str] = [] - if request_info.get("timeouts"): - warnings.append("timeout") - if optional_string(request_info.get("tool_errors")): - warnings.append("tool_error") - return warnings - - -def aggregate_contributions(contributions: list[dict[str, Any]]) -> list[dict[str, Any]]: - grouped: dict[str, dict[str, Any]] = {} - - for contribution in contributions: - instance_id = contribution["instance_id"] - if instance_id not in grouped: - grouped[instance_id] = { - key: value - for key, value in contribution.items() - if key not in {"attempt_scores", "finish_reasons", "experiment_metadata", "warnings", "score_source"} - } - grouped[instance_id]["attempt_scores"] = [] - grouped[instance_id]["finish_reasons"] = [] - grouped[instance_id]["experiment_metadata"] = None - grouped[instance_id]["score_sources"] = set() - grouped[instance_id]["warnings"] = set() - - row = grouped[instance_id] - row["attempt_scores"].extend(float(score) for score in contribution["attempt_scores"]) - row["finish_reasons"].extend(contribution["finish_reasons"]) - row["experiment_metadata"] = merge_experiment_metadata( - existing=row["experiment_metadata"], incoming=contribution["experiment_metadata"], instance_id=instance_id - ) - row["score_sources"].add(stable_string(contribution["score_source"])) - row["warnings"].update(contribution["warnings"]) - - rows: list[dict[str, Any]] = [] - for row in grouped.values(): - row["attempt_scores"] = [float(score) for score in row["attempt_scores"]] - row["finish_reasons"] = [stable_string(reason) for reason in row["finish_reasons"] if stable_string(reason)] - row["experiment_metadata"] = normalize_experiment_metadata(row["experiment_metadata"]) - row["score_sources"] = sorted(value for value in row["score_sources"] if value) - row["warnings"] = sorted(value for value in row["warnings"] if value) - rows.append(row) - - return rows - - -def strip_output_only_rollout_fields(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [{key: value for key, value in row.items() if key not in {"prompt_tokens", "ground_truth"}} for row in rows] - - -def normalize_attempt_scores_for_group( - rows: list[dict[str, Any]], *, allow_nonunit_scores: bool -) -> tuple[list[dict[str, Any]], dict[str, Any], int]: - score_processing = infer_score_processing(rows) - normalized_rows: list[dict[str, Any]] = [] - skipped_nonunit = 0 - - for row in rows: - normalized_scores = normalize_attempt_scores(row["attempt_scores"], score_processing) - if normalized_scores is None: - if allow_nonunit_scores: - kept_row = dict(row) - kept_row["attempt_scores"] = [float(score) for score in row["attempt_scores"]] - kept_row["warnings"] = sorted({*kept_row["warnings"], "nonbinary_reward_scores"}) - normalized_rows.append(kept_row) - else: - skipped_nonunit += 1 - continue - - normalized_row = dict(row) - normalized_row["attempt_scores"] = normalized_scores - normalized_rows.append(normalized_row) - - return normalized_rows, score_processing, skipped_nonunit - - -def infer_score_processing(rows: list[dict[str, Any]]) -> dict[str, Any]: - scores = [float(score) for row in rows for score in row.get("attempt_scores", [])] - score_processing = { - "source_field": "reward", - "output_field": "attempt_scores", - "normalization": "unsupported", - "positive_reward_value": None, - "supports_binary_difficulty": False, - } - - if not scores: - return score_processing - - if all(is_close(score, 0.0) or is_close(score, 1.0) for score in scores): - score_processing["normalization"] = "identity_binary" - score_processing["positive_reward_value"] = 1.0 - score_processing["supports_binary_difficulty"] = True - return score_processing - - if any(score < -EPS for score in scores): - return score_processing - - positive_scores = [score for score in scores if score > EPS] - if not positive_scores: - score_processing["normalization"] = "all_zero_binary" - score_processing["supports_binary_difficulty"] = True - return score_processing - - positive_reward_value = max(positive_scores) - if all(is_close(score, 0.0) or is_close(score, positive_reward_value) for score in scores): - score_processing["normalization"] = "binary_zero_or_constant" - score_processing["positive_reward_value"] = positive_reward_value - score_processing["supports_binary_difficulty"] = True - - return score_processing - - -def normalize_attempt_scores(attempt_scores: list[float], score_processing: dict[str, Any]) -> list[float] | None: - if not score_processing.get("supports_binary_difficulty"): - return None - - normalization = stable_string(score_processing.get("normalization")) - positive_reward_value = score_processing.get("positive_reward_value") - normalized_scores: list[float] = [] - - for score in attempt_scores: - if is_close(score, 0.0): - normalized_scores.append(0.0) - continue - - if normalization == "identity_binary" and is_close(score, 1.0): - normalized_scores.append(1.0) - continue - - if ( - normalization == "binary_zero_or_constant" - and positive_reward_value is not None - and is_close(score, float(positive_reward_value)) - ): - normalized_scores.append(1.0) - continue - - if normalization == "all_zero_binary": - return None - - return None - - return normalized_scores - - -def estimate_beta_prior(rows: list[dict[str, Any]], *, prior_mode: str) -> tuple[BetaPrior | None, int]: - binary_counts = [counts for row in rows if (counts := extract_binary_counts(row["attempt_scores"])) is not None] - if not binary_counts: - return None, 0 - - if prior_mode == "jeffreys": - return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys"), len(binary_counts) - - prior = fit_empirical_beta_prior(binary_counts) - if prior is not None: - return prior, len(binary_counts) - - logger.warning("Falling back to Jeffreys prior after empirical-Bayes fitting failed.") - return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys_fallback"), len(binary_counts) - - -def apply_beta_binomial_difficulty( - rows: list[dict[str, Any]], *, prior: BetaPrior | None, lower_quantile: float, num_buckets: int -) -> list[dict[str, Any]]: - posterior_rows: list[DifficultyPosteriorRow] = [] - - for row in rows: - row["difficulty"] = make_empty_difficulty_payload() - - if prior is None: - continue - - binary_counts = extract_binary_counts(row["attempt_scores"]) - if binary_counts is None: - continue - - success_count, attempt_count = binary_counts - posterior_alpha = success_count + prior.alpha - posterior_beta = attempt_count - success_count + prior.beta - posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta) - posterior_lower_bound = float(beta_distribution.ppf(lower_quantile, posterior_alpha, posterior_beta)) - - row["difficulty"] = { - "value": max(0.0, min(1.0, 1.0 - posterior_lower_bound)), - "posterior_mean": posterior_mean, - "posterior_lower_bound": posterior_lower_bound, - "expected_quantile": None, - "bucket_index": None, - "bucket_count": None, - } - posterior_rows.append( - DifficultyPosteriorRow(row=row, difficulty_alpha=posterior_beta, difficulty_beta=posterior_alpha) - ) - - assign_posterior_difficulty_buckets(posterior_rows, num_buckets=num_buckets) - return rows - - -def make_empty_difficulty_payload() -> dict[str, Any]: - return { - "value": None, - "posterior_mean": None, - "posterior_lower_bound": None, - "expected_quantile": None, - "bucket_index": None, - "bucket_count": None, - } - - -def assign_posterior_difficulty_buckets(posterior_rows: list[DifficultyPosteriorRow], *, num_buckets: int) -> None: - if not posterior_rows: - return - - expected_quantiles = estimate_expected_difficulty_quantiles(posterior_rows) - for posterior_row, expected_quantile in zip(posterior_rows, expected_quantiles, strict=True): - posterior_row.row["difficulty"]["expected_quantile"] = expected_quantile - - if num_buckets <= 0: - return - - effective_bucket_count = min(num_buckets, len(posterior_rows)) - ordered_rows = sorted( - zip(posterior_rows, expected_quantiles, strict=True), - key=lambda item: (item[1], item[0].row["difficulty"]["value"], stable_string(item[0].row["instance_id"])), - ) - base_bucket_size, remainder = divmod(len(ordered_rows), effective_bucket_count) - - cursor = 0 - for bucket_index in range(effective_bucket_count): - bucket_size = base_bucket_size + (1 if bucket_index < remainder else 0) - for posterior_row, _expected_quantile in ordered_rows[cursor : cursor + bucket_size]: - posterior_row.row["difficulty"]["bucket_index"] = bucket_index - posterior_row.row["difficulty"]["bucket_count"] = effective_bucket_count - cursor += bucket_size - - -def estimate_expected_difficulty_quantiles( - posterior_rows: list[DifficultyPosteriorRow], - *, - grid_size: int = POSTERIOR_QUANTILE_GRID_SIZE, - batch_size: int = POSTERIOR_QUANTILE_BATCH_SIZE, -) -> list[float]: - if not posterior_rows: - return [] - if len(posterior_rows) == 1: - return [0.5] - - grid = (np.arange(grid_size, dtype=np.float64) + 0.5) / grid_size - difficulty_alphas = np.asarray([row.difficulty_alpha for row in posterior_rows], dtype=np.float64) - difficulty_betas = np.asarray([row.difficulty_beta for row in posterior_rows], dtype=np.float64) - - mixture_cdf = np.zeros(grid_size, dtype=np.float64) - for start in range(0, len(posterior_rows), batch_size): - stop = start + batch_size - batch_cdf = beta_distribution.cdf( - grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] - ) - mixture_cdf += np.nan_to_num(batch_cdf, nan=0.0, posinf=1.0, neginf=0.0).sum(axis=0) - mixture_cdf /= len(posterior_rows) - - quantiles = np.zeros(len(posterior_rows), dtype=np.float64) - dx = 1.0 / grid_size - for start in range(0, len(posterior_rows), batch_size): - stop = start + batch_size - batch_pdf = beta_distribution.pdf( - grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] - ) - quantiles[start:stop] = np.clip( - np.nan_to_num(batch_pdf, nan=0.0, posinf=0.0, neginf=0.0).dot(mixture_cdf) * dx, 0.0, 1.0 - ) - - return quantiles.tolist() - - -def fit_empirical_beta_prior(binary_counts: list[tuple[int, int]]) -> BetaPrior | None: - total_successes = sum(success_count for success_count, _ in binary_counts) - total_attempts = sum(attempt_count for _, attempt_count in binary_counts) - if total_attempts == 0 or total_successes in {0, total_attempts}: - return None - - mean_rate = total_successes / total_attempts - init_alpha = max(mean_rate * 2.0, 1e-3) - init_beta = max((1.0 - mean_rate) * 2.0, 1e-3) - - def objective(log_params: tuple[float, float]) -> float: - alpha = math.exp(log_params[0]) - beta = math.exp(log_params[1]) - return -sum( - betaln(success_count + alpha, attempt_count - success_count + beta) - betaln(alpha, beta) - for success_count, attempt_count in binary_counts - ) - - result = minimize( - objective, - x0=(math.log(init_alpha), math.log(init_beta)), - method="L-BFGS-B", - bounds=[(-10.0, 10.0), (-10.0, 10.0)], - ) - if not result.success: - logger.warning("Empirical-Bayes fit failed: %s", result.message) - return None - - return BetaPrior(alpha=math.exp(result.x[0]), beta=math.exp(result.x[1]), source="empirical_bayes") - - -def merge_experiment_metadata( - existing: dict[str, Any] | None, incoming: dict[str, Any], *, instance_id: str -) -> dict[str, Any]: - normalized_incoming = normalize_experiment_metadata(incoming) - if existing is None: - return normalized_incoming - - merged = dict(existing) - for key in EXPERIMENT_METADATA_KEYS: - existing_value = merged.get(key) - incoming_value = normalized_incoming.get(key) - if existing_value in {None, ""}: - merged[key] = incoming_value - elif incoming_value in {None, ""} or incoming_value == existing_value: - continue - else: - raise ValueError( - f"Conflicting experiment metadata for instance {instance_id}: " - f"{key}={existing_value!r} vs {incoming_value!r}" - ) - return merged - - -def normalize_experiment_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]: - if metadata is None: - return {key: None for key in EXPERIMENT_METADATA_KEYS} - return {key: metadata.get(key) for key in EXPERIMENT_METADATA_KEYS} - - -def resolve_output_root(output: Path) -> Path: - output_str = str(output) - if output_str.endswith(".schema.json"): - return Path(output_str[: -len(".schema.json")]) - if output_str.endswith(".jsonl"): - return Path(output_str[: -len(".jsonl")]) - if output_str.endswith(".json"): - return Path(output_str[: -len(".json")]) - return output - - -def build_output_paths( - output_root: Path, *, task_name: str, model_name: str | None, dataset_metadata: dict[str, Any] -) -> tuple[Path, Path, Path]: - task_suffix = sanitize_name(task_name) or "unknown-task" - model_suffix = sanitize_name(model_name or "") or "unknown-model" - difficulty_suffix = build_difficulty_filename_suffix(dataset_metadata) - stem = output_root / f"{task_suffix}__{model_suffix}{difficulty_suffix}" - return Path(f"{stem}.jsonl"), Path(f"{stem}.schema.json"), Path(f"{stem}.metadata.json") - - -def write_output_files( - *, output_jsonl: Path, schema_json: Path, metadata_json: Path, dataset: Dataset, dataset_metadata: dict[str, Any] -) -> None: - output_jsonl.parent.mkdir(parents=True, exist_ok=True) - with output_jsonl.open("w") as output_file: - for row in dataset: - output_file.write(json.dumps(make_jsonable(row), ensure_ascii=False) + "\n") - - schema_json.parent.mkdir(parents=True, exist_ok=True) - try: - schema_payload: Any = dataset.features.to_dict() - except AttributeError: - schema_payload = str(dataset.features) - with schema_json.open("w") as output_file: - json.dump(schema_payload, output_file, indent=2, sort_keys=True) - - metadata_json.parent.mkdir(parents=True, exist_ok=True) - with metadata_json.open("w") as output_file: - json.dump(dataset_metadata, output_file, indent=2, sort_keys=True) - - -def build_dataset_metadata( - *, - rows: list[dict[str, Any]], - task_name: str, - model_name: str | None, - requested_prior_mode: str, - requested_bucket_count: int, - lower_quantile: float, - prior: BetaPrior | None, - binary_row_count: int, - score_processing: dict[str, Any], - source_format: dict[str, Any], -) -> dict[str, Any]: - effective_bucket_count = extract_effective_bucket_count(rows) - difficulty_generation = { - "method": DIFFICULTY_GENERATION_METHOD, - "difficulty_value_field": "difficulty.value", - "difficulty_value_definition": "1 - difficulty.posterior_lower_bound", - "bucket_field": "difficulty.bucket_index", - "bucket_count_field": "difficulty.bucket_count", - "bucket_ranking_field": "difficulty.expected_quantile", - "posterior_lower_quantile": lower_quantile, - "bucket_count_requested": requested_bucket_count, - "bucket_count_effective": effective_bucket_count, - "beta_prior_requested": requested_prior_mode, - "beta_prior_used": { - "source": prior.source if prior is not None else None, - "alpha": prior.alpha if prior is not None else None, - "beta": prior.beta if prior is not None else None, - }, - "binary_instance_count": binary_row_count, - "nonbinary_instance_count": max(0, len(rows) - binary_row_count), - } - difficulty_generation["tag"] = build_difficulty_config_tag(difficulty_generation) - return { - "task_name": task_name, - "model_name": model_name, - "row_count": len(rows), - "source_format": dict(source_format), - "score_processing": dict(score_processing), - "difficulty_generation": difficulty_generation, - } - - -def extract_effective_bucket_count(rows: list[dict[str, Any]]) -> int: - effective_bucket_counts = { - difficulty.get("bucket_count") - for row in rows - if isinstance((difficulty := row.get("difficulty")), dict) and difficulty.get("bucket_count") is not None - } - if not effective_bucket_counts: - return 0 - if len(effective_bucket_counts) != 1: - raise ValueError(f"Expected a single effective bucket count, found {sorted(effective_bucket_counts)}") - return next(iter(effective_bucket_counts)) - - -def build_difficulty_filename_suffix(dataset_metadata: dict[str, Any]) -> str: - return f"__{dataset_metadata['difficulty_generation']['tag']}" - - -def build_difficulty_config_tag(difficulty_generation: dict[str, Any]) -> str: - method_token = abbreviate_filename_token( - optional_string(difficulty_generation.get("method")), - aliases=DIFFICULTY_METHOD_FILENAME_ALIASES, - default="diff", - ) - prior_source = optional_string((difficulty_generation.get("beta_prior_used") or {}).get("source")) - prior_token = abbreviate_filename_token(prior_source, aliases=PRIOR_SOURCE_FILENAME_ALIASES, default="none") - quantile_token = format_quantile_token(difficulty_generation["posterior_lower_quantile"]) - bucket_token = format_bucket_token( - requested_count=difficulty_generation["bucket_count_requested"], - effective_count=difficulty_generation["bucket_count_effective"], - ) - return "-".join([method_token, prior_token, quantile_token, bucket_token]) - - -def abbreviate_filename_token(value: str | None, *, aliases: dict[str, str], default: str) -> str: - if not value: - return default - return aliases.get(value, sanitize_name(value)) - - -def format_quantile_token(value: float) -> str: - return f"q{format_filename_number(value * 100.0)}" - - -def format_bucket_token(*, requested_count: int, effective_count: int) -> str: - if requested_count == effective_count: - return f"k{requested_count}" - return f"k{requested_count}e{effective_count}" - - -def annotate_dataset_metadata(dataset: Dataset, dataset_metadata: dict[str, Any]) -> None: - if not hasattr(dataset, "info") or dataset.info is None: - return - dataset.info.description = json.dumps(dataset_metadata, indent=2, sort_keys=True) - - -def validate_args(args: argparse.Namespace) -> None: - if not 0.0 < args.posterior_lower_quantile < 1.0: - raise ValueError("--posterior-lower-quantile must be between 0 and 1.") - if args.difficulty_buckets < 0: - raise ValueError("--difficulty-buckets must be non-negative.") - if args.max_instances is not None and args.max_instances <= 0: - raise ValueError("--max-instances must be positive when provided.") - - -def group_rows_by_task_and_model(rows: list[dict[str, Any]]) -> dict[tuple[str, str | None], list[dict[str, Any]]]: - rows_by_group: dict[tuple[str, str | None], list[dict[str, Any]]] = defaultdict(list) - for row in rows: - experiment_metadata = row.get("experiment_metadata") or {} - task_name = stable_string(row.get("task_name")) - model_name = optional_string(experiment_metadata.get("model_name")) - rows_by_group[(task_name, model_name)].append(row) - return dict(rows_by_group) - - -def read_jsonl(path: Path) -> list[dict[str, Any]]: - with path.open() as input_file: - return [json.loads(line) for line in input_file if line.strip()] - - -def get_base_task_name(task_name: str) -> str: - return task_name.split("@", 1)[0].split(":", 1)[0] - - -def extract_binary_counts(attempt_scores: list[float]) -> tuple[int, int] | None: - if not attempt_scores: - return None - - success_count = 0 - for score in attempt_scores: - if is_close(score, 0.0): - continue - if is_close(score, 1.0): - success_count += 1 - continue - return None - - return success_count, len(attempt_scores) - - -def make_rollout_instance_id( - *, - task_name: str, - prompt_tokens: list[int] | None, - ground_truth: Any, - source_dataset: str | None = None, - source_dataset_id: int | None = None, -) -> str: - if source_dataset is not None and source_dataset_id is not None: - return f"{source_dataset}::{source_dataset_id}" - - if prompt_tokens is None: - raise ValueError("prompt_tokens are required when source row identity is unavailable") - - fingerprint = {"task_name": task_name, "prompt_tokens": prompt_tokens, "ground_truth": make_jsonable(ground_truth)} - digest = hashlib.sha1(canonical_json(fingerprint).encode("utf-8")).hexdigest()[:20] - task_prefix = sanitize_name(task_name) or "unknown" - return f"{task_prefix}::{digest}" - - -def canonical_json(value: Any) -> str: - return json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True, separators=(",", ":")) - - -def make_jsonable(value: Any) -> Any: - if value is None or isinstance(value, (str, int, float, bool)): - return value - if isinstance(value, list): - return [make_jsonable(item) for item in value] - if isinstance(value, tuple): - return [make_jsonable(item) for item in value] - if isinstance(value, dict): - return {stable_string(key): make_jsonable(item) for key, item in value.items()} - return stable_string(value) - - -def stable_string(value: Any) -> str: - if value is None: - return "" - if isinstance(value, str): - return value - return str(value) - - -def optional_string(value: Any) -> str | None: - text = stable_string(value) - return text or None - - -def serialize_value(value: Any) -> str | None: - if value is None: - return None - if isinstance(value, str): - return value - return json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True) - - -def format_filename_number(value: float) -> str: - text = f"{value:.8g}" - return text.replace("-", "m").replace(".", "p") - - -def sanitize_name(value: str) -> str: - return value.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") - - -def is_number(value: Any) -> bool: - return isinstance(value, (int, float)) and not isinstance(value, bool) and not math.isnan(float(value)) - - -def is_close(lhs: float, rhs: float) -> bool: - tolerance = EPS * max(1.0, abs(lhs), abs(rhs)) - return abs(lhs - rhs) <= tolerance - - -if __name__ == "__main__": - main() diff --git a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py index 12ebe6a906..8140c713c3 100644 --- a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py +++ b/scripts/data/difficulty_sampling/create_bucketed_difficulty.py @@ -8,19 +8,1140 @@ # ] # /// -"""Thin CLI wrapper for ``open_instruct.rlvr_difficulty``.""" +""" +Build a per-instance difficulty map from Hugging Face datasets with pass-rate +aggregates. + +The script loads a Hugging Face dataset that already contains per-row pass +counts, expands those counts into binary attempt outcomes, fits a Beta prior +across binary outcomes, estimates per-item difficulty, and writes JSONL +difficulty files plus schema/metadata sidecars. When `--push-to-hub` is set, it +also uploads the validated output dataset to the requested Hugging Face repo. +Hub uploads require exactly one task/model output group, so use `--task` or a +single-group input dataset when pushing. + +Examples: + Write local difficulty files: + uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --output /tmp/dapo_math_qwen3_difficulty + + Write local files and push the single output group to the Hub: + uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --task math \ + --output /tmp/dapo_math_qwen3_difficulty \ + --push-to-hub your-org/dapo-math-qwen3-difficulty \ + --split train +""" from __future__ import annotations -import sys +import argparse +import json +import logging +import math +from collections import defaultdict +from dataclasses import asdict, dataclass, field, is_dataclass from pathlib import Path +from typing import Any + +import numpy as np +from datasets import Dataset, load_dataset +from scipy.optimize import minimize +from scipy.special import betaln +from scipy.stats import beta as beta_distribution + +logger = logging.getLogger(__name__) + + +EPS = 1e-8 +JEFFREYS_PRIOR_ALPHA = 0.5 +JEFFREYS_PRIOR_BETA = 0.5 +DEFAULT_DIFFICULTY_BUCKETS = 5 +POSTERIOR_QUANTILE_GRID_SIZE = 512 +POSTERIOR_QUANTILE_BATCH_SIZE = 256 +DIFFICULTY_GENERATION_METHOD = "beta_binomial_posterior_quantiles" +DIFFICULTY_METHOD_FILENAME_ALIASES = {DIFFICULTY_GENERATION_METHOD: "bbq"} +PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} +HF_SOURCE_FORMAT_KIND = "hugging_face_dataset_passrate_rows" +HF_INSTANCE_ID_DEFINITION = ( + "dataset_repo_id::row_id_field when a stable row id is available; otherwise dataset_repo_id::row_index" +) +HF_OUTPUT_COLUMNS = ("difficulty",) + + +@dataclass(frozen=True) +class BetaPrior: + alpha: float + beta: float + source: str + + +@dataclass +class ExperimentMetadata: + source_root: str + model_name: str | None + experiment_id: str | None + experiment_name: str + + +@dataclass +class DifficultyPayload: + value: float | None = None + posterior_mean: float | None = None + posterior_lower_bound: float | None = None + expected_quantile: float | None = None + bucket_index: int | None = None + bucket_count: int | None = None + + +@dataclass +class DifficultyRow: + source_row_index: int + instance_id: str + task_name: str + base_task_name: str + source_dataset: str + source_row_id: str + attempt_scores: list[float] + finish_reasons: list[str] + experiment_metadata: ExperimentMetadata + score_sources: list[str] + warnings: list[str] + difficulty: DifficultyPayload = field(default_factory=DifficultyPayload) + + +@dataclass(frozen=True) +class SourceFormatMetadata: + kind: str + dataset_repo_id: str + config_name: str | None + split: str + row_id_field: str + task_field: str + model_field: str + pass_count_field: str + attempt_count_field: str + pass_rate_field: str | None + instance_id_definition: str + + +@dataclass +class ScoreProcessingMetadata: + source_field: str = "reward" + output_field: str = "attempt_scores" + normalization: str = "unsupported" + positive_reward_value: float | None = None + supports_binary_difficulty: bool = False + + +@dataclass(frozen=True) +class BetaPriorUsageMetadata: + source: str | None + alpha: float | None + beta: float | None + + +@dataclass +class DifficultyGenerationMetadata: + method: str + difficulty_value_field: str + difficulty_value_definition: str + bucket_field: str + bucket_count_field: str + bucket_ranking_field: str + posterior_lower_quantile: float + bucket_count_requested: int + bucket_count_effective: int + beta_prior_requested: str + beta_prior_used: BetaPriorUsageMetadata + binary_instance_count: int + nonbinary_instance_count: int + tag: str | None = None + + +@dataclass +class DatasetMetadata: + task_name: str + model_name: str | None + row_count: int + source_format: SourceFormatMetadata + score_processing: ScoreProcessingMetadata + difficulty_generation: DifficultyGenerationMetadata + + +@dataclass(frozen=True) +class DifficultyPosteriorRow: + row: DifficultyRow + difficulty_alpha: float + difficulty_beta: float + + +@dataclass(frozen=True) +class InputRowsBundle: + rows: list[DifficultyRow] + malformed_records: int + source_format: SourceFormatMetadata + source_dataset: Dataset + + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Build a per-instance difficulty map from HF pass-rate datasets.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--hf-dataset", + type=str, + required=True, + help="Hugging Face dataset repo id containing per-row pass-rate aggregates.", + ) + parser.add_argument("--hf-config", type=str, default=None, help="Optional dataset config for --hf-dataset.") + parser.add_argument("--hf-split", type=str, default="train", help="Input split to load from --hf-dataset.") + parser.add_argument( + "--hf-row-id-field", + type=str, + default="extra_info.index", + help="Dot-path to the stable per-row id field inside --hf-dataset.", + ) + parser.add_argument( + "--hf-task-field", type=str, default="dataset", help="Dot-path to the task/verifier field in --hf-dataset." + ) + parser.add_argument( + "--hf-model-field", + type=str, + default="generator_model", + help="Dot-path to the generator model field in --hf-dataset.", + ) + parser.add_argument( + "--hf-pass-count-field", + type=str, + default="pass_count", + help="Dot-path to the integer pass-count field in --hf-dataset.", + ) + parser.add_argument( + "--hf-attempt-count-field", + type=str, + default="num_samples", + help="Dot-path to the total-attempt-count field in --hf-dataset.", + ) + parser.add_argument( + "--hf-pass-rate-field", + type=str, + default="pass_rate", + help="Optional dot-path to a pass-rate or fraction field used for validation/fallback in --hf-dataset.", + ) + parser.add_argument( + "--task", action="append", default=[], help="Optional task filter. Matches the dataset task/verifier source." + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help=( + "Output directory or path-like root. The script writes one file per task/model inside it as " + "____.jsonl plus matching .schema.json and .metadata.json sidecars." + ), + ) + parser.add_argument( + "--push-to-hub", + type=str, + default=None, + help=("Optional dataset repo id to push the validated rows to. Requires exactly one task/model output group."), + ) + parser.add_argument("--split", type=str, default="train", help="Split to use with --push-to-hub.") + parser.add_argument("--strict", action="store_true", help="Fail if an input dataset row is malformed.") + parser.add_argument( + "--allow-nonunit-scores", + action="store_true", + help="Keep rows whose rewards cannot be normalized to binary correctness. Difficulty will be null for them.", + ) + parser.add_argument( + "--max-instances", + type=int, + default=None, + help="Optional cap for the number of resolved instances written (useful for smoke tests).", + ) + parser.add_argument( + "--beta-prior", + choices=["empirical-bayes", "jeffreys"], + default="empirical-bayes", + help="Global Beta prior to use for smoothing binary solve rates.", + ) + parser.add_argument( + "--posterior-lower-quantile", + type=float, + default=0.1, + help="Lower posterior quantile used to define difficulty as 1 - quantile.", + ) + parser.add_argument( + "--difficulty-buckets", + type=int, + default=DEFAULT_DIFFICULTY_BUCKETS, + help=( + "Number of posterior-aware quantile buckets to assign for stratification. " + "Set to 0 to skip discrete bucket assignment." + ), + ) + return parser + + +def main(argv: list[str] | None = None) -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") + args = make_parser().parse_args(argv) + validate_args(args) + task_filters = set(args.task) + output_root = args.output + output_root_str = str(output_root) + for suffix in (".schema.json", ".jsonl", ".json"): + if output_root_str.endswith(suffix): + output_root = Path(output_root_str[: -len(suffix)]) + break + + input_rows = load_hf_dataset_rows( + dataset_name=args.hf_dataset, + config_name=args.hf_config, + split=args.hf_split, + task_filters=task_filters, + strict=args.strict, + row_id_field=args.hf_row_id_field, + task_field=args.hf_task_field, + model_field=args.hf_model_field, + pass_count_field=args.hf_pass_count_field, + attempt_count_field=args.hf_attempt_count_field, + pass_rate_field=args.hf_pass_rate_field, + ) + + if not input_rows.rows: + raise ValueError("No resolved per-instance rows were produced.") + + rows = sorted( + input_rows.rows, key=lambda row: (row.task_name, row.experiment_metadata.model_name or "", row.instance_id) + ) + if args.max_instances is not None: + rows = rows[: args.max_instances] + + rows_by_group = group_rows_by_task_and_model(rows) + if args.push_to_hub is not None and len(rows_by_group) != 1: + raise ValueError( + "--push-to-hub requires a single task/model output. Filter with --task or use a dataset with one task." + ) + + score_source_field = ",".join( + field_name + for field_name in ( + input_rows.source_format.pass_count_field, + input_rows.source_format.attempt_count_field, + input_rows.source_format.pass_rate_field, + ) + if field_name + ) + skipped_nonunit = 0 + written_outputs: list[tuple[str, str | None, int, Path, Path, Path]] = [] + + for (task_name, model_name), group_rows in sorted( + rows_by_group.items(), key=lambda item: (item[0][0], item[0][1] or "") + ): + group_rows, score_processing, group_skipped_nonunit = normalize_attempt_scores_for_group( + group_rows, allow_nonunit_scores=args.allow_nonunit_scores + ) + score_processing.source_field = score_source_field + skipped_nonunit += group_skipped_nonunit + + if not group_rows: + logger.warning( + "Skipping task=%s model=%s because no rows remained after reward normalization.", task_name, model_name + ) + continue + + prior, binary_row_count = estimate_beta_prior(group_rows, prior_mode=args.beta_prior) + group_rows = apply_beta_binomial_difficulty( + group_rows, prior=prior, lower_quantile=args.posterior_lower_quantile, num_buckets=args.difficulty_buckets + ) + output_rows, dataset = build_hf_output_dataset(input_rows.source_dataset, group_rows) + + dataset_metadata = build_dataset_metadata( + rows=group_rows, + task_name=task_name, + model_name=model_name, + requested_prior_mode=args.beta_prior, + requested_bucket_count=args.difficulty_buckets, + lower_quantile=args.posterior_lower_quantile, + prior=prior, + binary_row_count=binary_row_count, + score_processing=score_processing, + source_format=input_rows.source_format, + ) + + if prior is not None: + logger.info( + "Using %s Beta prior alpha=%.4f beta=%.4f across %s binary instances for task=%s model=%s.", + prior.source, + prior.alpha, + prior.beta, + binary_row_count, + task_name, + model_name, + ) + else: + logger.warning( + "No binary instances were available for Beta-Binomial difficulty estimation for task=%s model=%s.", + task_name, + model_name, + ) + + annotate_dataset_metadata(dataset, dataset_metadata) + output_jsonl, schema_json, metadata_json = build_output_paths( + output_root, task_name=task_name, model_name=model_name, dataset_metadata=dataset_metadata + ) + write_output_files( + output_jsonl=output_jsonl, + schema_json=schema_json, + metadata_json=metadata_json, + dataset=dataset, + dataset_metadata=dataset_metadata, + ) + + if args.push_to_hub is not None: + dataset.push_to_hub(args.push_to_hub, split=args.split, private=True) + + written_outputs.append((task_name, model_name, len(output_rows), output_jsonl, schema_json, metadata_json)) + logger.info( + "Wrote %s rows for task=%s model=%s to %s, %s, and %s.", + len(group_rows), + task_name, + model_name, + output_jsonl, + schema_json, + metadata_json, + ) + + logger.info( + "Finished writing %s output file groups (%s malformed dataset rows, %s skipped due to unsupported scores).", + len(written_outputs), + input_rows.malformed_records, + skipped_nonunit, + ) + + +def load_hf_dataset_rows( + *, + dataset_name: str, + config_name: str | None, + split: str, + task_filters: set[str], + strict: bool, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> InputRowsBundle: + logger.info( + "Loading Hugging Face dataset %s (config=%s, split=%s).", dataset_name, config_name or "default", split + ) + + if config_name: + source_dataset = load_dataset(dataset_name, config_name, split=split) + else: + source_dataset = load_dataset(dataset_name, split=split) + + rows: list[DifficultyRow] = [] + malformed_records = 0 + + for row_index, source_row in enumerate(source_dataset): + try: + row = build_hf_dataset_row( + source_row=source_row, + source_row_index=row_index, + dataset_name=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ) + except Exception as exc: + malformed_records += 1 + message = f"Malformed HF dataset row {dataset_name}[{split}][{row_index}]: {exc}" + if strict: + raise ValueError(message) from exc + logger.warning(message) + continue + + task_name = row.task_name + if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: + continue + rows.append(row) + + return InputRowsBundle( + rows=rows, + malformed_records=malformed_records, + source_format=build_hf_source_format_metadata( + dataset_name=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ), + source_dataset=source_dataset, + ) + + +def build_hf_dataset_row( + *, + source_row: dict[str, Any], + source_row_index: int, + dataset_name: str, + config_name: str | None, + split: str, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> DifficultyRow: + task_name = normalize_task_name(get_nested_field(source_row, task_field)) + if task_name is None: + raise ValueError(f"missing task field {task_field!r}") + + source_row_id = normalize_identifier(get_nested_field(source_row, row_id_field)) or str(source_row_index) + pass_count, attempt_count = extract_hf_attempt_summary( + row=source_row, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ) + raw_model_name = get_nested_field(source_row, model_field) + model_name = None if raw_model_name is None else str(raw_model_name) or None + + return DifficultyRow( + source_row_index=source_row_index, + instance_id=f"{dataset_name}::{source_row_id}", + task_name=task_name, + base_task_name=get_base_task_name(task_name), + source_dataset=dataset_name, + source_row_id=source_row_id, + attempt_scores=expand_binary_attempt_scores(pass_count=pass_count, attempt_count=attempt_count), + finish_reasons=[], + experiment_metadata=ExperimentMetadata( + source_root=f"hf://{dataset_name}/{config_name or 'default'}/{split}", + model_name=model_name, + experiment_id=None, + experiment_name=dataset_name, + ), + score_sources=[task_name], + warnings=[], + ) + + +def build_hf_source_format_metadata( + *, + dataset_name: str, + config_name: str | None, + split: str, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> SourceFormatMetadata: + return SourceFormatMetadata( + kind=HF_SOURCE_FORMAT_KIND, + dataset_repo_id=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + instance_id_definition=HF_INSTANCE_ID_DEFINITION, + ) + + +def build_hf_output_dataset( + source_dataset: Dataset, rows: list[DifficultyRow] +) -> tuple[list[dict[str, Any]], Dataset]: + ordered_rows = sorted(rows, key=lambda row: row.source_row_index) + output_rows = [] + for row in ordered_rows: + output_row = asdict(row) + output_row.pop("source_row_index") + output_rows.append(output_row) + + dataset = source_dataset.select([row.source_row_index for row in ordered_rows]) + + for column_name in HF_OUTPUT_COLUMNS: + values = [make_jsonable(getattr(row, column_name)) for row in ordered_rows] + if column_name in dataset.column_names: + dataset = dataset.remove_columns(column_name) + dataset = dataset.add_column(column_name, values) + + return output_rows, dataset + + +def get_nested_field(value: Any, field_path: str) -> Any: + if not field_path: + return value + + current = value + for field_name in field_path.split("."): + if not isinstance(current, dict) or field_name not in current: + return None + current = current[field_name] + return current + + +def normalize_identifier(value: Any) -> str | None: + if value is None or isinstance(value, bool): + return None + text = (value if isinstance(value, str) else str(value)).strip() + return text or None + + +def normalize_nonnegative_int(value: Any) -> int | None: + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value if value >= 0 else None + if isinstance(value, float): + if not math.isfinite(value) or not value.is_integer() or value < 0: + return None + return int(value) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + parsed = int(stripped) + except ValueError: + return None + return parsed if parsed >= 0 else None + return None + + +def parse_pass_rate_value(value: Any) -> tuple[int | None, int | None, float | None]: + if value is None: + return None, None, None + if isinstance(value, (int, float)) and not isinstance(value, bool): + rate = float(value) + if not math.isfinite(rate): + raise ValueError(f"expected finite pass-rate value in [0, 1], received {value!r}") + if 0.0 <= rate <= 1.0: + return None, None, rate + raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") + if not isinstance(value, str): + raise ValueError(f"unsupported pass-rate value {value!r}") + + stripped = value.strip() + if not stripped: + return None, None, None + + if "/" in stripped: + numerator_text, denominator_text = stripped.split("/", 1) + numerator = normalize_nonnegative_int(numerator_text) + denominator = normalize_nonnegative_int(denominator_text) + if numerator is None or denominator is None or numerator > denominator: + raise ValueError(f"invalid pass-rate fraction {value!r}") + rate = 0.0 if denominator == 0 else numerator / denominator + return numerator, denominator, rate + + try: + rate = float(stripped) + except ValueError as exc: + raise ValueError(f"invalid pass-rate value {value!r}") from exc + if not math.isfinite(rate) or rate < 0.0 or rate > 1.0: + raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") + return None, None, rate + + +def extract_hf_attempt_summary( + *, row: dict[str, Any], pass_count_field: str, attempt_count_field: str, pass_rate_field: str | None +) -> tuple[int, int]: + pass_count = normalize_nonnegative_int(get_nested_field(row, pass_count_field)) + attempt_count = normalize_nonnegative_int(get_nested_field(row, attempt_count_field)) + + parsed_pass_count = None + parsed_attempt_count = None + parsed_pass_rate = None + if pass_rate_field: + parsed_pass_count, parsed_attempt_count, parsed_pass_rate = parse_pass_rate_value( + get_nested_field(row, pass_rate_field) + ) + + if pass_count is None and parsed_pass_count is not None: + pass_count = parsed_pass_count + if attempt_count is None and parsed_attempt_count is not None: + attempt_count = parsed_attempt_count + + if pass_count is None or attempt_count is None: + raise ValueError( + f"missing pass-count summary fields {pass_count_field!r}/{attempt_count_field!r}" + f"{f' or parseable {pass_rate_field!r}' if pass_rate_field else ''}" + ) + if attempt_count <= 0: + raise ValueError(f"attempt count must be positive, received {attempt_count}") + if pass_count > attempt_count: + raise ValueError(f"pass count {pass_count} exceeds attempt count {attempt_count}") + + if parsed_pass_count is not None and parsed_pass_count != pass_count: + raise ValueError(f"pass-count field {pass_count_field!r} disagrees with {pass_rate_field!r}") + if parsed_attempt_count is not None and parsed_attempt_count != attempt_count: + raise ValueError(f"attempt-count field {attempt_count_field!r} disagrees with {pass_rate_field!r}") + if parsed_pass_rate is not None and not math.isclose( + pass_count / attempt_count, parsed_pass_rate, rel_tol=EPS, abs_tol=EPS + ): + raise ValueError( + f"pass-count fields {pass_count_field!r}/{attempt_count_field!r} disagree with {pass_rate_field!r}" + ) + + return pass_count, attempt_count + + +def expand_binary_attempt_scores(*, pass_count: int, attempt_count: int) -> list[float]: + return [1.0] * pass_count + [0.0] * (attempt_count - pass_count) + + +def normalize_task_name(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)) and len(value) == 1: + return normalize_task_name(value[0]) + serialized = json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True) + return serialized or None + + +def normalize_attempt_scores_for_group( + rows: list[DifficultyRow], *, allow_nonunit_scores: bool +) -> tuple[list[DifficultyRow], ScoreProcessingMetadata, int]: + score_processing = infer_score_processing(rows) + normalized_rows: list[DifficultyRow] = [] + skipped_nonunit = 0 + + for row in rows: + normalized_scores = normalize_attempt_scores(row.attempt_scores, score_processing) + if normalized_scores is None: + if allow_nonunit_scores: + row.attempt_scores = [float(score) for score in row.attempt_scores] + row.warnings = sorted({*row.warnings, "nonbinary_reward_scores"}) + normalized_rows.append(row) + else: + skipped_nonunit += 1 + continue + + row.attempt_scores = normalized_scores + normalized_rows.append(row) + + return normalized_rows, score_processing, skipped_nonunit + + +def infer_score_processing(rows: list[DifficultyRow]) -> ScoreProcessingMetadata: + scores = [float(score) for row in rows for score in row.attempt_scores] + score_processing = ScoreProcessingMetadata() + + if not scores: + return score_processing + + if all( + math.isclose(score, 0.0, rel_tol=EPS, abs_tol=EPS) or math.isclose(score, 1.0, rel_tol=EPS, abs_tol=EPS) + for score in scores + ): + score_processing.normalization = "identity_binary" + score_processing.positive_reward_value = 1.0 + score_processing.supports_binary_difficulty = True + return score_processing + + if any(score < -EPS for score in scores): + return score_processing + + positive_scores = [score for score in scores if score > EPS] + if not positive_scores: + score_processing.normalization = "all_zero_binary" + score_processing.supports_binary_difficulty = True + return score_processing + + positive_reward_value = max(positive_scores) + if all( + math.isclose(score, 0.0, rel_tol=EPS, abs_tol=EPS) + or math.isclose(score, positive_reward_value, rel_tol=EPS, abs_tol=EPS) + for score in scores + ): + score_processing.normalization = "binary_zero_or_constant" + score_processing.positive_reward_value = positive_reward_value + score_processing.supports_binary_difficulty = True + + return score_processing + + +def normalize_attempt_scores( + attempt_scores: list[float], score_processing: ScoreProcessingMetadata +) -> list[float] | None: + if not score_processing.supports_binary_difficulty: + return None + + normalization = score_processing.normalization + positive_reward_value = score_processing.positive_reward_value + normalized_scores: list[float] = [] + + for score in attempt_scores: + if math.isclose(score, 0.0, rel_tol=EPS, abs_tol=EPS): + normalized_scores.append(0.0) + continue + + if normalization == "identity_binary" and math.isclose(score, 1.0, rel_tol=EPS, abs_tol=EPS): + normalized_scores.append(1.0) + continue + + if ( + normalization == "binary_zero_or_constant" + and positive_reward_value is not None + and math.isclose(score, float(positive_reward_value), rel_tol=EPS, abs_tol=EPS) + ): + normalized_scores.append(1.0) + continue + + if normalization == "all_zero_binary": + return None + + return None + + return normalized_scores + + +def estimate_beta_prior(rows: list[DifficultyRow], *, prior_mode: str) -> tuple[BetaPrior | None, int]: + binary_counts = [counts for row in rows if (counts := extract_binary_counts(row.attempt_scores)) is not None] + if not binary_counts: + return None, 0 + + if prior_mode == "jeffreys": + return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys"), len(binary_counts) + + prior = fit_empirical_beta_prior(binary_counts) + if prior is not None: + return prior, len(binary_counts) + + logger.warning("Falling back to Jeffreys prior after empirical-Bayes fitting failed.") + return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys_fallback"), len(binary_counts) + + +def apply_beta_binomial_difficulty( + rows: list[DifficultyRow], *, prior: BetaPrior | None, lower_quantile: float, num_buckets: int +) -> list[DifficultyRow]: + posterior_rows: list[DifficultyPosteriorRow] = [] + + for row in rows: + row.difficulty = DifficultyPayload() + + if prior is None: + continue + + binary_counts = extract_binary_counts(row.attempt_scores) + if binary_counts is None: + continue + + success_count, attempt_count = binary_counts + posterior_alpha = success_count + prior.alpha + posterior_beta = attempt_count - success_count + prior.beta + posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta) + posterior_lower_bound = float(beta_distribution.ppf(lower_quantile, posterior_alpha, posterior_beta)) + + row.difficulty = DifficultyPayload( + value=max(0.0, min(1.0, 1.0 - posterior_lower_bound)), + posterior_mean=posterior_mean, + posterior_lower_bound=posterior_lower_bound, + ) + posterior_rows.append( + DifficultyPosteriorRow(row=row, difficulty_alpha=posterior_beta, difficulty_beta=posterior_alpha) + ) + + assign_posterior_difficulty_buckets(posterior_rows, num_buckets=num_buckets) + return rows + + +def assign_posterior_difficulty_buckets(posterior_rows: list[DifficultyPosteriorRow], *, num_buckets: int) -> None: + if not posterior_rows: + return + + expected_quantiles = estimate_expected_difficulty_quantiles(posterior_rows) + for posterior_row, expected_quantile in zip(posterior_rows, expected_quantiles, strict=True): + posterior_row.row.difficulty.expected_quantile = expected_quantile + + if num_buckets <= 0: + return + + effective_bucket_count = min(num_buckets, len(posterior_rows)) + ordered_rows = sorted( + zip(posterior_rows, expected_quantiles, strict=True), + key=lambda item: (item[1], item[0].row.difficulty.value, item[0].row.instance_id), + ) + base_bucket_size, remainder = divmod(len(ordered_rows), effective_bucket_count) + + cursor = 0 + for bucket_index in range(effective_bucket_count): + bucket_size = base_bucket_size + (1 if bucket_index < remainder else 0) + for posterior_row, _expected_quantile in ordered_rows[cursor : cursor + bucket_size]: + posterior_row.row.difficulty.bucket_index = bucket_index + posterior_row.row.difficulty.bucket_count = effective_bucket_count + cursor += bucket_size + + +def estimate_expected_difficulty_quantiles( + posterior_rows: list[DifficultyPosteriorRow], + *, + grid_size: int = POSTERIOR_QUANTILE_GRID_SIZE, + batch_size: int = POSTERIOR_QUANTILE_BATCH_SIZE, +) -> list[float]: + if not posterior_rows: + return [] + if len(posterior_rows) == 1: + return [0.5] + + grid = (np.arange(grid_size, dtype=np.float64) + 0.5) / grid_size + difficulty_alphas = np.asarray([row.difficulty_alpha for row in posterior_rows], dtype=np.float64) + difficulty_betas = np.asarray([row.difficulty_beta for row in posterior_rows], dtype=np.float64) + + mixture_cdf = np.zeros(grid_size, dtype=np.float64) + for start in range(0, len(posterior_rows), batch_size): + stop = start + batch_size + batch_cdf = beta_distribution.cdf( + grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] + ) + mixture_cdf += np.nan_to_num(batch_cdf, nan=0.0, posinf=1.0, neginf=0.0).sum(axis=0) + mixture_cdf /= len(posterior_rows) + + quantiles = np.zeros(len(posterior_rows), dtype=np.float64) + dx = 1.0 / grid_size + for start in range(0, len(posterior_rows), batch_size): + stop = start + batch_size + batch_pdf = beta_distribution.pdf( + grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] + ) + quantiles[start:stop] = np.clip( + np.nan_to_num(batch_pdf, nan=0.0, posinf=0.0, neginf=0.0).dot(mixture_cdf) * dx, 0.0, 1.0 + ) + + return quantiles.tolist() + + +def fit_empirical_beta_prior(binary_counts: list[tuple[int, int]]) -> BetaPrior | None: + total_successes = sum(success_count for success_count, _ in binary_counts) + total_attempts = sum(attempt_count for _, attempt_count in binary_counts) + if total_attempts == 0 or total_successes in {0, total_attempts}: + return None + + mean_rate = total_successes / total_attempts + init_alpha = max(mean_rate * 2.0, 1e-3) + init_beta = max((1.0 - mean_rate) * 2.0, 1e-3) + + def objective(log_params: tuple[float, float]) -> float: + alpha = math.exp(log_params[0]) + beta = math.exp(log_params[1]) + return -sum( + betaln(success_count + alpha, attempt_count - success_count + beta) - betaln(alpha, beta) + for success_count, attempt_count in binary_counts + ) + + result = minimize( + objective, + x0=(math.log(init_alpha), math.log(init_beta)), + method="L-BFGS-B", + bounds=[(-10.0, 10.0), (-10.0, 10.0)], + ) + if not result.success: + logger.warning("Empirical-Bayes fit failed: %s", result.message) + return None + + return BetaPrior(alpha=math.exp(result.x[0]), beta=math.exp(result.x[1]), source="empirical_bayes") + + +def build_output_paths( + output_root: Path, *, task_name: str, model_name: str | None, dataset_metadata: DatasetMetadata +) -> tuple[Path, Path, Path]: + task_suffix = task_name.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") or "unknown-task" + model_suffix = (model_name or "").replace(":", "_").replace("/", "_").replace("\\", "_").replace( + " ", "_" + ) or "unknown-model" + difficulty_suffix = f"__{dataset_metadata.difficulty_generation.tag}" + stem = output_root / f"{task_suffix}__{model_suffix}{difficulty_suffix}" + return Path(f"{stem}.jsonl"), Path(f"{stem}.schema.json"), Path(f"{stem}.metadata.json") + + +def write_output_files( + *, output_jsonl: Path, schema_json: Path, metadata_json: Path, dataset: Dataset, dataset_metadata: DatasetMetadata +) -> None: + output_jsonl.parent.mkdir(parents=True, exist_ok=True) + with output_jsonl.open("w") as output_file: + for row in dataset: + output_file.write(json.dumps(make_jsonable(row), ensure_ascii=False) + "\n") + + schema_json.parent.mkdir(parents=True, exist_ok=True) + try: + schema_payload: Any = dataset.features.to_dict() + except AttributeError: + schema_payload = str(dataset.features) + with schema_json.open("w") as output_file: + json.dump(schema_payload, output_file, indent=2, sort_keys=True) + + metadata_json.parent.mkdir(parents=True, exist_ok=True) + with metadata_json.open("w") as output_file: + json.dump(make_jsonable(dataset_metadata), output_file, indent=2, sort_keys=True) + + +def build_dataset_metadata( + *, + rows: list[DifficultyRow], + task_name: str, + model_name: str | None, + requested_prior_mode: str, + requested_bucket_count: int, + lower_quantile: float, + prior: BetaPrior | None, + binary_row_count: int, + score_processing: ScoreProcessingMetadata, + source_format: SourceFormatMetadata, +) -> DatasetMetadata: + effective_bucket_count = extract_effective_bucket_count(rows) + difficulty_generation = DifficultyGenerationMetadata( + method=DIFFICULTY_GENERATION_METHOD, + difficulty_value_field="difficulty.value", + difficulty_value_definition="1 - difficulty.posterior_lower_bound", + bucket_field="difficulty.bucket_index", + bucket_count_field="difficulty.bucket_count", + bucket_ranking_field="difficulty.expected_quantile", + posterior_lower_quantile=lower_quantile, + bucket_count_requested=requested_bucket_count, + bucket_count_effective=effective_bucket_count, + beta_prior_requested=requested_prior_mode, + beta_prior_used=BetaPriorUsageMetadata( + source=prior.source if prior is not None else None, + alpha=prior.alpha if prior is not None else None, + beta=prior.beta if prior is not None else None, + ), + binary_instance_count=binary_row_count, + nonbinary_instance_count=max(0, len(rows) - binary_row_count), + ) + method_value = difficulty_generation.method + if method_value: + method_token = DIFFICULTY_METHOD_FILENAME_ALIASES.get( + method_value, method_value.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") + ) + else: + method_token = "diff" + + prior_source = difficulty_generation.beta_prior_used.source + if prior_source: + prior_token = PRIOR_SOURCE_FILENAME_ALIASES.get( + prior_source, prior_source.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") + ) + else: + prior_token = "none" + + quantile_token = ( + f"q{f'{difficulty_generation.posterior_lower_quantile * 100.0:.8g}'.replace('-', 'm').replace('.', 'p')}" + ) + if difficulty_generation.bucket_count_requested == difficulty_generation.bucket_count_effective: + bucket_token = f"k{difficulty_generation.bucket_count_requested}" + else: + bucket_token = ( + f"k{difficulty_generation.bucket_count_requested}e{difficulty_generation.bucket_count_effective}" + ) + difficulty_generation.tag = "-".join([method_token, prior_token, quantile_token, bucket_token]) + return DatasetMetadata( + task_name=task_name, + model_name=model_name, + row_count=len(rows), + source_format=source_format, + score_processing=score_processing, + difficulty_generation=difficulty_generation, + ) + + +def extract_effective_bucket_count(rows: list[DifficultyRow]) -> int: + effective_bucket_counts = {row.difficulty.bucket_count for row in rows if row.difficulty.bucket_count is not None} + if not effective_bucket_counts: + return 0 + if len(effective_bucket_counts) != 1: + raise ValueError(f"Expected a single effective bucket count, found {sorted(effective_bucket_counts)}") + return next(iter(effective_bucket_counts)) + + +def annotate_dataset_metadata(dataset: Dataset, dataset_metadata: DatasetMetadata) -> None: + if not hasattr(dataset, "info") or dataset.info is None: + return + dataset.info.description = json.dumps(make_jsonable(dataset_metadata), indent=2, sort_keys=True) + + +def validate_args(args: argparse.Namespace) -> None: + if not 0.0 < args.posterior_lower_quantile < 1.0: + raise ValueError("--posterior-lower-quantile must be between 0 and 1.") + if args.difficulty_buckets < 0: + raise ValueError("--difficulty-buckets must be non-negative.") + if args.max_instances is not None and args.max_instances <= 0: + raise ValueError("--max-instances must be positive when provided.") + + +def group_rows_by_task_and_model(rows: list[DifficultyRow]) -> dict[tuple[str, str | None], list[DifficultyRow]]: + rows_by_group: dict[tuple[str, str | None], list[DifficultyRow]] = defaultdict(list) + for row in rows: + task_name = row.task_name + model_name = row.experiment_metadata.model_name + rows_by_group[(task_name, model_name)].append(row) + return dict(rows_by_group) + + +def get_base_task_name(task_name: str) -> str: + return task_name.split("@", 1)[0].split(":", 1)[0] + + +def extract_binary_counts(attempt_scores: list[float]) -> tuple[int, int] | None: + if not attempt_scores: + return None + + success_count = 0 + for score in attempt_scores: + if math.isclose(score, 0.0, rel_tol=EPS, abs_tol=EPS): + continue + if math.isclose(score, 1.0, rel_tol=EPS, abs_tol=EPS): + success_count += 1 + continue + return None + + return success_count, len(attempt_scores) -REPO_ROOT = Path(__file__).resolve().parents[3] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) -from open_instruct import rlvr_difficulty +def make_jsonable(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if is_dataclass(value): + return {key: make_jsonable(item) for key, item in asdict(value).items()} + if isinstance(value, list): + return [make_jsonable(item) for item in value] + if isinstance(value, tuple): + return [make_jsonable(item) for item in value] + if isinstance(value, dict): + return { + (key if isinstance(key, str) else "" if key is None else str(key)): make_jsonable(item) + for key, item in value.items() + } + return "" if value is None else str(value) if __name__ == "__main__": - rlvr_difficulty.main() + main() diff --git a/tests/test_create_bucketed_difficulty.py b/tests/test_create_bucketed_difficulty.py index 21c0c550af..83bb476c1a 100644 --- a/tests/test_create_bucketed_difficulty.py +++ b/tests/test_create_bucketed_difficulty.py @@ -1,10 +1,9 @@ -"""Unit tests for posterior-aware bucketing in open_instruct.rlvr_difficulty.""" +"""Unit tests for posterior-aware bucketing in create_bucketed_difficulty.py.""" -import importlib +import importlib.util import json import math import sys -import tempfile import types import unittest from collections import Counter @@ -14,6 +13,8 @@ import numpy as np +MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts/data/difficulty_sampling/create_bucketed_difficulty.py" + def _load_create_bucketed_difficulty_module(): fake_datasets = types.ModuleType("datasets") @@ -82,10 +83,16 @@ def ppf(cls, q, alpha, beta): "scipy.special": fake_scipy_special, "scipy.stats": fake_scipy_stats, } + module_name = "test_create_bucketed_difficulty_module" + spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) with patch.dict(sys.modules, modules): - sys.modules.pop("open_instruct.rlvr_difficulty", None) - return importlib.import_module("open_instruct.rlvr_difficulty") + sys.modules.pop(module_name, None) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module MODULE = _load_create_bucketed_difficulty_module() @@ -123,133 +130,57 @@ def add_column(self, name, values): [{**row, name: value} for row, value in zip(self._rows, values, strict=True)] ) - def test_discover_rollout_sources_resolves_directory_runs(self): - with tempfile.TemporaryDirectory() as tmpdir: - root = Path(tmpdir) - (root / "demo_run_metadata.jsonl").write_text( - json.dumps({"run_name": "demo_run", "model_name": "demo-model"}) + "\n" - ) - (root / "demo_run_rollouts_000000.jsonl").write_text( - json.dumps( - { - "prompt_tokens": [1, 2, 3], - "reward": 1.0, - "finish_reason": "stop", - "dataset": "math", - "ground_truth": "4", - } - ) - + "\n" - ) - - sources = MODULE.discover_rollout_sources([str(root)]) - - self.assertEqual(len(sources), 1) - self.assertEqual(sources[0].run_name, "demo_run") - self.assertEqual(sources[0].metadata_path.name, "demo_run_metadata.jsonl") - self.assertEqual([path.name for path in sources[0].rollout_paths], ["demo_run_rollouts_000000.jsonl"]) - - def test_rollout_contributions_aggregate_and_normalize_constant_rewards(self): - with tempfile.TemporaryDirectory() as tmpdir: - root = Path(tmpdir) - (root / "demo_run_metadata.jsonl").write_text( - json.dumps({"run_name": "demo_run", "model_name": "Qwen/Qwen3-4B-Base"}) + "\n" - ) - shard = root / "demo_run_rollouts_000000.jsonl" - shard.write_text( - "\n".join( - [ - json.dumps( - { - "prompt_tokens": [11, 12, 13], - "reward": 10.0, - "finish_reason": "stop", - "dataset": "math", - "ground_truth": {"answer": "4"}, - "request_info": {"timeouts": 0, "tool_errors": ""}, - } - ), - json.dumps( - { - "prompt_tokens": [11, 12, 13], - "reward": 0.0, - "finish_reason": "length", - "dataset": "math", - "ground_truth": {"answer": "4"}, - "request_info": {"timeouts": 1, "tool_errors": ""}, - } - ), - json.dumps( - { - "prompt_tokens": [21, 22, 23], - "reward": 10.0, - "finish_reason": "stop", - "dataset": "math", - "ground_truth": {"answer": "9"}, - "request_info": {"timeouts": 0, "tool_errors": ""}, - } - ), - ] - ) - + "\n" - ) - - source = MODULE.discover_rollout_sources([str(root)])[0] - contributions, malformed_records = MODULE.build_contributions_for_source( - source_run=source, task_filters=set(), strict=True - ) - - self.assertEqual(malformed_records, 0) - - rows = MODULE.aggregate_contributions(contributions) - self.assertEqual(len(rows), 2) - - rows_by_group = MODULE.group_rows_by_task_and_model(rows) - group_rows, score_processing, skipped_nonunit = MODULE.normalize_attempt_scores_for_group( - rows_by_group[("math", "Qwen/Qwen3-4B-Base")], allow_nonunit_scores=False + def make_row( + self, + *, + source_row_index=0, + instance_id="row-0", + task_name="math", + source_dataset="mnoukhov/demo", + source_row_id="row-0", + attempt_scores=None, + model_name="demo-model", + warnings=None, + difficulty=None, + ): + return MODULE.DifficultyRow( + source_row_index=source_row_index, + instance_id=instance_id, + task_name=task_name, + base_task_name=MODULE.get_base_task_name(task_name), + source_dataset=source_dataset, + source_row_id=source_row_id, + attempt_scores=list(attempt_scores or []), + finish_reasons=[], + experiment_metadata=MODULE.ExperimentMetadata( + source_root=f"hf://{source_dataset}/default/train", + model_name=model_name, + experiment_id=None, + experiment_name=source_dataset, + ), + score_sources=[task_name], + warnings=list(warnings or []), + difficulty=difficulty or MODULE.DifficultyPayload(), ) - self.assertEqual(skipped_nonunit, 0) - self.assertEqual(score_processing["normalization"], "binary_zero_or_constant") - self.assertEqual(score_processing["positive_reward_value"], 10.0) + def test_parser_requires_hf_dataset_and_rejects_source(self): + with self.assertRaises(SystemExit): + MODULE.make_parser().parse_args(["--output", "/tmp/difficulty"]) - easy_row = next(row for row in group_rows if row["ground_truth"] == {"answer": "4"}) - self.assertEqual(easy_row["attempt_scores"], [1.0, 0.0]) - self.assertEqual(easy_row["prompt_tokens"], [11, 12, 13]) - self.assertEqual(easy_row["finish_reasons"], ["stop", "length"]) - self.assertEqual(easy_row["score_sources"], ["math"]) - self.assertEqual(easy_row["experiment_metadata"]["model_name"], "Qwen/Qwen3-4B-Base") - self.assertIn("timeout", easy_row["warnings"]) + with self.assertRaises(SystemExit): + MODULE.make_parser().parse_args(["--source", "/tmp/rollouts", "--output", "/tmp/difficulty"]) def test_normalize_attempt_scores_for_group_marks_unsupported_rewards(self): - rows = [ - { - "instance_id": "example", - "task_name": "math", - "base_task_name": "math", - "prompt_tokens": [1, 2, 3], - "ground_truth": "4", - "attempt_scores": [10.0, 5.0], - "finish_reasons": ["stop", "stop"], - "experiment_metadata": { - "source_root": "/tmp/example-rollouts", - "model_name": "demo-model", - "experiment_id": None, - "experiment_name": "demo-run", - }, - "score_sources": ["math"], - "warnings": [], - } - ] + rows = [self.make_row(instance_id="example", source_row_id="example", attempt_scores=[10.0, 5.0])] kept_rows, score_processing, skipped_nonunit = MODULE.normalize_attempt_scores_for_group( rows, allow_nonunit_scores=True ) self.assertEqual(skipped_nonunit, 0) - self.assertFalse(score_processing["supports_binary_difficulty"]) - self.assertEqual(kept_rows[0]["attempt_scores"], [10.0, 5.0]) - self.assertIn("nonbinary_reward_scores", kept_rows[0]["warnings"]) + self.assertFalse(score_processing.supports_binary_difficulty) + self.assertEqual(kept_rows[0].attempt_scores, [10.0, 5.0]) + self.assertIn("nonbinary_reward_scores", kept_rows[0].warnings) dropped_rows, _, dropped_count = MODULE.normalize_attempt_scores_for_group(rows, allow_nonunit_scores=False) @@ -258,29 +189,32 @@ def test_normalize_attempt_scores_for_group_marks_unsupported_rewards(self): def test_build_dataset_metadata_captures_difficulty_generation_details(self): rows = [ - { - "instance_id": "easy", - "difficulty": { - "value": 0.1, - "posterior_mean": 0.2, - "posterior_lower_bound": 0.9, - "expected_quantile": 0.2, - "bucket_index": 0, - "bucket_count": 3, - }, - }, - { - "instance_id": "hard", - "difficulty": { - "value": 0.8, - "posterior_mean": 0.7, - "posterior_lower_bound": 0.2, - "expected_quantile": 0.9, - "bucket_index": 2, - "bucket_count": 3, - }, - }, - {"instance_id": "nonbinary", "difficulty": MODULE.make_empty_difficulty_payload()}, + self.make_row( + instance_id="easy", + source_row_id="easy", + difficulty=MODULE.DifficultyPayload( + value=0.1, + posterior_mean=0.2, + posterior_lower_bound=0.9, + expected_quantile=0.2, + bucket_index=0, + bucket_count=3, + ), + ), + self.make_row( + source_row_index=1, + instance_id="hard", + source_row_id="hard", + difficulty=MODULE.DifficultyPayload( + value=0.8, + posterior_mean=0.7, + posterior_lower_bound=0.2, + expected_quantile=0.9, + bucket_index=2, + bucket_count=3, + ), + ), + self.make_row(source_row_index=2, instance_id="nonbinary", source_row_id="nonbinary"), ] metadata = MODULE.build_dataset_metadata( @@ -292,31 +226,41 @@ def test_build_dataset_metadata_captures_difficulty_generation_details(self): lower_quantile=0.1, prior=MODULE.BetaPrior(alpha=0.75, beta=1.25, source="empirical_bayes"), binary_row_count=2, - score_processing={ - "source_field": "reward", - "output_field": "attempt_scores", - "normalization": "binary_zero_or_constant", - "positive_reward_value": 10.0, - "supports_binary_difficulty": True, - }, - source_format=MODULE.build_rollout_source_format_metadata(), + score_processing=MODULE.ScoreProcessingMetadata( + source_field="reward", + output_field="attempt_scores", + normalization="binary_zero_or_constant", + positive_reward_value=10.0, + supports_binary_difficulty=True, + ), + source_format=MODULE.build_hf_source_format_metadata( + dataset_name="mnoukhov/demo", + config_name=None, + split="train", + row_id_field="extra_info.index", + task_field="dataset", + model_field="generator_model", + pass_count_field="pass_count", + attempt_count_field="num_samples", + pass_rate_field="pass_rate", + ), ) - self.assertEqual(metadata["task_name"], "math") - self.assertEqual(metadata["model_name"], "demo-model") - self.assertEqual(metadata["row_count"], 3) - self.assertEqual(metadata["source_format"]["kind"], "open_instruct_rollout_traces") - self.assertEqual(metadata["score_processing"]["normalization"], "binary_zero_or_constant") - self.assertEqual(metadata["score_processing"]["positive_reward_value"], 10.0) - self.assertEqual(metadata["difficulty_generation"]["method"], "beta_binomial_posterior_quantiles") - self.assertEqual(metadata["difficulty_generation"]["posterior_lower_quantile"], 0.1) - self.assertEqual(metadata["difficulty_generation"]["bucket_count_requested"], 5) - self.assertEqual(metadata["difficulty_generation"]["bucket_count_effective"], 3) - self.assertEqual(metadata["difficulty_generation"]["beta_prior_used"]["source"], "empirical_bayes") - self.assertEqual(metadata["difficulty_generation"]["beta_prior_used"]["alpha"], 0.75) - self.assertEqual(metadata["difficulty_generation"]["beta_prior_used"]["beta"], 1.25) - self.assertEqual(metadata["difficulty_generation"]["binary_instance_count"], 2) - self.assertEqual(metadata["difficulty_generation"]["nonbinary_instance_count"], 1) + self.assertEqual(metadata.task_name, "math") + self.assertEqual(metadata.model_name, "demo-model") + self.assertEqual(metadata.row_count, 3) + self.assertEqual(metadata.source_format.kind, MODULE.HF_SOURCE_FORMAT_KIND) + self.assertEqual(metadata.score_processing.normalization, "binary_zero_or_constant") + self.assertEqual(metadata.score_processing.positive_reward_value, 10.0) + self.assertEqual(metadata.difficulty_generation.method, "beta_binomial_posterior_quantiles") + self.assertEqual(metadata.difficulty_generation.posterior_lower_quantile, 0.1) + self.assertEqual(metadata.difficulty_generation.bucket_count_requested, 5) + self.assertEqual(metadata.difficulty_generation.bucket_count_effective, 3) + self.assertEqual(metadata.difficulty_generation.beta_prior_used.source, "empirical_bayes") + self.assertEqual(metadata.difficulty_generation.beta_prior_used.alpha, 0.75) + self.assertEqual(metadata.difficulty_generation.beta_prior_used.beta, 1.25) + self.assertEqual(metadata.difficulty_generation.binary_instance_count, 2) + self.assertEqual(metadata.difficulty_generation.nonbinary_instance_count, 1) def test_build_hf_dataset_row_parses_pass_rate_counts(self): row = MODULE.build_hf_dataset_row( @@ -340,11 +284,11 @@ def test_build_hf_dataset_row_parses_pass_rate_counts(self): pass_rate_field="pass_rate", ) - self.assertEqual(row["instance_id"], "mnoukhov/demo::row-7") - self.assertEqual(row["source_row_id"], "row-7") - self.assertEqual(row["attempt_scores"], [1.0, 1.0, 1.0, 0.0, 0.0]) - self.assertEqual(row["experiment_metadata"]["model_name"], "Qwen/Qwen3-4B-Base") - self.assertEqual(row["experiment_metadata"]["source_root"], "hf://mnoukhov/demo/default/train") + self.assertEqual(row.instance_id, "mnoukhov/demo::row-7") + self.assertEqual(row.source_row_id, "row-7") + self.assertEqual(row.attempt_scores, [1.0, 1.0, 1.0, 0.0, 0.0]) + self.assertEqual(row.experiment_metadata.model_name, "Qwen/Qwen3-4B-Base") + self.assertEqual(row.experiment_metadata.source_root, "hf://mnoukhov/demo/default/train") def test_load_hf_dataset_rows_builds_bundle_and_filters_tasks(self): fake_dataset = self.FakeHFDataset( @@ -384,11 +328,11 @@ def test_load_hf_dataset_rows_builds_bundle_and_filters_tasks(self): ) self.assertEqual(bundle.malformed_records, 0) - self.assertEqual(bundle.source_format["kind"], MODULE.HF_SOURCE_FORMAT_KIND) - self.assertEqual(bundle.source_format["dataset_repo_id"], "mnoukhov/demo") + self.assertEqual(bundle.source_format.kind, MODULE.HF_SOURCE_FORMAT_KIND) + self.assertEqual(bundle.source_format.dataset_repo_id, "mnoukhov/demo") self.assertEqual(len(bundle.rows), 1) - self.assertEqual(bundle.rows[0]["instance_id"], "mnoukhov/demo::math-1") - self.assertEqual(bundle.rows[0]["attempt_scores"], [1.0, 1.0, 0.0, 0.0]) + self.assertEqual(bundle.rows[0].instance_id, "mnoukhov/demo::math-1") + self.assertEqual(bundle.rows[0].attempt_scores, [1.0, 1.0, 0.0, 0.0]) def test_build_hf_output_dataset_preserves_source_rows_and_order(self): source_dataset = self.FakeHFDataset( @@ -398,64 +342,44 @@ def test_build_hf_output_dataset_preserves_source_rows_and_order(self): ] ) rows = [ - { - MODULE.HF_SOURCE_ROW_INDEX_FIELD: 1, - "instance_id": "mnoukhov/demo::row-1", - "task_name": "math", - "base_task_name": "math", - "source_dataset": "mnoukhov/demo", - "source_row_id": "row-1", - "attempt_scores": [0.0, 0.0], - "finish_reasons": [], - "experiment_metadata": { - "source_root": "hf://mnoukhov/demo/default/train", - "model_name": "Qwen/Qwen3-4B-Base", - "experiment_id": None, - "experiment_name": "mnoukhov/demo", - }, - "score_sources": ["math"], - "warnings": [], - "difficulty": { - "value": 0.9, - "posterior_mean": 0.1, - "posterior_lower_bound": 0.1, - "expected_quantile": 0.9, - "bucket_index": 1, - "bucket_count": 2, - }, - }, - { - MODULE.HF_SOURCE_ROW_INDEX_FIELD: 0, - "instance_id": "mnoukhov/demo::row-0", - "task_name": "math", - "base_task_name": "math", - "source_dataset": "mnoukhov/demo", - "source_row_id": "row-0", - "attempt_scores": [1.0, 1.0], - "finish_reasons": [], - "experiment_metadata": { - "source_root": "hf://mnoukhov/demo/default/train", - "model_name": "Qwen/Qwen3-4B-Base", - "experiment_id": None, - "experiment_name": "mnoukhov/demo", - }, - "score_sources": ["math"], - "warnings": [], - "difficulty": { - "value": 0.1, - "posterior_mean": 0.9, - "posterior_lower_bound": 0.9, - "expected_quantile": 0.1, - "bucket_index": 0, - "bucket_count": 2, - }, - }, + self.make_row( + source_row_index=1, + instance_id="mnoukhov/demo::row-1", + source_row_id="row-1", + attempt_scores=[0.0, 0.0], + model_name="Qwen/Qwen3-4B-Base", + difficulty=MODULE.DifficultyPayload( + value=0.9, + posterior_mean=0.1, + posterior_lower_bound=0.1, + expected_quantile=0.9, + bucket_index=1, + bucket_count=2, + ), + ), + self.make_row( + source_row_index=0, + instance_id="mnoukhov/demo::row-0", + source_row_id="row-0", + attempt_scores=[1.0, 1.0], + model_name="Qwen/Qwen3-4B-Base", + difficulty=MODULE.DifficultyPayload( + value=0.1, + posterior_mean=0.9, + posterior_lower_bound=0.9, + expected_quantile=0.1, + bucket_index=0, + bucket_count=2, + ), + ), ] - dataset = MODULE.build_hf_output_dataset(source_dataset, rows) + output_rows, dataset = MODULE.build_hf_output_dataset(source_dataset, rows) self.assertEqual(len(dataset), 2) self.assertEqual(dataset.column_names, ["prompt", "extra_info", "difficulty"]) + self.assertEqual(output_rows[0]["source_row_id"], "row-0") + self.assertEqual(output_rows[1]["source_row_id"], "row-1") self.assertEqual(dataset[0]["prompt"], "first") self.assertEqual(dataset[0]["difficulty"]["bucket_index"], 0) self.assertEqual(dataset[1]["prompt"], "second") @@ -470,85 +394,105 @@ def __init__(self): self.info = FakeInfo() dataset = FakeDataset() - dataset_metadata = {"task_name": "math", "difficulty_generation": {"bucket_count_requested": 5}} + dataset_metadata = MODULE.build_dataset_metadata( + rows=[self.make_row(difficulty=MODULE.DifficultyPayload(bucket_index=0, bucket_count=5))], + task_name="math", + model_name="demo-model", + requested_prior_mode="empirical-bayes", + requested_bucket_count=5, + lower_quantile=0.1, + prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="empirical_bayes"), + binary_row_count=1, + score_processing=MODULE.ScoreProcessingMetadata( + source_field="pass_count,num_samples,pass_rate", + output_field="attempt_scores", + normalization="identity_binary", + positive_reward_value=1.0, + supports_binary_difficulty=True, + ), + source_format=MODULE.build_hf_source_format_metadata( + dataset_name="mnoukhov/demo", + config_name=None, + split="train", + row_id_field="extra_info.index", + task_field="dataset", + model_field="generator_model", + pass_count_field="pass_count", + attempt_count_field="num_samples", + pass_rate_field="pass_rate", + ), + ) MODULE.annotate_dataset_metadata(dataset, dataset_metadata) - self.assertEqual(json.loads(dataset.info.description), dataset_metadata) - - def test_normalize_experiment_metadata_uses_canonical_source_root_only(self): - normalized = MODULE.normalize_experiment_metadata( - { - "source_root": "/tmp/example-rollouts", - "source_input": "/tmp/example-rollouts/demo_run_metadata.jsonl", - "model_name": "demo-model", - "experiment_id": "exp-123", - "experiment_name": "demo-run", - } - ) - - self.assertEqual( - normalized, - { - "source_root": "/tmp/example-rollouts", - "model_name": "demo-model", - "experiment_id": "exp-123", - "experiment_name": "demo-run", - }, - ) + self.assertEqual(json.loads(dataset.info.description), MODULE.make_jsonable(dataset_metadata)) def test_apply_beta_binomial_difficulty_orders_rows_by_expected_quantile(self): rows = [ - {"instance_id": "easy", "attempt_scores": [1.0, 1.0, 1.0, 1.0]}, - {"instance_id": "medium", "attempt_scores": [1.0, 1.0, 0.0, 0.0]}, - {"instance_id": "hard", "attempt_scores": [0.0, 0.0, 0.0, 0.0]}, + self.make_row(instance_id="easy", source_row_id="easy", attempt_scores=[1.0, 1.0, 1.0, 1.0]), + self.make_row( + source_row_index=1, instance_id="medium", source_row_id="medium", attempt_scores=[1.0, 1.0, 0.0, 0.0] + ), + self.make_row( + source_row_index=2, instance_id="hard", source_row_id="hard", attempt_scores=[0.0, 0.0, 0.0, 0.0] + ), ] result = MODULE.apply_beta_binomial_difficulty( rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=3 ) - difficulties = {row["instance_id"]: row["difficulty"] for row in result} + difficulties = {row.instance_id: row.difficulty for row in result} - self.assertLess(difficulties["easy"]["expected_quantile"], difficulties["medium"]["expected_quantile"]) - self.assertLess(difficulties["medium"]["expected_quantile"], difficulties["hard"]["expected_quantile"]) - self.assertEqual(difficulties["easy"]["bucket_index"], 0) - self.assertEqual(difficulties["medium"]["bucket_index"], 1) - self.assertEqual(difficulties["hard"]["bucket_index"], 2) - self.assertTrue(all(difficulty["bucket_count"] == 3 for difficulty in difficulties.values())) + self.assertLess(difficulties["easy"].expected_quantile, difficulties["medium"].expected_quantile) + self.assertLess(difficulties["medium"].expected_quantile, difficulties["hard"].expected_quantile) + self.assertEqual(difficulties["easy"].bucket_index, 0) + self.assertEqual(difficulties["medium"].bucket_index, 1) + self.assertEqual(difficulties["hard"].bucket_index, 2) + self.assertTrue(all(difficulty.bucket_count == 3 for difficulty in difficulties.values())) def test_apply_beta_binomial_difficulty_balances_bucket_sizes(self): rows = [ - {"instance_id": "easiest", "attempt_scores": [1.0, 1.0, 1.0, 1.0]}, - {"instance_id": "easy", "attempt_scores": [1.0, 1.0, 1.0, 0.0]}, - {"instance_id": "mid", "attempt_scores": [1.0, 1.0, 0.0, 0.0]}, - {"instance_id": "hard", "attempt_scores": [1.0, 0.0, 0.0, 0.0]}, - {"instance_id": "hardest", "attempt_scores": [0.0, 0.0, 0.0, 0.0]}, + self.make_row(instance_id="easiest", source_row_id="easiest", attempt_scores=[1.0, 1.0, 1.0, 1.0]), + self.make_row( + source_row_index=1, instance_id="easy", source_row_id="easy", attempt_scores=[1.0, 1.0, 1.0, 0.0] + ), + self.make_row( + source_row_index=2, instance_id="mid", source_row_id="mid", attempt_scores=[1.0, 1.0, 0.0, 0.0] + ), + self.make_row( + source_row_index=3, instance_id="hard", source_row_id="hard", attempt_scores=[1.0, 0.0, 0.0, 0.0] + ), + self.make_row( + source_row_index=4, instance_id="hardest", source_row_id="hardest", attempt_scores=[0.0, 0.0, 0.0, 0.0] + ), ] result = MODULE.apply_beta_binomial_difficulty( rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=2 ) - bucket_counts = Counter(row["difficulty"]["bucket_index"] for row in result) + bucket_counts = Counter(row.difficulty.bucket_index for row in result) self.assertEqual(bucket_counts[0], 3) self.assertEqual(bucket_counts[1], 2) def test_apply_beta_binomial_difficulty_leaves_nonbinary_rows_unbucketed(self): rows = [ - {"instance_id": "easy", "attempt_scores": [1.0, 1.0]}, - {"instance_id": "nonbinary", "attempt_scores": [0.5, 1.0]}, - {"instance_id": "hard", "attempt_scores": [0.0, 0.0]}, + self.make_row(instance_id="easy", source_row_id="easy", attempt_scores=[1.0, 1.0]), + self.make_row( + source_row_index=1, instance_id="nonbinary", source_row_id="nonbinary", attempt_scores=[0.5, 1.0] + ), + self.make_row(source_row_index=2, instance_id="hard", source_row_id="hard", attempt_scores=[0.0, 0.0]), ] result = MODULE.apply_beta_binomial_difficulty( rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=2 ) - difficulties = {row["instance_id"]: row["difficulty"] for row in result} + difficulties = {row.instance_id: row.difficulty for row in result} - self.assertIsNone(difficulties["nonbinary"]["value"]) - self.assertIsNone(difficulties["nonbinary"]["expected_quantile"]) - self.assertIsNone(difficulties["nonbinary"]["bucket_index"]) - self.assertIsNone(difficulties["nonbinary"]["bucket_count"]) + self.assertIsNone(difficulties["nonbinary"].value) + self.assertIsNone(difficulties["nonbinary"].expected_quantile) + self.assertIsNone(difficulties["nonbinary"].bucket_index) + self.assertIsNone(difficulties["nonbinary"].bucket_count) if __name__ == "__main__": From 484a455fd04f79fea15db31599c7864409d352f7 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 16:04:00 -0700 Subject: [PATCH 23/40] Move some stuff around --- docs/algorithms/grpo.md | 68 --------- scripts/data/difficulty_sampling/README.md | 130 ++++++++++++++++++ ...difficulty.py => create_difficulty_map.py} | 4 +- ...culty.py => test_create_difficulty_map.py} | 18 +-- 4 files changed, 141 insertions(+), 79 deletions(-) create mode 100644 scripts/data/difficulty_sampling/README.md rename scripts/data/difficulty_sampling/{create_bucketed_difficulty.py => create_difficulty_map.py} (99%) rename tests/{test_create_bucketed_difficulty.py => test_create_difficulty_map.py} (97%) diff --git a/docs/algorithms/grpo.md b/docs/algorithms/grpo.md index fb47d3318f..1206d9a490 100644 --- a/docs/algorithms/grpo.md +++ b/docs/algorithms/grpo.md @@ -77,74 +77,6 @@ Both `grpo.py` and `grpo_fast.py` share the same config classes and accept the s | | `--save_freq` | Save every N train steps | `200` | | | `--with_tracking` | Track experiment with Weights and Biases | `False` | -### Difficulty-Aware RLVR Curriculum - -`grpo_fast.py` can optionally replace uniform prompt reshuffling with `DifficultyCurriculumSampler`, a bucket-aware RLVR curriculum driven by per-instance difficulty metadata. The current recommended metadata format comes from the beta-binomial estimator in `scripts/data/difficulty_sampling/create_bucketed_difficulty.py`: - -```json -{ - "difficulty": { - "value": 0.9999999997624719, - "posterior_mean": 0.003437858035078528, - "posterior_lower_bound": 2.3752813430506325e-10, - "expected_quantile": 0.10139684528348392, - "bucket_index": 4, - "bucket_count": 5 - } -} -``` - -- `posterior_mean` is the estimated solve probability for that prompt. Lower means harder. -- `bucket_index = 0` is the easiest bucket and `bucket_index = bucket_count - 1` is the hardest. -- The sampler uses a smooth distribution with a configurable easy-heavy bootstrap phase, then gradually shifts mass toward harder buckets instead of hard-switching between discrete phases. -- Within each bucket, examples are weighted by a blend of uncertainty (`4 * p * (1 - p)`) and hardness (`1 - p`), so borderline prompts stay attractive while already-solved prompts are naturally down-weighted. -- If `--curriculum_adaptive true` is set, bucket probabilities are additionally blended with live reward / advantage statistics so buckets with useful learning signal can get more mass during training. - -Recommended starting settings for `bucket_count=5`: - -- Bootstrap (first ~100 steps by default): buckets 0 and 1 dominate so the model sees easier prompts while it settles into the chat template and task format. -- Early after bootstrap: bucket 2 highest, buckets 1 and 3 nonzero, bucket 4 low. -- Mid: buckets 2 and 3 dominate, with bucket 4 increasing. -- Late: buckets 3 and 4 dominate, while buckets 0-2 remain nonzero. - -Useful flags: - -```bash ---curriculum difficulty \ ---curriculum_metadata_field difficulty \ ---curriculum_bootstrap_steps 100 \ ---curriculum_bootstrap_target 0.125 \ ---curriculum_warmup_target 0.5 \ ---curriculum_final_target 1.0 \ ---curriculum_warmup_steps 500 \ ---curriculum_total_steps 10000 \ ---curriculum_min_hard_frac 0.05 \ ---curriculum_max_hard_frac 0.50 \ ---curriculum_bucket_sigma 0.0 \ ---curriculum_bootstrap_sigma 0.0 \ ---curriculum_uncertainty_weight 0.5 \ ---curriculum_adaptive true -``` - -Tuning tips: - -- Increase `curriculum_bootstrap_steps` to keep the easy bootstrap around longer. -- Lower `curriculum_bootstrap_target` to bias more strongly toward the easiest buckets early. -- Lower `curriculum_bucket_sigma` or `curriculum_bootstrap_sigma` to concentrate probability on fewer neighboring buckets. -- Lower `curriculum_warmup_target` if you want the post-bootstrap warmup to stay easier for longer. - -Metrics are logged through the standard GRPO tracking path. The most useful ones are: - -- `curriculum/progress` -- `curriculum/static_bucket_prob_*` -- `curriculum/adaptive_bucket_prob_*` -- `curriculum/bucket_prob_*` -- `curriculum/sampled_bucket_count_*` -- `curriculum/bucket_reward_mean_*` -- `curriculum/bucket_abs_advantage_mean_*` - -See `scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh` for a concrete launch example. The dataset metadata can be produced with `scripts/data/difficulty_sampling/create_bucketed_difficulty.py`. - For details on how GRPO's HSDP sharding works, see [OLMo-core Sharding and Parallelism](olmo_core_sharding.md). --- diff --git a/scripts/data/difficulty_sampling/README.md b/scripts/data/difficulty_sampling/README.md new file mode 100644 index 0000000000..e946a3f37b --- /dev/null +++ b/scripts/data/difficulty_sampling/README.md @@ -0,0 +1,130 @@ +# Difficulty Sampling + +This directory contains tooling for building per-instance difficulty metadata +for RLVR curricula. + +## Create A Difficulty Map + +Use `create_difficulty_map.py` to build a difficulty map from a Hugging Face +dataset that already contains per-row pass-rate aggregates. + +The script expands pass-count summaries into binary attempt outcomes, fits a +Beta prior across binary outcomes, estimates per-item difficulty, and writes +JSONL difficulty files plus schema and metadata sidecars. + +### Examples + +Write local difficulty files: + +```bash +uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --output /tmp/dapo_math_qwen3_difficulty +``` + +Write local files and push the single output group to the Hub: + +```bash +uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --task math \ + --output /tmp/dapo_math_qwen3_difficulty \ + --push-to-hub your-org/dapo-math-qwen3-difficulty \ + --split train +``` + +Hub uploads require exactly one task/model output group, so use `--task` or a +single-group input dataset when pushing. + +## Difficulty Metadata Format + +`grpo_fast.py` can optionally replace uniform prompt reshuffling with +`DifficultyCurriculumSampler`, a bucket-aware RLVR curriculum driven by +per-instance difficulty metadata. The current recommended metadata format comes +from the beta-binomial estimator in `create_difficulty_map.py`: + +```json +{ + "difficulty": { + "value": 0.9999999997624719, + "posterior_mean": 0.003437858035078528, + "posterior_lower_bound": 2.3752813430506325e-10, + "expected_quantile": 0.10139684528348392, + "bucket_index": 4, + "bucket_count": 5 + } +} +``` + +- `posterior_mean` is the estimated solve probability for that prompt. Lower + means harder. +- `bucket_index = 0` is the easiest bucket and + `bucket_index = bucket_count - 1` is the hardest. +- The sampler uses a smooth distribution with a configurable easy-heavy + bootstrap phase, then gradually shifts mass toward harder buckets instead of + hard-switching between discrete phases. +- Within each bucket, examples are weighted by a blend of uncertainty + (`4 * p * (1 - p)`) and hardness (`1 - p`), so borderline prompts stay + attractive while already-solved prompts are naturally down-weighted. +- If `--curriculum_adaptive true` is set, bucket probabilities are additionally + blended with live reward / advantage statistics so buckets with useful + learning signal can get more mass during training. + +## Recommended Starting Point + +For `bucket_count=5`: + +- Bootstrap (first ~100 steps by default): buckets 0 and 1 dominate so the + model sees easier prompts while it settles into the chat template and task + format. +- Early after bootstrap: bucket 2 highest, buckets 1 and 3 nonzero, bucket 4 + low. +- Mid: buckets 2 and 3 dominate, with bucket 4 increasing. +- Late: buckets 3 and 4 dominate, while buckets 0-2 remain nonzero. + +Useful flags: + +```bash +--curriculum difficulty \ +--curriculum_metadata_field difficulty \ +--curriculum_bootstrap_steps 100 \ +--curriculum_bootstrap_target 0.125 \ +--curriculum_warmup_target 0.5 \ +--curriculum_final_target 1.0 \ +--curriculum_warmup_steps 500 \ +--curriculum_total_steps 10000 \ +--curriculum_min_hard_frac 0.05 \ +--curriculum_max_hard_frac 0.50 \ +--curriculum_bucket_sigma 0.0 \ +--curriculum_bootstrap_sigma 0.0 \ +--curriculum_uncertainty_weight 0.5 \ +--curriculum_adaptive true +``` + +Tuning tips: + +- Increase `curriculum_bootstrap_steps` to keep the easy bootstrap around + longer. +- Lower `curriculum_bootstrap_target` to bias more strongly toward the easiest + buckets early. +- Lower `curriculum_bucket_sigma` or `curriculum_bootstrap_sigma` to + concentrate probability on fewer neighboring buckets. +- Lower `curriculum_warmup_target` if you want the post-bootstrap warmup to + stay easier for longer. + +## Metrics + +The most useful curriculum metrics are: + +- `curriculum/progress` +- `curriculum/static_bucket_prob_*` +- `curriculum/adaptive_bucket_prob_*` +- `curriculum/bucket_prob_*` +- `curriculum/sampled_bucket_count_*` +- `curriculum/bucket_reward_mean_*` +- `curriculum/bucket_abs_advantage_mean_*` + +See `scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh` for a +concrete launch example. diff --git a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py b/scripts/data/difficulty_sampling/create_difficulty_map.py similarity index 99% rename from scripts/data/difficulty_sampling/create_bucketed_difficulty.py rename to scripts/data/difficulty_sampling/create_difficulty_map.py index 8140c713c3..4804b6b044 100644 --- a/scripts/data/difficulty_sampling/create_bucketed_difficulty.py +++ b/scripts/data/difficulty_sampling/create_difficulty_map.py @@ -22,13 +22,13 @@ Examples: Write local difficulty files: - uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ --hf-split train \ --output /tmp/dapo_math_qwen3_difficulty Write local files and push the single output group to the Hub: - uv run scripts/data/difficulty_sampling/create_bucketed_difficulty.py \ + uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ --hf-split train \ --task math \ diff --git a/tests/test_create_bucketed_difficulty.py b/tests/test_create_difficulty_map.py similarity index 97% rename from tests/test_create_bucketed_difficulty.py rename to tests/test_create_difficulty_map.py index 83bb476c1a..6debf587b2 100644 --- a/tests/test_create_bucketed_difficulty.py +++ b/tests/test_create_difficulty_map.py @@ -1,4 +1,4 @@ -"""Unit tests for posterior-aware bucketing in create_bucketed_difficulty.py.""" +"""Unit tests for posterior-aware bucketing in create_difficulty_map.py.""" import importlib.util import json @@ -13,10 +13,10 @@ import numpy as np -MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts/data/difficulty_sampling/create_bucketed_difficulty.py" +MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts/data/difficulty_sampling/create_difficulty_map.py" -def _load_create_bucketed_difficulty_module(): +def _load_create_difficulty_map_module(): fake_datasets = types.ModuleType("datasets") fake_datasets.Dataset = type("Dataset", (), {}) fake_datasets.load_dataset = lambda *_args, **_kwargs: None @@ -83,7 +83,7 @@ def ppf(cls, q, alpha, beta): "scipy.special": fake_scipy_special, "scipy.stats": fake_scipy_stats, } - module_name = "test_create_bucketed_difficulty_module" + module_name = "test_create_difficulty_map_module" spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH) assert spec is not None and spec.loader is not None module = importlib.util.module_from_spec(spec) @@ -95,10 +95,10 @@ def ppf(cls, q, alpha, beta): return module -MODULE = _load_create_bucketed_difficulty_module() +MODULE = _load_create_difficulty_map_module() -class TestCreateBucketedDifficulty(unittest.TestCase): +class TestCreateDifficultyMap(unittest.TestCase): class FakeHFDataset: def __init__(self, rows): self._rows = [dict(row) for row in rows] @@ -117,16 +117,16 @@ def column_names(self): return list(self._rows[0].keys()) if self._rows else [] def select(self, indices): - return TestCreateBucketedDifficulty.FakeHFDataset([self._rows[index] for index in indices]) + return TestCreateDifficultyMap.FakeHFDataset([self._rows[index] for index in indices]) def remove_columns(self, column_names): names = {column_names} if isinstance(column_names, str) else set(column_names) - return TestCreateBucketedDifficulty.FakeHFDataset( + return TestCreateDifficultyMap.FakeHFDataset( [{key: value for key, value in row.items() if key not in names} for row in self._rows] ) def add_column(self, name, values): - return TestCreateBucketedDifficulty.FakeHFDataset( + return TestCreateDifficultyMap.FakeHFDataset( [{**row, name: value} for row, value in zip(self._rows, values, strict=True)] ) From caf89c7367449555490f8bdf61a599f46ea03957 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 16:06:25 -0700 Subject: [PATCH 24/40] Shorten filename --- ...apo_math_difficulty_curriculum.sh => qwen3_4b_dapo_math_dc.sh} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/train/qwen/{qwen3_4b_dapo_math_difficulty_curriculum.sh => qwen3_4b_dapo_math_dc.sh} (100%) diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh similarity index 100% rename from scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh rename to scripts/train/qwen/qwen3_4b_dapo_math_dc.sh From efc2860e920bb64b57bbda0bad0a13452bfa5cdf Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 16:22:44 -0700 Subject: [PATCH 25/40] Rename + revert some unnecessary changes --- open_instruct/data_loader.py | 42 +++------ open_instruct/dataset_transformation.py | 3 +- ...curriculum.py => difficulty_curriculum.py} | 0 open_instruct/environments/tools/parsers.py | 2 +- open_instruct/grpo_fast.py | 10 +-- open_instruct/model_utils.py | 8 -- open_instruct/rl_utils.py | 88 +++---------------- ...culum.py => test_difficulty_curriculum.py} | 28 +++--- open_instruct/test_rollout_traces.py | 84 ------------------ 9 files changed, 45 insertions(+), 220 deletions(-) rename open_instruct/{rlvr_curriculum.py => difficulty_curriculum.py} (100%) rename open_instruct/{test_rlvr_curriculum.py => test_difficulty_curriculum.py} (88%) delete mode 100644 open_instruct/test_rollout_traces.py diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 4e15c81324..7414990fa9 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -33,15 +33,13 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer -from open_instruct import data_types, padding_free_collator, rlvr_curriculum, utils +from open_instruct import data_types, difficulty_curriculum, padding_free_collator, utils from open_instruct.data_types import EnvConfig, EnvConfigEntry from open_instruct.dataset_transformation import ( - DATASET_ORIGIN_KEY, ENV_CONFIG_KEY, GROUND_TRUTHS_KEY, INPUT_IDS_PROMPT_KEY, RAW_PROMPT_KEY, - SOURCE_ROW_ID_KEY, TOOLS_COLUMN_KEY, VERIFIER_SOURCE_KEY, ) @@ -507,8 +505,6 @@ class StreamingDataLoaderConfig: # Rollout saving save_traces: bool = False rollouts_save_path: str = "/weka/oe-adapt-default/allennlp/deletable_rollouts/" - rollout_save_format: Literal["full", "scores_only"] = "full" - """Trace record shape to persist when save_traces is enabled.""" # Computed at post_init max_possible_score: float = 1.0 @@ -691,7 +687,7 @@ def __init__( dp_rank: int, dp_world_size: int, work_dir: str, - curriculum_config: rlvr_curriculum.DifficultyCurriculumConfig, + curriculum_config: difficulty_curriculum.DifficultyCurriculumConfig, automatic_reshuffle: bool = True, collator: Callable[[list[dict[str, Any]]], dict[str, Any]] | None = None, device: torch.device | None = None, @@ -705,7 +701,7 @@ def __init__( raise ValueError("DifficultyCurriculumHFDataLoader currently supports dp_world_size=1 only") self._sampling_step = 0 - self._curriculum_sampler = rlvr_curriculum.DifficultyCurriculumSampler( + self._curriculum_sampler = difficulty_curriculum.DifficultyCurriculumSampler( dataset=dataset, num_samples=max(len(dataset), 1), config=curriculum_config, @@ -729,7 +725,7 @@ def __init__( ) @property - def curriculum_sampler(self) -> rlvr_curriculum.DifficultyCurriculumSampler: + def curriculum_sampler(self) -> difficulty_curriculum.DifficultyCurriculumSampler: return self._curriculum_sampler def set_sampling_step(self, step: int) -> None: @@ -789,14 +785,14 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]: yield batch -def build_data_preparation_prompt_dataloader( +def create_prompt_dataloader( dataset: Dataset, seed: int, work_dir: str, - curriculum_config: rlvr_curriculum.DifficultyCurriculumConfig | None = None, + curriculum_config: difficulty_curriculum.DifficultyCurriculumConfig | None = None, ) -> HFDataLoader: - if curriculum_config is not None: - return DifficultyCurriculumHFDataLoader( + if curriculum_config is None: + return HFDataLoader( dataset=dataset, batch_size=1, seed=seed, @@ -805,10 +801,8 @@ def build_data_preparation_prompt_dataloader( work_dir=work_dir, automatic_reshuffle=True, collator=single_example_collator, - curriculum_config=curriculum_config, ) - - return HFDataLoader( + return DifficultyCurriculumHFDataLoader( dataset=dataset, batch_size=1, seed=seed, @@ -817,6 +811,7 @@ def build_data_preparation_prompt_dataloader( work_dir=work_dir, automatic_reshuffle=True, collator=single_example_collator, + curriculum_config=curriculum_config, ) @@ -940,8 +935,6 @@ def accumulate_inference_batches( all_active_tools = [] all_scores = [] all_indices = [] - all_source_row_ids = [] - all_source_datasets = [] all_percent_solved = [] all_model_steps = [] total_filtered_prompts = 0 @@ -993,8 +986,6 @@ def accumulate_inference_batches( ground_truth = example[GROUND_TRUTHS_KEY] dataset_name = example[VERIFIER_SOURCE_KEY] raw_query = example[RAW_PROMPT_KEY] - source_row_id = example.get(SOURCE_ROW_ID_KEY) - source_dataset = example.get(DATASET_ORIGIN_KEY) sample_active_tools = example.get(TOOLS_COLUMN_KEY) if replenish_prompts: @@ -1025,8 +1016,6 @@ def accumulate_inference_batches( k_raw_queries = repeat_each([raw_query], generation_config.n) k_active_tools = repeat_each([sample_active_tools], generation_config.n) k_indices = repeat_each([result.index], generation_config.n) - k_source_row_ids = repeat_each([source_row_id], generation_config.n) - k_source_datasets = repeat_each([source_dataset], generation_config.n) percent_solved = np.mean(result.reward_scores).item() / max_possible_score if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: @@ -1064,8 +1053,6 @@ def accumulate_inference_batches( all_raw_queries.extend(k_raw_queries) all_active_tools.extend(k_active_tools) all_indices.extend(k_indices) - all_source_row_ids.extend(k_source_row_ids) - all_source_datasets.extend(k_source_datasets) all_decoded_responses.extend(decoded_responses) all_scores.extend(result.reward_scores) all_reward_metrics.append(result.reward_metrics) @@ -1168,8 +1155,6 @@ def accumulate_inference_batches( indices=all_indices, scores=all_scores, active_tools=all_active_tools if all_active_tools else None, - source_row_ids=all_source_row_ids, - source_datasets=all_source_datasets, model_steps=all_model_steps, ) @@ -1315,7 +1300,7 @@ def __init__( model_name: str | None, base_env_config: EnvConfig, initial_state: dict | None = None, - curriculum_config: rlvr_curriculum.DifficultyCurriculumConfig | None = None, + curriculum_config: difficulty_curriculum.DifficultyCurriculumConfig | None = None, ): self.inference_results_Q = inference_results_Q self.param_prompt_Q = param_prompt_Q @@ -1336,7 +1321,7 @@ def __init__( self.model_name = model_name self.base_env_config = base_env_config - self.iter_dataloader = build_data_preparation_prompt_dataloader( + self.iter_dataloader = create_prompt_dataloader( dataset=dataset, seed=seed, work_dir=work_dir, curriculum_config=curriculum_config ) @@ -1447,7 +1432,6 @@ def _data_preparation_loop(self): self.prepared_data[step] = empty_data self.metrics[step] = empty_metrics self.current_prepared_step = step - self.training_step = step + 1 continue assert batch is not None @@ -1492,7 +1476,6 @@ def _data_preparation_loop(self): advantages, self.config.num_samples_per_prompt_rollout, self.total_samples_written, - record_format=self.config.rollout_save_format, ) self.total_samples_written += len(batch.queries) @@ -1610,7 +1593,6 @@ def _data_preparation_loop(self): self.prepared_data[step] = collated_data self.metrics[step] = step_metrics self.current_prepared_step = step - self.training_step = step + 1 def get_data(self, rank: int, step: int) -> dict: """Called by each rank's StreamingDataLoader. Blocks until data ready.""" diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 193116e1bf..b123ca87e6 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -945,8 +945,7 @@ def remove_dataset_source_field(dataset: Dataset) -> Dataset: # Cache version: increment this when transformation logic changes significantly # to invalidate old caches. v6: Added return_dict=False to apply_chat_template calls for transformers 5.x. -# v7: Preserve original source row ids in transformed datasets for rollout trace joins. -DATASET_CACHE_VERSION = "v7" +DATASET_CACHE_VERSION = "v6" def _normalize_env_config_column(row: dict[str, Any]) -> None: diff --git a/open_instruct/rlvr_curriculum.py b/open_instruct/difficulty_curriculum.py similarity index 100% rename from open_instruct/rlvr_curriculum.py rename to open_instruct/difficulty_curriculum.py diff --git a/open_instruct/environments/tools/parsers.py b/open_instruct/environments/tools/parsers.py index e9f81fbb2b..a8f04486b4 100644 --- a/open_instruct/environments/tools/parsers.py +++ b/open_instruct/environments/tools/parsers.py @@ -154,7 +154,7 @@ def _make_request(self) -> Any: Usually these only need the list of tools. """ - return ChatCompletionRequest(model="dummy", messages=[], tools=self._tool_definitions) + return ChatCompletionRequest(model="dummy", messages=[], tools=self._tool_definitions) # ty: ignore[invalid-argument-type] def get_tool_calls(self, text: str) -> list[EnvCall]: """Extract tool calls from model output. diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 894ced3845..4fda0a813a 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -40,7 +40,7 @@ from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF from deepspeed.utils import groups -from open_instruct import data_loader as data_loader_lib, rlvr_curriculum +from open_instruct import data_loader as data_loader_lib, difficulty_curriculum from open_instruct import data_types, grpo_utils, utils from open_instruct.data_loader import accumulate_inference_batches, add_prompt_to_generator from open_instruct.data_types import EnvConfig, EnvConfigEntry @@ -1300,7 +1300,7 @@ def create_model_and_optimizer( reward_config: RewardConfig, generation_config, base_env_config: EnvConfig, - curriculum_config: rlvr_curriculum.DifficultyCurriculumConfig | None, + curriculum_config: difficulty_curriculum.DifficultyCurriculumConfig | None, tool_definitions: list[dict[str, Any]] | None = None, tools_config: EnvsConfig | None = None, pools: dict[str, ray.actor.ActorHandle] | None = None, @@ -2397,7 +2397,7 @@ def main( streaming_config: data_loader_lib.StreamingDataLoaderConfig, vllm_config: data_loader_lib.VLLMConfig, tools_config: EnvsConfig, - curriculum_args: rlvr_curriculum.DifficultyCurriculumArgs, + curriculum_args: difficulty_curriculum.DifficultyCurriculumArgs, ): tokenizer = make_tokenizer(tc, model_config) args = setup_runtime_variables(args, streaming_config, tools_config) @@ -2598,7 +2598,7 @@ def main( data_loader_lib.StreamingDataLoaderConfig, data_loader_lib.VLLMConfig, EnvsConfig, - rlvr_curriculum.DifficultyCurriculumArgs, + difficulty_curriculum.DifficultyCurriculumArgs, ) ) parser.set_defaults(exp_name="grpo", warmup_ratio=0.0, max_grad_norm=1.0, per_device_train_batch_size=1) @@ -2611,6 +2611,6 @@ def main( assert isinstance(streaming_config, data_loader_lib.StreamingDataLoaderConfig) assert isinstance(vllm_config, data_loader_lib.VLLMConfig) assert isinstance(tools_config, EnvsConfig) - assert isinstance(curriculum_args, rlvr_curriculum.DifficultyCurriculumArgs) + assert isinstance(curriculum_args, difficulty_curriculum.DifficultyCurriculumArgs) main(args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum_args) diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index eaed9395a9..6a1ac44eb1 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -132,8 +132,6 @@ class Batch: indices: list[int] | None scores: list[float] | None active_tools: list[list[str] | None] | None = None - source_row_ids: list[int | None] | None = None - source_datasets: list[str | None] | None = None model_steps: list[int] = field(default_factory=list) def __getitem__(self, key: slice | int | list[int]) -> "Batch": @@ -149,8 +147,6 @@ def __getitem__(self, key: slice | int | list[int]) -> "Batch": indices=self.indices[key] if self.indices is not None else None, scores=self.scores[key] if self.scores is not None else None, active_tools=self.active_tools[key] if self.active_tools is not None else None, - source_row_ids=self.source_row_ids[key] if self.source_row_ids is not None else None, - source_datasets=self.source_datasets[key] if self.source_datasets is not None else None, model_steps=self.model_steps[key], ) elif isinstance(key, int): @@ -164,8 +160,6 @@ def __getitem__(self, key: slice | int | list[int]) -> "Batch": indices=[self.indices[key]] if self.indices is not None else None, scores=[self.scores[key]] if self.scores is not None else None, active_tools=[self.active_tools[key]] if self.active_tools is not None else None, - source_row_ids=[self.source_row_ids[key]] if self.source_row_ids is not None else None, - source_datasets=[self.source_datasets[key]] if self.source_datasets is not None else None, model_steps=[self.model_steps[key]], ) else: @@ -181,8 +175,6 @@ def __getitem__(self, key: slice | int | list[int]) -> "Batch": indices=[self.indices[i] for i in key] if self.indices is not None else None, scores=[self.scores[i] for i in key] if self.scores is not None else None, active_tools=[self.active_tools[i] for i in key] if self.active_tools is not None else None, - source_row_ids=[self.source_row_ids[i] for i in key] if self.source_row_ids is not None else None, - source_datasets=[self.source_datasets[i] for i in key] if self.source_datasets is not None else None, model_steps=[self.model_steps[i] for i in key], ) diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 3c79b728db..c7161b3320 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -5,7 +5,7 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass, field -from typing import Any, Generic, Literal, TypeVar +from typing import Generic, TypeVar import numpy as np import torch @@ -17,7 +17,6 @@ _rollout_executor = ThreadPoolExecutor(max_workers=2) ROLLOUT_SHARD_SIZE = 10000 -RolloutSaveFormat = Literal["full", "scores_only"] @dataclass @@ -26,7 +25,6 @@ class RolloutMetadata: git_commit: str model_name: str timestamp: str - experiment_id: str | None = None @dataclass @@ -41,25 +39,15 @@ class RolloutRecord: finish_reason: str dataset: str ground_truth: list[int] | None = None - source_row_id: int | None = None - source_dataset: str | None = None request_info: dict | None = None logprobs: list[float] | None = None -@dataclass -class RolloutScoreRecord: - dataset: str - reward: float - source_row_id: int | None = None - source_dataset: str | None = None - - def save_rollout_metadata(save_path: str, run_name: str, model_name: str | None) -> None: """Save metadata about the rollout collection to disk. Creates a JSONL file containing run information including git commit, - model name, runtime experiment id, and timestamp for traceability. + model name, and timestamp for traceability. Args: save_path: Directory to save metadata file. @@ -70,7 +58,6 @@ def save_rollout_metadata(save_path: str, run_name: str, model_name: str | None) run_name=run_name, git_commit=utils.get_git_commit(), model_name=model_name or "unknown", - experiment_id=os.environ.get("BEAKER_WORKLOAD_ID") or None, timestamp=datetime.datetime.now(datetime.timezone.utc).isoformat(), ) metadata_path = os.path.join(save_path, f"{run_name}_metadata.jsonl") @@ -98,33 +85,24 @@ def _get_request_info_for_sample(request_info: data_types.RequestInfo | None, i: } -def build_rollout_records( +def _save_rollouts( + save_path: str, + run_name: str, + step: int, batch: model_utils.Batch, result: data_types.GenerationResult, advantages: np.ndarray, - *, - step: int, num_samples_per_prompt: int, - record_format: RolloutSaveFormat = "full", -) -> list[dict[str, Any]]: - """Build JSON-serializable rollout records for persistence.""" + shard_idx: int, +) -> None: + shard_filename = f"{run_name}_rollouts_{shard_idx:06d}.jsonl" + filepath = os.path.join(save_path, shard_filename) + os.makedirs(save_path, exist_ok=True) + assert batch.scores is not None, "batch.scores must not be None when saving rollouts" records = [] for i in range(len(batch.queries)): - if record_format == "scores_only": - records.append( - asdict( - RolloutScoreRecord( - dataset=batch.datasets[i], - reward=float(batch.scores[i]), - source_row_id=batch.source_row_ids[i] if batch.source_row_ids is not None else None, - source_dataset=batch.source_datasets[i] if batch.source_datasets is not None else None, - ) - ) - ) - continue - records.append( asdict( RolloutRecord( @@ -138,40 +116,12 @@ def build_rollout_records( finish_reason=result.finish_reasons[i], dataset=batch.datasets[i], ground_truth=batch.ground_truths[i], - source_row_id=batch.source_row_ids[i] if batch.source_row_ids is not None else None, - source_dataset=batch.source_datasets[i] if batch.source_datasets is not None else None, request_info=_get_request_info_for_sample(result.request_info, i), logprobs=result.logprobs[i] if result.logprobs else None, ) ) ) - return records - - -def _save_rollouts( - save_path: str, - run_name: str, - step: int, - batch: model_utils.Batch, - result: data_types.GenerationResult, - advantages: np.ndarray, - num_samples_per_prompt: int, - shard_idx: int, - record_format: RolloutSaveFormat, -) -> None: - shard_filename = f"{run_name}_rollouts_{shard_idx:06d}.jsonl" - filepath = os.path.join(save_path, shard_filename) - os.makedirs(save_path, exist_ok=True) - records = build_rollout_records( - batch, - result, - advantages, - step=step, - num_samples_per_prompt=num_samples_per_prompt, - record_format=record_format, - ) - with open(filepath, "a") as f: for record in records: f.write(json.dumps(record) + "\n") @@ -187,7 +137,6 @@ def save_rollouts_to_disk( advantages: np.ndarray, num_samples_per_prompt: int, total_samples_written: int, - record_format: RolloutSaveFormat = "full", ) -> None: """Asynchronously save rollout records to disk. @@ -203,21 +152,10 @@ def save_rollouts_to_disk( advantages: Calculated advantage values per sample. num_samples_per_prompt: Number of samples generated per prompt. total_samples_written: Total samples written so far, used for sharding. - record_format: Output schema to persist. Use "scores_only" for the - minimum fields needed by the difficulty-map builder. """ shard_idx = total_samples_written // ROLLOUT_SHARD_SIZE _rollout_executor.submit( - _save_rollouts, - save_path, - run_name, - step, - batch, - result, - advantages, - num_samples_per_prompt, - shard_idx, - record_format, + _save_rollouts, save_path, run_name, step, batch, result, advantages, num_samples_per_prompt, shard_idx ) diff --git a/open_instruct/test_rlvr_curriculum.py b/open_instruct/test_difficulty_curriculum.py similarity index 88% rename from open_instruct/test_rlvr_curriculum.py rename to open_instruct/test_difficulty_curriculum.py index cd5e8aa767..64e0522578 100644 --- a/open_instruct/test_rlvr_curriculum.py +++ b/open_instruct/test_difficulty_curriculum.py @@ -11,7 +11,7 @@ vllm_stub.SamplingParams = object sys.modules["vllm"] = vllm_stub -from open_instruct import data_loader, rlvr_curriculum +from open_instruct import data_loader, difficulty_curriculum class ListDataset: @@ -57,19 +57,19 @@ def make_plain_hf_dataset(num_examples: int) -> Dataset: class TestDifficultyCurriculumSampler(unittest.TestCase): - def _make_metadata(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumMetadataConfig: - return rlvr_curriculum.DifficultyCurriculumMetadataConfig(**overrides) + def _make_metadata(self, **overrides) -> difficulty_curriculum.DifficultyCurriculumMetadataConfig: + return difficulty_curriculum.DifficultyCurriculumMetadataConfig(**overrides) - def _make_schedule(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumScheduleConfig: - return rlvr_curriculum.DifficultyCurriculumScheduleConfig( + def _make_schedule(self, **overrides) -> difficulty_curriculum.DifficultyCurriculumScheduleConfig: + return difficulty_curriculum.DifficultyCurriculumScheduleConfig( bootstrap_steps=100, warmup_steps=120, total_steps=200, **overrides ) - def _make_adaptive(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumAdaptiveConfig: - return rlvr_curriculum.DifficultyCurriculumAdaptiveConfig(**overrides) + def _make_adaptive(self, **overrides) -> difficulty_curriculum.DifficultyCurriculumAdaptiveConfig: + return difficulty_curriculum.DifficultyCurriculumAdaptiveConfig(**overrides) - def _make_config(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumConfig: - return rlvr_curriculum.DifficultyCurriculumConfig( + def _make_config(self, **overrides) -> difficulty_curriculum.DifficultyCurriculumConfig: + return difficulty_curriculum.DifficultyCurriculumConfig( metadata=overrides.pop("metadata", self._make_metadata()), schedule=overrides.pop("schedule", self._make_schedule()), adaptive=overrides.pop("adaptive", self._make_adaptive()), @@ -77,9 +77,9 @@ def _make_config(self, **overrides) -> rlvr_curriculum.DifficultyCurriculumConfi **overrides, ) - def _make_sampler(self, dataset, **config_overrides) -> rlvr_curriculum.DifficultyCurriculumSampler: + def _make_sampler(self, dataset, **config_overrides) -> difficulty_curriculum.DifficultyCurriculumSampler: config = self._make_config(**config_overrides) - return rlvr_curriculum.DifficultyCurriculumSampler( + return difficulty_curriculum.DifficultyCurriculumSampler( dataset=dataset, num_samples=max(len(dataset), 1), config=config, global_step_getter=lambda: 0 ) @@ -169,7 +169,7 @@ def test_bootstrap_distribution_is_tunable(self): self.assertLess(tuned_probs[2], default_probs[2]) def test_curriculum_args_parser_builds_grouped_config(self): - parser = HfArgumentParser((rlvr_curriculum.DifficultyCurriculumArgs,)) + parser = HfArgumentParser((difficulty_curriculum.DifficultyCurriculumArgs,)) (curriculum_args,) = parser.parse_args_into_dataclasses( [ "--curriculum", @@ -202,9 +202,7 @@ def test_curriculum_args_parser_builds_grouped_config(self): class TestDifficultyCurriculumLoaderIntegration(unittest.TestCase): def test_existing_behavior_is_unchanged_when_curriculum_disabled(self): dataset = make_plain_hf_dataset(20) - built_loader = data_loader.build_data_preparation_prompt_dataloader( - dataset=dataset, seed=7, work_dir=tempfile.gettempdir(), curriculum_config=None - ) + built_loader = data_loader.create_prompt_dataloader(dataset=dataset, seed=7, work_dir=tempfile.gettempdir()) baseline_loader = data_loader.HFDataLoader( dataset=dataset, batch_size=1, diff --git a/open_instruct/test_rollout_traces.py b/open_instruct/test_rollout_traces.py deleted file mode 100644 index ea0947b634..0000000000 --- a/open_instruct/test_rollout_traces.py +++ /dev/null @@ -1,84 +0,0 @@ -import unittest - -import numpy as np - -from open_instruct import model_utils, rl_utils -from open_instruct.data_types import GenerationResult, RequestInfo - - -class TestRolloutRecords(unittest.TestCase): - def _make_result(self) -> GenerationResult: - return GenerationResult( - responses=[[10, 11], [12, 13]], - finish_reasons=["stop", "length"], - masks=[[1, 1], [1, 1]], - request_info=RequestInfo( - num_calls=[0, 1], - timeouts=[0, 0], - tool_errors=["", ""], - tool_outputs=["", "ok"], - tool_runtimes=[0.0, 0.1], - tool_calleds=[False, True], - tool_call_stats=[[], []], - rollout_states=[{}, {"done": False}], - ), - index=3, - prompt_id="prompt_3", - logprobs=[[0.1, 0.2], [0.3, 0.4]], - model_step=7, - ) - - def _make_batch(self) -> model_utils.Batch: - return model_utils.Batch( - queries=[[1, 2, 3], [1, 2, 3]], - ground_truths=[[4], [4]], - datasets=["math", "math"], - raw_queries=["user: solve 2+2", "user: solve 2+2"], - decoded_responses=None, - indices=[3, 3], - scores=[10.0, 0.0], - source_row_ids=[11, 11], - source_datasets=["demo", "demo"], - model_steps=[7, 7], - ) - - def test_build_rollout_records_full_format(self): - records = rl_utils.build_rollout_records( - self._make_batch(), - self._make_result(), - np.array([5.0, -5.0]), - step=9, - num_samples_per_prompt=2, - record_format="full", - ) - - self.assertEqual(len(records), 2) - self.assertEqual(records[0]["step"], 9) - self.assertEqual(records[0]["prompt_idx"], 0) - self.assertEqual(records[0]["source_row_id"], 11) - self.assertEqual(records[0]["source_dataset"], "demo") - self.assertEqual(records[0]["response_tokens"], [10, 11]) - self.assertEqual(records[1]["finish_reason"], "length") - self.assertEqual(records[1]["request_info"]["tool_outputs"], "ok") - - def test_build_rollout_records_scores_only_format(self): - records = rl_utils.build_rollout_records( - self._make_batch(), - self._make_result(), - np.array([5.0, -5.0]), - step=9, - num_samples_per_prompt=2, - record_format="scores_only", - ) - - self.assertEqual( - records, - [ - {"dataset": "math", "reward": 10.0, "source_row_id": 11, "source_dataset": "demo"}, - {"dataset": "math", "reward": 0.0, "source_row_id": 11, "source_dataset": "demo"}, - ], - ) - - -if __name__ == "__main__": - unittest.main() From 1eebbbc62080ac5834456d9c6033301d55a82bd0 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 16:32:19 -0700 Subject: [PATCH 26/40] More cleanup --- open_instruct/dataset_transformation.py | 4 - open_instruct/difficulty_curriculum.py | 84 +++++++++++++-------- open_instruct/test_difficulty_curriculum.py | 15 ++++ 3 files changed, 67 insertions(+), 36 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index b123ca87e6..0d617b7ed7 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -869,7 +869,6 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"): GROUND_TRUTHS_KEY = "ground_truth" VERIFIER_SOURCE_KEY = "dataset" RAW_PROMPT_KEY = "prompt" -SOURCE_ROW_ID_KEY = "source_row_id" @dataclass @@ -1631,8 +1630,6 @@ def __post_init__(self): num_proc=max_num_processes(), ) assert isinstance(dataset, Dataset), f"Expected Dataset, got {type(dataset)}" - if SOURCE_ROW_ID_KEY not in dataset.column_names: - dataset = dataset.add_column(SOURCE_ROW_ID_KEY, range(len(dataset))) self.dataset = dataset if self.dataset_range is None: dataset_range = len(self.dataset) @@ -1734,7 +1731,6 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): target_columns = dataset.column_names if dc.target_columns is None else dc.target_columns # Always preserve dataset_source if it exists target_columns = _preserve_column(DATASET_ORIGIN_KEY, dataset, target_columns) - target_columns = _preserve_column(SOURCE_ROW_ID_KEY, dataset, target_columns) target_columns = _preserve_column(TOOLS_COLUMN_KEY, dataset, target_columns) target_columns = _preserve_column(ENV_CONFIG_KEY, dataset, target_columns) diff --git a/open_instruct/difficulty_curriculum.py b/open_instruct/difficulty_curriculum.py index 60817457d6..ac71daf418 100644 --- a/open_instruct/difficulty_curriculum.py +++ b/open_instruct/difficulty_curriculum.py @@ -226,12 +226,12 @@ def update( if advantages is not None and len(advantages) != len(rewards): raise ValueError("advantages and rewards must have the same length") - reward_values = [float(np.clip(value, 0.0, 1.0)) for value in rewards] - advantage_values = None if advantages is None else [abs(float(value)) for value in advantages] + reward_values = np.clip(np.asarray(rewards, dtype=np.float64), 0.0, 1.0) + advantage_values = None if advantages is None else np.abs(np.asarray(advantages, dtype=np.float64)) for position, bucket_index in enumerate(bucket_indices): bucket = int(bucket_index) - reward = reward_values[position] + reward = float(reward_values[position]) self.total_count += 1 self._count_by_bucket[bucket] = self._count_by_bucket.get(bucket, 0) + 1 @@ -239,7 +239,7 @@ def update( self._reward_sq_sum_by_bucket[bucket] = self._reward_sq_sum_by_bucket.get(bucket, 0.0) + reward * reward if advantage_values is not None: - advantage = advantage_values[position] + advantage = float(advantage_values[position]) self._abs_advantage_sum_by_bucket[bucket] = ( self._abs_advantage_sum_by_bucket.get(bucket, 0.0) + advantage ) @@ -329,6 +329,7 @@ class _ParsedDifficultyMetadata: @dataclass(frozen=True) class _DifficultyBucketIndex: index_to_bucket: dict[int, int] + index_to_bucket_position: dict[int, int] bucket_to_indices: tuple[tuple[int, ...], ...] bucket_weights: tuple[torch.Tensor, ...] bucket_count: int @@ -478,6 +479,7 @@ def _build_difficulty_bucket_index( bucket_to_indices: list[list[int]] = [[] for _ in range(bucket_count)] bucket_weight_lists: list[list[float]] = [[] for _ in range(bucket_count)] index_to_bucket: dict[int, int] = {} + index_to_bucket_position: dict[int, int] = {} metadata_fallback_count = 0 fallback_bucket = min(bucket_count - 1, bucket_count // 2) @@ -500,11 +502,13 @@ def _build_difficulty_bucket_index( ) index_to_bucket[dataset_index] = bucket_index + index_to_bucket_position[dataset_index] = len(bucket_to_indices[bucket_index]) bucket_to_indices[bucket_index].append(dataset_index) bucket_weight_lists[bucket_index].append(example_weight) return _DifficultyBucketIndex( index_to_bucket=index_to_bucket, + index_to_bucket_position=index_to_bucket_position, bucket_to_indices=tuple(tuple(indices) for indices in bucket_to_indices), bucket_weights=tuple(torch.tensor(weight_list, dtype=torch.float64) for weight_list in bucket_weight_lists), bucket_count=bucket_count, @@ -540,6 +544,7 @@ def __init__( epsilon=self.config.epsilon, ) self._index_to_bucket = dict(bucket_index.index_to_bucket) + self._index_to_bucket_position = dict(bucket_index.index_to_bucket_position) self.bucket_count = bucket_index.bucket_count self.metadata_fallback_count = bucket_index.metadata_fallback_count self._schedule = _DifficultyCurriculumSchedule(self.config.schedule, self.bucket_count) @@ -547,8 +552,9 @@ def __init__( self._excluded_indices: set[int] = set() self._base_bucket_indices = [list(indices) for indices in bucket_index.bucket_to_indices] self._base_bucket_weights = [weights.clone() for weights in bucket_index.bucket_weights] - self._active_bucket_indices = [list(indices) for indices in self._base_bucket_indices] self._active_bucket_weights = [weights.clone() for weights in self._base_bucket_weights] + self._active_bucket_counts = [len(indices) for indices in self._base_bucket_indices] + self._active_bucket_weight_sums = [float(weights.sum().item()) for weights in self._active_bucket_weights] self.adaptive_stats = None if self.config.adaptive.enabled: @@ -586,7 +592,7 @@ def get_progress(self, step: int | None = None) -> float: return self._schedule.get_progress(step) def _available_bucket_mask(self) -> np.ndarray: - return np.array([1.0 if indices else 0.0 for indices in self._active_bucket_indices], dtype=np.float64) + return np.array([1.0 if count > 0 else 0.0 for count in self._active_bucket_counts], dtype=np.float64) def get_static_bucket_probs(self, step: int | None = None) -> np.ndarray: if step is None: @@ -632,19 +638,20 @@ def get_example_probability(self, dataset_index: int, step: int | None = None) - if int(dataset_index) in self._excluded_indices: return 0.0 bucket_index = self.bucket_for_dataset_index(dataset_index) - active_indices = self._active_bucket_indices[bucket_index] - if not active_indices: + if self._active_bucket_counts[bucket_index] == 0: return 0.0 - try: - local_index = active_indices.index(int(dataset_index)) - except ValueError: + local_index = self._index_to_bucket_position.get(int(dataset_index)) + if local_index is None: return 0.0 bucket_weight = self._active_bucket_weights[bucket_index] - weight_total = float(bucket_weight.sum().item()) + dataset_weight = float(bucket_weight[local_index].item()) + if dataset_weight <= 0: + return 0.0 + weight_total = self._active_bucket_weight_sums[bucket_index] if weight_total <= 0: return 0.0 bucket_probs = self.get_bucket_probs(step) - return float(bucket_probs[bucket_index] * bucket_weight[local_index].item() / weight_total) + return float(bucket_probs[bucket_index] * dataset_weight / weight_total) def sample_index(self, step: int | None = None) -> int: if self._available_bucket_mask().sum() == 0: @@ -653,11 +660,11 @@ def sample_index(self, step: int | None = None) -> int: bucket_index = int(torch.multinomial(bucket_probs, 1, generator=self._generator).item()) example_weights = self._active_bucket_weights[bucket_index] - if example_weights.numel() == 0: + if self._active_bucket_counts[bucket_index] == 0 or self._active_bucket_weight_sums[bucket_index] <= 0: raise RuntimeError("attempted to sample from an empty curriculum bucket") sampled_offset = int(torch.multinomial(example_weights, 1, generator=self._generator).item()) - return self._active_bucket_indices[bucket_index][sampled_offset] + return self._base_bucket_indices[bucket_index][sampled_offset] def __iter__(self): for _ in range(self.num_samples): @@ -672,19 +679,22 @@ def exclude_index(self, dataset_index: int) -> None: if bucket_index is None: return - active_indices = self._active_bucket_indices[bucket_index] - try: - position = active_indices.index(dataset_index) - except ValueError: + position = self._index_to_bucket_position.get(dataset_index) + if position is None: self._excluded_indices.add(dataset_index) return - active_indices.pop(position) weights = self._active_bucket_weights[bucket_index] - if weights.numel() <= 1: - self._active_bucket_weights[bucket_index] = weights[:0].clone() - else: - self._active_bucket_weights[bucket_index] = torch.cat((weights[:position], weights[position + 1 :])) + current_weight = float(weights[position].item()) + if current_weight <= 0: + self._excluded_indices.add(dataset_index) + return + + weights[position] = 0.0 + self._active_bucket_counts[bucket_index] -= 1 + self._active_bucket_weight_sums[bucket_index] = max( + 0.0, self._active_bucket_weight_sums[bucket_index] - current_weight + ) self._excluded_indices.add(dataset_index) def record_observations( @@ -754,14 +764,24 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._generator.set_state(generator_state) self._excluded_indices = {int(index) for index in state_dict.get("excluded_indices", [])} - self._active_bucket_indices = [] - self._active_bucket_weights = [] - for base_indices, base_weights in zip(self._base_bucket_indices, self._base_bucket_weights, strict=True): - keep_positions = [ - position for position, index in enumerate(base_indices) if index not in self._excluded_indices - ] - self._active_bucket_indices.append([base_indices[position] for position in keep_positions]) - self._active_bucket_weights.append(base_weights[keep_positions].clone()) + self._active_bucket_weights = [weights.clone() for weights in self._base_bucket_weights] + self._active_bucket_counts = [len(indices) for indices in self._base_bucket_indices] + self._active_bucket_weight_sums = [float(weights.sum().item()) for weights in self._active_bucket_weights] + for dataset_index in self._excluded_indices: + bucket_index = self._index_to_bucket.get(dataset_index) + position = self._index_to_bucket_position.get(dataset_index) + if bucket_index is None or position is None: + continue + + current_weight = float(self._active_bucket_weights[bucket_index][position].item()) + if current_weight <= 0: + continue + + self._active_bucket_weights[bucket_index][position] = 0.0 + self._active_bucket_counts[bucket_index] -= 1 + self._active_bucket_weight_sums[bucket_index] = max( + 0.0, self._active_bucket_weight_sums[bucket_index] - current_weight + ) if self.adaptive_stats is not None and state_dict.get("adaptive_stats") is not None: self.adaptive_stats.load_state_dict(state_dict["adaptive_stats"]) diff --git a/open_instruct/test_difficulty_curriculum.py b/open_instruct/test_difficulty_curriculum.py index 64e0522578..913758cc08 100644 --- a/open_instruct/test_difficulty_curriculum.py +++ b/open_instruct/test_difficulty_curriculum.py @@ -155,6 +155,21 @@ def test_adaptive_stats_increase_sampling_probability_for_high_signal_bucket(sel adaptive_probs = sampler.get_bucket_probs(step=1) self.assertGreater(adaptive_probs[4], static_probs[4]) + def test_excluded_index_has_zero_probability_and_persists_after_restore(self): + dataset = make_bucket_dataset() + sampler = self._make_sampler(dataset) + sampler.exclude_index(4) + + self.assertEqual(sampler.get_example_probability(4, step=0), 0.0) + self.assertEqual(float(sampler.get_bucket_probs(step=0)[4]), 0.0) + self.assertTrue(all(sampler.sample_index(step=0) != 4 for _ in range(20))) + + restored_sampler = self._make_sampler(dataset) + restored_sampler.load_state_dict(sampler.state_dict()) + self.assertEqual(restored_sampler.get_example_probability(4, step=0), 0.0) + self.assertEqual(float(restored_sampler.get_bucket_probs(step=0)[4]), 0.0) + self.assertTrue(all(restored_sampler.sample_index(step=0) != 4 for _ in range(20))) + def test_bootstrap_distribution_is_tunable(self): default_sampler = self._make_sampler(make_bucket_dataset()) tuned_sampler = self._make_sampler( From 8c1e55e99657c5478e2139fc469522c29a54dd41 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 16:37:49 -0700 Subject: [PATCH 27/40] More cruft --- open_instruct/grpo_fast.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 4fda0a813a..5ec04c9777 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1180,7 +1180,6 @@ def setup_datasets( dataset_local_cache_dir=streaming_config.dataset_local_cache_dir, dataset_skip_cache=streaming_config.dataset_skip_cache, system_prompt_override=system_prompt_override, - drop_dataset_source=not streaming_config.save_traces, ) _validate_and_log_dataset_tools(train_dataset, configured_tool_call_names, "train_dataset") @@ -1198,7 +1197,6 @@ def setup_datasets( dataset_local_cache_dir=streaming_config.dataset_local_cache_dir, dataset_skip_cache=streaming_config.dataset_skip_cache, system_prompt_override=system_prompt_override, - drop_dataset_source=not streaming_config.save_traces, ) _validate_and_log_dataset_tools(eval_dataset, configured_tool_call_names, "eval_dataset") From 3c4bfaa1c4781254ff963cea1d97140478519505 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 16:48:57 -0700 Subject: [PATCH 28/40] Drop noop in favor of conditional --- open_instruct/data_loader.py | 59 +++++++++++------------------ tests/test_create_difficulty_map.py | 4 +- 2 files changed, 24 insertions(+), 39 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 7414990fa9..4722af757b 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -248,21 +248,6 @@ def exclude_index(self, index: int) -> None: """ self._excluded_indices.add(index) - def set_sampling_step(self, step: int) -> None: - del step - - def record_curriculum_observations( - self, - dataset_indices: list[int] | np.ndarray, - rewards: list[float] | np.ndarray, - advantages: list[float] | np.ndarray | None = None, - ) -> None: - del dataset_indices, rewards, advantages - - def build_curriculum_metrics(self, prompt_dataset_indices: list[int], step: int) -> dict[str, float]: - del prompt_dataset_indices, step - return {} - def reshuffle(self, epoch: int | None = None, **kwargs: Any) -> None: """Reshuffle and reshard the dataset for a new epoch. @@ -731,17 +716,6 @@ def curriculum_sampler(self) -> difficulty_curriculum.DifficultyCurriculumSample def set_sampling_step(self, step: int) -> None: self._sampling_step = int(step) - def record_curriculum_observations( - self, - dataset_indices: list[int] | np.ndarray, - rewards: list[float] | np.ndarray, - advantages: list[float] | np.ndarray | None = None, - ) -> None: - self._curriculum_sampler.record_observations(dataset_indices, rewards, advantages) - - def build_curriculum_metrics(self, prompt_dataset_indices: list[int], step: int) -> dict[str, float]: - return self._curriculum_sampler.build_metrics(prompt_dataset_indices, step) - def state_dict(self) -> dict[str, Any]: return { "epoch": self._epoch, @@ -1324,6 +1298,9 @@ def __init__( self.iter_dataloader = create_prompt_dataloader( dataset=dataset, seed=seed, work_dir=work_dir, curriculum_config=curriculum_config ) + self.curriculum_dataloader = ( + self.iter_dataloader if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader) else None + ) self.prepared_data: dict[int, list[data_types.CollatedBatchData]] = {} self.metrics: dict[int, dict] = {} @@ -1362,7 +1339,8 @@ def _data_preparation_loop(self): num_initial_prompts = self.config.async_steps * self.global_batch_size logger.info(f"[DataPreparationActor] Pushing {num_initial_prompts} initial prompts to param_prompt_Q") - self.iter_dataloader.set_sampling_step(self.training_step) + if self.curriculum_dataloader is not None: + self.curriculum_dataloader.set_sampling_step(self.training_step) for _ in range(num_initial_prompts): add_prompt_to_generator( next(self.iter_dataloader), @@ -1382,7 +1360,8 @@ def _data_preparation_loop(self): ) time.sleep(0.1) generation_idle_wait_time = time.perf_counter() - generation_idle_wait_start_time - self.iter_dataloader.set_sampling_step(step) + if self.curriculum_dataloader is not None: + self.curriculum_dataloader.set_sampling_step(step) logger.info( f"[DataPreparationActor] Step {step}: calling accumulate_inference_batches for {self.global_batch_size} prompts" @@ -1427,7 +1406,8 @@ def _data_preparation_loop(self): for _ in range(self.dp_world_size) ] empty_metrics = {"time/generation_idle_waiting_for_trainer": generation_idle_wait_time} - empty_metrics.update(self.iter_dataloader.build_curriculum_metrics([], step)) + if self.curriculum_dataloader is not None: + empty_metrics.update(self.curriculum_dataloader.curriculum_sampler.build_metrics([], step)) with self.lock: self.prepared_data[step] = empty_data self.metrics[step] = empty_metrics @@ -1436,11 +1416,11 @@ def _data_preparation_loop(self): assert batch is not None assert batch_stats is not None - prompt_dataset_indices = ( - [int(index) for index in batch.indices[:: self.config.num_samples_per_prompt_rollout]] - if batch.indices is not None - else [] - ) + prompt_dataset_indices: list[int] = [] + if self.curriculum_dataloader is not None and batch.indices is not None: + prompt_dataset_indices = [ + int(index) for index in batch.indices[:: self.config.num_samples_per_prompt_rollout] + ] if self.rubric_manager and batch.decoded_responses: rubric_metrics, new_overrides = self.rubric_manager.run_step( @@ -1498,9 +1478,11 @@ def _data_preparation_loop(self): assert result.logprobs is not None result.logprobs = [result.logprobs[i] for i in stop_idxes] - if batch.indices is not None: + if self.curriculum_dataloader is not None and batch.indices is not None: normalized_scores = np.clip(scores / max(self.config.max_possible_score, 1e-8), 0.0, 1.0) - self.iter_dataloader.record_curriculum_observations(batch.indices, normalized_scores, advantages) + self.curriculum_dataloader.curriculum_sampler.record_observations( + batch.indices, normalized_scores, advantages + ) assert result.logprobs is not None packed_sequences = pack_sequences( @@ -1587,7 +1569,10 @@ def _data_preparation_loop(self): step_metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time step_metrics["time/getting_response"] = result.token_statistics.generation_time - step_metrics.update(self.iter_dataloader.build_curriculum_metrics(prompt_dataset_indices, step)) + if self.curriculum_dataloader is not None: + step_metrics.update( + self.curriculum_dataloader.curriculum_sampler.build_metrics(prompt_dataset_indices, step) + ) with self.lock: self.prepared_data[step] = collated_data diff --git a/tests/test_create_difficulty_map.py b/tests/test_create_difficulty_map.py index 6debf587b2..841a0da407 100644 --- a/tests/test_create_difficulty_map.py +++ b/tests/test_create_difficulty_map.py @@ -13,7 +13,7 @@ import numpy as np -MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts/data/difficulty_sampling/create_difficulty_map.py" +SCRIPT_PATH = Path(__file__).resolve().parents[1] / "scripts/data/difficulty_sampling/create_difficulty_map.py" def _load_create_difficulty_map_module(): @@ -84,7 +84,7 @@ def ppf(cls, q, alpha, beta): "scipy.stats": fake_scipy_stats, } module_name = "test_create_difficulty_map_module" - spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH) + spec = importlib.util.spec_from_file_location(module_name, SCRIPT_PATH) assert spec is not None and spec.loader is not None module = importlib.util.module_from_spec(spec) From 8b04761e3d98b033e2d8c60d3202469a86d96ba9 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 16:59:40 -0700 Subject: [PATCH 29/40] Revert --- open_instruct/dataset_transformation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 0d617b7ed7..33b297f492 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -2162,7 +2162,6 @@ def get_cached_dataset_tulu( dataset_skip_cache: bool = False, dataset_config_seed: int = 42, system_prompt_override: str | None = None, - drop_dataset_source: bool = True, ) -> Dataset: return get_cached_dataset_tulu_with_statistics( dataset_mixer_list=dataset_mixer_list, @@ -2176,7 +2175,7 @@ def get_cached_dataset_tulu( hf_entity=hf_entity, dataset_local_cache_dir=dataset_local_cache_dir, dataset_skip_cache=dataset_skip_cache, - drop_dataset_source=drop_dataset_source, + drop_dataset_source=True, dataset_config_seed=dataset_config_seed, system_prompt_override=system_prompt_override, )[0] From cefb1b9589a2b6fa7a69b325a9ea8ca12e6a09c0 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 17:08:57 -0700 Subject: [PATCH 30/40] Some verification methods --- open_instruct/difficulty_curriculum.py | 74 +++++++++++++-------- open_instruct/grpo_fast.py | 15 +++-- open_instruct/test_difficulty_curriculum.py | 10 ++- 3 files changed, 62 insertions(+), 37 deletions(-) diff --git a/open_instruct/difficulty_curriculum.py b/open_instruct/difficulty_curriculum.py index ac71daf418..2b5cc6363a 100644 --- a/open_instruct/difficulty_curriculum.py +++ b/open_instruct/difficulty_curriculum.py @@ -162,39 +162,57 @@ class DifficultyCurriculumArgs: curriculum_adaptive_exploration_weight: float = 0.3 curriculum_adaptive_blend: float = 0.5 - def build_curriculum_config(self, *, seed: int) -> DifficultyCurriculumConfig | None: + def build_metadata_config(self) -> DifficultyCurriculumMetadataConfig: + return DifficultyCurriculumMetadataConfig( + field=self.curriculum_metadata_field, + posterior_mean_field=self.curriculum_posterior_mean_field, + bucket_index_field=self.curriculum_bucket_index_field, + bucket_count_field=self.curriculum_bucket_count_field, + strict=self.curriculum_strict_metadata, + ) + + def build_schedule_config(self) -> DifficultyCurriculumScheduleConfig: + return DifficultyCurriculumScheduleConfig( + bootstrap_steps=self.curriculum_bootstrap_steps, + warmup_steps=self.curriculum_warmup_steps, + total_steps=self.curriculum_total_steps, + bootstrap_target=self.curriculum_bootstrap_target, + warmup_target=self.curriculum_warmup_target, + final_target=self.curriculum_final_target, + min_hard_frac=self.curriculum_min_hard_frac, + max_hard_frac=self.curriculum_max_hard_frac, + bucket_sigma=self.curriculum_bucket_sigma, + bootstrap_sigma=self.curriculum_bootstrap_sigma, + ) + + def build_adaptive_config(self) -> DifficultyCurriculumAdaptiveConfig: + return DifficultyCurriculumAdaptiveConfig( + enabled=self.curriculum_adaptive, + update_every=self.curriculum_adaptive_update_every, + learning_weight=self.curriculum_adaptive_learning_weight, + exploration_weight=self.curriculum_adaptive_exploration_weight, + blend=self.curriculum_adaptive_blend, + ) + + def verify(self) -> None: if self.curriculum == "none": - return None + return if self.curriculum != "difficulty": raise ValueError(f"Unsupported curriculum type: {self.curriculum}") + self.build_metadata_config() + self.build_schedule_config() + self.build_adaptive_config() + + def build_curriculum_config(self, *, seed: int) -> DifficultyCurriculumConfig | None: + self.verify() + if self.curriculum == "none": + return None + return DifficultyCurriculumConfig( - metadata=DifficultyCurriculumMetadataConfig( - field=self.curriculum_metadata_field, - posterior_mean_field=self.curriculum_posterior_mean_field, - bucket_index_field=self.curriculum_bucket_index_field, - bucket_count_field=self.curriculum_bucket_count_field, - strict=self.curriculum_strict_metadata, - ), - schedule=DifficultyCurriculumScheduleConfig( - bootstrap_steps=self.curriculum_bootstrap_steps, - warmup_steps=self.curriculum_warmup_steps, - total_steps=self.curriculum_total_steps, - bootstrap_target=self.curriculum_bootstrap_target, - warmup_target=self.curriculum_warmup_target, - final_target=self.curriculum_final_target, - min_hard_frac=self.curriculum_min_hard_frac, - max_hard_frac=self.curriculum_max_hard_frac, - bucket_sigma=self.curriculum_bucket_sigma, - bootstrap_sigma=self.curriculum_bootstrap_sigma, - ), - adaptive=DifficultyCurriculumAdaptiveConfig( - enabled=self.curriculum_adaptive, - update_every=self.curriculum_adaptive_update_every, - learning_weight=self.curriculum_adaptive_learning_weight, - exploration_weight=self.curriculum_adaptive_exploration_weight, - blend=self.curriculum_adaptive_blend, - ), + metadata=self.build_metadata_config(), + schedule=self.build_schedule_config(), + adaptive=self.build_adaptive_config(), uncertainty_weight=self.curriculum_uncertainty_weight, seed=seed, ) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index aca12a2be1..b265d26138 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2339,12 +2339,15 @@ def main( streaming_config: data_loader_lib.StreamingDataLoaderConfig, vllm_config: data_loader_lib.VLLMConfig, tools_config: EnvsConfig, - curriculum_args: difficulty_curriculum.DifficultyCurriculumArgs, + curriculum: difficulty_curriculum.DifficultyCurriculumArgs | None = None, ): tokenizer = make_tokenizer(tc, model_config) args = setup_runtime_variables(args, streaming_config, tools_config) validate_configs(streaming_config, vllm_config, tuple(args.num_learners_per_node), args.sequence_parallel_size) - curriculum_config = curriculum_args.build_curriculum_config(seed=args.seed) + if curriculum is None: + curriculum = difficulty_curriculum.DifficultyCurriculumArgs() + curriculum.verify() + curriculum_config = curriculum.build_curriculum_config(seed=args.seed) if args.verbose: logging.getLogger().setLevel(logging.DEBUG) @@ -2401,7 +2404,7 @@ def main( if tc.tokenizer_name_or_path and tc.tokenizer_name_or_path != model_config.model_name_or_path: utils.ensure_hf_repo_cached(tc.tokenizer_name_or_path, revision=tc.tokenizer_revision) - pprint([args, model_config, streaming_config, vllm_config, tools_config, curriculum_args]) + pprint([args, model_config, streaming_config, vllm_config, tools_config, curriculum]) # Create Ray queues. # Since we now send/receive individual prompts, queue size should accommodate @@ -2544,7 +2547,7 @@ def main( ) ) parser.set_defaults(exp_name="grpo", warmup_ratio=0.0, max_grad_norm=1.0, per_device_train_batch_size=1) - args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum_args = ( + args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum = ( parser.parse_args_into_dataclasses() ) assert isinstance(args, grpo_utils.GRPOExperimentConfig) @@ -2553,6 +2556,6 @@ def main( assert isinstance(streaming_config, data_loader_lib.StreamingDataLoaderConfig) assert isinstance(vllm_config, data_loader_lib.VLLMConfig) assert isinstance(tools_config, EnvsConfig) - assert isinstance(curriculum_args, difficulty_curriculum.DifficultyCurriculumArgs) + assert isinstance(curriculum, difficulty_curriculum.DifficultyCurriculumArgs) - main(args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum_args) + main(args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum) diff --git a/open_instruct/test_difficulty_curriculum.py b/open_instruct/test_difficulty_curriculum.py index 913758cc08..6de5a9be3c 100644 --- a/open_instruct/test_difficulty_curriculum.py +++ b/open_instruct/test_difficulty_curriculum.py @@ -183,9 +183,9 @@ def test_bootstrap_distribution_is_tunable(self): self.assertGreater(tuned_probs[0], default_probs[0]) self.assertLess(tuned_probs[2], default_probs[2]) - def test_curriculum_args_parser_builds_grouped_config(self): + def test_curriculum_parser_builds_grouped_config(self): parser = HfArgumentParser((difficulty_curriculum.DifficultyCurriculumArgs,)) - (curriculum_args,) = parser.parse_args_into_dataclasses( + (curriculum,) = parser.parse_args_into_dataclasses( [ "--curriculum", "difficulty", @@ -202,7 +202,8 @@ def test_curriculum_args_parser_builds_grouped_config(self): ] ) - curriculum_config = curriculum_args.build_curriculum_config(seed=17) + curriculum.verify() + curriculum_config = curriculum.build_curriculum_config(seed=17) self.assertIsNotNone(curriculum_config) assert curriculum_config is not None @@ -213,6 +214,9 @@ def test_curriculum_args_parser_builds_grouped_config(self): self.assertEqual(curriculum_config.adaptive.blend, 0.25) self.assertEqual(curriculum_config.seed, 17) + def test_curriculum_verify_accepts_none_mode(self): + difficulty_curriculum.DifficultyCurriculumArgs().verify() + class TestDifficultyCurriculumLoaderIntegration(unittest.TestCase): def test_existing_behavior_is_unchanged_when_curriculum_disabled(self): From d5414f872731b7aa329fca6479576e1eb3710e21 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 6 May 2026 17:13:37 -0700 Subject: [PATCH 31/40] Add CHANGELOG update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1c0be107c..783dc112d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ### Changed +- Add optional difficulty-curriculum prompt sampling to grpo_fast.py, `DifficultyCurriculumHFDataLoader`, and a scripts/data/difficulty_sampling/create_difficulty_map.py builder, with matching docs/tests and a Qwen launch script (https://github.com/allenai/open-instruct/pull/1661). - Aggregate prompt/response lengths across all DP ranks (deduplicating SP groups) when computing GRPO step token counts and utilization metrics, instead of using only rank 0 (https://github.com/allenai/open-instruct/pull/1659). - Split `accumulate_inference_batches` into `process_single_result` and `combine_processed_results` for clarity (https://github.com/allenai/open-instruct/pull/1614). - Match reference SFT run: `olmo_core_finetune.py` parity with pure olmo-core; default CP strategy switched to `ulysses` and ring-flash-attn dependency removed (https://github.com/allenai/open-instruct/pull/1620). From cbe96f349eb1ee6d334afd9cb2b2b5dc4a5b225e Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 11 May 2026 11:32:47 -0700 Subject: [PATCH 32/40] Cache some stuff in the init --- open_instruct/difficulty_curriculum.py | 29 ++++++++++++++++++--- open_instruct/test_difficulty_curriculum.py | 26 ++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/open_instruct/difficulty_curriculum.py b/open_instruct/difficulty_curriculum.py index 2b5cc6363a..998eddc1eb 100644 --- a/open_instruct/difficulty_curriculum.py +++ b/open_instruct/difficulty_curriculum.py @@ -358,6 +358,7 @@ class _DifficultyCurriculumSchedule: def __init__(self, config: DifficultyCurriculumScheduleConfig, bucket_count: int) -> None: self.config = config self.bucket_count = bucket_count + self._bucket_ids = np.arange(max(self.bucket_count - 1, 0), dtype=np.float64) def get_progress(self, step: int) -> float: if step < self.config.warmup_steps: @@ -402,9 +403,8 @@ def build_probs(self, step: int, available_mask: np.ndarray) -> np.ndarray: self.config.min_hard_frac + (self.config.max_hard_frac - self.config.min_hard_frac) * smooth_progress ) - bucket_ids = np.arange(self.bucket_count - 1, dtype=np.float64) sigma = self._get_bucket_sigma(step) - gaussian_logits = np.exp(-0.5 * ((bucket_ids - target_bucket) / sigma) ** 2) + gaussian_logits = np.exp(-0.5 * ((self._bucket_ids - target_bucket) / sigma) ** 2) non_hard_probs = _normalize_probs(gaussian_logits) static_probs = np.zeros(self.bucket_count, dtype=np.float64) @@ -582,6 +582,8 @@ def __init__( epsilon=self.config.epsilon, ) + self._cached_static_probs: np.ndarray | None = None + self._last_static_refresh_step = -1 self._cached_adaptive_probs: np.ndarray | None = None self._last_adaptive_refresh_step = -1 @@ -612,10 +614,27 @@ def get_progress(self, step: int | None = None) -> float: def _available_bucket_mask(self) -> np.ndarray: return np.array([1.0 if count > 0 else 0.0 for count in self._active_bucket_counts], dtype=np.float64) + def _invalidate_static_bucket_probs(self) -> None: + self._cached_static_probs = None + self._last_static_refresh_step = -1 + + def _invalidate_adaptive_bucket_probs(self) -> None: + self._cached_adaptive_probs = None + self._last_adaptive_refresh_step = -1 + + def _invalidate_bucket_prob_caches(self) -> None: + self._invalidate_static_bucket_probs() + self._invalidate_adaptive_bucket_probs() + def get_static_bucket_probs(self, step: int | None = None) -> np.ndarray: if step is None: step = self._get_current_step() - return self._schedule.build_probs(step, self._available_bucket_mask()) + if self._cached_static_probs is not None and step == self._last_static_refresh_step: + return self._cached_static_probs.copy() + + self._cached_static_probs = self._schedule.build_probs(step, self._available_bucket_mask()) + self._last_static_refresh_step = step + return self._cached_static_probs.copy() def get_adaptive_bucket_probs(self, step: int | None = None) -> np.ndarray | None: if not self.config.adaptive.enabled or self.adaptive_stats is None or self.adaptive_stats.total_count == 0: @@ -714,6 +733,7 @@ def exclude_index(self, dataset_index: int) -> None: 0.0, self._active_bucket_weight_sums[bucket_index] - current_weight ) self._excluded_indices.add(dataset_index) + self._invalidate_bucket_prob_caches() def record_observations( self, @@ -728,7 +748,7 @@ def record_observations( bucket_indices = [self.bucket_for_dataset_index(int(dataset_index)) for dataset_index in dataset_indices] self.adaptive_stats.update(bucket_indices, rewards, advantages) - self._cached_adaptive_probs = None + self._invalidate_adaptive_bucket_probs() def build_metrics(self, prompt_dataset_indices: list[int], step: int | None = None) -> dict[str, float]: metrics: dict[str, float] = {"curriculum/progress": self.get_progress(step)} @@ -804,6 +824,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.adaptive_stats is not None and state_dict.get("adaptive_stats") is not None: self.adaptive_stats.load_state_dict(state_dict["adaptive_stats"]) + self._invalidate_static_bucket_probs() self._last_adaptive_refresh_step = int(state_dict.get("last_adaptive_refresh_step", -1)) cached_adaptive_probs = state_dict.get("cached_adaptive_probs") self._cached_adaptive_probs = None if cached_adaptive_probs is None else np.array(cached_adaptive_probs) diff --git a/open_instruct/test_difficulty_curriculum.py b/open_instruct/test_difficulty_curriculum.py index 6de5a9be3c..1acf72a4cb 100644 --- a/open_instruct/test_difficulty_curriculum.py +++ b/open_instruct/test_difficulty_curriculum.py @@ -2,6 +2,7 @@ import tempfile import types import unittest +from unittest import mock from datasets import Dataset from transformers import HfArgumentParser @@ -142,6 +143,31 @@ def test_probabilities_always_sum_to_one(self): self.assertAlmostEqual(float(sampler.get_static_bucket_probs(step=step).sum()), 1.0, places=6) self.assertAlmostEqual(float(sampler.get_bucket_probs(step=step).sum()), 1.0, places=6) + def test_static_bucket_probs_are_cached_per_step(self): + sampler = self._make_sampler(make_bucket_dataset()) + + with mock.patch.object(sampler._schedule, "build_probs", wraps=sampler._schedule.build_probs) as build_probs: + sampler.get_bucket_probs(step=7) + sampler.get_bucket_probs(step=7) + sampler.get_static_bucket_probs(step=7) + self.assertEqual(build_probs.call_count, 1) + + sampler.get_bucket_probs(step=8) + self.assertEqual(build_probs.call_count, 2) + + def test_excluding_index_invalidates_cached_bucket_probs(self): + sampler = self._make_sampler(make_bucket_dataset()) + + with mock.patch.object(sampler._schedule, "build_probs", wraps=sampler._schedule.build_probs) as build_probs: + sampler.get_static_bucket_probs(step=0) + self.assertEqual(build_probs.call_count, 1) + + sampler.exclude_index(4) + excluded_probs = sampler.get_static_bucket_probs(step=0) + self.assertEqual(build_probs.call_count, 2) + + self.assertEqual(float(excluded_probs[4]), 0.0) + def test_adaptive_stats_increase_sampling_probability_for_high_signal_bucket(self): sampler = self._make_sampler( make_bucket_dataset(), adaptive=self._make_adaptive(enabled=True, update_every=1, blend=0.5) From bb65e55a2592236c4bde4a665034c061faa54ab6 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 11 May 2026 11:58:01 -0700 Subject: [PATCH 33/40] Adds some variants --- open_instruct/difficulty_curriculum.py | 76 ++++++++++++++++++- open_instruct/test_difficulty_curriculum.py | 18 +++++ scripts/train/qwen/qwen3_4b_dapo_math_dc.sh | 53 ++++++++++--- .../qwen3_4b_dapo_math_dc_adaptive_light.sh | 11 +++ .../qwen/qwen3_4b_dapo_math_dc_hardest50.sh | 10 +++ ...3_4b_dapo_math_dc_longer_easy_bootstrap.sh | 10 +++ 6 files changed, 164 insertions(+), 14 deletions(-) create mode 100755 scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh create mode 100755 scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh create mode 100755 scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh diff --git a/open_instruct/difficulty_curriculum.py b/open_instruct/difficulty_curriculum.py index 998eddc1eb..0b1bc4ad6c 100644 --- a/open_instruct/difficulty_curriculum.py +++ b/open_instruct/difficulty_curriculum.py @@ -61,6 +61,7 @@ class DifficultyCurriculumMetadataConfig: posterior_mean_field: str = "posterior_mean" bucket_index_field: str = "bucket_index" bucket_count_field: str = "bucket_count" + quantile_field: str = "expected_quantile" strict: bool = True @@ -127,12 +128,20 @@ class DifficultyCurriculumConfig: schedule: DifficultyCurriculumScheduleConfig = field(default_factory=DifficultyCurriculumScheduleConfig) adaptive: DifficultyCurriculumAdaptiveConfig = field(default_factory=DifficultyCurriculumAdaptiveConfig) uncertainty_weight: float = 0.5 + min_quantile: float = 0.0 + max_quantile: float = 1.0 seed: int = 0 epsilon: float = 1e-8 def __post_init__(self) -> None: if not 0.0 <= self.uncertainty_weight <= 1.0: raise ValueError("uncertainty_weight must be in [0, 1]") + if not 0.0 <= self.min_quantile <= 1.0: + raise ValueError("min_quantile must be in [0, 1]") + if not 0.0 <= self.max_quantile <= 1.0: + raise ValueError("max_quantile must be in [0, 1]") + if self.min_quantile > self.max_quantile: + raise ValueError("min_quantile must be <= max_quantile") if self.epsilon <= 0: raise ValueError("epsilon must be > 0") @@ -144,6 +153,7 @@ class DifficultyCurriculumArgs: curriculum_posterior_mean_field: str = "posterior_mean" curriculum_bucket_index_field: str = "bucket_index" curriculum_bucket_count_field: str = "bucket_count" + curriculum_quantile_field: str = "expected_quantile" curriculum_strict_metadata: bool = True curriculum_bootstrap_steps: int = 100 curriculum_warmup_steps: int = 500 @@ -161,6 +171,8 @@ class DifficultyCurriculumArgs: curriculum_adaptive_learning_weight: float = 0.7 curriculum_adaptive_exploration_weight: float = 0.3 curriculum_adaptive_blend: float = 0.5 + curriculum_min_quantile: float = 0.0 + curriculum_max_quantile: float = 1.0 def build_metadata_config(self) -> DifficultyCurriculumMetadataConfig: return DifficultyCurriculumMetadataConfig( @@ -168,6 +180,7 @@ def build_metadata_config(self) -> DifficultyCurriculumMetadataConfig: posterior_mean_field=self.curriculum_posterior_mean_field, bucket_index_field=self.curriculum_bucket_index_field, bucket_count_field=self.curriculum_bucket_count_field, + quantile_field=self.curriculum_quantile_field, strict=self.curriculum_strict_metadata, ) @@ -214,6 +227,8 @@ def build_curriculum_config(self, *, seed: int) -> DifficultyCurriculumConfig | schedule=self.build_schedule_config(), adaptive=self.build_adaptive_config(), uncertainty_weight=self.curriculum_uncertainty_weight, + min_quantile=self.curriculum_min_quantile, + max_quantile=self.curriculum_max_quantile, seed=seed, ) @@ -341,6 +356,7 @@ class _ParsedDifficultyMetadata: bucket_index: int | None bucket_count: int | None posterior_mean: float | None + expected_quantile: float | None error: str | None @@ -352,6 +368,7 @@ class _DifficultyBucketIndex: bucket_weights: tuple[torch.Tensor, ...] bucket_count: int metadata_fallback_count: int + filtered_out_count: int class _DifficultyCurriculumSchedule: @@ -428,18 +445,21 @@ def _parse_difficulty_metadata( bucket_index=None, bucket_count=None, posterior_mean=None, + expected_quantile=None, error=f"missing '{metadata_config.field}' metadata for dataset index {index}", ) bucket_index = _coerce_int(_resolve_path(difficulty_blob, metadata_config.bucket_index_field)) bucket_count = _coerce_int(_resolve_path(difficulty_blob, metadata_config.bucket_count_field)) posterior_mean = _coerce_float(_resolve_path(difficulty_blob, metadata_config.posterior_mean_field)) + expected_quantile = _coerce_float(_resolve_path(difficulty_blob, metadata_config.quantile_field)) if bucket_index is None or bucket_index < 0: return _ParsedDifficultyMetadata( bucket_index=None, bucket_count=bucket_count, posterior_mean=posterior_mean, + expected_quantile=expected_quantile, error=f"invalid bucket_index for dataset index {index}", ) if bucket_count is None or bucket_count <= 0: @@ -447,6 +467,7 @@ def _parse_difficulty_metadata( bucket_index=bucket_index, bucket_count=None, posterior_mean=posterior_mean, + expected_quantile=expected_quantile, error=f"invalid bucket_count for dataset index {index}", ) if posterior_mean is None: @@ -454,10 +475,23 @@ def _parse_difficulty_metadata( bucket_index=bucket_index, bucket_count=bucket_count, posterior_mean=None, + expected_quantile=expected_quantile, error=f"invalid posterior_mean for dataset index {index}", ) + if expected_quantile is not None and not 0.0 <= expected_quantile <= 1.0: + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=bucket_count, + posterior_mean=posterior_mean, + expected_quantile=None, + error=f"invalid expected_quantile for dataset index {index}", + ) return _ParsedDifficultyMetadata( - bucket_index=bucket_index, bucket_count=bucket_count, posterior_mean=posterior_mean, error=None + bucket_index=bucket_index, + bucket_count=bucket_count, + posterior_mean=posterior_mean, + expected_quantile=expected_quantile, + error=None, ) @@ -469,11 +503,17 @@ def _compute_example_weight(posterior_mean: float, uncertainty_weight: float, ep def _build_difficulty_bucket_index( - dataset, metadata_config: DifficultyCurriculumMetadataConfig, uncertainty_weight: float, epsilon: float + dataset, + metadata_config: DifficultyCurriculumMetadataConfig, + uncertainty_weight: float, + epsilon: float, + min_quantile: float, + max_quantile: float, ) -> _DifficultyBucketIndex: parsed_rows: list[_ParsedDifficultyMetadata] = [] observed_bucket_counts: set[int] = set() max_bucket_index = -1 + filter_active = min_quantile > 0.0 or max_quantile < 1.0 for dataset_index in range(len(dataset)): parsed = _parse_difficulty_metadata(dataset[dataset_index], dataset_index, metadata_config) @@ -499,10 +539,14 @@ def _build_difficulty_bucket_index( index_to_bucket: dict[int, int] = {} index_to_bucket_position: dict[int, int] = {} metadata_fallback_count = 0 + filtered_out_count = 0 fallback_bucket = min(bucket_count - 1, bucket_count // 2) for dataset_index, parsed in enumerate(parsed_rows): if parsed.error is not None: + if filter_active: + filtered_out_count += 1 + continue bucket_index = fallback_bucket posterior_mean = _DEFAULT_POSTERIOR_MEAN metadata_fallback_count += 1 @@ -510,6 +554,16 @@ def _build_difficulty_bucket_index( assert parsed.bucket_index is not None bucket_index = int(np.clip(parsed.bucket_index, 0, bucket_count - 1)) posterior_mean = parsed.posterior_mean + if filter_active: + expected_quantile = parsed.expected_quantile + if expected_quantile is None: + if metadata_config.strict: + raise ValueError(f"invalid {metadata_config.quantile_field} for dataset index {dataset_index}") + filtered_out_count += 1 + continue + if expected_quantile < min_quantile or expected_quantile > max_quantile: + filtered_out_count += 1 + continue if posterior_mean is None: posterior_mean = _DEFAULT_POSTERIOR_MEAN @@ -531,6 +585,7 @@ def _build_difficulty_bucket_index( bucket_weights=tuple(torch.tensor(weight_list, dtype=torch.float64) for weight_list in bucket_weight_lists), bucket_count=bucket_count, metadata_fallback_count=metadata_fallback_count, + filtered_out_count=filtered_out_count, ) @@ -560,11 +615,14 @@ def __init__( metadata_config=self.config.metadata, uncertainty_weight=self.config.uncertainty_weight, epsilon=self.config.epsilon, + min_quantile=self.config.min_quantile, + max_quantile=self.config.max_quantile, ) self._index_to_bucket = dict(bucket_index.index_to_bucket) self._index_to_bucket_position = dict(bucket_index.index_to_bucket_position) self.bucket_count = bucket_index.bucket_count self.metadata_fallback_count = bucket_index.metadata_fallback_count + self.filtered_out_count = bucket_index.filtered_out_count self._schedule = _DifficultyCurriculumSchedule(self.config.schedule, self.bucket_count) self._excluded_indices: set[int] = set() @@ -594,6 +652,16 @@ def __init__( self.metadata_fallback_count, len(self.dataset), ) + if self.filtered_out_count > 0: + logger.info( + "Difficulty curriculum filtered out %s/%s rows outside quantile range [%s, %s].", + self.filtered_out_count, + len(self.dataset), + self.config.min_quantile, + self.config.max_quantile, + ) + if sum(self._active_bucket_counts) == 0: + raise ValueError("Difficulty curriculum filter removed all dataset examples.") def __len__(self) -> int: return self.num_samples @@ -674,7 +742,9 @@ def bucket_for_dataset_index(self, dataset_index: int) -> int: def get_example_probability(self, dataset_index: int, step: int | None = None) -> float: if int(dataset_index) in self._excluded_indices: return 0.0 - bucket_index = self.bucket_for_dataset_index(dataset_index) + bucket_index = self._index_to_bucket.get(int(dataset_index)) + if bucket_index is None: + return 0.0 if self._active_bucket_counts[bucket_index] == 0: return 0.0 local_index = self._index_to_bucket_position.get(int(dataset_index)) diff --git a/open_instruct/test_difficulty_curriculum.py b/open_instruct/test_difficulty_curriculum.py index 1acf72a4cb..0a33cae981 100644 --- a/open_instruct/test_difficulty_curriculum.py +++ b/open_instruct/test_difficulty_curriculum.py @@ -97,6 +97,21 @@ def test_missing_metadata_falls_back_when_not_strict(self): self.assertEqual(sampler.metadata_fallback_count, 1) self.assertIn(1, sampler.bucket_to_indices[2]) + def test_quantile_filter_keeps_only_hardest_half(self): + sampler = self._make_sampler(make_bucket_dataset(), min_quantile=0.5) + self.assertEqual(sampler.bucket_to_indices, ((), (), (2,), (3,), (4,))) + self.assertEqual(sampler.filtered_out_count, 2) + self.assertEqual(sampler.get_example_probability(0, step=0), 0.0) + self.assertTrue(all(sampler.sample_index(step=0) in {2, 3, 4} for _ in range(20))) + + def test_quantile_filter_reweights_only_available_buckets(self): + sampler = self._make_sampler(make_bucket_dataset(), min_quantile=0.5) + early_probs = sampler.get_static_bucket_probs(step=0) + self.assertAlmostEqual(float(early_probs.sum()), 1.0, places=6) + self.assertEqual(float(early_probs[0]), 0.0) + self.assertEqual(float(early_probs[1]), 0.0) + self.assertGreater(early_probs[2], early_probs[3]) + def test_bucket_grouping_works(self): sampler = self._make_sampler(make_bucket_dataset()) self.assertEqual(sampler.bucket_to_indices, ((0,), (1,), (2,), (3,), (4,))) @@ -225,6 +240,8 @@ def test_curriculum_parser_builds_grouped_config(self): "true", "--curriculum_adaptive_blend", "0.25", + "--curriculum_min_quantile", + "0.5", ] ) @@ -238,6 +255,7 @@ def test_curriculum_parser_builds_grouped_config(self): self.assertEqual(curriculum_config.schedule.total_steps, 56) self.assertTrue(curriculum_config.adaptive.enabled) self.assertEqual(curriculum_config.adaptive.blend, 0.25) + self.assertEqual(curriculum_config.min_quantile, 0.5) self.assertEqual(curriculum_config.seed, 17) def test_curriculum_verify_accepts_none_mode(self): diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh index 24f238fc8b..4407f9e4e6 100644 --- a/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh @@ -35,23 +35,54 @@ else DEFAULT_CURRICULUM_TOTAL_STEPS=$(( NUM_TRAINING_STEPS - CURRICULUM_WARMUP_STEPS )) fi CURRICULUM_TOTAL_STEPS="${CURRICULUM_TOTAL_STEPS:-${DEFAULT_CURRICULUM_TOTAL_STEPS}}" +CURRICULUM_METADATA_FIELD="${CURRICULUM_METADATA_FIELD:-difficulty}" +CURRICULUM_POSTERIOR_MEAN_FIELD="${CURRICULUM_POSTERIOR_MEAN_FIELD:-posterior_mean}" +CURRICULUM_BUCKET_INDEX_FIELD="${CURRICULUM_BUCKET_INDEX_FIELD:-bucket_index}" +CURRICULUM_BUCKET_COUNT_FIELD="${CURRICULUM_BUCKET_COUNT_FIELD:-bucket_count}" +CURRICULUM_QUANTILE_FIELD="${CURRICULUM_QUANTILE_FIELD:-expected_quantile}" +CURRICULUM_STRICT_METADATA="${CURRICULUM_STRICT_METADATA:-true}" +CURRICULUM_BOOTSTRAP_TARGET="${CURRICULUM_BOOTSTRAP_TARGET:-0.125}" +CURRICULUM_WARMUP_TARGET="${CURRICULUM_WARMUP_TARGET:-0.5}" +CURRICULUM_FINAL_TARGET="${CURRICULUM_FINAL_TARGET:-1.0}" +CURRICULUM_MIN_HARD_FRAC="${CURRICULUM_MIN_HARD_FRAC:-0.05}" +CURRICULUM_MAX_HARD_FRAC="${CURRICULUM_MAX_HARD_FRAC:-0.50}" +CURRICULUM_BUCKET_SIGMA="${CURRICULUM_BUCKET_SIGMA:-0.0}" +CURRICULUM_BOOTSTRAP_SIGMA="${CURRICULUM_BOOTSTRAP_SIGMA:-0.0}" +CURRICULUM_UNCERTAINTY_WEIGHT="${CURRICULUM_UNCERTAINTY_WEIGHT:-0.5}" +CURRICULUM_ADAPTIVE="${CURRICULUM_ADAPTIVE:-false}" +CURRICULUM_ADAPTIVE_UPDATE_EVERY="${CURRICULUM_ADAPTIVE_UPDATE_EVERY:-50}" +CURRICULUM_ADAPTIVE_LEARNING_WEIGHT="${CURRICULUM_ADAPTIVE_LEARNING_WEIGHT:-0.7}" +CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT="${CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT:-0.3}" +CURRICULUM_ADAPTIVE_BLEND="${CURRICULUM_ADAPTIVE_BLEND:-0.5}" +CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.0}" +CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" CURRICULUM_ARGS=( --curriculum difficulty - --curriculum_metadata_field difficulty + --curriculum_metadata_field "${CURRICULUM_METADATA_FIELD}" + --curriculum_posterior_mean_field "${CURRICULUM_POSTERIOR_MEAN_FIELD}" + --curriculum_bucket_index_field "${CURRICULUM_BUCKET_INDEX_FIELD}" + --curriculum_bucket_count_field "${CURRICULUM_BUCKET_COUNT_FIELD}" + --curriculum_quantile_field "${CURRICULUM_QUANTILE_FIELD}" --curriculum_bootstrap_steps "${CURRICULUM_BOOTSTRAP_STEPS}" - --curriculum_bootstrap_target 0.125 - --curriculum_warmup_target 0.5 - --curriculum_final_target 1.0 + --curriculum_bootstrap_target "${CURRICULUM_BOOTSTRAP_TARGET}" + --curriculum_warmup_target "${CURRICULUM_WARMUP_TARGET}" + --curriculum_final_target "${CURRICULUM_FINAL_TARGET}" --curriculum_warmup_steps "${CURRICULUM_WARMUP_STEPS}" --curriculum_total_steps "${CURRICULUM_TOTAL_STEPS}" - --curriculum_min_hard_frac 0.05 - --curriculum_max_hard_frac 0.50 - --curriculum_bucket_sigma 0.0 - --curriculum_bootstrap_sigma 0.0 - --curriculum_uncertainty_weight 0.5 - --curriculum_adaptive false - --curriculum_strict_metadata true + --curriculum_min_hard_frac "${CURRICULUM_MIN_HARD_FRAC}" + --curriculum_max_hard_frac "${CURRICULUM_MAX_HARD_FRAC}" + --curriculum_bucket_sigma "${CURRICULUM_BUCKET_SIGMA}" + --curriculum_bootstrap_sigma "${CURRICULUM_BOOTSTRAP_SIGMA}" + --curriculum_uncertainty_weight "${CURRICULUM_UNCERTAINTY_WEIGHT}" + --curriculum_adaptive "${CURRICULUM_ADAPTIVE}" + --curriculum_adaptive_update_every "${CURRICULUM_ADAPTIVE_UPDATE_EVERY}" + --curriculum_adaptive_learning_weight "${CURRICULUM_ADAPTIVE_LEARNING_WEIGHT}" + --curriculum_adaptive_exploration_weight "${CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT}" + --curriculum_adaptive_blend "${CURRICULUM_ADAPTIVE_BLEND}" + --curriculum_min_quantile "${CURRICULUM_MIN_QUANTILE}" + --curriculum_max_quantile "${CURRICULUM_MAX_QUANTILE}" + --curriculum_strict_metadata "${CURRICULUM_STRICT_METADATA}" ) uv run python mason.py \ diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh new file mode 100755 index 0000000000..9cc7d53ea0 --- /dev/null +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_adaptive_light}" +export CURRICULUM_ADAPTIVE="${CURRICULUM_ADAPTIVE:-true}" +export CURRICULUM_ADAPTIVE_UPDATE_EVERY="${CURRICULUM_ADAPTIVE_UPDATE_EVERY:-20}" +export CURRICULUM_ADAPTIVE_BLEND="${CURRICULUM_ADAPTIVE_BLEND:-0.25}" + +exec "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh new file mode 100755 index 0000000000..3a5bccb4e3 --- /dev/null +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_hardest50}" +export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" +export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" + +exec "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh new file mode 100755 index 0000000000..935cfceb21 --- /dev/null +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_longer_easy_bootstrap}" +export CURRICULUM_BOOTSTRAP_STEPS="${CURRICULUM_BOOTSTRAP_STEPS:-200}" +export CURRICULUM_WARMUP_STEPS="${CURRICULUM_WARMUP_STEPS:-200}" + +exec "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" From 0318c407b7bdc1b19eb6f3bbb8b9356a9606b471 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 11 May 2026 12:15:11 -0700 Subject: [PATCH 34/40] Perms --- scripts/train/qwen/qwen3_4b_dapo_math_dc.sh | 0 scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh | 2 +- scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh | 2 +- .../train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) mode change 100644 => 100755 scripts/train/qwen/qwen3_4b_dapo_math_dc.sh diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh old mode 100644 new mode 100755 diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh index 9cc7d53ea0..baa1173007 100755 --- a/scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh @@ -8,4 +8,4 @@ export CURRICULUM_ADAPTIVE="${CURRICULUM_ADAPTIVE:-true}" export CURRICULUM_ADAPTIVE_UPDATE_EVERY="${CURRICULUM_ADAPTIVE_UPDATE_EVERY:-20}" export CURRICULUM_ADAPTIVE_BLEND="${CURRICULUM_ADAPTIVE_BLEND:-0.25}" -exec "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh index 3a5bccb4e3..8579af7b9c 100755 --- a/scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh @@ -7,4 +7,4 @@ export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_hardest50} export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" -exec "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh index 935cfceb21..b9d0a937e6 100755 --- a/scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh @@ -7,4 +7,4 @@ export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_longer_eas export CURRICULUM_BOOTSTRAP_STEPS="${CURRICULUM_BOOTSTRAP_STEPS:-200}" export CURRICULUM_WARMUP_STEPS="${CURRICULUM_WARMUP_STEPS:-200}" -exec "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" From 335b6c66d66e6bf65235acfc8922273bd77a6cd1 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 11 May 2026 12:20:28 -0700 Subject: [PATCH 35/40] Budget --- scripts/train/qwen/qwen3_4b_dapo_math_dc.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh index 4407f9e4e6..c476b21c73 100755 --- a/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh @@ -12,7 +12,8 @@ fi CLUSTER="${CLUSTER:-ai2/jupiter}" PRIORITY="${PRIORITY:-urgent}" -WORKSPACE="${WORKSPACE:-ai2/olmo-instruct}" +WORKSPACE="${WORKSPACE:-ai2/open-instruct-dev}" +BUDGET="${BUDGET:-ai2/oe-omai}" # Difficulty-annotated variant of hamishivi/DAPO-Math-17k-Processed_filtered DATASET_WITH_DIFFICULTY="undfined/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples-ds" @@ -98,7 +99,7 @@ uv run python mason.py \ --num_nodes 1 \ --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ --gpus $NUM_GPUS \ - --budget ai2/oe-adapt \ + --budget ${BUDGET} \ -- \ uv run open_instruct/grpo_fast.py \ --run_name "${RUN_NAME}" \ From b4a7e567067fd01372c61a290cb5d624b22c4de0 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 11 May 2026 12:25:13 -0700 Subject: [PATCH 36/40] Outdated argument --- scripts/train/qwen/qwen3_4b_dapo_math_dc.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh b/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh index c476b21c73..d9eae4861c 100755 --- a/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh +++ b/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh @@ -111,7 +111,6 @@ uv run open_instruct/grpo_fast.py \ --async_steps 4 \ --active_sampling \ --inflight_updates \ - --truncated_importance_sampling_ratio_cap 2.0 \ --advantage_normalization_type centered \ --num_samples_per_prompt_rollout ${NUM_SAMPLES_PER_PROMPT_ROLLOUT} \ --num_unique_prompts_rollout ${NUM_UNIQUE_PROMPTS_ROLLOUT} \ From ccbf6866d503d831a868e2633b6da9afb30dee43 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 11 May 2026 16:26:54 -0700 Subject: [PATCH 37/40] Better layout --- .../qwen/{ => difficult-curriculum}/qwen3_4b_dapo_math_dc.sh | 0 .../qwen3_4b_dapo_math_dc_adaptive_light.sh | 0 .../{ => difficult-curriculum}/qwen3_4b_dapo_math_dc_hardest50.sh | 0 .../qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename scripts/train/qwen/{ => difficult-curriculum}/qwen3_4b_dapo_math_dc.sh (100%) rename scripts/train/qwen/{ => difficult-curriculum}/qwen3_4b_dapo_math_dc_adaptive_light.sh (100%) rename scripts/train/qwen/{ => difficult-curriculum}/qwen3_4b_dapo_math_dc_hardest50.sh (100%) rename scripts/train/qwen/{ => difficult-curriculum}/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh (100%) diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc.sh similarity index 100% rename from scripts/train/qwen/qwen3_4b_dapo_math_dc.sh rename to scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc.sh diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light.sh similarity index 100% rename from scripts/train/qwen/qwen3_4b_dapo_math_dc_adaptive_light.sh rename to scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light.sh diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh similarity index 100% rename from scripts/train/qwen/qwen3_4b_dapo_math_dc_hardest50.sh rename to scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh diff --git a/scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh similarity index 100% rename from scripts/train/qwen/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh rename to scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh From 4de09bff5fbb80e9a8198dd5e58a84705e0a5bd5 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Tue, 12 May 2026 08:08:40 -0700 Subject: [PATCH 38/40] Readme, stronger adaptive, cleanup --- open_instruct/data_loader.py | 12 +-- open_instruct/test_data_loader.py | 12 +-- scripts/data/difficulty_sampling/README.md | 75 +++++++++++++------ .../qwen3_4b_dapo_math_dc_adaptive_strong.sh | 14 ++++ .../qwen3_4b_dapo_math_dc_hardest50.sh | 3 + 5 files changed, 80 insertions(+), 36 deletions(-) create mode 100755 scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong.sh diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index fb1ff2d270..56e67dbc69 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -20,12 +20,11 @@ from dataclasses import asdict, dataclass, field from pathlib import Path from queue import Empty -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal import numpy as np import ray import torch -import vllm from datasets import Dataset from olmo_core.data import data_loader from ray.util import queue as ray_queue @@ -48,6 +47,9 @@ from open_instruct.rubrics import RubricManager from open_instruct.utils import combine_reward_metrics +if TYPE_CHECKING: + import vllm + logger = logger_utils.setup_logger(__name__) DATA_PREP_ACTOR_NAME = "data_prep_singleton" @@ -885,7 +887,7 @@ class Group: def process_group( result: data_types.GenerationResult, - generation_config: vllm.SamplingParams, + generation_config: "vllm.SamplingParams", tokenizer: PreTrainedTokenizer, dataset: Dataset, max_possible_score: float, @@ -959,7 +961,7 @@ def process_group( def make_batch_from_groups( groups: list[Group], - generation_config: vllm.SamplingParams, + generation_config: "vllm.SamplingParams", training_step: int, actor_manager=None, filtered_prompts: int = 0, @@ -1115,7 +1117,7 @@ def make_batch_from_groups( def accumulate_inference_batches( inference_results_Q: ray_queue.Queue, - generation_config: vllm.SamplingParams, + generation_config: "vllm.SamplingParams", num_prompts: int, model_dims: utils.ModelDims, tokenizer: PreTrainedTokenizer, diff --git a/open_instruct/test_data_loader.py b/open_instruct/test_data_loader.py index 560f5983b3..aa7a6b19fc 100644 --- a/open_instruct/test_data_loader.py +++ b/open_instruct/test_data_loader.py @@ -1,6 +1,4 @@ -import sys import tempfile -import types import unittest import numpy as np @@ -8,13 +6,9 @@ import torch from datasets import Dataset -vllm_stub = types.ModuleType("vllm") -vllm_stub.SamplingParams = object -sys.modules.setdefault("vllm", vllm_stub) - -from open_instruct import data_loader, data_types # noqa: E402 -from open_instruct.model_utils import Batch # noqa: E402 -from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO # noqa: E402 +from open_instruct import data_loader, data_types +from open_instruct.model_utils import Batch +from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO def _make_dpo_dataset(num_samples: int, max_seq_length: int) -> Dataset: diff --git a/scripts/data/difficulty_sampling/README.md b/scripts/data/difficulty_sampling/README.md index e946a3f37b..31d8ea8c51 100644 --- a/scripts/data/difficulty_sampling/README.md +++ b/scripts/data/difficulty_sampling/README.md @@ -1,16 +1,22 @@ # Difficulty Sampling This directory contains tooling for building per-instance difficulty metadata -for RLVR curricula. +from pass-rate style datasets. + +A common use case is difficulty-aware curricula for GRPO / RLVR, but the +outputs are not specific to that setting. They are also useful for stratified +sampling, filtering, analysis, and dataset construction. ## Create A Difficulty Map -Use `create_difficulty_map.py` to build a difficulty map from a Hugging Face -dataset that already contains per-row pass-rate aggregates. +Use `create_difficulty_map.py` to turn a Hugging Face dataset with per-row +pass-rate aggregates into reusable difficulty annotations. The script expands pass-count summaries into binary attempt outcomes, fits a -Beta prior across binary outcomes, estimates per-item difficulty, and writes -JSONL difficulty files plus schema and metadata sidecars. +shared Beta prior across binary outcomes, estimates per-item difficulty, and +writes JSONL difficulty files plus schema and metadata sidecars. When +`--push-to-hub` is set, it also uploads the validated output dataset as a +private Hub split. ### Examples @@ -38,12 +44,14 @@ uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ Hub uploads require exactly one task/model output group, so use `--task` or a single-group input dataset when pushing. +If you only need continuous difficulty scores, you can set +`--difficulty-buckets 0` to skip discrete bucket assignment. If you want to use +the default difficulty curriculum in `grpo_fast.py`, keep bucket assignment +enabled. + ## Difficulty Metadata Format -`grpo_fast.py` can optionally replace uniform prompt reshuffling with -`DifficultyCurriculumSampler`, a bucket-aware RLVR curriculum driven by -per-instance difficulty metadata. The current recommended metadata format comes -from the beta-binomial estimator in `create_difficulty_map.py`: +`create_difficulty_map.py` writes a nested `difficulty` blob like this: ```json { @@ -58,23 +66,40 @@ from the beta-binomial estimator in `create_difficulty_map.py`: } ``` +- `value` is the overall difficulty score, defined as + `1 - posterior_lower_bound`. Higher means harder. - `posterior_mean` is the estimated solve probability for that prompt. Lower means harder. +- `posterior_lower_bound` is a conservative lower bound on solve probability at + the configured `--posterior-lower-quantile`. +- `expected_quantile` is a posterior-aware rank in `[0, 1]`. It is useful for + filtering or clamping to a subset of the difficulty range. - `bucket_index = 0` is the easiest bucket and `bucket_index = bucket_count - 1` is the hardest. -- The sampler uses a smooth distribution with a configurable easy-heavy - bootstrap phase, then gradually shifts mass toward harder buckets instead of - hard-switching between discrete phases. -- Within each bucket, examples are weighted by a blend of uncertainty - (`4 * p * (1 - p)`) and hardness (`1 - p`), so borderline prompts stay - attractive while already-solved prompts are naturally down-weighted. -- If `--curriculum_adaptive true` is set, bucket probabilities are additionally - blended with live reward / advantage statistics so buckets with useful - learning signal can get more mass during training. + +`grpo_fast.py` can optionally replace uniform prompt reshuffling with +`DifficultyCurriculumSampler`, which consumes this metadata by default. The +current curriculum implementation uses: + +- `difficulty.posterior_mean` for within-bucket weighting. +- `difficulty.bucket_index` and `difficulty.bucket_count` for bucketed + sampling. +- `difficulty.expected_quantile` for optional + `--curriculum_min_quantile` / `--curriculum_max_quantile` filtering. + +The sampler uses a smooth distribution with a configurable easy-heavy +bootstrap phase, then gradually shifts mass toward harder buckets instead of +hard-switching between discrete phases. Within each bucket, examples are +weighted by a blend of uncertainty (`4 * p * (1 - p)`) and hardness +(`1 - p`), so borderline prompts stay attractive while already-solved prompts +are naturally down-weighted. If `--curriculum_adaptive true` is set, bucket +probabilities are additionally blended with live reward / advantage statistics +so buckets with useful learning signal can get more mass during training. ## Recommended Starting Point -For `bucket_count=5`: +If you are using this metadata with `DifficultyCurriculumSampler`, +`bucket_count=5` is a reasonable starting point: - Bootstrap (first ~100 steps by default): buckets 0 and 1 dominate so the model sees easier prompts while it settles into the chat template and task @@ -100,6 +125,9 @@ Useful flags: --curriculum_bucket_sigma 0.0 \ --curriculum_bootstrap_sigma 0.0 \ --curriculum_uncertainty_weight 0.5 \ +--curriculum_strict_metadata true \ +--curriculum_min_quantile 0.0 \ +--curriculum_max_quantile 1.0 \ --curriculum_adaptive true ``` @@ -113,10 +141,12 @@ Tuning tips: concentrate probability on fewer neighboring buckets. - Lower `curriculum_warmup_target` if you want the post-bootstrap warmup to stay easier for longer. +- Raise `curriculum_min_quantile` or lower `curriculum_max_quantile` if you + want to focus on a narrower difficulty band. ## Metrics -The most useful curriculum metrics are: +Some useful curriculum metrics are: - `curriculum/progress` - `curriculum/static_bucket_prob_*` @@ -126,5 +156,6 @@ The most useful curriculum metrics are: - `curriculum/bucket_reward_mean_*` - `curriculum/bucket_abs_advantage_mean_*` -See `scripts/train/qwen/qwen3_4b_dapo_math_difficulty_curriculum.sh` for a -concrete launch example. +See [qwen3_4b_dapo_math_dc.sh](../../train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc.sh) +for a concrete launch example, and the rest of +`scripts/train/qwen/difficult-curriculum/` for variants. diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong.sh new file mode 100755 index 0000000000..0dfbe9bef9 --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_adaptive_strong}" +export CURRICULUM_ADAPTIVE="${CURRICULUM_ADAPTIVE:-true}" +# Pull harder toward live curriculum stats than the light variant. +export CURRICULUM_ADAPTIVE_UPDATE_EVERY="${CURRICULUM_ADAPTIVE_UPDATE_EVERY:-10}" +export CURRICULUM_ADAPTIVE_LEARNING_WEIGHT="${CURRICULUM_ADAPTIVE_LEARNING_WEIGHT:-0.9}" +export CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT="${CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT:-0.1}" +export CURRICULUM_ADAPTIVE_BLEND="${CURRICULUM_ADAPTIVE_BLEND:-0.75}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh index 8579af7b9c..61a43cbd3b 100755 --- a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh @@ -6,5 +6,8 @@ SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_hardest50}" export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" +# After filtering out the easy half, start bootstrap at the easiest remaining +# bucket instead of inheriting the base global target near bucket 0. +export CURRICULUM_BOOTSTRAP_TARGET="${CURRICULUM_BOOTSTRAP_TARGET:-0.5}" exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" From 45463cdef0088e58bb93be359392a99f8f1aa9d6 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Tue, 12 May 2026 10:10:52 -0700 Subject: [PATCH 39/40] Adds adaptive + hardest 50th variants --- ...wen3_4b_dapo_math_dc_adaptive_light_hardest50.sh | 13 +++++++++++++ ...en3_4b_dapo_math_dc_adaptive_strong_hardest50.sh | 13 +++++++++++++ 2 files changed, 26 insertions(+) create mode 100755 scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh create mode 100755 scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh new file mode 100755 index 0000000000..7467e967e7 --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_adaptive_light_hardest50}" +export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" +export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" +# After filtering out the easy half, start bootstrap at the easiest remaining +# bucket instead of inheriting the base global target near bucket 0. +export CURRICULUM_BOOTSTRAP_TARGET="${CURRICULUM_BOOTSTRAP_TARGET:-0.5}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc_adaptive_light.sh" "$@" diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh new file mode 100755 index 0000000000..80da2b44ed --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_adaptive_strong_hardest50}" +export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" +export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" +# After filtering out the easy half, start bootstrap at the easiest remaining +# bucket instead of inheriting the base global target near bucket 0. +export CURRICULUM_BOOTSTRAP_TARGET="${CURRICULUM_BOOTSTRAP_TARGET:-0.5}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc_adaptive_strong.sh" "$@" From ae2354c8dacd7445ff13704cec1d95481c232a40 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Tue, 12 May 2026 10:18:29 -0700 Subject: [PATCH 40/40] Char length in wandb --- .../qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh | 3 ++- .../qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh index 7467e967e7..cec1aff22c 100755 --- a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh @@ -3,7 +3,8 @@ set -euo pipefail SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_adaptive_light_hardest50}" +# Keep under W&B's 64-character tag limit; EXP_NAME is used as a run tag. +export EXP_NAME="${EXP_NAME:-qwen3_4b_dapo_dc_adaptive_light_hardest50}" export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" # After filtering out the easy half, start bootstrap at the easiest remaining diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh index 80da2b44ed..2dfe1abdf1 100755 --- a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh @@ -3,7 +3,8 @@ set -euo pipefail SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_adaptive_strong_hardest50}" +# Keep under W&B's 64-character tag limit; EXP_NAME is used as a run tag. +export EXP_NAME="${EXP_NAME:-qwen3_4b_dapo_dc_adaptive_strong_hardest50}" export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" # After filtering out the easy half, start bootstrap at the easiest remaining