diff --git a/scripts/regenerate_train_data.py b/scripts/regenerate_train_data.py index d38392b6..d53ed7aa 100644 --- a/scripts/regenerate_train_data.py +++ b/scripts/regenerate_train_data.py @@ -231,9 +231,9 @@ def call_sglang( "content": response_text, } if args.is_reasoning_model: - resp_msg["reasoning_content"] = resp.choices[ - 0 - ].message.reasoning_content + reasoning = getattr(resp.choices[0].message, "reasoning_content", None) + if reasoning is not None: + resp_msg["reasoning_content"] = reasoning regenerated_messages.append(resp_msg) else: data["status"] = "error" diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 0bd157b3..7870062f 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -155,6 +155,23 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: default=None, help="directory includes the checkpoint to start training with", ) + training_group.add_argument( + "--regenerate-vocab-mapping", + action="store_true", + help=( + "Recompute the d2t/t2d vocab mapping from the new training data " + "even when fine-tuning from a checkpoint. By default, when " + "--ckpt-dir is provided or --resume picks up a previous run, the " + "drafter reuses the mapping stored in the loaded checkpoint so " + "that the lm_head slots stay aligned with the target token ids " + "they were trained against. Setting this flag overwrites that " + "mapping with a fresh one derived from the new corpus's token " + "frequencies, which silently realigns slot->target_id and " + "collapses acceptance until lm_head is retrained. Only use this " + "for true cold-start continuation where the data distribution has " + "shifted enough to justify re-picking the reduced 32K vocab." + ), + ) training_group.add_argument("--eval-interval", type=int, default=5000) training_group.add_argument("--save-interval", type=int, default=5000) training_group.add_argument( @@ -366,9 +383,26 @@ def sp_sanity_check(args: Namespace) -> None: ) -def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]: +def build_draft_model( + args: Namespace, +) -> Tuple[ + AutoDraftModelConfig, nn.Module, Tuple[int, int], Optional[dict], bool +]: + """Build (or load) the draft model. + + Returns: + (draft_model_config, draft_model, ckpt_info, resume_state, warm_started) + + ``warm_started`` is True if the drafter weights were loaded from an + existing checkpoint (via ``--ckpt-dir`` or a detected ``--resume`` + checkpoint). When True, the loaded checkpoint's d2t/t2d buffers must + be preserved: the drafter's lm_head slots are aligned to that + specific mapping, and overwriting it with a freshly-regenerated + mapping silently mis-aligns every slot. + """ # ckpt info(epoch, step) ckpt_info = (0, 0) + warm_started = False # Handle draft model config if args.draft_model_config is None: @@ -409,6 +443,7 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] attention_backend=args.attention_backend, torch_dtype=torch.bfloat16, ).cuda() + warm_started = True else: draft_model = AutoEagle3DraftModel.from_config( draft_model_config, @@ -433,14 +468,15 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) draft_model.freeze_embedding() - return draft_model_config, draft_model, ckpt_info, resume_state + return draft_model_config, draft_model, ckpt_info, resume_state, warm_started def build_dataloaders( args: Namespace, draft_model_config: AutoDraftModelConfig, processor: Optional[AutoProcessor] = None, -) -> Tuple[DataLoader, str, Optional[DataLoader]]: + skip_vocab_mapping: bool = False, +) -> Tuple[DataLoader, Optional[str], Optional[DataLoader]]: # build dataloaders tokenizer = AutoTokenizer.from_pretrained( args.target_model_path, trust_remote_code=args.trust_remote_code @@ -475,13 +511,20 @@ def build_dataloaders( num_proc=args.build_dataset_num_proc, train_only_last_turn=args.train_only_last_turn, ) - vocab_mapping_path = generate_vocab_mapping_file( - dataset=train_eagle3_dataset, - target_vocab_size=draft_model_config.vocab_size, - draft_vocab_size=draft_model_config.draft_vocab_size, - cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), - cache_key=cache_key, - ) + if skip_vocab_mapping: + vocab_mapping_path = None + print_on_rank0( + "Skipping vocab mapping generation: reusing d2t/t2d from the " + "loaded checkpoint (lm_head slots are aligned to that mapping)." + ) + else: + vocab_mapping_path = generate_vocab_mapping_file( + dataset=train_eagle3_dataset, + target_vocab_size=draft_model_config.vocab_size, + draft_vocab_size=draft_model_config.draft_vocab_size, + cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), + cache_key=cache_key, + ) if not is_online: train_eagle3_dataset = build_offline_eagle3_dataset( @@ -748,18 +791,51 @@ def main(): # ================================================ # 2. Build models # ================================================ - draft_model_config, draft_model, ckpt_info, resume_state = build_draft_model(args) + draft_model_config, draft_model, ckpt_info, resume_state, warm_started = ( + build_draft_model(args) + ) target_model, processor = build_target_model(args, draft_model_config, is_online) # ================================================ # 3. Build dataloader # ================================================ + # When warm-starting from a checkpoint, the d2t/t2d buffers loaded from + # model.safetensors are the ground truth: the drafter's lm_head slots were + # trained against exactly that mapping. Regenerating from the new training + # data would silently remap every slot to a different target token id and + # collapse acceptance. Reuse unless the user explicitly opts out via + # --regenerate-vocab-mapping (for true cold-start continuation). + skip_vocab_mapping = warm_started and not args.regenerate_vocab_mapping + if warm_started and args.regenerate_vocab_mapping: + print_on_rank0( + "WARNING: --regenerate-vocab-mapping is set while warm-starting " + "from a checkpoint. The d2t/t2d mapping will be recomputed from " + "the new training data, which realigns the drafter's lm_head " + "slots to different target token ids. Acceptance will drop " + "significantly until lm_head retrains against the new mapping. " + "Confirm this is intended." + ) train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders( - args, draft_model_config, processor + args, draft_model_config, processor, skip_vocab_mapping=skip_vocab_mapping ) - # we load the vocab mapping then + # Load vocab mapping (no-op when skip_vocab_mapping=True; the buffers were + # already populated by from_pretrained() in build_draft_model). draft_model.load_vocab_mapping(vocab_mapping_path) + if vocab_mapping_path is None: + # Defensive sanity check: if we're reusing the checkpoint mapping, + # verify the buffers actually contain a real mapping (i.e. the + # checkpoint wasn't saved before load_vocab_mapping was called). + if not draft_model.has_nondefault_vocab_mapping(): + raise RuntimeError( + "Refusing to start training: --ckpt-dir was provided (or " + "--resume picked up a checkpoint) but the drafter's d2t/t2d " + "buffers still match the uninitialized defaults (d2t all " + "zeros, t2d all True). This means the checkpoint was saved " + "without a vocab mapping. Re-run with " + "--regenerate-vocab-mapping to build one from the training " + "data (this will reinitialize lm_head alignment from scratch)." + ) print_with_rank("Loaded vocab mapping") # Calculate total steps if not provided diff --git a/specforge/args.py b/specforge/args.py index 2cd5efc3..f4a89f53 100644 --- a/specforge/args.py +++ b/specforge/args.py @@ -174,6 +174,18 @@ def add_args(parser: argparse.ArgumentParser) -> None: default=1, help="The ep size of the SGLang backend", ) + parser.add_argument( + "--sglang-max-running-requests", + type=int, + default=None, + help="Override auto-computed max_running_requests for SGLang backend (default: target_batch_size)", + ) + parser.add_argument( + "--sglang-max-total-tokens", + type=int, + default=None, + help="Override auto-computed max_total_tokens for SGLang backend (default: target_batch_size * max_length)", + ) @staticmethod def from_args(args: argparse.Namespace) -> "SGLangBackendArgs": @@ -191,12 +203,18 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs": sglang_piecewise_cuda_graph_tokens=args.sglang_piecewise_cuda_graph_tokens, sglang_ep_size=args.sglang_ep_size, sglang_max_running_requests=( - args.target_batch_size if hasattr(args, "target_batch_size") else None + args.sglang_max_running_requests + if getattr(args, "sglang_max_running_requests", None) is not None + else (args.target_batch_size if hasattr(args, "target_batch_size") else None) ), sglang_max_total_tokens=( - args.target_batch_size * args.max_length - if hasattr(args, "target_batch_size") and hasattr(args, "max_length") - else None + args.sglang_max_total_tokens + if getattr(args, "sglang_max_total_tokens", None) is not None + else ( + args.target_batch_size * args.max_length + if hasattr(args, "target_batch_size") and hasattr(args, "max_length") + else None + ) ), ) diff --git a/specforge/data/template.py b/specforge/data/template.py index 4dde000f..e468bc2d 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -324,3 +324,23 @@ def get_all_template_names(self) -> List[str]: enable_thinking=True, ), ) + +TEMPLATE_REGISTRY.register( + name="minimax-m2.5", + template=ChatTemplate( + assistant_header="]~b]ai\n", + user_header="]~b]user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="[e~[\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="minimax-m2.7", + template=ChatTemplate( + assistant_header="]~b]ai\n", + user_header="]~b]user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="[e~[\n", + ), +) diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index b5584a75..75670d2a 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -173,17 +173,53 @@ def load_embedding( local_cache_path = snapshot_download(repo_id=model_path) self.load_embedding(local_cache_path, embedding_key) - def load_vocab_mapping(self, file_path: str) -> None: + def load_vocab_mapping(self, file_path: Optional[str]) -> None: """ Load the vocab buffers of the draft model. Args: - file_path (str): The path to the vocab mapping file. + file_path: Path to the vocab mapping file. If ``None``, this is a + no-op and the caller is expected to have already populated the + ``d2t``/``t2d`` buffers (for example by loading them from a + checkpoint via ``from_pretrained``). This is the correct + behaviour when fine-tuning from an existing drafter: the + mapping stored in the checkpoint is the one the ``lm_head`` + slots were trained against, so it must not be overwritten with + a freshly-computed mapping derived from a different training + corpus. """ assert hasattr(self, "t2d") and hasattr( self, "d2t" - ), "t2d and d2t buffersare not found in the draft model, please check your draft model implementation" + ), "t2d and d2t buffers are not found in the draft model, please check your draft model implementation" + if file_path is None: + self.vocab_mapping_loaded = True + return vocab_mapping = torch.load(file_path) self.t2d.copy_(vocab_mapping["t2d"]) self.d2t.copy_(vocab_mapping["d2t"]) self.vocab_mapping_loaded = True + + def has_nondefault_vocab_mapping(self) -> bool: + """ + Return ``True`` if ``d2t``/``t2d`` buffers look like a real, trained + mapping (i.e. they differ from the all-zeros / all-True defaults set + in the draft model's ``__init__``). + + A fresh draft model constructed via ``from_config`` has + ``d2t = zeros(draft_vocab_size)`` and ``t2d = ones(vocab_size, bool)``. + Either the first forward pass should receive a mapping generated from + the training data (cold start) or the mapping should be loaded from a + checkpoint's ``model.safetensors`` via ``from_pretrained`` (warm + start). This helper lets callers sanity-check that one of those paths + actually populated the buffers before training begins. + """ + if not (hasattr(self, "d2t") and hasattr(self, "t2d")): + return False + # d2t defaults to all zeros; any real mapping has non-zero entries + # because token id 0 cannot be the most-frequent id across all 32K + # drafter slots. + d2t_is_default = bool((self.d2t == 0).all().item()) + # t2d defaults to all True; a real mapping keeps only the top-K + # targets marked True and zeroes out the rest. + t2d_is_default = bool(self.t2d.all().item()) + return not (d2t_is_default and t2d_is_default)