diff --git a/CHANGELOG.md b/CHANGELOG.md index 86a5e479c5..746798bb9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ### Changed +- Add optional difficulty-curriculum prompt sampling to grpo_fast.py, `DifficultyCurriculumHFDataLoader`, and a scripts/data/difficulty_sampling/create_difficulty_map.py builder, with matching docs/tests and a Qwen launch script (https://github.com/allenai/open-instruct/pull/1661). - Add parameterized `combine_dataset` tests in `open_instruct/test_utils.py` against local jsonl fixtures (no network), covering varied fractional/sample-count weight combinations and split-count mismatch (would have caught the bug fixed in #1674). Extract the interleaved-list→dict parsing into a shared `utils.parse_dataset_mixer_list` helper (with its own parameterized unit tests) and tighten `combine_dataset` / `get_datasets` to accept dict-only `dataset_mixer`; the one external list-form caller (`rejection_sampling/generation.py`) now converts at the call site. - Make `mason.py` `--output_dir` / `--checkpoint_state_dir` overrides idempotent via `replace_or_append_flag`, add `open_instruct/grpo.py` to `OPEN_INSTRUCT_COMMANDS` / `OPEN_INSTRUCT_RESUMABLES`, and wire OLMo-core checkpoint save/resume into `grpo.py` (`CheckpointerCallback` + `DataPreparationActorCheckpointCallback` + `LoadStrategy.if_available`) so resumable Beaker jobs actually resume (https://github.com/allenai/open-instruct/pull/1666). - Make `--budget` optional in `mason.py` (falls back to the workspace's default budget) and drop the explicit `--budget` flag from launch scripts where it already matched the workspace default (https://github.com/allenai/open-instruct/pull/1673). diff --git a/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json b/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json new file mode 100644 index 0000000000..7ec86a1beb --- /dev/null +++ b/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json @@ -0,0 +1,45 @@ +{ + "difficulty_generation": { + "beta_prior_requested": "empirical-bayes", + "beta_prior_used": { + "alpha": 0.4383101221875231, + "beta": 2.1002075643639166, + "source": "empirical_bayes" + }, + "binary_instance_count": 12643, + "bucket_count_effective": 5, + "bucket_count_field": "difficulty.bucket_count", + "bucket_count_requested": 5, + "bucket_field": "difficulty.bucket_index", + "bucket_ranking_field": "difficulty.expected_quantile", + "difficulty_value_definition": "1 - difficulty.posterior_lower_bound", + "difficulty_value_field": "difficulty.value", + "method": "beta_binomial_posterior_quantiles", + "nonbinary_instance_count": 0, + "posterior_lower_quantile": 0.1, + "tag": "bbq-eb-q10-k5" + }, + "model_name": "Qwen/Qwen3-4B-Base", + "row_count": 12643, + "score_processing": { + "normalization": "identity_binary", + "output_field": "attempt_scores", + "positive_reward_value": 1.0, + "source_field": "pass_count,num_samples,pass_rate", + "supports_binary_difficulty": true + }, + "source_format": { + "attempt_count_field": "num_samples", + "config_name": null, + "dataset_repo_id": "mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples", + "instance_id_definition": "dataset_repo_id::row_id_field when a stable row id is available; otherwise dataset_repo_id::row_index", + "kind": "hugging_face_dataset_passrate_rows", + "model_field": "generator_model", + "pass_count_field": "pass_count", + "pass_rate_field": "pass_rate", + "row_id_field": "extra_info.index", + "split": "train", + "task_field": "dataset" + }, + "task_name": "math" +} \ No newline at end of file diff --git a/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json b/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json new file mode 100644 index 0000000000..1068399271 --- /dev/null +++ b/configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json @@ -0,0 +1,137 @@ +{ + "ability": { + "_type": "Value", + "dtype": "string" + }, + "completions": { + "_type": "List", + "feature": { + "_type": "Value", + "dtype": "string" + } + }, + "data_source": { + "_type": "Value", + "dtype": "string" + }, + "dataset": { + "_type": "Value", + "dtype": "string" + }, + "difficulty": { + "bucket_count": { + "_type": "Value", + "dtype": "int64" + }, + "bucket_index": { + "_type": "Value", + "dtype": "int64" + }, + "expected_quantile": { + "_type": "Value", + "dtype": "float64" + }, + "posterior_lower_bound": { + "_type": "Value", + "dtype": "float64" + }, + "posterior_mean": { + "_type": "Value", + "dtype": "float64" + }, + "value": { + "_type": "Value", + "dtype": "float64" + } + }, + "extra_info": { + "index": { + "_type": "Value", + "dtype": "string" + } + }, + "generator_chat_template": { + "_type": "Value", + "dtype": "string" + }, + "generator_max_tokens": { + "_type": "Value", + "dtype": "int64" + }, + "generator_model": { + "_type": "Value", + "dtype": "string" + }, + "generator_temperature": { + "_type": "Value", + "dtype": "float64" + }, + "generator_top_p": { + "_type": "Value", + "dtype": "float64" + }, + "ground_truth": { + "_type": "Value", + "dtype": "string" + }, + "messages": { + "_type": "List", + "feature": { + "content": { + "_type": "Value", + "dtype": "string" + }, + "role": { + "_type": "Value", + "dtype": "string" + } + } + }, + "num_samples": { + "_type": "Value", + "dtype": "int64" + }, + "pass_count": { + "_type": "Value", + "dtype": "int64" + }, + "pass_rate": { + "_type": "Value", + "dtype": "string" + }, + "prompt": { + "_type": "Value", + "dtype": "string" + }, + "reward_model": { + "ground_truth": { + "_type": "Value", + "dtype": "string" + }, + "style": { + "_type": "Value", + "dtype": "string" + } + }, + "solution": { + "_type": "Value", + "dtype": "string" + }, + "source_prompt": { + "_type": "List", + "feature": { + "content": { + "_type": "Value", + "dtype": "string" + }, + "role": { + "_type": "Value", + "dtype": "string" + } + } + }, + "source_split": { + "_type": "Value", + "dtype": "string" + } +} \ No newline at end of file diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 87131a7faa..56e67dbc69 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import threading import time @@ -21,19 +20,18 @@ from dataclasses import asdict, dataclass, field from pathlib import Path from queue import Empty -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal import numpy as np import ray import torch -import vllm from datasets import Dataset from olmo_core.data import data_loader from ray.util import queue as ray_queue from tqdm import tqdm from transformers import PreTrainedTokenizer -from open_instruct import data_types, padding_free_collator, utils +from open_instruct import data_types, difficulty_curriculum, logger_utils, padding_free_collator, utils from open_instruct.data_types import EnvConfig, EnvConfigEntry from open_instruct.dataset_transformation import ( ENV_CONFIG_KEY, @@ -49,7 +47,10 @@ from open_instruct.rubrics import RubricManager from open_instruct.utils import combine_reward_metrics -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + import vllm + +logger = logger_utils.setup_logger(__name__) DATA_PREP_ACTOR_NAME = "data_prep_singleton" @@ -662,6 +663,134 @@ def single_example_collator(examples: list[dict[str, Any]]) -> dict[str, Any]: return example | {"index": torch.tensor([example["index"]])} +class DifficultyCurriculumHFDataLoader(HFDataLoader): + """Prompt loader that samples with a difficulty-aware curriculum.""" + + def __init__( + self, + dataset: Dataset, + batch_size: int, + seed: int, + dp_rank: int, + dp_world_size: int, + work_dir: str, + curriculum_config: difficulty_curriculum.DifficultyCurriculumConfig, + automatic_reshuffle: bool = True, + collator: Callable[[list[dict[str, Any]]], dict[str, Any]] | None = None, + device: torch.device | None = None, + drop_last: bool = True, + fs_local_rank: int | None = None, + max_seq_length: int = 1, + ) -> None: + if batch_size != 1: + raise ValueError("DifficultyCurriculumHFDataLoader currently supports batch_size=1 only") + if dp_world_size != 1 or dp_rank != 0: + raise ValueError("DifficultyCurriculumHFDataLoader currently supports dp_world_size=1 only") + + self._sampling_step = 0 + self._curriculum_sampler = difficulty_curriculum.DifficultyCurriculumSampler( + dataset=dataset, + num_samples=max(len(dataset), 1), + config=curriculum_config, + global_step_getter=lambda: self._sampling_step, + ) + self._curriculum_iter = None + + super().__init__( + dataset=dataset, + batch_size=batch_size, + seed=seed, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + work_dir=work_dir, + automatic_reshuffle=automatic_reshuffle, + collator=collator, + device=device, + drop_last=drop_last, + fs_local_rank=fs_local_rank, + max_seq_length=max_seq_length, + ) + + @property + def curriculum_sampler(self) -> difficulty_curriculum.DifficultyCurriculumSampler: + return self._curriculum_sampler + + def set_sampling_step(self, step: int) -> None: + self._sampling_step = int(step) + + def state_dict(self) -> dict[str, Any]: + return { + "epoch": self._epoch, + "batches_processed": self.batches_processed, + "sampling_step": self._sampling_step, + "curriculum_sampler_state": self._curriculum_sampler.state_dict(), + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._epoch = state_dict["epoch"] + self.batches_processed = state_dict["batches_processed"] + self._sampling_step = state_dict.get("sampling_step", 0) + self._curriculum_sampler.load_state_dict(state_dict["curriculum_sampler_state"]) + self.effective_size = max(len(self._full_dataset), 1) + self.dataset = self._full_dataset + self._curriculum_iter = iter(self._curriculum_sampler) + self._current_iter = None + + def exclude_index(self, index: int) -> None: + self._curriculum_sampler.exclude_index(index) + + def _reshard(self, epoch: int) -> None: + del epoch + self._precomputed_batch_sizes = None + self._num_padding_batches = 0 + self._overflow = [] + self.effective_size = max(len(self._full_dataset), 1) + self.dataset = self._full_dataset + self._curriculum_iter = iter(self._curriculum_sampler) + + def _iter_batches(self) -> Iterable[dict[str, Any]]: + start_example = self.batches_processed + if self._curriculum_iter is None: + self._curriculum_iter = iter(self._curriculum_sampler) + + for batch_offset in range(start_example, self.effective_size): + dataset_index = next(self._curriculum_iter) + example = self._full_dataset[dataset_index] + prompt_id = f"{self._epoch}_{batch_offset}_{dataset_index}" + batch = to_device(self._collator([example | {"prompt_id": prompt_id}]), self._device) + yield batch + + +def create_prompt_dataloader( + dataset: Dataset, + seed: int, + work_dir: str, + curriculum_config: difficulty_curriculum.DifficultyCurriculumConfig | None = None, +) -> HFDataLoader: + if curriculum_config is None: + return HFDataLoader( + dataset=dataset, + batch_size=1, + seed=seed, + dp_rank=0, + dp_world_size=1, + work_dir=work_dir, + automatic_reshuffle=True, + collator=single_example_collator, + ) + return DifficultyCurriculumHFDataLoader( + dataset=dataset, + batch_size=1, + seed=seed, + dp_rank=0, + dp_world_size=1, + work_dir=work_dir, + automatic_reshuffle=True, + collator=single_example_collator, + curriculum_config=curriculum_config, + ) + + def _merge_env_config(base_env_config: EnvConfig, sample_env_config: dict[str, Any] | None) -> EnvConfig: """Merge base and sample env config into canonical payload. Sample env_config overrides any base env_configs with the same name. @@ -718,6 +847,7 @@ def add_prompt_to_generator( ground_truth_overrides: dict[int, Any] | None = None, ) -> None: index = int(example["index"]) + prompt_id = example.get("prompt_id", f"{epoch_number}_{index}") sample_env_config = example.get(ENV_CONFIG_KEY) env_config = _merge_env_config(base_env_config, sample_env_config) @@ -729,7 +859,7 @@ def add_prompt_to_generator( prompt=example[INPUT_IDS_PROMPT_KEY], generation_config=generation_config, index=index, - prompt_id=f"{epoch_number}_{index}", + prompt_id=prompt_id, is_eval=is_eval, active_tools=example.get(TOOLS_COLUMN_KEY), env_config=env_config, @@ -757,7 +887,7 @@ class Group: def process_group( result: data_types.GenerationResult, - generation_config: vllm.SamplingParams, + generation_config: "vllm.SamplingParams", tokenizer: PreTrainedTokenizer, dataset: Dataset, max_possible_score: float, @@ -831,7 +961,7 @@ def process_group( def make_batch_from_groups( groups: list[Group], - generation_config: vllm.SamplingParams, + generation_config: "vllm.SamplingParams", training_step: int, actor_manager=None, filtered_prompts: int = 0, @@ -987,7 +1117,7 @@ def make_batch_from_groups( def accumulate_inference_batches( inference_results_Q: ray_queue.Queue, - generation_config: vllm.SamplingParams, + generation_config: "vllm.SamplingParams", num_prompts: int, model_dims: utils.ModelDims, tokenizer: PreTrainedTokenizer, @@ -1093,7 +1223,7 @@ def accumulate_inference_batches( if no_resampling_pass_rate is not None and group.percent_solved >= no_resampling_pass_rate: total_no_resampled += 1 - logging.debug( + logger.debug( f"[Data Preparation Thread] Prompt solved at {group.percent_solved}, " f"will be excluded from resampling, total no resampled: {total_no_resampled}" ) @@ -1103,7 +1233,7 @@ def accumulate_inference_batches( groups.append(group) if len(groups) == 0: - logging.warning( + logger.warning( "[Data Preparation Thread] All prompts were filtered during accumulation. " f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" @@ -1123,22 +1253,25 @@ def accumulate_inference_batches( ) -def maybe_mask_truncated_completions(result: data_types.GenerationResult, batch: Batch, enabled: bool) -> Batch: - """If enabled, drop rollouts that didn't finish with 'stop' from result (in place) and batch.""" +def maybe_mask_truncated_completions( + result: data_types.GenerationResult, batch: Batch, scores: np.ndarray, advantages: np.ndarray, enabled: bool +) -> tuple[Batch, np.ndarray, np.ndarray]: + """If enabled, drop rollouts that didn't finish with 'stop' from result, batch, scores, and advantages.""" if not enabled: - return batch + return batch, scores, advantages stop_idxes = [i for i, fr in enumerate(result.finish_reasons) if fr == "stop"] num_truncated = len(result.finish_reasons) - len(stop_idxes) if num_truncated > 0: + retention_rate = len(stop_idxes) / len(result.finish_reasons) if result.finish_reasons else 0.0 logger.info( f"[DataPreparationActor] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" + f"Retention rate: {retention_rate:.2%}" ) result.responses = [result.responses[i] for i in stop_idxes] result.masks = [result.masks[i] for i in stop_idxes] result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] result.logprobs = [result.logprobs[i] for i in stop_idxes] - return batch[stop_idxes] + return batch[stop_idxes], scores[stop_idxes], advantages[stop_idxes] def prepare_collated_data_for_workers( @@ -1260,6 +1393,7 @@ def __init__( model_name: str | None, base_env_config: EnvConfig, initial_state: dict | None = None, + curriculum_config: difficulty_curriculum.DifficultyCurriculumConfig | None = None, ): self.inference_results_Q = inference_results_Q self.param_prompt_Q = param_prompt_Q @@ -1280,15 +1414,11 @@ def __init__( self.model_name = model_name self.base_env_config = base_env_config - self.iter_dataloader = HFDataLoader( - dataset=dataset, - batch_size=1, - seed=seed, - dp_rank=0, - dp_world_size=1, - work_dir=work_dir, - automatic_reshuffle=True, - collator=single_example_collator, + self.iter_dataloader = create_prompt_dataloader( + dataset=dataset, seed=seed, work_dir=work_dir, curriculum_config=curriculum_config + ) + self.curriculum_dataloader = ( + self.iter_dataloader if isinstance(self.iter_dataloader, DifficultyCurriculumHFDataLoader) else None ) self.prepared_data: dict[int, list[data_types.CollatedBatchData]] = {} @@ -1328,6 +1458,8 @@ def _data_preparation_loop(self): num_initial_prompts = self.config.async_steps * self.global_batch_size logger.info(f"[DataPreparationActor] Pushing {num_initial_prompts} initial prompts to param_prompt_Q") + if self.curriculum_dataloader is not None: + self.curriculum_dataloader.set_sampling_step(self.training_step) for _ in range(num_initial_prompts): add_prompt_to_generator( next(self.iter_dataloader), @@ -1347,6 +1479,8 @@ def _data_preparation_loop(self): ) time.sleep(0.1) generation_idle_wait_time = time.perf_counter() - generation_idle_wait_start_time + if self.curriculum_dataloader is not None: + self.curriculum_dataloader.set_sampling_step(self.training_step) logger.info( f"[DataPreparationActor] Step {self.training_step}: calling accumulate_inference_batches for {self.global_batch_size} prompts" @@ -1387,25 +1521,11 @@ def _data_preparation_loop(self): assert batch is not None assert batch_stats is not None - - batch = maybe_mask_truncated_completions(result, batch, self.config.mask_truncated_completions) - - if len(result.responses) == 0: - logger.warning( - f"[DataPreparationActor] 🤡 Step {self.training_step}: no trainable responses after truncation filter; " - "resampling without advancing step counter" - ) - continue - - if self.rubric_manager and batch.decoded_responses: - rubric_metrics, new_overrides = self.rubric_manager.run_step( - decoded_responses=batch.decoded_responses, - ground_truths=batch.ground_truths, - indices=batch.indices, - step=self.training_step, - ) - reward_metrics.update(rubric_metrics) - self.ground_truth_overrides.update(new_overrides) + prompt_dataset_indices: list[int] = [] + if self.curriculum_dataloader is not None and batch.indices is not None: + prompt_dataset_indices = [ + int(index) for index in batch.indices[:: self.config.num_samples_per_prompt_rollout] + ] scores = np.array(batch.scores) scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) @@ -1434,6 +1554,32 @@ def _data_preparation_loop(self): ) self.total_samples_written += len(batch.queries) + batch, scores, advantages = maybe_mask_truncated_completions( + result, batch, scores, advantages, self.config.mask_truncated_completions + ) + + if len(result.responses) == 0: + logger.warning( + f"[DataPreparationActor] 🤡 Step {self.training_step}: no trainable responses after truncation filter; " + "resampling without advancing step counter" + ) + continue + + if self.rubric_manager and batch.decoded_responses: + rubric_metrics, new_overrides = self.rubric_manager.run_step( + decoded_responses=batch.decoded_responses, + ground_truths=batch.ground_truths, + indices=batch.indices, + step=self.training_step, + ) + reward_metrics.update(rubric_metrics) + self.ground_truth_overrides.update(new_overrides) + + if self.curriculum_dataloader is not None and batch.indices is not None: + normalized_scores = np.clip(scores / max(self.config.max_possible_score, 1e-8), 0.0, 1.0) + self.curriculum_dataloader.curriculum_sampler.record_observations( + batch.indices, normalized_scores, advantages + ) packed_sequences = pack_sequences( queries=batch.queries, responses=result.responses, @@ -1520,6 +1666,13 @@ def _data_preparation_loop(self): step_metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time step_metrics["time/getting_response"] = result.token_statistics.generation_time + if self.curriculum_dataloader is not None: + step_metrics.update( + self.curriculum_dataloader.curriculum_sampler.build_metrics( + prompt_dataset_indices, self.training_step + ) + ) + with self.lock: self.prepared_data[self.training_step] = collated_data self.metrics[self.training_step] = step_metrics diff --git a/open_instruct/difficulty_curriculum.py b/open_instruct/difficulty_curriculum.py new file mode 100644 index 0000000000..0b1bc4ad6c --- /dev/null +++ b/open_instruct/difficulty_curriculum.py @@ -0,0 +1,900 @@ +"""Difficulty-aware curriculum sampling for RLVR / GRPO prompt selection.""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Literal + +import numpy as np +import torch +from torch.utils.data import Sampler + +from open_instruct import logger_utils + +logger = logger_utils.setup_logger(__name__) + +_DEFAULT_POSTERIOR_MEAN = 0.5 + + +def _resolve_path(value: Any, path: str) -> Any: + current = value + for part in path.split("."): + if not isinstance(current, dict) or part not in current: + return None + current = current[part] + return current + + +def _coerce_int(value: Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + return None + + +def _coerce_float(value: Any) -> float | None: + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + value = float(value) + if math.isnan(value) or math.isinf(value): + return None + return value + return None + + +def _normalize_probs(values: np.ndarray) -> np.ndarray: + total = float(values.sum()) + if total <= 0: + return np.zeros_like(values, dtype=np.float64) + return values.astype(np.float64) / total + + +@dataclass +class DifficultyCurriculumMetadataConfig: + field: str = "difficulty" + posterior_mean_field: str = "posterior_mean" + bucket_index_field: str = "bucket_index" + bucket_count_field: str = "bucket_count" + quantile_field: str = "expected_quantile" + strict: bool = True + + +@dataclass +class DifficultyCurriculumScheduleConfig: + bootstrap_steps: int = 100 + warmup_steps: int = 500 + total_steps: int = 10000 + bootstrap_target: float = 0.125 + warmup_target: float = 0.5 + final_target: float = 1.0 + min_hard_frac: float = 0.05 + max_hard_frac: float = 0.50 + bucket_sigma: float = 0.0 + bootstrap_sigma: float = 0.0 + + def __post_init__(self) -> None: + if self.bootstrap_steps < 0: + raise ValueError("bootstrap_steps must be >= 0") + if self.warmup_steps < 0: + raise ValueError("warmup_steps must be >= 0") + if self.total_steps <= 0: + raise ValueError("total_steps must be > 0") + if not 0.0 <= self.bootstrap_target <= 1.0: + raise ValueError("bootstrap_target must be in [0, 1]") + if not 0.0 <= self.warmup_target <= 1.0: + raise ValueError("warmup_target must be in [0, 1]") + if not 0.0 <= self.final_target <= 1.0: + raise ValueError("final_target must be in [0, 1]") + if not 0.0 <= self.min_hard_frac <= 1.0: + raise ValueError("min_hard_frac must be in [0, 1]") + if not 0.0 <= self.max_hard_frac <= 1.0: + raise ValueError("max_hard_frac must be in [0, 1]") + if self.min_hard_frac > self.max_hard_frac: + raise ValueError("min_hard_frac must be <= max_hard_frac") + if self.bucket_sigma < 0: + raise ValueError("bucket_sigma must be >= 0") + if self.bootstrap_sigma < 0: + raise ValueError("bootstrap_sigma must be >= 0") + + +@dataclass +class DifficultyCurriculumAdaptiveConfig: + enabled: bool = False + update_every: int = 50 + learning_weight: float = 0.7 + exploration_weight: float = 0.3 + blend: float = 0.5 + + def __post_init__(self) -> None: + if self.update_every <= 0: + raise ValueError("update_every must be > 0") + if not 0.0 <= self.learning_weight <= 1.0: + raise ValueError("learning_weight must be in [0, 1]") + if not 0.0 <= self.exploration_weight <= 1.0: + raise ValueError("exploration_weight must be in [0, 1]") + if not 0.0 <= self.blend <= 1.0: + raise ValueError("blend must be in [0, 1]") + + +@dataclass +class DifficultyCurriculumConfig: + metadata: DifficultyCurriculumMetadataConfig = field(default_factory=DifficultyCurriculumMetadataConfig) + schedule: DifficultyCurriculumScheduleConfig = field(default_factory=DifficultyCurriculumScheduleConfig) + adaptive: DifficultyCurriculumAdaptiveConfig = field(default_factory=DifficultyCurriculumAdaptiveConfig) + uncertainty_weight: float = 0.5 + min_quantile: float = 0.0 + max_quantile: float = 1.0 + seed: int = 0 + epsilon: float = 1e-8 + + def __post_init__(self) -> None: + if not 0.0 <= self.uncertainty_weight <= 1.0: + raise ValueError("uncertainty_weight must be in [0, 1]") + if not 0.0 <= self.min_quantile <= 1.0: + raise ValueError("min_quantile must be in [0, 1]") + if not 0.0 <= self.max_quantile <= 1.0: + raise ValueError("max_quantile must be in [0, 1]") + if self.min_quantile > self.max_quantile: + raise ValueError("min_quantile must be <= max_quantile") + if self.epsilon <= 0: + raise ValueError("epsilon must be > 0") + + +@dataclass +class DifficultyCurriculumArgs: + curriculum: Literal["none", "difficulty"] = "none" + curriculum_metadata_field: str = "difficulty" + curriculum_posterior_mean_field: str = "posterior_mean" + curriculum_bucket_index_field: str = "bucket_index" + curriculum_bucket_count_field: str = "bucket_count" + curriculum_quantile_field: str = "expected_quantile" + curriculum_strict_metadata: bool = True + curriculum_bootstrap_steps: int = 100 + curriculum_warmup_steps: int = 500 + curriculum_total_steps: int = 10000 + curriculum_bootstrap_target: float = 0.125 + curriculum_warmup_target: float = 0.5 + curriculum_final_target: float = 1.0 + curriculum_min_hard_frac: float = 0.05 + curriculum_max_hard_frac: float = 0.50 + curriculum_bucket_sigma: float = 0.0 + curriculum_bootstrap_sigma: float = 0.0 + curriculum_uncertainty_weight: float = 0.5 + curriculum_adaptive: bool = False + curriculum_adaptive_update_every: int = 50 + curriculum_adaptive_learning_weight: float = 0.7 + curriculum_adaptive_exploration_weight: float = 0.3 + curriculum_adaptive_blend: float = 0.5 + curriculum_min_quantile: float = 0.0 + curriculum_max_quantile: float = 1.0 + + def build_metadata_config(self) -> DifficultyCurriculumMetadataConfig: + return DifficultyCurriculumMetadataConfig( + field=self.curriculum_metadata_field, + posterior_mean_field=self.curriculum_posterior_mean_field, + bucket_index_field=self.curriculum_bucket_index_field, + bucket_count_field=self.curriculum_bucket_count_field, + quantile_field=self.curriculum_quantile_field, + strict=self.curriculum_strict_metadata, + ) + + def build_schedule_config(self) -> DifficultyCurriculumScheduleConfig: + return DifficultyCurriculumScheduleConfig( + bootstrap_steps=self.curriculum_bootstrap_steps, + warmup_steps=self.curriculum_warmup_steps, + total_steps=self.curriculum_total_steps, + bootstrap_target=self.curriculum_bootstrap_target, + warmup_target=self.curriculum_warmup_target, + final_target=self.curriculum_final_target, + min_hard_frac=self.curriculum_min_hard_frac, + max_hard_frac=self.curriculum_max_hard_frac, + bucket_sigma=self.curriculum_bucket_sigma, + bootstrap_sigma=self.curriculum_bootstrap_sigma, + ) + + def build_adaptive_config(self) -> DifficultyCurriculumAdaptiveConfig: + return DifficultyCurriculumAdaptiveConfig( + enabled=self.curriculum_adaptive, + update_every=self.curriculum_adaptive_update_every, + learning_weight=self.curriculum_adaptive_learning_weight, + exploration_weight=self.curriculum_adaptive_exploration_weight, + blend=self.curriculum_adaptive_blend, + ) + + def verify(self) -> None: + if self.curriculum == "none": + return + if self.curriculum != "difficulty": + raise ValueError(f"Unsupported curriculum type: {self.curriculum}") + + self.build_metadata_config() + self.build_schedule_config() + self.build_adaptive_config() + + def build_curriculum_config(self, *, seed: int) -> DifficultyCurriculumConfig | None: + self.verify() + if self.curriculum == "none": + return None + + return DifficultyCurriculumConfig( + metadata=self.build_metadata_config(), + schedule=self.build_schedule_config(), + adaptive=self.build_adaptive_config(), + uncertainty_weight=self.curriculum_uncertainty_weight, + min_quantile=self.curriculum_min_quantile, + max_quantile=self.curriculum_max_quantile, + seed=seed, + ) + + +class AdaptiveBucketStats: + """Tracks per-bucket learning signal statistics for adaptive sampling.""" + + def __init__(self, learning_signal_weight: float, exploration_weight: float, epsilon: float) -> None: + self.learning_signal_weight = learning_signal_weight + self.exploration_weight = exploration_weight + self.epsilon = epsilon + + self.total_count = 0 + self._count_by_bucket: dict[int, int] = {} + self._reward_sum_by_bucket: dict[int, float] = {} + self._reward_sq_sum_by_bucket: dict[int, float] = {} + self._abs_advantage_sum_by_bucket: dict[int, float] = {} + self._advantage_count_by_bucket: dict[int, int] = {} + + def update( + self, + bucket_indices: list[int] | np.ndarray, + rewards: list[float] | np.ndarray, + advantages: list[float] | np.ndarray | None = None, + ) -> None: + if len(bucket_indices) != len(rewards): + raise ValueError("bucket_indices and rewards must have the same length") + if advantages is not None and len(advantages) != len(rewards): + raise ValueError("advantages and rewards must have the same length") + + reward_values = np.clip(np.asarray(rewards, dtype=np.float64), 0.0, 1.0) + advantage_values = None if advantages is None else np.abs(np.asarray(advantages, dtype=np.float64)) + + for position, bucket_index in enumerate(bucket_indices): + bucket = int(bucket_index) + reward = float(reward_values[position]) + + self.total_count += 1 + self._count_by_bucket[bucket] = self._count_by_bucket.get(bucket, 0) + 1 + self._reward_sum_by_bucket[bucket] = self._reward_sum_by_bucket.get(bucket, 0.0) + reward + self._reward_sq_sum_by_bucket[bucket] = self._reward_sq_sum_by_bucket.get(bucket, 0.0) + reward * reward + + if advantage_values is not None: + advantage = float(advantage_values[position]) + self._abs_advantage_sum_by_bucket[bucket] = ( + self._abs_advantage_sum_by_bucket.get(bucket, 0.0) + advantage + ) + self._advantage_count_by_bucket[bucket] = self._advantage_count_by_bucket.get(bucket, 0) + 1 + + def state_dict(self) -> dict[str, Any]: + return { + "total_count": self.total_count, + "count_by_bucket": self._count_by_bucket, + "reward_sum_by_bucket": self._reward_sum_by_bucket, + "reward_sq_sum_by_bucket": self._reward_sq_sum_by_bucket, + "abs_advantage_sum_by_bucket": self._abs_advantage_sum_by_bucket, + "advantage_count_by_bucket": self._advantage_count_by_bucket, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.total_count = int(state_dict.get("total_count", 0)) + self._count_by_bucket = {int(k): int(v) for k, v in state_dict.get("count_by_bucket", {}).items()} + self._reward_sum_by_bucket = {int(k): float(v) for k, v in state_dict.get("reward_sum_by_bucket", {}).items()} + self._reward_sq_sum_by_bucket = { + int(k): float(v) for k, v in state_dict.get("reward_sq_sum_by_bucket", {}).items() + } + self._abs_advantage_sum_by_bucket = { + int(k): float(v) for k, v in state_dict.get("abs_advantage_sum_by_bucket", {}).items() + } + self._advantage_count_by_bucket = { + int(k): int(v) for k, v in state_dict.get("advantage_count_by_bucket", {}).items() + } + + def get_count(self, bucket_index: int) -> int: + return self._count_by_bucket.get(bucket_index, 0) + + def get_mean_reward(self, bucket_index: int) -> float: + count = self.get_count(bucket_index) + if count == 0: + return 0.0 + return self._reward_sum_by_bucket.get(bucket_index, 0.0) / count + + def get_mean_abs_advantage(self, bucket_index: int) -> float: + count = self._advantage_count_by_bucket.get(bucket_index, 0) + if count == 0: + return 0.0 + return self._abs_advantage_sum_by_bucket.get(bucket_index, 0.0) / count + + def _get_reward_variance(self, bucket_index: int) -> float: + count = self.get_count(bucket_index) + if count == 0: + return 0.0 + mean_reward = self.get_mean_reward(bucket_index) + mean_reward_sq = self._reward_sq_sum_by_bucket.get(bucket_index, 0.0) / count + return max(0.0, mean_reward_sq - mean_reward * mean_reward) + + def get_bucket_scores(self, bucket_count: int) -> np.ndarray: + scores = np.zeros(bucket_count, dtype=np.float64) + total_count = max(self.total_count, 0) + + for bucket_index in range(bucket_count): + count = self.get_count(bucket_index) + mean_reward = self.get_mean_reward(bucket_index) + mean_abs_advantage = self.get_mean_abs_advantage(bucket_index) + + if self._advantage_count_by_bucket.get(bucket_index, 0) > 0: + learning_signal = mean_abs_advantage * max(0.0, 1.0 - mean_reward) + else: + reward_variance = self._get_reward_variance(bucket_index) + non_saturation = max(0.0, 1.0 - 2.0 * abs(mean_reward - 0.5)) + learning_signal = 0.5 * reward_variance + 0.5 * non_saturation + + exploration_bonus = math.sqrt(math.log(total_count + 1.0) / (count + 1.0)) + scores[bucket_index] = ( + self.learning_signal_weight * learning_signal + + self.exploration_weight * exploration_bonus + + self.epsilon + ) + + return scores + + +@dataclass(frozen=True) +class _ParsedDifficultyMetadata: + bucket_index: int | None + bucket_count: int | None + posterior_mean: float | None + expected_quantile: float | None + error: str | None + + +@dataclass(frozen=True) +class _DifficultyBucketIndex: + index_to_bucket: dict[int, int] + index_to_bucket_position: dict[int, int] + bucket_to_indices: tuple[tuple[int, ...], ...] + bucket_weights: tuple[torch.Tensor, ...] + bucket_count: int + metadata_fallback_count: int + filtered_out_count: int + + +class _DifficultyCurriculumSchedule: + def __init__(self, config: DifficultyCurriculumScheduleConfig, bucket_count: int) -> None: + self.config = config + self.bucket_count = bucket_count + self._bucket_ids = np.arange(max(self.bucket_count - 1, 0), dtype=np.float64) + + def get_progress(self, step: int) -> float: + if step < self.config.warmup_steps: + return 0.0 + return min(1.0, (step - self.config.warmup_steps) / self.config.total_steps) + + def _smooth_progress(self, step: int) -> float: + progress = self.get_progress(step) + return progress * progress * (3.0 - 2.0 * progress) + + def _get_default_bucket_sigma(self) -> float: + return max(0.85, 0.25 * max(self.bucket_count - 1, 1)) + + def _get_bucket_sigma(self, step: int) -> float: + sigma = self.config.bucket_sigma if self.config.bucket_sigma > 0 else self._get_default_bucket_sigma() + if step < self.config.bootstrap_steps and self.config.bootstrap_sigma > 0: + return self.config.bootstrap_sigma + return sigma + + def _bucket_ratio_to_bucket_index(self, bucket_ratio: float) -> float: + return float(self.bucket_count - 1) * bucket_ratio + + def _get_target_bucket(self, step: int) -> float: + warmup_target_bucket = self._bucket_ratio_to_bucket_index(self.config.warmup_target) + final_target_bucket = self._bucket_ratio_to_bucket_index(self.config.final_target) + + if self.config.bootstrap_steps > 0 and step < self.config.bootstrap_steps: + bootstrap_progress = min(1.0, step / self.config.bootstrap_steps) + bootstrap_target_bucket = self._bucket_ratio_to_bucket_index(self.config.bootstrap_target) + return bootstrap_target_bucket + (warmup_target_bucket - bootstrap_target_bucket) * bootstrap_progress + + smooth_progress = self._smooth_progress(step) + return warmup_target_bucket + (final_target_bucket - warmup_target_bucket) * smooth_progress + + def build_probs(self, step: int, available_mask: np.ndarray) -> np.ndarray: + if self.bucket_count == 1: + return np.ones(1, dtype=np.float64) + + smooth_progress = self._smooth_progress(step) + target_bucket = self._get_target_bucket(step) + hard_bucket_frac = ( + self.config.min_hard_frac + (self.config.max_hard_frac - self.config.min_hard_frac) * smooth_progress + ) + + sigma = self._get_bucket_sigma(step) + gaussian_logits = np.exp(-0.5 * ((self._bucket_ids - target_bucket) / sigma) ** 2) + non_hard_probs = _normalize_probs(gaussian_logits) + + static_probs = np.zeros(self.bucket_count, dtype=np.float64) + static_probs[:-1] = (1.0 - hard_bucket_frac) * non_hard_probs + static_probs[-1] = hard_bucket_frac + + static_probs *= available_mask + if available_mask.sum() == 0: + return np.ones(self.bucket_count, dtype=np.float64) / self.bucket_count + if static_probs.sum() <= 0: + return _normalize_probs(available_mask) + return _normalize_probs(static_probs) + + +def _parse_difficulty_metadata( + example: dict[str, Any], index: int, metadata_config: DifficultyCurriculumMetadataConfig +) -> _ParsedDifficultyMetadata: + difficulty_blob = _resolve_path(example, metadata_config.field) + if not isinstance(difficulty_blob, dict): + return _ParsedDifficultyMetadata( + bucket_index=None, + bucket_count=None, + posterior_mean=None, + expected_quantile=None, + error=f"missing '{metadata_config.field}' metadata for dataset index {index}", + ) + + bucket_index = _coerce_int(_resolve_path(difficulty_blob, metadata_config.bucket_index_field)) + bucket_count = _coerce_int(_resolve_path(difficulty_blob, metadata_config.bucket_count_field)) + posterior_mean = _coerce_float(_resolve_path(difficulty_blob, metadata_config.posterior_mean_field)) + expected_quantile = _coerce_float(_resolve_path(difficulty_blob, metadata_config.quantile_field)) + + if bucket_index is None or bucket_index < 0: + return _ParsedDifficultyMetadata( + bucket_index=None, + bucket_count=bucket_count, + posterior_mean=posterior_mean, + expected_quantile=expected_quantile, + error=f"invalid bucket_index for dataset index {index}", + ) + if bucket_count is None or bucket_count <= 0: + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=None, + posterior_mean=posterior_mean, + expected_quantile=expected_quantile, + error=f"invalid bucket_count for dataset index {index}", + ) + if posterior_mean is None: + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=bucket_count, + posterior_mean=None, + expected_quantile=expected_quantile, + error=f"invalid posterior_mean for dataset index {index}", + ) + if expected_quantile is not None and not 0.0 <= expected_quantile <= 1.0: + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=bucket_count, + posterior_mean=posterior_mean, + expected_quantile=None, + error=f"invalid expected_quantile for dataset index {index}", + ) + return _ParsedDifficultyMetadata( + bucket_index=bucket_index, + bucket_count=bucket_count, + posterior_mean=posterior_mean, + expected_quantile=expected_quantile, + error=None, + ) + + +def _compute_example_weight(posterior_mean: float, uncertainty_weight: float, epsilon: float) -> float: + probability = float(np.clip(posterior_mean, 0.0, 1.0)) + uncertainty = 4.0 * probability * (1.0 - probability) + hardness = 1.0 - probability + return uncertainty_weight * uncertainty + (1.0 - uncertainty_weight) * hardness + epsilon + + +def _build_difficulty_bucket_index( + dataset, + metadata_config: DifficultyCurriculumMetadataConfig, + uncertainty_weight: float, + epsilon: float, + min_quantile: float, + max_quantile: float, +) -> _DifficultyBucketIndex: + parsed_rows: list[_ParsedDifficultyMetadata] = [] + observed_bucket_counts: set[int] = set() + max_bucket_index = -1 + filter_active = min_quantile > 0.0 or max_quantile < 1.0 + + for dataset_index in range(len(dataset)): + parsed = _parse_difficulty_metadata(dataset[dataset_index], dataset_index, metadata_config) + if parsed.error is not None and metadata_config.strict: + raise ValueError(parsed.error) + if parsed.bucket_count is not None: + observed_bucket_counts.add(parsed.bucket_count) + if parsed.bucket_index is not None: + max_bucket_index = max(max_bucket_index, parsed.bucket_index) + parsed_rows.append(parsed) + + if observed_bucket_counts: + if metadata_config.strict and len(observed_bucket_counts) > 1: + raise ValueError(f"inconsistent difficulty bucket_count values found: {sorted(observed_bucket_counts)}") + bucket_count = max(observed_bucket_counts) + elif max_bucket_index >= 0: + bucket_count = max_bucket_index + 1 + else: + bucket_count = 1 + + bucket_to_indices: list[list[int]] = [[] for _ in range(bucket_count)] + bucket_weight_lists: list[list[float]] = [[] for _ in range(bucket_count)] + index_to_bucket: dict[int, int] = {} + index_to_bucket_position: dict[int, int] = {} + metadata_fallback_count = 0 + filtered_out_count = 0 + fallback_bucket = min(bucket_count - 1, bucket_count // 2) + + for dataset_index, parsed in enumerate(parsed_rows): + if parsed.error is not None: + if filter_active: + filtered_out_count += 1 + continue + bucket_index = fallback_bucket + posterior_mean = _DEFAULT_POSTERIOR_MEAN + metadata_fallback_count += 1 + else: + assert parsed.bucket_index is not None + bucket_index = int(np.clip(parsed.bucket_index, 0, bucket_count - 1)) + posterior_mean = parsed.posterior_mean + if filter_active: + expected_quantile = parsed.expected_quantile + if expected_quantile is None: + if metadata_config.strict: + raise ValueError(f"invalid {metadata_config.quantile_field} for dataset index {dataset_index}") + filtered_out_count += 1 + continue + if expected_quantile < min_quantile or expected_quantile > max_quantile: + filtered_out_count += 1 + continue + + if posterior_mean is None: + posterior_mean = _DEFAULT_POSTERIOR_MEAN + example_weight = _compute_example_weight( + posterior_mean=float(np.clip(posterior_mean, 0.0, 1.0)), + uncertainty_weight=uncertainty_weight, + epsilon=epsilon, + ) + + index_to_bucket[dataset_index] = bucket_index + index_to_bucket_position[dataset_index] = len(bucket_to_indices[bucket_index]) + bucket_to_indices[bucket_index].append(dataset_index) + bucket_weight_lists[bucket_index].append(example_weight) + + return _DifficultyBucketIndex( + index_to_bucket=index_to_bucket, + index_to_bucket_position=index_to_bucket_position, + bucket_to_indices=tuple(tuple(indices) for indices in bucket_to_indices), + bucket_weights=tuple(torch.tensor(weight_list, dtype=torch.float64) for weight_list in bucket_weight_lists), + bucket_count=bucket_count, + metadata_fallback_count=metadata_fallback_count, + filtered_out_count=filtered_out_count, + ) + + +class DifficultyCurriculumSampler(Sampler[int]): + """Bucket-aware curriculum sampler for difficulty-annotated prompt datasets.""" + + def __init__( + self, + dataset, + num_samples: int, + config: DifficultyCurriculumConfig, + global_step_getter: Callable[[], int] | None, + ) -> None: + if num_samples <= 0: + raise ValueError("num_samples must be > 0") + + self.dataset = dataset + self.num_samples = num_samples + self.config = config + self.global_step_getter = global_step_getter + + self._generator = torch.Generator() + self._generator.manual_seed(self.config.seed) + + bucket_index = _build_difficulty_bucket_index( + dataset=dataset, + metadata_config=self.config.metadata, + uncertainty_weight=self.config.uncertainty_weight, + epsilon=self.config.epsilon, + min_quantile=self.config.min_quantile, + max_quantile=self.config.max_quantile, + ) + self._index_to_bucket = dict(bucket_index.index_to_bucket) + self._index_to_bucket_position = dict(bucket_index.index_to_bucket_position) + self.bucket_count = bucket_index.bucket_count + self.metadata_fallback_count = bucket_index.metadata_fallback_count + self.filtered_out_count = bucket_index.filtered_out_count + self._schedule = _DifficultyCurriculumSchedule(self.config.schedule, self.bucket_count) + + self._excluded_indices: set[int] = set() + self._base_bucket_indices = [list(indices) for indices in bucket_index.bucket_to_indices] + self._base_bucket_weights = [weights.clone() for weights in bucket_index.bucket_weights] + self._active_bucket_weights = [weights.clone() for weights in self._base_bucket_weights] + self._active_bucket_counts = [len(indices) for indices in self._base_bucket_indices] + self._active_bucket_weight_sums = [float(weights.sum().item()) for weights in self._active_bucket_weights] + + self.adaptive_stats = None + if self.config.adaptive.enabled: + self.adaptive_stats = AdaptiveBucketStats( + learning_signal_weight=self.config.adaptive.learning_weight, + exploration_weight=self.config.adaptive.exploration_weight, + epsilon=self.config.epsilon, + ) + + self._cached_static_probs: np.ndarray | None = None + self._last_static_refresh_step = -1 + self._cached_adaptive_probs: np.ndarray | None = None + self._last_adaptive_refresh_step = -1 + + if self.metadata_fallback_count > 0 and not self.config.metadata.strict: + logger.warning( + "Difficulty curriculum fell back to conservative defaults for %s/%s rows because metadata was missing " + "or invalid.", + self.metadata_fallback_count, + len(self.dataset), + ) + if self.filtered_out_count > 0: + logger.info( + "Difficulty curriculum filtered out %s/%s rows outside quantile range [%s, %s].", + self.filtered_out_count, + len(self.dataset), + self.config.min_quantile, + self.config.max_quantile, + ) + if sum(self._active_bucket_counts) == 0: + raise ValueError("Difficulty curriculum filter removed all dataset examples.") + + def __len__(self) -> int: + return self.num_samples + + @property + def bucket_to_indices(self) -> tuple[tuple[int, ...], ...]: + return tuple(tuple(indices) for indices in self._base_bucket_indices) + + def _get_current_step(self) -> int: + step = self.global_step_getter() if self.global_step_getter is not None else 0 + return max(int(step), 0) + + def get_progress(self, step: int | None = None) -> float: + if step is None: + step = self._get_current_step() + return self._schedule.get_progress(step) + + def _available_bucket_mask(self) -> np.ndarray: + return np.array([1.0 if count > 0 else 0.0 for count in self._active_bucket_counts], dtype=np.float64) + + def _invalidate_static_bucket_probs(self) -> None: + self._cached_static_probs = None + self._last_static_refresh_step = -1 + + def _invalidate_adaptive_bucket_probs(self) -> None: + self._cached_adaptive_probs = None + self._last_adaptive_refresh_step = -1 + + def _invalidate_bucket_prob_caches(self) -> None: + self._invalidate_static_bucket_probs() + self._invalidate_adaptive_bucket_probs() + + def get_static_bucket_probs(self, step: int | None = None) -> np.ndarray: + if step is None: + step = self._get_current_step() + if self._cached_static_probs is not None and step == self._last_static_refresh_step: + return self._cached_static_probs.copy() + + self._cached_static_probs = self._schedule.build_probs(step, self._available_bucket_mask()) + self._last_static_refresh_step = step + return self._cached_static_probs.copy() + + def get_adaptive_bucket_probs(self, step: int | None = None) -> np.ndarray | None: + if not self.config.adaptive.enabled or self.adaptive_stats is None or self.adaptive_stats.total_count == 0: + return None + + refresh_step = self._get_current_step() if step is None else step + if ( + self._cached_adaptive_probs is not None + and refresh_step - self._last_adaptive_refresh_step < self.config.adaptive.update_every + ): + return self._cached_adaptive_probs.copy() + + adaptive_scores = self.adaptive_stats.get_bucket_scores(self.bucket_count) + adaptive_scores *= self._available_bucket_mask() + if adaptive_scores.sum() <= 0: + return None + + self._cached_adaptive_probs = _normalize_probs(adaptive_scores) + self._last_adaptive_refresh_step = refresh_step + return self._cached_adaptive_probs.copy() + + def get_bucket_probs(self, step: int | None = None) -> np.ndarray: + static_probs = self.get_static_bucket_probs(step) + adaptive_probs = self.get_adaptive_bucket_probs(step) + if adaptive_probs is None: + return static_probs + + final_probs = (1.0 - self.config.adaptive.blend) * static_probs + self.config.adaptive.blend * adaptive_probs + final_probs *= self._available_bucket_mask() + if final_probs.sum() <= 0: + return static_probs + return _normalize_probs(final_probs) + + def bucket_for_dataset_index(self, dataset_index: int) -> int: + return self._index_to_bucket[int(dataset_index)] + + def get_example_probability(self, dataset_index: int, step: int | None = None) -> float: + if int(dataset_index) in self._excluded_indices: + return 0.0 + bucket_index = self._index_to_bucket.get(int(dataset_index)) + if bucket_index is None: + return 0.0 + if self._active_bucket_counts[bucket_index] == 0: + return 0.0 + local_index = self._index_to_bucket_position.get(int(dataset_index)) + if local_index is None: + return 0.0 + bucket_weight = self._active_bucket_weights[bucket_index] + dataset_weight = float(bucket_weight[local_index].item()) + if dataset_weight <= 0: + return 0.0 + weight_total = self._active_bucket_weight_sums[bucket_index] + if weight_total <= 0: + return 0.0 + bucket_probs = self.get_bucket_probs(step) + return float(bucket_probs[bucket_index] * dataset_weight / weight_total) + + def sample_index(self, step: int | None = None) -> int: + if self._available_bucket_mask().sum() == 0: + raise RuntimeError("All dataset examples have been excluded. Cannot continue iteration.") + bucket_probs = torch.tensor(self.get_bucket_probs(step), dtype=torch.float64) + bucket_index = int(torch.multinomial(bucket_probs, 1, generator=self._generator).item()) + + example_weights = self._active_bucket_weights[bucket_index] + if self._active_bucket_counts[bucket_index] == 0 or self._active_bucket_weight_sums[bucket_index] <= 0: + raise RuntimeError("attempted to sample from an empty curriculum bucket") + + sampled_offset = int(torch.multinomial(example_weights, 1, generator=self._generator).item()) + return self._base_bucket_indices[bucket_index][sampled_offset] + + def __iter__(self): + for _ in range(self.num_samples): + yield self.sample_index() + + def exclude_index(self, dataset_index: int) -> None: + dataset_index = int(dataset_index) + if dataset_index in self._excluded_indices: + return + + bucket_index = self._index_to_bucket.get(dataset_index) + if bucket_index is None: + return + + position = self._index_to_bucket_position.get(dataset_index) + if position is None: + self._excluded_indices.add(dataset_index) + return + + weights = self._active_bucket_weights[bucket_index] + current_weight = float(weights[position].item()) + if current_weight <= 0: + self._excluded_indices.add(dataset_index) + return + + weights[position] = 0.0 + self._active_bucket_counts[bucket_index] -= 1 + self._active_bucket_weight_sums[bucket_index] = max( + 0.0, self._active_bucket_weight_sums[bucket_index] - current_weight + ) + self._excluded_indices.add(dataset_index) + self._invalidate_bucket_prob_caches() + + def record_observations( + self, + dataset_indices: list[int] | np.ndarray, + rewards: list[float] | np.ndarray, + advantages: list[float] | np.ndarray | None = None, + ) -> None: + if not self.config.adaptive.enabled or self.adaptive_stats is None: + return + if len(dataset_indices) == 0: + return + + bucket_indices = [self.bucket_for_dataset_index(int(dataset_index)) for dataset_index in dataset_indices] + self.adaptive_stats.update(bucket_indices, rewards, advantages) + self._invalidate_adaptive_bucket_probs() + + def build_metrics(self, prompt_dataset_indices: list[int], step: int | None = None) -> dict[str, float]: + metrics: dict[str, float] = {"curriculum/progress": self.get_progress(step)} + + static_probs = self.get_static_bucket_probs(step) + for bucket_index, probability in enumerate(static_probs): + metrics[f"curriculum/static_bucket_prob_{bucket_index}"] = float(probability) + + adaptive_probs = self.get_adaptive_bucket_probs(step) + if self.config.adaptive.enabled: + if adaptive_probs is None: + adaptive_probs = np.zeros(self.bucket_count, dtype=np.float64) + for bucket_index, probability in enumerate(adaptive_probs): + metrics[f"curriculum/adaptive_bucket_prob_{bucket_index}"] = float(probability) + + final_probs = self.get_bucket_probs(step) + for bucket_index, probability in enumerate(final_probs): + metrics[f"curriculum/bucket_prob_{bucket_index}"] = float(probability) + + sampled_counts = np.zeros(self.bucket_count, dtype=np.float64) + for dataset_index in prompt_dataset_indices: + sampled_counts[self.bucket_for_dataset_index(int(dataset_index))] += 1.0 + for bucket_index, count in enumerate(sampled_counts): + metrics[f"curriculum/sampled_bucket_count_{bucket_index}"] = float(count) + + if self.config.adaptive.enabled and self.adaptive_stats is not None: + for bucket_index in range(self.bucket_count): + metrics[f"curriculum/bucket_reward_mean_{bucket_index}"] = float( + self.adaptive_stats.get_mean_reward(bucket_index) + ) + metrics[f"curriculum/bucket_abs_advantage_mean_{bucket_index}"] = float( + self.adaptive_stats.get_mean_abs_advantage(bucket_index) + ) + + return metrics + + def state_dict(self) -> dict[str, Any]: + return { + "generator_state": self._generator.get_state(), + "excluded_indices": sorted(self._excluded_indices), + "adaptive_stats": None if self.adaptive_stats is None else self.adaptive_stats.state_dict(), + "last_adaptive_refresh_step": self._last_adaptive_refresh_step, + "cached_adaptive_probs": None + if self._cached_adaptive_probs is None + else self._cached_adaptive_probs.tolist(), + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + generator_state = state_dict.get("generator_state") + if generator_state is not None: + self._generator.set_state(generator_state) + + self._excluded_indices = {int(index) for index in state_dict.get("excluded_indices", [])} + self._active_bucket_weights = [weights.clone() for weights in self._base_bucket_weights] + self._active_bucket_counts = [len(indices) for indices in self._base_bucket_indices] + self._active_bucket_weight_sums = [float(weights.sum().item()) for weights in self._active_bucket_weights] + for dataset_index in self._excluded_indices: + bucket_index = self._index_to_bucket.get(dataset_index) + position = self._index_to_bucket_position.get(dataset_index) + if bucket_index is None or position is None: + continue + + current_weight = float(self._active_bucket_weights[bucket_index][position].item()) + if current_weight <= 0: + continue + + self._active_bucket_weights[bucket_index][position] = 0.0 + self._active_bucket_counts[bucket_index] -= 1 + self._active_bucket_weight_sums[bucket_index] = max( + 0.0, self._active_bucket_weight_sums[bucket_index] - current_weight + ) + + if self.adaptive_stats is not None and state_dict.get("adaptive_stats") is not None: + self.adaptive_stats.load_state_dict(state_dict["adaptive_stats"]) + + self._invalidate_static_bucket_probs() + self._last_adaptive_refresh_step = int(state_dict.get("last_adaptive_refresh_step", -1)) + cached_adaptive_probs = state_dict.get("cached_adaptive_probs") + self._cached_adaptive_probs = None if cached_adaptive_probs is None else np.array(cached_adaptive_probs) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index f9cb2f7f82..1a356f66e3 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -40,7 +40,7 @@ from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF from deepspeed.utils import groups -from open_instruct import data_loader as data_loader_lib +from open_instruct import data_loader as data_loader_lib, difficulty_curriculum from open_instruct import data_types, grpo_utils, utils from open_instruct.data_loader import accumulate_inference_batches, add_prompt_to_generator from open_instruct.data_types import EnvConfig, EnvConfigEntry @@ -1280,6 +1280,7 @@ def create_model_and_optimizer( reward_config: RewardConfig, generation_config, base_env_config: EnvConfig, + curriculum_config: difficulty_curriculum.DifficultyCurriculumConfig | None, tool_definitions: list[dict[str, Any]] | None = None, tools_config: EnvsConfig | None = None, pools: dict[str, ray.actor.ActorHandle] | None = None, @@ -1341,6 +1342,7 @@ def create_model_and_optimizer( model_name=model_config.model_name_or_path, initial_state=None, base_env_config=base_env_config, + curriculum_config=curriculum_config, ) # Create policy group and start model loading BEFORE vLLM engines (matches main branch order). @@ -2334,10 +2336,15 @@ def main( streaming_config: data_loader_lib.StreamingDataLoaderConfig, vllm_config: data_loader_lib.VLLMConfig, tools_config: EnvsConfig, + curriculum: difficulty_curriculum.DifficultyCurriculumArgs | None = None, ): tokenizer = make_tokenizer(tc, model_config) args = setup_runtime_variables(args, streaming_config, tools_config) validate_configs(streaming_config, vllm_config, tuple(args.num_learners_per_node), args.sequence_parallel_size) + if curriculum is None: + curriculum = difficulty_curriculum.DifficultyCurriculumArgs() + curriculum.verify() + curriculum_config = curriculum.build_curriculum_config(seed=args.seed) if args.verbose: logging.getLogger().setLevel(logging.DEBUG) @@ -2394,7 +2401,7 @@ def main( if tc.tokenizer_name_or_path and tc.tokenizer_name_or_path != model_config.model_name_or_path: utils.ensure_hf_repo_cached(tc.tokenizer_name_or_path, revision=tc.tokenizer_revision) - pprint([args, model_config, streaming_config, vllm_config, tools_config]) + pprint([args, model_config, streaming_config, vllm_config, tools_config, curriculum]) # Create Ray queues. # Since we now send/receive individual prompts, queue size should accommodate @@ -2450,6 +2457,7 @@ def main( reward_config, generation_configs["train"], base_env_config, + curriculum_config, tool_definitions, tools_config, pools, @@ -2533,10 +2541,11 @@ def main( data_loader_lib.StreamingDataLoaderConfig, data_loader_lib.VLLMConfig, EnvsConfig, + difficulty_curriculum.DifficultyCurriculumArgs, ) ) parser.set_defaults(exp_name="grpo", warmup_ratio=0.0, max_grad_norm=1.0, per_device_train_batch_size=1) - args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config = ( + args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum = ( parser.parse_args_into_dataclasses() ) assert isinstance(args, grpo_utils.GRPOExperimentConfig) @@ -2545,5 +2554,6 @@ def main( assert isinstance(streaming_config, data_loader_lib.StreamingDataLoaderConfig) assert isinstance(vllm_config, data_loader_lib.VLLMConfig) assert isinstance(tools_config, EnvsConfig) + assert isinstance(curriculum, difficulty_curriculum.DifficultyCurriculumArgs) - main(args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config) + main(args, tokenizer_config, model_config, streaming_config, vllm_config, tools_config, curriculum) diff --git a/open_instruct/test_data_loader.py b/open_instruct/test_data_loader.py index 2bb6390fde..aa7a6b19fc 100644 --- a/open_instruct/test_data_loader.py +++ b/open_instruct/test_data_loader.py @@ -1,11 +1,13 @@ import tempfile import unittest +import numpy as np import parameterized import torch from datasets import Dataset -from open_instruct import data_loader +from open_instruct import data_loader, data_types +from open_instruct.model_utils import Batch from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO @@ -79,5 +81,70 @@ def test_packing_equal_batches_across_ranks( self.assertEqual(all_indices, expected_indices, f"Missing indices: {expected_indices - all_indices}") +class TestMaskTruncatedCompletions(unittest.TestCase): + def _make_batch(self) -> Batch: + return Batch( + queries=[[11], [11], [22], [22]], + ground_truths=[[1], [1], [2], [2]], + datasets=["train", "train", "train", "train"], + raw_queries=["q0a", "q0b", "q1a", "q1b"], + decoded_responses=["r0a", "r0b", "r1a", "r1b"], + indices=[10, 10, 11, 11], + scores=[0.1, 0.2, 0.3, 0.4], + model_steps=[0, 0, 0, 0], + ) + + def _make_result(self, finish_reasons: list[str]) -> data_types.GenerationResult: + return data_types.GenerationResult( + responses=[[100 + i, 200 + i] for i in range(len(finish_reasons))], + finish_reasons=finish_reasons, + masks=[[1, 1] for _ in finish_reasons], + request_info=data_types.RequestInfo( + num_calls=[], timeouts=[], tool_errors=[], tool_outputs=[], tool_runtimes=[], tool_calleds=[] + ), + index=0, + prompt_id="0_0", + logprobs=[[0.1, 0.2] for _ in finish_reasons], + ) + + def test_mask_truncated_completions_keeps_batch_and_arrays_aligned(self): + batch = self._make_batch() + result = self._make_result(["stop", "length", "stop", "tool_calls"]) + scores = np.array(batch.scores, dtype=np.float32) + advantages = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + + filtered_batch, filtered_scores, filtered_advantages = data_loader.maybe_mask_truncated_completions( + result, batch, scores, advantages, enabled=True + ) + + self.assertEqual(filtered_batch.indices, [10, 11]) + self.assertEqual(filtered_batch.scores, [0.1, 0.3]) + self.assertEqual(filtered_batch.decoded_responses, ["r0a", "r1a"]) + self.assertEqual(result.finish_reasons, ["stop", "stop"]) + self.assertEqual(result.responses, [[100, 200], [102, 202]]) + self.assertEqual(result.logprobs, [[0.1, 0.2], [0.1, 0.2]]) + np.testing.assert_allclose(filtered_scores, np.array([0.1, 0.3], dtype=np.float32)) + np.testing.assert_allclose(filtered_advantages, np.array([1.0, 3.0], dtype=np.float32)) + + def test_mask_truncated_completions_handles_all_truncated(self): + batch = self._make_batch() + result = self._make_result(["length", "tool_calls", "length", "tool_calls"]) + scores = np.array(batch.scores, dtype=np.float32) + advantages = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + + filtered_batch, filtered_scores, filtered_advantages = data_loader.maybe_mask_truncated_completions( + result, batch, scores, advantages, enabled=True + ) + + self.assertEqual(filtered_batch.indices, []) + self.assertEqual(filtered_batch.scores, []) + self.assertEqual(result.responses, []) + self.assertEqual(result.finish_reasons, []) + self.assertEqual(result.masks, []) + self.assertEqual(result.logprobs, []) + self.assertEqual(filtered_scores.size, 0) + self.assertEqual(filtered_advantages.size, 0) + + if __name__ == "__main__": unittest.main() diff --git a/open_instruct/test_difficulty_curriculum.py b/open_instruct/test_difficulty_curriculum.py new file mode 100644 index 0000000000..0a33cae981 --- /dev/null +++ b/open_instruct/test_difficulty_curriculum.py @@ -0,0 +1,288 @@ +import sys +import tempfile +import types +import unittest +from unittest import mock + +from datasets import Dataset +from transformers import HfArgumentParser + +if "vllm" not in sys.modules: + vllm_stub = types.ModuleType("vllm") + vllm_stub.SamplingParams = object + sys.modules["vllm"] = vllm_stub + +from open_instruct import data_loader, difficulty_curriculum + + +class ListDataset: + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, index): + return self.rows[index] + + +def make_difficulty_row(index: int, bucket_index: int, posterior_mean: float, bucket_count: int = 5) -> dict: + return { + "index": index, + "difficulty": { + "value": 1.0 - posterior_mean, + "posterior_mean": posterior_mean, + "posterior_lower_bound": 0.0, + "expected_quantile": bucket_index / max(bucket_count - 1, 1), + "bucket_index": bucket_index, + "bucket_count": bucket_count, + }, + } + + +def make_bucket_dataset(bucket_count: int = 5) -> ListDataset: + rows = [ + make_difficulty_row(index=0, bucket_index=0, posterior_mean=0.95, bucket_count=bucket_count), + make_difficulty_row(index=1, bucket_index=1, posterior_mean=0.80, bucket_count=bucket_count), + make_difficulty_row(index=2, bucket_index=2, posterior_mean=0.50, bucket_count=bucket_count), + make_difficulty_row(index=3, bucket_index=3, posterior_mean=0.20, bucket_count=bucket_count), + make_difficulty_row(index=4, bucket_index=4, posterior_mean=0.003, bucket_count=bucket_count), + ] + return ListDataset(rows) + + +def make_plain_hf_dataset(num_examples: int) -> Dataset: + return Dataset.from_dict( + {"text": [f"example_{index}" for index in range(num_examples)], "index": list(range(num_examples))} + ) + + +class TestDifficultyCurriculumSampler(unittest.TestCase): + def _make_metadata(self, **overrides) -> difficulty_curriculum.DifficultyCurriculumMetadataConfig: + return difficulty_curriculum.DifficultyCurriculumMetadataConfig(**overrides) + + def _make_schedule(self, **overrides) -> difficulty_curriculum.DifficultyCurriculumScheduleConfig: + return difficulty_curriculum.DifficultyCurriculumScheduleConfig( + bootstrap_steps=100, warmup_steps=120, total_steps=200, **overrides + ) + + def _make_adaptive(self, **overrides) -> difficulty_curriculum.DifficultyCurriculumAdaptiveConfig: + return difficulty_curriculum.DifficultyCurriculumAdaptiveConfig(**overrides) + + def _make_config(self, **overrides) -> difficulty_curriculum.DifficultyCurriculumConfig: + return difficulty_curriculum.DifficultyCurriculumConfig( + metadata=overrides.pop("metadata", self._make_metadata()), + schedule=overrides.pop("schedule", self._make_schedule()), + adaptive=overrides.pop("adaptive", self._make_adaptive()), + seed=13, + **overrides, + ) + + def _make_sampler(self, dataset, **config_overrides) -> difficulty_curriculum.DifficultyCurriculumSampler: + config = self._make_config(**config_overrides) + return difficulty_curriculum.DifficultyCurriculumSampler( + dataset=dataset, num_samples=max(len(dataset), 1), config=config, global_step_getter=lambda: 0 + ) + + def test_missing_metadata_raises_when_strict_metadata(self): + dataset = ListDataset([{"index": 0}]) + with self.assertRaises(ValueError): + self._make_sampler(dataset, metadata=self._make_metadata(strict=True)) + + def test_missing_metadata_falls_back_when_not_strict(self): + dataset = ListDataset( + [make_difficulty_row(index=0, bucket_index=0, posterior_mean=0.9, bucket_count=5), {"index": 1}] + ) + sampler = self._make_sampler(dataset, metadata=self._make_metadata(strict=False)) + self.assertEqual(sampler.metadata_fallback_count, 1) + self.assertIn(1, sampler.bucket_to_indices[2]) + + def test_quantile_filter_keeps_only_hardest_half(self): + sampler = self._make_sampler(make_bucket_dataset(), min_quantile=0.5) + self.assertEqual(sampler.bucket_to_indices, ((), (), (2,), (3,), (4,))) + self.assertEqual(sampler.filtered_out_count, 2) + self.assertEqual(sampler.get_example_probability(0, step=0), 0.0) + self.assertTrue(all(sampler.sample_index(step=0) in {2, 3, 4} for _ in range(20))) + + def test_quantile_filter_reweights_only_available_buckets(self): + sampler = self._make_sampler(make_bucket_dataset(), min_quantile=0.5) + early_probs = sampler.get_static_bucket_probs(step=0) + self.assertAlmostEqual(float(early_probs.sum()), 1.0, places=6) + self.assertEqual(float(early_probs[0]), 0.0) + self.assertEqual(float(early_probs[1]), 0.0) + self.assertGreater(early_probs[2], early_probs[3]) + + def test_bucket_grouping_works(self): + sampler = self._make_sampler(make_bucket_dataset()) + self.assertEqual(sampler.bucket_to_indices, ((0,), (1,), (2,), (3,), (4,))) + + def test_bootstrap_curriculum_heavily_samples_easy_buckets(self): + sampler = self._make_sampler(make_bucket_dataset()) + early_probs = sampler.get_static_bucket_probs(step=0) + self.assertGreater(early_probs[0] + early_probs[1], 0.75) + self.assertGreater(early_probs[0], early_probs[2]) + self.assertGreater(early_probs[1], early_probs[2]) + self.assertLessEqual(early_probs[4], sampler.config.schedule.min_hard_frac + 1e-6) + + def test_post_bootstrap_curriculum_returns_to_medium_buckets(self): + sampler = self._make_sampler(make_bucket_dataset()) + post_bootstrap_probs = sampler.get_static_bucket_probs(step=sampler.config.schedule.bootstrap_steps) + self.assertEqual(int(post_bootstrap_probs.argmax()), 2) + self.assertGreater(post_bootstrap_probs[2], post_bootstrap_probs[1]) + self.assertGreater(post_bootstrap_probs[2], post_bootstrap_probs[3]) + + def test_late_curriculum_increases_hard_bucket_probability(self): + sampler = self._make_sampler(make_bucket_dataset()) + early_probs = sampler.get_bucket_probs(step=0) + late_step = sampler.config.schedule.warmup_steps + sampler.config.schedule.total_steps + late_probs = sampler.get_bucket_probs(step=late_step) + self.assertGreater(late_probs[4], early_probs[4]) + self.assertGreater(late_probs[4], late_probs[2]) + self.assertGreater(late_probs[3], late_probs[2]) + + def test_extremely_hard_example_is_rare_early_but_more_likely_late(self): + sampler = self._make_sampler(make_bucket_dataset()) + early_probability = sampler.get_example_probability(4, step=0) + late_step = sampler.config.schedule.warmup_steps + sampler.config.schedule.total_steps + late_probability = sampler.get_example_probability(4, step=late_step) + self.assertLess(early_probability, 0.1) + self.assertGreater(late_probability, early_probability) + + def test_probabilities_always_sum_to_one(self): + sampler = self._make_sampler(make_bucket_dataset()) + for step in ( + 0, + sampler.config.schedule.warmup_steps + 5, + sampler.config.schedule.warmup_steps + sampler.config.schedule.total_steps, + ): + self.assertAlmostEqual(float(sampler.get_static_bucket_probs(step=step).sum()), 1.0, places=6) + self.assertAlmostEqual(float(sampler.get_bucket_probs(step=step).sum()), 1.0, places=6) + + def test_static_bucket_probs_are_cached_per_step(self): + sampler = self._make_sampler(make_bucket_dataset()) + + with mock.patch.object(sampler._schedule, "build_probs", wraps=sampler._schedule.build_probs) as build_probs: + sampler.get_bucket_probs(step=7) + sampler.get_bucket_probs(step=7) + sampler.get_static_bucket_probs(step=7) + self.assertEqual(build_probs.call_count, 1) + + sampler.get_bucket_probs(step=8) + self.assertEqual(build_probs.call_count, 2) + + def test_excluding_index_invalidates_cached_bucket_probs(self): + sampler = self._make_sampler(make_bucket_dataset()) + + with mock.patch.object(sampler._schedule, "build_probs", wraps=sampler._schedule.build_probs) as build_probs: + sampler.get_static_bucket_probs(step=0) + self.assertEqual(build_probs.call_count, 1) + + sampler.exclude_index(4) + excluded_probs = sampler.get_static_bucket_probs(step=0) + self.assertEqual(build_probs.call_count, 2) + + self.assertEqual(float(excluded_probs[4]), 0.0) + + def test_adaptive_stats_increase_sampling_probability_for_high_signal_bucket(self): + sampler = self._make_sampler( + make_bucket_dataset(), adaptive=self._make_adaptive(enabled=True, update_every=1, blend=0.5) + ) + static_probs = sampler.get_bucket_probs(step=0) + sampler.record_observations( + dataset_indices=[4, 4, 4, 4, 4, 2, 2], + rewards=[0.3, 0.35, 0.25, 0.4, 0.3, 0.95, 0.9], + advantages=[1.2, 1.1, 1.0, 0.9, 1.3, 0.05, 0.02], + ) + adaptive_probs = sampler.get_bucket_probs(step=1) + self.assertGreater(adaptive_probs[4], static_probs[4]) + + def test_excluded_index_has_zero_probability_and_persists_after_restore(self): + dataset = make_bucket_dataset() + sampler = self._make_sampler(dataset) + sampler.exclude_index(4) + + self.assertEqual(sampler.get_example_probability(4, step=0), 0.0) + self.assertEqual(float(sampler.get_bucket_probs(step=0)[4]), 0.0) + self.assertTrue(all(sampler.sample_index(step=0) != 4 for _ in range(20))) + + restored_sampler = self._make_sampler(dataset) + restored_sampler.load_state_dict(sampler.state_dict()) + self.assertEqual(restored_sampler.get_example_probability(4, step=0), 0.0) + self.assertEqual(float(restored_sampler.get_bucket_probs(step=0)[4]), 0.0) + self.assertTrue(all(restored_sampler.sample_index(step=0) != 4 for _ in range(20))) + + def test_bootstrap_distribution_is_tunable(self): + default_sampler = self._make_sampler(make_bucket_dataset()) + tuned_sampler = self._make_sampler( + make_bucket_dataset(), + schedule=self._make_schedule(bootstrap_target=0.0, warmup_target=0.4, bootstrap_sigma=0.5), + ) + + default_probs = default_sampler.get_static_bucket_probs(step=0) + tuned_probs = tuned_sampler.get_static_bucket_probs(step=0) + + self.assertGreater(tuned_probs[0], default_probs[0]) + self.assertLess(tuned_probs[2], default_probs[2]) + + def test_curriculum_parser_builds_grouped_config(self): + parser = HfArgumentParser((difficulty_curriculum.DifficultyCurriculumArgs,)) + (curriculum,) = parser.parse_args_into_dataclasses( + [ + "--curriculum", + "difficulty", + "--curriculum_bootstrap_steps", + "12", + "--curriculum_warmup_steps", + "34", + "--curriculum_total_steps", + "56", + "--curriculum_adaptive", + "true", + "--curriculum_adaptive_blend", + "0.25", + "--curriculum_min_quantile", + "0.5", + ] + ) + + curriculum.verify() + curriculum_config = curriculum.build_curriculum_config(seed=17) + + self.assertIsNotNone(curriculum_config) + assert curriculum_config is not None + self.assertEqual(curriculum_config.schedule.bootstrap_steps, 12) + self.assertEqual(curriculum_config.schedule.warmup_steps, 34) + self.assertEqual(curriculum_config.schedule.total_steps, 56) + self.assertTrue(curriculum_config.adaptive.enabled) + self.assertEqual(curriculum_config.adaptive.blend, 0.25) + self.assertEqual(curriculum_config.min_quantile, 0.5) + self.assertEqual(curriculum_config.seed, 17) + + def test_curriculum_verify_accepts_none_mode(self): + difficulty_curriculum.DifficultyCurriculumArgs().verify() + + +class TestDifficultyCurriculumLoaderIntegration(unittest.TestCase): + def test_existing_behavior_is_unchanged_when_curriculum_disabled(self): + dataset = make_plain_hf_dataset(20) + built_loader = data_loader.create_prompt_dataloader(dataset=dataset, seed=7, work_dir=tempfile.gettempdir()) + baseline_loader = data_loader.HFDataLoader( + dataset=dataset, + batch_size=1, + seed=7, + dp_rank=0, + dp_world_size=1, + work_dir=tempfile.gettempdir(), + automatic_reshuffle=True, + collator=data_loader.single_example_collator, + ) + + self.assertIs(type(built_loader), data_loader.HFDataLoader) + + built_indices = [batch["index"].item() for batch in built_loader] + baseline_indices = [batch["index"].item() for batch in baseline_loader] + self.assertEqual(built_indices, baseline_indices) + + +if __name__ == "__main__": + unittest.main() diff --git a/scripts/data/difficulty_sampling/README.md b/scripts/data/difficulty_sampling/README.md new file mode 100644 index 0000000000..31d8ea8c51 --- /dev/null +++ b/scripts/data/difficulty_sampling/README.md @@ -0,0 +1,161 @@ +# Difficulty Sampling + +This directory contains tooling for building per-instance difficulty metadata +from pass-rate style datasets. + +A common use case is difficulty-aware curricula for GRPO / RLVR, but the +outputs are not specific to that setting. They are also useful for stratified +sampling, filtering, analysis, and dataset construction. + +## Create A Difficulty Map + +Use `create_difficulty_map.py` to turn a Hugging Face dataset with per-row +pass-rate aggregates into reusable difficulty annotations. + +The script expands pass-count summaries into binary attempt outcomes, fits a +shared Beta prior across binary outcomes, estimates per-item difficulty, and +writes JSONL difficulty files plus schema and metadata sidecars. When +`--push-to-hub` is set, it also uploads the validated output dataset as a +private Hub split. + +### Examples + +Write local difficulty files: + +```bash +uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --output /tmp/dapo_math_qwen3_difficulty +``` + +Write local files and push the single output group to the Hub: + +```bash +uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --task math \ + --output /tmp/dapo_math_qwen3_difficulty \ + --push-to-hub your-org/dapo-math-qwen3-difficulty \ + --split train +``` + +Hub uploads require exactly one task/model output group, so use `--task` or a +single-group input dataset when pushing. + +If you only need continuous difficulty scores, you can set +`--difficulty-buckets 0` to skip discrete bucket assignment. If you want to use +the default difficulty curriculum in `grpo_fast.py`, keep bucket assignment +enabled. + +## Difficulty Metadata Format + +`create_difficulty_map.py` writes a nested `difficulty` blob like this: + +```json +{ + "difficulty": { + "value": 0.9999999997624719, + "posterior_mean": 0.003437858035078528, + "posterior_lower_bound": 2.3752813430506325e-10, + "expected_quantile": 0.10139684528348392, + "bucket_index": 4, + "bucket_count": 5 + } +} +``` + +- `value` is the overall difficulty score, defined as + `1 - posterior_lower_bound`. Higher means harder. +- `posterior_mean` is the estimated solve probability for that prompt. Lower + means harder. +- `posterior_lower_bound` is a conservative lower bound on solve probability at + the configured `--posterior-lower-quantile`. +- `expected_quantile` is a posterior-aware rank in `[0, 1]`. It is useful for + filtering or clamping to a subset of the difficulty range. +- `bucket_index = 0` is the easiest bucket and + `bucket_index = bucket_count - 1` is the hardest. + +`grpo_fast.py` can optionally replace uniform prompt reshuffling with +`DifficultyCurriculumSampler`, which consumes this metadata by default. The +current curriculum implementation uses: + +- `difficulty.posterior_mean` for within-bucket weighting. +- `difficulty.bucket_index` and `difficulty.bucket_count` for bucketed + sampling. +- `difficulty.expected_quantile` for optional + `--curriculum_min_quantile` / `--curriculum_max_quantile` filtering. + +The sampler uses a smooth distribution with a configurable easy-heavy +bootstrap phase, then gradually shifts mass toward harder buckets instead of +hard-switching between discrete phases. Within each bucket, examples are +weighted by a blend of uncertainty (`4 * p * (1 - p)`) and hardness +(`1 - p`), so borderline prompts stay attractive while already-solved prompts +are naturally down-weighted. If `--curriculum_adaptive true` is set, bucket +probabilities are additionally blended with live reward / advantage statistics +so buckets with useful learning signal can get more mass during training. + +## Recommended Starting Point + +If you are using this metadata with `DifficultyCurriculumSampler`, +`bucket_count=5` is a reasonable starting point: + +- Bootstrap (first ~100 steps by default): buckets 0 and 1 dominate so the + model sees easier prompts while it settles into the chat template and task + format. +- Early after bootstrap: bucket 2 highest, buckets 1 and 3 nonzero, bucket 4 + low. +- Mid: buckets 2 and 3 dominate, with bucket 4 increasing. +- Late: buckets 3 and 4 dominate, while buckets 0-2 remain nonzero. + +Useful flags: + +```bash +--curriculum difficulty \ +--curriculum_metadata_field difficulty \ +--curriculum_bootstrap_steps 100 \ +--curriculum_bootstrap_target 0.125 \ +--curriculum_warmup_target 0.5 \ +--curriculum_final_target 1.0 \ +--curriculum_warmup_steps 500 \ +--curriculum_total_steps 10000 \ +--curriculum_min_hard_frac 0.05 \ +--curriculum_max_hard_frac 0.50 \ +--curriculum_bucket_sigma 0.0 \ +--curriculum_bootstrap_sigma 0.0 \ +--curriculum_uncertainty_weight 0.5 \ +--curriculum_strict_metadata true \ +--curriculum_min_quantile 0.0 \ +--curriculum_max_quantile 1.0 \ +--curriculum_adaptive true +``` + +Tuning tips: + +- Increase `curriculum_bootstrap_steps` to keep the easy bootstrap around + longer. +- Lower `curriculum_bootstrap_target` to bias more strongly toward the easiest + buckets early. +- Lower `curriculum_bucket_sigma` or `curriculum_bootstrap_sigma` to + concentrate probability on fewer neighboring buckets. +- Lower `curriculum_warmup_target` if you want the post-bootstrap warmup to + stay easier for longer. +- Raise `curriculum_min_quantile` or lower `curriculum_max_quantile` if you + want to focus on a narrower difficulty band. + +## Metrics + +Some useful curriculum metrics are: + +- `curriculum/progress` +- `curriculum/static_bucket_prob_*` +- `curriculum/adaptive_bucket_prob_*` +- `curriculum/bucket_prob_*` +- `curriculum/sampled_bucket_count_*` +- `curriculum/bucket_reward_mean_*` +- `curriculum/bucket_abs_advantage_mean_*` + +See [qwen3_4b_dapo_math_dc.sh](../../train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc.sh) +for a concrete launch example, and the rest of +`scripts/train/qwen/difficult-curriculum/` for variants. diff --git a/scripts/data/difficulty_sampling/create_difficulty_map.py b/scripts/data/difficulty_sampling/create_difficulty_map.py new file mode 100644 index 0000000000..4804b6b044 --- /dev/null +++ b/scripts/data/difficulty_sampling/create_difficulty_map.py @@ -0,0 +1,1147 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = "==3.12.*" +# dependencies = [ +# "datasets>=4.0.0", +# "numpy>=2", +# "scipy>=1.14.0", +# ] +# /// + +""" +Build a per-instance difficulty map from Hugging Face datasets with pass-rate +aggregates. + +The script loads a Hugging Face dataset that already contains per-row pass +counts, expands those counts into binary attempt outcomes, fits a Beta prior +across binary outcomes, estimates per-item difficulty, and writes JSONL +difficulty files plus schema/metadata sidecars. When `--push-to-hub` is set, it +also uploads the validated output dataset to the requested Hugging Face repo. +Hub uploads require exactly one task/model output group, so use `--task` or a +single-group input dataset when pushing. + +Examples: + Write local difficulty files: + uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --output /tmp/dapo_math_qwen3_difficulty + + Write local files and push the single output group to the Hub: + uv run scripts/data/difficulty_sampling/create_difficulty_map.py \ + --hf-dataset mnoukhov/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples \ + --hf-split train \ + --task math \ + --output /tmp/dapo_math_qwen3_difficulty \ + --push-to-hub your-org/dapo-math-qwen3-difficulty \ + --split train +""" + +from __future__ import annotations + +import argparse +import json +import logging +import math +from collections import defaultdict +from dataclasses import asdict, dataclass, field, is_dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from datasets import Dataset, load_dataset +from scipy.optimize import minimize +from scipy.special import betaln +from scipy.stats import beta as beta_distribution + +logger = logging.getLogger(__name__) + + +EPS = 1e-8 +JEFFREYS_PRIOR_ALPHA = 0.5 +JEFFREYS_PRIOR_BETA = 0.5 +DEFAULT_DIFFICULTY_BUCKETS = 5 +POSTERIOR_QUANTILE_GRID_SIZE = 512 +POSTERIOR_QUANTILE_BATCH_SIZE = 256 +DIFFICULTY_GENERATION_METHOD = "beta_binomial_posterior_quantiles" +DIFFICULTY_METHOD_FILENAME_ALIASES = {DIFFICULTY_GENERATION_METHOD: "bbq"} +PRIOR_SOURCE_FILENAME_ALIASES = {"empirical_bayes": "eb", "jeffreys": "j", "jeffreys_fallback": "jf"} +HF_SOURCE_FORMAT_KIND = "hugging_face_dataset_passrate_rows" +HF_INSTANCE_ID_DEFINITION = ( + "dataset_repo_id::row_id_field when a stable row id is available; otherwise dataset_repo_id::row_index" +) +HF_OUTPUT_COLUMNS = ("difficulty",) + + +@dataclass(frozen=True) +class BetaPrior: + alpha: float + beta: float + source: str + + +@dataclass +class ExperimentMetadata: + source_root: str + model_name: str | None + experiment_id: str | None + experiment_name: str + + +@dataclass +class DifficultyPayload: + value: float | None = None + posterior_mean: float | None = None + posterior_lower_bound: float | None = None + expected_quantile: float | None = None + bucket_index: int | None = None + bucket_count: int | None = None + + +@dataclass +class DifficultyRow: + source_row_index: int + instance_id: str + task_name: str + base_task_name: str + source_dataset: str + source_row_id: str + attempt_scores: list[float] + finish_reasons: list[str] + experiment_metadata: ExperimentMetadata + score_sources: list[str] + warnings: list[str] + difficulty: DifficultyPayload = field(default_factory=DifficultyPayload) + + +@dataclass(frozen=True) +class SourceFormatMetadata: + kind: str + dataset_repo_id: str + config_name: str | None + split: str + row_id_field: str + task_field: str + model_field: str + pass_count_field: str + attempt_count_field: str + pass_rate_field: str | None + instance_id_definition: str + + +@dataclass +class ScoreProcessingMetadata: + source_field: str = "reward" + output_field: str = "attempt_scores" + normalization: str = "unsupported" + positive_reward_value: float | None = None + supports_binary_difficulty: bool = False + + +@dataclass(frozen=True) +class BetaPriorUsageMetadata: + source: str | None + alpha: float | None + beta: float | None + + +@dataclass +class DifficultyGenerationMetadata: + method: str + difficulty_value_field: str + difficulty_value_definition: str + bucket_field: str + bucket_count_field: str + bucket_ranking_field: str + posterior_lower_quantile: float + bucket_count_requested: int + bucket_count_effective: int + beta_prior_requested: str + beta_prior_used: BetaPriorUsageMetadata + binary_instance_count: int + nonbinary_instance_count: int + tag: str | None = None + + +@dataclass +class DatasetMetadata: + task_name: str + model_name: str | None + row_count: int + source_format: SourceFormatMetadata + score_processing: ScoreProcessingMetadata + difficulty_generation: DifficultyGenerationMetadata + + +@dataclass(frozen=True) +class DifficultyPosteriorRow: + row: DifficultyRow + difficulty_alpha: float + difficulty_beta: float + + +@dataclass(frozen=True) +class InputRowsBundle: + rows: list[DifficultyRow] + malformed_records: int + source_format: SourceFormatMetadata + source_dataset: Dataset + + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Build a per-instance difficulty map from HF pass-rate datasets.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--hf-dataset", + type=str, + required=True, + help="Hugging Face dataset repo id containing per-row pass-rate aggregates.", + ) + parser.add_argument("--hf-config", type=str, default=None, help="Optional dataset config for --hf-dataset.") + parser.add_argument("--hf-split", type=str, default="train", help="Input split to load from --hf-dataset.") + parser.add_argument( + "--hf-row-id-field", + type=str, + default="extra_info.index", + help="Dot-path to the stable per-row id field inside --hf-dataset.", + ) + parser.add_argument( + "--hf-task-field", type=str, default="dataset", help="Dot-path to the task/verifier field in --hf-dataset." + ) + parser.add_argument( + "--hf-model-field", + type=str, + default="generator_model", + help="Dot-path to the generator model field in --hf-dataset.", + ) + parser.add_argument( + "--hf-pass-count-field", + type=str, + default="pass_count", + help="Dot-path to the integer pass-count field in --hf-dataset.", + ) + parser.add_argument( + "--hf-attempt-count-field", + type=str, + default="num_samples", + help="Dot-path to the total-attempt-count field in --hf-dataset.", + ) + parser.add_argument( + "--hf-pass-rate-field", + type=str, + default="pass_rate", + help="Optional dot-path to a pass-rate or fraction field used for validation/fallback in --hf-dataset.", + ) + parser.add_argument( + "--task", action="append", default=[], help="Optional task filter. Matches the dataset task/verifier source." + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help=( + "Output directory or path-like root. The script writes one file per task/model inside it as " + "____.jsonl plus matching .schema.json and .metadata.json sidecars." + ), + ) + parser.add_argument( + "--push-to-hub", + type=str, + default=None, + help=("Optional dataset repo id to push the validated rows to. Requires exactly one task/model output group."), + ) + parser.add_argument("--split", type=str, default="train", help="Split to use with --push-to-hub.") + parser.add_argument("--strict", action="store_true", help="Fail if an input dataset row is malformed.") + parser.add_argument( + "--allow-nonunit-scores", + action="store_true", + help="Keep rows whose rewards cannot be normalized to binary correctness. Difficulty will be null for them.", + ) + parser.add_argument( + "--max-instances", + type=int, + default=None, + help="Optional cap for the number of resolved instances written (useful for smoke tests).", + ) + parser.add_argument( + "--beta-prior", + choices=["empirical-bayes", "jeffreys"], + default="empirical-bayes", + help="Global Beta prior to use for smoothing binary solve rates.", + ) + parser.add_argument( + "--posterior-lower-quantile", + type=float, + default=0.1, + help="Lower posterior quantile used to define difficulty as 1 - quantile.", + ) + parser.add_argument( + "--difficulty-buckets", + type=int, + default=DEFAULT_DIFFICULTY_BUCKETS, + help=( + "Number of posterior-aware quantile buckets to assign for stratification. " + "Set to 0 to skip discrete bucket assignment." + ), + ) + return parser + + +def main(argv: list[str] | None = None) -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") + args = make_parser().parse_args(argv) + validate_args(args) + task_filters = set(args.task) + output_root = args.output + output_root_str = str(output_root) + for suffix in (".schema.json", ".jsonl", ".json"): + if output_root_str.endswith(suffix): + output_root = Path(output_root_str[: -len(suffix)]) + break + + input_rows = load_hf_dataset_rows( + dataset_name=args.hf_dataset, + config_name=args.hf_config, + split=args.hf_split, + task_filters=task_filters, + strict=args.strict, + row_id_field=args.hf_row_id_field, + task_field=args.hf_task_field, + model_field=args.hf_model_field, + pass_count_field=args.hf_pass_count_field, + attempt_count_field=args.hf_attempt_count_field, + pass_rate_field=args.hf_pass_rate_field, + ) + + if not input_rows.rows: + raise ValueError("No resolved per-instance rows were produced.") + + rows = sorted( + input_rows.rows, key=lambda row: (row.task_name, row.experiment_metadata.model_name or "", row.instance_id) + ) + if args.max_instances is not None: + rows = rows[: args.max_instances] + + rows_by_group = group_rows_by_task_and_model(rows) + if args.push_to_hub is not None and len(rows_by_group) != 1: + raise ValueError( + "--push-to-hub requires a single task/model output. Filter with --task or use a dataset with one task." + ) + + score_source_field = ",".join( + field_name + for field_name in ( + input_rows.source_format.pass_count_field, + input_rows.source_format.attempt_count_field, + input_rows.source_format.pass_rate_field, + ) + if field_name + ) + skipped_nonunit = 0 + written_outputs: list[tuple[str, str | None, int, Path, Path, Path]] = [] + + for (task_name, model_name), group_rows in sorted( + rows_by_group.items(), key=lambda item: (item[0][0], item[0][1] or "") + ): + group_rows, score_processing, group_skipped_nonunit = normalize_attempt_scores_for_group( + group_rows, allow_nonunit_scores=args.allow_nonunit_scores + ) + score_processing.source_field = score_source_field + skipped_nonunit += group_skipped_nonunit + + if not group_rows: + logger.warning( + "Skipping task=%s model=%s because no rows remained after reward normalization.", task_name, model_name + ) + continue + + prior, binary_row_count = estimate_beta_prior(group_rows, prior_mode=args.beta_prior) + group_rows = apply_beta_binomial_difficulty( + group_rows, prior=prior, lower_quantile=args.posterior_lower_quantile, num_buckets=args.difficulty_buckets + ) + output_rows, dataset = build_hf_output_dataset(input_rows.source_dataset, group_rows) + + dataset_metadata = build_dataset_metadata( + rows=group_rows, + task_name=task_name, + model_name=model_name, + requested_prior_mode=args.beta_prior, + requested_bucket_count=args.difficulty_buckets, + lower_quantile=args.posterior_lower_quantile, + prior=prior, + binary_row_count=binary_row_count, + score_processing=score_processing, + source_format=input_rows.source_format, + ) + + if prior is not None: + logger.info( + "Using %s Beta prior alpha=%.4f beta=%.4f across %s binary instances for task=%s model=%s.", + prior.source, + prior.alpha, + prior.beta, + binary_row_count, + task_name, + model_name, + ) + else: + logger.warning( + "No binary instances were available for Beta-Binomial difficulty estimation for task=%s model=%s.", + task_name, + model_name, + ) + + annotate_dataset_metadata(dataset, dataset_metadata) + output_jsonl, schema_json, metadata_json = build_output_paths( + output_root, task_name=task_name, model_name=model_name, dataset_metadata=dataset_metadata + ) + write_output_files( + output_jsonl=output_jsonl, + schema_json=schema_json, + metadata_json=metadata_json, + dataset=dataset, + dataset_metadata=dataset_metadata, + ) + + if args.push_to_hub is not None: + dataset.push_to_hub(args.push_to_hub, split=args.split, private=True) + + written_outputs.append((task_name, model_name, len(output_rows), output_jsonl, schema_json, metadata_json)) + logger.info( + "Wrote %s rows for task=%s model=%s to %s, %s, and %s.", + len(group_rows), + task_name, + model_name, + output_jsonl, + schema_json, + metadata_json, + ) + + logger.info( + "Finished writing %s output file groups (%s malformed dataset rows, %s skipped due to unsupported scores).", + len(written_outputs), + input_rows.malformed_records, + skipped_nonunit, + ) + + +def load_hf_dataset_rows( + *, + dataset_name: str, + config_name: str | None, + split: str, + task_filters: set[str], + strict: bool, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> InputRowsBundle: + logger.info( + "Loading Hugging Face dataset %s (config=%s, split=%s).", dataset_name, config_name or "default", split + ) + + if config_name: + source_dataset = load_dataset(dataset_name, config_name, split=split) + else: + source_dataset = load_dataset(dataset_name, split=split) + + rows: list[DifficultyRow] = [] + malformed_records = 0 + + for row_index, source_row in enumerate(source_dataset): + try: + row = build_hf_dataset_row( + source_row=source_row, + source_row_index=row_index, + dataset_name=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ) + except Exception as exc: + malformed_records += 1 + message = f"Malformed HF dataset row {dataset_name}[{split}][{row_index}]: {exc}" + if strict: + raise ValueError(message) from exc + logger.warning(message) + continue + + task_name = row.task_name + if task_filters and task_name not in task_filters and get_base_task_name(task_name) not in task_filters: + continue + rows.append(row) + + return InputRowsBundle( + rows=rows, + malformed_records=malformed_records, + source_format=build_hf_source_format_metadata( + dataset_name=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ), + source_dataset=source_dataset, + ) + + +def build_hf_dataset_row( + *, + source_row: dict[str, Any], + source_row_index: int, + dataset_name: str, + config_name: str | None, + split: str, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> DifficultyRow: + task_name = normalize_task_name(get_nested_field(source_row, task_field)) + if task_name is None: + raise ValueError(f"missing task field {task_field!r}") + + source_row_id = normalize_identifier(get_nested_field(source_row, row_id_field)) or str(source_row_index) + pass_count, attempt_count = extract_hf_attempt_summary( + row=source_row, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + ) + raw_model_name = get_nested_field(source_row, model_field) + model_name = None if raw_model_name is None else str(raw_model_name) or None + + return DifficultyRow( + source_row_index=source_row_index, + instance_id=f"{dataset_name}::{source_row_id}", + task_name=task_name, + base_task_name=get_base_task_name(task_name), + source_dataset=dataset_name, + source_row_id=source_row_id, + attempt_scores=expand_binary_attempt_scores(pass_count=pass_count, attempt_count=attempt_count), + finish_reasons=[], + experiment_metadata=ExperimentMetadata( + source_root=f"hf://{dataset_name}/{config_name or 'default'}/{split}", + model_name=model_name, + experiment_id=None, + experiment_name=dataset_name, + ), + score_sources=[task_name], + warnings=[], + ) + + +def build_hf_source_format_metadata( + *, + dataset_name: str, + config_name: str | None, + split: str, + row_id_field: str, + task_field: str, + model_field: str, + pass_count_field: str, + attempt_count_field: str, + pass_rate_field: str | None, +) -> SourceFormatMetadata: + return SourceFormatMetadata( + kind=HF_SOURCE_FORMAT_KIND, + dataset_repo_id=dataset_name, + config_name=config_name, + split=split, + row_id_field=row_id_field, + task_field=task_field, + model_field=model_field, + pass_count_field=pass_count_field, + attempt_count_field=attempt_count_field, + pass_rate_field=pass_rate_field, + instance_id_definition=HF_INSTANCE_ID_DEFINITION, + ) + + +def build_hf_output_dataset( + source_dataset: Dataset, rows: list[DifficultyRow] +) -> tuple[list[dict[str, Any]], Dataset]: + ordered_rows = sorted(rows, key=lambda row: row.source_row_index) + output_rows = [] + for row in ordered_rows: + output_row = asdict(row) + output_row.pop("source_row_index") + output_rows.append(output_row) + + dataset = source_dataset.select([row.source_row_index for row in ordered_rows]) + + for column_name in HF_OUTPUT_COLUMNS: + values = [make_jsonable(getattr(row, column_name)) for row in ordered_rows] + if column_name in dataset.column_names: + dataset = dataset.remove_columns(column_name) + dataset = dataset.add_column(column_name, values) + + return output_rows, dataset + + +def get_nested_field(value: Any, field_path: str) -> Any: + if not field_path: + return value + + current = value + for field_name in field_path.split("."): + if not isinstance(current, dict) or field_name not in current: + return None + current = current[field_name] + return current + + +def normalize_identifier(value: Any) -> str | None: + if value is None or isinstance(value, bool): + return None + text = (value if isinstance(value, str) else str(value)).strip() + return text or None + + +def normalize_nonnegative_int(value: Any) -> int | None: + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value if value >= 0 else None + if isinstance(value, float): + if not math.isfinite(value) or not value.is_integer() or value < 0: + return None + return int(value) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + parsed = int(stripped) + except ValueError: + return None + return parsed if parsed >= 0 else None + return None + + +def parse_pass_rate_value(value: Any) -> tuple[int | None, int | None, float | None]: + if value is None: + return None, None, None + if isinstance(value, (int, float)) and not isinstance(value, bool): + rate = float(value) + if not math.isfinite(rate): + raise ValueError(f"expected finite pass-rate value in [0, 1], received {value!r}") + if 0.0 <= rate <= 1.0: + return None, None, rate + raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") + if not isinstance(value, str): + raise ValueError(f"unsupported pass-rate value {value!r}") + + stripped = value.strip() + if not stripped: + return None, None, None + + if "/" in stripped: + numerator_text, denominator_text = stripped.split("/", 1) + numerator = normalize_nonnegative_int(numerator_text) + denominator = normalize_nonnegative_int(denominator_text) + if numerator is None or denominator is None or numerator > denominator: + raise ValueError(f"invalid pass-rate fraction {value!r}") + rate = 0.0 if denominator == 0 else numerator / denominator + return numerator, denominator, rate + + try: + rate = float(stripped) + except ValueError as exc: + raise ValueError(f"invalid pass-rate value {value!r}") from exc + if not math.isfinite(rate) or rate < 0.0 or rate > 1.0: + raise ValueError(f"expected pass-rate value in [0, 1], received {value!r}") + return None, None, rate + + +def extract_hf_attempt_summary( + *, row: dict[str, Any], pass_count_field: str, attempt_count_field: str, pass_rate_field: str | None +) -> tuple[int, int]: + pass_count = normalize_nonnegative_int(get_nested_field(row, pass_count_field)) + attempt_count = normalize_nonnegative_int(get_nested_field(row, attempt_count_field)) + + parsed_pass_count = None + parsed_attempt_count = None + parsed_pass_rate = None + if pass_rate_field: + parsed_pass_count, parsed_attempt_count, parsed_pass_rate = parse_pass_rate_value( + get_nested_field(row, pass_rate_field) + ) + + if pass_count is None and parsed_pass_count is not None: + pass_count = parsed_pass_count + if attempt_count is None and parsed_attempt_count is not None: + attempt_count = parsed_attempt_count + + if pass_count is None or attempt_count is None: + raise ValueError( + f"missing pass-count summary fields {pass_count_field!r}/{attempt_count_field!r}" + f"{f' or parseable {pass_rate_field!r}' if pass_rate_field else ''}" + ) + if attempt_count <= 0: + raise ValueError(f"attempt count must be positive, received {attempt_count}") + if pass_count > attempt_count: + raise ValueError(f"pass count {pass_count} exceeds attempt count {attempt_count}") + + if parsed_pass_count is not None and parsed_pass_count != pass_count: + raise ValueError(f"pass-count field {pass_count_field!r} disagrees with {pass_rate_field!r}") + if parsed_attempt_count is not None and parsed_attempt_count != attempt_count: + raise ValueError(f"attempt-count field {attempt_count_field!r} disagrees with {pass_rate_field!r}") + if parsed_pass_rate is not None and not math.isclose( + pass_count / attempt_count, parsed_pass_rate, rel_tol=EPS, abs_tol=EPS + ): + raise ValueError( + f"pass-count fields {pass_count_field!r}/{attempt_count_field!r} disagree with {pass_rate_field!r}" + ) + + return pass_count, attempt_count + + +def expand_binary_attempt_scores(*, pass_count: int, attempt_count: int) -> list[float]: + return [1.0] * pass_count + [0.0] * (attempt_count - pass_count) + + +def normalize_task_name(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)) and len(value) == 1: + return normalize_task_name(value[0]) + serialized = json.dumps(make_jsonable(value), ensure_ascii=False, sort_keys=True) + return serialized or None + + +def normalize_attempt_scores_for_group( + rows: list[DifficultyRow], *, allow_nonunit_scores: bool +) -> tuple[list[DifficultyRow], ScoreProcessingMetadata, int]: + score_processing = infer_score_processing(rows) + normalized_rows: list[DifficultyRow] = [] + skipped_nonunit = 0 + + for row in rows: + normalized_scores = normalize_attempt_scores(row.attempt_scores, score_processing) + if normalized_scores is None: + if allow_nonunit_scores: + row.attempt_scores = [float(score) for score in row.attempt_scores] + row.warnings = sorted({*row.warnings, "nonbinary_reward_scores"}) + normalized_rows.append(row) + else: + skipped_nonunit += 1 + continue + + row.attempt_scores = normalized_scores + normalized_rows.append(row) + + return normalized_rows, score_processing, skipped_nonunit + + +def infer_score_processing(rows: list[DifficultyRow]) -> ScoreProcessingMetadata: + scores = [float(score) for row in rows for score in row.attempt_scores] + score_processing = ScoreProcessingMetadata() + + if not scores: + return score_processing + + if all( + math.isclose(score, 0.0, rel_tol=EPS, abs_tol=EPS) or math.isclose(score, 1.0, rel_tol=EPS, abs_tol=EPS) + for score in scores + ): + score_processing.normalization = "identity_binary" + score_processing.positive_reward_value = 1.0 + score_processing.supports_binary_difficulty = True + return score_processing + + if any(score < -EPS for score in scores): + return score_processing + + positive_scores = [score for score in scores if score > EPS] + if not positive_scores: + score_processing.normalization = "all_zero_binary" + score_processing.supports_binary_difficulty = True + return score_processing + + positive_reward_value = max(positive_scores) + if all( + math.isclose(score, 0.0, rel_tol=EPS, abs_tol=EPS) + or math.isclose(score, positive_reward_value, rel_tol=EPS, abs_tol=EPS) + for score in scores + ): + score_processing.normalization = "binary_zero_or_constant" + score_processing.positive_reward_value = positive_reward_value + score_processing.supports_binary_difficulty = True + + return score_processing + + +def normalize_attempt_scores( + attempt_scores: list[float], score_processing: ScoreProcessingMetadata +) -> list[float] | None: + if not score_processing.supports_binary_difficulty: + return None + + normalization = score_processing.normalization + positive_reward_value = score_processing.positive_reward_value + normalized_scores: list[float] = [] + + for score in attempt_scores: + if math.isclose(score, 0.0, rel_tol=EPS, abs_tol=EPS): + normalized_scores.append(0.0) + continue + + if normalization == "identity_binary" and math.isclose(score, 1.0, rel_tol=EPS, abs_tol=EPS): + normalized_scores.append(1.0) + continue + + if ( + normalization == "binary_zero_or_constant" + and positive_reward_value is not None + and math.isclose(score, float(positive_reward_value), rel_tol=EPS, abs_tol=EPS) + ): + normalized_scores.append(1.0) + continue + + if normalization == "all_zero_binary": + return None + + return None + + return normalized_scores + + +def estimate_beta_prior(rows: list[DifficultyRow], *, prior_mode: str) -> tuple[BetaPrior | None, int]: + binary_counts = [counts for row in rows if (counts := extract_binary_counts(row.attempt_scores)) is not None] + if not binary_counts: + return None, 0 + + if prior_mode == "jeffreys": + return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys"), len(binary_counts) + + prior = fit_empirical_beta_prior(binary_counts) + if prior is not None: + return prior, len(binary_counts) + + logger.warning("Falling back to Jeffreys prior after empirical-Bayes fitting failed.") + return BetaPrior(JEFFREYS_PRIOR_ALPHA, JEFFREYS_PRIOR_BETA, "jeffreys_fallback"), len(binary_counts) + + +def apply_beta_binomial_difficulty( + rows: list[DifficultyRow], *, prior: BetaPrior | None, lower_quantile: float, num_buckets: int +) -> list[DifficultyRow]: + posterior_rows: list[DifficultyPosteriorRow] = [] + + for row in rows: + row.difficulty = DifficultyPayload() + + if prior is None: + continue + + binary_counts = extract_binary_counts(row.attempt_scores) + if binary_counts is None: + continue + + success_count, attempt_count = binary_counts + posterior_alpha = success_count + prior.alpha + posterior_beta = attempt_count - success_count + prior.beta + posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta) + posterior_lower_bound = float(beta_distribution.ppf(lower_quantile, posterior_alpha, posterior_beta)) + + row.difficulty = DifficultyPayload( + value=max(0.0, min(1.0, 1.0 - posterior_lower_bound)), + posterior_mean=posterior_mean, + posterior_lower_bound=posterior_lower_bound, + ) + posterior_rows.append( + DifficultyPosteriorRow(row=row, difficulty_alpha=posterior_beta, difficulty_beta=posterior_alpha) + ) + + assign_posterior_difficulty_buckets(posterior_rows, num_buckets=num_buckets) + return rows + + +def assign_posterior_difficulty_buckets(posterior_rows: list[DifficultyPosteriorRow], *, num_buckets: int) -> None: + if not posterior_rows: + return + + expected_quantiles = estimate_expected_difficulty_quantiles(posterior_rows) + for posterior_row, expected_quantile in zip(posterior_rows, expected_quantiles, strict=True): + posterior_row.row.difficulty.expected_quantile = expected_quantile + + if num_buckets <= 0: + return + + effective_bucket_count = min(num_buckets, len(posterior_rows)) + ordered_rows = sorted( + zip(posterior_rows, expected_quantiles, strict=True), + key=lambda item: (item[1], item[0].row.difficulty.value, item[0].row.instance_id), + ) + base_bucket_size, remainder = divmod(len(ordered_rows), effective_bucket_count) + + cursor = 0 + for bucket_index in range(effective_bucket_count): + bucket_size = base_bucket_size + (1 if bucket_index < remainder else 0) + for posterior_row, _expected_quantile in ordered_rows[cursor : cursor + bucket_size]: + posterior_row.row.difficulty.bucket_index = bucket_index + posterior_row.row.difficulty.bucket_count = effective_bucket_count + cursor += bucket_size + + +def estimate_expected_difficulty_quantiles( + posterior_rows: list[DifficultyPosteriorRow], + *, + grid_size: int = POSTERIOR_QUANTILE_GRID_SIZE, + batch_size: int = POSTERIOR_QUANTILE_BATCH_SIZE, +) -> list[float]: + if not posterior_rows: + return [] + if len(posterior_rows) == 1: + return [0.5] + + grid = (np.arange(grid_size, dtype=np.float64) + 0.5) / grid_size + difficulty_alphas = np.asarray([row.difficulty_alpha for row in posterior_rows], dtype=np.float64) + difficulty_betas = np.asarray([row.difficulty_beta for row in posterior_rows], dtype=np.float64) + + mixture_cdf = np.zeros(grid_size, dtype=np.float64) + for start in range(0, len(posterior_rows), batch_size): + stop = start + batch_size + batch_cdf = beta_distribution.cdf( + grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] + ) + mixture_cdf += np.nan_to_num(batch_cdf, nan=0.0, posinf=1.0, neginf=0.0).sum(axis=0) + mixture_cdf /= len(posterior_rows) + + quantiles = np.zeros(len(posterior_rows), dtype=np.float64) + dx = 1.0 / grid_size + for start in range(0, len(posterior_rows), batch_size): + stop = start + batch_size + batch_pdf = beta_distribution.pdf( + grid[None, :], difficulty_alphas[start:stop, None], difficulty_betas[start:stop, None] + ) + quantiles[start:stop] = np.clip( + np.nan_to_num(batch_pdf, nan=0.0, posinf=0.0, neginf=0.0).dot(mixture_cdf) * dx, 0.0, 1.0 + ) + + return quantiles.tolist() + + +def fit_empirical_beta_prior(binary_counts: list[tuple[int, int]]) -> BetaPrior | None: + total_successes = sum(success_count for success_count, _ in binary_counts) + total_attempts = sum(attempt_count for _, attempt_count in binary_counts) + if total_attempts == 0 or total_successes in {0, total_attempts}: + return None + + mean_rate = total_successes / total_attempts + init_alpha = max(mean_rate * 2.0, 1e-3) + init_beta = max((1.0 - mean_rate) * 2.0, 1e-3) + + def objective(log_params: tuple[float, float]) -> float: + alpha = math.exp(log_params[0]) + beta = math.exp(log_params[1]) + return -sum( + betaln(success_count + alpha, attempt_count - success_count + beta) - betaln(alpha, beta) + for success_count, attempt_count in binary_counts + ) + + result = minimize( + objective, + x0=(math.log(init_alpha), math.log(init_beta)), + method="L-BFGS-B", + bounds=[(-10.0, 10.0), (-10.0, 10.0)], + ) + if not result.success: + logger.warning("Empirical-Bayes fit failed: %s", result.message) + return None + + return BetaPrior(alpha=math.exp(result.x[0]), beta=math.exp(result.x[1]), source="empirical_bayes") + + +def build_output_paths( + output_root: Path, *, task_name: str, model_name: str | None, dataset_metadata: DatasetMetadata +) -> tuple[Path, Path, Path]: + task_suffix = task_name.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") or "unknown-task" + model_suffix = (model_name or "").replace(":", "_").replace("/", "_").replace("\\", "_").replace( + " ", "_" + ) or "unknown-model" + difficulty_suffix = f"__{dataset_metadata.difficulty_generation.tag}" + stem = output_root / f"{task_suffix}__{model_suffix}{difficulty_suffix}" + return Path(f"{stem}.jsonl"), Path(f"{stem}.schema.json"), Path(f"{stem}.metadata.json") + + +def write_output_files( + *, output_jsonl: Path, schema_json: Path, metadata_json: Path, dataset: Dataset, dataset_metadata: DatasetMetadata +) -> None: + output_jsonl.parent.mkdir(parents=True, exist_ok=True) + with output_jsonl.open("w") as output_file: + for row in dataset: + output_file.write(json.dumps(make_jsonable(row), ensure_ascii=False) + "\n") + + schema_json.parent.mkdir(parents=True, exist_ok=True) + try: + schema_payload: Any = dataset.features.to_dict() + except AttributeError: + schema_payload = str(dataset.features) + with schema_json.open("w") as output_file: + json.dump(schema_payload, output_file, indent=2, sort_keys=True) + + metadata_json.parent.mkdir(parents=True, exist_ok=True) + with metadata_json.open("w") as output_file: + json.dump(make_jsonable(dataset_metadata), output_file, indent=2, sort_keys=True) + + +def build_dataset_metadata( + *, + rows: list[DifficultyRow], + task_name: str, + model_name: str | None, + requested_prior_mode: str, + requested_bucket_count: int, + lower_quantile: float, + prior: BetaPrior | None, + binary_row_count: int, + score_processing: ScoreProcessingMetadata, + source_format: SourceFormatMetadata, +) -> DatasetMetadata: + effective_bucket_count = extract_effective_bucket_count(rows) + difficulty_generation = DifficultyGenerationMetadata( + method=DIFFICULTY_GENERATION_METHOD, + difficulty_value_field="difficulty.value", + difficulty_value_definition="1 - difficulty.posterior_lower_bound", + bucket_field="difficulty.bucket_index", + bucket_count_field="difficulty.bucket_count", + bucket_ranking_field="difficulty.expected_quantile", + posterior_lower_quantile=lower_quantile, + bucket_count_requested=requested_bucket_count, + bucket_count_effective=effective_bucket_count, + beta_prior_requested=requested_prior_mode, + beta_prior_used=BetaPriorUsageMetadata( + source=prior.source if prior is not None else None, + alpha=prior.alpha if prior is not None else None, + beta=prior.beta if prior is not None else None, + ), + binary_instance_count=binary_row_count, + nonbinary_instance_count=max(0, len(rows) - binary_row_count), + ) + method_value = difficulty_generation.method + if method_value: + method_token = DIFFICULTY_METHOD_FILENAME_ALIASES.get( + method_value, method_value.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") + ) + else: + method_token = "diff" + + prior_source = difficulty_generation.beta_prior_used.source + if prior_source: + prior_token = PRIOR_SOURCE_FILENAME_ALIASES.get( + prior_source, prior_source.replace(":", "_").replace("/", "_").replace("\\", "_").replace(" ", "_") + ) + else: + prior_token = "none" + + quantile_token = ( + f"q{f'{difficulty_generation.posterior_lower_quantile * 100.0:.8g}'.replace('-', 'm').replace('.', 'p')}" + ) + if difficulty_generation.bucket_count_requested == difficulty_generation.bucket_count_effective: + bucket_token = f"k{difficulty_generation.bucket_count_requested}" + else: + bucket_token = ( + f"k{difficulty_generation.bucket_count_requested}e{difficulty_generation.bucket_count_effective}" + ) + difficulty_generation.tag = "-".join([method_token, prior_token, quantile_token, bucket_token]) + return DatasetMetadata( + task_name=task_name, + model_name=model_name, + row_count=len(rows), + source_format=source_format, + score_processing=score_processing, + difficulty_generation=difficulty_generation, + ) + + +def extract_effective_bucket_count(rows: list[DifficultyRow]) -> int: + effective_bucket_counts = {row.difficulty.bucket_count for row in rows if row.difficulty.bucket_count is not None} + if not effective_bucket_counts: + return 0 + if len(effective_bucket_counts) != 1: + raise ValueError(f"Expected a single effective bucket count, found {sorted(effective_bucket_counts)}") + return next(iter(effective_bucket_counts)) + + +def annotate_dataset_metadata(dataset: Dataset, dataset_metadata: DatasetMetadata) -> None: + if not hasattr(dataset, "info") or dataset.info is None: + return + dataset.info.description = json.dumps(make_jsonable(dataset_metadata), indent=2, sort_keys=True) + + +def validate_args(args: argparse.Namespace) -> None: + if not 0.0 < args.posterior_lower_quantile < 1.0: + raise ValueError("--posterior-lower-quantile must be between 0 and 1.") + if args.difficulty_buckets < 0: + raise ValueError("--difficulty-buckets must be non-negative.") + if args.max_instances is not None and args.max_instances <= 0: + raise ValueError("--max-instances must be positive when provided.") + + +def group_rows_by_task_and_model(rows: list[DifficultyRow]) -> dict[tuple[str, str | None], list[DifficultyRow]]: + rows_by_group: dict[tuple[str, str | None], list[DifficultyRow]] = defaultdict(list) + for row in rows: + task_name = row.task_name + model_name = row.experiment_metadata.model_name + rows_by_group[(task_name, model_name)].append(row) + return dict(rows_by_group) + + +def get_base_task_name(task_name: str) -> str: + return task_name.split("@", 1)[0].split(":", 1)[0] + + +def extract_binary_counts(attempt_scores: list[float]) -> tuple[int, int] | None: + if not attempt_scores: + return None + + success_count = 0 + for score in attempt_scores: + if math.isclose(score, 0.0, rel_tol=EPS, abs_tol=EPS): + continue + if math.isclose(score, 1.0, rel_tol=EPS, abs_tol=EPS): + success_count += 1 + continue + return None + + return success_count, len(attempt_scores) + + +def make_jsonable(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if is_dataclass(value): + return {key: make_jsonable(item) for key, item in asdict(value).items()} + if isinstance(value, list): + return [make_jsonable(item) for item in value] + if isinstance(value, tuple): + return [make_jsonable(item) for item in value] + if isinstance(value, dict): + return { + (key if isinstance(key, str) else "" if key is None else str(key)): make_jsonable(item) + for key, item in value.items() + } + return "" if value is None else str(value) + + +if __name__ == "__main__": + main() diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc.sh new file mode 100755 index 0000000000..d9eae4861c --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc.sh @@ -0,0 +1,151 @@ +#!/bin/bash +set -euo pipefail + +EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum}" +RUN_NAME="${RUN_NAME:-${EXP_NAME}_$(date +%Y%m%d_%H%M%S)}" + +NUM_GPUS="${NUM_GPUS:-8}" +BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" +if [[ $# -gt 0 ]]; then + shift +fi + +CLUSTER="${CLUSTER:-ai2/jupiter}" +PRIORITY="${PRIORITY:-urgent}" +WORKSPACE="${WORKSPACE:-ai2/open-instruct-dev}" +BUDGET="${BUDGET:-ai2/oe-omai}" + +# Difficulty-annotated variant of hamishivi/DAPO-Math-17k-Processed_filtered +DATASET_WITH_DIFFICULTY="undfined/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples-ds" + +TOTAL_EPISODES="${TOTAL_EPISODES:-128000}" +NUM_SAMPLES_PER_PROMPT_ROLLOUT="${NUM_SAMPLES_PER_PROMPT_ROLLOUT:-16}" +NUM_UNIQUE_PROMPTS_ROLLOUT="${NUM_UNIQUE_PROMPTS_ROLLOUT:-8}" +LOCAL_EVAL_EVERY="${LOCAL_EVAL_EVERY:-100}" +SAVE_FREQ="${SAVE_FREQ:-100}" +CHECKPOINT_STATE_FREQ="${CHECKPOINT_STATE_FREQ:-100}" + +NUM_TRAINING_STEPS=$(( TOTAL_EPISODES / (NUM_UNIQUE_PROMPTS_ROLLOUT * NUM_SAMPLES_PER_PROMPT_ROLLOUT) )) + +# Keep the bootstrap aligned with the first logging/eval window by default. +CURRICULUM_BOOTSTRAP_STEPS="${CURRICULUM_BOOTSTRAP_STEPS:-${LOCAL_EVAL_EVERY}}" +CURRICULUM_WARMUP_STEPS="${CURRICULUM_WARMUP_STEPS:-${CURRICULUM_BOOTSTRAP_STEPS}}" +if (( NUM_TRAINING_STEPS <= CURRICULUM_WARMUP_STEPS )); then + DEFAULT_CURRICULUM_TOTAL_STEPS=1 +else + DEFAULT_CURRICULUM_TOTAL_STEPS=$(( NUM_TRAINING_STEPS - CURRICULUM_WARMUP_STEPS )) +fi +CURRICULUM_TOTAL_STEPS="${CURRICULUM_TOTAL_STEPS:-${DEFAULT_CURRICULUM_TOTAL_STEPS}}" +CURRICULUM_METADATA_FIELD="${CURRICULUM_METADATA_FIELD:-difficulty}" +CURRICULUM_POSTERIOR_MEAN_FIELD="${CURRICULUM_POSTERIOR_MEAN_FIELD:-posterior_mean}" +CURRICULUM_BUCKET_INDEX_FIELD="${CURRICULUM_BUCKET_INDEX_FIELD:-bucket_index}" +CURRICULUM_BUCKET_COUNT_FIELD="${CURRICULUM_BUCKET_COUNT_FIELD:-bucket_count}" +CURRICULUM_QUANTILE_FIELD="${CURRICULUM_QUANTILE_FIELD:-expected_quantile}" +CURRICULUM_STRICT_METADATA="${CURRICULUM_STRICT_METADATA:-true}" +CURRICULUM_BOOTSTRAP_TARGET="${CURRICULUM_BOOTSTRAP_TARGET:-0.125}" +CURRICULUM_WARMUP_TARGET="${CURRICULUM_WARMUP_TARGET:-0.5}" +CURRICULUM_FINAL_TARGET="${CURRICULUM_FINAL_TARGET:-1.0}" +CURRICULUM_MIN_HARD_FRAC="${CURRICULUM_MIN_HARD_FRAC:-0.05}" +CURRICULUM_MAX_HARD_FRAC="${CURRICULUM_MAX_HARD_FRAC:-0.50}" +CURRICULUM_BUCKET_SIGMA="${CURRICULUM_BUCKET_SIGMA:-0.0}" +CURRICULUM_BOOTSTRAP_SIGMA="${CURRICULUM_BOOTSTRAP_SIGMA:-0.0}" +CURRICULUM_UNCERTAINTY_WEIGHT="${CURRICULUM_UNCERTAINTY_WEIGHT:-0.5}" +CURRICULUM_ADAPTIVE="${CURRICULUM_ADAPTIVE:-false}" +CURRICULUM_ADAPTIVE_UPDATE_EVERY="${CURRICULUM_ADAPTIVE_UPDATE_EVERY:-50}" +CURRICULUM_ADAPTIVE_LEARNING_WEIGHT="${CURRICULUM_ADAPTIVE_LEARNING_WEIGHT:-0.7}" +CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT="${CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT:-0.3}" +CURRICULUM_ADAPTIVE_BLEND="${CURRICULUM_ADAPTIVE_BLEND:-0.5}" +CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.0}" +CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" + +CURRICULUM_ARGS=( + --curriculum difficulty + --curriculum_metadata_field "${CURRICULUM_METADATA_FIELD}" + --curriculum_posterior_mean_field "${CURRICULUM_POSTERIOR_MEAN_FIELD}" + --curriculum_bucket_index_field "${CURRICULUM_BUCKET_INDEX_FIELD}" + --curriculum_bucket_count_field "${CURRICULUM_BUCKET_COUNT_FIELD}" + --curriculum_quantile_field "${CURRICULUM_QUANTILE_FIELD}" + --curriculum_bootstrap_steps "${CURRICULUM_BOOTSTRAP_STEPS}" + --curriculum_bootstrap_target "${CURRICULUM_BOOTSTRAP_TARGET}" + --curriculum_warmup_target "${CURRICULUM_WARMUP_TARGET}" + --curriculum_final_target "${CURRICULUM_FINAL_TARGET}" + --curriculum_warmup_steps "${CURRICULUM_WARMUP_STEPS}" + --curriculum_total_steps "${CURRICULUM_TOTAL_STEPS}" + --curriculum_min_hard_frac "${CURRICULUM_MIN_HARD_FRAC}" + --curriculum_max_hard_frac "${CURRICULUM_MAX_HARD_FRAC}" + --curriculum_bucket_sigma "${CURRICULUM_BUCKET_SIGMA}" + --curriculum_bootstrap_sigma "${CURRICULUM_BOOTSTRAP_SIGMA}" + --curriculum_uncertainty_weight "${CURRICULUM_UNCERTAINTY_WEIGHT}" + --curriculum_adaptive "${CURRICULUM_ADAPTIVE}" + --curriculum_adaptive_update_every "${CURRICULUM_ADAPTIVE_UPDATE_EVERY}" + --curriculum_adaptive_learning_weight "${CURRICULUM_ADAPTIVE_LEARNING_WEIGHT}" + --curriculum_adaptive_exploration_weight "${CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT}" + --curriculum_adaptive_blend "${CURRICULUM_ADAPTIVE_BLEND}" + --curriculum_min_quantile "${CURRICULUM_MIN_QUANTILE}" + --curriculum_max_quantile "${CURRICULUM_MAX_QUANTILE}" + --curriculum_strict_metadata "${CURRICULUM_STRICT_METADATA}" +) + +uv run python mason.py \ + --task_name ${EXP_NAME} \ + --description "${RUN_NAME}" \ + --cluster ${CLUSTER} \ + --workspace ${WORKSPACE} \ + --priority ${PRIORITY} \ + --pure_docker_mode \ + --no_auto_dataset_cache \ + --image ${BEAKER_IMAGE} \ + --preemptible \ + --num_nodes 1 \ + --env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ + --gpus $NUM_GPUS \ + --budget ${BUDGET} \ + -- \ +uv run open_instruct/grpo_fast.py \ + --run_name "${RUN_NAME}" \ + --exp_name "${EXP_NAME}" \ + --eval_pass_at_k 32 \ + --eval_top_p 0.95 \ + --vllm_top_p 1.0 \ + --beta 0.0 \ + --async_steps 4 \ + --active_sampling \ + --inflight_updates \ + --advantage_normalization_type centered \ + --num_samples_per_prompt_rollout ${NUM_SAMPLES_PER_PROMPT_ROLLOUT} \ + --num_unique_prompts_rollout ${NUM_UNIQUE_PROMPTS_ROLLOUT} \ + --num_mini_batches 1 \ + --learning_rate 1e-6 \ + --per_device_train_batch_size 1 \ + --dataset_mixer_list "${DATASET_WITH_DIFFICULTY}" 1.0 \ + --dataset_mixer_list_splits "train" \ + --dataset_mixer_eval_list allenai/aime_2025_openinstruct 1.0 allenai/brumo_2025_openinstruct 1.0 \ + --dataset_mixer_eval_list_splits "train" \ + --max_prompt_token_length 2048 \ + --response_length 8192 \ + --pack_length 10240 \ + --model_name_or_path "Qwen/Qwen3-4B-Base" \ + --non_stop_penalty False \ + --temperature 1.0 \ + --total_episodes ${TOTAL_EPISODES} \ + --deepspeed_stage 2 \ + --num_learners_per_node 4 \ + --vllm_num_engines 4 \ + --vllm_tensor_parallel_size 1 \ + --lr_scheduler_type constant \ + --apply_verifiable_reward true \ + --seed 1 \ + --local_eval_every ${LOCAL_EVAL_EVERY} \ + --save_freq ${SAVE_FREQ} \ + --checkpoint_state_freq ${CHECKPOINT_STATE_FREQ} \ + --gradient_checkpointing \ + --with_tracking \ + --send_slack_alerts \ + --vllm_enable_prefix_caching \ + --clip_higher 0.272 \ + --mask_truncated_completions False \ + --chat_template qwen_instruct_user_boxed_math \ + --load_ref_policy False \ + --keep_last_n_checkpoints -1 \ + --push_to_hub False \ + "${CURRICULUM_ARGS[@]}" "$@" diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light.sh new file mode 100755 index 0000000000..baa1173007 --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_adaptive_light}" +export CURRICULUM_ADAPTIVE="${CURRICULUM_ADAPTIVE:-true}" +export CURRICULUM_ADAPTIVE_UPDATE_EVERY="${CURRICULUM_ADAPTIVE_UPDATE_EVERY:-20}" +export CURRICULUM_ADAPTIVE_BLEND="${CURRICULUM_ADAPTIVE_BLEND:-0.25}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh new file mode 100755 index 0000000000..cec1aff22c --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Keep under W&B's 64-character tag limit; EXP_NAME is used as a run tag. +export EXP_NAME="${EXP_NAME:-qwen3_4b_dapo_dc_adaptive_light_hardest50}" +export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" +export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" +# After filtering out the easy half, start bootstrap at the easiest remaining +# bucket instead of inheriting the base global target near bucket 0. +export CURRICULUM_BOOTSTRAP_TARGET="${CURRICULUM_BOOTSTRAP_TARGET:-0.5}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc_adaptive_light.sh" "$@" diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong.sh new file mode 100755 index 0000000000..0dfbe9bef9 --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_adaptive_strong}" +export CURRICULUM_ADAPTIVE="${CURRICULUM_ADAPTIVE:-true}" +# Pull harder toward live curriculum stats than the light variant. +export CURRICULUM_ADAPTIVE_UPDATE_EVERY="${CURRICULUM_ADAPTIVE_UPDATE_EVERY:-10}" +export CURRICULUM_ADAPTIVE_LEARNING_WEIGHT="${CURRICULUM_ADAPTIVE_LEARNING_WEIGHT:-0.9}" +export CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT="${CURRICULUM_ADAPTIVE_EXPLORATION_WEIGHT:-0.1}" +export CURRICULUM_ADAPTIVE_BLEND="${CURRICULUM_ADAPTIVE_BLEND:-0.75}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh new file mode 100755 index 0000000000..2dfe1abdf1 --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Keep under W&B's 64-character tag limit; EXP_NAME is used as a run tag. +export EXP_NAME="${EXP_NAME:-qwen3_4b_dapo_dc_adaptive_strong_hardest50}" +export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" +export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" +# After filtering out the easy half, start bootstrap at the easiest remaining +# bucket instead of inheriting the base global target near bucket 0. +export CURRICULUM_BOOTSTRAP_TARGET="${CURRICULUM_BOOTSTRAP_TARGET:-0.5}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc_adaptive_strong.sh" "$@" diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh new file mode 100755 index 0000000000..61a43cbd3b --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_hardest50}" +export CURRICULUM_MIN_QUANTILE="${CURRICULUM_MIN_QUANTILE:-0.5}" +export CURRICULUM_MAX_QUANTILE="${CURRICULUM_MAX_QUANTILE:-1.0}" +# After filtering out the easy half, start bootstrap at the easiest remaining +# bucket instead of inheriting the base global target near bucket 0. +export CURRICULUM_BOOTSTRAP_TARGET="${CURRICULUM_BOOTSTRAP_TARGET:-0.5}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh new file mode 100755 index 0000000000..b9d0a937e6 --- /dev/null +++ b/scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +export EXP_NAME="${EXP_NAME:-qwen3_4b_base_dapo_difficulty_curriculum_longer_easy_bootstrap}" +export CURRICULUM_BOOTSTRAP_STEPS="${CURRICULUM_BOOTSTRAP_STEPS:-200}" +export CURRICULUM_WARMUP_STEPS="${CURRICULUM_WARMUP_STEPS:-200}" + +exec bash "${SCRIPT_DIR}/qwen3_4b_dapo_math_dc.sh" "$@" diff --git a/tests/test_create_difficulty_map.py b/tests/test_create_difficulty_map.py new file mode 100644 index 0000000000..841a0da407 --- /dev/null +++ b/tests/test_create_difficulty_map.py @@ -0,0 +1,499 @@ +"""Unit tests for posterior-aware bucketing in create_difficulty_map.py.""" + +import importlib.util +import json +import math +import sys +import types +import unittest +from collections import Counter +from pathlib import Path +from statistics import NormalDist +from unittest.mock import patch + +import numpy as np + +SCRIPT_PATH = Path(__file__).resolve().parents[1] / "scripts/data/difficulty_sampling/create_difficulty_map.py" + + +def _load_create_difficulty_map_module(): + fake_datasets = types.ModuleType("datasets") + fake_datasets.Dataset = type("Dataset", (), {}) + fake_datasets.load_dataset = lambda *_args, **_kwargs: None + + fake_scipy = types.ModuleType("scipy") + fake_scipy_optimize = types.ModuleType("scipy.optimize") + fake_scipy_optimize.minimize = lambda *_args, **_kwargs: types.SimpleNamespace( + success=False, message="not implemented in test stub", x=(0.0, 0.0) + ) + fake_scipy_special = types.ModuleType("scipy.special") + fake_scipy_special.betaln = lambda alpha, beta: math.lgamma(alpha) + math.lgamma(beta) - math.lgamma(alpha + beta) + fake_scipy_stats = types.ModuleType("scipy.stats") + + class _ApproximateBetaDistribution: + @staticmethod + def _mean(alpha, beta): + return alpha / (alpha + beta) + + @staticmethod + def _sigma(alpha, beta): + total = alpha + beta + variance = np.maximum(alpha * beta / (total * total * (total + 1.0)), 1e-6) + return np.sqrt(variance) + + @classmethod + def cdf(cls, x, alpha, beta): + x_array, alpha_array, beta_array = np.broadcast_arrays( + np.asarray(x, dtype=float), np.asarray(alpha, dtype=float), np.asarray(beta, dtype=float) + ) + mean = cls._mean(alpha_array, beta_array) + sigma = cls._sigma(alpha_array, beta_array) + z_scores = (x_array - mean) / (sigma * math.sqrt(2.0)) + return 0.5 * (1.0 + np.vectorize(math.erf)(z_scores)) + + @classmethod + def pdf(cls, x, alpha, beta): + x_array, alpha_array, beta_array = np.broadcast_arrays( + np.asarray(x, dtype=float), np.asarray(alpha, dtype=float), np.asarray(beta, dtype=float) + ) + mean = cls._mean(alpha_array, beta_array) + sigma = cls._sigma(alpha_array, beta_array) + z_scores = (x_array - mean) / sigma + normalizer = sigma * math.sqrt(2.0 * math.pi) + return np.exp(-0.5 * z_scores * z_scores) / normalizer + + @classmethod + def ppf(cls, q, alpha, beta): + q_array, alpha_array, beta_array = np.broadcast_arrays( + np.asarray(q, dtype=float), np.asarray(alpha, dtype=float), np.asarray(beta, dtype=float) + ) + quantiles = np.empty_like(q_array, dtype=float) + for index in np.ndindex(q_array.shape): + mean = float(cls._mean(alpha_array[index], beta_array[index])) + sigma = float(cls._sigma(alpha_array[index], beta_array[index])) + quantiles[index] = np.clip(NormalDist(mu=mean, sigma=sigma).inv_cdf(float(q_array[index])), 0.0, 1.0) + return quantiles + + fake_scipy_stats.beta = _ApproximateBetaDistribution + + modules = { + "datasets": fake_datasets, + "scipy": fake_scipy, + "scipy.optimize": fake_scipy_optimize, + "scipy.special": fake_scipy_special, + "scipy.stats": fake_scipy_stats, + } + module_name = "test_create_difficulty_map_module" + spec = importlib.util.spec_from_file_location(module_name, SCRIPT_PATH) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) + + with patch.dict(sys.modules, modules): + sys.modules.pop(module_name, None) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +MODULE = _load_create_difficulty_map_module() + + +class TestCreateDifficultyMap(unittest.TestCase): + class FakeHFDataset: + def __init__(self, rows): + self._rows = [dict(row) for row in rows] + + def __getitem__(self, index): + return self._rows[index] + + def __iter__(self): + return iter(self._rows) + + def __len__(self): + return len(self._rows) + + @property + def column_names(self): + return list(self._rows[0].keys()) if self._rows else [] + + def select(self, indices): + return TestCreateDifficultyMap.FakeHFDataset([self._rows[index] for index in indices]) + + def remove_columns(self, column_names): + names = {column_names} if isinstance(column_names, str) else set(column_names) + return TestCreateDifficultyMap.FakeHFDataset( + [{key: value for key, value in row.items() if key not in names} for row in self._rows] + ) + + def add_column(self, name, values): + return TestCreateDifficultyMap.FakeHFDataset( + [{**row, name: value} for row, value in zip(self._rows, values, strict=True)] + ) + + def make_row( + self, + *, + source_row_index=0, + instance_id="row-0", + task_name="math", + source_dataset="mnoukhov/demo", + source_row_id="row-0", + attempt_scores=None, + model_name="demo-model", + warnings=None, + difficulty=None, + ): + return MODULE.DifficultyRow( + source_row_index=source_row_index, + instance_id=instance_id, + task_name=task_name, + base_task_name=MODULE.get_base_task_name(task_name), + source_dataset=source_dataset, + source_row_id=source_row_id, + attempt_scores=list(attempt_scores or []), + finish_reasons=[], + experiment_metadata=MODULE.ExperimentMetadata( + source_root=f"hf://{source_dataset}/default/train", + model_name=model_name, + experiment_id=None, + experiment_name=source_dataset, + ), + score_sources=[task_name], + warnings=list(warnings or []), + difficulty=difficulty or MODULE.DifficultyPayload(), + ) + + def test_parser_requires_hf_dataset_and_rejects_source(self): + with self.assertRaises(SystemExit): + MODULE.make_parser().parse_args(["--output", "/tmp/difficulty"]) + + with self.assertRaises(SystemExit): + MODULE.make_parser().parse_args(["--source", "/tmp/rollouts", "--output", "/tmp/difficulty"]) + + def test_normalize_attempt_scores_for_group_marks_unsupported_rewards(self): + rows = [self.make_row(instance_id="example", source_row_id="example", attempt_scores=[10.0, 5.0])] + + kept_rows, score_processing, skipped_nonunit = MODULE.normalize_attempt_scores_for_group( + rows, allow_nonunit_scores=True + ) + + self.assertEqual(skipped_nonunit, 0) + self.assertFalse(score_processing.supports_binary_difficulty) + self.assertEqual(kept_rows[0].attempt_scores, [10.0, 5.0]) + self.assertIn("nonbinary_reward_scores", kept_rows[0].warnings) + + dropped_rows, _, dropped_count = MODULE.normalize_attempt_scores_for_group(rows, allow_nonunit_scores=False) + + self.assertEqual(dropped_rows, []) + self.assertEqual(dropped_count, 1) + + def test_build_dataset_metadata_captures_difficulty_generation_details(self): + rows = [ + self.make_row( + instance_id="easy", + source_row_id="easy", + difficulty=MODULE.DifficultyPayload( + value=0.1, + posterior_mean=0.2, + posterior_lower_bound=0.9, + expected_quantile=0.2, + bucket_index=0, + bucket_count=3, + ), + ), + self.make_row( + source_row_index=1, + instance_id="hard", + source_row_id="hard", + difficulty=MODULE.DifficultyPayload( + value=0.8, + posterior_mean=0.7, + posterior_lower_bound=0.2, + expected_quantile=0.9, + bucket_index=2, + bucket_count=3, + ), + ), + self.make_row(source_row_index=2, instance_id="nonbinary", source_row_id="nonbinary"), + ] + + metadata = MODULE.build_dataset_metadata( + rows=rows, + task_name="math", + model_name="demo-model", + requested_prior_mode="empirical-bayes", + requested_bucket_count=5, + lower_quantile=0.1, + prior=MODULE.BetaPrior(alpha=0.75, beta=1.25, source="empirical_bayes"), + binary_row_count=2, + score_processing=MODULE.ScoreProcessingMetadata( + source_field="reward", + output_field="attempt_scores", + normalization="binary_zero_or_constant", + positive_reward_value=10.0, + supports_binary_difficulty=True, + ), + source_format=MODULE.build_hf_source_format_metadata( + dataset_name="mnoukhov/demo", + config_name=None, + split="train", + row_id_field="extra_info.index", + task_field="dataset", + model_field="generator_model", + pass_count_field="pass_count", + attempt_count_field="num_samples", + pass_rate_field="pass_rate", + ), + ) + + self.assertEqual(metadata.task_name, "math") + self.assertEqual(metadata.model_name, "demo-model") + self.assertEqual(metadata.row_count, 3) + self.assertEqual(metadata.source_format.kind, MODULE.HF_SOURCE_FORMAT_KIND) + self.assertEqual(metadata.score_processing.normalization, "binary_zero_or_constant") + self.assertEqual(metadata.score_processing.positive_reward_value, 10.0) + self.assertEqual(metadata.difficulty_generation.method, "beta_binomial_posterior_quantiles") + self.assertEqual(metadata.difficulty_generation.posterior_lower_quantile, 0.1) + self.assertEqual(metadata.difficulty_generation.bucket_count_requested, 5) + self.assertEqual(metadata.difficulty_generation.bucket_count_effective, 3) + self.assertEqual(metadata.difficulty_generation.beta_prior_used.source, "empirical_bayes") + self.assertEqual(metadata.difficulty_generation.beta_prior_used.alpha, 0.75) + self.assertEqual(metadata.difficulty_generation.beta_prior_used.beta, 1.25) + self.assertEqual(metadata.difficulty_generation.binary_instance_count, 2) + self.assertEqual(metadata.difficulty_generation.nonbinary_instance_count, 1) + + def test_build_hf_dataset_row_parses_pass_rate_counts(self): + row = MODULE.build_hf_dataset_row( + source_row={ + "dataset": "math", + "extra_info": {"index": "row-7"}, + "pass_count": 3, + "num_samples": 5, + "pass_rate": "3/5", + "generator_model": "Qwen/Qwen3-4B-Base", + }, + source_row_index=7, + dataset_name="mnoukhov/demo", + config_name=None, + split="train", + row_id_field="extra_info.index", + task_field="dataset", + model_field="generator_model", + pass_count_field="pass_count", + attempt_count_field="num_samples", + pass_rate_field="pass_rate", + ) + + self.assertEqual(row.instance_id, "mnoukhov/demo::row-7") + self.assertEqual(row.source_row_id, "row-7") + self.assertEqual(row.attempt_scores, [1.0, 1.0, 1.0, 0.0, 0.0]) + self.assertEqual(row.experiment_metadata.model_name, "Qwen/Qwen3-4B-Base") + self.assertEqual(row.experiment_metadata.source_root, "hf://mnoukhov/demo/default/train") + + def test_load_hf_dataset_rows_builds_bundle_and_filters_tasks(self): + fake_dataset = self.FakeHFDataset( + [ + { + "dataset": "math", + "extra_info": {"index": "math-1"}, + "pass_count": 2, + "num_samples": 4, + "pass_rate": "2/4", + "generator_model": "Qwen/Qwen3-4B-Base", + }, + { + "dataset": "gsm8k", + "extra_info": {"index": "gsm-1"}, + "pass_count": 1, + "num_samples": 4, + "pass_rate": "1/4", + "generator_model": "Qwen/Qwen3-4B-Base", + }, + ] + ) + + with patch.object(MODULE, "load_dataset", return_value=fake_dataset): + bundle = MODULE.load_hf_dataset_rows( + dataset_name="mnoukhov/demo", + config_name=None, + split="train", + task_filters={"math"}, + strict=True, + row_id_field="extra_info.index", + task_field="dataset", + model_field="generator_model", + pass_count_field="pass_count", + attempt_count_field="num_samples", + pass_rate_field="pass_rate", + ) + + self.assertEqual(bundle.malformed_records, 0) + self.assertEqual(bundle.source_format.kind, MODULE.HF_SOURCE_FORMAT_KIND) + self.assertEqual(bundle.source_format.dataset_repo_id, "mnoukhov/demo") + self.assertEqual(len(bundle.rows), 1) + self.assertEqual(bundle.rows[0].instance_id, "mnoukhov/demo::math-1") + self.assertEqual(bundle.rows[0].attempt_scores, [1.0, 1.0, 0.0, 0.0]) + + def test_build_hf_output_dataset_preserves_source_rows_and_order(self): + source_dataset = self.FakeHFDataset( + [ + {"prompt": "first", "extra_info": {"index": "row-0"}}, + {"prompt": "second", "extra_info": {"index": "row-1"}}, + ] + ) + rows = [ + self.make_row( + source_row_index=1, + instance_id="mnoukhov/demo::row-1", + source_row_id="row-1", + attempt_scores=[0.0, 0.0], + model_name="Qwen/Qwen3-4B-Base", + difficulty=MODULE.DifficultyPayload( + value=0.9, + posterior_mean=0.1, + posterior_lower_bound=0.1, + expected_quantile=0.9, + bucket_index=1, + bucket_count=2, + ), + ), + self.make_row( + source_row_index=0, + instance_id="mnoukhov/demo::row-0", + source_row_id="row-0", + attempt_scores=[1.0, 1.0], + model_name="Qwen/Qwen3-4B-Base", + difficulty=MODULE.DifficultyPayload( + value=0.1, + posterior_mean=0.9, + posterior_lower_bound=0.9, + expected_quantile=0.1, + bucket_index=0, + bucket_count=2, + ), + ), + ] + + output_rows, dataset = MODULE.build_hf_output_dataset(source_dataset, rows) + + self.assertEqual(len(dataset), 2) + self.assertEqual(dataset.column_names, ["prompt", "extra_info", "difficulty"]) + self.assertEqual(output_rows[0]["source_row_id"], "row-0") + self.assertEqual(output_rows[1]["source_row_id"], "row-1") + self.assertEqual(dataset[0]["prompt"], "first") + self.assertEqual(dataset[0]["difficulty"]["bucket_index"], 0) + self.assertEqual(dataset[1]["prompt"], "second") + self.assertEqual(dataset[1]["difficulty"]["bucket_index"], 1) + + def test_annotate_dataset_metadata_stores_json_description(self): + class FakeInfo: + description = "" + + class FakeDataset: + def __init__(self): + self.info = FakeInfo() + + dataset = FakeDataset() + dataset_metadata = MODULE.build_dataset_metadata( + rows=[self.make_row(difficulty=MODULE.DifficultyPayload(bucket_index=0, bucket_count=5))], + task_name="math", + model_name="demo-model", + requested_prior_mode="empirical-bayes", + requested_bucket_count=5, + lower_quantile=0.1, + prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="empirical_bayes"), + binary_row_count=1, + score_processing=MODULE.ScoreProcessingMetadata( + source_field="pass_count,num_samples,pass_rate", + output_field="attempt_scores", + normalization="identity_binary", + positive_reward_value=1.0, + supports_binary_difficulty=True, + ), + source_format=MODULE.build_hf_source_format_metadata( + dataset_name="mnoukhov/demo", + config_name=None, + split="train", + row_id_field="extra_info.index", + task_field="dataset", + model_field="generator_model", + pass_count_field="pass_count", + attempt_count_field="num_samples", + pass_rate_field="pass_rate", + ), + ) + + MODULE.annotate_dataset_metadata(dataset, dataset_metadata) + + self.assertEqual(json.loads(dataset.info.description), MODULE.make_jsonable(dataset_metadata)) + + def test_apply_beta_binomial_difficulty_orders_rows_by_expected_quantile(self): + rows = [ + self.make_row(instance_id="easy", source_row_id="easy", attempt_scores=[1.0, 1.0, 1.0, 1.0]), + self.make_row( + source_row_index=1, instance_id="medium", source_row_id="medium", attempt_scores=[1.0, 1.0, 0.0, 0.0] + ), + self.make_row( + source_row_index=2, instance_id="hard", source_row_id="hard", attempt_scores=[0.0, 0.0, 0.0, 0.0] + ), + ] + + result = MODULE.apply_beta_binomial_difficulty( + rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=3 + ) + difficulties = {row.instance_id: row.difficulty for row in result} + + self.assertLess(difficulties["easy"].expected_quantile, difficulties["medium"].expected_quantile) + self.assertLess(difficulties["medium"].expected_quantile, difficulties["hard"].expected_quantile) + self.assertEqual(difficulties["easy"].bucket_index, 0) + self.assertEqual(difficulties["medium"].bucket_index, 1) + self.assertEqual(difficulties["hard"].bucket_index, 2) + self.assertTrue(all(difficulty.bucket_count == 3 for difficulty in difficulties.values())) + + def test_apply_beta_binomial_difficulty_balances_bucket_sizes(self): + rows = [ + self.make_row(instance_id="easiest", source_row_id="easiest", attempt_scores=[1.0, 1.0, 1.0, 1.0]), + self.make_row( + source_row_index=1, instance_id="easy", source_row_id="easy", attempt_scores=[1.0, 1.0, 1.0, 0.0] + ), + self.make_row( + source_row_index=2, instance_id="mid", source_row_id="mid", attempt_scores=[1.0, 1.0, 0.0, 0.0] + ), + self.make_row( + source_row_index=3, instance_id="hard", source_row_id="hard", attempt_scores=[1.0, 0.0, 0.0, 0.0] + ), + self.make_row( + source_row_index=4, instance_id="hardest", source_row_id="hardest", attempt_scores=[0.0, 0.0, 0.0, 0.0] + ), + ] + + result = MODULE.apply_beta_binomial_difficulty( + rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=2 + ) + bucket_counts = Counter(row.difficulty.bucket_index for row in result) + + self.assertEqual(bucket_counts[0], 3) + self.assertEqual(bucket_counts[1], 2) + + def test_apply_beta_binomial_difficulty_leaves_nonbinary_rows_unbucketed(self): + rows = [ + self.make_row(instance_id="easy", source_row_id="easy", attempt_scores=[1.0, 1.0]), + self.make_row( + source_row_index=1, instance_id="nonbinary", source_row_id="nonbinary", attempt_scores=[0.5, 1.0] + ), + self.make_row(source_row_index=2, instance_id="hard", source_row_id="hard", attempt_scores=[0.0, 0.0]), + ] + + result = MODULE.apply_beta_binomial_difficulty( + rows, prior=MODULE.BetaPrior(alpha=0.5, beta=0.5, source="test"), lower_quantile=0.1, num_buckets=2 + ) + difficulties = {row.instance_id: row.difficulty for row in result} + + self.assertIsNone(difficulties["nonbinary"].value) + self.assertIsNone(difficulties["nonbinary"].expected_quantile) + self.assertIsNone(difficulties["nonbinary"].bucket_index) + self.assertIsNone(difficulties["nonbinary"].bucket_count) + + +if __name__ == "__main__": + unittest.main()