fix(distillation): teacher block drops user policy overrides via YAML anchor#2504
Open
bzantium wants to merge 1 commit into
Open
Conversation
`examples/configs/distillation_math.yaml` and
`examples/configs/distillation_math_megatron.yaml` define
`policy: &POLICY_BASE` and merge it into `teacher:` with
`<<: *POLICY_BASE`. YAML's merge operator is a load-time copy, so
fields like `max_total_sequence_length`, `train_micro_batch_size`,
`logprob_batch_size`, and `precision` are frozen on the teacher at
parse time. A user who overrides `policy.max_total_sequence_length`
in a derived recipe leaves the teacher on the literal anchor value
(8192) and creates a silent mismatch. The relative interpolations
inside `policy.dynamic_batching` (`${..max_total_sequence_length}`)
re-anchor to teacher and read the stale 8192, so
`teacher.dynamic_batching.logprob_mb_tokens` ends up wrong.
The mismatch surfaces only mid-training when a student rollout
exceeds the teacher's frozen sequence length, with an
`AssertionError` deep in `shard_by_batch_size` after model loading
and rollout. Hours of GPU time wasted per occurrence.
Re-bind the four fields whose semantics require the teacher to track
the policy with absolute interpolation (`${policy.<field>}`).
OmegaConf resolves interpolations at use time, so user overrides on
`policy.<field>` now propagate. An explicit `teacher.<field>`
override still wins because OmegaConf prefers the most specific
value at the call site.
Both `distillation_math.yaml` (DTensor exemplar) and
`distillation_math_megatron.yaml` (Megatron exemplar) carry the same
anchor pattern and need the same fix.
Backwards compatible: recipes that do not override the affected
fields on `policy:` get the same effective values as before.
Signed-off-by: Minho Ryu <ryumin93@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
examples/configs/distillation_math.yamlandexamples/configs/distillation_math_megatron.yamldefinepolicy: &POLICY_BASEand merge it intoteacher:with<<: *POLICY_BASE. YAML's merge operator is a load-time copy, so fields likemax_total_sequence_length,train_micro_batch_size,logprob_batch_size, andprecisionare frozen on the teacher at parse time. Any user override ofpolicy.<field>in a derived recipe leaves the teacher reading the original anchor value, with no warning at config load.The interpolations inside
policy.dynamic_batchingare relative (${..max_total_sequence_length}) and re-anchor to teacher when copied, so they read the stale frozen value. The result is a silent mismatch:Training runs cleanly until a student rollout actually exceeds the teacher's frozen sequence length, at which point the teacher's logprob inference asserts deep in
shard_by_batch_size:Hours of GPU time can be wasted before the assertion fires.
Fix
Re-bind the four fields whose semantics require the teacher to track the policy with absolute interpolation (
${policy.<field>}). OmegaConf resolves interpolations at use time, so user overrides onpolicy.<field>propagate. An explicitteacher.<field>override still wins because OmegaConf prefers the most specific value at the call site.Same patch applies to both DTensor and Megatron exemplars; both carry the identical anchor pattern.
Why these four fields
max_total_sequence_lengthshard_by_batch_sizeasserts. Also feedsdynamic_batching.{train_mb_tokens, logprob_mb_tokens}.train_micro_batch_size,logprob_batch_sizedynamic_batchingandsequence_packingtoken budgets via${..max_total_sequence_length} * ${..<batch_size>}.precisionA teacher with materially different architecture (different vocab, different TP/CP shape) still needs its own
dtensor_cfg/megatron_cfgoverrides; those are inherently teacher-shaped and not interpolated.Files touched
examples/configs/distillation_math.yaml${policy.<field>}interpolations to theteacher:block.examples/configs/distillation_math_megatron.yamlIssues
None filed; this is a small consistency fix discovered while running long-context distillation with a
policy.max_total_sequence_lengthoverride.Usage
After this patch, a derived recipe only needs to override on
policy::Recipes that already explicitly override
teacher.<field>keep working unchanged — OmegaConf still honours the most specific value.Before your PR is "Ready for review"
Pre checks:
Additional Information
policy.max_total_sequence_length=16384on the merged config and assertsteacher.max_total_sequence_length == 16384if reviewers prefer.policy:get the same effective values as before.