Skip to content

fix(distillation): teacher block drops user policy overrides via YAML anchor#2504

Open
bzantium wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
bzantium:fix/distillation-teacher-tracks-policy-overrides
Open

fix(distillation): teacher block drops user policy overrides via YAML anchor#2504
bzantium wants to merge 1 commit into
NVIDIA-NeMo:mainfrom
bzantium:fix/distillation-teacher-tracks-policy-overrides

Conversation

@bzantium
Copy link
Copy Markdown

What does this PR do ?

examples/configs/distillation_math.yaml and examples/configs/distillation_math_megatron.yaml define policy: &POLICY_BASE and merge it into teacher: with <<: *POLICY_BASE. YAML's merge operator is a load-time copy, so fields like max_total_sequence_length, train_micro_batch_size, logprob_batch_size, and precision are frozen on the teacher at parse time. Any user override of policy.<field> in a derived recipe leaves the teacher reading the original anchor value, with no warning at config load.

The interpolations inside policy.dynamic_batching are relative (${..max_total_sequence_length}) and re-anchor to teacher when copied, so they read the stale frozen value. The result is a silent mismatch:

policy:
  max_total_sequence_length: 32768   # user override
  generation:
    max_new_tokens: 8192

teacher:
  # no override -> still 8192 frozen from the anchor
  # teacher.dynamic_batching.logprob_mb_tokens = 8192 * 1 = 8192 (stale)

Training runs cleanly until a student rollout actually exceeds the teacher's frozen sequence length, at which point the teacher's logprob inference asserts deep in shard_by_batch_size:

File ".../nemo_rl/distributed/batched_data_dict.py", line 635, in shard_by_batch_size
    assert max_seqlen_this_shard_indice <= max_tokens_per_microbatch, (
AssertionError: got an input of padded (64) sequence length of 8260, however
max microbatch size is 8192 tokens

Hours of GPU time can be wasted before the assertion fires.

Fix

Re-bind the four fields whose semantics require the teacher to track the policy with absolute interpolation (${policy.<field>}). OmegaConf resolves interpolations at use time, so user overrides on policy.<field> propagate. An explicit teacher.<field> override still wins because OmegaConf prefers the most specific value at the call site.

teacher:
    <<: *POLICY_BASE
    model_name: \"Qwen/Qwen3-4B\"
    max_total_sequence_length: ${policy.max_total_sequence_length}
    train_micro_batch_size: ${policy.train_micro_batch_size}
    logprob_batch_size: ${policy.logprob_batch_size}
    precision: ${policy.precision}
    dtensor_cfg:
        ...

Same patch applies to both DTensor and Megatron exemplars; both carry the identical anchor pattern.

Why these four fields

Field Why teacher must track policy
max_total_sequence_length Teacher must accept whatever sequences the student rolls out, otherwise shard_by_batch_size asserts. Also feeds dynamic_batching.{train_mb_tokens, logprob_mb_tokens}.
train_micro_batch_size, logprob_batch_size Both feed the dynamic_batching and sequence_packing token budgets via ${..max_total_sequence_length} * ${..<batch_size>}.
precision Student rollout dtype must match teacher logit dtype for the KL loss arithmetic to stay numerically consistent.

A teacher with materially different architecture (different vocab, different TP/CP shape) still needs its own dtensor_cfg / megatron_cfg overrides; those are inherently teacher-shaped and not interpolated.

Files touched

File Change
examples/configs/distillation_math.yaml Add four ${policy.<field>} interpolations to the teacher: block.
examples/configs/distillation_math_megatron.yaml Same change to the Megatron exemplar.

Issues

None filed; this is a small consistency fix discovered while running long-context distillation with a policy.max_total_sequence_length override.

Usage

After this patch, a derived recipe only needs to override on policy::

defaults: ../../examples/configs/distillation_math.yaml

policy:
  max_total_sequence_length: 32768
  generation:
    max_new_tokens: 8192

teacher:
  model_name: ...                # other teacher overrides as needed
  # max_total_sequence_length etc. now follow policy automatically

Recipes that already explicitly override teacher.<field> keep working unchanged — OmegaConf still honours the most specific value.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • No new tests. Happy to add an OmegaConf-load test that overrides policy.max_total_sequence_length=16384 on the merged config and asserts teacher.max_total_sequence_length == 16384 if reviewers prefer.
  • Backwards compatible: recipes that do not override the affected fields on policy: get the same effective values as before.

`examples/configs/distillation_math.yaml` and
`examples/configs/distillation_math_megatron.yaml` define
`policy: &POLICY_BASE` and merge it into `teacher:` with
`<<: *POLICY_BASE`. YAML's merge operator is a load-time copy, so
fields like `max_total_sequence_length`, `train_micro_batch_size`,
`logprob_batch_size`, and `precision` are frozen on the teacher at
parse time. A user who overrides `policy.max_total_sequence_length`
in a derived recipe leaves the teacher on the literal anchor value
(8192) and creates a silent mismatch. The relative interpolations
inside `policy.dynamic_batching` (`${..max_total_sequence_length}`)
re-anchor to teacher and read the stale 8192, so
`teacher.dynamic_batching.logprob_mb_tokens` ends up wrong.

The mismatch surfaces only mid-training when a student rollout
exceeds the teacher's frozen sequence length, with an
`AssertionError` deep in `shard_by_batch_size` after model loading
and rollout. Hours of GPU time wasted per occurrence.

Re-bind the four fields whose semantics require the teacher to track
the policy with absolute interpolation (`${policy.<field>}`).
OmegaConf resolves interpolations at use time, so user overrides on
`policy.<field>` now propagate. An explicit `teacher.<field>`
override still wins because OmegaConf prefers the most specific
value at the call site.

Both `distillation_math.yaml` (DTensor exemplar) and
`distillation_math_megatron.yaml` (Megatron exemplar) carry the same
anchor pattern and need the same fix.

Backwards compatible: recipes that do not override the affected
fields on `policy:` get the same effective values as before.

Signed-off-by: Minho Ryu <ryumin93@gmail.com>
@bzantium bzantium requested a review from a team as a code owner May 15, 2026 10:09
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants