Adds difficulty sampling curriculum dataloader and dataset builder#1661
Open
undfined wants to merge 45 commits into
Open
Adds difficulty sampling curriculum dataloader and dataset builder#1661undfined wants to merge 45 commits into
undfined wants to merge 45 commits into
Conversation
…/difficulty-sampling
…/difficulty-sampling
There was a problem hiding this comment.
Pull request overview
Adds an end-to-end difficulty-aware curriculum pipeline: a dataset-side difficulty metadata builder (Beta-Binomial posterior + bucketing) and a training-side sampler/dataloader path that can shift prompt sampling over training steps (optionally adaptive based on observed rewards/advantages).
Changes:
- Add
create_difficulty_map.py+ docs/tests to generate per-rowdifficulty.*metadata from HF pass-rate aggregates and write JSONL + schema/metadata sidecars (optionally push to Hub). - Add
open_instruct/difficulty_curriculum.pysampler/config + integrate intoopen_instruct/data_loader.pyandgrpo_fast.pyCLI/config wiring. - Add launch scripts + reference curriculum artifacts under
configs/curriculum/..., and update changelog.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_create_difficulty_map.py | Unit tests for difficulty-map generation logic (with stubs for external deps). |
| scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc.sh | Example launch script enabling difficulty curriculum flags. |
| scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_longer_easy_bootstrap.sh | Variant launcher adjusting bootstrap/warmup steps. |
| scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_hardest50.sh | Variant launcher filtering to hardest quantiles. |
| scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong.sh | Variant launcher enabling stronger adaptive sampling. |
| scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_strong_hardest50.sh | Adaptive+hardest50 launcher variant. |
| scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light.sh | Variant launcher enabling lighter adaptive sampling. |
| scripts/train/qwen/difficult-curriculum/qwen3_4b_dapo_math_dc_adaptive_light_hardest50.sh | Adaptive-light + hardest50 launcher variant. |
| scripts/data/difficulty_sampling/README.md | User-facing documentation for difficulty metadata + curriculum usage. |
| scripts/data/difficulty_sampling/create_difficulty_map.py | Difficulty-map builder script (load HF dataset, estimate posterior, bucket, write/push outputs). |
| open_instruct/test_difficulty_curriculum.py | Tests for sampler behavior, quantile filtering, adaptive stats, and loader integration. |
| open_instruct/test_data_loader.py | Extends tests to ensure truncation masking keeps arrays/batch aligned. |
| open_instruct/grpo_fast.py | Wires curriculum args/config into GRPO entrypoint and DataPreparationActor creation. |
| open_instruct/difficulty_curriculum.py | New curriculum sampler implementation (schedule, weighting, adaptive reweighting, metrics, state). |
| open_instruct/data_loader.py | Adds curriculum-backed prompt dataloader + adaptive observation recording + curriculum metrics. |
| configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.schema.json | Reference schema artifact for difficulty-annotated dataset. |
| configs/curriculum/Qwen_Qwen3-4B-Base/math__Qwen_Qwen3-4B-Base__bbq-eb-q10-k5.metadata.json | Reference metadata artifact describing generation configuration. |
| CHANGELOG.md | Notes the new difficulty curriculum + builder feature. |
Comments suppressed due to low confidence (1)
open_instruct/data_loader.py:1542
advantages(andscores_per_prompt/mean_grouped_rewards) are computed beforemaybe_mask_truncated_completionsfilters out non-stoprollouts. When masking is enabled, truncated rollouts still influence the per-prompt mean/std used for advantage normalization, which biases the remaining trainable samples (and can make the grouping logic inconsistent if some rollouts are removed). Consider applying the truncation mask before computing per-prompt statistics, or recomputing per-prompt means/stds from the retained rollouts (grouping by prompt_id / index) so only trainable samples contribute to the advantage calculation.
scores = np.array(batch.scores)
scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout)
mean_grouped_rewards = scores_per_prompt.mean(axis=-1)
mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0)
std_grouped_rewards = scores_per_prompt.std(axis=-1)
std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0)
if self.config.advantage_normalization_type == "standard":
advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)
elif self.config.advantage_normalization_type == "centered":
advantages = scores - mean_grouped_rewards
else:
raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}")
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+1524
to
1529
| prompt_dataset_indices: list[int] = [] | ||
| if self.curriculum_dataloader is not None and batch.indices is not None: | ||
| prompt_dataset_indices = [ | ||
| int(index) for index in batch.indices[:: self.config.num_samples_per_prompt_rollout] | ||
| ] | ||
|
|
| BUDGET="${BUDGET:-ai2/oe-omai}" | ||
|
|
||
| # Difficulty-annotated variant of hamishivi/DAPO-Math-17k-Processed_filtered | ||
| DATASET_WITH_DIFFICULTY="undfined/dapo-math-17k-processed-filtered-qwen3-4b-base-32samples-ds" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a difficulty-map generation pipeline and a difficulty-aware prompt sampling path in the dataloader.
scripts/data/difficulty_sampling/create_difficulty_map.pyto build per-instance difficulty metadata from Hugging Face datasets with pass-count / attempt-count aggregates.difficulty.value,difficulty.posterior_mean,difficulty.posterior_lower_bound,difficulty.expected_quantile,difficulty.bucket_index, anddifficulty.bucket_count..schema.jsonand.metadata.jsonsidecars, and supports optional push-to-hub for a single task/model output group.open_instruct/difficulty_curriculum.py, which defines the difficulty curriculum config, metadata parsing, bucket schedule, within-bucket weighting, optional adaptive bucket reweighting, and quantile-based filtering over difficulty metadata.open_instruct/data_loader.pywithDifficultyCurriculumHFDataLoaderand prompt-loader integration so training can sample prompts through the difficulty curriculum instead of uniform reshuffling.configs/curriculum/Qwen_Qwen3-4B-Base/.scripts/data/difficulty_sampling/README.md.Tests
tests/test_create_difficulty_map.pyfor difficulty-map generation.open_instruct/test_difficulty_curriculum.pyfor curriculum sampling, parser wiring, filtering, and loader integration.