Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions scripts/regenerate_train_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def call_sglang(
"content": response_text,
}
if args.is_reasoning_model:
resp_msg["reasoning_content"] = resp.choices[
0
].message.reasoning_content
reasoning = getattr(resp.choices[0].message, "reasoning_content", None)
if reasoning is not None:
resp_msg["reasoning_content"] = reasoning
regenerated_messages.append(resp_msg)
else:
data["status"] = "error"
Expand Down
102 changes: 89 additions & 13 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,23 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
default=None,
help="directory includes the checkpoint to start training with",
)
training_group.add_argument(
"--regenerate-vocab-mapping",
action="store_true",
help=(
"Recompute the d2t/t2d vocab mapping from the new training data "
"even when fine-tuning from a checkpoint. By default, when "
"--ckpt-dir is provided or --resume picks up a previous run, the "
"drafter reuses the mapping stored in the loaded checkpoint so "
"that the lm_head slots stay aligned with the target token ids "
"they were trained against. Setting this flag overwrites that "
"mapping with a fresh one derived from the new corpus's token "
"frequencies, which silently realigns slot->target_id and "
"collapses acceptance until lm_head is retrained. Only use this "
"for true cold-start continuation where the data distribution has "
"shifted enough to justify re-picking the reduced 32K vocab."
),
)
training_group.add_argument("--eval-interval", type=int, default=5000)
training_group.add_argument("--save-interval", type=int, default=5000)
training_group.add_argument(
Expand Down Expand Up @@ -366,9 +383,26 @@ def sp_sanity_check(args: Namespace) -> None:
)


def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]:
def build_draft_model(
args: Namespace,
) -> Tuple[
AutoDraftModelConfig, nn.Module, Tuple[int, int], Optional[dict], bool
]:
"""Build (or load) the draft model.

Returns:
(draft_model_config, draft_model, ckpt_info, resume_state, warm_started)

``warm_started`` is True if the drafter weights were loaded from an
existing checkpoint (via ``--ckpt-dir`` or a detected ``--resume``
checkpoint). When True, the loaded checkpoint's d2t/t2d buffers must
be preserved: the drafter's lm_head slots are aligned to that
specific mapping, and overwriting it with a freshly-regenerated
mapping silently mis-aligns every slot.
"""
# ckpt info(epoch, step)
ckpt_info = (0, 0)
warm_started = False

# Handle draft model config
if args.draft_model_config is None:
Expand Down Expand Up @@ -409,6 +443,7 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]
attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16,
).cuda()
warm_started = True
else:
draft_model = AutoEagle3DraftModel.from_config(
draft_model_config,
Expand All @@ -433,14 +468,15 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]

draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
draft_model.freeze_embedding()
return draft_model_config, draft_model, ckpt_info, resume_state
return draft_model_config, draft_model, ckpt_info, resume_state, warm_started


def build_dataloaders(
args: Namespace,
draft_model_config: AutoDraftModelConfig,
processor: Optional[AutoProcessor] = None,
) -> Tuple[DataLoader, str, Optional[DataLoader]]:
skip_vocab_mapping: bool = False,
) -> Tuple[DataLoader, Optional[str], Optional[DataLoader]]:
# build dataloaders
tokenizer = AutoTokenizer.from_pretrained(
args.target_model_path, trust_remote_code=args.trust_remote_code
Expand Down Expand Up @@ -475,13 +511,20 @@ def build_dataloaders(
num_proc=args.build_dataset_num_proc,
train_only_last_turn=args.train_only_last_turn,
)
vocab_mapping_path = generate_vocab_mapping_file(
dataset=train_eagle3_dataset,
target_vocab_size=draft_model_config.vocab_size,
draft_vocab_size=draft_model_config.draft_vocab_size,
cache_dir=os.path.join(args.cache_dir, "vocab_mapping"),
cache_key=cache_key,
)
if skip_vocab_mapping:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The build_eagle3_dataset call (lines 501-513) is executed unconditionally. However, in offline training mode (is_online=False), this dataset is only used for vocabulary mapping generation and is subsequently overwritten by build_offline_eagle3_dataset at line 530.

When skip_vocab_mapping is True and is_online is False (which is common when fine-tuning from a checkpoint in offline mode), building the train_eagle3_dataset is redundant and can be very time-consuming as it involves tokenizing the entire training corpus. Consider wrapping the dataset construction in a condition like if is_online or not skip_vocab_mapping: to avoid this unnecessary overhead.

vocab_mapping_path = None
print_on_rank0(
"Skipping vocab mapping generation: reusing d2t/t2d from the "
"loaded checkpoint (lm_head slots are aligned to that mapping)."
)
else:
vocab_mapping_path = generate_vocab_mapping_file(
dataset=train_eagle3_dataset,
target_vocab_size=draft_model_config.vocab_size,
draft_vocab_size=draft_model_config.draft_vocab_size,
cache_dir=os.path.join(args.cache_dir, "vocab_mapping"),
cache_key=cache_key,
)

if not is_online:
train_eagle3_dataset = build_offline_eagle3_dataset(
Expand Down Expand Up @@ -748,18 +791,51 @@ def main():
# ================================================
# 2. Build models
# ================================================
draft_model_config, draft_model, ckpt_info, resume_state = build_draft_model(args)
draft_model_config, draft_model, ckpt_info, resume_state, warm_started = (
build_draft_model(args)
)
target_model, processor = build_target_model(args, draft_model_config, is_online)

# ================================================
# 3. Build dataloader
# ================================================
# When warm-starting from a checkpoint, the d2t/t2d buffers loaded from
# model.safetensors are the ground truth: the drafter's lm_head slots were
# trained against exactly that mapping. Regenerating from the new training
# data would silently remap every slot to a different target token id and
# collapse acceptance. Reuse unless the user explicitly opts out via
# --regenerate-vocab-mapping (for true cold-start continuation).
skip_vocab_mapping = warm_started and not args.regenerate_vocab_mapping
if warm_started and args.regenerate_vocab_mapping:
print_on_rank0(
"WARNING: --regenerate-vocab-mapping is set while warm-starting "
"from a checkpoint. The d2t/t2d mapping will be recomputed from "
"the new training data, which realigns the drafter's lm_head "
"slots to different target token ids. Acceptance will drop "
"significantly until lm_head retrains against the new mapping. "
"Confirm this is intended."
)
train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders(
args, draft_model_config, processor
args, draft_model_config, processor, skip_vocab_mapping=skip_vocab_mapping
)

# we load the vocab mapping then
# Load vocab mapping (no-op when skip_vocab_mapping=True; the buffers were
# already populated by from_pretrained() in build_draft_model).
draft_model.load_vocab_mapping(vocab_mapping_path)
if vocab_mapping_path is None:
# Defensive sanity check: if we're reusing the checkpoint mapping,
# verify the buffers actually contain a real mapping (i.e. the
# checkpoint wasn't saved before load_vocab_mapping was called).
if not draft_model.has_nondefault_vocab_mapping():
raise RuntimeError(
"Refusing to start training: --ckpt-dir was provided (or "
"--resume picked up a checkpoint) but the drafter's d2t/t2d "
"buffers still match the uninitialized defaults (d2t all "
"zeros, t2d all True). This means the checkpoint was saved "
"without a vocab mapping. Re-run with "
"--regenerate-vocab-mapping to build one from the training "
"data (this will reinitialize lm_head alignment from scratch)."
)
print_with_rank("Loaded vocab mapping")

# Calculate total steps if not provided
Expand Down
26 changes: 22 additions & 4 deletions specforge/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ def add_args(parser: argparse.ArgumentParser) -> None:
default=1,
help="The ep size of the SGLang backend",
)
parser.add_argument(
"--sglang-max-running-requests",
type=int,
default=None,
help="Override auto-computed max_running_requests for SGLang backend (default: target_batch_size)",
)
parser.add_argument(
"--sglang-max-total-tokens",
type=int,
default=None,
help="Override auto-computed max_total_tokens for SGLang backend (default: target_batch_size * max_length)",
)

@staticmethod
def from_args(args: argparse.Namespace) -> "SGLangBackendArgs":
Expand All @@ -191,12 +203,18 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs":
sglang_piecewise_cuda_graph_tokens=args.sglang_piecewise_cuda_graph_tokens,
sglang_ep_size=args.sglang_ep_size,
sglang_max_running_requests=(
args.target_batch_size if hasattr(args, "target_batch_size") else None
args.sglang_max_running_requests
if getattr(args, "sglang_max_running_requests", None) is not None
else (args.target_batch_size if hasattr(args, "target_batch_size") else None)
),
sglang_max_total_tokens=(
args.target_batch_size * args.max_length
if hasattr(args, "target_batch_size") and hasattr(args, "max_length")
else None
args.sglang_max_total_tokens
if getattr(args, "sglang_max_total_tokens", None) is not None
else (
args.target_batch_size * args.max_length
if hasattr(args, "target_batch_size") and hasattr(args, "max_length")
else None
)
),
)

Expand Down
20 changes: 20 additions & 0 deletions specforge/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,23 @@ def get_all_template_names(self) -> List[str]:
enable_thinking=True,
),
)

TEMPLATE_REGISTRY.register(
name="minimax-m2.5",
template=ChatTemplate(
assistant_header="]~b]ai\n",
user_header="]~b]user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="[e~[\n",
),
)

TEMPLATE_REGISTRY.register(
name="minimax-m2.7",
template=ChatTemplate(
assistant_header="]~b]ai\n",
user_header="]~b]user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="[e~[\n",
),
)
42 changes: 39 additions & 3 deletions specforge/modeling/draft/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,53 @@ def load_embedding(
local_cache_path = snapshot_download(repo_id=model_path)
self.load_embedding(local_cache_path, embedding_key)

def load_vocab_mapping(self, file_path: str) -> None:
def load_vocab_mapping(self, file_path: Optional[str]) -> None:
"""
Load the vocab buffers of the draft model.

Args:
file_path (str): The path to the vocab mapping file.
file_path: Path to the vocab mapping file. If ``None``, this is a
no-op and the caller is expected to have already populated the
``d2t``/``t2d`` buffers (for example by loading them from a
checkpoint via ``from_pretrained``). This is the correct
behaviour when fine-tuning from an existing drafter: the
mapping stored in the checkpoint is the one the ``lm_head``
slots were trained against, so it must not be overwritten with
a freshly-computed mapping derived from a different training
corpus.
"""
assert hasattr(self, "t2d") and hasattr(
self, "d2t"
), "t2d and d2t buffersare not found in the draft model, please check your draft model implementation"
), "t2d and d2t buffers are not found in the draft model, please check your draft model implementation"
if file_path is None:
self.vocab_mapping_loaded = True
return
vocab_mapping = torch.load(file_path)
self.t2d.copy_(vocab_mapping["t2d"])
self.d2t.copy_(vocab_mapping["d2t"])
self.vocab_mapping_loaded = True

def has_nondefault_vocab_mapping(self) -> bool:
"""
Return ``True`` if ``d2t``/``t2d`` buffers look like a real, trained
mapping (i.e. they differ from the all-zeros / all-True defaults set
in the draft model's ``__init__``).

A fresh draft model constructed via ``from_config`` has
``d2t = zeros(draft_vocab_size)`` and ``t2d = ones(vocab_size, bool)``.
Either the first forward pass should receive a mapping generated from
the training data (cold start) or the mapping should be loaded from a
checkpoint's ``model.safetensors`` via ``from_pretrained`` (warm
start). This helper lets callers sanity-check that one of those paths
actually populated the buffers before training begins.
"""
if not (hasattr(self, "d2t") and hasattr(self, "t2d")):
return False
# d2t defaults to all zeros; any real mapping has non-zero entries
# because token id 0 cannot be the most-frequent id across all 32K
# drafter slots.
d2t_is_default = bool((self.d2t == 0).all().item())
# t2d defaults to all True; a real mapping keeps only the top-K
# targets marked True and zeroes out the rest.
t2d_is_default = bool(self.t2d.all().item())
return not (d2t_is_default and t2d_is_default)