Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .claude/skills/trtllm-moe-develop/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,26 @@ Checklist:
- Existing legacy `forward` methods can be read for compatibility context, but
they are not the default pattern for new backend work.

### Imported Kernel ABI Checklist

When importing or wrapping an upstream kernel, derive the TRT-LLM adapter
contract from the lowest-level kernel consumer. Comments, docs, design notes,
and parameter names are useful hints, but they are not proof of the runtime ABI.

- Derive weight shape and layout from the kernel entrypoint, `make_layout`, TMA,
MMA/GEMM transforms, and stride usage. Record required tensor shape, stride,
physical storage layout, and boundary view layout.
- Derive alpha and scale semantics from kernel consumption points. Trace where
alpha, norm constants, block scales, activation scales, and weight scales are
loaded and multiplied before deciding how upper layers compute or pack them.
Treat weight bytes, block scales/SF, and global alpha/norm constants as
separate contracts.
- Design the upper-layer adapter from the kernel ABI upward. Map each kernel
input/output to an adapter responsibility: storage tensor, view/transposition,
dtype reinterpretation, padding, scale packing, workspace ownership,
synchronization, and output reduction. Validate parity with upstream
invocation dumps, not just final output.

### Quantization And Weights

Role:
Expand Down
38 changes: 38 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,25 @@ common-files: &common_files |
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py |
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py |
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py |
tensorrt_llm/_torch/cute_dsl_utils.py |
tensorrt_llm/_torch/debug/__init__.py |
tensorrt_llm/_torch/debug/debug_hook.py |
Expand Down Expand Up @@ -1658,6 +1677,25 @@ legacy-files: &legacy_files |
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py |
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py |
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py |
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py |
tensorrt_llm/_torch/cute_dsl_utils.py |
tensorrt_llm/_torch/debug/__init__.py |
tensorrt_llm/_torch/debug/debug_hook.py |
Expand Down
19 changes: 19 additions & 0 deletions legacy-files.txt
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,25 @@ tensorrt_llm/_torch/cute_dsl_kernels/blackwell/__init__.py
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py
tensorrt_llm/_torch/cute_dsl_utils.py
tensorrt_llm/_torch/debug/__init__.py
tensorrt_llm/_torch/debug/debug_hook.py
Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,25 @@ exclude = [
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py",
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py",
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py",
"tensorrt_llm/_torch/cute_dsl_utils.py",
"tensorrt_llm/_torch/debug/__init__.py",
"tensorrt_llm/_torch/debug/debug_hook.py",
Expand Down
19 changes: 19 additions & 0 deletions ruff-legacy.toml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,25 @@ include = [
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py",
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py",
"tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py",
"tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py",
"tensorrt_llm/_torch/cute_dsl_utils.py",
"tensorrt_llm/_torch/debug/__init__.py",
"tensorrt_llm/_torch/debug/debug_hook.py",
Expand Down
3 changes: 1 addition & 2 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,8 +1511,7 @@ def _create_tensor_like(self, origin_tensor: torch.Tensor,
if dtype == torch.float4_e2m1fn_x2:
return (torch.rand(shapes, device=device) * 10 - 5).to(
torch.uint8).view(dtype)
else:
return (torch.rand(shapes, device=device) * 10 - 5).to(dtype)
return (torch.rand(shapes, device=device) * 10 - 5).to(dtype)

def _prepare_input_tensors(
self, profile: OptimizationProfile,
Expand Down
12 changes: 12 additions & 0 deletions tensorrt_llm/_torch/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def inplace_slice_copy(dest: torch.Tensor, src: torch.Tensor, dim1_start: int,
'cute_dsl_nvfp4_dense_gemm_swiglu_fp4out_blackwell',
]

# MegaMoE NVFP4 op probes a strict superset of IS_CUTLASS_DSL_AVAILABLE
# (cutlass.torch + cutlass._mlir + cute_nvgpu MMA atoms + the ported
# CuteDSL kernel package). The cute_dsl_megamoe_custom_op module
# sets ``IS_MEGAMOE_OP_AVAILABLE`` based on its own try/except probe;
# importing the module is safe regardless of the result -- it just
# logs and leaves ``IS_MEGAMOE_OP_AVAILABLE = False`` on partial
# cutlass-dsl installs so callers can fall back via the factory.
from .cute_dsl_megamoe_custom_op import IS_MEGAMOE_OP_AVAILABLE
if IS_MEGAMOE_OP_AVAILABLE:
from .cute_dsl_megamoe_custom_op import cute_dsl_megamoe_nvfp4_blackwell
__all__ += ['cute_dsl_megamoe_nvfp4_blackwell']

if IS_CUDA_TILE_AVAILABLE:
from .cuda_tile_custom_ops import (cuda_tile_rms_norm,
cuda_tile_rms_norm_fuse_residual_)
Expand Down
Loading
Loading