Preserve drafter vocab mapping when fine-tuning from a checkpoint#534
Preserve drafter vocab mapping when fine-tuning from a checkpoint#534luv-bansal wants to merge 1 commit intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the ability to reuse vocabulary mappings when fine-tuning from checkpoints in train_eagle3.py, preventing accidental misalignment of the drafter's lm_head. It also adds configuration options for SGLang backend limits and includes new chat templates for Minimax models. A performance optimization was suggested for scripts/train_eagle3.py to avoid redundant dataset construction and tokenization when vocabulary mapping is skipped in offline mode.
| cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), | ||
| cache_key=cache_key, | ||
| ) | ||
| if skip_vocab_mapping: |
There was a problem hiding this comment.
The build_eagle3_dataset call (lines 501-513) is executed unconditionally. However, in offline training mode (is_online=False), this dataset is only used for vocabulary mapping generation and is subsequently overwritten by build_offline_eagle3_dataset at line 530.
When skip_vocab_mapping is True and is_online is False (which is common when fine-tuning from a checkpoint in offline mode), building the train_eagle3_dataset is redundant and can be very time-consuming as it involves tokenizing the entire training corpus. Consider wrapping the dataset construction in a condition like if is_online or not skip_vocab_mapping: to avoid this unnecessary overhead.
Summary
When fine-tuning an EAGLE-3 drafter from an existing checkpoint (via
--ckpt-diror--resume),scripts/train_eagle3.pycurrently regenerates thed2t/t2dvocab mapping from the new training data and overwrites the mapping loaded from the checkpoint.This silently misaligns every
lm_headslot with a different target token ID. Specifically,lm_headcolumniwas trained to predict tokeni + d2t_old[i], but is now asked to predicti + d2t_new[i].The result is a catastrophic drop in draft acceptance that is not visible in
train/acc_*oreval/acc_*, since those metrics are computed against the new mapping rather than the true target IDs. The drafter effectively has to relearn the slot-to-token alignment from scratch, negating the benefit of warm-starting.This PR makes the safe behavior the default: when a checkpoint is loaded, its
d2t/t2dmappings (stored inmodel.safetensorsas registered buffers) are reused, and vocab-mapping regeneration is skipped.An opt-in flag
--regenerate-vocab-mappingpreserves the old behavior for users who explicitly want to recompute the reduced vocab (e.g., when the training distribution has shifted significantly).Changes
specforge/modeling/draft/base.pyload_vocab_mapping(file_path)now acceptsNoneas a no-op, allowing callers to skip overwriting cleanly.has_nondefault_vocab_mapping()to detect whetherd2t/t2dstill contain their default initialization (d2t == 0,t2d == True). Used as a sanity check before training.buffersare→buffers are.scripts/train_eagle3.py--regenerate-vocab-mappingflag (default: off).build_draft_model()now returns an additionalwarm_started: bool.build_dataloaders()acceptsskip_vocab_mappingand skipsgenerate_vocab_mapping_file()when set.main():skip_vocab_mapping = warm_started and not args.regenerate_vocab_mappingBehavior Matrix
--regenerate-vocab-mapping--ckpt-dir, no--resume)--ckpt-dir <dir>--resumeCold-start training remains unchanged. The silent misalignment issue during warm-start training is eliminated by default.
Why the Default Changed
Regenerating the mapping during warm-start is almost never desirable:
lm_headis a32K × hiddenmatrix whose columns are trained against specific target token IDs.d2t[i] = target_id_of_slot_i − i) is saved alongside these weights inmodel.safetensorsand restored viafrom_pretrained().Fresh training is unaffected because
warm_started = Falsewhen no checkpoint is used.Tests
Unit-style sanity checks for
Eagle3DraftModel: