Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# AutoDeploy config for serving Mistral-Small-4-119B with Eagle speculative decoding.
# The Eagle checkpoint (mistralai/Mistral-Small-4-119B-2603-eagle) is in native Mistral
# format (params.json + consolidated.safetensors); its config is loaded from the target model.
runtime: trtllm
compile_backend: torch-simple
model_factory: Mistral3ForConditionalGeneration
skip_loading_weights: false
max_seq_len: 512
world_size: 8
tokenizer: mistralai/Mistral-Small-4-119B-2603
transforms:
insert_cached_mla_attention:
backend: torch_mla
speculative_config:
decoding_type: Eagle3
max_draft_len: 3
speculative_model: mistralai/Mistral-Small-4-119B-2603-eagle
eagle3_one_model: true
eagle3_model_arch: mistral_large3
eagle3_layers_to_capture: [-1]
speculative_model_kwargs:
num_hidden_layers: 2
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be overridden

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Config for Mistral Small 4 119B with Eagle3 speculative decoding.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait which one is used? This other the other configs?

compile_backend: torch-simple
model_factory: Mistral3ForConditionalGeneration
tokenizer: mistralai/Mistral-Small-4-119B-2603
max_seq_len: 512
world_size: 8
speculative_config:
decoding_type: Eagle3
max_draft_len: 3
speculative_model: mistralai/Mistral-Small-4-119B-2603-eagle
eagle3_one_model: true
eagle3_model_arch: mistral_large3
transforms:
insert_cached_mla_attention:
backend: torch_mla
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
runtime: trtllm
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be here; tests only

compile_backend: torch-simple
model_factory: Mistral3ForConditionalGeneration
skip_loading_weights: false
max_seq_len: 512
world_size: 8
tokenizer: mistralai/Mistral-Small-4-119B-2603
transforms:
insert_cached_mla_attention:
backend: torch_mla
model_kwargs:
num_hidden_layers: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Standalone AutoDeploy config for Mistral Small 4 119B using torch MLA.
runtime: trtllm
attn_backend: trtllm
compile_backend: torch-simple
model_factory: AutoModelForImageTextToText
skip_loading_weights: false
max_seq_len: 512
world_size: 8
tokenizer: tensorrt_llm/_torch/auto_deploy/tokenizers/mistral_small_4_119b
transforms:
insert_cached_mla_attention:
backend: torch_mla
38 changes: 29 additions & 9 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A

@model_validator(mode="after")
def setup_hidden_state_capture(self):
"""Enable the hidden state capture transform if the speculative config requires it.

This validator only configures transforms — factory selection is handled by
create_factory() to avoid mutating model_factory in place.
"""
spec_config = self.speculative_config
if spec_config is None:
return self
Expand All @@ -131,16 +136,15 @@ def setup_hidden_state_capture(self):
"enabled. Ensure num_nextn_predict_layers is set in the model config."
)
capture_layers = {-1}
self.model_factory = "eagle_one_model"
elif isinstance(spec_config, EagleDecodingConfig):
if spec_config.max_draft_len is None:
raise ValueError(
"EagleDecodingConfig.max_draft_len must not be None. "
"Provide a positive integer for max_draft_len."
)
capture_layers = spec_config.eagle3_layers_to_capture
if spec_config.eagle3_one_model:
self.model_factory = "eagle_one_model"
if not spec_config.eagle3_one_model:
return self
else:
return self

Expand Down Expand Up @@ -351,22 +355,38 @@ def update_cuda_graph_batch_sizes(self):
return self

### UTILITY METHODS ############################################################################
def _requires_eagle_one_model(self) -> bool:
"""Check if the speculative config requires Eagle one-model factory."""
spec_config = self.speculative_config
if spec_config is None:
return False
if isinstance(spec_config, MTPDecodingConfig):
return spec_config.mtp_eagle_one_model
if isinstance(spec_config, EagleDecodingConfig):
return spec_config.eagle3_one_model
return False

def create_factory(self) -> ModelFactory:
"""Create a model factory from the arguments."""

# TODO (lucaslie): consider supporting Path objects in the model factory
return ModelFactoryRegistry.get(self.model_factory)(
common_kwargs = dict(
model=str(self.model),
model_kwargs=self.model_kwargs,
tokenizer=None if self.tokenizer is None else str(self.tokenizer),
tokenizer_kwargs=self.tokenizer_kwargs,
skip_loading_weights=self.skip_loading_weights,
max_seq_len=self.max_seq_len,
# Extra kwargs consumed by EagleOneModelFactory (ignored by others via **kwargs)
speculative_config=self.speculative_config,
speculative_model_kwargs=self.speculative_model_kwargs or None,
)

if self._requires_eagle_one_model():
return ModelFactoryRegistry.get("eagle_one_model")(
**common_kwargs,
speculative_config=self.speculative_config,
speculative_model_kwargs=self.speculative_model_kwargs or None,
target_factory_cls_name=self.model_factory,
)

return ModelFactoryRegistry.get(self.model_factory)(**common_kwargs)

def is_cuda_graph_enabled(self) -> bool:
return self.compile_backend in ["torch-cudagraph", "torch-opt"]

Expand Down
58 changes: 51 additions & 7 deletions tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import torch
import torch.nn as nn
from torch.fx import GraphModule
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.utils import ModelOutput
Expand All @@ -41,6 +42,7 @@
from ...shim.interface import CachedSequenceInterface
from ...utils._config import deep_merge_dicts
from ...utils.logger import ad_logger
from .modeling_mistral3 import build_mistral4_eagle_layers
from .modeling_nemotron_h import build_nemotron_eagle_layers

# =============================================================================
Expand Down Expand Up @@ -72,10 +74,12 @@ def get_eagle_layers(config, model_type: str) -> Union[nn.ModuleList, nn.Module]
layers = build_llama_eagle_layers(config)
case "nemotron_h":
layers = build_nemotron_eagle_layers(config)
case "mistral4":
layers = build_mistral4_eagle_layers(config)
case _:
raise ValueError(
f"Model type '{model_type}' not supported for Eagle drafter. "
f"Supported types: llama, nemotron_h"
f"Supported types: llama, nemotron_h, mistral4"
)

if len(layers) == 1:
Expand Down Expand Up @@ -134,6 +138,24 @@ class EagleConfig(PretrainedConfig):
r"^mtp\.": "model.",
},
},
"mistral4": {
"load_embedding_from_target": True,
"load_lm_head_from_target": True,
"num_capture_layers": 1,
# PyTorch backend captures post-norm hidden states for Mistral3/4
# (layers_to_capture={-1} captures after final RMSNorm). AutoDeploy
# captures at the residual add (pre-norm), so we normalize afterwards.
"normalize_target_hidden_state": True,
"layers_handle_final_norm": False,
# Mistral4 Eagle checkpoint (native Mistral format):
# eagle_linear.weight [hidden, 2*hidden] -> model.layers.0.eagle_proj.weight
# layers.* -> model.layers.*
# norm.weight stays as-is (maps to EagleDrafterForCausalLM.norm)
"_checkpoint_conversion_mapping": {
r"^eagle_linear": "model.layers.0.eagle_proj",
r"^layers": "model.layers",
},
},
}
# Some custom HF config classes expose backward-compatibility fields as properties instead of
# storing them directly in __dict__. Those values do not survive config.to_dict(), so carry
Expand Down Expand Up @@ -495,7 +517,7 @@ def __init__(self, config, layers: Union[nn.ModuleList, nn.Module]):
self.embed_tokens = (
None
if load_embedding_from_target
else nn.Embedding(config.vocab_size, config.hidden_size)
else nn.Embedding(config.vocab_size, config.hidden_size, dtype=self.dtype)
)

# Vocab mapping for draft -> target token conversion
Expand Down Expand Up @@ -585,9 +607,9 @@ class EagleDrafterForCausalLM(PreTrainedModel):

base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["LlamaEagleLayer", "NemotronHEagleLayer"]
_no_split_modules = ["LlamaEagleLayer", "NemotronHEagleLayer", "Mistral4EagleLayer"]

def __init__(self, config, layers: Optional[Union[nn.ModuleList, nn.Module]] = None):
def __init__(self, config, layers: Optional[Union[nn.ModuleList, nn.Module]] = None, **kwargs):
super().__init__(config)

# Read checkpoint conversion mapping from config (set by EagleConfig based on model_type)
Expand Down Expand Up @@ -803,7 +825,13 @@ def apply_draft_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
def apply_lm_head(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply lm_head to get logits from hidden states."""
if self.load_lm_head_from_target:
lm_head_weights = self.target_model.get_output_embeddings()(hidden_states)
lm_head = self.target_model.get_output_embeddings()
# Cast weight to hidden_states dtype: quantize_fp8_linear_from_config may have
# converted the lm_head weight to FP8 in-place for the FX graph, but apply_lm_head
# calls it as a plain nn.Linear outside the graph.
lm_head_weights = torch.nn.functional.linear(
hidden_states, lm_head.weight.to(hidden_states.dtype), lm_head.bias
)
return lm_head_weights.to(self._draft_dtype)
else:
return self.draft_model.get_output_embeddings()(hidden_states)
Expand Down Expand Up @@ -886,8 +914,24 @@ def _forward_prefill_only(self, input_ids: torch.Tensor, position_ids: torch.Ten

@staticmethod
def _filter_kwargs_for_submodule(kwargs: dict, submodule: nn.Module) -> dict:
"""Filter kwargs to only include those accepted by submodule's forward (GraphModule)."""
expected_names = {node.name for node in submodule.graph.nodes if node.op == "placeholder"}
"""Filter kwargs to only include those accepted by submodule's forward (GraphModule).

Graph transforms (KV cache insertion, sharding, etc.) add placeholder nodes to the
exported GraphModule. The placeholder names are the authoritative set of kwargs that
the submodule's forward accepts at inference time — all cache / attention metadata
belongs to the inner GraphModule, not to any eager wrapper around it.

For VLM targets (e.g., Mistral3ForConditionalGenerationAD wrapping Mistral4ForCausalLM),
only the language model is exported to a GraphModule while the outer wrapper stays in
eager mode. We walk direct children to locate the inner GraphModule in that case.
"""
gm = submodule
if not isinstance(gm, GraphModule):
for child in submodule.children():
if isinstance(child, GraphModule):
gm = child
break
expected_names = {node.name for node in gm.graph.nodes if node.op == "placeholder"}
return {k: v for k, v in kwargs.items() if k in expected_names}

@staticmethod
Expand Down
Loading
Loading