From 4b7645238da1c23c624f3062f468225eed0b4f6d Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 02:40:42 +0000 Subject: [PATCH 01/15] nvtx annotations Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/plugins/transformers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1b85c342e75..d482e35fa44 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -35,6 +35,7 @@ from typing import Any import torch +from torch.cuda import nvtx import transformers from packaging.version import Version from torch import nn @@ -657,6 +658,7 @@ def _prepare_decoder_attention_mask( return combined_attention_mask + @nvtx.range("prepare_eagle_inputs") def _prepare_eagle_inputs( self, input_ids, @@ -746,6 +748,7 @@ def _compute_ttt_attention_mask( tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask + @nvtx.range("base_model_forward") def _base_model_forward( self, input_ids, @@ -794,6 +797,7 @@ def _map_logits_to_draft_vocab(self, full_logits): ) return full_logits[:, :, reverse_mapping] + @nvtx.range("eagle_forward") def _eagle_forward( self, eagle_input_hidden_states, @@ -977,6 +981,7 @@ def forward( train_acc=train_accs, ) + @nvtx.range("eagle_loss") def _eagle_loss( self, base_model_logits, From 7e10294fc0894cf1cd5b92b892d9d610b16fc8a8 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 05:09:15 +0000 Subject: [PATCH 02/15] fix fsdp crash Signed-off-by: Benjamin Chislett --- examples/speculative_decoding/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 25817ee94fc..e3454c9b2f4 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -145,9 +145,10 @@ def train(): model_args, data_args, training_args, medusa_args, eagle_args = ( parser.parse_args_into_dataclasses() ) - training_args.parallelism_config = ParallelismConfig( - cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size - ) + if training_args.cp_size > 1 or training_args.dp_shard_size > 1: + training_args.parallelism_config = ParallelismConfig( + cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size + ) if training_args.cp_size > 1: patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 From 90cbd091fe8fb129ebe3a8930a12fdf78ad8f716 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 05:10:21 +0000 Subject: [PATCH 03/15] fix rope init Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/plugins/transformers.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index d482e35fa44..56cf6882d77 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -285,6 +285,9 @@ def __init__(self, config, decoder_layer_cls, bias=False): ) self.layers[0].hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.eagle_decoder_type == "llama": + self.rotary_emb = LlamaRotaryEmbedding(config=config) + if self.config.parallel_draft_step > 1: self.parallel_draft_heads = ParallelDraft( config.hidden_size, @@ -373,11 +376,6 @@ def forward( self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) if self.config.eagle_decoder_type == "llama": - # Lazy init rope to avoid save/load meta tensor error - if not hasattr(self, "rotary_emb"): - self.rotary_emb = LlamaRotaryEmbedding( - config=self.config, device=hidden_states.device - ) position_embeddings = self.rotary_emb(hidden_states, position_ids) else: position_embeddings = None From 3d40373c31874dd79675186bbef3f0de9573cb4a Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 05:12:23 +0000 Subject: [PATCH 04/15] optimize loss calculation by avoiding extra softmax calls and .item() calls Signed-off-by: Benjamin Chislett --- .../torch/speculative/plugins/transformers.py | 43 ++++++++++++------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 56cf6882d77..69069076615 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -716,7 +716,14 @@ def _prepare_eagle_inputs( else: eagle_position_ids = position_ids.view(-1, seq_length).long() - return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids + base_model_logits = base_outputs.logits + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) + base_output_predict_tok = base_model_logits.argmax(dim=-1).detach() + base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach() + + return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids, \ + base_output_predict_tok, base_output_softmax_logits def _compute_ttt_attention_mask( self, batch_size, seq_length, ttt_step @@ -892,13 +899,17 @@ def forward( # ====Prepare inputs for the first eagle forward pass==== eagle_loss = None - train_accs = [[] for _ in range(self.eagle_config.parallel_draft_step)] + num_parallel = self.eagle_config.parallel_draft_step + num_ttt = self.eagle_ttt_steps + train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device) b, seq_length, _ = base_outputs.out_hiddens.shape ( eagle_input_embeds, eagle_input_hiddens, eagle_attn_mask_0, eagle_position_ids, + base_output_predict_tok, + base_output_softmax_logits ) = self._prepare_eagle_inputs( input_ids, attention_mask, @@ -951,7 +962,8 @@ def forward( # base model predict +1 tok, while eagle predict +2 # so we shift base model outputs compared to eagle outputs # additionally, we mask the first n tok of eagle outputs at nth TTT step - base_outputs.logits[:, 1 + i + ttt_step :], + base_output_softmax_logits[:, 1 + i + ttt_step :], + base_output_predict_tok[:, 1 + i + ttt_step :], eagle_logit[:, ttt_step : -(1 + i)], loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i], ) @@ -960,10 +972,13 @@ def forward( eagle_loss = ( classification_loss if eagle_loss is None else eagle_loss + classification_loss ) - train_accs[i].append(acc) + train_accs[i, ttt_step] = acc if not self.training: break + # Slice by actual number of steps taken, in case of early return + train_accs = train_accs[:, : ttt_step + 1].tolist() + # Merge base model loss and eagle loss if base_outputs.loss is None and eagle_loss is None: loss = None @@ -982,27 +997,23 @@ def forward( @nvtx.range("eagle_loss") def _eagle_loss( self, - base_model_logits, + base_output_softmax_logits, + base_output_predict_tok, eagle_logits, loss_mask, ): """Function for EAGLE loss computing.""" - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) loss_mask = loss_mask[:, : eagle_logits.shape[1], None] - classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)( - eagle_logits - ) - classification_loss = -torch.sum(torch.sum(loss_mask * classification_loss, 2)) / ( + eagle_logsoft = torch.log_softmax(eagle_logits, dim=2) + classification_loss = -torch.sum(torch.sum(loss_mask * base_output_softmax_logits * eagle_logsoft, 2)) / ( loss_mask.sum() + 1e-5 ) - # Compute accuracy - base_predict_tok = base_model_logits.clone().detach().argmax(dim=-1) - eagle_predict_tok = eagle_logits.clone().detach().argmax(dim=-1) + # Compute accuracy (returned as tensor to avoid sync; .item() called after TTT loop) + eagle_predict_tok = eagle_logits.detach().argmax(dim=-1) valid = loss_mask[:, :, 0].bool() - correct = (base_predict_tok == eagle_predict_tok) & valid + correct = (base_output_predict_tok == eagle_predict_tok) & valid denom = valid.sum().clamp_min(1).float() - accuracy = round(correct.sum().float().div(denom).item(), 3) + accuracy = correct.sum().float() / denom return classification_loss, accuracy From 0309c192757f73814fb0d8eb54d1531948c07583 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 05:34:49 +0000 Subject: [PATCH 05/15] fix precommit Signed-off-by: Benjamin Chislett --- .../torch/speculative/plugins/transformers.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 69069076615..f8d79b6dd6c 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -35,10 +35,10 @@ from typing import Any import torch -from torch.cuda import nvtx import transformers from packaging.version import Version from torch import nn +from torch.cuda import nvtx from torch.nn import CrossEntropyLoss from torch.nn.attention.flex_attention import BlockMask, create_block_mask from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel @@ -722,8 +722,14 @@ def _prepare_eagle_inputs( base_output_predict_tok = base_model_logits.argmax(dim=-1).detach() base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach() - return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids, \ - base_output_predict_tok, base_output_softmax_logits + return ( + eagle_input_embeds, + eagle_input_hiddens, + eagle_attention_mask, + eagle_position_ids, + base_output_predict_tok, + base_output_softmax_logits, + ) def _compute_ttt_attention_mask( self, batch_size, seq_length, ttt_step @@ -909,7 +915,7 @@ def forward( eagle_attn_mask_0, eagle_position_ids, base_output_predict_tok, - base_output_softmax_logits + base_output_softmax_logits, ) = self._prepare_eagle_inputs( input_ids, attention_mask, @@ -1005,9 +1011,9 @@ def _eagle_loss( """Function for EAGLE loss computing.""" loss_mask = loss_mask[:, : eagle_logits.shape[1], None] eagle_logsoft = torch.log_softmax(eagle_logits, dim=2) - classification_loss = -torch.sum(torch.sum(loss_mask * base_output_softmax_logits * eagle_logsoft, 2)) / ( - loss_mask.sum() + 1e-5 - ) + classification_loss = -torch.sum( + torch.sum(loss_mask * base_output_softmax_logits * eagle_logsoft, 2) + ) / (loss_mask.sum() + 1e-5) # Compute accuracy (returned as tensor to avoid sync; .item() called after TTT loop) eagle_predict_tok = eagle_logits.detach().argmax(dim=-1) valid = loss_mask[:, :, 0].bool() From 44c00b85d2b62dc03fb88a55dac20302340255a7 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 19:20:54 +0000 Subject: [PATCH 06/15] fix tests Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/plugins/transformers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index f8d79b6dd6c..b86f602279c 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -285,9 +285,6 @@ def __init__(self, config, decoder_layer_cls, bias=False): ) self.layers[0].hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - if config.eagle_decoder_type == "llama": - self.rotary_emb = LlamaRotaryEmbedding(config=config) - if self.config.parallel_draft_step > 1: self.parallel_draft_heads = ParallelDraft( config.hidden_size, @@ -296,6 +293,10 @@ def __init__(self, config, decoder_layer_cls, bias=False): num_layers=self.config.parallel_draft_heads_num_layers, ) + def _maybe_init_rope(self): + if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"): + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + def _expand_first_attn_in_dim(self, first_layer_attn): """Modify qkv projection in first layer to accept 2h hidden size.""" # Find Linear modules to expand @@ -924,6 +925,8 @@ def forward( base_outputs, ) + self.eagle_module._maybe_init_rope() + # ====Run eagle forward with extra training-time-test steps==== for ttt_step in range(self.eagle_ttt_steps): # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then. From b30d95cc268d98fc588d13bc795b8140a7f37b66 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 19:43:17 +0000 Subject: [PATCH 07/15] fix comment Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/plugins/transformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index b86f602279c..eff44cf6259 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -1062,6 +1062,7 @@ def pseudo_speculative_generate( else: eagle_input_hidden_states = base_model_hidden_states + self.eagle_module._maybe_init_rope() draft_tokens = [] for step in range(steps): b, seq_length = eagle_ids.shape From 281577370657273933901c0dffaceffdba93a8d8 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 21:02:29 +0000 Subject: [PATCH 08/15] torch.compile annotations Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/plugins/transformers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index eff44cf6259..76560b520dd 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -77,6 +77,10 @@ # module variable to cache attention mask for cp ttt CACHED_SHARD_TTT_MASKS = {} +# ALL_ATTENTION_FUNCTIONS["flex_attention"] = partial( +# ALL_ATTENTION_FUNCTIONS["flex_attention"], kernel_options={"BACKEND": "FLASH"} +# ) + def _get_empty_cache(config): """Return an empty cache. Handle different versions of transformers for unit tests.""" @@ -658,6 +662,7 @@ def _prepare_decoder_attention_mask( return combined_attention_mask @nvtx.range("prepare_eagle_inputs") + @torch.compile(dynamic=False) def _prepare_eagle_inputs( self, input_ids, @@ -810,6 +815,7 @@ def _map_logits_to_draft_vocab(self, full_logits): return full_logits[:, :, reverse_mapping] @nvtx.range("eagle_forward") + @torch.compile(dynamic=False, mode="max-autotune") def _eagle_forward( self, eagle_input_hidden_states, @@ -1004,6 +1010,7 @@ def forward( ) @nvtx.range("eagle_loss") + @torch.compile(dynamic=False, fullgraph=True) def _eagle_loss( self, base_output_softmax_logits, From 67e4beed444f4153becb44b2abed17a71fc64de9 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 16 Mar 2026 21:03:19 +0000 Subject: [PATCH 09/15] remove dead code Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/plugins/transformers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 76560b520dd..641ad664d53 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -77,10 +77,6 @@ # module variable to cache attention mask for cp ttt CACHED_SHARD_TTT_MASKS = {} -# ALL_ATTENTION_FUNCTIONS["flex_attention"] = partial( -# ALL_ATTENTION_FUNCTIONS["flex_attention"], kernel_options={"BACKEND": "FLASH"} -# ) - def _get_empty_cache(config): """Return an empty cache. Handle different versions of transformers for unit tests.""" From a637eef5f5f7471df7a6f7fd4fca773a910a4cdd Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Wed, 18 Mar 2026 16:55:47 +0000 Subject: [PATCH 10/15] make nvtx safe and torch.compile optional (on by default) Signed-off-by: Benjamin Chislett --- examples/speculative_decoding/launch_train.sh | 6 ++++++ examples/speculative_decoding/main.py | 5 +++++ modelopt/torch/speculative/config.py | 5 +++++ .../torch/speculative/eagle/eagle_model.py | 1 + modelopt/torch/speculative/eagle/utils.py | 13 ++++++++++++ .../torch/speculative/plugins/transformers.py | 21 +++++++++++-------- 6 files changed, 42 insertions(+), 9 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 074151c5a0c..4cc3c76c4d7 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -114,6 +114,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi MIX_HIDDEN_STATES="${1#*=}" ;; + --disable_torch_compile*) + if [[ "$1" != *=* ]]; then shift; fi + DISABLE_TORCH_COMPILE="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -154,6 +158,7 @@ DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} +DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"} if [[ "$MODE" == "eagle3" ]]; then @@ -240,6 +245,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ --mix_hidden_states $MIX_HIDDEN_STATES \ + --disable_torch_compile $DISABLE_TORCH_COMPILE \ $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index e3454c9b2f4..29741ac3688 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -130,6 +130,10 @@ class EagleArguments: default=False, metadata={"help": "Whether to mix hidden states from previous TTT step."}, ) + disable_torch_compile: bool = field( + default=False, + metadata={"help": "Disable torch.compile on eagle forward/loss methods."}, + ) def train(): @@ -209,6 +213,7 @@ def train(): "eagle_decoder_type": eagle_args.eagle_decoder_type, "eagle_offline": use_offline_training, "eagle_mix_hidden_states": eagle_args.mix_hidden_states, + "eagle_use_torch_compile": not eagle_args.disable_torch_compile, "eagle_architecture_config": custom_config, } diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index b28dca61f57..e22f6726f1d 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -110,3 +110,8 @@ class EagleConfig(ModeloptBaseConfig): "Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost." ), ) + + eagle_use_torch_compile: bool = ModeloptField( + default=True, + description="Whether to use torch.compile on eagle forward/loss methods for faster training.", + ) diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index 85251c86a21..b9532b6c54a 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -39,3 +39,4 @@ def modify( self.eagle_decoder_type = config.eagle_decoder_type self.eagle_ttt_steps = config.eagle_ttt_steps self.eagle_mix_hidden_states = config.eagle_mix_hidden_states + self.eagle_use_torch_compile = config.eagle_use_torch_compile diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index d77ed298acc..0a21b31aa6e 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -35,6 +35,8 @@ """Eagle model utils.""" +from contextlib import nullcontext + import torch @@ -70,3 +72,14 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def maybe_nvtx_range(*args, **kwargs): + """Helper function to create NVTX ranges if NVTX is available.""" + try: + import torch.cuda.nvtx as nvtx + + return nvtx.range(*args, **kwargs) + except (ImportError, RuntimeError): + # If NVTX is not available, return a no-op context manager + return nullcontext() diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 641ad664d53..e632f2d9303 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -38,7 +38,6 @@ import transformers from packaging.version import Version from torch import nn -from torch.cuda import nvtx from torch.nn import CrossEntropyLoss from torch.nn.attention.flex_attention import BlockMask, create_block_mask from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel @@ -58,7 +57,7 @@ ) from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel -from ..eagle.utils import expand_mask, make_causal_mask +from ..eagle.utils import expand_mask, make_causal_mask, maybe_nvtx_range from ..medusa.conversion import MedusaDMRegistry from ..medusa.medusa_model import MedusaModel from ..utils import ( @@ -618,6 +617,13 @@ def modify( # https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566 self.is_quantized = False + if self.eagle_use_torch_compile: + self._prepare_eagle_inputs = torch.compile(self._prepare_eagle_inputs, dynamic=False) + self._eagle_forward = torch.compile( + self._eagle_forward, dynamic=False, mode="max-autotune" + ) + self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True) + self._cached_attn_blk_masks = {} def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): @@ -657,8 +663,7 @@ def _prepare_decoder_attention_mask( return combined_attention_mask - @nvtx.range("prepare_eagle_inputs") - @torch.compile(dynamic=False) + @maybe_nvtx_range("prepare_eagle_inputs") def _prepare_eagle_inputs( self, input_ids, @@ -761,7 +766,7 @@ def _compute_ttt_attention_mask( tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask - @nvtx.range("base_model_forward") + @maybe_nvtx_range("base_model_forward") def _base_model_forward( self, input_ids, @@ -810,8 +815,7 @@ def _map_logits_to_draft_vocab(self, full_logits): ) return full_logits[:, :, reverse_mapping] - @nvtx.range("eagle_forward") - @torch.compile(dynamic=False, mode="max-autotune") + @maybe_nvtx_range("eagle_forward") def _eagle_forward( self, eagle_input_hidden_states, @@ -1005,8 +1009,7 @@ def forward( train_acc=train_accs, ) - @nvtx.range("eagle_loss") - @torch.compile(dynamic=False, fullgraph=True) + @maybe_nvtx_range("eagle_loss") def _eagle_loss( self, base_output_softmax_logits, From d5ce31d573e585c353f57dd190b1f5a940bf8289 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Wed, 18 Mar 2026 19:55:52 +0000 Subject: [PATCH 11/15] torch.compile safely Signed-off-by: Benjamin Chislett --- .../torch/speculative/plugins/transformers.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index e632f2d9303..4be730d3716 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -618,13 +618,32 @@ def modify( self.is_quantized = False if self.eagle_use_torch_compile: + self._activate_torch_compile() + + self._cached_attn_blk_masks = {} + + def _activate_torch_compile(self): + import torch._dynamo + + torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode + + # Individual try-catch for each function to maximize torch.compile usage + try: self._prepare_eagle_inputs = torch.compile(self._prepare_eagle_inputs, dynamic=False) + except Exception: + print("Disabling torch.compile for _prepare_eagle_inputs due to compilation error.") + + try: self._eagle_forward = torch.compile( self._eagle_forward, dynamic=False, mode="max-autotune" ) - self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True) + except Exception: + print("Disabling torch.compile for _eagle_forward due to compilation error.") - self._cached_attn_blk_masks = {} + try: + self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True) + except Exception: + print("Disabling torch.compile for _eagle_loss due to compilation error.") def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): # compile and cached flex attention masks in first call From b58be4ac776caa90081f0822864c130ce86d5b98 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Wed, 18 Mar 2026 20:53:43 +0000 Subject: [PATCH 12/15] nvtx sniff test Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/eagle/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index 0a21b31aa6e..1ca1f0c7fc6 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -79,6 +79,9 @@ def maybe_nvtx_range(*args, **kwargs): try: import torch.cuda.nvtx as nvtx + nvtx.range_push("nvtx init") + nvtx.range_pop("nvtx init") + return nvtx.range(*args, **kwargs) except (ImportError, RuntimeError): # If NVTX is not available, return a no-op context manager From cc0ef499b2e8d44b8ce57362b7e9650dc25ebf1e Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 19 Mar 2026 01:03:55 +0000 Subject: [PATCH 13/15] fix Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/eagle/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index 1ca1f0c7fc6..29db8d33a19 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -80,9 +80,9 @@ def maybe_nvtx_range(*args, **kwargs): import torch.cuda.nvtx as nvtx nvtx.range_push("nvtx init") - nvtx.range_pop("nvtx init") + nvtx.range_pop() return nvtx.range(*args, **kwargs) - except (ImportError, RuntimeError): + except Exception: # If NVTX is not available, return a no-op context manager return nullcontext() From 7bc055ee6c03eebac5394797c1c17340f5de95ec Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:08:42 +0000 Subject: [PATCH 14/15] fix test: disable nvtx by default Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/speculative/config.py | 5 + .../torch/speculative/eagle/eagle_model.py | 1 + modelopt/torch/speculative/eagle/utils.py | 16 --- .../torch/speculative/plugins/transformers.py | 101 ++++++++++-------- 4 files changed, 64 insertions(+), 59 deletions(-) diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 9b77391b497..69491c65994 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -115,3 +115,8 @@ class EagleConfig(ModeloptBaseConfig): default=True, description="Whether to use torch.compile on eagle forward/loss methods for faster training.", ) + + eagle_enable_nvtx: bool = ModeloptField( + default=False, + description="Whether to enable NVTX ranges for profiling eagle forward/loss methods.", + ) diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index b9532b6c54a..e2a08c5252a 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -40,3 +40,4 @@ def modify( self.eagle_ttt_steps = config.eagle_ttt_steps self.eagle_mix_hidden_states = config.eagle_mix_hidden_states self.eagle_use_torch_compile = config.eagle_use_torch_compile + self.eagle_enable_nvtx = config.eagle_enable_nvtx diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index 29db8d33a19..d77ed298acc 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -35,8 +35,6 @@ """Eagle model utils.""" -from contextlib import nullcontext - import torch @@ -72,17 +70,3 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -def maybe_nvtx_range(*args, **kwargs): - """Helper function to create NVTX ranges if NVTX is available.""" - try: - import torch.cuda.nvtx as nvtx - - nvtx.range_push("nvtx init") - nvtx.range_pop() - - return nvtx.range(*args, **kwargs) - except Exception: - # If NVTX is not available, return a no-op context manager - return nullcontext() diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 4be730d3716..5519110aa88 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -57,7 +57,7 @@ ) from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel -from ..eagle.utils import expand_mask, make_causal_mask, maybe_nvtx_range +from ..eagle.utils import expand_mask, make_causal_mask from ..medusa.conversion import MedusaDMRegistry from ..medusa.medusa_model import MedusaModel from ..utils import ( @@ -453,6 +453,23 @@ def _draft_model_config(self): """Return the llm config for the draft model.""" return self.eagle_config + def _enable_cp_ttt(self): + if self.training and not self.eagle_mix_hidden_states: + return enable_cp_ttt_patch() + return contextlib.nullcontext() + + def _nvtx_range(self, name): + """Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set.""" + if not self.eagle_enable_nvtx: + return contextlib.nullcontext() + try: + import torch.cuda.nvtx as nvtx + + return nvtx.range(name) + except Exception as e: + print(f"Failed to create NVTX range {name}: {e}") + return contextlib.nullcontext() + def get_exporter(self) -> SpeculativeDecodingExporter: """Get the exporter for the draft model.""" exporter_cls = ( @@ -682,7 +699,6 @@ def _prepare_decoder_attention_mask( return combined_attention_mask - @maybe_nvtx_range("prepare_eagle_inputs") def _prepare_eagle_inputs( self, input_ids, @@ -785,7 +801,6 @@ def _compute_ttt_attention_mask( tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask - @maybe_nvtx_range("base_model_forward") def _base_model_forward( self, input_ids, @@ -834,7 +849,6 @@ def _map_logits_to_draft_vocab(self, full_logits): ) return full_logits[:, :, reverse_mapping] - @maybe_nvtx_range("eagle_forward") def _eagle_forward( self, eagle_input_hidden_states, @@ -913,15 +927,16 @@ def forward( base_outputs.logits = self.lm_head(base_outputs.out_hiddens) past_key_values = None else: - base_outputs, past_key_values = self._base_model_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - self.eagle_freeze_base_model, - labels, - **kwargs, - ) + with self._nvtx_range("base_model_forward"): + base_outputs, past_key_values = self._base_model_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + self.eagle_freeze_base_model, + labels, + **kwargs, + ) if not isinstance(past_key_values, Cache): past_key_values = _get_empty_cache(self._base_llm_config) @@ -935,20 +950,21 @@ def forward( num_ttt = self.eagle_ttt_steps train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device) b, seq_length, _ = base_outputs.out_hiddens.shape - ( - eagle_input_embeds, - eagle_input_hiddens, - eagle_attn_mask_0, - eagle_position_ids, - base_output_predict_tok, - base_output_softmax_logits, - ) = self._prepare_eagle_inputs( - input_ids, - attention_mask, - position_ids, - eagle_cache, - base_outputs, - ) + with self._nvtx_range("prepare_eagle_inputs"): + ( + eagle_input_embeds, + eagle_input_hiddens, + eagle_attn_mask_0, + eagle_position_ids, + base_output_predict_tok, + base_output_softmax_logits, + ) = self._prepare_eagle_inputs( + input_ids, + attention_mask, + position_ids, + eagle_cache, + base_outputs, + ) self.eagle_module._maybe_init_rope() @@ -960,11 +976,7 @@ def forward( if self.eagle_mix_hidden_states or ttt_step == 0 else self._get_ttt_attention_mask(b, seq_length, ttt_step) ) - with ( - enable_cp_ttt_patch() - if self.training and not self.eagle_mix_hidden_states - else contextlib.nullcontext() - ): + with self._enable_cp_ttt(), self._nvtx_range("eagle_forward"): _, eagle_output_hiddens, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hiddens, eagle_input_embeds, @@ -992,15 +1004,16 @@ def forward( for i in range(self.eagle_config.parallel_draft_step): eagle_logit = eagle_logits[i] - classification_loss, acc = self._eagle_loss( - # base model predict +1 tok, while eagle predict +2 - # so we shift base model outputs compared to eagle outputs - # additionally, we mask the first n tok of eagle outputs at nth TTT step - base_output_softmax_logits[:, 1 + i + ttt_step :], - base_output_predict_tok[:, 1 + i + ttt_step :], - eagle_logit[:, ttt_step : -(1 + i)], - loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i], - ) + with self._nvtx_range("eagle_loss"): + classification_loss, acc = self._eagle_loss( + # base model predict +1 tok, while eagle predict +2 + # so we shift base model outputs compared to eagle outputs + # additionally, we mask the first n tok of eagle outputs at nth TTT step + base_output_softmax_logits[:, 1 + i + ttt_step :], + base_output_predict_tok[:, 1 + i + ttt_step :], + eagle_logit[:, ttt_step : -(1 + i)], + loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i], + ) # Apply loss decay factor to focus on early steps classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i) eagle_loss = ( @@ -1028,7 +1041,6 @@ def forward( train_acc=train_accs, ) - @maybe_nvtx_range("eagle_loss") def _eagle_loss( self, base_output_softmax_logits, @@ -1100,7 +1112,10 @@ def pseudo_speculative_generate( ) # Use SDPA attention during generation for both stability and performance - with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"): + with ( + temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"), + self._nvtx_range("eagle_forward"), + ): _, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward( eagle_input_hidden_states, self._base_model_embeddings(eagle_ids), From 2933ced778cf269fcb650c83bd9b1a2098d08583 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:18:06 +0000 Subject: [PATCH 15/15] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../torch/speculative/plugins/transformers.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 5519110aa88..37517d768d4 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -644,23 +644,16 @@ def _activate_torch_compile(self): torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode - # Individual try-catch for each function to maximize torch.compile usage - try: - self._prepare_eagle_inputs = torch.compile(self._prepare_eagle_inputs, dynamic=False) - except Exception: - print("Disabling torch.compile for _prepare_eagle_inputs due to compilation error.") - - try: - self._eagle_forward = torch.compile( - self._eagle_forward, dynamic=False, mode="max-autotune" - ) - except Exception: - print("Disabling torch.compile for _eagle_forward due to compilation error.") - - try: - self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True) - except Exception: - print("Disabling torch.compile for _eagle_loss due to compilation error.") + compile_targets = [ + ("_prepare_eagle_inputs", {}), + ("_eagle_forward", {"mode": "max-autotune"}), + ("_eagle_loss", {"fullgraph": True}), + ] + for name, kwargs in compile_targets: + try: + setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) + except Exception: # noqa: PERF203 + print(f"Disabling torch.compile for {name} due to compilation error.") def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): # compile and cached flex attention masks in first call