diff --git a/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml b/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml index 3056ab839b1..7e5e967078e 100644 --- a/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml +++ b/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml @@ -7,6 +7,7 @@ max_batch_size: 32 cuda_graph_config: batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] enable_chunked_prefill: true +# Use AutoModelForCausalLM for text only mode until issue #12699 is resolved model_factory: Qwen3_5MoeForConditionalGeneration kv_cache_config: enable_block_reuse: false @@ -15,13 +16,18 @@ kv_cache_config: model_kwargs: torch_dtype: bfloat16 transforms: + # disable for text only use case initialize_mrope_delta_cache: enabled: true export_to_gm: num_moe_experts_for_export: 2 fuse_gemms_mixed_children: enabled: true + fuse_nvfp4_moe: + backend: trtllm_gen detect_sharding: + # for long input, tp8ep1 gives better performance + # dist_mapping: {moe_tp: 8, moe_ep: 1} allreduce_strategy: SYMM_MEM shard_all_unprocessed: true simple_shard_filter: "lm_head" @@ -37,6 +43,9 @@ transforms: "k_proj": "colwise" "v_proj": "colwise" "o_proj": "rowwise" + # lm_head: "gather" = column split + all_gather (not "colwise" which + # requires a LayerSubgraph and crashes for standalone unprocessed nodes) + "lm_head": "gather" # replicating shared experts (keep them commented out) # "shared_expert_gate_proj": "colwise" # "shared_expert_up_proj": "colwise" diff --git a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py index 85c87afe278..a4a149fcdea 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py @@ -110,15 +110,19 @@ def __init__( self, model: nn.Module, num_batched_inputs: Optional[int] = None, # number of batched, dynamic inputs... + dynamic_dims: Optional[List[int]] = None, ): super().__init__() self.model = model self.num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1 + self.dynamic_dims = dynamic_dims or [0] * self.num_batched_inputs + assert len(self.dynamic_dims) == self.num_batched_inputs self.cudagraphs: Dict[Tuple[int, ...], CUDAGraph] = {} self._input_buffers: List[torch.Tensor] = [ torch.empty(0, 1) for _ in range(self.num_batched_inputs) ] self._out_buffer_flat: List[torch.Tensor] = None + self._output_dynamic_dim: int = 0 self._args_hash: Optional[Tuple[int, ...]] = None self._cuda_graph_mem_pool = None @@ -139,13 +143,14 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph: # capture graph now torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() + od = self._output_dynamic_dim with torch.cuda.graph(graph, pool=self._cuda_graph_mem_pool): # compute output out = self.model(*args, **kwargs) # write out into output buffer up to out batch size out_flat = tree_flatten_spec(out, self._out_spec) for o_buffer, o in zip(self._out_buffer_flat, out_flat): - o_buffer[: o.shape[0]] = o + o_buffer.narrow(od, 0, o.shape[od]).copy_(o) torch.cuda.synchronize() self._cuda_graph_mem_pool = self._cuda_graph_mem_pool or graph.pool() return graph @@ -167,6 +172,26 @@ def capture_graph(self, get_args_kwargs: GetArgsKwargsForBatchSize, batch_sizes: args_batched = all_args_flat[: self.num_batched_inputs] args_static = all_args_flat[self.num_batched_inputs :] + # Auto-detect dynamic dims: compare two different batch sizes to find + # which dim changes per batched input. Probe below max to stay within + # buffer capacity (e.g. cu_seqlen is sized for max_batch_size + 1). + probe_bs = max(1, batch_sizes[0] - 1) + args2, kwargs2 = get_args_kwargs(probe_bs) + flat2 = _args_kwargs_flatten_spec(self._in_spec, *args2, **kwargs2) + batched2 = flat2[: self.num_batched_inputs] + detected_dims = [] + for t1, t2 in zip(args_batched, batched2): + dim_found = 0 + for d in range(t1.ndim): + if t1.shape[d] != t2.shape[d]: + dim_found = d + break + detected_dims.append(dim_found) + self.dynamic_dims = detected_dims + + # Detect output dynamic dim from the first batched input's dynamic dim + self._output_dynamic_dim = self.dynamic_dims[0] + # set the args hash --> this is used to compare the static inputs during graph replay self._args_hash = self._get_hash(args_static) @@ -198,14 +223,21 @@ def capture_graph(self, get_args_kwargs: GetArgsKwargsForBatchSize, batch_sizes: "Static args mismatch during capture" ) - # copy new inputs to input buffers + # copy new inputs to input buffers along their respective dynamic dims + input_sizes: List[int] = [] for i, input_tensor in enumerate(args_batched): - self._input_buffers[i][: input_tensor.shape[0]].copy_( + dim_i = self.dynamic_dims[i] + size_i = input_tensor.shape[dim_i] + input_sizes.append(size_i) + self._input_buffers[i].narrow(dim_i, 0, size_i).copy_( input_tensor, non_blocking=True ) - # setup args, kwargs - inputs_truncated = [in_buffer[:bs] for in_buffer in self._input_buffers] + # truncate input buffers along their respective dynamic dims + inputs_truncated = [ + buf.narrow(self.dynamic_dims[i], 0, input_sizes[i]) + for i, buf in enumerate(self._input_buffers) + ] args, kwargs = self._in_spec.unflatten(inputs_truncated + args_static) # capture graph for truncated inputs @@ -232,16 +264,19 @@ def forward(self, *args, **kwargs) -> Any: if combined_shape not in self.cudagraphs: return self.model(*args, **kwargs) - # copy inputs to input buffers + # copy inputs to input buffers along their respective dynamic dims for i, input_tensor in enumerate(args_batched): - self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True) + dim_i = self.dynamic_dims[i] + size_i = input_tensor.shape[dim_i] + self._input_buffers[i].narrow(dim_i, 0, size_i).copy_(input_tensor, non_blocking=True) # run forward pass via graph self.cudagraphs[combined_shape].replay() # retrieve output from buffer, cut to batch size, and unflatten - bs = args_batched[0].shape[0] - out_flat = [o_b[:bs] for o_b in self._out_buffer_flat] + od = self._output_dynamic_dim + bs = args_batched[0].shape[self.dynamic_dims[0]] + out_flat = [o_b.narrow(od, 0, bs) for o_b in self._out_buffer_flat] return self._out_spec.unflatten(out_flat) @@ -263,6 +298,7 @@ def __init__( piecewise_num_tokens: Optional[List[int]] = None, capture_lm_head: bool = False, max_batch_size: Optional[int] = None, + out_spec: Optional[TreeSpec] = None, ): super().__init__() self.original_model = model @@ -273,6 +309,14 @@ def __init__( self.split_gm: Optional[GraphModule] = None self._is_prepared = False self._wrapped_dynamic_indices: Set[int] = set() + # Pre-allocated static buffers for kwargs whose addresses change between + # calls. Allocated during warmup_and_capture, used at runtime to ensure + # CUDA graph replay sees stable addresses. + # Format: {kwarg_name: (static_buffer, dynamic_dim_or_none)} + self._static_input_buffers: Dict[str, Tuple[torch.Tensor, Optional[int]]] = {} + # Output tree spec for reconstructing structured outputs (e.g. + # ModelOutput) from flat tuples returned by split_gm. + self._out_spec = out_spec def prepare(self) -> None: """Split the model, wrap static segments in runners, wrap Group 3 dynamic ops.""" @@ -393,19 +437,14 @@ def prepare(self) -> None: num_wrapped_dynamic += 1 self._is_prepared = True + num_dynamic_eager = ( + len(self.split_info.dynamic_submod_indices) - num_wrapped_dynamic - num_metadata_wrapped + ) ad_logger.info( - "PiecewiseCapturedGraph: prepared with %d submodules " - "(%d static runners, %d trivial skipped, %d dynamic wrapped, " - "%d metadata wrapped, %d dynamic eager), piecewise_num_tokens=%s", - self.split_info.num_submodules, - num_wrapped_static, - num_skipped_static, - num_wrapped_dynamic, - num_metadata_wrapped, - len(self.split_info.dynamic_submod_indices) - - num_wrapped_dynamic - - num_metadata_wrapped, - self.piecewise_num_tokens, + f"PiecewiseCapturedGraph: prepared with {self.split_info.num_submodules} submodules " + f"({num_wrapped_static} static runners, {num_skipped_static} trivial skipped, " + f"{num_wrapped_dynamic} dynamic wrapped, {num_metadata_wrapped} metadata wrapped, " + f"{num_dynamic_eager} dynamic eager), piecewise_num_tokens={self.piecewise_num_tokens}" ) def _discover_dynamic_output_shapes(self, args: Tuple, kwargs: Dict) -> Dict[int, OutputInfo]: @@ -458,6 +497,70 @@ def _set_dynamic_out_info_on_runners(self, discovered: Dict[int, OutputInfo]) -> f"Cannot pre-allocate out= buffers — downstream static runners require stable addresses." ) + def _allocate_static_input_buffers( + self, + get_args_kwargs: Callable[[int], Any], + ) -> None: + """Allocate static buffers for kwargs whose addresses change between calls. + + Calls `get_args_kwargs` twice with the largest bucket to check address + stability (data_ptr), and once with a different size to detect the + dynamic dimension by shape comparison. Any kwarg with unstable + addresses gets a pre-allocated static buffer. + """ + max_bucket = max(self.piecewise_num_tokens) + _, kw1 = get_args_kwargs(max_bucket) + _, kw2 = get_args_kwargs(max_bucket) + _, kw_probe = get_args_kwargs(max(1, max_bucket - 1)) + + for key in kw1: + v1, v2 = kw1.get(key), kw2.get(key) + if ( + isinstance(v1, torch.Tensor) + and isinstance(v2, torch.Tensor) + and v1.data_ptr() != v2.data_ptr() + ): + v_probe = kw_probe.get(key) + dyn_dim = None + if isinstance(v_probe, torch.Tensor): + for d in range(v1.ndim): + if v1.shape[d] != v_probe.shape[d]: + dyn_dim = d + break + if dyn_dim is not None or ( + isinstance(v_probe, torch.Tensor) and v_probe.shape == v1.shape + ): + # Static-shape kwargs still need buffering when their addresses + # change across calls. In that case dyn_dim stays None and we + # copy the full buffer at runtime. + self._static_input_buffers[key] = (torch.empty_like(v1), dyn_dim) + else: + ad_logger.warning( + "PiecewiseCapturedGraph: kwarg '%s' has unstable address but " + "no dynamic dim found; leaving it unbuffered", + key, + ) + + if self._static_input_buffers: + ad_logger.info( + "PiecewiseCapturedGraph: allocated %d static input buffer(s): %s", + len(self._static_input_buffers), + {k: (v[0].shape, f"dyn_dim={v[1]}") for k, v in self._static_input_buffers.items()}, + ) + + def _copy_to_static_buffers(self, kwargs: Dict[str, Any]) -> None: + """Copy kwargs into pre-allocated static buffers for address stability.""" + for key, (buf, dyn_dim) in self._static_input_buffers.items(): + src = kwargs.get(key) + if src is not None and isinstance(src, torch.Tensor): + if dyn_dim is None: + buf.copy_(src) + kwargs[key] = buf + else: + buf_view = buf.narrow(dyn_dim, 0, src.shape[dyn_dim]) + buf_view.copy_(src) + kwargs[key] = buf_view + def warmup_and_capture( self, get_args_kwargs: Callable[[int], Any], @@ -474,6 +577,10 @@ def warmup_and_capture( 4. Capture: run split_gm once more; runners capture CUDA graphs and allocate dynamic output buffers inside torch.cuda.graph(). 5. Cleanup: gc.collect() + empty_cache() between buckets. + + Before the per-bucket loop, calls get_args_kwargs twice to detect any + kwargs whose tensor addresses are unstable. Those kwargs are copied + into static buffers for both capture and runtime replay. """ if not self._is_prepared: self.prepare() @@ -481,10 +588,13 @@ def warmup_and_capture( if self.split_gm is None: return + self._allocate_static_input_buffers(get_args_kwargs) + num_tokens_list = sorted(self.piecewise_num_tokens, reverse=True) for nt in num_tokens_list: - ad_logger.info("PiecewiseCapturedGraph: warming up for num_tokens=%d", nt) + ad_logger.info(f"PiecewiseCapturedGraph: warming up for num_tokens={nt}") args, kwargs = get_args_kwargs(nt) + self._copy_to_static_buffers(kwargs) ADPiecewiseRunner.set_current_num_tokens(nt) @@ -504,7 +614,7 @@ def warmup_and_capture( ADPiecewiseRunner.set_current_phase("capture") self.split_gm(*args, **kwargs) - ad_logger.info("PiecewiseCapturedGraph: captured graphs for num_tokens=%d", nt) + ad_logger.info(f"PiecewiseCapturedGraph: captured graphs for num_tokens={nt}") torch.cuda.synchronize() gc.collect() @@ -513,11 +623,30 @@ def warmup_and_capture( ADPiecewiseRunner.set_current_num_tokens(None) ADPiecewiseRunner.set_current_phase("replay") + def _reconstruct_output(self, result: Any) -> Any: + """Reconstruct structured output from a flat tuple using the output tree spec.""" + if not isinstance(result, tuple) or self._out_spec is None: + return result + try: + return self._out_spec.unflatten(list(result)) + except Exception as e: + ad_logger.warning( + "PiecewiseCapturedGraph._reconstruct_output: failed to unflatten output " + "(%s); returning raw tuple", + e, + ) + return result + def forward(self, *args, num_tokens: Optional[int] = None, **kwargs) -> Any: """Forward pass: static segments replay graphs, dynamic segments run eagerly.""" if self.split_gm is not None: + self._copy_to_static_buffers(kwargs) ADPiecewiseRunner.set_current_num_tokens(num_tokens) - return self.split_gm(*args, **kwargs) + try: + result = self.split_gm(*args, **kwargs) + finally: + ADPiecewiseRunner.set_current_num_tokens(None) + return self._reconstruct_output(result) return self.original_model(*args, **kwargs) @@ -555,6 +684,17 @@ def __init__( # Sorted list of pre-captured bucket sizes for nearest-bucket lookup self._captured_num_tokens_sorted: List[int] = sorted(piecewise.piecewise_num_tokens) + def __getattr__(self, name: str): + """Proxy attribute lookups to the underlying model. + + When this module replaces an inner submodule, the parent may access + methods like get_input_embeddings() that live on the original model. + """ + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.monolithic.model, name) + def _is_decode_only(self, **kwargs) -> bool: """Check if the current batch is decode-only using batch_info_host. @@ -578,15 +718,16 @@ def _is_decode_only(self, **kwargs) -> bool: return True def _get_num_tokens(self, **kwargs) -> int: - """Extract total num_tokens from the batched inputs. - - For prefill/mixed with flattened layout: input_ids shape = [1, total_num_tokens] - We use numel() which works for both [1, N] and [N] layouts. - """ + """Extract total num_tokens from batch_info_host or batched inputs.""" + batch_info = kwargs.get(self.batch_info_kwarg_name) + if batch_info is not None and isinstance(batch_info, torch.Tensor): + # batch_info_host layout: [0]=num_prefill, [1]=num_prefill_tokens, + # [2]=num_extend, [3]=num_extend_tokens, [4]=num_decode, [5]=num_decode_tokens + return int((batch_info[1] + batch_info[3] + batch_info[5]).item()) for name in self.batched_input_names: v = kwargs.get(name) - if v is not None and isinstance(v, torch.Tensor): - return v.numel() + if v is not None and isinstance(v, torch.Tensor) and v.ndim >= 1: + return int(v.numel()) return 0 def _find_nearest_bucket(self, num_tokens: int) -> Optional[int]: @@ -596,6 +737,46 @@ def _find_nearest_bucket(self, num_tokens: int) -> Optional[int]: return bucket return None + def _truncate_output(self, result: Any, num_tokens: int, bucket: int) -> Any: + """Slice padded outputs from bucket size to real num_tokens. + + Finds the token dimension by looking for the dim whose size equals + the bucket size, then narrows it to num_tokens. + """ + output_dynamic_dim = getattr(self.monolithic, "_output_dynamic_dim", None) + + def _narrow(v): + if not isinstance(v, torch.Tensor): + return v + if ( + isinstance(output_dynamic_dim, int) + and 0 <= output_dynamic_dim < v.ndim + and v.shape[output_dynamic_dim] == bucket + ): + return v.narrow(output_dynamic_dim, 0, num_tokens) + matching_dims = [d for d in range(v.ndim) if v.shape[d] == bucket] + if not matching_dims: + return v + if len(matching_dims) > 1: + ad_logger.warning( + "DualModeCapturedGraph._truncate_output: ambiguous token dim for shape %s, " + "bucket=%d; falling back to dim %d", + tuple(v.shape), + bucket, + matching_dims[0], + ) + return v.narrow(matching_dims[0], 0, num_tokens) + + if isinstance(result, torch.Tensor): + return _narrow(result) + if hasattr(result, "to_tuple"): + sliced = {k: _narrow(v) for k, v in result.items()} + return type(result)(**sliced) + elif isinstance(result, abc.Mapping): + return {k: _narrow(v) for k, v in result.items()} + else: + return tuple(_narrow(r) for r in result) + def forward(self, *args, **kwargs) -> Any: # NOTE: AD calls model(**named_args) so everything is in kwargs, args is empty if self._is_decode_only(**kwargs): @@ -606,18 +787,12 @@ def forward(self, *args, **kwargs) -> Any: num_tokens = self._get_num_tokens(**kwargs) bucket = self._find_nearest_bucket(num_tokens) if bucket is not None: - result = self.piecewise(*args, num_tokens=bucket, **kwargs) - ADPiecewiseRunner.set_current_num_tokens(None) + try: + result = self.piecewise(*args, num_tokens=bucket, **kwargs) + finally: + ADPiecewiseRunner.set_current_num_tokens(None) if bucket > num_tokens: - # HF ModelOutput iterates over field names (e.g. "logits"), not - # tensor values. Normalize to the payload tuple before slicing. - if hasattr(result, "to_tuple"): - result = result.to_tuple() - elif isinstance(result, abc.Mapping): - result = tuple(result.values()) - else: - result = tuple(result) - result = tuple(r[:, :num_tokens] if r.ndim >= 2 else r for r in result) + result = self._truncate_output(result, num_tokens, bucket) return result # No bucket large enough -- eager fallback @@ -687,6 +862,26 @@ def _setup_piecewise_mixed_batch(seq_info: Any, num_tokens: int) -> None: ) +def _capture_inner_kwargs( + full_model: nn.Module, + inner_module: nn.Module, + top_level_kwargs: Dict[str, Any], +) -> Dict[str, Any]: + """Run full model once and intercept kwargs passed to the inner module.""" + captured: Dict[str, Any] = {} + + def hook(module, args, kwargs): + captured.update(kwargs) + return args, kwargs + + handle = inner_module.register_forward_pre_hook(hook, with_kwargs=True) + try: + full_model(**top_level_kwargs) + finally: + handle.remove() + return captured + + @CompileBackendRegistry.register("torch-cudagraph") class TorchCudagraphCompiler(CompilerBackend): """Compiler that uses CUDA graphs. @@ -694,6 +889,10 @@ class TorchCudagraphCompiler(CompilerBackend): Supports two modes: - piecewise_enabled=False (default): monolithic CG only (decode-only batches) - piecewise_enabled=True: dual-mode (monolithic for decode + piecewise for prefill/mixed) + + When the top-level model is a wrapper (not a GraphModule), the compiler + auto-discovers the inner GraphModule (e.g. text model) and compiles it. + The wrapper (e.g. vision tower, embed merge) runs eagerly. """ def __init__( @@ -706,6 +905,7 @@ def __init__( piecewise_num_tokens: Optional[List[int]] = None, piecewise_seq_info: Any = None, piecewise_named_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, + full_model: Optional[nn.Module] = None, **kwargs_for_init, ): super().__init__(*args_for_init, **kwargs_for_init) @@ -716,6 +916,22 @@ def __init__( self.piecewise_num_tokens = piecewise_num_tokens or [] self.piecewise_seq_info = piecewise_seq_info self.piecewise_named_args_fn = piecewise_named_args_fn + self.full_model = full_model + + def _get_inner_args_kwargs_fn(self, inner_gm: GraphModule) -> GetArgsKwargsForBatchSize: + """Return a function that generates inner-model args for a given batch size. + + Runs the full model with top-level args and captures the kwargs that the + wrapper passes to the inner GraphModule via a forward pre-hook. + """ + assert self.full_model is not None + + def get_inner_args(batch_size: int): + _, top_level_kwargs = self.get_args_kwargs_for_compile(batch_size) + inner_kwargs = _capture_inner_kwargs(self.full_model, inner_gm, top_level_kwargs) + return (), inner_kwargs + + return get_inner_args @torch.inference_mode() def compile(self) -> nn.Module: @@ -723,27 +939,35 @@ def compile(self) -> nn.Module: "get_args_kwargs_for_compile must be provided" ) - # wrap get_args_kwargs_for_compile with CudaGraphWarmUpPhase. Note that host-side prepare - # functions may be called as part of get_args_kwargs. We want to let these functions know it's - # a warm-up phase. - def get_args_kwargs_warmup(batch_size: int): + # Build args functions once — unified for both wrapper and direct GM cases. + # self.model is always the target GM (compile_model.py extracts it). + target_gm = self.model + + if self.full_model is not None: + ad_logger.info("TorchCudagraphCompiler: wrapper detected, compiling inner GraphModule") + get_capture_args_fn = self._get_inner_args_kwargs_fn(target_gm) + else: + get_capture_args_fn = self.get_args_kwargs_for_compile + + def get_capture_args_with_warmup(batch_size: int): with CudaGraphWarmUpPhase(): - return self.get_args_kwargs_for_compile(batch_size) + return get_capture_args_fn(batch_size) - monolithic = CapturedGraph(self.model, num_batched_inputs=self.num_batched_inputs) - monolithic.capture_graph(get_args_kwargs_warmup, self.cuda_graph_batch_sizes) + monolithic = CapturedGraph(target_gm, num_batched_inputs=self.num_batched_inputs) + monolithic.capture_graph(get_capture_args_with_warmup, self.cuda_graph_batch_sizes) piecewise = None if self.piecewise_enabled: ad_logger.info("TorchCudagraphCompiler: dual-mode enabled (monolithic + piecewise)") piecewise = PiecewiseCapturedGraph( - model=self.model, + model=target_gm, piecewise_num_tokens=self.piecewise_num_tokens, max_batch_size=( self.piecewise_seq_info.max_batch_size if self.piecewise_seq_info is not None else None ), + out_spec=monolithic._out_spec, ) piecewise.prepare() @@ -753,11 +977,16 @@ def get_args_kwargs_warmup(batch_size: int): and self.piecewise_num_tokens ): - def get_mixed_args_kwargs(num_tokens: int): + def get_piecewise_args(num_tokens: int): _setup_piecewise_mixed_batch(self.piecewise_seq_info, num_tokens) - return (), self.piecewise_named_args_fn() + top_level_kwargs = self.piecewise_named_args_fn() + if self.full_model is not None: + return (), _capture_inner_kwargs( + self.full_model, target_gm, top_level_kwargs + ) + return (), top_level_kwargs - piecewise.warmup_and_capture(get_mixed_args_kwargs) + piecewise.warmup_and_capture(get_piecewise_args) if piecewise is not None: return DualModeCapturedGraph(monolithic, piecewise) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py index e227bc7ebec..ff03589b2f3 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py @@ -730,7 +730,11 @@ class Qwen3_5MoeCausalLMOutput(ModelOutput): class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel): - """Qwen3.5 MoE text model (embed + decoder layers + final norm).""" + """Qwen3.5 MoE text model (embed + decoder layers + final norm + lm_head). + + lm_head is included so that the exported GraphModule contains it directly, + allowing sharding and gather_logits_before_lm_head transforms to see it. + """ def __init__(self, config: Qwen3_5MoeTextConfig): super().__init__(config) @@ -746,10 +750,15 @@ def __init__(self, config: Qwen3_5MoeTextConfig): ) self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config) + self.lm_head = None # set by parent model via set_lm_head() # Initialize weights and apply final processing self.post_init() + def set_lm_head(self, lm_head: nn.Module): + """Set the lm_head from the parent model.""" + self.lm_head = lm_head + def get_input_embeddings(self): return self.embed_tokens @@ -801,7 +810,11 @@ def forward( hidden_states = decoder_layer(hidden_states, position_embeddings=position_embeddings) hidden_states = self.norm(hidden_states) - return Qwen3_5MoeOutput(last_hidden_state=hidden_states) + assert self.lm_head is not None, ( + "lm_head not set — call set_lm_head() from the parent model before forward()" + ) + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + return Qwen3_5MoeCausalLMOutput(logits=logits) class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin): @@ -814,6 +827,7 @@ def __init__(self, config: Qwen3_5MoeTextConfig, **kwargs): self.model = Qwen3_5MoeTextModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.model.set_lm_head(self.lm_head) # Initialize weights and apply final processing self.post_init() @@ -829,6 +843,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + self.model.set_lm_head(new_embeddings) def forward( self, @@ -848,8 +863,7 @@ def forward( rope_cos=rope_cos, rope_sin=rope_sin, ) - hidden_states = outputs[0] - logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + logits = outputs.logits return Qwen3_5MoeCausalLMOutput(logits=logits) @@ -2565,10 +2579,19 @@ def __init__(self, config: Qwen3_5MoeConfig, **kwargs): self.lm_head = nn.Linear( config.text_config.hidden_size, config.text_config.vocab_size, bias=False ) + # Share lm_head with the text model so it's inside the exported graph + self.model.language_model.set_lm_head(self.lm_head) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.language_model.get_input_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + self.model.language_model.set_lm_head(new_embeddings) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -2590,8 +2613,7 @@ def forward( video_grid_thw=video_grid_thw, **kwargs, ) - hidden_states = outputs.last_hidden_state - logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + logits = outputs.logits return Qwen3_5MoeConditionalOutput(logits=logits) @@ -2607,6 +2629,9 @@ class Qwen3_5MoeTextExportInfo(TextModelExportInfo): (batch, sequence) are dynamic. """ + def __init__(self, submodule_name: str): + super().__init__(submodule_name) + def _init_dynamic_shape_lookup(self): base = super()._init_dynamic_shape_lookup() batch_size_dyn = Dim.DYNAMIC @@ -2858,4 +2883,7 @@ def init_input_processor(self, base): AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeTextConfig) AutoModelForCausalLMFactory.register_custom_model_cls("Qwen3_5MoeTextConfig", Qwen3_5MoeForCausalLM) +AutoModelForCausalLMFactory.register_custom_model_cls( + "Qwen3_5MoeConfig", Qwen3_5MoeForConditionalGeneration +) Qwen3_5MoeFactory.register_custom_model_cls("Qwen3_5MoeConfig", Qwen3_5MoeForConditionalGeneration) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py index 62cc2003582..aa4b24f507d 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py @@ -2,6 +2,7 @@ import torch.nn as nn from pydantic import Field +from torch.fx import GraphModule from ...compile import ArgsKwargs, CompileBackendRegistry from ...models.factory import ModelFactory @@ -16,6 +17,15 @@ ) +def _set_submodule(model: nn.Module, key: str, new_module: nn.Module) -> None: + """Replace a nested submodule given a dotted key path (e.g. 'model.language_model').""" + parts = key.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], new_module) + + def _generate_default_piecewise_num_tokens(max_num_tokens: int) -> List[int]: """Generate default piecewise bucket sizes when none are specified. @@ -138,13 +148,41 @@ def _get_args_kwargs(bs: int) -> ArgsKwargs: config_dict = self.config.model_dump() config_dict.update(config_overrides) - compiler_backend = CompileBackendRegistry.get(self.config.backend)( - mod, - get_args_kwargs_for_compile=_get_args_kwargs, - **extra_kwargs, - **config_dict, - ) - mod_compiled = compiler_backend.compile() + # Walk the module tree and collect the top-level GraphModules to compile. + # Once a GM is found, its children are skipped (they're part of the GM). + compile_targets = [] + seen = set() + if isinstance(mod, GraphModule): + compile_targets.append(("", mod)) + seen.add("") + for name, submod in mod.named_modules(): + if any(p == "" or name.startswith(p + ".") for p in seen): + continue + if isinstance(submod, GraphModule): + compile_targets.append((name, submod)) + seen.add(name) + + if compile_targets: + ad_logger.info( + f"CompileModel: compiling {len(compile_targets)} GraphModule(s): " + f"{[name or '(root)' for name, _ in compile_targets]}" + ) + + for gm_key, gm in compile_targets: + full_model = mod if gm_key else None + compiler_backend = CompileBackendRegistry.get(self.config.backend)( + gm, + get_args_kwargs_for_compile=_get_args_kwargs, + full_model=full_model, + **extra_kwargs, + **config_dict, + ) + compiled_gm = compiler_backend.compile() + if gm_key: + _set_submodule(mod, gm_key, compiled_gm) + else: + mod = compiled_gm + mod_compiled = mod # store info object about the transform info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index a3cf9661ee3..633d0500e9f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -405,6 +405,10 @@ def build_custom_args_for_linear(self, scales: Dict[str, Node]) -> Tuple: return ([scales["input_scale"]], [scales["weight_scale"], scales["alpha"]], [], []) def load_hook(self, state_dict, prefix, *args, weight_name): + # Prepend prefix so the hook works when the GraphModule is a submodule + # of the model on which load_state_dict is called (e.g., VLM models + # where the text model lives at model.language_model.*). + weight_name = prefix + weight_name if weight_name in state_dict: input_scale_name = weight_name.rsplit(".", 1)[0] + ".input_scale" alpha_name = weight_name.rsplit(".", 1)[0] + ".alpha" diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index fff763b437a..7dcb129d909 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -637,7 +637,8 @@ def shard_load_hook( world_size: int, min_local_shape: int = 1, ) -> None: - scale_key = weight_name + "_scale_inv" + # Prepend prefix for VLM models where gm is a submodule + scale_key = prefix + weight_name + "_scale_inv" if scale_key in state_dict: scale = state_dict[scale_key] state_dict[scale_key] = self._split_scale(scale, dim, rank, world_size) @@ -720,7 +721,8 @@ def shard_load_hook( world_size: int, min_local_shape: int = 1, ) -> None: - key = weight_name + "_scale" + # Prepend prefix for VLM models where gm is a submodule + key = prefix + weight_name + "_scale" if key in state_dict: state_dict[key] = _shard_fp4_weight_scale( state_dict[key], @@ -1568,19 +1570,23 @@ def f_split( sharded_weight = f_split(weight_tensor) sharded_shape = sharded_weight.shape - # Register load hook - gm._register_load_state_dict_pre_hook( + # Update the parameter in the module + modname, _, param_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) + + # Register load hook on the owning submodule (not the top-level gm). + # This ensures the hook runs *after* any parent-level hooks that transform + # the state_dict (e.g., unfusing fused MoE checkpoint weights into + # individual expert keys). With the hook on gm, it would run before + # unfusing and fail to find the individual expert keys. + submod._register_load_state_dict_pre_hook( partial( _load_hook, f_split=f_split, - param_key=param_key, + param_key=param_name, param_shape=sharded_shape, ) ) - - # Update the parameter in the module - modname, _, param_name = param_key.rpartition(".") - submod = gm.get_submodule(modname) param_new = nn.Parameter(sharded_weight.detach().clone(), requires_grad=requires_grad) setattr(submod, param_name, param_new) @@ -1747,6 +1753,84 @@ def _merge_arg(current_arg: Any, stored_arg: Any) -> Any: ad_logger.debug(f"Updated node {node}: sharded arguments are now {node.args}.") +def _shard_nvfp4_moe_scale( + scale: torch.Tensor, + orig_weight_shape: torch.Size, + dim: int, + rank: int, + world_size: int, +) -> torch.Tensor: + """Shard NVFP4 weight_scale for MoE TP, preserving 2D cutlass format. + + Unlike _shard_fp4_weight_scale (which returns 1D), this returns a 2D tensor + with the correct padded shape, matching the format expected by MoE stacking. + """ + weight_shape_elements = list(orig_weight_shape) + weight_shape_elements[-1] *= 2 # uint8 -> element count (FP4 packs 2 per byte) + modelopt_scale = cutlass_fp4_scale_to_modelopt_fp4_scale(scale, tuple(weight_shape_elements)) + sharded = _split_tensor_for_tp(modelopt_scale, dim, rank, world_size) + m, n = sharded.shape + # Pad to match CUTLASS FP4 scale swizzle alignment requirements: + # 128 rows (4 * 32 tile in M dim) and 4 columns (N dim grouping). + # See modelopt_fp4_scale_to_cutlass_fp4_scale in quantization_utils.py. + pad_m = (128 - m % 128) % 128 + pad_n = (4 - n % 4) % 4 + result_1d = modelopt_fp4_scale_to_cutlass_fp4_scale(sharded) + return result_1d.reshape(m + pad_m, n + pad_n) + + +def _tp_shard_moe_scale( + gm: GraphModule, + scale_node: Node, + scale_name: str, + dim: int, + rank: int, + world_size: int, + orig_weight_shape: torch.Size, +) -> None: + """TP-shard a single MoE expert's blocked scale tensor. + + For NVFP4 (weight_scale): converts from cutlass format, splits, reconverts to 2D. + For FineGrained FP8 (weight_scale_inv): directly splits the 2D scale tensor. + """ + param_key = scale_node.target + modname, _, attr_name = param_key.rpartition(".") + submod = gm.get_submodule(modname) + scale_tensor = submod.get_buffer(attr_name) + + if scale_name == "weight_scale": + f_split = partial( + _shard_nvfp4_moe_scale, + orig_weight_shape=orig_weight_shape, + dim=dim, + rank=rank, + world_size=world_size, + ) + elif scale_name == "weight_scale_inv": + f_split = partial( + FineGrainedFP8WeightShardingInfo._split_scale, + dim=dim, + rank=rank, + world_size=world_size, + ) + else: + return + + sharded_scale = f_split(scale_tensor) + submod.register_buffer(attr_name, sharded_scale) + + # Register load hook on the owning submodule so it runs after any + # parent-level checkpoint format conversion hooks (e.g., fused MoE unfusing). + submod._register_load_state_dict_pre_hook( + partial( + _load_hook, + f_split=f_split, + param_key=attr_name, + param_shape=sharded_scale.shape, + ) + ) + + def _insert_sharded_moe( gm: GraphModule, node: Node, @@ -1807,6 +1891,11 @@ def get_partition(lst, world_size, rank): # if tp_size > 1, we do 2D EP+TP sharding. if tp_size > 1: + # Capture original weight shapes before TP sharding (needed for scale TP sharding) + w_up_orig_shapes = [gm.get_parameter(w.target).shape for w in w_up_list_sharded] + w_down_orig_shapes = [gm.get_parameter(w.target).shape for w in w_down_list_sharded] + w_gate_orig_shapes = [gm.get_parameter(w.target).shape for w in w_gate_list_sharded] + # we add TP sharding of all expert weights. for w_up in w_up_list_sharded + w_gate_list_sharded: shard_weight_tensor( @@ -1843,6 +1932,27 @@ def get_partition(lst, world_size, rank): args[6 + i] = sharded scales_to_remove.extend(to_remove) + # ===================================================================================== + # TP-shard blocked scales (weight_scale for NVFP4, weight_scale_inv for FineGrained FP8) + # ===================================================================================== + if tp_size > 1 and scale_names: + _BLOCKED_SCALE_NAMES = {"weight_scale", "weight_scale_inv"} + for s_idx, s_name in enumerate(scale_names): + if s_name not in _BLOCKED_SCALE_NAMES: + continue + # For each scale_name, the 3 lists correspond to w_up, w_down, w_gate + # w_up/w_gate use COLUMN split (dim=0), w_down uses ROW split (dim=1) + scale_dim_groups = [ + (6 + s_idx * 3 + 0, SplitDimension.COLUMN, w_up_orig_shapes), + (6 + s_idx * 3 + 1, SplitDimension.ROW, w_down_orig_shapes), + (6 + s_idx * 3 + 2, SplitDimension.COLUMN, w_gate_orig_shapes), + ] + for arg_idx, dim, orig_shapes in scale_dim_groups: + for j, scale_node in enumerate(args[arg_idx]): + _tp_shard_moe_scale( + gm, scale_node, s_name, dim, tp_rank, tp_size, orig_shapes[j] + ) + if enable_alltoall: # --------------------------------------------------------------------------- # ALL-TO-ALL PARADIGM diff --git a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py index d54c86d0c37..b57b2d805ff 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/_graph.py @@ -571,6 +571,16 @@ def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Nod if is_op(lm_head_node, torch.ops.aten.to): lm_head_node = lm_head_node.all_input_nodes[0] + # Unwrap all_gather for sharded lm_head: when lm_head weight is column- + # sharded the graph contains lm_head_linear -> all_gather -> output. + # We look through the all_gather so that callers (e.g. + # gather_logits_before_lm_head) see the underlying linear and can insert + # gather_tokens *before* the sharded GEMM + all_gather, keeping both out + # of the main CUDA graph and avoiding NVLink contention with layer + # AllReduces. + if is_op(lm_head_node, torch.ops.auto_deploy.trtllm_dist_all_gather): + lm_head_node = lm_head_node.all_input_nodes[0] + return lm_head_node diff --git a/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py index ce417ba0252..b81ef759f4b 100644 --- a/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py @@ -11,7 +11,8 @@ import torch.nn.functional as F from _dist_test_utils import get_device_counts from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm -from _model_test_utils import FakeFineGrainedFP8Linear, FakeFP8Linear +from _model_test_utils import FakeFineGrainedFP8Linear, FakeFP8Linear, MoEOpModel +from _torch_test_utils import fp4_compatible, trtllm_ops_available from torch._inductor.pattern_matcher import stable_topological_sort import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common @@ -31,6 +32,7 @@ from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op, is_weight_node from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import ( cutlass_fp4_scale_to_modelopt_fp4_scale, + fp4_global_scale, modelopt_fp4_scale_to_cutlass_fp4_scale, ) from tensorrt_llm.functional import AllReduceStrategy @@ -1299,3 +1301,244 @@ def test_pad_nvfp4_weight_scale_roundtrip(n, k): if k_padded > k: pad_region_k = padded_modelopt[:n, k // block_size :] assert (pad_region_k.float() == 0).all(), "k-padding region should be zero" + + +class NVFP4MoEOpModel(nn.Module): + """NVFP4-quantized MoE model using torch.ops.trtllm.fp4_quantize. + + Mimics the real Qwen3.5-MoE NVFP4 checkpoint loading path where weights are + quantized via fp4_quantize and scales are in cutlass swizzled 2D format. + Dimensions must be compatible with NVFP4 block size (16) and cutlass alignment. + """ + + SCALING_VECTOR_SIZE = 16 + + def __init__(self, hidden_size=128, intermediate_size=256, num_experts=4, top_k=2): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts = num_experts + self.top_k = top_k + + self.gate = nn.Linear(hidden_size, num_experts) + + for i in range(num_experts): + w1_bf16 = (torch.randn(intermediate_size, hidden_size) * 0.1).to(torch.bfloat16) + w2_bf16 = (torch.randn(hidden_size, intermediate_size) * 0.1).to(torch.bfloat16) + w3_bf16 = (torch.randn(intermediate_size, hidden_size) * 0.1).to(torch.bfloat16) + + inp_scale = fp4_global_scale(w1_bf16) + + for prefix, w_bf16 in [("w1", w1_bf16), ("w2", w2_bf16), ("w3", w3_bf16)]: + wt_scale_2 = fp4_global_scale(w_bf16) + w_fp4, w_scale_1d = torch.ops.trtllm.fp4_quantize( + w_bf16.cuda(), wt_scale_2.cuda(), self.SCALING_VECTOR_SIZE, False + ) + _, k_packed = w_fp4.shape + k_elements = k_packed * 2 # uint8 packs 2 fp4 values + n_scale = k_elements // self.SCALING_VECTOR_SIZE + m_scale = w_scale_1d.numel() // n_scale + w_scale_2d = w_scale_1d.reshape(m_scale, n_scale).contiguous() + + self.register_parameter( + f"expert_{i}_{prefix}", + nn.Parameter(w_fp4.cpu(), requires_grad=False), + ) + self.register_buffer(f"expert_{i}_{prefix}_input_scale", inp_scale) + self.register_buffer(f"expert_{i}_{prefix}_weight_scale", w_scale_2d.cpu()) + self.register_buffer( + f"expert_{i}_{prefix}_alpha", + (1.0 / (inp_scale * wt_scale_2)).cpu(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + router_logits = self.gate(x) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(x.dtype) + + w1 = [getattr(self, f"expert_{i}_w1") for i in range(self.num_experts)] + w2 = [getattr(self, f"expert_{i}_w2") for i in range(self.num_experts)] + w3 = [getattr(self, f"expert_{i}_w3") for i in range(self.num_experts)] + w1_is = [getattr(self, f"expert_{i}_w1_input_scale") for i in range(self.num_experts)] + w2_is = [getattr(self, f"expert_{i}_w2_input_scale") for i in range(self.num_experts)] + w3_is = [getattr(self, f"expert_{i}_w3_input_scale") for i in range(self.num_experts)] + w1_ws = [getattr(self, f"expert_{i}_w1_weight_scale") for i in range(self.num_experts)] + w2_ws = [getattr(self, f"expert_{i}_w2_weight_scale") for i in range(self.num_experts)] + w3_ws = [getattr(self, f"expert_{i}_w3_weight_scale") for i in range(self.num_experts)] + w1_a = [getattr(self, f"expert_{i}_w1_alpha") for i in range(self.num_experts)] + w2_a = [getattr(self, f"expert_{i}_w2_alpha") for i in range(self.num_experts)] + w3_a = [getattr(self, f"expert_{i}_w3_alpha") for i in range(self.num_experts)] + + return torch.ops.auto_deploy.torch_quant_nvfp4_moe( + x, + selected_experts, + routing_weights, + w1, + w2, + w3, + w1_is, + w2_is, + w3_is, + w1_ws, + w2_ws, + w3_ws, + w1_a, + w2_a, + w3_a, + ) + + def get_input(self, device, dtype=torch.bfloat16): + return torch.randn(2, self.hidden_size, device=device, dtype=dtype) + + +def _run_nvfp4_moe_tp_shard_job( + num_experts: int, + _rank: int, + world_size: int, +) -> None: + """Run NVFP4 MoE TP sharding test. See NVFP4MoEOpModel for scale format details.""" + device = "cuda" + hidden_size = 128 + intermediate_size = 256 + model = NVFP4MoEOpModel( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts=num_experts, + ).to(device=device) + x = model.get_input(device=device, dtype=torch.bfloat16) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + + # Apply MoE TP sharding with moe_tp=world_size, moe_ep=1 + gm_transformed = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "sharding_dims": ["ep"], + "dist_mapping": {"moe_tp": world_size, "moe_ep": 1}, + }, + "sharding_transform_executor": { + "stage": "sharding", + }, + }, + )(None, gm) + + # Verify all_reduce is inserted after MoE node + allreduce_correct = any( + is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm_transformed.graph.nodes + ) == (world_size > 1) + assert allreduce_correct, f"Expected all_reduce present={world_size > 1} after MoE TP sharding" + + # Verify: NVFP4 expert weights should be sharded along TP dimension + # FP4 weights are uint8-packed (2 elements per byte), so packed dim is halved + if world_size > 1: + for name, param in gm_transformed.named_parameters(): + if "experts" in name and "w1" in name: + # w1 (up_proj) column-sharded: intermediate_size // world_size rows + assert param.shape[0] == intermediate_size // world_size, ( + f"w1 {name} shape {param.shape} not TP-sharded " + f"(expected dim0={intermediate_size // world_size})" + ) + elif "experts" in name and "w2" in name: + # w2 (down_proj) row-sharded: packed k dim = intermediate_size // world_size // 2 + expected_k_packed = intermediate_size // world_size // 2 + assert param.shape[1] == expected_k_packed, ( + f"w2 {name} shape {param.shape} not TP-sharded " + f"(expected packed dim1={expected_k_packed})" + ) + elif "experts" in name and "w3" in name: + # w3 (gate_proj) column-sharded: intermediate_size // world_size rows + assert param.shape[0] == intermediate_size // world_size, ( + f"w3 {name} shape {param.shape} not TP-sharded " + f"(expected dim0={intermediate_size // world_size})" + ) + + +def _run_moe_tp_shard_job( + num_experts: int, + _rank: int, + world_size: int, +) -> None: + """Run BF16 MoE TP sharding test.""" + device = "cuda" + hidden_size = 32 + intermediate_size = 16 + model = MoEOpModel( + hidden_size=hidden_size, + num_experts=num_experts, + intermediate_size=intermediate_size, + ).to(device=device, dtype=torch.bfloat16) + x = model.get_input(device=device, dtype=torch.bfloat16) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "detect_sharding": { + "stage": "sharding", + "sharding_dims": ["ep"], + "dist_mapping": {"moe_tp": world_size, "moe_ep": 1}, + }, + "sharding_transform_executor": { + "stage": "sharding", + }, + }, + )(None, gm) + + # Verify: TP sharding should insert all_reduce after MoE node + allreduce_correct = any( + is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm_transformed.graph.nodes + ) == (world_size > 1) + assert allreduce_correct, ( + f"Expected all_reduce present={world_size > 1} after MoE TP sharding, " + f"world_size={world_size}" + ) + + # Verify: expert weights should be sharded along TP dimension + if world_size > 1: + for name, param in gm_transformed.named_parameters(): + if "experts" in name and "w1" in name: + # w1 (up_proj) is column-sharded: intermediate_size // world_size + assert param.shape[0] == intermediate_size // world_size, ( + f"w1 {name} shape {param.shape} not TP-sharded " + f"(expected dim0={intermediate_size // world_size})" + ) + elif "experts" in name and "w2" in name: + # w2 (down_proj) is row-sharded: hidden_size x (intermediate_size // world_size) + assert param.shape[1] == intermediate_size // world_size, ( + f"w2 {name} shape {param.shape} not TP-sharded " + f"(expected dim1={intermediate_size // world_size})" + ) + elif "experts" in name and "w3" in name: + # w3 (gate_proj) is column-sharded: intermediate_size // world_size + assert param.shape[0] == intermediate_size // world_size, ( + f"w3 {name} shape {param.shape} not TP-sharded " + f"(expected dim0={intermediate_size // world_size})" + ) + + +@pytest.mark.parametrize("device_count", get_device_counts([2, 8])) +@pytest.mark.parametrize("num_experts", [4, 8]) +def test_moe_tp_shard_bf16(device_count: int, num_experts: int): + """Test MoE TP sharding with BF16 weights.""" + dist_common.spawn_multiprocess_job( + job=partial(_run_moe_tp_shard_job, num_experts), + size=device_count, + ) + + +@pytest.mark.skipif( + not (fp4_compatible() and trtllm_ops_available()), + reason="Requires NVFP4 support (SM100+) and TRTLLM ops", +) +@pytest.mark.parametrize("device_count", get_device_counts([2, 8])) +@pytest.mark.parametrize("num_experts", [4, 8]) +def test_moe_tp_shard_nvfp4(device_count: int, num_experts: int): + """Test MoE TP sharding with NVFP4 quantized weights (Qwen3.5-like).""" + dist_common.spawn_multiprocess_job( + job=partial(_run_nvfp4_moe_tp_shard_job, num_experts), + size=device_count, + ) diff --git a/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py b/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py index b1c1e604e0f..c241294aa57 100644 --- a/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py +++ b/tests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.py @@ -17,10 +17,12 @@ PiecewiseCapturedGraph, _args_kwargs_flatten_spec, ) +from tensorrt_llm._torch.auto_deploy.compile.piecewise_runner import ADPiecewiseRunner from tensorrt_llm._torch.auto_deploy.compile.piecewise_utils import submod_has_cuda_ops from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.shim.ad_executor import _round_up_to_closest from tensorrt_llm._torch.auto_deploy.transform.library.compile_model import ( + CompileModel, _generate_default_piecewise_num_tokens, ) @@ -172,6 +174,49 @@ def get_args_kwargs(bs): ) +# ============================================================================ +# Tests for CapturedGraph capture-time truncation +# ============================================================================ + + +class TestCapturedGraphCapture: + """Tests for capture-time input truncation in CapturedGraph.""" + + def test_capture_graph_uses_per_input_extents_for_truncation(self, monkeypatch): + class ModelWithDifferentDynamicDims(nn.Module): + def forward(self, x, y): + return x.sum() + y.sum() + + compiled_model = CapturedGraph( + ModelWithDifferentDynamicDims(), + num_batched_inputs=2, + ) + captured_shapes = [] + + def fake_capture_one_graph(self, *args, **kwargs): + captured_shapes.append(tuple(arg.shape for arg in args)) + return object() + + monkeypatch.setattr(CapturedGraph, "_capture_one_graph", fake_capture_one_graph) + + def get_args_kwargs(bs): + x = torch.arange(bs * 2, dtype=torch.float32).reshape(bs, 2) + y = torch.arange(2 * (bs + 1), dtype=torch.float32).reshape(2, bs + 1) + return (x, y), {} + + compiled_model.capture_graph(get_args_kwargs, [5, 3]) + + assert compiled_model.dynamic_dims == [0, 1] + assert captured_shapes == [ + (torch.Size([5, 2]), torch.Size([2, 6])), + (torch.Size([3, 2]), torch.Size([2, 4])), + ] + assert set(compiled_model.cudagraphs) == { + (5, 2, 2, 6), + (3, 2, 2, 4), + } + + # ============================================================================ # Helpers for piecewise / submod_has_cuda_ops tests # ============================================================================ @@ -325,6 +370,27 @@ def test_get_num_tokens_no_input(self): dual = self._make_dual_mode() assert dual._get_num_tokens() == 0 + def test_truncate_output_preserves_tensor_type(self): + dual = self._make_dual_mode() + result = torch.arange(12, dtype=torch.float32).reshape(3, 4) + + truncated = dual._truncate_output(result, num_tokens=2, bucket=4) + + assert isinstance(truncated, torch.Tensor) + assert truncated.shape == (3, 2) + assert torch.equal(truncated, result[:, :2]) + + def test_truncate_output_prefers_monolithic_output_dynamic_dim(self): + dual = self._make_dual_mode() + dual.monolithic._output_dynamic_dim = 1 + result = torch.arange(16, dtype=torch.float32).reshape(4, 4) + + truncated = dual._truncate_output(result, num_tokens=2, bucket=4) + + assert isinstance(truncated, torch.Tensor) + assert truncated.shape == (4, 2) + assert torch.equal(truncated, result[:, :2]) + @pytest.mark.parametrize( "num_tokens, expected_bucket", [ @@ -372,6 +438,94 @@ def test_prepare_is_idempotent(self): assert pcg._is_prepared is True +# ============================================================================ +# Tests for PiecewiseCapturedGraph output handling +# ============================================================================ + + +class TestPiecewiseCapturedGraphOutputHandling: + """Tests for output reconstruction and forward-state cleanup.""" + + def test_reconstruct_output_warns_on_unflatten_failure(self, monkeypatch): + pcg = PiecewiseCapturedGraph(nn.Linear(4, 4), piecewise_num_tokens=[8]) + pcg._out_spec = MagicMock() + pcg._out_spec.unflatten.side_effect = ValueError("boom") + warnings = [] + + monkeypatch.setattr( + "tensorrt_llm._torch.auto_deploy.compile.backends.torch_cudagraph.ad_logger.warning", + lambda msg, *args: warnings.append(msg % args if args else msg), + ) + + result = (torch.tensor([1.0]),) + + assert pcg._reconstruct_output(result) is result + assert len(warnings) == 1 + assert "failed to unflatten output" in warnings[0] + + def test_forward_clears_num_tokens_on_error(self): + pcg = PiecewiseCapturedGraph(nn.Linear(4, 4), piecewise_num_tokens=[8]) + pcg.split_gm = MagicMock(side_effect=RuntimeError("boom")) + ADPiecewiseRunner.set_current_num_tokens(None) + + with pytest.raises(RuntimeError, match="boom"): + pcg.forward(num_tokens=8) + + assert ADPiecewiseRunner._current_num_tokens is None + + +# ============================================================================ +# Tests for PiecewiseCapturedGraph static input buffers +# ============================================================================ + + +class TestPiecewiseCapturedGraphStaticInputBuffers: + """Tests for static kwarg buffers used by piecewise capture.""" + + @pytest.mark.parametrize( + ("buf_shape", "src_shape", "dyn_dim"), + [ + ((8, 4), (3, 4), 0), + ((2, 8), (2, 3), 1), + ], + ) + def test_copy_to_static_buffers_preserves_runtime_shape(self, buf_shape, src_shape, dyn_dim): + pcg = PiecewiseCapturedGraph(nn.Linear(4, 4), piecewise_num_tokens=[8]) + static_buffer = torch.full(buf_shape, fill_value=-1.0) + src = torch.arange(torch.Size(src_shape).numel(), dtype=torch.float32).reshape(src_shape) + pcg._static_input_buffers["input_ids"] = (static_buffer, dyn_dim) + kwargs = {"input_ids": src} + + pcg._copy_to_static_buffers(kwargs) + + copied = kwargs["input_ids"] + assert copied.shape == src.shape + assert copied.data_ptr() == static_buffer.data_ptr() + assert copied is not static_buffer + assert torch.equal(copied, src) + + def test_allocate_static_input_buffers_handles_static_shape_unstable_kwarg(self): + pcg = PiecewiseCapturedGraph(nn.Linear(4, 4), piecewise_num_tokens=[8]) + + def get_args_kwargs(_): + return (), {"input_ids": torch.arange(8, dtype=torch.float32)} + + pcg._allocate_static_input_buffers(get_args_kwargs) + + static_buffer, dyn_dim = pcg._static_input_buffers["input_ids"] + assert dyn_dim is None + + src = torch.arange(8, dtype=torch.float32) + kwargs = {"input_ids": src} + pcg._copy_to_static_buffers(kwargs) + + copied = kwargs["input_ids"] + assert copied.shape == src.shape + assert copied.data_ptr() == static_buffer.data_ptr() + assert copied is static_buffer + assert torch.equal(copied, src) + + # ============================================================================ # Tests for _generate_default_piecewise_num_tokens (compile_model.py) # ============================================================================ @@ -422,3 +576,47 @@ def test_no_duplicates_when_max_is_power_of_two(self): result = _generate_default_piecewise_num_tokens(4096) # 4096 is already a power of 2, should not be duplicated assert result.count(4096) == 1 + + +# ============================================================================ +# Tests for CompileModel GraphModule target collection +# ============================================================================ + + +class TestCompileModelGraphModuleTargetCollection: + """Tests for selecting GraphModule compile targets.""" + + def test_root_graphmodule_skips_child_graphmodules(self, monkeypatch): + root_gm = _build_trivial_graphmodule() + child_gm = _build_trivial_graphmodule() + root_gm.child = child_gm + compiled_models = [] + + class FakeBackend: + def __init__(self, model, **compiler_kwargs): + self.model = model + + def compile(self): + compiled_models.append(self.model) + return self.model + + monkeypatch.setattr( + "tensorrt_llm._torch.auto_deploy.transform.library.compile_model.CompileBackendRegistry.get", + lambda backend: FakeBackend, + ) + + transform = CompileModel.from_kwargs(stage="compile", backend="torch-simple") + cm = MagicMock() + cm.info = MagicMock() + cm.named_args = {} + + mod_compiled, info = transform._apply_to_full_model( + root_gm, + cm=cm, + factory=MagicMock(), + shared_config=MagicMock(), + ) + + assert mod_compiled is root_gm + assert info.skipped is False + assert compiled_models == [root_gm]