Skip to content

Commit 670d675

Browse files
committed
fix tests
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 59a0984 commit 670d675

10 files changed

Lines changed: 101 additions & 76 deletions

File tree

.github/workflows/example_tests.yml

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
strategy:
6464
fail-fast: false
6565
matrix:
66-
example: [llm_distill, llm_qat, llm_sparsity, speculative_decoding]
66+
example: [llm_distill, llm_qat, llm_sparsity]
6767
uses: ./.github/workflows/_example_tests_runner.yml
6868
secrets: inherit
6969
with:
@@ -77,7 +77,7 @@ jobs:
7777
strategy:
7878
fail-fast: false
7979
matrix:
80-
example: [llm_distill, llm_qat, llm_sparsity, speculative_decoding]
80+
example: [llm_distill, llm_qat, llm_sparsity]
8181
uses: ./.github/workflows/_example_tests_runner.yml
8282
secrets: inherit
8383
with:
@@ -86,6 +86,28 @@ jobs:
8686
pip_install_extras: "[hf,dev-test]"
8787
runner: linux-amd64-gpu-h100-latest-2
8888

89+
##### Speculative Decoding Example Tests (requires 25.08 image) #####
90+
speculative-decoding-pr:
91+
needs: [check-file-changes, wait-checks]
92+
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
93+
uses: ./.github/workflows/_example_tests_runner.yml
94+
secrets: inherit
95+
with:
96+
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
97+
example: speculative_decoding
98+
pip_install_extras: "[hf,dev-test]"
99+
runner: linux-amd64-gpu-l4-latest-1
100+
101+
speculative-decoding-non-pr:
102+
if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }}
103+
uses: ./.github/workflows/_example_tests_runner.yml
104+
secrets: inherit
105+
with:
106+
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
107+
example: speculative_decoding
108+
pip_install_extras: "[hf,dev-test]"
109+
runner: linux-amd64-gpu-h100-latest-2
110+
89111
##### TensorRT-LLM Example Tests #####
90112
trtllm-pr:
91113
needs: [check-file-changes, wait-checks]
@@ -150,14 +172,15 @@ jobs:
150172
example-pr-required-check:
151173
# Run even if example tests are skipped
152174
if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }}
153-
needs: [check-file-changes, torch-pr, trtllm-pr, onnx-pr]
175+
needs: [check-file-changes, torch-pr, speculative-decoding-pr, trtllm-pr, onnx-pr]
154176
runs-on: ubuntu-latest
155177
steps:
156178
- name: Required GPU tests did not succeed
157179
if: |
158180
needs.check-file-changes.result != 'success' ||
159181
(needs.check-file-changes.outputs.any_changed == 'true' && (
160182
needs.torch-pr.result != 'success' ||
183+
needs.speculative-decoding-pr.result != 'success' ||
161184
needs.trtllm-pr.result != 'success' ||
162185
needs.onnx-pr.result != 'success'
163186
))

examples/speculative_decoding/README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,
3030

3131
### Docker
3232

33-
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
33+
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
3434

3535
Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies.
3636

@@ -56,7 +56,7 @@ See [other-datasets](#other-datasets) section for other dataset options and inst
5656
## Getting Started: Simplified Workflow
5757

5858
```bash
59-
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4
59+
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct
6060
```
6161

6262
This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
@@ -74,12 +74,11 @@ For small base models that fit in GPU memory, we can collocate them with draft m
7474
./launch_train.sh --model $BASE_MODEL \
7575
--output_dir $OUTPUT_DIR \
7676
--data input_conversations/daring-anteater.jsonl \
77-
--num_gpu $NUM_GPU \
7877
--num_epochs $NUM_EPOCH \
7978
--eagle_config eagle_config.json
8079
```
8180

82-
This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details.
81+
FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`.
8382
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
8483

8584
## Training Draft Model with Offline Base Model
@@ -118,7 +117,6 @@ Once we finish dumping hidden states, launch offline training with an extra `--o
118117
./launch_train.sh --model $BASE_MODEL \
119118
--output_dir $OUTPUT_DIR \
120119
--data $DATA \
121-
--num_gpu $NUM_GPU \
122120
--num_epochs $NUM_EPOCH \
123121
--eagle_config eagle_config.json \
124122
--offline-data $HIDDEN_STATES_DIR

examples/speculative_decoding/eagle_utils.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@
2828
import torch
2929
import transformers
3030
from datasets import load_dataset
31+
from packaging.version import Version
3132
from PIL import Image
3233
from scripts.ar_validate import validate_ar
3334
from torch.distributed.tensor.experimental._attention import _SDPAMerger
3435
from torch.utils.data import Dataset
3536
from transformers import AutoProcessor, Trainer, TrainerCallback
3637
from transformers.trainer_pt_utils import LabelSmoother
3738

38-
import modelopt.torch.speculative.plugins.transformers
39+
import modelopt
3940
from modelopt.torch.speculative.utils import get_ttt_msk_func
4041
from modelopt.torch.utils import print_rank_0
4142
from modelopt.torch.utils.distributed import is_master
@@ -577,27 +578,6 @@ def on_step_end(self, args, state, control, **kwargs):
577578
return control
578579

579580

580-
def _compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) -> torch.Tensor:
581-
"""Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""
582-
583-
msk_func = get_ttt_msk_func(seq_length, ttt_step)
584-
585-
dtypemin = torch.finfo(dtype).min
586-
q_len = seq_length
587-
kv_len = seq_length * (1 + ttt_step)
588-
# Return tensor mask for non-flex attention
589-
tensor_mask = msk_func(
590-
None,
591-
None,
592-
torch.arange(q_len).view(1, 1, q_len, 1),
593-
torch.arange(kv_len).view(1, 1, 1, kv_len),
594-
).to(torch.cuda.current_device())
595-
tensor_mask = torch.full_like(
596-
tensor_mask, 0, dtype=dtype, device=torch.cuda.current_device()
597-
).masked_fill(~tensor_mask, dtypemin)
598-
return tensor_mask
599-
600-
601581
def get_patched_templated_ring_attn(orig_templated_attn: Callable):
602582
"""
603583
Return patched version of
@@ -659,14 +639,14 @@ def patched_templated_attn(*args, **kwargs):
659639
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
660640

661641
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
662-
raise ValueError(f"CP TTT only supports cuddn attention now. Got: {original_op}")
642+
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")
663643

664644
# Unset is_causal to use custom attn mask
665645
if patch_enbabled:
666646
kwargs["is_causal"] = False
667647

668648
def patched_op(*args, **kwargs):
669-
# Inpect the parent frame to get current shard info
649+
# Inspect the parent frame to get current shard info
670650
# This is sensitive to torch _templated_ring_attention impl
671651
try:
672652
frame: FrameType = inspect.currentframe()
@@ -678,7 +658,9 @@ def patched_op(*args, **kwargs):
678658
i = f_back.f_locals["i"]
679659
ttt_step = (key.shape[2] // query.shape[2]) - 1
680660
except Exception as e:
681-
print(f"Failed to capture loop variables in patched _templated_ring_attention: {e}")
661+
raise RuntimeError(
662+
f"Failed to capture loop variables in patched _templated_ring_attention: {e}"
663+
) from e
682664
# Set attn mask to permuted TTT mask
683665
if "attn_bias" in kwargs:
684666
kwargs["attn_bias"] = _get_sharded_ttt_msk(
@@ -696,8 +678,18 @@ def patch_ring_attention_for_ttt():
696678
"""Patch torch ring attention to support context parallelism for TTT."""
697679
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
698680

681+
if not (
682+
Version(torch.__version__) > Version("2.7.1")
683+
and Version(torch.__version__) < Version("2.9.0")
684+
):
685+
raise RuntimeError(
686+
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
687+
f"Got {torch.__version__}. "
688+
f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1."
689+
)
690+
699691
# 1. Disable load balance, which is designed for causal mask.
700-
# This affect how buffers are sharded. So need to be done permenantly before accelerate/hf trainer init.
692+
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
701693
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False
702694

703695
# 2. Patch templated ring attention for TTT mask.
@@ -717,9 +709,7 @@ def patch_ring_attention_for_ttt():
717709
# 3. Patch merger to skip the blank shard to avoid difference in output.
718710
original_sdpa_merger_step = _SDPAMerger.step
719711

720-
def patched_sdpa_merger_step(
721-
self, out: torch.Tensor, lse: torch.Tensor, partial: bool
722-
) -> torch.Tensor:
712+
def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
723713
if lse.sum() <= 0:
724714
return
725715
return original_sdpa_merger_step(self, out, lse, partial)

examples/speculative_decoding/launch_train.sh

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,15 @@ else
170170
VLM_ARGS=""
171171
fi
172172

173+
if [[ "$GPU_COUNT" -gt 1 ]]; then
174+
#Use FSDP2 when multi GPU available
175+
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json"
176+
else
177+
#Otherwise, single GPU training
178+
FSDP_ARGS=""
179+
fi
180+
181+
173182
# Disable tokenizers parallelism to avoid warning
174183
export TOKENIZERS_PARALLELISM=False
175184
CMD="accelerate launch --mixed_precision bf16 main.py \
@@ -201,8 +210,7 @@ CMD="accelerate launch --mixed_precision bf16 main.py \
201210
$VLM_ARGS \
202211
$OFFLINE_TRAINING_ARGS \
203212
$SPECULATIVE_ARGS \
204-
--fsdp 'full_shard' \
205-
--fsdp_config fsdp_config.json \
213+
$FSDP_ARGS \
206214
--cp_size $CP_SIZE \
207215
--dp_shard_size $DP_SHARD_SIZE \
208216
"
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
11
accelerate==1.12.0
2-
torch==2.8.0
32
transformers==5.0.0rc1
4-
wandb

examples/speculative_decoding/train_eagle3_and_export.sh

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,18 @@
1717

1818
set -eo pipefail
1919

20-
# Set default values for BASE_MODEL, NUM_GPU, and DATA
20+
# Set default values for BASE_MODEL and DATA
2121
BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct
22-
NUM_GPU=1
2322
DATA=input_conversations/daring-anteater.jsonl
2423

25-
# Parse input arguments --base_model, --num_gpu, and --data
24+
# Parse input arguments --base_model and --data
2625
while [[ $# -gt 0 ]]; do
2726
key="$1"
2827
case $key in
2928
--base_model)
3029
BASE_MODEL="$2"
3130
shift; shift
3231
;;
33-
--num_gpu)
34-
NUM_GPU="$2"
35-
shift; shift
36-
;;
3732
--data)
3833
DATA="$2"
3934
shift; shift
@@ -49,15 +44,6 @@ while [[ $# -gt 0 ]]; do
4944
esac
5045
done
5146

52-
53-
if [[ "$NUM_GPU" == 1 ]]; then
54-
export CUDA_VISIBLE_DEVICES=0
55-
else
56-
# Export as 0,1,...,N-1 for NUM_GPU GPUs
57-
devs="$(seq -s, 0 $((NUM_GPU-1)))"
58-
export CUDA_VISIBLE_DEVICES="$devs"
59-
fi
60-
6147
if [[ "$OFFLINE_DATA_PATH" != "" ]]; then
6248
OFFLINE_DATA_ARGS="--offline-data $OFFLINE_DATA_PATH"
6349
else
@@ -73,7 +59,6 @@ mkdir -p "$(dirname "$OUTPUT_DIR")"
7359
--output_dir $OUTPUT_DIR \
7460
$OFFLINE_DATA_ARGS \
7561
--data $DATA \
76-
--num_gpu $NUM_GPU \
7762
--num_epochs 2 \
7863
--eagle_config eagle_config.json
7964

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from typing import Any
3535

3636
import torch
37+
import transformers
38+
from packaging.version import Version
3739
from torch import nn
3840
from torch.nn import CrossEntropyLoss
3941
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
@@ -61,11 +63,22 @@
6163
temporary_set_config_value,
6264
)
6365

66+
__all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"]
67+
6468
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
6569
ENABLE_CP_TTT_PATCH = False
70+
# module variable to cache attention mask for cp ttt
6671
CACHED_SHARD_TTT_MASKS = {}
6772

6873

74+
def _get_empty_cache(config):
75+
"""Return an empty cache. Handle different versions of transformers for unit tests."""
76+
if Version(transformers.__version__) >= Version("4.54"):
77+
return DynamicCache(config=config)
78+
else:
79+
return DynamicCache()
80+
81+
6982
@MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
7083
class HFMedusaModel(MedusaModel):
7184
"""Medusa Model Class for huggingface models."""
@@ -852,9 +865,7 @@ def forward(
852865
base_model_logits = base_outputs["base_model_logits"]
853866
else:
854867
base_model_logits = self.lm_head(base_model_hidden_states)
855-
base_model_loss = None
856-
past_key_values = DynamicCache() # Dummy cache
857-
868+
base_model_loss, past_key_values = None, None
858869
else:
859870
base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = (
860871
self._base_model_forward(
@@ -869,9 +880,9 @@ def forward(
869880
)
870881

871882
if not isinstance(past_key_values, Cache):
872-
past_key_values = DynamicCache(config=self._base_llm_config)
883+
past_key_values = _get_empty_cache(self._base_llm_config)
873884
if not isinstance(eagle_cache, Cache):
874-
eagle_cache = DynamicCache(config=self.eagle_module.config)
885+
eagle_cache = _get_empty_cache(self.eagle_module.config)
875886

876887
# ====Run eagle forward====
877888
eagle_loss = None
@@ -908,7 +919,7 @@ def forward(
908919
if ttt_step == 0
909920
else self._get_ttt_attention_mask(b, seq_length, ttt_step)
910921
)
911-
with enable_cp_ttt_patch():
922+
with enable_cp_ttt_patch() if self.training else contextlib.nullcontext():
912923
_, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward(
913924
eagle_input_hidden_states,
914925
inputs_embeds,

modelopt/torch/speculative/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
from torch.nn.attention import SDPBackend, sdpa_kernel
3131
from transformers.cache_utils import DynamicCache
3232

33-
import modelopt.torch.speculative.plugins.transformers
34-
3533
KIMI_K2_REPO_ID = "moonshotai/Kimi-K2-Thinking"
3634
KIMI_K2_PACKAGE_NAME = "kimi_k2_temp"
3735

@@ -462,6 +460,8 @@ def ttt_msk_func(b, h, q_idx, kv_idx):
462460
@contextlib.contextmanager
463461
def enable_cp_ttt_patch():
464462
"""Context manager to enable CP TTT patch."""
463+
import modelopt.torch.speculative.plugins.transformers
464+
465465
modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True
466466
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
467467
try:

0 commit comments

Comments
 (0)