Skip to content

Commit 6fc6ac3

Browse files
committed
Merge branch 'chenjiel/support_mtp_qwen_next' of github.com:NVIDIA/TensorRT-Model-Optimizer into chenjiel/support_mtp_qwen_next
2 parents 219bb2d + 46d124d commit 6fc6ac3

25 files changed

Lines changed: 1794 additions & 490 deletions

File tree

.github/workflows/example_tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,14 @@ jobs:
8686
pip_install_extras: "[hf,dev-test]"
8787
runner: linux-amd64-gpu-h100-latest-2
8888

89-
##### Speculative Decoding Example Tests (requires 25.08 image) #####
89+
##### Speculative Decoding Example Tests (requires 26.01 image) #####
9090
speculative-decoding-pr:
9191
needs: [check-file-changes, wait-checks]
9292
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
9393
uses: ./.github/workflows/_example_tests_runner.yml
9494
secrets: inherit
9595
with:
96-
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
96+
docker_image: "nvcr.io/nvidia/pytorch:26.01-py3"
9797
example: speculative_decoding
9898
pip_install_extras: "[hf,dev-test]"
9999
runner: linux-amd64-gpu-l4-latest-1
@@ -103,7 +103,7 @@ jobs:
103103
uses: ./.github/workflows/_example_tests_runner.yml
104104
secrets: inherit
105105
with:
106-
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
106+
docker_image: "nvcr.io/nvidia/pytorch:26.01-py3"
107107
example: speculative_decoding
108108
pip_install_extras: "[hf,dev-test]"
109109
runner: linux-amd64-gpu-h100-latest-2

examples/speculative_decoding/eagle_utils.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from packaging.version import Version
3232
from PIL import Image
3333
from scripts.ar_validate import validate_ar
34-
from torch.distributed.tensor.experimental._attention import _SDPAMerger
3534
from torch.utils.data import Dataset
3635
from transformers import AutoProcessor, Trainer, TrainerCallback
3736
from transformers.trainer_pt_utils import LabelSmoother
@@ -581,7 +580,7 @@ def on_step_end(self, args, state, control, **kwargs):
581580
def get_patched_templated_ring_attn(orig_templated_attn: Callable):
582581
"""
583582
Return patched version of
584-
torch.distributed.tensor.experimental._attention._templated_ring_attention
583+
torch.distributed.tensor.experimental._context_parallel._attention._templated_ring_attention
585584
to support TTT.
586585
"""
587586

@@ -630,7 +629,7 @@ def _get_sharded_ttt_msk(i, rank, size, q_len, ttt_step, dtype):
630629
return attn_bias
631630

632631
def patched_templated_attn(*args, **kwargs):
633-
"""Patched version of torch.distributed.tensor.experimental._attention._templated_ring_attention."""
632+
"""Patched version of _templated_ring_attention."""
634633
# Get original attention op
635634
# Sensitive to impl of _templated_ring_attention
636635
original_op = args[2]
@@ -678,40 +677,35 @@ def patch_ring_attention_for_ttt():
678677
"""Patch torch ring attention to support context parallelism for TTT."""
679678
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
680679

681-
if not (
682-
Version(torch.__version__) > Version("2.7.1")
683-
and Version(torch.__version__) < Version("2.9.0")
684-
):
680+
if Version(torch.__version__) < Version("2.10.0"):
685681
raise RuntimeError(
686-
f"Context parallel TTT only supported for PyTorch 2.8.0 now. "
682+
f"Context parallel TTT only supported for PyTorch >= 2.10.0. "
687683
f"Got {torch.__version__}. "
688-
f"Please use nvcr.io/nvidia/pytorch:25.08-py3 or torch 2.8.0 or cp_size=1."
684+
f"Please use torch 2.10.0 or cp_size=1."
689685
)
690686

687+
from torch.distributed.tensor.experimental._context_parallel import _attention
688+
691689
# 1. Disable load balance, which is designed for causal mask.
692690
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
693-
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance = False
691+
_attention._cp_options.enable_load_balance = False
694692

695693
# 2. Patch templated ring attention for TTT mask.
696-
original_templated_ring_attention = (
697-
torch.distributed.tensor.experimental._attention._templated_ring_attention
698-
)
699-
original_templated_ring_attention_backward = (
700-
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward
701-
)
702-
torch.distributed.tensor.experimental._attention._templated_ring_attention = (
703-
get_patched_templated_ring_attn(original_templated_ring_attention)
694+
original_templated_ring_attention = _attention._templated_ring_attention
695+
original_templated_ring_attention_backward = _attention._templated_ring_attention_backward
696+
_attention._templated_ring_attention = get_patched_templated_ring_attn(
697+
original_templated_ring_attention
704698
)
705-
torch.distributed.tensor.experimental._attention._templated_ring_attention_backward = (
706-
get_patched_templated_ring_attn(original_templated_ring_attention_backward)
699+
_attention._templated_ring_attention_backward = get_patched_templated_ring_attn(
700+
original_templated_ring_attention_backward
707701
)
708702

709703
# 3. Patch merger to skip the blank shard to avoid difference in output.
710-
original_sdpa_merger_step = _SDPAMerger.step
704+
original_sdpa_merger_step = _attention._SDPAMerger.step
711705

712706
def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
713707
if lse.sum() <= 0:
714708
return
715709
return original_sdpa_merger_step(self, out, lse, partial)
716710

717-
_SDPAMerger.step = patched_sdpa_merger_step
711+
_attention._SDPAMerger.step = patched_sdpa_merger_step

experimental/README.md

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Experimental Optimization Techniques
2+
3+
Experimental optimization algorithms and research prototypes under active development.
4+
5+
## Purpose
6+
7+
For new optimization techniques (quantization, pruning, sparsity, etc.) that are:
8+
9+
- Novel or research-stage algorithms
10+
- Not yet production-ready
11+
- May have unstable APIs
12+
13+
**⚠️ Warning**: Experimental features are not guaranteed to work across releases. APIs may change or features may be removed without notice. Use at your own risk.
14+
15+
## Requirements
16+
17+
Each experimental technique must include:
18+
19+
- **README.md** - Explains what the technique does, how to use it, current status, model support, and references
20+
- **Working code** - Clear, readable implementation
21+
- **Comprehensive tests** - Good test coverage demonstrating correctness
22+
- **Detailed documentation** - Clear docs on usage, APIs, and behavior
23+
- **Example** - Demonstrating usage
24+
- **Model support list** - Which models/frameworks are supported
25+
- **Deployment info** - Supported deployment frameworks (TensorRT-LLM, vLLM, SGLang, etc.) and whether custom kernels are required
26+
- **requirements.txt** - Additional dependencies beyond base modelopt
27+
- **License headers** - Apache 2.0 headers on all Python files
28+
29+
## Example Structures
30+
31+
Organize your code however makes sense. Here are some examples:
32+
33+
**Simple flat structure:**
34+
35+
```text
36+
experimental/my_technique/
37+
├── README.md
38+
├── requirements.txt
39+
├── my_technique.py
40+
├── test_my_technique.py
41+
└── example.py
42+
```
43+
44+
**Package structure:**
45+
46+
```text
47+
experimental/my_technique/
48+
├── README.md
49+
├── requirements.txt
50+
├── my_technique/
51+
│ ├── __init__.py
52+
│ ├── core.py
53+
│ └── config.py
54+
├── tests/
55+
│ └── test_core.py
56+
└── examples/
57+
└── example_usage.py
58+
```
59+
60+
## Quality Standards
61+
62+
Experimental code must meet quality standards:
63+
64+
- Comprehensive test coverage required
65+
- Clear documentation required
66+
- Pass all pre-commit checks
67+
68+
## PR Guidelines
69+
70+
Keep PRs focused and reviewable:
71+
72+
- **Split large features**: Break complex techniques into multiple PRs if needed
73+
- **Reasonable scope**: PRs with tens of thousands of lines are difficult to review
74+
- **Incremental development**: Consider submitting core functionality first, then enhancements
75+
- If your technique is large, discuss the implementation plan in an issue first
76+
77+
## Example Documentation Template
78+
79+
Your technique's README.md should include:
80+
81+
```markdown
82+
# Your Technique Name
83+
84+
Brief description of the optimization technique.
85+
86+
## Model Support
87+
88+
| Model/Framework | Supported | Notes |
89+
|-----------------|-----------|-------|
90+
| LLMs (Llama, GPT, etc.) || Tested on Llama 3.1 |
91+
| Diffusion Models || Not yet supported |
92+
| Vision Models || Experimental |
93+
94+
## Deployment
95+
96+
| Framework | Supported | Notes |
97+
|-----------|-----------|-------|
98+
| TensorRT-LLM || Requires custom kernel |
99+
| vLLM || Not yet supported |
100+
| SGLang || Uses standard ops |
101+
102+
## Usage
103+
104+
\`\`\`python
105+
from experimental.my_technique import my_optimize
106+
...
107+
\`\`\`
108+
109+
## Status
110+
111+
Current state: Prototype
112+
113+
Known issues:
114+
- Issue 1
115+
- Issue 2
116+
117+
## References
118+
119+
- [Paper](link)
120+
- [Code repository](link)
121+
- [Project page](link)
122+
- [Related work](link)
123+
```
124+
125+
## Path to Production
126+
127+
When a technique is ready for production (proven effective, stable API, full tests, comprehensive docs), it can be promoted to the main `modelopt` package.
128+
129+
**Contributors**: Open an issue proposing graduation with evidence of effectiveness and stability.
130+
131+
**Users**: If you find an experimental feature valuable, open a GitHub issue requesting promotion to production. User demand is a key signal for production readiness.
132+
133+
## Questions?
134+
135+
Open a GitHub issue with `[experimental]` prefix.

experimental/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Experimental optimization techniques for Model Optimizer.
17+
18+
This package contains experimental and research-stage optimization algorithms
19+
that are under active development. APIs may change without notice.
20+
21+
Warning:
22+
Code in this package is experimental and not covered by semantic versioning.
23+
Use at your own risk in production environments.
24+
"""
25+
26+
import warnings
27+
28+
warnings.warn(
29+
"The 'experimental' package contains unstable APIs that may change. "
30+
"Use at your own risk in production environments.",
31+
FutureWarning,
32+
stacklevel=2,
33+
)
34+
35+
__all__ = []

modelopt/onnx/autocast/referencerunner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,16 @@ def _validate_inputs(self, data_loader):
8989
if sorted(self.input_names) != sorted(data_loader[0].keys()):
9090
raise ValueError("Input names from ONNX model do not match provided input names.")
9191
for inp_name, inp_shape in data_loader[0].items():
92-
if self.input_shapes[inp_name] != list(inp_shape.shape):
92+
# Get model and data shapes as numpy arrays
93+
inp_shape_model = np.array(self.input_shapes[inp_name])
94+
inp_shape_data = np.array(inp_shape.shape)
95+
# Compare input rank
96+
raise_value_error = len(inp_shape_model) != len(inp_shape_data)
97+
if not raise_value_error:
98+
# Compare input shape, skipping check for unknown dimensions
99+
mask = inp_shape_model > 0
100+
raise_value_error = np.any(inp_shape_model[mask] != inp_shape_data[mask])
101+
if raise_value_error:
93102
raise ValueError(
94103
f"Input shape from '{inp_name}' does not match provided input shape: "
95104
f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. "

modelopt/onnx/op_types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,22 @@ def get_set_ops():
367367
"Unique",
368368
"NonZero",
369369
}
370+
371+
372+
def get_symmetric_ops():
373+
"""Returns set of commutative/symmetric operations where operand order doesn't matter."""
374+
return {
375+
"Add",
376+
"Mul",
377+
"And",
378+
"Or",
379+
"Xor",
380+
"Equal",
381+
"Max",
382+
"Min",
383+
"Sum",
384+
"Mean",
385+
"BitwiseAnd",
386+
"BitwiseOr",
387+
"BitwiseXor",
388+
}

0 commit comments

Comments
 (0)