Skip to content

Commit ef8c0c4

Browse files
committed
squash: speculative decoding recipe lib
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 7f5fd65 commit ef8c0c4

File tree

16 files changed

+446
-507
lines changed

16 files changed

+446
-507
lines changed

examples/speculative_decoding/README.md

Lines changed: 31 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,16 @@ This one-line command runs a minimal example workflow of training and exporting
7373
For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command:
7474

7575
```bash
76-
./launch_train.sh --model $BASE_MODEL \
77-
--output_dir $OUTPUT_DIR \
78-
--data input_conversations/train.jsonl \
79-
--num_epochs $NUM_EPOCH \
80-
--eagle_config eagle_config.json
76+
./launch_train.sh \
77+
--config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml \
78+
model.model_name_or_path=meta-llama/Llama-3.2-1B \
79+
data.data_path=input_conversations/train.jsonl \
80+
training.output_dir=ckpts/llama-3.2-1b-online
8181
```
8282

83-
FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`.
83+
All default training settings live in `eagle3.yaml`; override any field via OmegaConf dotlist arguments on the command line.
84+
85+
To enable context parallelism for long-context training, add `training.cp_size=<N>` to the overrides.
8486
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
8587

8688
## Training Draft Model with Offline Base Model
@@ -113,15 +115,14 @@ python collect_hidden_states/compute_hidden_states_hf.py \
113115

114116
### Train Draft Model with Dumped Hidden States
115117

116-
Once we finish dumping hidden states, launch offline training with an extra `--offline-data` argument:
118+
Once we finish dumping hidden states, launch offline training pointing to the hidden states directory:
117119

118120
```bash
119-
./launch_train.sh --model $BASE_MODEL \
120-
--output_dir $OUTPUT_DIR \
121-
--data $DATA \
122-
--num_epochs $NUM_EPOCH \
123-
--eagle_config eagle_config.json \
124-
--offline-data $HIDDEN_STATES_DIR
121+
./launch_train.sh \
122+
--config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml \
123+
model.model_name_or_path=meta-llama/Llama-3.2-1B \
124+
data.offline_data_path=$HIDDEN_STATES_DIR \
125+
training.output_dir=ckpts/llama-3.2-1b-offline
125126
```
126127

127128
## Model Validation
@@ -244,13 +245,13 @@ For large scale data generation, please see [SLURM prepare data](SLURM_prepare_d
244245

245246
### Configuring Draft Model
246247

247-
For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. To use 2-layer eagle with 8192 intermediate size for MLP, set `eagle_config.json` to:
248+
For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings via `eagle.eagle_architecture_config` in the YAML. E.g. to use a 2-layer EAGLE head with 8192 intermediate size:
248249

249-
```json
250-
{
251-
"num_hidden_layers": 2,
252-
"intermediate_size":8192
253-
}
250+
```yaml
251+
eagle:
252+
eagle_architecture_config:
253+
num_hidden_layers: 2
254+
intermediate_size: 8192
254255
```
255256
256257
### Draft Vocabulary Compression
@@ -263,61 +264,26 @@ python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct
263264

264265
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.
265266

266-
Then, simply set `{"draft_vocab_size":32000}` in `eagle_config.json` and include `--draft_vocab_cache <path_to_d2t.pt>` when running `./launch_train.sh`. The draft model will use this provided vocab table during training and export.
267+
Then, set `eagle_architecture_config.draft_vocab_size: 32000` and `data.draft_vocab_cache: <path_to_d2t.pt>` in your YAML. The draft model will use this provided vocab table during training and export.
267268

268269
### Interact with `modelopt.torch.speculative`
269270

270-
`main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps:
271-
First, load the base model and tokenizer from Hugging Face:
272-
273-
```python
274-
model = transformers.AutoModelForCausalLM.from_pretrained(
275-
"<path to your pretrained model>"
276-
)
277-
```
278-
279-
Then, load default eagle config and make necessary overwrites:
271+
`main.py` provides a complete example for converting a HF base model for speculative decoding and training it. The core steps are loading the base model, converting it with an eagle config dict, and training with HF Trainer:
280272

281273
```python
282-
# Load default config
283-
config = {
284-
"eagle1": EAGLE1_DEFAULT_CFG,
285-
"eagle3": EAGLE3_DEFAULT_CFG,
286-
}[training_args.mode]["config"]
287-
288-
# overwrite config with custom config
289-
config["eagle_architecture_config"].update({"<overwrite_keys>": "<overwrite_values>"})
290-
291-
# Mandatory: hidden size, vocab size and max position embeddings must match base model
292-
config["eagle_architecture_config"].update(
293-
{
294-
"hidden_size": model.config.hidden_size,
295-
"vocab_size": model.config.vocab_size,
296-
"max_position_embeddings": model.config.max_position_embeddings,
297-
}
298-
)
299-
```
274+
import modelopt.torch.speculative as mtsp
300275

301-
Then, we convert model to a speculative decoding model:
276+
# Convert base model in-place to an EAGLE speculative decoding model
277+
eagle_cfg = {"eagle_decoder_type": "llama", ...} # fields from EagleConfig
278+
mtsp.convert(model, [("eagle", eagle_cfg)])
302279

303-
```python
304-
mtsp.convert(model, [("eagle", config)])
280+
# Train with HF Trainer as usual
281+
trainer = transformers.Trainer(model=model, ...)
282+
trainer.train()
283+
trainer.save_model("<output_dir>")
305284
```
306285

307-
This will modify the model in-place with eagle training forward, making it compatible with HF trainer:
308-
309-
```python
310-
# Create a trainer
311-
trainer = transformers.Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
312-
trainer._move_model_to_device(model, trainer.args.device)
313-
314-
# Enable HF checkpointing so that the saved model will contain the speculative decoding module
315-
mto.enable_huggingface_checkpointing()
316-
317-
trainer.train(resume_from_checkpoint=checkpoint)
318-
trainer.save_state()
319-
trainer.save_model("<path to the output directory>")
320-
```
286+
See `main.py` for the full example including tokenizer setup, dataset loading, and checkpoint handling.
321287

322288
## Support Matrix
323289

examples/speculative_decoding/eagle_config.json

Lines changed: 0 additions & 2 deletions
This file was deleted.

examples/speculative_decoding/eagle_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444

4545
try:
4646
import wandb
47-
except ImportError:
47+
48+
wandb.log # Verify wandb is functional (not a stub module).
49+
except (ImportError, AttributeError):
4850
wandb = None
4951

5052
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@@ -194,7 +196,7 @@ class EagleTrainingPlot(TrainerCallback):
194196

195197
def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False):
196198
self.ar_validate_steps = ar_validate_steps
197-
if wandb and is_master():
199+
if hasattr(wandb, "init") and is_master():
198200
wandb.init()
199201
self.estimate_ar = estimate_ar
200202

examples/speculative_decoding/fsdp_config.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)