Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7543abb
ep1
taylor-yb-lee Mar 16, 2026
b65229b
moe sharding tp8
taylor-yb-lee Mar 16, 2026
9ead146
Exclude lm_head from cuda graph
taylor-yb-lee Mar 16, 2026
0ed7977
Turn on lm_head_sharding
taylor-yb-lee Mar 16, 2026
8bcf4dd
Fix tp8 sharding for fused moe checkpoint
taylor-yb-lee Mar 20, 2026
7e0172a
Added qwen3.5 config for long context length
taylor-yb-lee Mar 20, 2026
f4cb16f
Qwen3.5 configs
taylor-yb-lee Mar 22, 2026
693ed16
Added comment
taylor-yb-lee Mar 24, 2026
331af8a
Added moe sharding for tp8/ep1
taylor-yb-lee Mar 24, 2026
4a0cf8d
Added unittest for tp sharding for NVFP4 MoE
taylor-yb-lee Mar 24, 2026
93b48a2
- Revert freemem size
taylor-yb-lee Mar 25, 2026
3777753
revert graphting of lm_head and add lm_head in the model graph
taylor-yb-lee Apr 1, 2026
e872bb8
The fix adds a text-only fast path at the top of forward() that:
taylor-yb-lee Mar 31, 2026
22e7154
Allow quwen3.5 to use AutoModelr
taylor-yb-lee Apr 1, 2026
ee96ad4
Fix piecewise CUDA graph for Qwen3.5 MoE
taylor-yb-lee Apr 2, 2026
0a0ccf8
config for text only case
taylor-yb-lee Apr 2, 2026
aaf33aa
removed unnecessary code
taylor-yb-lee Apr 2, 2026
22ed629
Rename variable
taylor-yb-lee Apr 2, 2026
f92e78e
Add assert
taylor-yb-lee Apr 2, 2026
5d589ae
Revised comment and added a method to clarify is_full_model
taylor-yb-lee Apr 2, 2026
7cc985d
Revert fast path (it does not affect performance)
taylor-yb-lee Apr 2, 2026
14c5351
Fixed set_output_embeddings() no longer updates module actually used …
taylor-yb-lee Apr 2, 2026
c26c050
remove tp8 config
taylor-yb-lee Apr 2, 2026
bac09ce
Extract text-model (graph module) in Qwen3.5 model for enabling piece…
taylor-yb-lee Apr 2, 2026
fa5f485
Remove picewise cudagraph w/a
taylor-yb-lee Apr 2, 2026
c44662d
Revert unnecessary change for using AutoModel for Qwen3.5 text model
taylor-yb-lee Apr 2, 2026
ce45322
Update Piecewise CG to support VLMs
nvchenghaoz Apr 3, 2026
0041583
address CR's reviews
nvchenghaoz Apr 6, 2026
d2e6fa4
Merge branch 'main' into chenghao/piecewise_update_0402
nvchenghaoz Apr 6, 2026
ee798a3
further address CR's review
nvchenghaoz Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ max_batch_size: 32
cuda_graph_config:
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
enable_chunked_prefill: true
# Use AutoModelForCausalLM for text only mode until issue #12699 is resolved
model_factory: Qwen3_5MoeForConditionalGeneration
kv_cache_config:
enable_block_reuse: false
Expand All @@ -15,13 +16,18 @@ kv_cache_config:
model_kwargs:
torch_dtype: bfloat16
transforms:
# disable for text only use case
initialize_mrope_delta_cache:
enabled: true
export_to_gm:
num_moe_experts_for_export: 2
fuse_gemms_mixed_children:
enabled: true
fuse_nvfp4_moe:
backend: trtllm_gen
detect_sharding:
# for long input, tp8ep1 gives better performance
# dist_mapping: {moe_tp: 8, moe_ep: 1}
allreduce_strategy: SYMM_MEM
shard_all_unprocessed: true
simple_shard_filter: "lm_head"
Expand All @@ -37,6 +43,9 @@ transforms:
"k_proj": "colwise"
"v_proj": "colwise"
"o_proj": "rowwise"
# lm_head: "gather" = column split + all_gather (not "colwise" which
# requires a LayerSubgraph and crashes for standalone unprocessed nodes)
"lm_head": "gather"
# replicating shared experts (keep them commented out)
# "shared_expert_gate_proj": "colwise"
# "shared_expert_up_proj": "colwise"
Expand Down
335 changes: 282 additions & 53 deletions tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,11 @@ class Qwen3_5MoeCausalLMOutput(ModelOutput):


class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel):
"""Qwen3.5 MoE text model (embed + decoder layers + final norm)."""
"""Qwen3.5 MoE text model (embed + decoder layers + final norm + lm_head).

lm_head is included so that the exported GraphModule contains it directly,
allowing sharding and gather_logits_before_lm_head transforms to see it.
"""

def __init__(self, config: Qwen3_5MoeTextConfig):
super().__init__(config)
Expand All @@ -746,10 +750,15 @@ def __init__(self, config: Qwen3_5MoeTextConfig):
)
self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config)
self.lm_head = None # set by parent model via set_lm_head()

# Initialize weights and apply final processing
self.post_init()

def set_lm_head(self, lm_head: nn.Module):
"""Set the lm_head from the parent model."""
self.lm_head = lm_head

def get_input_embeddings(self):
return self.embed_tokens

Expand Down Expand Up @@ -801,7 +810,11 @@ def forward(
hidden_states = decoder_layer(hidden_states, position_embeddings=position_embeddings)

hidden_states = self.norm(hidden_states)
return Qwen3_5MoeOutput(last_hidden_state=hidden_states)
assert self.lm_head is not None, (
"lm_head not set — call set_lm_head() from the parent model before forward()"
)
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
return Qwen3_5MoeCausalLMOutput(logits=logits)


class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin):
Expand All @@ -814,6 +827,7 @@ def __init__(self, config: Qwen3_5MoeTextConfig, **kwargs):
self.model = Qwen3_5MoeTextModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.model.set_lm_head(self.lm_head)

# Initialize weights and apply final processing
self.post_init()
Expand All @@ -829,6 +843,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
self.model.set_lm_head(new_embeddings)

def forward(
self,
Expand All @@ -848,8 +863,7 @@ def forward(
rope_cos=rope_cos,
rope_sin=rope_sin,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
logits = outputs.logits
return Qwen3_5MoeCausalLMOutput(logits=logits)


Expand Down Expand Up @@ -2565,10 +2579,19 @@ def __init__(self, config: Qwen3_5MoeConfig, **kwargs):
self.lm_head = nn.Linear(
config.text_config.hidden_size, config.text_config.vocab_size, bias=False
)
# Share lm_head with the text model so it's inside the exported graph
self.model.language_model.set_lm_head(self.lm_head)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.language_model.get_input_embeddings()

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
self.model.language_model.set_lm_head(new_embeddings)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand All @@ -2590,8 +2613,7 @@ def forward(
video_grid_thw=video_grid_thw,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
logits = outputs.logits
return Qwen3_5MoeConditionalOutput(logits=logits)


Expand All @@ -2607,6 +2629,9 @@ class Qwen3_5MoeTextExportInfo(TextModelExportInfo):
(batch, sequence) are dynamic.
"""

def __init__(self, submodule_name: str):
super().__init__(submodule_name)

def _init_dynamic_shape_lookup(self):
base = super()._init_dynamic_shape_lookup()
batch_size_dyn = Dim.DYNAMIC
Expand Down Expand Up @@ -2858,4 +2883,7 @@ def init_input_processor(self, base):
AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeTextConfig)

AutoModelForCausalLMFactory.register_custom_model_cls("Qwen3_5MoeTextConfig", Qwen3_5MoeForCausalLM)
AutoModelForCausalLMFactory.register_custom_model_cls(
"Qwen3_5MoeConfig", Qwen3_5MoeForConditionalGeneration
)
Qwen3_5MoeFactory.register_custom_model_cls("Qwen3_5MoeConfig", Qwen3_5MoeForConditionalGeneration)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch.nn as nn
from pydantic import Field
from torch.fx import GraphModule

from ...compile import ArgsKwargs, CompileBackendRegistry
from ...models.factory import ModelFactory
Expand All @@ -16,6 +17,15 @@
)


def _set_submodule(model: nn.Module, key: str, new_module: nn.Module) -> None:
"""Replace a nested submodule given a dotted key path (e.g. 'model.language_model')."""
parts = key.split(".")
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], new_module)


def _generate_default_piecewise_num_tokens(max_num_tokens: int) -> List[int]:
"""Generate default piecewise bucket sizes when none are specified.

Expand Down Expand Up @@ -138,13 +148,41 @@ def _get_args_kwargs(bs: int) -> ArgsKwargs:
config_dict = self.config.model_dump()
config_dict.update(config_overrides)

compiler_backend = CompileBackendRegistry.get(self.config.backend)(
mod,
get_args_kwargs_for_compile=_get_args_kwargs,
**extra_kwargs,
**config_dict,
)
mod_compiled = compiler_backend.compile()
# Walk the module tree and collect the top-level GraphModules to compile.
# Once a GM is found, its children are skipped (they're part of the GM).
compile_targets = []
seen = set()
if isinstance(mod, GraphModule):
compile_targets.append(("", mod))
seen.add("")
for name, submod in mod.named_modules():
if any(p == "" or name.startswith(p + ".") for p in seen):
continue
if isinstance(submod, GraphModule):
compile_targets.append((name, submod))
seen.add(name)

if compile_targets:
ad_logger.info(
f"CompileModel: compiling {len(compile_targets)} GraphModule(s): "
f"{[name or '(root)' for name, _ in compile_targets]}"
)

for gm_key, gm in compile_targets:
full_model = mod if gm_key else None
compiler_backend = CompileBackendRegistry.get(self.config.backend)(
gm,
get_args_kwargs_for_compile=_get_args_kwargs,
full_model=full_model,
**extra_kwargs,
**config_dict,
)
compiled_gm = compiler_backend.compile()
if gm_key:
_set_submodule(mod, gm_key, compiled_gm)
else:
mod = compiled_gm
mod_compiled = mod

# store info object about the transform
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,10 @@ def build_custom_args_for_linear(self, scales: Dict[str, Node]) -> Tuple:
return ([scales["input_scale"]], [scales["weight_scale"], scales["alpha"]], [], [])

def load_hook(self, state_dict, prefix, *args, weight_name):
# Prepend prefix so the hook works when the GraphModule is a submodule
# of the model on which load_state_dict is called (e.g., VLM models
# where the text model lives at model.language_model.*).
weight_name = prefix + weight_name
if weight_name in state_dict:
input_scale_name = weight_name.rsplit(".", 1)[0] + ".input_scale"
alpha_name = weight_name.rsplit(".", 1)[0] + ".alpha"
Expand Down
Loading
Loading