[#12699][feat] AutoDeploy: Support Piecewise CG for VLMs#12749
[#12699][feat] AutoDeploy: Support Piecewise CG for VLMs#12749nvchenghaoz wants to merge 30 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
(This was good for Qwen3.5 w/ long input sequence (=> 15000) Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Fix lm_head not sharded in Qwen3_5MoeForConditionalGeneration export
The Qwen3_5MoeForConditionalGeneration factory exports only
model.language_model, leaving lm_head outside the graph.
This causes lm_head to run unsharded (248320x4096 on every GPU) and prevents gather_logits_before_lm_head from optimizing it.
Graft lm_head into the exported graph during post_process:
- Capture lm_head from the parent model in from_autoinferred()
- Insert auto_deploy.torch_linear_simple + aten.to.dtype nodes wittexplicit names for filter matching
- Set _lm_head_grafted flag so the parent forward skips redundant
lm_head during cache init
- Add "lm_head": "gather" to the manual tp_plan for column-spli sharding capture.
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
1. Checks position_ids is not None and not has_images and not has_videos _ this is the AD runtime text-only path (every call in your benchmark) 2. Expands position_ids to 3D (same logic as the original) 3. Calls self.language_model(inputs_embeds=..., position_ids=...) with just 2 args _ no **kwargs This skips all of the following overhead that was running on every forward step: - 12 kwargs.get() calls for multimodal metadata extraction - has_chunk_mm_layout check with .numel() and .item() calls - mrope_delta_cache lookup loop over all kwargs - 3D position_ids conditional branching with cu_seqlen tensor ops - 10-item kwargs.pop() loop + key-suffix scan - **kwargs passthrough to language_model (forces flatten/hash) Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
From the commit 93d99f1, Qwen3.5 model was created by Qwen3_5MoeFactory. However, it exports only the inner text model as a GraphModule, wrapping it in a non-GraphModule wrapper. This broke piecewise CUDA graph capture. This commit fixes it by exporting the full model as a single GraphModule. Also added _init_dynamic_shape_lookup() returns 2D position_ids spec for full model export (the 2D_3D expansion happens inside the traced graph) Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…by forward(). Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
remove redundant comment Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…wise cuda graph. Exported only text model (decoder + lm_head) as a graph module and added named_args preprocessing hook on CachedSequenceInterface to convert input_ids to inputs_embed outside the graph. Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
📝 WalkthroughWalkthroughThe PR updates the Qwen3.5 MoE 400B model deployment infrastructure by extending model capacity configuration, restructuring lm_head ownership to the text model, refactoring CUDA graph capture for dynamic dimensions and nested modules, and extending TP-sharding with MoE-specific quantization scale handling. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py`:
- Around line 226-237: The capture uses the wrong truncated sizes: after copying
each input from args_batched you compute size_i but then build inputs_truncated
using a fixed bs, which records incorrect shapes for dynamic dims; update the
truncation to narrow each self._input_buffers[i] using the corresponding size
from args_batched (the same size_i computed in the copy loop) so
inputs_truncated is created with per-input extents (refer to args_batched,
dynamic_dims, self._input_buffers, inputs_truncated and bs) — e.g., replace the
list comprehension that uses bs with one that uses the per-index size computed
from args_batched.
- Around line 537-543: In _copy_to_static_buffers the code replaces kwargs[key]
with the full pre-allocated buffer (buf), which changes the logical shape;
instead assign a narrowed view of the buffer that matches the source's runtime
size so address stability is preserved but the graph sees the original shape.
Concretely, after copying (buf.narrow(dyn_dim, 0,
src.shape[dyn_dim]).copy_(src)), set kwargs[key] = buf.narrow(dyn_dim, 0,
src.shape[dyn_dim]) (or assign that view to a variable) instead of kwargs[key] =
buf; operate on symbols _copy_to_static_buffers, _static_input_buffers, buf,
dyn_dim, and kwargs.
- Around line 724-730: The fallback path iterates over result and thus corrupts
plain torch.Tensor outputs (returning a tuple of slices); add an early branch
that detects if result is a torch.Tensor and return _narrow(result) (or the
appropriately truncated tensor) before the existing hasattr(result, "to_tuple")
and isinstance(result, abc.Mapping) checks so that tensor outputs keep their
original type; reference the symbols result, _narrow, hasattr(result,
"to_tuple"), and abc.Mapping when making the change.
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py`:
- Line 830: The model double-registers the same nn.Linear under two paths
causing state_dict key mismatches; to fix, remove the duplicate registration by
ensuring only one attribute owns the linear (either keep self.lm_head in the
parent OR let the child call set_lm_head but do NOT assign self.lm_head twice);
specifically update Qwen3_5MoeForCausalLM / Qwen3_5MoeForConditionalGeneration
to stop assigning the same module to both self.lm_head and via set_lm_head(),
and add a proper _tied_weights_keys tuple in Qwen3_5MoeForConditionalGeneration
(and verify the one in Qwen3_5MoeForCausalLM) that maps the single top-level
lm_head key (e.g. ("lm_head.weight","lm_head.bias") or the correct pair used in
your code) so load_state_dict() only expects the HF checkpoint keys.
In `@tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py`:
- Around line 153-160: The loop collecting compile_targets lets child
GraphModules through when the root mod is a GraphModule because seen contains ""
but the current filter (if p) ignores it; update the logic so that when mod
itself is a GraphModule you mark that fact in seen and then treat an
empty-string entry as matching all children. Concretely: when detecting the root
GraphModule (isinstance(mod, GraphModule)) add "" to seen (or otherwise record
the root), and change the membership check from if any(name.startswith(p + ".")
for p in seen if p): to something that treats p == "" as matching (e.g., if
any(p == "" or name.startswith(p + ".") for p in seen):) so child GraphModules
are skipped when the parent/root GraphModule is already scheduled for
compilation (affecting variables compile_targets, seen, mod.named_modules(),
name, submod).
In `@tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py`:
- Around line 1926-1942: The loop currently TP-shards any scale named
"weight_scale" (from FP8EPShardingInfo.scale_names()) treating it like
NVFP4/CUTLASS; to fix, skip TP-sharding for plain FP8 weight_scale by adding a
guard before calling _tp_shard_moe_scale: if s_name == "weight_scale" and the
sharding/layout indicates plain FP8 (not NVF4/CUTLASS), continue (leave
replicated). Update the condition around _BLOCKED_SCALE_NAMES or add an explicit
check using the sharding/layout flag available in the context (e.g., inspect the
EP sharding info or a format enum), ensuring tp_size, scale_names,
_tp_shard_moe_scale and FP8EPShardingInfo.scale_names() are used to decide
whether to call _tp_shard_moe_scale.
In
`@tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py`:
- Around line 1438-1457: The test's loop over gm_transformed.named_parameters()
is filtering on "experts" but NVFP4MoEOpModel registers parameters as
"expert_{i}_w*", so the shard-shape assertions never run; update the condition
in the loop to detect the actual parameter names (e.g., check for "expert_" or
match r"expert_\d+_w" / name.startswith("expert_") and then the specific
"w1"/"w2"/"w3" suffixes) so the branches for w1, w2, w3 execute and validate
dim0/dim1 against intermediate_size // world_size (and packed k logic) for the
NVFP4 TP-sharded weights in gm_transformed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e90b7651-1a5f-4af9-bb7d-07e4c2f0186e
📒 Files selected for processing (8)
examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yamltensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.pytensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.pytensorrt_llm/_torch/auto_deploy/transform/library/compile_model.pytensorrt_llm/_torch/auto_deploy/transform/library/quantization.pytensorrt_llm/_torch/auto_deploy/transform/library/sharding.pytensorrt_llm/_torch/auto_deploy/utils/_graph.pytests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py
tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
Outdated
Show resolved
Hide resolved
tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py
Show resolved
Hide resolved
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
|
/bot run |
|
@CodeRabbit please review the piecewise related code. |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 40271 🏁 Script executed: Length of output: 7180
Review: Piecewise CUDA Graph (
|
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
|
@CodeRabbit more review please |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 295700
Deeper review and fixes:
Integration touchpoints verified:
Proposed diffs:
# In PiecewiseCapturedGraph.forward(...)
- return self.original_model(*args, **kwargs)
+ return self.original_model(*args, num_tokens=num_tokens, **kwargs)
# In PiecewiseCapturedGraph._reconstruct_output(...)
try:
return self._out_spec.unflatten(list(result))
- except Exception:
- return result
+ except Exception as e:
+ ad_logger.warning("_reconstruct_output: failed to unflatten output (%s); returning raw tuple", e)
+ return result
# In PiecewiseCapturedGraph._allocate_static_input_buffers(...)
- if dyn_dim is not None:
- self._static_input_buffers[key] = (torch.empty_like(v1), dyn_dim)
+ if dyn_dim is not None:
+ self._static_input_buffers[key] = (torch.empty_like(v1), dyn_dim)
+ else:
+ # Address is unstable but no dynamic dim — still buffer and copy full tensor
+ self._static_input_buffers[key] = (torch.empty_like(v1), None)
# In PiecewiseCapturedGraph.forward(...) where copying kwargs into static buffers
- buf, dyn_dim = self._static_input_buffers[key]
- if dyn_dim is None:
- continue
- bs = args_batched[0].shape[dyn_dim]
- buf.narrow(dyn_dim, 0, bs).copy_(v_current.narrow(dyn_dim, 0, bs))
- kwargs_batched[key] = buf.narrow(dyn_dim, 0, bs)
+ buf, dyn_dim = self._static_input_buffers[key]
+ if dyn_dim is None:
+ buf.copy_(v_current)
+ kwargs_batched[key] = buf
+ else:
+ bs = args_batched[0].shape[dyn_dim]
+ buf.narrow(dyn_dim, 0, bs).copy_(v_current.narrow(dyn_dim, 0, bs))
+ kwargs_batched[key] = buf.narrow(dyn_dim, 0, bs)
# In DualModeCapturedGraph.forward(...)
- result = self.piecewise(*args, num_tokens=bucket, **kwargs)
- ADPiecewiseRunner.set_current_num_tokens(None)
+ try:
+ result = self.piecewise(*args, num_tokens=bucket, **kwargs)
+ finally:
+ ADPiecewiseRunner.set_current_num_tokens(None)
# In DualModeCapturedGraph.__getattr__(...)
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.monolithic.model, name)
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ try:
+ return getattr(self.monolithic.model, name)
+ except AttributeError:
+ raise AttributeError(f\"{type(self).__name__} and its inner model have no attribute '{name}'\")
# In DualModeCapturedGraph._truncate_output(...)
- for d in range(v.ndim):
- if v.shape[d] == bucket:
- return v.narrow(d, 0, num_tokens)
- return v
+ matching = [d for d in range(v.ndim) if v.shape[d] == bucket]
+ if len(matching) != 1:
+ ad_logger.warning(\"_truncate_output: ambiguous dim for shape %s, bucket=%d; dims=%s\",
+ tuple(v.shape), bucket, matching)
+ d = matching[0] if matching else -1
+ return v.narrow(d, 0, num_tokens) if d >= 0 else v
# In get_lm_head_node(...)
- if node.target is torch.ops.auto_deploy.trtllm_dist_all_gather:
+ if node.target in {
+ torch.ops.auto_deploy.trtllm_dist_all_gather.default,
+ torch.ops.auto_deploy.torch_dist_all_gather.default,
+ }:
node = list(node.all_input_nodes)[0]Why these now:
If you want, I can also add a tiny unit test to assert:
Happy to prepare those next. 🏁 Script executed: Length of output: 7041
What looks good
Actionable fixes and improvements
Patch: # tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
class PiecewiseCapturedGraph(nn.Module):
def forward(self, *args, num_tokens: Optional[int] = None, **kwargs) -> Any:
if self.split_gm is not None:
self._copy_to_static_buffers(kwargs)
ADPiecewiseRunner.set_current_num_tokens(num_tokens)
try:
result = self.split_gm(*args, **kwargs)
finally:
ADPiecewiseRunner.set_current_num_tokens(None)
return self._reconstruct_output(result)
- return self.original_model(*args, **kwargs)
+ return self.original_model(*args, num_tokens=num_tokens, **kwargs)
Patch (conceptual; adjust to your _copy_to_static_buffers/_allocate_static_input_buffers structure): # When detecting instability:
if dyn_dim is not None:
self._static_input_buffers[key] = (torch.empty_like(v1), dyn_dim)
else:
# address unstable but static shape — still buffer it
self._static_input_buffers[key] = (torch.empty_like(v1), None)
# When copying into buffers before replay:
buf, dyn_dim = self._static_input_buffers[key]
if dyn_dim is None:
buf.copy_(v_current)
kwargs_batched[key] = buf
else:
bs = args_batched[0].shape[dyn_dim]
buf.narrow(dyn_dim, 0, bs).copy_(v_current.narrow(dyn_dim, 0, bs))
kwargs_batched[key] = buf.narrow(dyn_dim, 0, bs)
Patch: def _truncate_output(self, result: Any, num_tokens: int, bucket: int) -> Any:
def cut(v: torch.Tensor) -> torch.Tensor:
matching = [d for d in range(v.ndim) if v.shape[d] == bucket]
if len(matching) != 1:
ad_logger.warning(
"_truncate_output: ambiguous dim for shape %s, bucket=%d; dims=%s",
tuple(v.shape), bucket, matching
)
d = matching[0] if matching else -1
return v.narrow(d, 0, num_tokens) if d >= 0 else v
if isinstance(result, tuple):
return tuple(cut(v) if isinstance(v, torch.Tensor) else v for v in result)
return cut(result) if isinstance(result, torch.Tensor) else resultOptional but recommended: during capture, compare each output tensor’s shape at bucket vs probe to record its own token-dim index (e.g., self._output_dynamic_dims[i]), then use that instead of the heuristic.
Patch: def _reconstruct_output(self, result: Any) -> Any:
if not isinstance(result, tuple) or self._out_spec is None:
return result
try:
return self._out_spec.unflatten(list(result))
except Exception as e:
ad_logger.warning("_reconstruct_output: failed to unflatten output (%s); returning raw tuple", e)
return result
Patch: # tensorrt_llm/_torch/auto_deploy/utils/_graph.py
from torch.fx import GraphModule, Node
import torch
def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Node:
...
- if is_op(lm_head_node, torch.ops.auto_deploy.trtllm_dist_all_gather):
- lm_head_node = lm_head_node.all_input_nodes[0]
+ if is_op(lm_head_node, {
+ torch.ops.auto_deploy.trtllm_dist_all_gather,
+ torch.ops.auto_deploy.torch_dist_all_gather,
+ }):
+ lm_head_node = lm_head_node.all_input_nodes[0]
return lm_head_nodeNice-to-have
Would you like me to stage these as a small follow-up patch and add a couple of focused unit tests (num_tokens passthrough and truncation-dim disambiguation)? If yes, I can prepare a plan next. |
|
Close this one since it contains other work. This branch is just to check whether the change fixes the Qwen VLMs. Will create a new PR to have Piecewise CG changes only. |
#12699
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests