Skip to content

Preserve drafter vocab mapping when fine-tuning from a checkpoint#534

Open
luv-bansal wants to merge 1 commit intosgl-project:mainfrom
luv-bansal:fix-vocab-mapping
Open

Preserve drafter vocab mapping when fine-tuning from a checkpoint#534
luv-bansal wants to merge 1 commit intosgl-project:mainfrom
luv-bansal:fix-vocab-mapping

Conversation

@luv-bansal
Copy link
Copy Markdown

Summary

When fine-tuning an EAGLE-3 drafter from an existing checkpoint (via --ckpt-dir or --resume), scripts/train_eagle3.py currently regenerates the d2t / t2d vocab mapping from the new training data and overwrites the mapping loaded from the checkpoint.

This silently misaligns every lm_head slot with a different target token ID. Specifically, lm_head column i was trained to predict token i + d2t_old[i], but is now asked to predict i + d2t_new[i].

The result is a catastrophic drop in draft acceptance that is not visible in train/acc_* or eval/acc_*, since those metrics are computed against the new mapping rather than the true target IDs. The drafter effectively has to relearn the slot-to-token alignment from scratch, negating the benefit of warm-starting.

This PR makes the safe behavior the default: when a checkpoint is loaded, its d2t / t2d mappings (stored in model.safetensors as registered buffers) are reused, and vocab-mapping regeneration is skipped.

An opt-in flag --regenerate-vocab-mapping preserves the old behavior for users who explicitly want to recompute the reduced vocab (e.g., when the training distribution has shifted significantly).


Changes

specforge/modeling/draft/base.py

  • load_vocab_mapping(file_path) now accepts None as a no-op, allowing callers to skip overwriting cleanly.
  • Added has_nondefault_vocab_mapping() to detect whether d2t / t2d still contain their default initialization (d2t == 0, t2d == True). Used as a sanity check before training.
  • Fixed typo: buffersarebuffers are.

scripts/train_eagle3.py

  • Added --regenerate-vocab-mapping flag (default: off).
  • build_draft_model() now returns an additional warm_started: bool.
  • build_dataloaders() accepts skip_vocab_mapping and skips generate_vocab_mapping_file() when set.
  • In main():
    • skip_vocab_mapping = warm_started and not args.regenerate_vocab_mapping
    • Logs the chosen behavior
    • Raises a clear error if a reused mapping is found to be uninitialized

Behavior Matrix

Mode Default (new) --regenerate-vocab-mapping
Cold start (no --ckpt-dir, no --resume) Generate from data Same
--ckpt-dir <dir> Reuse checkpoint mapping Regenerate + overwrite (old behavior)
--resume Reuse checkpoint mapping Regenerate + overwrite

Cold-start training remains unchanged. The silent misalignment issue during warm-start training is eliminated by default.


Why the Default Changed

Regenerating the mapping during warm-start is almost never desirable:

  1. The drafter’s lm_head is a 32K × hidden matrix whose columns are trained against specific target token IDs.
  2. The mapping (d2t[i] = target_id_of_slot_i − i) is saved alongside these weights in model.safetensors and restored via from_pretrained().
  3. Recomputing the mapping from a different dataset reshuffles the slot-to-token assignments, causing every learned column to point to the wrong softmax row.

Fresh training is unaffected because warm_started = False when no checkpoint is used.


Tests

Unit-style sanity checks for Eagle3DraftModel:

# Fresh buffers are detected as default
m = FakeDrafter()
assert not m.has_nondefault_vocab_mapping()

# load_vocab_mapping(None) is a safe no-op
m.load_vocab_mapping(None)
assert m.vocab_mapping_loaded

# Populated buffers are detected as non-default
m.d2t[0] = 9
assert m.has_nondefault_vocab_mapping()

# Round-trip load works
torch.save({"d2t": ..., "t2d": ...}, path)
m2 = FakeDrafter()
m2.load_vocab_mapping(path)
assert m2.d2t[0].item() == 42

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the ability to reuse vocabulary mappings when fine-tuning from checkpoints in train_eagle3.py, preventing accidental misalignment of the drafter's lm_head. It also adds configuration options for SGLang backend limits and includes new chat templates for Minimax models. A performance optimization was suggested for scripts/train_eagle3.py to avoid redundant dataset construction and tokenization when vocabulary mapping is skipped in offline mode.

Comment thread scripts/train_eagle3.py
cache_dir=os.path.join(args.cache_dir, "vocab_mapping"),
cache_key=cache_key,
)
if skip_vocab_mapping:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The build_eagle3_dataset call (lines 501-513) is executed unconditionally. However, in offline training mode (is_online=False), this dataset is only used for vocabulary mapping generation and is subsequently overwritten by build_offline_eagle3_dataset at line 530.

When skip_vocab_mapping is True and is_online is False (which is common when fine-tuning from a checkpoint in offline mode), building the train_eagle3_dataset is redundant and can be very time-consuming as it involves tokenizing the entire training corpus. Consider wrapping the dataset construction in a condition like if is_online or not skip_vocab_mapping: to avoid this unnecessary overhead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant