[None][feat] add MXFP8 weight format + CUTLASS W8A8 Linear and MoE#14962
[None][feat] add MXFP8 weight format + CUTLASS W8A8 Linear and MoE#14962WeiHaocheng wants to merge 3 commits into
Conversation
📝 WalkthroughWalkthroughThis PR introduces MXFP8 (OCP microscaling FP8) quantization support across TensorRT-LLM. The implementation adds MXFP8 mode flags to the quantization system, provides SM100 CUTLASS kernel templates for MXFP8×MXFP8 GEMM operations, extends MoE and Linear layer quantization frameworks, and includes comprehensive unit tests and HuggingFace model config loading support. ChangesMXFP8 Quantization Implementation
MXFP8 Test Coverage
🎯 4 (Complex) | ⏱️ ~60 minutes
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (12)
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cu (1)
2-2:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate copyright year to 2026.
This file is being meaningfully modified in 2026 and should reflect the year of latest modification. As per coding guidelines, "NVIDIA copyright header on ALL new files (update year on modified files)".
📅 Proposed fix
-/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/* + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cu` at line 2, Update the copyright header in fp4_gemm_bf16.cu to reflect the current year 2026 by changing the existing "Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved." line to use "2026" (e.g., "2020-2026" or "2026" per project convention) so the file header reflects the latest modification year.Source: Coding guidelines
cpp/tensorrt_llm/kernels/cutlass_kernels/include/fp4_gemm.h (1)
2-2:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate copyright year to 2026.
This file is being meaningfully modified in 2026 and should reflect the year of latest modification. As per coding guidelines, "NVIDIA copyright header on ALL new files (update year on modified files)".
📅 Proposed fix
-/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/* + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/include/fp4_gemm.h` at line 2, Update the copyright header string that currently reads "Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved." to reflect the latest modification year 2026; locate the top-of-file header containing "NVIDIA CORPORATION" in fp4_gemm.h and change the year range to "2020-2026" (or to include 2026 as appropriate) so the file header matches the 2026 update.Source: Coding guidelines
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cu (1)
2-2:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate copyright year to 2026.
This file is being meaningfully modified in 2026 and should reflect the year of latest modification. As per coding guidelines, "NVIDIA copyright header on ALL new files (update year on modified files)".
📅 Proposed fix
-/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/* + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cu` at line 2, Update the copyright header string "Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved." to reflect the current year by changing the year range to 2020-2026 (i.e., "Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved."); locate the header near the top of fp4_gemm_fp32.cu and replace only the year portion so the rest of the header text remains unchanged.Source: Coding guidelines
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cu (1)
2-2:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate copyright year to 2026.
This file is being meaningfully modified in 2026 and should reflect the year of latest modification. As per coding guidelines, "NVIDIA copyright header on ALL new files (update year on modified files)".
📅 Proposed fix
-/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/* + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cu` at line 2, Update the copyright header in the file's top comment: change the year range "2020-2023" to "2020-2026" so the NVIDIA copyright line at the top of fp4_gemm_fp16.cu reflects the 2026 modification.Source: Coding guidelines
cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h (1)
2-2:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate copyright year to 2026.
This file is being meaningfully modified in 2026 and should reflect the year of latest modification. As per coding guidelines, "NVIDIA copyright header on ALL new files (update year on modified files)".
📅 Proposed fix
-/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/* + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h` at line 2, Update the copyright header in cutlass_heuristic.h to reflect the current modification year (change "2020-2023" to "2020-2026") so the file header matches the 2026 update; modify the top-of-file comment in cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h where the NVIDIA COPYRIGHT line appears to include 2026.Source: Coding guidelines
cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (1)
2-2:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winCopyright year should be updated.
This file has meaningful modifications but the copyright year still shows 2020-2023. Per coding guidelines, it should be updated to include 2026.
Proposed fix
-* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp` at line 2, Update the copyright header at the top of the file that currently reads "Copyright (c) 2020-2023, NVIDIA CORPORATION." to include 2026 (e.g., "2020-2026" or "2020-2023, 2026" per project style); locate the header comment block (the top-of-file copyright line) and change the year range, and also scan the same file for any other occurrences of the old year range and update them consistently.cpp/include/tensorrt_llm/common/quantization.h (2)
2-2:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate the copyright year in this modified C++ header.
This file was meaningfully modified in this PR, so the header year should include the latest modification year.
As per coding guidelines, “NVIDIA copyright header on ALL new files (update year on modified files)”.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/include/tensorrt_llm/common/quantization.h` at line 2, The file header in quantization.h still shows the old copyright year; update the copyright line at the top of cpp/include/tensorrt_llm/common/quantization.h to include the current modification year (e.g., change the year in the existing copyright comment near the top of the file), ensuring the header reflects the latest modification year per project guidelines.Source: Coding guidelines
242-246:⚠️ Potential issue | 🟠 Major | ⚡ Quick winC++ quant-mode builders still cannot materialize MXFP8.
mxfp8()was added, butfromDescription(...)has no MXFP8 parameter andfromQuantAlgo(...)has no"MXFP8"branch. This leaves the new flag unreachable through the standard C++ config paths.Also applies to: 417-432
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/include/tensorrt_llm/common/quantization.h` around lines 242 - 246, The new MXFP8 quant mode (mxfp8) was added but is unreachable because static constexpr QuantMode fromDescription(...) and fromQuantAlgo(...) lack support for MXFP8; update fromDescription (the function signature shown) to include a boolean parameter (e.g., bool useMxfp8) and propagate it into the QuantMode construction logic, and update fromQuantAlgo(...) to handle the "MXFP8" string branch (returning the appropriate QuantMode or setting the corresponding flag) so that mxfp8() can be materialized via the C++ config paths; ensure you mirror the same changes at the other occurrence noted (around lines 417-432) so both builders accept and map the MXFP8 flag.cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h (1)
1-15:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate the NVIDIA copyright year in this modified header.
This file was meaningfully modified in 2026, but the header still ends at 2023. Please bump the year range before merge.
As per coding guidelines, modified source files must have the NVIDIA header "update year on modified files."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h` around lines 1 - 15, Update the NVIDIA copyright header year range in the file by changing the ending year "2023" to "2026" in the top-of-file NVIDIA copyright comment (the header block containing "Copyright (c) 2020-2023, NVIDIA CORPORATION"); ensure the modified header preserves the exact wording and license text aside from the year range.Source: Coding guidelines
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
1-15:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate copyright year to 2026.
The copyright header shows 2020-2023 but this file has been meaningfully modified in 2026 with MXFP8 support. Per coding guidelines, update the year to reflect the latest meaningful modification.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h` around lines 1 - 15, Update the copyright header range in the file-level comment block to include 2026 (e.g., change "2020-2023" to "2020-2026") so the top-of-file license notice reflects the 2026 MXFP8 modifications; confirm any other occurrences of the same year range in the file header are updated consistently.Source: Coding guidelines
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
1-15:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate copyright year to 2026.
The copyright header shows 2020-2023 but this file has been meaningfully modified in 2026 with MXFP8 dispatch support. Per coding guidelines, update the year to reflect the latest meaningful modification.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h` around lines 1 - 15, Update the copyright header year range in the file by replacing the old range "2020-2023" with the new range "2020-2026" in the top-of-file comment block; locate the comment block containing the string "Copyright (c) 2020-2023, NVIDIA CORPORATION." (in moe_gemm_template_dispatch_tma_ws.h) and change the year to reflect the 2026 modification.Source: Coding guidelines
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu (1)
1-15:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate the NVIDIA copyright year range.
This file is modified in this PR, but the header still ends at 2025. Please bump it to include 2026. As per coding guidelines, "NVIDIA copyright header on ALL new files (update year on modified files)" and "All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu` around lines 1 - 15, Update the top-of-file copyright header that currently reads "Copyright (c) 2020-2025, NVIDIA CORPORATION." to include 2026 (e.g., "Copyright (c) 2020-2026, NVIDIA CORPORATION."); locate the header comment block (the lines beginning with "Copyright (c) 2020-2025, NVIDIA CORPORATION.") in moe_kernels.cu and change the year range only, leaving the rest of the license text unchanged.Source: Coding guidelines
🧹 Nitpick comments (7)
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/mxfp8_mxfp8_gemm_template_sm100.h (1)
44-45: 💤 Low value
using namespacein header pollutes includers' namespace scope.These using-declarations at file scope in a header will inject
cuteandtensorrt_llm::kernels::cutlass_kernelssymbols into every translation unit that includes this file. Consider moving them inside thenamespace tensorrt_llm::kernels::cutlass_kernelsblock or removing them entirely and qualifying types explicitly.However, reviewing the sibling header
mxfp8_mxfp4_gemm_template_sm100.hshows the same pattern is already established, so this is consistent with existing code.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/mxfp8_mxfp8_gemm_template_sm100.h` around lines 44 - 45, The header currently injects symbols via "using namespace cute;" and "using namespace tensorrt_llm::kernels::cutlass_kernels;" at global scope which pollutes includers; fix by removing these file-scope using-declarations and either (a) move them inside the existing namespace tensorrt_llm::kernels::cutlass_kernels { ... } block so they only affect that namespace, or (b) delete them entirely and fully qualify types with the cute:: and tensorrt_llm::kernels::cutlass_kernels:: prefixes in the functions/classes in this header (look for references to cute symbols and types defined in tensorrt_llm::kernels::cutlass_kernels to update).tests/unittest/quantization/test_mxfp8_format.py (2)
37-57: ⚡ Quick winAdd MXFP8 parser failure-mode tests (invalid config inputs).
Please add negative tests for at least invalid
weight_block_sizeand invalidactivation_schemeso parser contract regressions are caught early.As per coding guidelines, tests should cover “happy path, important edge cases, and failure modes relevant to the feature or fix.”
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/quantization/test_mxfp8_format.py` around lines 37 - 57, Add negative unit tests to tests/unittest/quantization/test_mxfp8_format.py that call ModelConfig.load_hf_quant_config (from tensorrt_llm._torch.model_config) with invalid MXFP8 inputs: one test where "weight_block_size" is malformed (e.g., not a two-element list or contains non-integer/zero) and another where "activation_scheme" is an unsupported value; assert that the parser raises the expected exception (or returns an error) and that the exception message references the offending field so regressions in ModelConfig.load_hf_quant_config's MXFP8 parsing are caught.Source: Coding guidelines
1-57: QA list update is unnecessary for this PR scope.This file adds unit tests only (
tests/unittest/...), so integration QA test-list updates undertests/integration/test_lists/qa/are not needed.As per coding guidelines, QA list hygiene updates are required when integration/release test definitions are added or materially altered.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/quantization/test_mxfp8_format.py` around lines 1 - 57, The PR only adds unit tests (functions like test_mxfp8_quant_algo_exists, test_mxfp8_quant_mode_helpers, test_mxfp8_from_description, test_load_mxfp8_hf_quant_config) so remove any unrelated QA list changes from this PR; specifically undo or exclude modifications to the integration QA test-list that were added alongside these unit test files so the diff contains only the new unit tests and no integration QA list hygiene updates.Source: Coding guidelines
cpp/tensorrt_llm/thop/moeOp.cpp (1)
1071-1091: ⚡ Quick winMissing detailed shape validation for MXFP8 weight block scales.
The MXFP4 paths (e.g.,
isWMxfp4AMxfp8Quantat lines 1024-1046) validate the exact shapes of weight block tensors against expected dimensions including alignment and block scale vector size. This MXFP8 branch only checksdim() == 3.Without explicit shape validation, incorrectly-shaped tensors would only surface as TMA/kernel errors at runtime, making debugging harder.
Consider adding shape validation similar to the MXFP4 path:
constexpr int FP8_PER_INT32 = 4; TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank && fc1_weight_block.sizes()[1] == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * expand_ratio && fc1_weight_block.sizes()[2] * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX), "fc1_weight_block shape mismatch for MXFP8");🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/thop/moeOp.cpp` around lines 1071 - 1091, The MXFP8 branch only checks dim()==3 but needs explicit shape validation like the MXFP4 path; add TORCH_CHECKs for fc1_weight_block and fc2_weight_block using TmaWarpSpecializedGroupedGemmInput::alignToSfDim, MinNDimAlignmentMXFPX, MinKDimAlignmentMXFPX and MXFPXBlockScaleVectorSize (define FP8_PER_INT32 = 4) to assert sizes[0]==num_experts_on_rank, sizes[1]==alignToSfDim(inter_size, MinNDimAlignmentMXFPX)*expand_ratio, and sizes[2]*FP8_PER_INT32*MXFPXBlockScaleVectorSize==alignToSfDim(hidden_size, MinKDimAlignmentMXFPX) (and analogous checks for fc2_weight_block) so malformed weight-block tensors are caught early; add these checks before returning kernels::QuantParams::MXFP8MXFP4.tests/unittest/_torch/modules/test_mxfp8_moe.py (1)
117-180: QA test-list update is not required for this change set.This file adds unit tests under
tests/unittest/...; notests/integration/test_lists/qa/*entry is needed for this PR scope.As per coding guidelines, unittest-only changes should be explicitly called out as not requiring QA integration list updates.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/test_mxfp8_moe.py` around lines 117 - 180, Add an explicit note in the test file indicating this is a unittest-only change that does not require updates to the QA integration test lists: insert a short comment above the test function test_mxfp8_moe_forward_smoke (or at the top of tests/unittest/_torch/modules/test_mxfp8_moe.py) stating that no tests/integration/test_lists/qa/* entry is required for this PR (or alternatively add a pytest marker like `@pytest.mark.no_qa_required` and document that marker in the repo), so reviewers and CI maintainers have an explicit signal that QA list changes are not needed.Source: Coding guidelines
tests/unittest/_torch/modules/test_mxfp8_linear.py (1)
1-149: QA list update status: unnecessary for this change set.This PR slice adds/updates only unit tests under
tests/unittest/...; notests/integration/defs/...changes, so QA integration test-list updates are not required.As per coding guidelines, QA test-list hygiene updates are only needed when integration/release-run definitions are added or materially changed.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/test_mxfp8_linear.py` around lines 1 - 149, Remove the extraneous QA list update status comment block that was added to the test file; specifically delete the standalone text block "**QA list update status: unnecessary for this change set.**" (and its surrounding explanatory lines) from tests/unittest/_torch/modules/test_mxfp8_linear.py so the file contains only test code and relevant comments (no QA-process note).Source: Coding guidelines
tensorrt_llm/_torch/modules/linear.py (1)
2788-2791: ⚡ Quick winAvoid per-call allocation of
global_scalein the CUTLASS path.Allocating
torch.ones([1], ...)on every forward adds unnecessary hot-path overhead. Reuse a preallocated tensor/buffer initialized increate_weights().Suggested fix
def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): @@ if self.use_cutlass: @@ module.weight_scale = Parameter(torch.empty( [self._swizzled_scale_size(out_features, in_features)], dtype=torch.uint8), requires_grad=False) + module.scale_one = Parameter( + torch.ones([1], dtype=torch.float32), requires_grad=False + ) @@ if self.use_cutlass: @@ - global_scale = torch.ones([1], - dtype=torch.float32, - device=input.device) + global_scale = module.scale_one.to(device=input.device) output = torch.ops.trtllm.mxfp8_mxfp8_gemm(act_e4m3, act_sf, module.weight, module.weight_scale, global_scale, module.dtype)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/linear.py` around lines 2788 - 2791, The CUTLASS path currently allocates global_scale with torch.ones([1], dtype=torch.float32, device=input.device) on every forward; move this allocation into the module initialization (create_weights()) as a persistent, non‑grad tensor (e.g., self.global_scale) and reuse it in the forward/CUTLASS branch instead of per-call torch.ones; ensure the stored tensor is created with the correct device/dtype (or moved to input.device at runtime if necessary), has requires_grad=False, and is shaped or expanded to match the gemm call interface used by torch.ops.trtllm.mxfp8_mxfp8_gemm so the forward uses self.global_scale (or self.global_scale.to(input.device) if needed) rather than allocating a new tensor each call.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/include/tensorrt_llm/common/quantization.h`:
- Around line 137-140: The mxfp8() static method in QuantMode currently uses
BaseType(1u) << 17 which is out-of-sync with the Python QuantMode bit position;
update QuantMode::mxfp8() to use the same bit shift value used by the Python
enum (i.e., change the shift from 17 to the Python-defined bit index for MXFP8)
so serialized/deserialized raw QuantMode values and hasMxfp8() checks match
across languages, and run/update any cross-language serialization tests to
verify parity.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h`:
- Around line 861-878: The MXFPX return in getScalingType() is too broad:
tighten the FP8 branch so MXFPX is returned only for the e4m3/e4m3
instantiation. In getScalingType(), keep the existing checks (use_wfp4afp8,
use_fp4, use_fp8 && std::is_same_v<T, WeightType>) but add an additional
constexpr guard that verifies the template instantiation is specifically the
e4m3/e4m3 variant before honoring use_mxfp8_weight_scaling_; otherwise return
NONE. Reference: getScalingType(), use_fp8, std::is_same_v<T, WeightType>, and
use_mxfp8_weight_scaling_.
In `@tensorrt_llm/_torch/custom_ops/torch_custom_ops.py`:
- Around line 131-143: The cached runner key omits the use_fused_finalize flag
so MoERunner.runner_dict can return a FusedMoeRunner configured with the wrong
finalize-fusion setting; update the instance_key tuple construction in
torch_custom_ops.py to include use_fused_finalize (matching unique_id()
behavior) before using it to index MoERunner.runner_dict and replicate this
change for the other similar block around the 149-167 region so that
instance_key fully reflects all constructor parameters passed to
torch.classes.trtllm.FusedMoeRunner.
In `@tensorrt_llm/_torch/model_config.py`:
- Around line 478-495: In the MXFP8 branch (where you set
quant_config.quant_algo = QuantAlgo.MXFP8), read and validate
hf_quant_config.get("activation_scheme") and reject any value that is not the
expected dynamic scheme for MXFP8 (e.g., assert activation_scheme == "dynamic"
or compare against an allowed set) so mismatched checkpoints fail fast; if
valid, assign the value to quant_config.activation_scheme (or map it to the
internal enum) and raise a clear error message referencing
hf_quant_config["activation_scheme"] and QuantAlgo.MXFP8 when validation fails.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 4566-4611: The code currently treats missing MXFP8 scale keys
(returned None from _get_scale_key) as optional and then copies into
preallocated uninitialized buffers (dst_w3_w1_u8/dst_w2_u8), which yields
nondeterministic results; update the block inside the VANILLA branch (where
_get_scale_key is used) to fail fast: after computing w1_key, w3_key, w2_key
(and before using weights or allocating/copying into
module.w3_w1_weight_scale/module.w2_weight_scale), check if any key is None and
raise a clear RuntimeError (or assert) that includes the module identity and
expert_id and lists which of w1/w3/w2 scale keys are missing so the caller sees
a deterministic load-time error instead of undefined bytes.
- Around line 4543-4544: The online EPLB path currently only warns in
_online_eplb_not_verified and then calls setup_quant_scales, but it never stages
or finalizes the shared MXFP8 scales (w3_w1_weight_scale / w2_weight_scale),
which can leave stale buffers after expert migration; fix this by either (A)
mirroring the shared-scale staging/finalization used by the other block-scale
methods (ensure the same code paths that manage staging, copying, and finalizing
shared w3_w1_weight_scale and w2_weight_scale are invoked for the online EPLB
branch alongside setup_quant_scales), or (B) immediately mark/short-circuit the
online EPLB flow as NOT_SUPPORTED (e.g., have _online_eplb_not_verified raise or
set a NOT_SUPPORTED flag) until shared MXFP8 scale migration is implemented, so
no unsafe path executes. Ensure references to _online_eplb_not_verified,
setup_quant_scales, w3_w1_weight_scale, and w2_weight_scale are updated
accordingly.
- Around line 4528-4540: The current multi-line nn.Parameter(...) calls for
w3_w1_weight_scale and w2_weight_scale produce hanging-indent E126; reformat
their argument lists so continuation lines align under the opening parenthesis
or use a consistent hanging indent (e.g., align the second line with the first
argument or place each argument on its own line indented once) for the
nn.Parameter(torch.empty(..., dtype=self.BLOCK_SCALES_DTYPE),
requires_grad=False) expressions and the
module.register_parameter("w3_w1_weight_scale", w3_w1_weight_scale) call to
remove the hanging-indent errors while keeping the same arguments and types
(references: w3_w1_weight_scale, w2_weight_scale, nn.Parameter, torch.empty,
self.BLOCK_SCALES_DTYPE, module.register_parameter).
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 2864-2877: The code calls load_weight_scales(...) then
torch.cat(scales, dim=0) without checking that all fused Q/K/V scale tensors are
present; add a fast-fail check in load_weights_fused_qkv_linear (and the
adjacent fused-QKV loader at the other occurrence) that verifies the returned
scales list contains three non-None tensors (and/or that none are missing and
all have compatible shapes) before concatenation, and raise a clear ValueError
mentioning module name/module.tp_rank and which scale(s) are missing if the
validation fails; keep using _store_scale only after this validation succeeds.
In `@tests/unittest/_torch/modules/test_mxfp8_linear.py`:
- Around line 118-122: Replace the hardcoded CKPT path with an
environment-driven value: read a root from os.getenv("LLM_MODELS_ROOT") (or a
similarly named env var) and construct CKPT via os.path.join(root,
"hidden_trail", "minimax-m3-preview_vv1"); update the skip guard used by
test_load_real_mxfp8_dense_layer (and the other block at lines 132-142) to skip
when the env var is empty or the constructed CKPT directory does not exist (use
os.path.isdir) and ensure the test message mentions the env var so CI/devs know
how to enable the test.
In `@tests/unittest/_torch/modules/test_mxfp8_moe.py`:
- Around line 40-47: The SM gating is too broad (sm_major < 10) and should be
tightened to only allow the exact supported SMs (sm100 and sm103); update the
checks in test_mxfp8_moe_supported_on_sm_in_table and the nearby tests that
currently use sm_major to instead read both torch.cuda.get_device_capability()
components and only proceed when (sm_major, sm_minor) equals (10, 0) or (10, 3)
before calling CutlassFusedMoE.can_implement(QuantAlgo.MXFP8), skipping
otherwise.
---
Outside diff comments:
In `@cpp/include/tensorrt_llm/common/quantization.h`:
- Line 2: The file header in quantization.h still shows the old copyright year;
update the copyright line at the top of
cpp/include/tensorrt_llm/common/quantization.h to include the current
modification year (e.g., change the year in the existing copyright comment near
the top of the file), ensuring the header reflects the latest modification year
per project guidelines.
- Around line 242-246: The new MXFP8 quant mode (mxfp8) was added but is
unreachable because static constexpr QuantMode fromDescription(...) and
fromQuantAlgo(...) lack support for MXFP8; update fromDescription (the function
signature shown) to include a boolean parameter (e.g., bool useMxfp8) and
propagate it into the QuantMode construction logic, and update
fromQuantAlgo(...) to handle the "MXFP8" string branch (returning the
appropriate QuantMode or setting the corresponding flag) so that mxfp8() can be
materialized via the C++ config paths; ensure you mirror the same changes at the
other occurrence noted (around lines 417-432) so both builders accept and map
the MXFP8 flag.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp`:
- Line 2: Update the copyright header at the top of the file that currently
reads "Copyright (c) 2020-2023, NVIDIA CORPORATION." to include 2026 (e.g.,
"2020-2026" or "2020-2023, 2026" per project style); locate the header comment
block (the top-of-file copyright line) and change the year range, and also scan
the same file for any other occurrences of the old year range and update them
consistently.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h`:
- Line 2: Update the copyright header in cutlass_heuristic.h to reflect the
current modification year (change "2020-2023" to "2020-2026") so the file header
matches the 2026 update; modify the top-of-file comment in
cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h where the NVIDIA
COPYRIGHT line appears to include 2026.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cu`:
- Line 2: Update the copyright header in fp4_gemm_bf16.cu to reflect the current
year 2026 by changing the existing "Copyright (c) 2020-2023, NVIDIA CORPORATION.
All rights reserved." line to use "2026" (e.g., "2020-2026" or "2026" per
project convention) so the file header reflects the latest modification year.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cu`:
- Line 2: Update the copyright header in the file's top comment: change the year
range "2020-2023" to "2020-2026" so the NVIDIA copyright line at the top of
fp4_gemm_fp16.cu reflects the 2026 modification.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cu`:
- Line 2: Update the copyright header string "Copyright (c) 2020-2023, NVIDIA
CORPORATION. All rights reserved." to reflect the current year by changing the
year range to 2020-2026 (i.e., "Copyright (c) 2020-2026, NVIDIA CORPORATION.
All rights reserved."); locate the header near the top of fp4_gemm_fp32.cu and
replace only the year portion so the rest of the header text remains unchanged.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/include/fp4_gemm.h`:
- Line 2: Update the copyright header string that currently reads "Copyright (c)
2020-2023, NVIDIA CORPORATION. All rights reserved." to reflect the latest
modification year 2026; locate the top-of-file header containing "NVIDIA
CORPORATION" in fp4_gemm.h and change the year range to "2020-2026" (or to
include 2026 as appropriate) so the file header matches the 2026 update.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h`:
- Around line 1-15: Update the NVIDIA copyright header year range in the file by
changing the ending year "2023" to "2026" in the top-of-file NVIDIA copyright
comment (the header block containing "Copyright (c) 2020-2023, NVIDIA
CORPORATION"); ensure the modified header preserves the exact wording and
license text aside from the year range.
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 1-15: Update the copyright header year range in the file by
replacing the old range "2020-2023" with the new range "2020-2026" in the
top-of-file comment block; locate the comment block containing the string
"Copyright (c) 2020-2023, NVIDIA CORPORATION." (in
moe_gemm_template_dispatch_tma_ws.h) and change the year to reflect the 2026
modification.
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h`:
- Around line 1-15: Update the copyright header range in the file-level comment
block to include 2026 (e.g., change "2020-2023" to "2020-2026") so the
top-of-file license notice reflects the 2026 MXFP8 modifications; confirm any
other occurrences of the same year range in the file header are updated
consistently.
In `@cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu`:
- Around line 1-15: Update the top-of-file copyright header that currently reads
"Copyright (c) 2020-2025, NVIDIA CORPORATION." to include 2026 (e.g., "Copyright
(c) 2020-2026, NVIDIA CORPORATION."); locate the header comment block (the lines
beginning with "Copyright (c) 2020-2025, NVIDIA CORPORATION.") in moe_kernels.cu
and change the year range only, leaving the rest of the license text unchanged.
---
Nitpick comments:
In
`@cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/mxfp8_mxfp8_gemm_template_sm100.h`:
- Around line 44-45: The header currently injects symbols via "using namespace
cute;" and "using namespace tensorrt_llm::kernels::cutlass_kernels;" at global
scope which pollutes includers; fix by removing these file-scope
using-declarations and either (a) move them inside the existing namespace
tensorrt_llm::kernels::cutlass_kernels { ... } block so they only affect that
namespace, or (b) delete them entirely and fully qualify types with the cute::
and tensorrt_llm::kernels::cutlass_kernels:: prefixes in the functions/classes
in this header (look for references to cute symbols and types defined in
tensorrt_llm::kernels::cutlass_kernels to update).
In `@cpp/tensorrt_llm/thop/moeOp.cpp`:
- Around line 1071-1091: The MXFP8 branch only checks dim()==3 but needs
explicit shape validation like the MXFP4 path; add TORCH_CHECKs for
fc1_weight_block and fc2_weight_block using
TmaWarpSpecializedGroupedGemmInput::alignToSfDim, MinNDimAlignmentMXFPX,
MinKDimAlignmentMXFPX and MXFPXBlockScaleVectorSize (define FP8_PER_INT32 = 4)
to assert sizes[0]==num_experts_on_rank, sizes[1]==alignToSfDim(inter_size,
MinNDimAlignmentMXFPX)*expand_ratio, and
sizes[2]*FP8_PER_INT32*MXFPXBlockScaleVectorSize==alignToSfDim(hidden_size,
MinKDimAlignmentMXFPX) (and analogous checks for fc2_weight_block) so malformed
weight-block tensors are caught early; add these checks before returning
kernels::QuantParams::MXFP8MXFP4.
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 2788-2791: The CUTLASS path currently allocates global_scale with
torch.ones([1], dtype=torch.float32, device=input.device) on every forward; move
this allocation into the module initialization (create_weights()) as a
persistent, non‑grad tensor (e.g., self.global_scale) and reuse it in the
forward/CUTLASS branch instead of per-call torch.ones; ensure the stored tensor
is created with the correct device/dtype (or moved to input.device at runtime if
necessary), has requires_grad=False, and is shaped or expanded to match the gemm
call interface used by torch.ops.trtllm.mxfp8_mxfp8_gemm so the forward uses
self.global_scale (or self.global_scale.to(input.device) if needed) rather than
allocating a new tensor each call.
In `@tests/unittest/_torch/modules/test_mxfp8_linear.py`:
- Around line 1-149: Remove the extraneous QA list update status comment block
that was added to the test file; specifically delete the standalone text block
"**QA list update status: unnecessary for this change set.**" (and its
surrounding explanatory lines) from
tests/unittest/_torch/modules/test_mxfp8_linear.py so the file contains only
test code and relevant comments (no QA-process note).
In `@tests/unittest/_torch/modules/test_mxfp8_moe.py`:
- Around line 117-180: Add an explicit note in the test file indicating this is
a unittest-only change that does not require updates to the QA integration test
lists: insert a short comment above the test function
test_mxfp8_moe_forward_smoke (or at the top of
tests/unittest/_torch/modules/test_mxfp8_moe.py) stating that no
tests/integration/test_lists/qa/* entry is required for this PR (or
alternatively add a pytest marker like `@pytest.mark.no_qa_required` and document
that marker in the repo), so reviewers and CI maintainers have an explicit
signal that QA list changes are not needed.
In `@tests/unittest/quantization/test_mxfp8_format.py`:
- Around line 37-57: Add negative unit tests to
tests/unittest/quantization/test_mxfp8_format.py that call
ModelConfig.load_hf_quant_config (from tensorrt_llm._torch.model_config) with
invalid MXFP8 inputs: one test where "weight_block_size" is malformed (e.g., not
a two-element list or contains non-integer/zero) and another where
"activation_scheme" is an unsupported value; assert that the parser raises the
expected exception (or returns an error) and that the exception message
references the offending field so regressions in
ModelConfig.load_hf_quant_config's MXFP8 parsing are caught.
- Around line 1-57: The PR only adds unit tests (functions like
test_mxfp8_quant_algo_exists, test_mxfp8_quant_mode_helpers,
test_mxfp8_from_description, test_load_mxfp8_hf_quant_config) so remove any
unrelated QA list changes from this PR; specifically undo or exclude
modifications to the integration QA test-list that were added alongside these
unit test files so the diff contains only the new unit tests and no integration
QA list hygiene updates.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 55a06ef3-48e5-43fa-bc46-7f4bd2b3893b
📒 Files selected for processing (32)
cpp/include/tensorrt_llm/common/quantization.hcpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cppcpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.hcpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cucpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cucpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cucpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.hcpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/mxfp8_mxfp8_gemm_template_sm100.hcpp/tensorrt_llm/kernels/cutlass_kernels/include/fp4_gemm.hcpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.hcpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inlcpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.hcpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.hcpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cucpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.pycpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/moeOp.cppcpp/tensorrt_llm/thop/mxfp8Gemm.cppruff-legacy-baseline.jsontensorrt_llm/_torch/custom_ops/torch_custom_ops.pytensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.pytensorrt_llm/_torch/modules/fused_moe/interface.pytensorrt_llm/_torch/modules/fused_moe/quantization.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/modules/mxfp8_utils.pytensorrt_llm/quantization/mode.pytests/unittest/_torch/modules/test_mxfp8_linear.pytests/unittest/_torch/modules/test_mxfp8_moe.pytests/unittest/quantization/__init__.pytests/unittest/quantization/test_mxfp8_format.pytests/unittest/trt/quantization/test_mode.py
66ccfff to
c41c23a
Compare
|
@djns99 Could you help to review the kernel code? |
djns99
left a comment
There was a problem hiding this comment.
I took a look at MOE code changes, will leave the rest to others. Thank you for this contribution!
No functional issues, but a lot of tidy up and refactoring to put things in the right places
c41c23a to
1e62250
Compare
|
/bot run |
|
PR_Github #54542 [ run ] triggered by Bot. Commit: |
|
PR_Github #54542 [ run ] completed with state
|
73340aa to
7ac1c41
Compare
|
/bot run |
|
PR_Github #54608 [ run ] triggered by Bot. Commit: |
|
PR_Github #54608 [ run ] completed with state
|
djns99
left a comment
There was a problem hiding this comment.
Thanks for addressing my comments. A few extra tidy ups, but overall MOE is looking relatively good now
7ac1c41 to
f4bec5c
Compare
|
/bot run --disable-fail-fast |
f4bec5c to
a46a1f0
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #55001 [ run ] triggered by Bot. Commit: |
|
PR_Github #55001 [ run ] completed with state
|
e113da0 to
dd28a09
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #55014 [ run ] triggered by Bot. Commit: |
|
PR_Github #55014 [ run ] completed with state
|
dd28a09 to
96f5476
Compare
96f5476 to
afce527
Compare
|
/bot run --disable-fail-fast |
afce527 to
af94a3f
Compare
|
PR_Github #55132 [ run ] triggered by Bot. Commit: |
13ac76d to
0c7693c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #55202 [ run ] triggered by Bot. Commit: |
|
PR_Github #55132 [ run ] completed with state |
|
PR_Github #55202 [ run ] completed with state
|
0c7693c to
ffc4188
Compare
Add first-class MXFP8 (OCP microscaling: e4m3 elements + per-32-element
UE8M0 block scales) weight quantization to the PyTorch backend and
execute MXFP8xMXFP8 W8A8 GEMMs through CUTLASS on Blackwell sm_100/103
for both dense Linear layers and fused MoE.
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
ffc4188 to
1c7fde6
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #55256 [ run ] triggered by Bot. Commit: |
Summary by CodeRabbit
Release Notes
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.