From 59a09841a4cb0b098a7e0ad90ed51bc73f6d75df Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 21 Jan 2026 06:53:55 +0000 Subject: [PATCH 1/2] squash: cp ttt Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 159 ++++++++++++++++++ .../speculative_decoding/fsdp_config.json | 1 + examples/speculative_decoding/launch_train.sh | 31 ++-- examples/speculative_decoding/main.py | 17 +- .../speculative_decoding/requirements.txt | 9 +- .../torch/speculative/plugins/transformers.py | 38 ++--- modelopt/torch/speculative/utils.py | 29 ++++ 7 files changed, 242 insertions(+), 42 deletions(-) create mode 100644 examples/speculative_decoding/fsdp_config.json diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 45c9c66321..ade92d21e1 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -13,9 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json import os +from collections.abc import Callable from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import FrameType from typing import Any import numpy as np @@ -24,10 +30,13 @@ from datasets import load_dataset from PIL import Image from scripts.ar_validate import validate_ar +from torch.distributed.tensor.experimental._attention import _SDPAMerger from torch.utils.data import Dataset from transformers import AutoProcessor, Trainer, TrainerCallback from transformers.trainer_pt_utils import LabelSmoother +import modelopt.torch.speculative.plugins.transformers +from modelopt.torch.speculative.utils import get_ttt_msk_func from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import is_master @@ -566,3 +575,153 @@ def on_step_end(self, args, state, control, **kwargs): except Exception: print_rank_0("AR validation not available.") return control + + +def _compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) -> torch.Tensor: + """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl.""" + + msk_func = get_ttt_msk_func(seq_length, ttt_step) + + dtypemin = torch.finfo(dtype).min + q_len = seq_length + kv_len = seq_length * (1 + ttt_step) + # Return tensor mask for non-flex attention + tensor_mask = msk_func( + None, + None, + torch.arange(q_len).view(1, 1, q_len, 1), + torch.arange(kv_len).view(1, 1, 1, kv_len), + ).to(torch.cuda.current_device()) + tensor_mask = torch.full_like( + tensor_mask, 0, dtype=dtype, device=torch.cuda.current_device() + ).masked_fill(~tensor_mask, dtypemin) + return tensor_mask + + +def get_patched_templated_ring_attn(orig_templated_attn: Callable): + """ + Return patched version of + torch.distributed.tensor.experimental._attention._templated_ring_attention + to support TTT. + """ + + def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype): + """Get chunk-interleaved TTT mask for current rank. + e.g.: + 2 ranks, ttt_step=1; + full_ttt_mask = [[0, 0, 0, 0, x, 0, 0, 0], + [x, 0, 0, 0, 0, x, 0, 0], + [x, x, 0, 0, 0, 0, x, 0], + [x, x, x, 0, 0, 0, 0, x], + + rank 0, step0: [[0, 0, x, 0], + [x, 0, 0, x]] + + rank 1, step0: [[0, 0, x, 0], + [x, 0, 0, x]] + + rank 0, step1: [[0, 0, 0, 0], + [0, 0, 0, 0]] + + rank 1, step1: [[x, x, 0, 0], + [x, x, 0, 0]] + + """ + device = torch.cuda.current_device() + q_indices = torch.arange(q_len * rank, q_len * (rank + 1), device=device) + kv_indices = ( + torch.arange(q_len * size * (ttt_step + 1), device=device) + .view(ttt_step + 1, size, q_len)[:, (rank - i) % size, :] + .reshape(-1) + ) + msk_func = get_ttt_msk_func(q_len * size, ttt_step) + attn_mask = msk_func( + None, + None, + q_indices.view(1, 1, -1, 1), + kv_indices.view(1, 1, 1, -1), + ) + attn_bias = torch.where( + attn_mask, + torch.zeros((), dtype=dtype, device=attn_mask.device), + torch.full((), torch.finfo(dtype).min, dtype=dtype, device=attn_mask.device), + ) + + return attn_bias + + def patched_templated_attn(*args, **kwargs): + """Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention.""" + # Get original attention op + # Sensitive to impl of _templated_ring_attention + original_op = args[2] + + # This patch is only enabled for eagle model by context manager, not base model. + patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH + + if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention: + raise ValueError(f"CP TTT only supports cuddn attention now. Got: {original_op}") + + # Unset is_causal to use custom attn mask + if patch_enbabled: + kwargs["is_causal"] = False + + def patched_op(*args, **kwargs): + # Inpect the parent frame to get current shard info + # This is sensitive to torch _templated_ring_attention impl + try: + frame: FrameType = inspect.currentframe() + f_back: FrameType = frame.f_back + rank = f_back.f_locals["rank"] + size = f_back.f_locals["size"] + query = f_back.f_locals["query"] + key = f_back.f_locals["key"] + i = f_back.f_locals["i"] + ttt_step = (key.shape[2] // query.shape[2]) - 1 + except Exception as e: + print(f"Failed to capture loop variables in patched _templated_ring_attention: {e}") + # Set attn mask to permuted TTT mask + if "attn_bias" in kwargs: + kwargs["attn_bias"] = _get_sharded_ttt_msk( + i, rank, size, query.shape[2], ttt_step, query.dtype + ) + # Perform shard attention + return original_op(*args, **kwargs) + + return orig_templated_attn(args[0], args[1], patched_op, *args[3:], **kwargs) + + return patched_templated_attn + + +def patch_ring_attention_for_ttt(): + """Patch torch ring attention to support context parallelism for TTT.""" + # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask. + + # 1. Disable load balance, which is designed for causal mask. + # This affect how buffers are sharded. So need to be done permenantly before accelerate/hf trainer init. + torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False + + # 2. Patch templated ring attention for TTT mask. + original_templated_ring_attention = ( + torch.distributed.tensor.experimental._attention._templated_ring_attention + ) + original_templated_ring_attention_backward = ( + torch.distributed.tensor.experimental._attention._templated_ring_attention_backward + ) + torch.distributed.tensor.experimental._attention._templated_ring_attention = ( + get_patched_templated_ring_attn(original_templated_ring_attention) + ) + torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = ( + get_patched_templated_ring_attn(original_templated_ring_attention_backward) + ) + + # 3. Patch merger to skip the blank shard to avoid difference in output. + original_sdpa_merger_step = _SDPAMerger.step + + def patched_sdpa_merger_step( + self, out: torch.Tensor, lse: torch.Tensor, partial: bool + ) -> torch.Tensor: + if lse.sum() <= 0: + return + return original_sdpa_merger_step(self, out, lse, partial) + + _SDPAMerger.step = patched_sdpa_merger_step diff --git a/examples/speculative_decoding/fsdp_config.json b/examples/speculative_decoding/fsdp_config.json new file mode 100644 index 0000000000..6d934182fe --- /dev/null +++ b/examples/speculative_decoding/fsdp_config.json @@ -0,0 +1 @@ +{"fsdp_version":2} \ No newline at end of file diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index e3b6c5a21d..ad49d614f4 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -74,14 +74,6 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi EAGLE_CONFIG="${1#*=}" ;; - --fsdp_transformer_layer_cls_to_wrap*) - if [[ "$1" != *=* ]]; then shift; fi - FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}" - ;; - --num_gpu*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_GPU="${1#*=}" - ;; --disable_tqdm*) if [[ "$1" != *=* ]]; then shift; fi DISABLE_TQDM="${1#*=}" @@ -102,6 +94,14 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi AR_VALIDATE_STEPS="${1#*=}" ;; + --cp_size*) + if [[ "$1" != *=* ]]; then shift; fi + CP_SIZE="${1#*=}" + ;; + --dp_size*) + if [[ "$1" != *=* ]]; then shift; fi + DP_SHARD_SIZE="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -129,8 +129,6 @@ LR=${LR:-"1e-4"} TRAIN_BS=${TRAIN_BS:-4} MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} -FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} -NUM_GPU=${NUM_GPU:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} DISABLE_TQDM=${DISABLE_TQDM:-False} @@ -138,6 +136,8 @@ VLM_PROCESSOR=${VLM_PROCESSOR:-} VLM_IMG_DIR=${VLM_IMG_DIR:-} AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} ESTIMATE_AR=${ESTIMATE_AR:-False} +CP_SIZE=${CP_SIZE:-1} +DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" @@ -163,11 +163,6 @@ else OFFLINE_TRAINING_ARGS="" fi -if [[ "$NUM_GPU" == 1 ]]; then - MULTI_GPU="" -else - MULTI_GPU="--multi_gpu" -fi if [[ "$VLM_PROCESSOR" != "" ]]; then VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR" @@ -177,7 +172,7 @@ fi # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False -CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ +CMD="accelerate launch --mixed_precision bf16 main.py \ --mode $MODE \ --eagle_decoder_type $EAGLE_DECODER_TYPE \ --model_name_or_path $MODEL \ @@ -206,6 +201,10 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS \ + --fsdp 'full_shard' \ + --fsdp_config fsdp_config.json \ + --cp_size $CP_SIZE \ + --dp_shard_size $DP_SHARD_SIZE \ " start_time=$(date +%s) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index cd1af9563b..f8452cd906 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -36,7 +36,13 @@ import torch import transformers -from eagle_utils import EagleTrainerWithAccLog, EagleTrainingPlot, make_eagle_supervised_data_module +from accelerate import ParallelismConfig +from eagle_utils import ( + EagleTrainerWithAccLog, + EagleTrainingPlot, + make_eagle_supervised_data_module, + patch_ring_attention_for_ttt, +) from medusa_utils import make_medusa_supervised_data_module from transformers.trainer_utils import get_last_checkpoint @@ -100,6 +106,8 @@ class TrainingArguments(transformers.TrainingArguments): remove_unused_columns: bool = field( default=False, metadata={"help": "Set to False to keep extra args for VLM."} ) + cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) + dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."}) @dataclass @@ -130,6 +138,13 @@ def train(): model_args, data_args, training_args, medusa_args, eagle_args = ( parser.parse_args_into_dataclasses() ) + training_args.parallelism_config = ParallelismConfig( + cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size + ) + if training_args.cp_size > 1: + patch_ring_attention_for_ttt() + # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 + training_args.parallelism_config.sp_backend = None print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}") # Detecting last checkpoint. diff --git a/examples/speculative_decoding/requirements.txt b/examples/speculative_decoding/requirements.txt index 765af61041..176e43a65c 100644 --- a/examples/speculative_decoding/requirements.txt +++ b/examples/speculative_decoding/requirements.txt @@ -1,5 +1,4 @@ -flash-attn -openai -py7zr -sentencepiece>=0.2.0 -tensorboardX +accelerate==1.12.0 +torch==2.8.0 +transformers==5.0.0rc1 +wandb diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 517ddd9b4a..561dc9cf22 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -56,10 +56,14 @@ AcceptanceRateValidation, ResBlock, _setup_kimi_k2_decoder, + enable_cp_ttt_patch, + get_ttt_msk_func, temporary_set_config_value, ) IGNORE_TOKEN_ID = LabelSmoother.ignore_index +ENABLE_CP_TTT_PATCH = False +CACHED_SHARD_TTT_MASKS = {} @MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) @@ -370,7 +374,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -678,16 +682,7 @@ def _compute_ttt_attention_mask( self, batch_size, seq_length, ttt_step ) -> BlockMask | torch.Tensor: """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl.""" - - def msk_func(b, h, q_idx, kv_idx): - mask = kv_idx <= (q_idx - ttt_step) - for i in range(1, ttt_step + 1): - mask_block_i = (kv_idx == q_idx + i * seq_length - (ttt_step - i)) & ( - kv_idx >= seq_length * i - ) - mask = mask | mask_block_i - return mask - + msk_func = get_ttt_msk_func(seq_length, ttt_step) dtypemin = torch.finfo(self._base_llm_config.dtype).min q_len = seq_length kv_len = seq_length * (1 + ttt_step) @@ -874,9 +869,9 @@ def forward( ) if not isinstance(past_key_values, Cache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values = DynamicCache(config=self._base_llm_config) if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) + eagle_cache = DynamicCache(config=self.eagle_module.config) # ====Run eagle forward==== eagle_loss = None @@ -907,18 +902,20 @@ def forward( # ====Perform training-time-testing with 3 extra eagle forward passes==== for ttt_step in range(self.num_ttt_steps): + # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then. attention_mask = ( attention_mask_0 if ttt_step == 0 else self._get_ttt_attention_mask(b, seq_length, ttt_step) ) - _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward( - eagle_input_hidden_states, - inputs_embeds, - attention_mask, - position_ids, - eagle_cache, - ) + with enable_cp_ttt_patch(): + _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward( + eagle_input_hidden_states, + inputs_embeds, + attention_mask, + position_ids, + eagle_cache, + ) eagle_input_hidden_states = torch.cat( ( torch.zeros( @@ -989,6 +986,7 @@ def _eagle_loss( assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) loss_mask = loss_mask[:, :, None] + loss_mask = loss_mask[:, : eagle_logits.shape[1]] classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)( eagle_logits ) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 1f919de065..a3f91ce252 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -27,8 +27,11 @@ import torch.distributed from huggingface_hub import snapshot_download from torch import nn +from torch.nn.attention import SDPBackend, sdpa_kernel from transformers.cache_utils import DynamicCache +import modelopt.torch.speculative.plugins.transformers + KIMI_K2_REPO_ID = "moonshotai/Kimi-K2-Thinking" KIMI_K2_PACKAGE_NAME = "kimi_k2_temp" @@ -439,3 +442,29 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs): kimi_k2_module.DeepseekV3Attention.forward = patched_fwd_with_lazy_rope_init return getattr(kimi_k2_module, "DeepseekV3DecoderLayer") + + +def get_ttt_msk_func(seq_length, ttt_step): + """Return mask function for Eagle3 Training Time Test.""" + + def ttt_msk_func(b, h, q_idx, kv_idx): + mask = kv_idx <= (q_idx - ttt_step) + for i in range(1, ttt_step + 1): + mask_block_i = (kv_idx == q_idx + i * seq_length - (ttt_step - i)) & ( + kv_idx >= seq_length * i + ) + mask = mask | mask_block_i + return mask + + return ttt_msk_func + + +@contextlib.contextmanager +def enable_cp_ttt_patch(): + """Context manager to enable CP TTT patch.""" + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + try: + yield + finally: + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False From 670d6758ce79d89b0905e4af252524d972118ba6 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sat, 24 Jan 2026 01:01:42 +0000 Subject: [PATCH 2/2] fix tests Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .github/workflows/example_tests.yml | 29 +++++++++-- examples/speculative_decoding/README.md | 8 ++-- examples/speculative_decoding/eagle_utils.py | 48 ++++++++----------- examples/speculative_decoding/launch_train.sh | 12 ++++- .../speculative_decoding/requirements.txt | 2 - .../train_eagle3_and_export.sh | 19 +------- .../torch/speculative/plugins/transformers.py | 23 ++++++--- modelopt/torch/speculative/utils.py | 4 +- .../speculative_decoding/test_eagle.py | 25 +++++++--- .../speculative_decoding/test_medusa.py | 7 +-- 10 files changed, 101 insertions(+), 76 deletions(-) diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index fef3dc4db8..a64528dd57 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -63,7 +63,7 @@ jobs: strategy: fail-fast: false matrix: - example: [llm_distill, llm_qat, llm_sparsity, speculative_decoding] + example: [llm_distill, llm_qat, llm_sparsity] uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: @@ -77,7 +77,7 @@ jobs: strategy: fail-fast: false matrix: - example: [llm_distill, llm_qat, llm_sparsity, speculative_decoding] + example: [llm_distill, llm_qat, llm_sparsity] uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: @@ -86,6 +86,28 @@ jobs: pip_install_extras: "[hf,dev-test]" runner: linux-amd64-gpu-h100-latest-2 + ##### Speculative Decoding Example Tests (requires 25.08 image) ##### + speculative-decoding-pr: + needs: [check-file-changes, wait-checks] + if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/pytorch:25.08-py3" + example: speculative_decoding + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-l4-latest-1 + + speculative-decoding-non-pr: + if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + uses: ./.github/workflows/_example_tests_runner.yml + secrets: inherit + with: + docker_image: "nvcr.io/nvidia/pytorch:25.08-py3" + example: speculative_decoding + pip_install_extras: "[hf,dev-test]" + runner: linux-amd64-gpu-h100-latest-2 + ##### TensorRT-LLM Example Tests ##### trtllm-pr: needs: [check-file-changes, wait-checks] @@ -150,7 +172,7 @@ jobs: example-pr-required-check: # Run even if example tests are skipped if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }} - needs: [check-file-changes, torch-pr, trtllm-pr, onnx-pr] + needs: [check-file-changes, torch-pr, speculative-decoding-pr, trtllm-pr, onnx-pr] runs-on: ubuntu-latest steps: - name: Required GPU tests did not succeed @@ -158,6 +180,7 @@ jobs: needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && ( needs.torch-pr.result != 'success' || + needs.speculative-decoding-pr.result != 'success' || needs.trtllm-pr.result != 'success' || needs.onnx-pr.result != 'success' )) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index c495809bb9..3294ba653c 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -30,7 +30,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM, ### Docker -Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information. +Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information. Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies. @@ -56,7 +56,7 @@ See [other-datasets](#other-datasets) section for other dataset options and inst ## Getting Started: Simplified Workflow ```bash -bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4 +bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct ``` This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it @@ -74,12 +74,11 @@ For small base models that fit in GPU memory, we can collocate them with draft m ./launch_train.sh --model $BASE_MODEL \ --output_dir $OUTPUT_DIR \ --data input_conversations/daring-anteater.jsonl \ - --num_gpu $NUM_GPU \ --num_epochs $NUM_EPOCH \ --eagle_config eagle_config.json ``` -This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details. +FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. ## Training Draft Model with Offline Base Model @@ -118,7 +117,6 @@ Once we finish dumping hidden states, launch offline training with an extra `--o ./launch_train.sh --model $BASE_MODEL \ --output_dir $OUTPUT_DIR \ --data $DATA \ - --num_gpu $NUM_GPU \ --num_epochs $NUM_EPOCH \ --eagle_config eagle_config.json \ --offline-data $HIDDEN_STATES_DIR diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index ade92d21e1..05ec3f864c 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -28,6 +28,7 @@ import torch import transformers from datasets import load_dataset +from packaging.version import Version from PIL import Image from scripts.ar_validate import validate_ar from torch.distributed.tensor.experimental._attention import _SDPAMerger @@ -35,7 +36,7 @@ from transformers import AutoProcessor, Trainer, TrainerCallback from transformers.trainer_pt_utils import LabelSmoother -import modelopt.torch.speculative.plugins.transformers +import modelopt from modelopt.torch.speculative.utils import get_ttt_msk_func from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import is_master @@ -577,27 +578,6 @@ def on_step_end(self, args, state, control, **kwargs): return control -def _compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) -> torch.Tensor: - """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl.""" - - msk_func = get_ttt_msk_func(seq_length, ttt_step) - - dtypemin = torch.finfo(dtype).min - q_len = seq_length - kv_len = seq_length * (1 + ttt_step) - # Return tensor mask for non-flex attention - tensor_mask = msk_func( - None, - None, - torch.arange(q_len).view(1, 1, q_len, 1), - torch.arange(kv_len).view(1, 1, 1, kv_len), - ).to(torch.cuda.current_device()) - tensor_mask = torch.full_like( - tensor_mask, 0, dtype=dtype, device=torch.cuda.current_device() - ).masked_fill(~tensor_mask, dtypemin) - return tensor_mask - - def get_patched_templated_ring_attn(orig_templated_attn: Callable): """ Return patched version of @@ -659,14 +639,14 @@ def patched_templated_attn(*args, **kwargs): patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention: - raise ValueError(f"CP TTT only supports cuddn attention now. Got: {original_op}") + raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}") # Unset is_causal to use custom attn mask if patch_enbabled: kwargs["is_causal"] = False def patched_op(*args, **kwargs): - # Inpect the parent frame to get current shard info + # Inspect the parent frame to get current shard info # This is sensitive to torch _templated_ring_attention impl try: frame: FrameType = inspect.currentframe() @@ -678,7 +658,9 @@ def patched_op(*args, **kwargs): i = f_back.f_locals["i"] ttt_step = (key.shape[2] // query.shape[2]) - 1 except Exception as e: - print(f"Failed to capture loop variables in patched _templated_ring_attention: {e}") + raise RuntimeError( + f"Failed to capture loop variables in patched _templated_ring_attention: {e}" + ) from e # Set attn mask to permuted TTT mask if "attn_bias" in kwargs: kwargs["attn_bias"] = _get_sharded_ttt_msk( @@ -696,8 +678,18 @@ def patch_ring_attention_for_ttt(): """Patch torch ring attention to support context parallelism for TTT.""" # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask. + if not ( + Version(torch.__version__) > Version("2.7.1") + and Version(torch.__version__) < Version("2.9.0") + ): + raise RuntimeError( + f"Context parallel TTT only supported for PyTorch 2.8.0 now. " + f"Got {torch.__version__}. " + f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1." + ) + # 1. Disable load balance, which is designed for causal mask. - # This affect how buffers are sharded. So need to be done permenantly before accelerate/hf trainer init. + # This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init. torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False # 2. Patch templated ring attention for TTT mask. @@ -717,9 +709,7 @@ def patch_ring_attention_for_ttt(): # 3. Patch merger to skip the blank shard to avoid difference in output. original_sdpa_merger_step = _SDPAMerger.step - def patched_sdpa_merger_step( - self, out: torch.Tensor, lse: torch.Tensor, partial: bool - ) -> torch.Tensor: + def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool): if lse.sum() <= 0: return return original_sdpa_merger_step(self, out, lse, partial) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index ad49d614f4..c937d5b097 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -170,6 +170,15 @@ else VLM_ARGS="" fi +if [[ "$GPU_COUNT" -gt 1 ]]; then + #Use FSDP2 when multi GPU available + FSDP_ARGS="--fsdp 'full_shard' --fsdp_config fsdp_config.json" +else + #Otherwise, single GPU training + FSDP_ARGS="" +fi + + # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False CMD="accelerate launch --mixed_precision bf16 main.py \ @@ -201,8 +210,7 @@ CMD="accelerate launch --mixed_precision bf16 main.py \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS \ - --fsdp 'full_shard' \ - --fsdp_config fsdp_config.json \ + $FSDP_ARGS \ --cp_size $CP_SIZE \ --dp_shard_size $DP_SHARD_SIZE \ " diff --git a/examples/speculative_decoding/requirements.txt b/examples/speculative_decoding/requirements.txt index 176e43a65c..6324bac62b 100644 --- a/examples/speculative_decoding/requirements.txt +++ b/examples/speculative_decoding/requirements.txt @@ -1,4 +1,2 @@ accelerate==1.12.0 -torch==2.8.0 transformers==5.0.0rc1 -wandb diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index 60c57d1c3a..4e117d1238 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -17,12 +17,11 @@ set -eo pipefail -# Set default values for BASE_MODEL, NUM_GPU, and DATA +# Set default values for BASE_MODEL and DATA BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct -NUM_GPU=1 DATA=input_conversations/daring-anteater.jsonl -# Parse input arguments --base_model, --num_gpu, and --data +# Parse input arguments --base_model and --data while [[ $# -gt 0 ]]; do key="$1" case $key in @@ -30,10 +29,6 @@ while [[ $# -gt 0 ]]; do BASE_MODEL="$2" shift; shift ;; - --num_gpu) - NUM_GPU="$2" - shift; shift - ;; --data) DATA="$2" shift; shift @@ -49,15 +44,6 @@ while [[ $# -gt 0 ]]; do esac done - -if [[ "$NUM_GPU" == 1 ]]; then - export CUDA_VISIBLE_DEVICES=0 -else - # Export as 0,1,...,N-1 for NUM_GPU GPUs - devs="$(seq -s, 0 $((NUM_GPU-1)))" - export CUDA_VISIBLE_DEVICES="$devs" -fi - if [[ "$OFFLINE_DATA_PATH" != "" ]]; then OFFLINE_DATA_ARGS="--offline-data $OFFLINE_DATA_PATH" else @@ -73,7 +59,6 @@ mkdir -p "$(dirname "$OUTPUT_DIR")" --output_dir $OUTPUT_DIR \ $OFFLINE_DATA_ARGS \ --data $DATA \ - --num_gpu $NUM_GPU \ --num_epochs 2 \ --eagle_config eagle_config.json diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 561dc9cf22..3090297aa1 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -34,6 +34,8 @@ from typing import Any import torch +import transformers +from packaging.version import Version from torch import nn from torch.nn import CrossEntropyLoss from torch.nn.attention.flex_attention import BlockMask, create_block_mask @@ -61,11 +63,22 @@ temporary_set_config_value, ) +__all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"] + IGNORE_TOKEN_ID = LabelSmoother.ignore_index ENABLE_CP_TTT_PATCH = False +# module variable to cache attention mask for cp ttt CACHED_SHARD_TTT_MASKS = {} +def _get_empty_cache(config): + """Return an empty cache. Handle different versions of transformers for unit tests.""" + if Version(transformers.__version__) >= Version("4.54"): + return DynamicCache(config=config) + else: + return DynamicCache() + + @MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFMedusaModel(MedusaModel): """Medusa Model Class for huggingface models.""" @@ -852,9 +865,7 @@ def forward( base_model_logits = base_outputs["base_model_logits"] else: base_model_logits = self.lm_head(base_model_hidden_states) - base_model_loss = None - past_key_values = DynamicCache() # Dummy cache - + base_model_loss, past_key_values = None, None else: base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = ( self._base_model_forward( @@ -869,9 +880,9 @@ def forward( ) if not isinstance(past_key_values, Cache): - past_key_values = DynamicCache(config=self._base_llm_config) + past_key_values = _get_empty_cache(self._base_llm_config) if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache(config=self.eagle_module.config) + eagle_cache = _get_empty_cache(self.eagle_module.config) # ====Run eagle forward==== eagle_loss = None @@ -908,7 +919,7 @@ def forward( if ttt_step == 0 else self._get_ttt_attention_mask(b, seq_length, ttt_step) ) - with enable_cp_ttt_patch(): + with enable_cp_ttt_patch() if self.training else contextlib.nullcontext(): _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hidden_states, inputs_embeds, diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index a3f91ce252..d259a1fce6 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -30,8 +30,6 @@ from torch.nn.attention import SDPBackend, sdpa_kernel from transformers.cache_utils import DynamicCache -import modelopt.torch.speculative.plugins.transformers - KIMI_K2_REPO_ID = "moonshotai/Kimi-K2-Thinking" KIMI_K2_PACKAGE_NAME = "kimi_k2_temp" @@ -462,6 +460,8 @@ def ttt_msk_func(b, h, q_idx, kv_idx): @contextlib.contextmanager def enable_cp_ttt_patch(): """Context manager to enable CP TTT patch.""" + import modelopt.torch.speculative.plugins.transformers + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): try: diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 6bf1c79d2b..deabba5702 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -17,7 +17,9 @@ import pytest import safetensors.torch +import torch from _test_utils.examples.run_command import run_example_command +from packaging.version import Version from modelopt.torch.export.plugins.hf_spec_export import LLAMA_EAGLE_SINGLE_LAYER @@ -29,8 +31,17 @@ def eagle_output_dir(tmp_path_factory): # fmt: off -def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path, eagle_output_dir): - """Test Eagle3 training with a tiny llama model.""" +@pytest.mark.parametrize("cp_size", [1, 2]) +def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, cp_size): + """Test Eagle3 training with a tiny llama model, using different cp_size values.""" + available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + if cp_size == 2 and available_gpus < 2: + pytest.skip("cp_size=2 requires at least 2 GPUs, but only {} found.".format(available_gpus)) + if cp_size == 2 and not ( + Version(torch.__version__) > Version("2.7.0") + and Version(torch.__version__) < Version("2.9.0") + ): + pytest.skip("cp_size=2 requires torch 2.8.0") # Create an ultra-tiny EAGLE config for testing to reduce memory usage tiny_eagle_config = { "max_position_embeddings": 128, @@ -42,7 +53,7 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_ } # Write the tiny config to a temporary file - config_file = tmp_path / "tiny_eagle_config.json" + config_file = tmp_path / f"tiny_eagle_config_cp{cp_size}.json" with open(config_file, "w") as f: json.dump(tiny_eagle_config, f) @@ -53,11 +64,11 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_ "--data", tiny_daring_anteater_path, "--num_epochs", "1", "--lr", "1e-5", - "--num_gpu", str(num_gpus), "--mode", "eagle3", "--eagle_config", str(config_file), - "--output_dir", eagle_output_dir / "eagle-tinyllama", + "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}", "--training_seq_len", "128", # Match max_position_embeddings + "--cp_size", str(cp_size), ], "speculative_decoding", ) @@ -68,7 +79,7 @@ def test_ar_validate(eagle_output_dir): run_example_command( [ "python", "./scripts/ar_validate.py", - "--model_path", eagle_output_dir / "eagle-tinyllama", + "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", "--osl", "20", "--num_samples", "10", "--steps", "3" @@ -82,7 +93,7 @@ def test_export_hf_checkpoint(eagle_output_dir): run_example_command( [ "python", "./scripts/export_hf_checkpoint.py", - "--model_path", eagle_output_dir / "eagle-tinyllama", + "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", "--export_path", eagle_output_dir / "eagle-tinyllama-export", ], "speculative_decoding", diff --git a/tests/examples/speculative_decoding/test_medusa.py b/tests/examples/speculative_decoding/test_medusa.py index 488e24855b..545b79d7ea 100644 --- a/tests/examples/speculative_decoding/test_medusa.py +++ b/tests/examples/speculative_decoding/test_medusa.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import pytest from _test_utils.examples.run_command import run_example_command @@ -32,7 +32,7 @@ def _run_hf_ptq(model_path, output_dir, qformat): ) -def test_llama_medusa_fp8_qat(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path): +def test_llama_medusa_fp8_qat(tiny_llama_path, tiny_daring_anteater_path, tmp_path): medusa_path = tmp_path / "medusa-tinyllama" # Test Medusa @@ -43,7 +43,6 @@ def test_llama_medusa_fp8_qat(tiny_llama_path, num_gpus, tiny_daring_anteater_pa "--data", tiny_daring_anteater_path, "--num_epochs", "1", "--lr", "1e-5", - "--num_gpu", str(num_gpus), "--mode", "medusa", "--output_dir", medusa_path, "--medusa_num_heads", "2", @@ -52,6 +51,8 @@ def test_llama_medusa_fp8_qat(tiny_llama_path, num_gpus, tiny_daring_anteater_pa "speculative_decoding", ) + pytest.skip("speculative decoding uses transformers 5.x, quantization example uses transformers 4.x") + # Test PTQ on Medusa _run_hf_ptq(medusa_path, tmp_path / "medusa-tinyllama-hf", "fp8")