From 4dc191e679be9d07920a0466817f84321d665b70 Mon Sep 17 00:00:00 2001 From: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Date: Fri, 3 Apr 2026 19:08:58 -0700 Subject: [PATCH] [None][feat] add Mistral4 Eagle speculative decoding to AutoDeploy Add Eagle one-model speculative decoding support for Mistral-Small-4-119B: - Mistral4-specific Eagle layer (Mistral4EagleMLA, Mistral4EagleMLP, Mistral4EagleDecoderLayer) with dispatch table entry - EagleOneModelFactory gains TargetModelExportInfo/DraftModelExportInfo with scalar-sentinel DCE prevention matching hf.py pattern - Factory delegation so Mistral4's custom target factory is used - Mistral3 wrapper delegation methods (get_input/output_embeddings, config) - FP8 checkpoint loading improvements for Mistral4 - Hidden-state capture guard against double-apply - MLA RoPE deinterleave hook ordering fix for Eagle path - LlmArgs Eagle config defaults and MTP one-model routing Configs: mistral4_eagle_119b.yaml (8-GPU), mistral_small_4_119b_eagle.yaml, mistral_small_4_119b_lite.yaml, mistral_small_4_119b_torch_mla.yaml Tests: hierarchical unit tests (AD ops vs PyTorch reference), hidden-state capture detection, torch.export + AD pipeline integration, 1-layer and 3-layer Eagle one-model E2E smoke, layer subgraph debug, framework-level spec-dec config/KV-cache tests. Signed-off-by: Govind Ramnarayan Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> --- .../configs/mistral4_eagle_119b.yaml | 22 + .../configs/mistral_small_4_119b_eagle.yaml | 15 + .../configs/mistral_small_4_119b_lite.yaml | 12 + .../mistral_small_4_119b_torch_mla.yaml | 12 + tensorrt_llm/_torch/auto_deploy/llm_args.py | 38 +- .../models/custom/modeling_eagle.py | 58 +- .../models/custom/modeling_mistral3.py | 335 ++++++++++ .../_torch/auto_deploy/models/eagle.py | 162 ++++- tensorrt_llm/_torch/auto_deploy/models/hf.py | 56 +- .../transform/library/hidden_states.py | 5 +- .../_torch/auto_deploy/utils/node_utils.py | 72 +- .../models/test_mistral4_eagle_modeling.py | 624 ++++++++++++++++++ .../test_mistral4_hidden_state_capture.py | 371 +++++++++++ .../_utils_test/_model_test_utils.py | 24 + .../singlegpu/models/test_mistral4_eagle.py | 367 ++++++++++ .../smoke/test_ad_speculative_decoding.py | 90 ++- .../smoke/test_layer_subgraph_debug.py | 232 +++++++ 17 files changed, 2420 insertions(+), 75 deletions(-) create mode 100644 examples/auto_deploy/model_registry/configs/mistral4_eagle_119b.yaml create mode 100644 examples/auto_deploy/model_registry/configs/mistral_small_4_119b_eagle.yaml create mode 100644 examples/auto_deploy/model_registry/configs/mistral_small_4_119b_lite.yaml create mode 100644 examples/auto_deploy/model_registry/configs/mistral_small_4_119b_torch_mla.yaml create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral4_eagle_modeling.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral4_hidden_state_capture.py create mode 100644 tests/unittest/auto_deploy/singlegpu/models/test_mistral4_eagle.py create mode 100644 tests/unittest/auto_deploy/singlegpu/smoke/test_layer_subgraph_debug.py diff --git a/examples/auto_deploy/model_registry/configs/mistral4_eagle_119b.yaml b/examples/auto_deploy/model_registry/configs/mistral4_eagle_119b.yaml new file mode 100644 index 00000000000..3d9145d744d --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/mistral4_eagle_119b.yaml @@ -0,0 +1,22 @@ +# AutoDeploy config for serving Mistral-Small-4-119B with Eagle speculative decoding. +# The Eagle checkpoint (mistralai/Mistral-Small-4-119B-2603-eagle) is in native Mistral +# format (params.json + consolidated.safetensors); its config is loaded from the target model. +runtime: trtllm +compile_backend: torch-simple +model_factory: Mistral3ForConditionalGeneration +skip_loading_weights: false +max_seq_len: 512 +world_size: 8 +tokenizer: mistralai/Mistral-Small-4-119B-2603 +transforms: + insert_cached_mla_attention: + backend: torch_mla +speculative_config: + decoding_type: Eagle3 + max_draft_len: 3 + speculative_model: mistralai/Mistral-Small-4-119B-2603-eagle + eagle3_one_model: true + eagle3_model_arch: mistral_large3 + eagle3_layers_to_capture: [-1] +speculative_model_kwargs: + num_hidden_layers: 2 diff --git a/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_eagle.yaml b/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_eagle.yaml new file mode 100644 index 00000000000..e43e291daae --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_eagle.yaml @@ -0,0 +1,15 @@ +# Config for Mistral Small 4 119B with Eagle3 speculative decoding. +compile_backend: torch-simple +model_factory: Mistral3ForConditionalGeneration +tokenizer: mistralai/Mistral-Small-4-119B-2603 +max_seq_len: 512 +world_size: 8 +speculative_config: + decoding_type: Eagle3 + max_draft_len: 3 + speculative_model: mistralai/Mistral-Small-4-119B-2603-eagle + eagle3_one_model: true + eagle3_model_arch: mistral_large3 +transforms: + insert_cached_mla_attention: + backend: torch_mla diff --git a/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_lite.yaml b/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_lite.yaml new file mode 100644 index 00000000000..885ca859248 --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_lite.yaml @@ -0,0 +1,12 @@ +runtime: trtllm +compile_backend: torch-simple +model_factory: Mistral3ForConditionalGeneration +skip_loading_weights: false +max_seq_len: 512 +world_size: 8 +tokenizer: mistralai/Mistral-Small-4-119B-2603 +transforms: + insert_cached_mla_attention: + backend: torch_mla +model_kwargs: + num_hidden_layers: 5 diff --git a/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_torch_mla.yaml b/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_torch_mla.yaml new file mode 100644 index 00000000000..e94e77922b5 --- /dev/null +++ b/examples/auto_deploy/model_registry/configs/mistral_small_4_119b_torch_mla.yaml @@ -0,0 +1,12 @@ +# Standalone AutoDeploy config for Mistral Small 4 119B using torch MLA. +runtime: trtllm +attn_backend: trtllm +compile_backend: torch-simple +model_factory: AutoModelForImageTextToText +skip_loading_weights: false +max_seq_len: 512 +world_size: 8 +tokenizer: tensorrt_llm/_torch/auto_deploy/tokenizers/mistral_small_4_119b +transforms: + insert_cached_mla_attention: + backend: torch_mla diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 99e97aaf9d0..b8bb71c03d5 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -116,6 +116,11 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A @model_validator(mode="after") def setup_hidden_state_capture(self): + """Enable the hidden state capture transform if the speculative config requires it. + + This validator only configures transforms — factory selection is handled by + create_factory() to avoid mutating model_factory in place. + """ spec_config = self.speculative_config if spec_config is None: return self @@ -131,7 +136,6 @@ def setup_hidden_state_capture(self): "enabled. Ensure num_nextn_predict_layers is set in the model config." ) capture_layers = {-1} - self.model_factory = "eagle_one_model" elif isinstance(spec_config, EagleDecodingConfig): if spec_config.max_draft_len is None: raise ValueError( @@ -139,8 +143,8 @@ def setup_hidden_state_capture(self): "Provide a positive integer for max_draft_len." ) capture_layers = spec_config.eagle3_layers_to_capture - if spec_config.eagle3_one_model: - self.model_factory = "eagle_one_model" + if not spec_config.eagle3_one_model: + return self else: return self @@ -351,22 +355,38 @@ def update_cuda_graph_batch_sizes(self): return self ### UTILITY METHODS ############################################################################ + def _requires_eagle_one_model(self) -> bool: + """Check if the speculative config requires Eagle one-model factory.""" + spec_config = self.speculative_config + if spec_config is None: + return False + if isinstance(spec_config, MTPDecodingConfig): + return spec_config.mtp_eagle_one_model + if isinstance(spec_config, EagleDecodingConfig): + return spec_config.eagle3_one_model + return False + def create_factory(self) -> ModelFactory: """Create a model factory from the arguments.""" - - # TODO (lucaslie): consider supporting Path objects in the model factory - return ModelFactoryRegistry.get(self.model_factory)( + common_kwargs = dict( model=str(self.model), model_kwargs=self.model_kwargs, tokenizer=None if self.tokenizer is None else str(self.tokenizer), tokenizer_kwargs=self.tokenizer_kwargs, skip_loading_weights=self.skip_loading_weights, max_seq_len=self.max_seq_len, - # Extra kwargs consumed by EagleOneModelFactory (ignored by others via **kwargs) - speculative_config=self.speculative_config, - speculative_model_kwargs=self.speculative_model_kwargs or None, ) + if self._requires_eagle_one_model(): + return ModelFactoryRegistry.get("eagle_one_model")( + **common_kwargs, + speculative_config=self.speculative_config, + speculative_model_kwargs=self.speculative_model_kwargs or None, + target_factory_cls_name=self.model_factory, + ) + + return ModelFactoryRegistry.get(self.model_factory)(**common_kwargs) + def is_cuda_graph_enabled(self) -> bool: return self.compile_backend in ["torch-cudagraph", "torch-opt"] diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py index 28f553dc148..5e8ad686102 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py @@ -33,6 +33,7 @@ import torch import torch.nn as nn +from torch.fx import GraphModule from transformers import PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN from transformers.utils import ModelOutput @@ -41,6 +42,7 @@ from ...shim.interface import CachedSequenceInterface from ...utils._config import deep_merge_dicts from ...utils.logger import ad_logger +from .modeling_mistral3 import build_mistral4_eagle_layers from .modeling_nemotron_h import build_nemotron_eagle_layers # ============================================================================= @@ -72,10 +74,12 @@ def get_eagle_layers(config, model_type: str) -> Union[nn.ModuleList, nn.Module] layers = build_llama_eagle_layers(config) case "nemotron_h": layers = build_nemotron_eagle_layers(config) + case "mistral4": + layers = build_mistral4_eagle_layers(config) case _: raise ValueError( f"Model type '{model_type}' not supported for Eagle drafter. " - f"Supported types: llama, nemotron_h" + f"Supported types: llama, nemotron_h, mistral4" ) if len(layers) == 1: @@ -134,6 +138,24 @@ class EagleConfig(PretrainedConfig): r"^mtp\.": "model.", }, }, + "mistral4": { + "load_embedding_from_target": True, + "load_lm_head_from_target": True, + "num_capture_layers": 1, + # PyTorch backend captures post-norm hidden states for Mistral3/4 + # (layers_to_capture={-1} captures after final RMSNorm). AutoDeploy + # captures at the residual add (pre-norm), so we normalize afterwards. + "normalize_target_hidden_state": True, + "layers_handle_final_norm": False, + # Mistral4 Eagle checkpoint (native Mistral format): + # eagle_linear.weight [hidden, 2*hidden] -> model.layers.0.eagle_proj.weight + # layers.* -> model.layers.* + # norm.weight stays as-is (maps to EagleDrafterForCausalLM.norm) + "_checkpoint_conversion_mapping": { + r"^eagle_linear": "model.layers.0.eagle_proj", + r"^layers": "model.layers", + }, + }, } # Some custom HF config classes expose backward-compatibility fields as properties instead of # storing them directly in __dict__. Those values do not survive config.to_dict(), so carry @@ -495,7 +517,7 @@ def __init__(self, config, layers: Union[nn.ModuleList, nn.Module]): self.embed_tokens = ( None if load_embedding_from_target - else nn.Embedding(config.vocab_size, config.hidden_size) + else nn.Embedding(config.vocab_size, config.hidden_size, dtype=self.dtype) ) # Vocab mapping for draft -> target token conversion @@ -585,9 +607,9 @@ class EagleDrafterForCausalLM(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = False - _no_split_modules = ["LlamaEagleLayer", "NemotronHEagleLayer"] + _no_split_modules = ["LlamaEagleLayer", "NemotronHEagleLayer", "Mistral4EagleLayer"] - def __init__(self, config, layers: Optional[Union[nn.ModuleList, nn.Module]] = None): + def __init__(self, config, layers: Optional[Union[nn.ModuleList, nn.Module]] = None, **kwargs): super().__init__(config) # Read checkpoint conversion mapping from config (set by EagleConfig based on model_type) @@ -803,7 +825,13 @@ def apply_draft_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: def apply_lm_head(self, hidden_states: torch.Tensor) -> torch.Tensor: """Apply lm_head to get logits from hidden states.""" if self.load_lm_head_from_target: - lm_head_weights = self.target_model.get_output_embeddings()(hidden_states) + lm_head = self.target_model.get_output_embeddings() + # Cast weight to hidden_states dtype: quantize_fp8_linear_from_config may have + # converted the lm_head weight to FP8 in-place for the FX graph, but apply_lm_head + # calls it as a plain nn.Linear outside the graph. + lm_head_weights = torch.nn.functional.linear( + hidden_states, lm_head.weight.to(hidden_states.dtype), lm_head.bias + ) return lm_head_weights.to(self._draft_dtype) else: return self.draft_model.get_output_embeddings()(hidden_states) @@ -886,8 +914,24 @@ def _forward_prefill_only(self, input_ids: torch.Tensor, position_ids: torch.Ten @staticmethod def _filter_kwargs_for_submodule(kwargs: dict, submodule: nn.Module) -> dict: - """Filter kwargs to only include those accepted by submodule's forward (GraphModule).""" - expected_names = {node.name for node in submodule.graph.nodes if node.op == "placeholder"} + """Filter kwargs to only include those accepted by submodule's forward (GraphModule). + + Graph transforms (KV cache insertion, sharding, etc.) add placeholder nodes to the + exported GraphModule. The placeholder names are the authoritative set of kwargs that + the submodule's forward accepts at inference time — all cache / attention metadata + belongs to the inner GraphModule, not to any eager wrapper around it. + + For VLM targets (e.g., Mistral3ForConditionalGenerationAD wrapping Mistral4ForCausalLM), + only the language model is exported to a GraphModule while the outer wrapper stays in + eager mode. We walk direct children to locate the inner GraphModule in that case. + """ + gm = submodule + if not isinstance(gm, GraphModule): + for child in submodule.children(): + if isinstance(child, GraphModule): + gm = child + break + expected_names = {node.name for node in gm.graph.nodes if node.op == "placeholder"} return {k: v for k, v in kwargs.items() if k in expected_names} @staticmethod diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_mistral3.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_mistral3.py index 6b3763b0a35..cbb8f11f1ee 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_mistral3.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_mistral3.py @@ -661,6 +661,9 @@ def get_output_embeddings(self): def get_decoder(self): return self.model + def get_final_normalization(self): + return self.model.norm + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -714,6 +717,12 @@ def __init__(self, config: Mistral3Config, **kwargs): def get_input_embeddings(self): return self.language_model.get_input_embeddings() + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def get_final_normalization(self): + return self.language_model.get_final_normalization() + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -861,3 +870,329 @@ def init_processor(self) -> Optional[Any]: Mistral3ForConditionalGenerationFactory.register_custom_model_cls( "Mistral3Config", Mistral3ForConditionalGenerationAD ) + + +# ============================================================================= +# Eagle Layer Builder for Mistral4 (Eagle speculative decoding) +# ============================================================================= + + +class Mistral4EagleMLP(nn.Module): + """Dense SwiGLU MLP for Mistral4 Eagle drafter. + + Uses w1/w2/w3 parameter names to match the native Mistral Eagle checkpoint. + Applies FP8 dequantization if the checkpoint contains FP8-quantized weights. + In SwiGLU convention: w1=gate, w3=up, w2=down. + """ + + def __init__(self, config: Mistral4TextConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + dtype = getattr(config, "torch_dtype", None) + self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype) + self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype) + self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype) + self.act_fn = ACT2FN[config.hidden_act] + self._register_load_state_dict_pre_hook(self._dequantize_fp8_weights) + + def _dequantize_fp8_weights(self, state_dict, prefix, *args): + """Dequantize FP8-quantized MLP weights (w1/w2/w3) from checkpoint.""" + for proj in ("w1", "w2", "w3"): + weight_key = prefix + f"{proj}.weight" + scale_key = prefix + f"{proj}.qscale_weight" + act_scale_key = prefix + f"{proj}.qscale_act" + if weight_key not in state_dict or scale_key not in state_dict: + state_dict.pop(act_scale_key, None) + continue + weight = state_dict[weight_key] + if weight.dtype not in {torch.float8_e4m3fn, torch.float8_e5m2}: + state_dict.pop(act_scale_key, None) + continue + scale = state_dict.pop(scale_key).to(torch.float32) + target_dtype = getattr(self, proj).weight.dtype + state_dict[weight_key] = weight.to(torch.float32).mul(scale).to(target_dtype) + state_dict.pop(act_scale_key, None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + + +class Mistral4EagleMLA(nn.Module): + """MLA attention for Mistral4 Eagle drafter. + + Reuses Mistral4YarnRotaryEmbedding and torch_mla, mirroring Mistral4Attention. + Weight names (wq_a, wq_b, wkv_a_with_mqa, wkv_b, q_a_norm, kv_a_norm, wo) + match the native Mistral Eagle checkpoint. + """ + + def __init__(self, config: Mistral4TextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.softmax_scale = self.q_head_dim ** (-0.5) + self.rope_theta = config.rope_parameters.get("rope_theta", 10000.0) + + # Apply YARN magnitude scale to softmax if configured + rope_scaling = config.rope_scaling + if rope_scaling is not None: + scale = rope_scaling.get("factor", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) + if mscale_all_dim: + yarn_scale = Mistral4YarnRotaryEmbedding._yarn_get_mscale(scale, mscale_all_dim) + self.softmax_scale = self.softmax_scale * yarn_scale * yarn_scale + + rms_eps = config.rms_norm_eps + # Priority: Eagle config's own 'dtype' → Eagle config's own 'torch_dtype' (HF standard, + # deprecated) → None. The outer multimodal config's dtype is propagated into 'dtype' + # by EagleDrafterFactory._get_model_config when neither field is present. + dtype = getattr(config, "dtype", None) or getattr(config, "torch_dtype", None) + + # Projection layers — named to match Eagle checkpoint keys + self.wq_a = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False, dtype=dtype) + self.q_a_norm = Mistral4RMSNorm(self.q_lora_rank, eps=rms_eps) + self.wq_b = nn.Linear( + self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False, dtype=dtype + ) + self.wkv_a_with_mqa = nn.Linear( + self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, dtype=dtype + ) + self.kv_a_norm = Mistral4RMSNorm(self.kv_lora_rank, eps=rms_eps) + # wkv_b is absorbed into torch_mla; must be dequantized from FP8 before absorption. + self.wkv_b = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + dtype=dtype, + ) + self.wo = nn.Linear( + self.num_heads * self.v_head_dim, self.hidden_size, bias=False, dtype=dtype + ) + + # RoPE deinterleave hook (same native Mistral format as base model: rope_interleave=True) + if getattr(config, "rope_interleave", True): + self._register_load_state_dict_pre_hook(self._rope_deinterleave_load_hook) + + # FP8 dequant hooks: wkv_b first (absorbed), then remaining weights + self._register_load_state_dict_pre_hook(self._dequantize_fp8_wkv_b) + self._register_load_state_dict_pre_hook(self._dequantize_fp8_weights) + + self._init_rope() + + def _init_rope(self) -> None: + rope_scaling = self.config.rope_scaling + max_pos = self.config.max_position_embeddings + if ( + rope_scaling is None + or rope_scaling.get("type", rope_scaling.get("rope_type")) != "yarn" + ): + self.rotary_emb = Mistral4RotaryEmbedding( + self.qk_rope_head_dim, max_pos, self.rope_theta + ) + else: + self.rotary_emb = Mistral4YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_pos, + self.rope_theta, + rope_scaling["factor"], + rope_scaling.get("original_max_position_embeddings", 8192), + rope_scaling.get("beta_fast", 32.0), + rope_scaling.get("beta_slow", 1.0), + rope_scaling.get("mscale", 1.0), + rope_scaling.get("mscale_all_dim", 1.0), + ) + + def _rope_deinterleave_load_hook(self, state_dict, prefix, *args): + """Permute RoPE weight columns from interleaved to non-interleaved layout. + + The native Mistral checkpoint stores RoPE dims in interleaved order + (0, 2, 4, ..., 1, 3, 5, ...) inside wq_b and wkv_a_with_mqa. This + hook applies the same permutation as mla_rope_utils._rope_deinterleave_load_hook + but targets the Eagle-specific parameter names. + """ + d = self.qk_rope_head_dim + perm = torch.cat([torch.arange(0, d, 2), torch.arange(1, d, 2)]) + + # wq_b.weight: [num_heads * q_head_dim, q_lora_rank] + wq_b_key = prefix + "wq_b.weight" + if wq_b_key in state_dict: + w = state_dict[wq_b_key] + w = w.view(self.num_heads, self.q_head_dim, -1) + w_nope = w[:, : self.qk_nope_head_dim, :] + w_rope = w[:, self.qk_nope_head_dim :, :] + w_rope = mla_rope_utils._index_select_with_float8_cpu_workaround(w_rope, 1, perm) + state_dict[wq_b_key] = torch.cat([w_nope, w_rope], dim=1).view(-1, w.shape[-1]) + + # wkv_a_with_mqa.weight: [kv_lora_rank + qk_rope_head_dim, hidden_size] + wkv_key = prefix + "wkv_a_with_mqa.weight" + if wkv_key in state_dict: + w = state_dict[wkv_key] + w_kv = w[: self.kv_lora_rank, :] + w_pe = w[self.kv_lora_rank :, :] + w_pe = mla_rope_utils._index_select_with_float8_cpu_workaround(w_pe, 0, perm) + state_dict[wkv_key] = torch.cat([w_kv, w_pe], dim=0) + + def _dequantize_fp8_wkv_b(self, state_dict, prefix, *args): + """Dequantize wkv_b from FP8 before torch_mla absorbs it.""" + weight_key = prefix + "wkv_b.weight" + scale_key = prefix + "wkv_b.qscale_weight" + act_scale_key = prefix + "wkv_b.qscale_act" + if weight_key not in state_dict or scale_key not in state_dict: + state_dict.pop(act_scale_key, None) + return + weight = state_dict[weight_key] + if weight.dtype not in {torch.float8_e4m3fn, torch.float8_e5m2}: + state_dict.pop(act_scale_key, None) + return + target_dtype = self.wkv_b.weight.dtype + scale = state_dict.pop(scale_key).to(torch.float32) + state_dict[weight_key] = weight.to(torch.float32).mul(scale).to(target_dtype) + state_dict.pop(act_scale_key, None) + + def _dequantize_fp8_weights(self, state_dict, prefix, *args): + """Dequantize FP8-quantized attention projection weights from checkpoint.""" + for proj in ("wq_a", "wq_b", "wkv_a_with_mqa", "wo"): + weight_key = prefix + f"{proj}.weight" + scale_key = prefix + f"{proj}.qscale_weight" + act_scale_key = prefix + f"{proj}.qscale_act" + if weight_key not in state_dict or scale_key not in state_dict: + state_dict.pop(act_scale_key, None) + continue + weight = state_dict[weight_key] + if weight.dtype not in {torch.float8_e4m3fn, torch.float8_e5m2}: + state_dict.pop(act_scale_key, None) + continue + scale = state_dict.pop(scale_key).to(torch.float32) + target_dtype = getattr(self, proj).weight.dtype + state_dict[weight_key] = weight.to(torch.float32).mul(scale).to(target_dtype) + state_dict.pop(act_scale_key, None) + + def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + + q = self.wq_b(self.q_a_norm(self.wq_a(hidden_states))) + q = q.view(batch_size, seq_len, self.num_heads, self.q_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + kv = self.wkv_a_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + compressed_kv = self.kv_a_norm(compressed_kv) + k_pe = k_pe.view(batch_size, seq_len, 1, self.qk_rope_head_dim) + + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + cos = cos[position_ids] + sin = sin[position_ids] + q_pe_rotated, k_pe_rotated = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin( + q_pe, + k_pe, + cos, + sin, + 2, + ) + + attn_output = torch.ops.auto_deploy.torch_mla( + q_nope, + q_pe_rotated, + compressed_kv, + k_pe_rotated, + self.wkv_b.weight, + True, + self.softmax_scale, + "bsnd", + ) + attn_output = attn_output.reshape(batch_size, seq_len, self.num_heads * self.v_head_dim) + return self.wo(attn_output) + + +class Mistral4EagleLayer(nn.Module): + """Transformer layer for Mistral4 Eagle drafter. + + Architecture: + - First layer (has_eagle_proj=True): projects cat([inputs_embeds, hidden_states]) + via eagle_proj (Linear: 2*hidden → hidden) to produce initial hidden_states. + - attention_norm → MLA attention → residual add + - ffn_norm → dense SwiGLU MLP → residual add + """ + + def __init__(self, config: Mistral4TextConfig, layer_idx: int, has_eagle_proj: bool = False): + super().__init__() + self.layer_idx = layer_idx + rms_eps = config.rms_norm_eps + dtype = getattr(config, "torch_dtype", None) + + # Eagle projection fuses inputs_embeds + hidden_states on the first layer + if has_eagle_proj: + self.eagle_proj = nn.Linear( + 2 * config.hidden_size, config.hidden_size, bias=False, dtype=dtype + ) + self._register_load_state_dict_pre_hook(self._dequantize_fp8_eagle_proj) + else: + self.eagle_proj = None + + self.attention_norm = Mistral4RMSNorm(config.hidden_size, eps=rms_eps) + self.attention = Mistral4EagleMLA(config, layer_idx) + self.ffn_norm = Mistral4RMSNorm(config.hidden_size, eps=rms_eps) + self.feed_forward = Mistral4EagleMLP(config) + + def _dequantize_fp8_eagle_proj(self, state_dict, prefix, *args): + """Dequantize eagle_proj from FP8 if present in checkpoint.""" + weight_key = prefix + "eagle_proj.weight" + scale_key = prefix + "eagle_proj.qscale_weight" + act_scale_key = prefix + "eagle_proj.qscale_act" + if weight_key not in state_dict or scale_key not in state_dict: + state_dict.pop(act_scale_key, None) + return + weight = state_dict[weight_key] + if weight.dtype not in {torch.float8_e4m3fn, torch.float8_e5m2}: + state_dict.pop(act_scale_key, None) + return + target_dtype = self.eagle_proj.weight.dtype + scale = state_dict.pop(scale_key).to(torch.float32) + state_dict[weight_key] = weight.to(torch.float32).mul(scale).to(target_dtype) + state_dict.pop(act_scale_key, None) + + def forward( + self, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, + position_ids: torch.LongTensor, + ) -> torch.Tensor: + """Forward pass with unified Eagle interface. + + Args: + hidden_states: Hidden states from target model [batch, seq, hidden_size] + inputs_embeds: Token embeddings [batch, seq, hidden_size] + position_ids: Position IDs for RoPE [batch, seq] + + Returns: + Updated hidden states [batch, seq, hidden_size] + """ + if self.eagle_proj is not None: + hidden_states = self.eagle_proj(torch.cat([inputs_embeds, hidden_states], dim=-1)) + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + hidden_states = self.attention(hidden_states, position_ids) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + return residual + hidden_states + + +def build_mistral4_eagle_layers(config) -> list[nn.Module]: + """Build Mistral4 Eagle transformer layers. + + Called by get_eagle_layers() in modeling_eagle.py when model_type == "mistral4". + """ + return [ + Mistral4EagleLayer(config, layer_idx=i, has_eagle_proj=(i == 0)) + for i in range(config.num_hidden_layers) + ] diff --git a/tensorrt_llm/_torch/auto_deploy/models/eagle.py b/tensorrt_llm/_torch/auto_deploy/models/eagle.py index 9dda8de7da4..a54a92a8001 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/eagle.py +++ b/tensorrt_llm/_torch/auto_deploy/models/eagle.py @@ -21,6 +21,7 @@ Eagle speculative decoding. """ +import operator import types from contextlib import nullcontext from typing import Any, Dict, List, Optional @@ -31,8 +32,10 @@ from torch._prims_common import DeviceLikeType from torch.export import Dim from torch.fx import GraphModule +from transformers import AutoConfig from ....llmapi.llm_args import MTPDecodingConfig +from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger from .custom.modeling_eagle import ( EagleConfig, @@ -54,8 +57,85 @@ class EagleDrafterFactory(AutoModelForCausalLMFactory): The checkpoint config is expected to have the base model's model_type (e.g., "llama") along with Eagle-specific fields like draft_vocab_size. + + Args: + config_model: Optional path/name of a model whose HF config should be + used as the architecture config instead of the Eagle checkpoint's own + config. Useful when the Eagle checkpoint is in native (non-HF) format + and lacks a standard config.json (e.g., Mistral4 Eagle). """ + def __init__( + self, + model: str, + config_model: Optional[str] = None, + **kwargs, + ): + super().__init__(model=model, **kwargs) + self._config_model = config_model + + def _get_model_config(self): + # Prefetch Eagle checkpoint so weights are available for later loading. + self.prefetch_checkpoint(skip_loading_weights=True) + # Load architecture config from config_model (e.g. target model) if provided; + # otherwise fall back to the Eagle checkpoint path itself. + config_source = self._config_model if self._config_model is not None else self.model + model_config, unused = AutoConfig.from_pretrained( + config_source, return_unused_kwargs=True, trust_remote_code=True + ) + # For multimodal target models (e.g. Mistral3Config wrapping Mistral4TextConfig), + # extract the inner text config so that model_type reflects the text backbone + # (e.g. 'mistral4') rather than the outer wrapper (e.g. 'mistral3'). + if hasattr(model_config, "text_config") and model_config.text_config is not None: + ad_logger.info( + f"EagleDrafterFactory: extracting text_config from multimodal config " + f"(outer model_type='{model_config.model_type}')" + ) + # The inner text_config may not carry the compute dtype set on the outer wrapper. + # Extract it from either 'dtype' or 'torch_dtype' (deprecated), normalizing to + # a torch.dtype object, then propagate to text_config.dtype. + outer_dtype: Optional[torch.dtype] = None + for dtype_key in ("dtype", "torch_dtype"): + val = getattr(model_config, dtype_key, None) + if val is not None: + if isinstance(val, str) and val != "auto": + val = getattr(torch, val) + assert isinstance(val, torch.dtype), f"Invalid dtype string: {val}" + if isinstance(val, torch.dtype): + outer_dtype = val + break + model_config = model_config.text_config + # Only fall back to the outer dtype when the text_config has neither field set — + # any explicitly set Eagle-model dtype/torch_dtype takes priority over the outer one. + if ( + outer_dtype is not None + and getattr(model_config, "dtype", None) is None + and getattr(model_config, "torch_dtype", None) is None + ): + model_config.dtype = outer_dtype + model_config, nested = self._recursive_update_config(model_config, self.model_kwargs or {}) + return model_config, deep_merge_dicts(unused, nested) + + def _get_checkpoint_file(self, checkpoint): + """Extend the standard checkpoint file search to include native Mistral format. + + Native Mistral checkpoints use ``consolidated.safetensors`` rather than the + HuggingFace-standard ``model.safetensors``. Fall back to the consolidated file + if none of the standard names are found. + """ + try: + return super()._get_checkpoint_file(checkpoint) + except ValueError: + import os + + consolidated = os.path.join(str(checkpoint), "consolidated.safetensors") + if os.path.isfile(consolidated): + ad_logger.info( + f"Native-format Eagle checkpoint detected; loading from {consolidated}" + ) + return consolidated + raise + def _build_model(self, device: DeviceLikeType) -> nn.Module: model_config, unused_kwargs = self._get_model_config() @@ -99,8 +179,8 @@ def build_and_load_model(self, _device: DeviceLikeType) -> nn.Module: class TargetModelExportInfo(SubModuleExportInfo): """Export info for the target model inside EagleWrapper.""" - def __init__(self, load_lm_head_from_target: bool): - super().__init__("target_model") + def __init__(self, load_lm_head_from_target: bool, submodule_name: str = "target_model"): + super().__init__(submodule_name) self.load_lm_head_from_target = load_lm_head_from_target def _init_dynamic_shape_lookup(self) -> Dict[str, DynamicShape]: @@ -111,11 +191,31 @@ def _init_dynamic_shape_lookup(self) -> Dict[str, DynamicShape]: "position_ids": {0: batch_size_dyn, 1: seq_len_dyn}, } + @staticmethod + def _add_sticky_sentinel(sub_gm: GraphModule, attr_path: str) -> None: + """Insert a scalar-valued sentinel node so the submodule at attr_path is never DCE'd. + + Mirrors TextModelExportInfo in hf.py: we derive a scalar (num rows ≥ 0) from the + weight tensor rather than asserting on the tensor itself, which would fail for + non-scalar tensors under fake-tensor shape propagation. + """ + output_node = next(node for node in sub_gm.graph.nodes if node.op == "output") + with sub_gm.graph.inserting_before(output_node): + n_weight = sub_gm.graph.get_attr(f"{attr_path}.weight") + n_rows = sub_gm.graph.call_function(torch.ops.aten.sym_size.int, args=(n_weight, 0)) + n_ok = sub_gm.graph.call_function(operator.ge, args=(n_rows, 0)) + sub_gm.graph.call_function( + torch._assert, args=(n_ok, f"Avoid {attr_path} getting deleted from graph.") + ) + def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): - """Preserve embedding (always) and optionally lm_head on the exported GraphModule.""" + """Preserve embedding (always) and optionally lm_head on the exported GraphModule. + + Follows the same pattern as TextModelExportInfo.post_process in hf.py: + __func__ binding + set_submodule + scalar sym_size sentinel. + """ # --- Embedding: always needed (target embeds input_ids for both target and draft) --- embed_tokens = sub_mod.get_input_embeddings() - # Find the submodule path for the embedding for embed_name, subsubmod in sub_mod.named_modules(): if subsubmod is embed_tokens: break @@ -123,13 +223,9 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): raise RuntimeError("Could not find embedding module in target model.") sub_gm.set_submodule(embed_name, embed_tokens) sub_gm.get_input_embeddings = types.MethodType( - lambda self, _n=embed_name: self.get_submodule(_n), sub_gm - ) - # Add impure node to prevent GC - n_embed = sub_gm.graph.get_attr(f"{embed_name}.weight") - sub_gm.graph.call_function( - torch._assert, args=(n_embed, "Avoid embedding getting deleted from graph.") + sub_mod.get_input_embeddings.__func__, sub_gm ) + self._add_sticky_sentinel(sub_gm, embed_name) # --- lm_head: only if draft model loads it from target --- if self.load_lm_head_from_target: @@ -141,12 +237,9 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): raise RuntimeError("Could not find lm_head module in target model.") sub_gm.set_submodule(lm_head_name, lm_head) sub_gm.get_output_embeddings = types.MethodType( - lambda self, _n=lm_head_name: self.get_submodule(_n), sub_gm - ) - n_lm_head = sub_gm.graph.get_attr(f"{lm_head_name}.weight") - sub_gm.graph.call_function( - torch._assert, args=(n_lm_head, "Avoid lm_head getting deleted from graph.") + sub_mod.get_output_embeddings.__func__, sub_gm ) + self._add_sticky_sentinel(sub_gm, lm_head_name) # --- Final normalization: only if target model exposes it (e.g., NemotronH for MTP) --- if hasattr(sub_mod, "get_final_normalization"): @@ -158,12 +251,9 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): raise RuntimeError("Could not find final normalization module in target model.") sub_gm.set_submodule(norm_name, norm_module) sub_gm.get_final_normalization = types.MethodType( - lambda self, _n=norm_name: self.get_submodule(_n), sub_gm - ) - n_norm = sub_gm.graph.get_attr(f"{norm_name}.weight") - sub_gm.graph.call_function( - torch._assert, args=(n_norm, "Avoid final norm getting deleted from graph.") + sub_mod.get_final_normalization.__func__, sub_gm ) + self._add_sticky_sentinel(sub_gm, norm_name) class DraftModelExportInfo(SubModuleExportInfo): @@ -204,8 +294,10 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): lambda self, _n=embed_name: self.get_submodule(_n), sub_gm ) n_embed = sub_gm.graph.get_attr(f"{embed_name}.weight") + n_rows = sub_gm.graph.call_function(torch.ops.aten.sym_size.int, args=(n_embed, 0)) + n_ok = sub_gm.graph.call_function(operator.ge, args=(n_rows, 0)) sub_gm.graph.call_function( - torch._assert, args=(n_embed, "Avoid draft embedding getting deleted.") + torch._assert, args=(n_ok, "Avoid draft embedding getting deleted.") ) # --- lm_head (only if draft model has its own) --- @@ -221,23 +313,23 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): lambda self, _n=lm_head_name: self.get_submodule(_n), sub_gm ) n_lm_head = sub_gm.graph.get_attr(f"{lm_head_name}.weight") + n_rows = sub_gm.graph.call_function(torch.ops.aten.sym_size.int, args=(n_lm_head, 0)) + n_ok = sub_gm.graph.call_function(operator.ge, args=(n_rows, 0)) sub_gm.graph.call_function( - torch._assert, args=(n_lm_head, "Avoid draft lm_head getting deleted.") + torch._assert, args=(n_ok, "Avoid draft lm_head getting deleted.") ) # --- fc module (fuses hidden states from multiple layers) --- fc_module = getattr(inner_model, "fc", None) if fc_module is not None: sub_gm.set_submodule("model.fc", fc_module) - n_fc = sub_gm.graph.get_attr("model.fc.weight") - sub_gm.graph.call_function(torch._assert, args=(n_fc, "Avoid fc getting deleted.")) + sub_gm.graph.get_attr("model.fc.weight") # --- d2t parameter (draft-to-target vocab mapping) --- d2t = getattr(inner_model, "d2t", None) if d2t is not None: inner_gm.register_parameter("d2t", d2t) - n_d2t = sub_gm.graph.get_attr("model.d2t") - sub_gm.graph.call_function(torch._assert, args=(n_d2t, "Avoid d2t getting deleted.")) + sub_gm.graph.get_attr("model.d2t") # --- model dtype (used by apply_eagle3_fc) --- model_dtype = getattr(inner_model, "dtype", None) @@ -268,6 +360,7 @@ def __init__( max_seq_len: int = 512, speculative_config: Any = None, speculative_model_kwargs: Optional[Dict[str, Any]] = None, + target_factory_cls_name: str = "AutoModelForCausalLM", **kwargs, ): super().__init__( @@ -291,8 +384,9 @@ def __init__( if draft_model_path is None: raise ValueError("speculative_config.speculative_model must be set.") - # Create target factory (AutoModelForCausalLM) - self.target_factory = AutoModelForCausalLMFactory( + # Create target factory using the configured factory class + target_factory_cls = ModelFactoryRegistry.get(target_factory_cls_name) + self.target_factory = target_factory_cls( model=model, model_kwargs=model_kwargs, tokenizer=tokenizer, @@ -301,9 +395,11 @@ def __init__( max_seq_len=max_seq_len, ) - # Create draft factory (EagleDrafter) + # Create draft factory (EagleDrafter), passing target model path as config_model + # so that drafters with native (non-HF) checkpoints can reuse the target's config. self.draft_factory = EagleDrafterFactory( model=str(draft_model_path), + config_model=model, model_kwargs=speculative_model_kwargs, tokenizer=tokenizer, skip_loading_weights=skip_loading_weights, @@ -352,9 +448,15 @@ def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: draft_config = model.draft_model.config load_embedding_from_target = getattr(draft_config, "load_embedding_from_target", True) load_lm_head_from_target = getattr(draft_config, "load_lm_head_from_target", True) + target_export_infos = self.target_factory.get_export_infos(model.target_model) + target_sub_name = target_export_infos[0].submodule_name if target_export_infos else "" + if target_sub_name: + target_submodule_name = f"target_model.{target_sub_name}" + else: + target_submodule_name = "target_model" return [ - TargetModelExportInfo(load_lm_head_from_target), + TargetModelExportInfo(load_lm_head_from_target, submodule_name=target_submodule_name), DraftModelExportInfo(load_embedding_from_target, load_lm_head_from_target), ] diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index e01ec6fe2c0..be4adcc8270 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -462,9 +462,41 @@ def _load_checkpoint_with_preload( self, model: nn.Module, ckpt_file: str, device: DeviceLikeType ): all_weights = self._load_full_checkpoint_to_cpu(ckpt_file) + model_keys = list(model.state_dict().keys()) + model_moe_keys = [key for key in model_keys if _MOE_EXPERT_KEY_RE.search(key) is not None] + if model_moe_keys: + ad_logger.info( + f"Model expects {len(model_moe_keys)} MoE expert keys spanning expert ids: " + f"{_summarize_moe_expert_keys(model_moe_keys)}" + ) ad_logger.info(f"Loading weights into model (device: {device})...") - model.load_state_dict(all_weights, strict=False) + incompatible = model.load_state_dict(all_weights, strict=False) + if incompatible.missing_keys or incompatible.unexpected_keys: + ad_logger.warning( + "Checkpoint load completed with " + f"{len(incompatible.missing_keys)} missing and " + f"{len(incompatible.unexpected_keys)} unexpected keys" + ) + if incompatible.missing_keys: + ad_logger.warning( + "Sample missing keys: " + ", ".join(sorted(incompatible.missing_keys)[:20]) + ) + if incompatible.unexpected_keys: + ad_logger.warning( + "Sample unexpected keys: " + + ", ".join(sorted(incompatible.unexpected_keys)[:20]) + ) + unexpected_moe_keys = [ + key + for key in incompatible.unexpected_keys + if _MOE_EXPERT_KEY_RE.search(key) is not None + ] + if unexpected_moe_keys: + ad_logger.warning( + f"Unexpected MoE expert keys: {len(unexpected_moe_keys)} spanning expert " + f"ids {_summarize_moe_expert_keys(unexpected_moe_keys)}" + ) ad_logger.info("Checkpoint loading completed") @@ -824,3 +856,25 @@ def _example_image_dims(self) -> Tuple[int, int]: def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: return [TextModelExportInfo.from_autoinferred(model)] + + +_MOE_EXPERT_KEY_RE = re.compile(r"\.mlp\.experts\.(\d+)\.") + + +def _summarize_moe_expert_keys(keys: List[str]) -> str: + expert_ids = sorted( + { + int(match.group(1)) + for key in keys + if (match := _MOE_EXPERT_KEY_RE.search(key)) is not None + } + ) + if not expert_ids: + return "none" + if len(expert_ids) <= 16: + return ",".join(str(idx) for idx in expert_ids) + return ( + ",".join(str(idx) for idx in expert_ids[:8]) + + " ... " + + ",".join(str(idx) for idx in expert_ids[-8:]) + ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py b/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py index 8c92d9e6f5a..6e918ba8f79 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py @@ -59,8 +59,9 @@ def cached_residual_add( t1: torch.Tensor, t2: torch.Tensor, hidden_states_cache: torch.Tensor ) -> torch.Tensor: ret = torch.ops.aten.add(t1, t2) - b, s, _ = ret.shape - num_tokens = b * s + # Support both 3D [batch, seq, hidden] (standard attention) and + # 2D [num_tokens, hidden] (flat MLA pipeline where the AD pipeline captures flat inputs). + num_tokens = ret.shape[0] if ret.dim() == 2 else ret.shape[0] * ret.shape[1] hidden_states_cache[:num_tokens].copy_(ret.view(num_tokens, -1), non_blocking=True) return ret diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 71f4bab4e52..a958c86d03d 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -854,43 +854,64 @@ def get_user_if_pattern_match(node, ops, numusers, user_idx: int = 0): def identify_regions_between_residuals(gm: GraphModule) -> List[Node]: - """Identify regions of the graph that we can investigate further for patterning matching. + """Identify regions of the graph that we can investigate further for pattern matching. - Right now, we split the regions according to the following structure: - 1. Input node - 2. Embedding node - 3. Residual nodes from the embedding node onwards (no other nodes in-between) + Returns boundary nodes that split the graph into regions: + 1. Seed placeholder (input_ids or inputs_embeds) + 2. Embedding node (if present — only when input_ids feeds into aten.embedding) + 3. Residual add nodes forming the skip-connection chain 4. Output node - The list will contain the boundary nodes between the regions. + Seed selection: + - If the graph contains an ``aten.embedding`` op, the placeholder feeding it + (``input_ids``) is used as the seed, and the embedding node is the first + boundary after the seed. + - Otherwise, the ``inputs_embeds`` placeholder is used directly as the seed. + This handles models exported with pre-embedded inputs (e.g., Eagle targets + where a CausalLM is exported with ``inputs_embeds`` provided and the + ``input_ids`` → embedding branch is not traced, leaving ``input_ids`` as + an unused placeholder in the graph). """ assert gm.graph.nodes, "Graph is empty" - # get first input node and last output node - input_id_node = None + # Collect all placeholders and the output node. + placeholders = [] output_node = None for node in gm.graph.nodes: - if input_id_node is None and node.op == "placeholder": - input_id_node = node + if node.op == "placeholder": + placeholders.append(node) if node.op == "output": output_node = node - assert input_id_node, "Could not find input node" + assert placeholders, "Could not find input node" assert output_node, "Could not find output node" - # start list of boundary nodes - boundary_nodes = [input_id_node] - - # find embedding node which we assume to be the first node in a sequence of residual nodes - for n_user in input_id_node.users: - if is_op(n_user, torch.ops.aten.embedding): - break + # Find the right seed placeholder for the residual chain. + # 1. Look for an aten.embedding op — its input placeholder (input_ids) is the seed. + # 2. Otherwise, use the "inputs_embeds" placeholder by name. + seed_node = None + embed_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.embedding.default) + embed_node = embed_nodes[0] if embed_nodes else None + if embed_node is not None: + # aten.embedding(weight, indices, ...) — args[1] is the input_ids placeholder. + seed_node = embed_node.args[1] else: - # we could not identify any boundary regions via embedding nodes - boundary_nodes.append(output_node) - return boundary_nodes + # No embedding — use the inputs_embeds placeholder by name. + placeholder_by_name = {ph.name: ph for ph in placeholders} + seed_node = placeholder_by_name.get("inputs_embeds") + + if seed_node is None: + # No usable seed found — return minimal boundary list so callers + # (e.g., get_all_layer_subgraphs) gracefully find no layers. + ad_logger.debug( + f"Could not find residual chain seed: no aten.embedding op and no " + f"'inputs_embeds' placeholder. Placeholders: {[ph.name for ph in placeholders]}" + ) + return [placeholders[0], output_node] - # add embedding node to boundary nodes - boundary_nodes.append(n_user) + # start list of boundary nodes + boundary_nodes = [seed_node] + if embed_node is not None: + boundary_nodes.append(embed_node) # find residual nodes from here on while True: @@ -1384,8 +1405,9 @@ def filter_condition(node: Node, dim: int) -> bool: layer_type=LayerType.UNKNOWN, ) + src_node = linear_nodes[start_lin_index] forward_subgraph = subgraph( - sources=[linear_nodes[start_lin_index]], + sources=[src_node], boundary_condition=lambda n: boundary_condition(n, dim=0), ) lin_nodes_in_subgraph = list( @@ -1528,7 +1550,7 @@ def classify_layer_type() -> [LayerType, int]: min_local_shape=head_size, ) assert linear_nodes[start_lin_index] in opening_linear_nodes, ( - f"Linear node not found in opening linear nodes - " + f"Linear node ({linear_nodes[start_lin_index].name}) not found in opening linear nodes - " f"terminating_linear_node:{terminating_linear_node.name}, " f"opening_linear_nodes: {[n.name for n in opening_linear_nodes]}" ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral4_eagle_modeling.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral4_eagle_modeling.py new file mode 100644 index 00000000000..8145185b0a3 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral4_eagle_modeling.py @@ -0,0 +1,624 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Hierarchical unit tests for the Mistral4 Eagle AutoDeploy layer implementation. + +Tests every AD canonical op (torch_mla, torch_rope_with_explicit_cos_sin, torch_rmsnorm) +against a plain-PyTorch reference, then tests the full Mistral4EagleLayer and the +EagleDrafterForCausalLM export path. + +Reference implementations below are minimal faithful copies of the checkpoint's modeling +semantics written in plain PyTorch (no AD canonical ops), used exclusively for numerical +comparison. +""" + +import math + +import pytest +import torch +import torch.nn.functional as F +from torch import nn + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( + EagleConfig, + EagleDrafterForCausalLM, +) +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_mistral3 import ( + Mistral4EagleLayer, + Mistral4EagleMLA, + Mistral4EagleMLP, + Mistral4RMSNorm, + Mistral4TextConfig, +) + +# ============================================================================= +# Helpers +# ============================================================================= + + +def assert_rmse_close( + actual: torch.Tensor, expected: torch.Tensor, rmse_ratio_tol: float, msg: str = "" +) -> None: + actual = actual.float() + expected = expected.float() + rmse = torch.sqrt(torch.mean((actual - expected) ** 2)) + denom = torch.sqrt(torch.mean(expected**2)).clamp_min(1e-8) + ratio = (rmse / denom).item() + assert ratio <= rmse_ratio_tol, f"{msg}rmse_ratio={ratio:.6f} > {rmse_ratio_tol:.6f}" + + +def _device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _small_eagle_text_config() -> Mistral4TextConfig: + """Tiny Mistral4TextConfig suitable for Eagle unit testing.""" + return Mistral4TextConfig( + vocab_size=256, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + q_lora_rank=32, + kv_lora_rank=16, + qk_head_dim=16, + qk_nope_head_dim=8, + qk_rope_head_dim=8, + v_head_dim=8, + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + first_k_dense_replace=0, + moe_layer_freq=1, + max_position_embeddings=128, + rope_parameters={ + "type": "yarn", + "rope_type": "yarn", + "factor": 8.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 32, + "rope_theta": 10000.0, + "llama_4_scaling_beta": 0.1, + }, + pad_token_id=0, + ) + + +@pytest.fixture(autouse=True) +def set_seed(): + torch.manual_seed(42) + + +# ============================================================================= +# Reference implementations (plain PyTorch, no AD canonical ops) +# ============================================================================= + + +class RefMistral4EagleRMSNorm(nn.Module): + """Plain-PyTorch RMSNorm reference.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.float() + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight.float() * hidden_states).to(input_dtype) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class RefMistral4EagleRotaryEmbedding(nn.Module): + """Plain-PyTorch YaRN RoPE reference (same math as Mistral4YarnRotaryEmbedding).""" + + def __init__(self, config: Mistral4TextConfig): + super().__init__() + rs = config.rope_scaling + self.dim = config.qk_rope_head_dim + self.base = rs.get("rope_theta", 10000.0) + self.scale = rs.get("factor", 1.0) + self.beta_fast = rs.get("beta_fast", 32.0) + self.beta_slow = rs.get("beta_slow", 1.0) + self.mscale = rs.get("mscale", 1.0) + self.mscale_all_dim = rs.get("mscale_all_dim", 1.0) + self.original_max_position_embeddings = rs.get("original_max_position_embeddings", 8192) + self.max_position_embeddings = config.max_position_embeddings + self._build_cache() + + @staticmethod + def _find_correction_dim(num_rotations, dim, base, max_position_embeddings): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + @classmethod + def _find_correction_range(cls, low_rot, high_rot, dim, base, max_position_embeddings): + low = math.floor(cls._find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(cls._find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + @staticmethod + def _get_mscale(scale, mscale): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def _build_cache(self): + dim = self.dim + freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + freq_inter = 1.0 / ( + self.scale * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + low, high = self._find_correction_range( + self.beta_fast, self.beta_slow, dim, self.base, self.original_max_position_embeddings + ) + mask = 1.0 - torch.clamp( + (torch.arange(dim // 2, dtype=torch.float32) - low) / max(high - low, 1e-3), 0, 1 + ) + inv_freq = freq_inter * (1 - mask) + freq_extra * mask + t = torch.arange(self.max_position_embeddings, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + mscale = self._get_mscale(self.scale, self.mscale) / self._get_mscale( + self.scale, self.mscale_all_dim + ) + self.register_buffer("cos_cached", emb.cos() * mscale, persistent=False) + self.register_buffer("sin_cached", emb.sin() * mscale, persistent=False) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + cos = self.cos_cached.to(dtype=x.dtype, device=x.device)[position_ids] + sin = self.sin_cached.to(dtype=x.dtype, device=x.device)[position_ids] + return cos, sin + + +class RefMistral4EagleMLP(nn.Module): + """Plain-PyTorch SwiGLU MLP with w1/w2/w3 names matching Eagle checkpoint.""" + + def __init__(self, config: Mistral4TextConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) + self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) + self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RefMistral4EagleMLA(nn.Module): + """Plain-PyTorch MLA reference using standard SDPA. + + Implements the same computation as Mistral4EagleMLA but using: + - F.scaled_dot_product_attention instead of torch_mla + - Manual rotate_half instead of torch_rope_with_explicit_cos_sin + - Inline variance norm instead of torch_rmsnorm + + Weight names match Mistral4EagleMLA / Eagle checkpoint naming. + """ + + def __init__(self, config: Mistral4TextConfig, layer_idx: int): + super().__init__() + self.num_heads = config.num_attention_heads + self.q_lora_rank = config.q_lora_rank + self.kv_lora_rank = config.kv_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim + self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.softmax_scale = self.q_head_dim ** (-0.5) + + rope_parameters = config.rope_parameters + mscale_all_dim = rope_parameters.get("mscale_all_dim", 0.0) + if mscale_all_dim: + scale = rope_parameters.get("factor", 1.0) + yarn_mscale = RefMistral4EagleRotaryEmbedding._get_mscale(scale, mscale_all_dim) + self.softmax_scale = self.softmax_scale * yarn_mscale * yarn_mscale + + rms_eps = config.rms_norm_eps + self.wq_a = nn.Linear(config.hidden_size, self.q_lora_rank, bias=False) + self.q_a_norm = RefMistral4EagleRMSNorm(self.q_lora_rank, eps=rms_eps) + self.wq_b = nn.Linear(self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + self.wkv_a_with_mqa = nn.Linear( + config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False + ) + self.kv_a_norm = RefMistral4EagleRMSNorm(self.kv_lora_rank, eps=rms_eps) + self.wkv_b = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + self.wo = nn.Linear(self.num_heads * self.v_head_dim, config.hidden_size, bias=False) + self.rotary_emb = RefMistral4EagleRotaryEmbedding(config) + + def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + + q = self.wq_b(self.q_a_norm(self.wq_a(hidden_states))) + q = q.view(batch_size, seq_len, self.num_heads, self.q_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + kv = self.wkv_a_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + compressed_kv = self.kv_a_norm(compressed_kv) + k_pe = k_pe.view(batch_size, seq_len, 1, self.qk_rope_head_dim) + + cos, sin = self.rotary_emb(hidden_states, position_ids) + # Apply RoPE via rotate_half (no unsqueeze needed — cos/sin are [B,S,rope_dim]) + cos_h = cos.unsqueeze(2) # [B,S,1,rope_dim] + sin_h = sin.unsqueeze(2) # [B,S,1,rope_dim] + q_pe = q_pe * cos_h + _rotate_half(q_pe) * sin_h + k_pe = k_pe * cos_h + _rotate_half(k_pe) * sin_h + + # Absorb wkv_b: expand compressed_kv to full K and V + kv_expanded = F.linear(compressed_kv, self.wkv_b.weight) + kv_expanded = kv_expanded.view( + batch_size, seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = torch.split(kv_expanded, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q = torch.cat([q_nope, q_pe], dim=-1).permute(0, 2, 1, 3).float() + k = ( + torch.cat([k_nope, k_pe.expand(-1, -1, self.num_heads, -1)], dim=-1) + .permute(0, 2, 1, 3) + .float() + ) + v = v.permute(0, 2, 1, 3).float() + + scores = torch.matmul(q, k.transpose(-1, -2)) * self.softmax_scale + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device=scores.device, dtype=torch.bool), diagonal=1 + ) + scores = scores.masked_fill(causal_mask, float("-inf")) + probs = torch.softmax(scores, dim=-1) + output = torch.matmul(probs, v).permute(0, 2, 1, 3).to(hidden_states.dtype) + output = output.reshape(batch_size, seq_len, self.num_heads * self.v_head_dim) + return self.wo(output) + + +class RefMistral4EagleLayer(nn.Module): + """Plain-PyTorch Mistral4 Eagle transformer layer reference.""" + + def __init__(self, config: Mistral4TextConfig, layer_idx: int, has_eagle_proj: bool = False): + super().__init__() + rms_eps = config.rms_norm_eps + self.eagle_proj = ( + nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + if has_eagle_proj + else None + ) + self.attention_norm = RefMistral4EagleRMSNorm(config.hidden_size, eps=rms_eps) + self.attention = RefMistral4EagleMLA(config, layer_idx) + self.ffn_norm = RefMistral4EagleRMSNorm(config.hidden_size, eps=rms_eps) + self.feed_forward = RefMistral4EagleMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, + position_ids: torch.LongTensor, + ) -> torch.Tensor: + if self.eagle_proj is not None: + hidden_states = self.eagle_proj(torch.cat([inputs_embeds, hidden_states], dim=-1)) + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + hidden_states = self.attention(hidden_states, position_ids) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + return residual + hidden_states + + +# ============================================================================= +# Weight-copy helpers +# ============================================================================= + + +def _copy_weights(target: nn.Module, source: nn.Module) -> None: + """Load source state dict into target (strict, same parameter names).""" + target.load_state_dict(source.state_dict(), strict=True) + + +# ============================================================================= +# Tests: per-op AD canonical op vs. reference +# ============================================================================= + + +def test_mistral4_eagle_rmsnorm_equivalence(): + """Mistral4RMSNorm (torch_rmsnorm AD op) == RefMistral4EagleRMSNorm (plain PyTorch).""" + device = _device() + dtype = torch.bfloat16 + ad_mod = Mistral4RMSNorm(64, eps=1e-6).to(device=device, dtype=dtype) + ref_mod = RefMistral4EagleRMSNorm(64, eps=1e-6).to(device=device, dtype=dtype) + _copy_weights(ref_mod, ad_mod) + x = torch.randn(2, 8, 64, device=device, dtype=dtype) + actual = ad_mod(x) + expected = ref_mod(x) + torch.testing.assert_close(actual, expected, atol=1e-3, rtol=1e-3) + + +def test_mistral4_eagle_mlp_equivalence(): + """Mistral4EagleMLP (torch ops) == RefMistral4EagleMLP (plain PyTorch F.silu + linear).""" + device = _device() + dtype = torch.bfloat16 + config = _small_eagle_text_config() + ad_mod = Mistral4EagleMLP(config).to(device=device, dtype=dtype) + ref_mod = RefMistral4EagleMLP(config).to(device=device, dtype=dtype) + _copy_weights(ref_mod, ad_mod) + x = torch.randn(2, 8, config.hidden_size, device=device, dtype=dtype) + actual = ad_mod(x) + expected = ref_mod(x) + torch.testing.assert_close(actual, expected, atol=1e-3, rtol=1e-3) + + +def test_mistral4_eagle_mla_equivalence(): + """Mistral4EagleMLA (torch_mla + torch_rope AD ops) == RefMistral4EagleMLA (plain PyTorch SDPA). + + Tests: torch_mla, torch_rope_with_explicit_cos_sin, torch_rmsnorm (via q_a_norm / kv_a_norm). + """ + if not torch.cuda.is_available(): + pytest.skip("torch_mla requires CUDA.") + device = "cuda" + dtype = torch.bfloat16 + config = _small_eagle_text_config() + ad_mod = Mistral4EagleMLA(config, layer_idx=0).to(device=device, dtype=dtype) + ref_mod = RefMistral4EagleMLA(config, layer_idx=0).to(device=device, dtype=dtype) + _copy_weights(ref_mod, ad_mod) + hidden_states = torch.randn(2, 6, config.hidden_size, device=device, dtype=dtype) + position_ids = torch.arange(6, device=device).unsqueeze(0).expand(2, -1) + actual = ad_mod(hidden_states, position_ids) + expected = ref_mod(hidden_states, position_ids) + assert_rmse_close(actual, expected, rmse_ratio_tol=0.10, msg="MLA: ") + + +def test_mistral4_eagle_layer_no_proj_equivalence(): + """Mistral4EagleLayer (no eagle_proj, layer_idx>0) == Ref layer.""" + if not torch.cuda.is_available(): + pytest.skip("torch_mla requires CUDA.") + device = "cuda" + dtype = torch.bfloat16 + config = _small_eagle_text_config() + ad_mod = Mistral4EagleLayer(config, layer_idx=1, has_eagle_proj=False).to( + device=device, dtype=dtype + ) + ref_mod = RefMistral4EagleLayer(config, layer_idx=1, has_eagle_proj=False).to( + device=device, dtype=dtype + ) + _copy_weights(ref_mod, ad_mod) + hidden_states = torch.randn(2, 5, config.hidden_size, device=device, dtype=dtype) + inputs_embeds = torch.randn(2, 5, config.hidden_size, device=device, dtype=dtype) + position_ids = torch.arange(5, device=device).unsqueeze(0).expand(2, -1) + actual = ad_mod(hidden_states, inputs_embeds, position_ids) + expected = ref_mod(hidden_states, inputs_embeds, position_ids) + assert_rmse_close(actual, expected, rmse_ratio_tol=0.05, msg="EagleLayer (no proj): ") + + +def test_mistral4_eagle_layer_with_proj_equivalence(): + """Mistral4EagleLayer (with eagle_proj, layer_idx=0) == Ref layer.""" + if not torch.cuda.is_available(): + pytest.skip("torch_mla requires CUDA.") + device = "cuda" + dtype = torch.bfloat16 + config = _small_eagle_text_config() + ad_mod = Mistral4EagleLayer(config, layer_idx=0, has_eagle_proj=True).to( + device=device, dtype=dtype + ) + ref_mod = RefMistral4EagleLayer(config, layer_idx=0, has_eagle_proj=True).to( + device=device, dtype=dtype + ) + _copy_weights(ref_mod, ad_mod) + hidden_states = torch.randn(2, 5, config.hidden_size, device=device, dtype=dtype) + inputs_embeds = torch.randn(2, 5, config.hidden_size, device=device, dtype=dtype) + position_ids = torch.arange(5, device=device).unsqueeze(0).expand(2, -1) + actual = ad_mod(hidden_states, inputs_embeds, position_ids) + expected = ref_mod(hidden_states, inputs_embeds, position_ids) + assert_rmse_close(actual, expected, rmse_ratio_tol=0.05, msg="EagleLayer (with proj): ") + + +# ============================================================================= +# Tests: FP8 dequant hooks +# ============================================================================= + + +def _make_fp8_weight_entry( + proj_name: str, + shape: tuple, + device: str, +) -> tuple[dict, torch.Tensor]: + """Create a fake FP8 checkpoint entry (weight + scale) and the expected dequant result.""" + dequantized = torch.randn(*shape, dtype=torch.float32, device=device) + scale = torch.tensor(0.5, dtype=torch.float32, device=device) + quantized = (dequantized / scale).clamp(-448, 448).to(torch.float8_e4m3fn) + state_dict = { + f"{proj_name}.weight": quantized, + f"{proj_name}.qscale_weight": scale, + f"{proj_name}.qscale_act": torch.tensor(3.0, device=device), + } + return state_dict, (quantized.float() * scale) + + +def test_mistral4_eagle_mlp_fp8_dequant_hook(): + """Mistral4EagleMLP._dequantize_fp8_weights correctly dequantizes w1/w2/w3 FP8 weights.""" + device = _device() + config = _small_eagle_text_config() + mlp = Mistral4EagleMLP(config).to(device) + + state_dict = {} + expected = {} + for proj, shape in [ + ("w1", (config.intermediate_size, config.hidden_size)), + ("w2", (config.hidden_size, config.intermediate_size)), + ("w3", (config.intermediate_size, config.hidden_size)), + ]: + entry, deq = _make_fp8_weight_entry(proj, shape, device) + state_dict.update(entry) + expected[proj] = deq + + mlp._dequantize_fp8_weights(state_dict, "") + + for proj in ("w1", "w2", "w3"): + assert f"{proj}.qscale_weight" not in state_dict, f"{proj}.qscale_weight not removed" + assert f"{proj}.qscale_act" not in state_dict, f"{proj}.qscale_act not removed" + weight = state_dict[f"{proj}.weight"] + assert weight.dtype != torch.float8_e4m3fn, f"{proj}.weight not dequantized" + torch.testing.assert_close(weight.float(), expected[proj], rtol=0, atol=1e-3) + + +def test_mistral4_eagle_mla_fp8_dequant_hooks(): + """Mistral4EagleMLA FP8 hooks dequantize wkv_b and attention projection weights.""" + device = _device() + config = _small_eagle_text_config() + mla = Mistral4EagleMLA(config, layer_idx=0).to(device) + + # Test wkv_b hook + wkv_b_shape = ( + config.num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim), + config.kv_lora_rank, + ) + state_dict_wkv_b, expected_wkv_b = _make_fp8_weight_entry("wkv_b", wkv_b_shape, device) + mla._dequantize_fp8_wkv_b(state_dict_wkv_b, "") + assert "wkv_b.qscale_weight" not in state_dict_wkv_b + assert "wkv_b.qscale_act" not in state_dict_wkv_b + assert state_dict_wkv_b["wkv_b.weight"].dtype != torch.float8_e4m3fn + torch.testing.assert_close( + state_dict_wkv_b["wkv_b.weight"].float(), expected_wkv_b, rtol=0, atol=1e-3 + ) + + # Test remaining projections hook + proj_shapes = { + "wq_a": (config.q_lora_rank, config.hidden_size), + "wq_b": ( + config.num_attention_heads * (config.qk_nope_head_dim + config.qk_rope_head_dim), + config.q_lora_rank, + ), + "wkv_a_with_mqa": (config.kv_lora_rank + config.qk_rope_head_dim, config.hidden_size), + "wo": (config.hidden_size, config.num_attention_heads * config.v_head_dim), + } + state_dict_proj = {} + expected_proj = {} + for proj, shape in proj_shapes.items(): + entry, deq = _make_fp8_weight_entry(proj, shape, device) + state_dict_proj.update(entry) + expected_proj[proj] = deq + + mla._dequantize_fp8_weights(state_dict_proj, "") + for proj in proj_shapes: + assert f"{proj}.qscale_weight" not in state_dict_proj + assert f"{proj}.qscale_act" not in state_dict_proj + assert state_dict_proj[f"{proj}.weight"].dtype != torch.float8_e4m3fn + torch.testing.assert_close( + state_dict_proj[f"{proj}.weight"].float(), expected_proj[proj], rtol=0, atol=1e-3 + ) + + +def test_mistral4_eagle_layer_fp8_eagle_proj_hook(): + """Mistral4EagleLayer._dequantize_fp8_eagle_proj dequantizes eagle_proj.""" + device = _device() + config = _small_eagle_text_config() + layer = Mistral4EagleLayer(config, layer_idx=0, has_eagle_proj=True).to(device) + + proj_shape = (config.hidden_size, 2 * config.hidden_size) + state_dict, expected = _make_fp8_weight_entry("eagle_proj", proj_shape, device) + + layer._dequantize_fp8_eagle_proj(state_dict, "") + assert "eagle_proj.qscale_weight" not in state_dict + assert "eagle_proj.qscale_act" not in state_dict + assert state_dict["eagle_proj.weight"].dtype != torch.float8_e4m3fn + torch.testing.assert_close(state_dict["eagle_proj.weight"].float(), expected, rtol=0, atol=1e-3) + + +def test_mistral4_eagle_layer_no_eagle_proj_hook_is_noop(): + """Layers without eagle_proj don't have the FP8 hook registered.""" + config = _small_eagle_text_config() + layer = Mistral4EagleLayer(config, layer_idx=1, has_eagle_proj=False) + # No pre-hooks registered on the layer itself (hooks are on submodules) + assert layer.eagle_proj is None + + +# ============================================================================= +# Tests: EagleDrafterForCausalLM export +# ============================================================================= + + +def _make_eagle_config(config: Mistral4TextConfig) -> EagleConfig: + """Wrap a Mistral4TextConfig in EagleConfig with mistral4 defaults.""" + return EagleConfig(config, "mistral4") + + +def test_mistral4_eagle_drafter_forward(): + """EagleDrafterForCausalLM forward runs without error and returns finite outputs.""" + if not torch.cuda.is_available(): + pytest.skip("torch_mla requires CUDA.") + device = "cuda" + dtype = torch.bfloat16 + config = _make_eagle_config(_small_eagle_text_config()) + model = EagleDrafterForCausalLM(config).to(device=device, dtype=dtype) + model.eval() + + batch, seq = 2, 6 + hidden_size = config.hidden_size + inputs_embeds = torch.randn(batch, seq, hidden_size, device=device, dtype=dtype) + position_ids = torch.arange(seq, device=device).unsqueeze(0).expand(batch, -1) + hidden_states = torch.randn(batch, seq, hidden_size, device=device, dtype=dtype) + + with torch.no_grad(): + out = model( + inputs_embeds=inputs_embeds, position_ids=position_ids, hidden_states=hidden_states + ) + + assert out.norm_hidden_state is not None + assert torch.isfinite(out.norm_hidden_state).all() + assert out.norm_hidden_state.shape == (batch, seq, hidden_size) + + +def test_mistral4_eagle_drafter_export(): + """EagleDrafterForCausalLM can be exported with torch.export.""" + if not torch.cuda.is_available(): + pytest.skip("torch_mla requires CUDA.") + device = "cuda" + dtype = torch.bfloat16 + config = _make_eagle_config(_small_eagle_text_config()) + model = EagleDrafterForCausalLM(config).to(device=device, dtype=dtype) + model.eval() + + batch, seq = 1, 8 + hidden_size = config.hidden_size + inputs_embeds = torch.randn(batch, seq, hidden_size, device=device, dtype=dtype) + position_ids = torch.arange(seq, device=device).unsqueeze(0) + hidden_states = torch.randn(batch, seq, hidden_size, device=device, dtype=dtype) + + # Note: dynamic_shapes is not passed since hidden_states arrives via **kwargs and + # PyTorch's dynamic_shapes pytree matching doesn't support **kwargs args well. + # We verify export succeeds and outputs are finite; shape flexibility is tested at + # the EagleOneModelFactory level during full-pipeline smoke tests. + gm = torch_export_to_gm( + model, + args=(inputs_embeds, position_ids), + kwargs={"hidden_states": hidden_states}, + ) + assert gm is not None + + with torch.no_grad(): + out = gm(inputs_embeds, position_ids, hidden_states=hidden_states) + assert out.norm_hidden_state is not None + assert torch.isfinite(out.norm_hidden_state).all() + assert out.norm_hidden_state.shape == (batch, seq, hidden_size) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral4_hidden_state_capture.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral4_hidden_state_capture.py new file mode 100644 index 00000000000..eef8e2808b9 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral4_hidden_state_capture.py @@ -0,0 +1,371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for detect_hidden_states_for_capture on Mistral4 models. + +Tests that collect_residual_add_nodes correctly identifies decoder-layer residual +add nodes in the Mistral4 graph (MLA attention + MoE FFN) and that: + - {0} captures the single layer of a 1-layer model + - {-1} resolves to layer 0 and captures the same node + - the captured node is placed correctly (between layers or at model output) + +Uses Mistral4ForCausalLM directly (text-only, no multimodal wrapper) so the +graph does not have the language_model. prefix that complicates layer detection. +""" + +import pytest +import torch +from torch.fx import GraphModule + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_mistral3 import ( + Mistral4ForCausalLM, + Mistral4TextConfig, +) +from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig +from tensorrt_llm._torch.auto_deploy.transform.library.hidden_states import ( + DetectHiddenStatesForCapture, +) + + +def _small_1layer_config() -> Mistral4TextConfig: + """Minimal 1-layer Mistral4TextConfig for fast export.""" + return Mistral4TextConfig( + vocab_size=256, + hidden_size=4096, + intermediate_size=12288, + moe_intermediate_size=2048, + num_hidden_layers=1, + num_attention_heads=32, + num_key_value_heads=32, + q_lora_rank=1024, + kv_lora_rank=256, + qk_head_dim=128, + qk_nope_head_dim=64, + qk_rope_head_dim=64, + v_head_dim=128, + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + n_group=1, + topk_group=1, + moe_layer_freq=1, + first_k_dense_replace=0, + max_position_embeddings=128, + rope_parameters={ + "type": "yarn", + "rope_type": "yarn", + "factor": 8.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 32, + "rope_theta": 10000.0, + "llama_4_scaling_beta": 0.1, + }, + pad_token_id=0, + ) + + +def _export_mistral4_text(config: Mistral4TextConfig) -> GraphModule: + """Export a Mistral4ForCausalLM to a GraphModule on CUDA.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA required for Mistral4 export.") + device = "cuda" + dtype = torch.bfloat16 + model = Mistral4ForCausalLM(config).to(device=device, dtype=dtype).eval() + inputs_embeds = torch.randn(1, 4, config.hidden_size, device=device, dtype=dtype) + position_ids = torch.arange(4, device=device).unsqueeze(0) + return torch_export_to_gm( + model, + kwargs={"inputs_embeds": inputs_embeds, "position_ids": position_ids}, + clone=True, + ) + + +def _make_transform(layers_to_capture) -> DetectHiddenStatesForCapture: + return DetectHiddenStatesForCapture( + config=TransformConfig( + stage="pattern_matcher", + eagle3_layers_to_capture=layers_to_capture, + ) + ) + + +@pytest.fixture(autouse=True) +def set_seed(): + torch.manual_seed(42) + + +# --------------------------------------------------------------------------- +# Helper: dump graph to /tmp for manual inspection +# --------------------------------------------------------------------------- + + +def _dump_graph(gm: GraphModule, path: str) -> None: + """Write a human-readable graph dump to `path` for visualization.""" + lines = [] + for node in gm.graph.nodes: + users = [u.name for u in node.users] + if node.op == "call_function": + w_info = "" + from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op + + if is_linear_op(node) and len(node.args) > 1: + w = node.args[1] + if hasattr(w, "op") and w.op == "get_attr": + w_info = f" weight={w.target}" + lines.append( + f"{node.op:15s} {str(node.target):60s} {node.name:40s}{w_info}" + f" [{', '.join(users)}]" + ) + else: + lines.append( + f"{node.op:15s} {str(node.target)[:60]:60s} {node.name:40s} [{', '.join(users)}]" + ) + with open(path, "w") as f: + f.write("\n".join(lines)) + + +# --------------------------------------------------------------------------- +# Test: collect_residual_add_nodes on 1-layer Mistral4 +# --------------------------------------------------------------------------- + + +def test_mistral4_collect_residual_add_nodes_one_layer(): + """collect_residual_add_nodes finds exactly 1 residual add for a 1-layer Mistral4. + + Dumps the graph to /tmp/mistral4_1layer_graph.txt for manual inspection. + """ + config = _small_1layer_config() + gm = _export_mistral4_text(config) + + dump_path = "/tmp/mistral4_1layer_graph.txt" + _dump_graph(gm, dump_path) + print(f"Graph dumped to {dump_path} ({len(list(gm.graph.nodes))} nodes)") + + t = DetectHiddenStatesForCapture(config=TransformConfig(stage="pattern_matcher")) + residual_add_nodes = t.collect_residual_add_nodes(gm) + + print(f"collect_residual_add_nodes: {sorted(residual_add_nodes.keys())}") + assert len(residual_add_nodes) >= 1, ( + f"Expected at least 1 residual add node for the single decoder layer, " + f"got {residual_add_nodes}. Check {dump_path} for graph structure." + ) + assert 0 in residual_add_nodes, ( + f"Expected layer 0 residual add but found keys={sorted(residual_add_nodes.keys())}. " + f"Check {dump_path} for graph structure." + ) + + +def test_mistral4_capture_explicit_layer_0(): + """detect_hidden_states_for_capture with eagle3_layers_to_capture={0} succeeds.""" + config = _small_1layer_config() + gm = _export_mistral4_text(config) + t = _make_transform(layers_to_capture={0}) + gm_out, info = t._apply(gm, None, None, None) + assert info.num_matches == 1, f"Expected 1 match for {{0}}, got {info.num_matches}" + capture_nodes = [ + n + for n in gm_out.graph.nodes + if n.op == "call_function" + and n.target == torch.ops.auto_deploy.residual_add_for_capture.default + ] + assert len(capture_nodes) == 1, f"Expected 1 capture node, got {len(capture_nodes)}" + + +def test_mistral4_capture_neg1_same_as_layer_0(): + """eagle3_layers_to_capture={-1} captures the same node as {0} for a 1-layer model. + + For a single-layer model layer 0 IS the last layer, so {-1} must resolve to {0} + and capture the identical residual add. + """ + config = _small_1layer_config() + + # Capture with explicit {0} + gm0 = _export_mistral4_text(config) + t0 = _make_transform(layers_to_capture={0}) + gm0_out, info0 = t0._apply(gm0, None, None, None) + + # Capture with {-1} (separate export to avoid in-place mutation interference) + gm_neg = _export_mistral4_text(config) + t_neg = _make_transform(layers_to_capture={-1}) + gm_neg_out, info_neg = t_neg._apply(gm_neg, None, None, None) + + assert info0.num_matches == 1, f"{{0}} gave {info0.num_matches} matches" + assert info_neg.num_matches == 1, f"{{-1}} gave {info_neg.num_matches} matches" + + # Both should have inserted exactly 1 capture node + def _get_capture_nodes(g): + return [ + n + for n in g.graph.nodes + if n.op == "call_function" + and n.target == torch.ops.auto_deploy.residual_add_for_capture.default + ] + + caps0 = _get_capture_nodes(gm0_out) + caps_neg = _get_capture_nodes(gm_neg_out) + assert len(caps0) == 1 and len(caps_neg) == 1 + + # The capture nodes should have the same args (same original residual add position) + args0 = tuple(a.name if isinstance(a, torch.fx.Node) else a for a in caps0[0].args) + args_neg = tuple(a.name if isinstance(a, torch.fx.Node) else a for a in caps_neg[0].args) + assert args0 == args_neg, ( + f"{{0}} captured at {args0} but {{-1}} captured at {args_neg}; " + "they should be identical for a 1-layer model." + ) + + +def test_mistral4_collect_residual_adds_counts_layers(): + """collect_residual_add_nodes returns exactly N entries for an N-layer model. + + This is the implicit 'count hidden layers' test the user described — the number + of entries in the dict tells you how many decoder layers were detected. + """ + for num_layers in (1, 2, 3): + config = _small_1layer_config() + config.num_hidden_layers = num_layers + gm = _export_mistral4_text(config) + t = DetectHiddenStatesForCapture(config=TransformConfig(stage="pattern_matcher")) + residual_add_nodes = t.collect_residual_add_nodes(gm) + assert len(residual_add_nodes) == num_layers, ( + f"num_hidden_layers={num_layers}: expected {num_layers} residual adds, " + f"got {sorted(residual_add_nodes.keys())}" + ) + assert set(residual_add_nodes.keys()) == set(range(num_layers)), ( + f"Expected layer indices {{0..{num_layers - 1}}}, " + f"got {sorted(residual_add_nodes.keys())}" + ) + + +def test_mistral4_is_linear_op_finds_correct_nodes(): + """is_linear_op identifies exactly the attention + shared-expert linear nodes. + + For Mistral4 with n_routed_experts=4, each decoder layer has: + - 4 MLA attention linears: wq_a, wq_b, wkv_a_with_mqa, wo + - 3 shared expert linears: gate_proj, up_proj, down_proj + Plus 1 lm_head → 7*N + 1 total. Routed expert linears are inside torch_moe + and are NOT is_linear_op nodes. + """ + from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op + + for num_layers in (1, 2): + config = _small_1layer_config() + config.num_hidden_layers = num_layers + gm = _export_mistral4_text(config) + lin_nodes = [n for n in gm.graph.nodes if is_linear_op(n)] + expected = 7 * num_layers + 1 # 7 per layer + lm_head + assert len(lin_nodes) == expected, ( + f"num_layers={num_layers}: expected {expected} linear nodes, got {len(lin_nodes)}" + ) + + +def test_mistral4_capture_default_layers_requires_enough_hidden_layers(): + """eagle3_layers_to_capture=None (default) requires num_hidden_layers > 6.""" + config = _small_1layer_config() + config.num_hidden_layers = 1 # too few for default capture set + gm = _export_mistral4_text(config) + t = DetectHiddenStatesForCapture( + config=TransformConfig(stage="pattern_matcher", eagle3_layers_to_capture=None) + ) + import pytest + + with pytest.raises(ValueError, match="Not enough hidden layers"): + t._apply(gm, None, None, None) + + +def test_mistral4_capture_default_layers_7_layer_model(): + """eagle3_layers_to_capture=None succeeds and picks 3 layers for a 7-layer model. + + set_default_eagle3_layers_to_capture({1, num_layers//2-1, num_layers-4}) for num_layers=7 + gives {1, 2, 3}. + """ + config = _small_1layer_config() + config.num_hidden_layers = 7 + gm = _export_mistral4_text(config) + t = DetectHiddenStatesForCapture( + config=TransformConfig(stage="pattern_matcher", eagle3_layers_to_capture=None) + ) + gm_out, info = t._apply(gm, None, None, None) + assert info.num_matches == 3, ( + f"Expected 3 matches (default 3-layer capture), got {info.num_matches}" + ) + capture_nodes = [ + n + for n in gm_out.graph.nodes + if n.op == "call_function" + and n.target == torch.ops.auto_deploy.residual_add_for_capture.default + ] + assert len(capture_nodes) == 3, f"Expected 3 capture nodes, got {len(capture_nodes)}" + + +def test_mistral4_capture_neg1_resolves_correctly_for_various_depths(): + """{-1} always resolves to the LAST layer, regardless of model depth.""" + for num_layers in (1, 2, 4): + config = _small_1layer_config() + config.num_hidden_layers = num_layers + gm = _export_mistral4_text(config) + t = _make_transform(layers_to_capture={-1}) + gm_out, info = t._apply(gm, None, None, None) + assert info.num_matches == 1, ( + f"num_layers={num_layers}: expected 1 match for {{-1}}, got {info.num_matches}" + ) + # {-1} should resolve to last layer index = num_layers - 1 + # Verify the captured node is from the last layer by checking that + # {num_layers-1} explicit capture finds the same node. + gm2 = _export_mistral4_text(config) + t2 = _make_transform(layers_to_capture={num_layers - 1}) + gm2_out, info2 = t2._apply(gm2, None, None, None) + assert info2.num_matches == 1 + + def _cap_args(g): + caps = [ + n + for n in g.graph.nodes + if n.op == "call_function" + and n.target == torch.ops.auto_deploy.residual_add_for_capture.default + ] + return tuple(a.name if isinstance(a, torch.fx.Node) else a for a in caps[0].args) + + assert _cap_args(gm_out) == _cap_args(gm2_out), ( + f"num_layers={num_layers}: {{-1}} and {{{num_layers - 1}}} captured at " + f"different positions: {_cap_args(gm_out)} vs {_cap_args(gm2_out)}" + ) + + +def test_mistral4_capture_2layer_last_vs_first(): + """For a 2-layer model, {-1} captures layer 1 and {0} captures layer 0 (different nodes).""" + config = _small_1layer_config() + config.num_hidden_layers = 2 # override to 2 layers + + # {0}: capture first layer + gm0 = _export_mistral4_text(config) + t0 = _make_transform(layers_to_capture={0}) + gm0_out, info0 = t0._apply(gm0, None, None, None) + + # {-1}: capture last layer (= layer 1) + gm_last = _export_mistral4_text(config) + t_last = _make_transform(layers_to_capture={-1}) + gm_last_out, info_last = t_last._apply(gm_last, None, None, None) + + assert info0.num_matches == 1, f"{{0}} matches: {info0.num_matches}" + assert info_last.num_matches == 1, f"{{-1}} matches: {info_last.num_matches}" + + def _capture_args(g): + caps = [ + n + for n in g.graph.nodes + if n.op == "call_function" + and n.target == torch.ops.auto_deploy.residual_add_for_capture.default + ] + return tuple(a.name if isinstance(a, torch.fx.Node) else a for a in caps[0].args) + + args0 = _capture_args(gm0_out) + args_last = _capture_args(gm_last_out) + # For a 2-layer model, the two captures should be at different graph positions + assert args0 != args_last, ( + f"{{0}} and {{-1}} captured the same node {args0}; they should differ for a 2-layer model." + ) diff --git a/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py index e0a57da550d..42e587056b1 100644 --- a/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py @@ -581,6 +581,30 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): "hidden_size": 64, } }, + "mistralai/Mistral-Small-4-119B-2603": { + "model_factory": "Mistral3ForConditionalGeneration", + "model_kwargs": { + "text_config": { + "num_hidden_layers": 2, + "hidden_size": 64, + "intermediate_size": 128, + "moe_intermediate_size": 32, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "q_lora_rank": 32, + "kv_lora_rank": 16, + "qk_head_dim": 16, + "qk_nope_head_dim": 8, + "qk_rope_head_dim": 8, + "v_head_dim": 8, + "n_routed_experts": 4, + "n_shared_experts": 1, + "num_experts_per_tok": 2, + "n_group": 1, + "topk_group": 1, + } + }, + }, } diff --git a/tests/unittest/auto_deploy/singlegpu/models/test_mistral4_eagle.py b/tests/unittest/auto_deploy/singlegpu/models/test_mistral4_eagle.py new file mode 100644 index 00000000000..4612da6c536 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/models/test_mistral4_eagle.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone tests for the Mistral4 Eagle drafter head with AutoDeploy. + +Two tests mirror the pattern from test_eagle.py (Llama Eagle): + +1. test_mistral4_eagle_model_torch_export — verify the Eagle drafter can be exported + with torch.export (architecture traceability, no weights needed). + +2. test_build_ad_mistral4_eagle — verify the Eagle drafter runs through the + full AutoDeploy compile/run pipeline using a mock that injects random hidden states, + so the Eagle head can be validated in isolation without the target model. + +Both tests require: + - mistralai/Mistral-Small-4-119B-2603-eagle (Eagle checkpoint, for weights/structure) + - mistralai/Mistral-Small-4-119B-2603 (target model, for HF config) +resolved via hf_id_to_local_model_dir. Tests are skipped automatically if either path +is unavailable. + +The Mistral4 Eagle checkpoint is in native Mistral format (no config.json), so +EagleDrafterFactory must be initialised with config_model= to +load the architecture config from the target model instead. +""" + +from contextlib import nullcontext +from pathlib import Path + +import pytest +import torch +from _model_test_utils import get_small_model_config +from accelerate import init_empty_weights +from build_and_run_ad import ExperimentConfig, main + +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( + Eagle3DraftOutput, + EagleConfig, + EagleDrafterForCausalLM, +) +from tensorrt_llm._torch.auto_deploy.models.custom.modeling_mistral3 import ADMistralSmall4Processor +from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry +from tests.test_common.llm_data import hf_id_to_local_model_dir + +EAGLE_MODEL_HUB_ID = "mistralai/Mistral-Small-4-119B-2603-eagle" +TARGET_MODEL_HUB_ID = "mistralai/Mistral-Small-4-119B-2603" + + +# --------------------------------------------------------------------------- +# Helpers to resolve and skip +# --------------------------------------------------------------------------- + + +def _require_paths(): + """Return (eagle_path, target_path) or skip the test if either is missing.""" + eagle_path = hf_id_to_local_model_dir(EAGLE_MODEL_HUB_ID) + target_path = hf_id_to_local_model_dir(TARGET_MODEL_HUB_ID) + if eagle_path is None or not Path(eagle_path).is_dir(): + pytest.skip(f"Eagle checkpoint not found: {EAGLE_MODEL_HUB_ID}") + if target_path is None or not Path(target_path).is_dir(): + pytest.skip(f"Target model not found: {TARGET_MODEL_HUB_ID}") + return eagle_path, target_path + + +# --------------------------------------------------------------------------- +# Mock classes for standalone Mistral4 Eagle testing +# --------------------------------------------------------------------------- + + +class MockMistral4EagleConfig(EagleConfig): + """EagleConfig variant for standalone Eagle testing. + + Disables loading embedding/lm_head from target (since there is no target model + in the standalone test) and forces random initialisation for these modules so + the model can generate logits on its own. + + Sets torch_dtype explicitly to bfloat16 so all Eagle layers (Linear, Embedding) + initialize in BF16, which is required for flashinfer kernels (rmsnorm, mla). + The nested text_config extracted from Mistral3Config may not carry torch_dtype. + """ + + _drafter_defaults = { + "mistral4": { + "load_embedding_from_target": False, + "load_lm_head_from_target": False, + "num_capture_layers": 1, + "normalize_target_hidden_state": False, + "layers_handle_final_norm": False, + # Ensure BF16 so flashinfer kernels can dispatch correctly. + # Use string form to avoid omegaconf serialization issues with torch.dtype objects. + "torch_dtype": "bfloat16", + "_checkpoint_conversion_mapping": { + r"^eagle_linear": "model.layers.0.eagle_proj", + r"^layers": "model.layers", + }, + } + } + + +class MockMistral4EagleDrafterForCausalLM(EagleDrafterForCausalLM): + """Eagle drafter that injects random hidden states for standalone testing. + + In production, hidden states come from the target model. Here we generate + them randomly so the Eagle head can be exercised without a target model. + The forward signature is changed to accept input_ids (like a normal LM) and + produce logits, making it compatible with build_and_run_ad / demollm. + """ + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self._hidden_size = config.hidden_size + self._dtype = getattr(config, "torch_dtype", torch.bfloat16) + + def forward(self, input_ids, position_ids, **kwargs): + assert self.model.embed_tokens is not None, ( + "embed_tokens must be initialised for standalone Mistral4 Eagle testing." + ) + assert self.lm_head is not None, ( + "lm_head must be initialised for standalone Mistral4 Eagle testing." + ) + inputs_embeds = self.model.embed_tokens(input_ids) + if "hidden_states" not in kwargs: + batch_size, seq_len = input_ids.shape + kwargs["hidden_states"] = torch.randn( + (batch_size, seq_len, self._hidden_size), + dtype=inputs_embeds.dtype, + device=input_ids.device, + ) + draft_output = super().forward(inputs_embeds, position_ids, **kwargs) + logits = self.lm_head(draft_output.norm_hidden_state) + return Eagle3DraftOutput(logits=logits, last_hidden_state=draft_output.last_hidden_state) + + +class MockMistral4EagleDrafterFactory(EagleDrafterFactory): + """Factory that builds MockMistral4EagleDrafterForCausalLM for standalone testing. + + Passes config_model= so that the architecture config is loaded from + the HF-format target model rather than from the native-format Eagle checkpoint. + """ + + def _build_model(self, device): + model_config, unused_kwargs = self._get_model_config() + model_config = MockMistral4EagleConfig(model_config, model_config.model_type) + + with (init_empty_weights if device == "meta" else nullcontext)(): + model = MockMistral4EagleDrafterForCausalLM._from_config(model_config, **unused_kwargs) + + if device == "meta": + if hasattr(model, "post_init"): + model.post_init() + else: + # Cast to bfloat16 so that flashinfer kernels (rmsnorm, mla) can dispatch + # correctly. Random-init weights default to float32 without this. + model.to(device=device, dtype=torch.bfloat16) + + self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None) + model.eval() + return model + + def _load_quantization_config(self, fetched_dir): + """Skip quant-config detection for standalone mock testing. + + The Eagle checkpoint's params.json declares fp8 quantization. When running + with skip_loading_weights=True (random BF16 init), applying FP8 graph + transforms on top of BF16 weights causes runtime dtype mismatches in + flashinfer kernels. Disable quant config auto-detection for mock tests. + """ + + def init_tokenizer(self): + """Load the Mistral4 tekken.json tokenizer via ADMistralSmall4Processor.""" + if self.tokenizer is None: + return None + processor = ADMistralSmall4Processor.from_pretrained(self.tokenizer) + return processor.tokenizer + + +@pytest.fixture +def register_mock_mistral4_eagle_factory(): + """Temporarily register MockMistral4EagleDrafterFactory in the model registry.""" + key = "MockMistral4EagleDrafter" + ModelFactoryRegistry._registry[key] = MockMistral4EagleDrafterFactory + yield + ModelFactoryRegistry._registry.pop(key, None) + + +# --------------------------------------------------------------------------- +# Test 1 — torch.export traceability +# --------------------------------------------------------------------------- + + +def test_mistral4_eagle_model_torch_export(): + """Mistral4 Eagle drafter can be traced with torch.export. + + Validates that the model architecture (Mistral4EagleMLA + torch_mla, + Mistral4EagleMLP, eagle_proj) is fully torch.export-compatible. Weights are + not loaded (skip_loading_weights=True); random initialisations are used so + only the graph structure matters. + """ + eagle_path, target_path = _require_paths() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.bfloat16 + + factory = EagleDrafterFactory( + model=str(eagle_path), + config_model=str(target_path), + skip_loading_weights=True, + # Reduce to a tiny model so the export runs quickly + model_kwargs={"num_hidden_layers": 2}, + ) + model = factory.build_model(device) + model = model.to(dtype=dtype) + config = model.config + + batch_size, seq_len = 1, 8 + hidden_dim = config.hidden_size + inputs_embeds = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + hidden_states = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + print( + f"\nExporting Mistral4 Eagle drafter: hidden_size={hidden_dim}, " + f"num_layers={config.num_hidden_layers}" + ) + + try: + exported = torch.export.export( + model, + args=(inputs_embeds, position_ids), + kwargs={"hidden_states": hidden_states}, + ) + print("torch.export succeeded.") + print("Graph (first 20 lines):") + print("\n".join(exported.graph_module.code.split("\n")[:20])) + except Exception as e: + pytest.fail(f"torch.export failed: {e}") + + +# --------------------------------------------------------------------------- +# Test 2 — full AutoDeploy compile + run pipeline +# --------------------------------------------------------------------------- + + +def test_build_ad_mistral4_eagle(register_mock_mistral4_eagle_factory): + """Mistral4 Eagle drafter runs through the full AutoDeploy compile/run pipeline. + + Uses MockMistral4EagleDrafterFactory which: + - Loads architecture config from the target model (Mistral-Small-4-119B-2603). + - Injects random hidden states so the Eagle head can run without a target model. + - Uses skip_loading_weights=True and tiny model_kwargs for fast execution. + + The MLA attention transform (insert_cached_mla_attention) is applied since + Mistral4EagleMLA uses the torch_mla canonical op. + """ + eagle_path, target_path = _require_paths() + + llm_extra_args = { + "model_factory": "MockMistral4EagleDrafter", + # config_model is passed to the factory via model_kwargs plumbing below + "transforms": { + "insert_cached_mla_attention": {"backend": "torch_mla"}, + "compile_model": {"backend": "torch-simple"}, + }, + } + experiment_config = get_small_model_config(TARGET_MODEL_HUB_ID, **llm_extra_args) + # Point the experiment at the Eagle checkpoint; config comes from target via config_model + experiment_config["args"]["model"] = str(eagle_path) + experiment_config["args"]["runtime"] = "demollm" + experiment_config["args"]["world_size"] = 0 + experiment_config["args"]["tokenizer"] = str(target_path) + experiment_config["args"]["skip_loading_weights"] = True + + # config_model is consumed by MockMistral4EagleDrafterFactory.__init__ via **kwargs + # passed through ModelFactoryRegistry. We inject it via a monkey-patch on the factory + # class so that target_path is captured without changing the LlmArgs schema. + original_init = MockMistral4EagleDrafterFactory.__init__ + + def _patched_init(self, model, **kwargs): + kwargs.setdefault("config_model", str(target_path)) + original_init(self, model=model, **kwargs) + + MockMistral4EagleDrafterFactory.__init__ = _patched_init + try: + cfg = ExperimentConfig(**experiment_config) + main(cfg) + finally: + MockMistral4EagleDrafterFactory.__init__ = original_init + + +# --------------------------------------------------------------------------- +# Test 3 & 4 — E2E smoke: Eagle one-model spec-dec (skip_loading_weights) +# --------------------------------------------------------------------------- + + +def _run_mistral4_eagle_one_model_smoke(num_hidden_layers: int): + """Shared implementation for Eagle one-model smoke tests. + + Builds the full Eagle one-model pipeline (target + draft) with + skip_loading_weights=True and reduced model_kwargs. Constructs LLM + directly (not via ExperimentConfig + main) because + Eagle3DecodingConfig.eagle3_layers_to_capture (Set[int]) does not survive + the model_dump -> LlmArgs OmegaConf round-trip that main() performs. + TODO: fix the Set[int] OmegaConf issue and switch to ExperimentConfig + main(). + """ + from tensorrt_llm._torch.auto_deploy.llm import LLM as ADLLM + from tensorrt_llm.llmapi import Eagle3DecodingConfig, SamplingParams + + eagle_path, target_path = _require_paths() + + small_config = get_small_model_config(TARGET_MODEL_HUB_ID) + small_dims = dict(small_config["args"]["model_kwargs"]["text_config"]) + small_dims["num_hidden_layers"] = num_hidden_layers + + spec_config = Eagle3DecodingConfig( + max_draft_len=3, + speculative_model=str(eagle_path), + eagle3_one_model=True, + eagle3_model_arch="mistral_large3", + ) + + with ADLLM( + model=str(target_path), + model_factory="Mistral3ForConditionalGeneration", + model_kwargs={"text_config": small_dims}, + skip_loading_weights=True, + transforms={"insert_cached_mla_attention": {"backend": "torch_mla"}}, + speculative_config=spec_config, + speculative_model_kwargs=dict(small_dims), + disable_overlap_scheduler=True, + compile_backend="torch-simple", + max_num_tokens=256, + world_size=1, + ) as llm: + outputs = llm.generate( + ["What is the capital of France?"], + SamplingParams(max_tokens=16, top_k=None, temperature=0.0, seed=42), + ) + + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].text) > 0 + + +def test_mistral4_eagle_one_model_smoke(): + """Eagle one-model spec-dec with a tiny (1-layer) Mistral-Small-4-119B target.""" + _run_mistral4_eagle_one_model_smoke(num_hidden_layers=1) + + +def test_mistral4_eagle_one_model_smoke_3layers(): + """Eagle one-model spec-dec with 3 layers. + + Exercises multi-layer graph transforms (layer boundary detection, residual add + identification) to surface boundary-detection issues such as MLA's + multi-projection pattern or unused placeholders in the exported graph. + """ + _run_mistral4_eagle_one_model_smoke(num_hidden_layers=3) diff --git a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py index c833b4b680e..be5299c5257 100644 --- a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py +++ b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_speculative_decoding.py @@ -14,6 +14,8 @@ # limitations under the License. +from pathlib import Path + import pytest import torch from _model_test_utils import get_small_model_config @@ -162,6 +164,92 @@ def test_super_mtp_smoke(): assert prompt == test_prompt +def test_mistral4_target_only_e2e_real_weights(): + """End-to-end test with real Mistral4 target model weights only (no Eagle). + + Loads full target model weights on 8 GPUs without speculative decoding. + Verifies the base pipeline builds, loads weights, and generates output. + """ + from tensorrt_llm._torch.auto_deploy.llm import LLM as ADLLM + from tensorrt_llm.llmapi import SamplingParams + + test_prompt = "How big is the universe?" + model_hub_id = "mistralai/Mistral-Small-4-119B-2603" + model_path = hf_id_to_local_model_dir(model_hub_id) + if model_path is None or not Path(model_path).is_dir(): + pytest.skip(f"Target model path does not exist: {model_path}") + + with ADLLM( + model=str(model_path), + tokenizer=str(model_path), + model_factory="Mistral3ForConditionalGeneration", + transforms={"insert_cached_mla_attention": {"backend": "torch_mla"}}, + runtime="trtllm", + compile_backend="torch-simple", + disable_overlap_scheduler=True, + max_seq_len=512, + world_size=8, + ) as llm: + outputs = llm.generate( + [test_prompt], + SamplingParams(max_tokens=50, top_k=None, temperature=0.0, seed=42), + ) + + assert len(outputs) == 1 + generated = outputs[0].outputs[0].text + print(f"Generated: {generated}") + assert len(generated) > 0 + + +def test_mistral4_eagle_one_model_e2e_real_weights(): + """End-to-end test with real Mistral4 + Eagle weights (no model_kwargs reduction). + + Loads full model weights on 8 GPUs with Eagle speculative decoding. + Verifies the pipeline builds, loads weights, and generates coherent output. + """ + from tensorrt_llm._torch.auto_deploy.llm import LLM as ADLLM + from tensorrt_llm.llmapi import SamplingParams + + test_prompt = "How big is the universe?" + model_hub_id = "mistralai/Mistral-Small-4-119B-2603" + model_path = hf_id_to_local_model_dir(model_hub_id) + if model_path is None or not Path(model_path).is_dir(): + pytest.skip(f"Target model path does not exist: {model_path}") + + eagle_hub_id = "mistralai/Mistral-Small-4-119B-2603-eagle" + eagle_path = hf_id_to_local_model_dir(eagle_hub_id) + if eagle_path is None or not Path(eagle_path).is_dir(): + pytest.skip(f"Eagle model path does not exist: {eagle_path}") + + spec_config = Eagle3DecodingConfig( + max_draft_len=3, + speculative_model=str(eagle_path), + eagle3_one_model=True, + eagle3_model_arch="mistral_large3", + ) + + with ADLLM( + model=str(model_path), + tokenizer=str(model_path), + model_factory="Mistral3ForConditionalGeneration", + transforms={"insert_cached_mla_attention": {"backend": "torch_mla"}}, + speculative_config=spec_config, + disable_overlap_scheduler=True, + compile_backend="torch-simple", + max_seq_len=512, + world_size=8, + ) as llm: + outputs = llm.generate( + [test_prompt], + SamplingParams(max_tokens=100, top_k=None, temperature=0.0, seed=42), + ) + + assert len(outputs) == 1 + generated = outputs[0].outputs[0].text + print(f"Generated: {generated}") + assert len(generated) > 0 + + def test_kv_cache_extra_seq_len_for_spec_dec(): """Test that get_extra_seq_len_for_kv_cache computes correct extra capacity.""" from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs @@ -220,7 +308,7 @@ def test_mtp_autodeploy_uses_eagle_one_model_capture(): ) assert isinstance(args.speculative_config, MTPDecodingConfig) - assert args.model_factory == "eagle_one_model" + assert args._requires_eagle_one_model() assert args.transforms["detect_hidden_states_for_capture"]["enabled"] is True assert args.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] == {-1} diff --git a/tests/unittest/auto_deploy/singlegpu/smoke/test_layer_subgraph_debug.py b/tests/unittest/auto_deploy/singlegpu/smoke/test_layer_subgraph_debug.py new file mode 100644 index 00000000000..ae4d561a8b8 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/smoke/test_layer_subgraph_debug.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Diagnostic test for Eagle-style layer subgraph partitioning. + +Export Llama and Mistral4 as Eagle-style targets (inputs_embeds, no embedding op) +with 3 hidden layers, then run get_all_layer_subgraphs and collect_residual_add_nodes to print +every op in every subgraph for comparison. + +Usage: + pytest tests/unittest/auto_deploy/singlegpu/smoke/test_layer_subgraph_debug.py -sv + pytest tests/unittest/auto_deploy/singlegpu/smoke/test_layer_subgraph_debug.py::test_llama_subgraphs -sv + pytest tests/unittest/auto_deploy/singlegpu/smoke/test_layer_subgraph_debug.py::test_mistral4_subgraphs -sv +""" + +from pathlib import Path + +import pytest +import torch +from _model_test_utils import get_small_model_config +from test_common.llm_data import hf_id_to_local_model_dir + +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig +from tensorrt_llm._torch.auto_deploy.transform.library.hidden_states import ( + DetectHiddenStatesForCapture, +) +from tensorrt_llm._torch.auto_deploy.utils.node_utils import ( + get_all_layer_subgraphs, + identify_regions_between_residuals, +) + +NUM_HIDDEN_LAYERS = 3 +SEQ_LEN = 8 + + +def _print_all_graph_nodes(gm, label=""): + """Print every node in the FX graph with op, target, name.""" + print(f"\n{'=' * 80}") + print(f" ALL GRAPH NODES {label}") + print(f"{'=' * 80}") + for i, node in enumerate(gm.graph.nodes): + target_str = str(node.target) if node.op == "call_function" else node.target + print(f" [{i:3d}] op={node.op:<16s} name={node.name:<60s} target={target_str}") + print(f"{'=' * 80}\n") + + +def _print_subgraph_detail(layer_subgraphs, label=""): + """Print detailed info about each layer subgraph.""" + print(f"\n{'=' * 80}") + print(f" LAYER SUBGRAPHS {label}") + print(f"{'=' * 80}") + for i, sg in enumerate(layer_subgraphs): + print(f"\n--- Subgraph {i}: type={sg.layer_type} ---") + print(f" Opening nodes ({len(sg.opening_nodes)}):") + for n in sg.opening_nodes: + print(f" {n.name}") + print(f" Terminating node: {sg.terminating_node.name if sg.terminating_node else None}") + print(f" Interior nodes ({len(sg.subgraph_nodes)}):") + for n in sg.subgraph_nodes: + if n.op == "call_function": + print(f" {n.name:<60s} target={n.target}") + else: + print(f" {n.name:<60s} op={n.op}") + print(f"{'=' * 80}\n") + + +def _print_residual_boundaries(gm, label=""): + """Print the boundary nodes returned by identify_regions_between_residuals.""" + residuals = identify_regions_between_residuals(gm) + print(f"\n{'=' * 80}") + print(f" RESIDUAL BOUNDARIES {label}") + print(f"{'=' * 80}") + for i, node in enumerate(residuals): + print(f" [{i}] name={node.name:<60s} op={node.op}") + print(f" Total boundary nodes: {len(residuals)}") + if len(residuals) == 2: + print(" WARNING: Only input+output — no residual adds found (no embedding op?)") + print(f"{'=' * 80}\n") + + +def _export_as_eagle_target(model, hidden_size): + """Export model using inputs_embeds (like Eagle target), not input_ids. + + This means no aten.embedding in the graph — matching what detect_hidden_states_for_capture sees. + """ + inputs_embeds = torch.randn(1, SEQ_LEN, hidden_size, dtype=torch.bfloat16) + position_ids = torch.arange(SEQ_LEN, dtype=torch.int64).unsqueeze(0) + gm = torch_export_to_gm( + model, + kwargs={"inputs_embeds": inputs_embeds, "position_ids": position_ids}, + ) + return gm + + +def _export_with_input_ids(model): + """Export model using input_ids (standard path with embedding op).""" + input_ids = torch.ones((1, SEQ_LEN), dtype=torch.int64) + position_ids = torch.arange(SEQ_LEN, dtype=torch.int64).unsqueeze(0) + gm = torch_export_to_gm( + model, + args=(input_ids, position_ids), + ) + return gm + + +def _build_llama_model(num_hidden_layers): + """Build a small Llama model.""" + config = get_small_model_config("meta-llama/Meta-Llama-3.1-8B-Instruct") + config["args"]["model_kwargs"]["num_hidden_layers"] = num_hidden_layers + + from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry + + factory_cls = ModelFactoryRegistry.get("AutoModelForCausalLM") + factory = factory_cls( + model=config["args"]["model"], + model_kwargs=config["args"]["model_kwargs"], + skip_loading_weights=True, + ) + model = factory.build_model("meta") + hidden_size = config["args"]["model_kwargs"]["hidden_size"] + return model, hidden_size + + +def _build_mistral4_model(num_hidden_layers): + """Build a small Mistral4 model.""" + model_hub_id = "mistralai/Mistral-Small-4-119B-2603" + model_path = hf_id_to_local_model_dir(model_hub_id) + if model_path is None or not Path(model_path).is_dir(): + pytest.skip(f"Target model path does not exist: {model_path}") + + config = get_small_model_config(model_hub_id) + small_dims = dict(config["args"]["model_kwargs"]["text_config"]) + small_dims["num_hidden_layers"] = num_hidden_layers + + from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactoryRegistry + + factory_cls = ModelFactoryRegistry.get("Mistral3ForConditionalGeneration") + factory = factory_cls( + model=model_path, + model_kwargs={"text_config": small_dims}, + skip_loading_weights=True, + ) + model = factory.build_model("meta") + hidden_size = small_dims["hidden_size"] + return model, hidden_size + + +def _run_diagnostics(gm, label): + """Run all diagnostic prints on a GraphModule.""" + _print_all_graph_nodes(gm, label) + _print_residual_boundaries(gm, label) + + print(f"\n--- Running get_all_layer_subgraphs {label} ---") + try: + layer_subgraphs, unprocessed = get_all_layer_subgraphs(gm) + _print_subgraph_detail(layer_subgraphs, label) + print(f"Unprocessed linear nodes: {[n.name for n in unprocessed]}") + except Exception as e: + print(f"get_all_layer_subgraphs FAILED: {type(e).__name__}: {e}") + return + + print(f"\n--- Running collect_residual_add_nodes {label} ---") + try: + transform = DetectHiddenStatesForCapture( + config=TransformConfig( + stage="pattern_matcher", + eagle3_layers_to_capture={-1}, + ) + ) + residual_add_nodes = transform.collect_residual_add_nodes(gm) + print(f" Found residual add nodes for layers: {sorted(residual_add_nodes.keys())}") + for layer_num, node in sorted(residual_add_nodes.items()): + print(f" layer {layer_num}: {node.name}") + except Exception as e: + print(f"collect_residual_add_nodes FAILED: {type(e).__name__}: {e}") + + +def test_llama_subgraphs(): + """Diagnostic: Llama 3-layer as Eagle target (inputs_embeds, no embedding). + + Also shows the standard input_ids path for comparison. + """ + model, hidden_size = _build_llama_model(NUM_HIDDEN_LAYERS) + + # Standard path (with embedding) — for comparison + print("\n" + "#" * 80) + print("# LLAMA — STANDARD (input_ids, with embedding)") + print("#" * 80) + gm_standard = _export_with_input_ids(model) + _run_diagnostics(gm_standard, "[Llama standard]") + + # Eagle target path (inputs_embeds, no embedding) + print("\n" + "#" * 80) + print("# LLAMA — EAGLE TARGET (inputs_embeds, no embedding)") + print("#" * 80) + gm_eagle = _export_as_eagle_target(model, hidden_size) + _run_diagnostics(gm_eagle, "[Llama eagle-target]") + + +def test_mistral4_subgraphs(): + """Diagnostic: Mistral4 3-layer as Eagle target (inputs_embeds, no embedding).""" + model, hidden_size = _build_mistral4_model(NUM_HIDDEN_LAYERS) + + # Eagle target path (inputs_embeds, no embedding) + print("\n" + "#" * 80) + print("# MISTRAL4 — EAGLE TARGET (inputs_embeds, no embedding)") + print("#" * 80) + gm_eagle = _export_as_eagle_target(model, hidden_size) + _run_diagnostics(gm_eagle, "[Mistral4 eagle-target]") + + +if __name__ == "__main__": + import sys + + test_name = sys.argv[1] if len(sys.argv) > 1 else "all" + if test_name in ("llama", "all"): + test_llama_subgraphs() + if test_name in ("mistral4", "all"): + test_mistral4_subgraphs()