[None][feat] DSv4 prep: compressor and mHC primitives#15379
Conversation
Signed-off-by: Fanrong Li <lfr-0531@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds two new CUDA kernel libraries for DeepSeek-V4: ChangesDeepSeek-V4 mHC and KV Compressor Kernel Stack
Sequence Diagram(s)sequenceDiagram
participant PyLayer as Model Layer (Python)
participant mHC as mHC Module
participant MhcFusedHcRunner
participant mhcFusedHcLaunch as CUDA: mhcFusedHcLaunch
participant mhcBigFuseLaunch as CUDA: mhcBigFuseLaunch
participant Compressor as Compressor Module
participant compressorKernels as CUDA: Compressor Kernels
participant PagedKVCache as Paged KV Cache
PyLayer->>mHC: fused_hc(x_prev, residual_prev, post_mix_prev, comb_mix_prev, norm_weight)
mHC->>MhcFusedHcRunner: AutoTuner.choose_one() selects backend tactic
MhcFusedHcRunner->>mhcFusedHcLaunch: dispatch (FMA/TF32/all-in-one) with workspaces
mhcFusedHcLaunch->>mhcBigFuseLaunch: y_acc, r_acc -> split-K reduce + Sinkhorn + optional RMSNorm
mhcBigFuseLaunch-->>mHC: residual_cur, post_mix_cur, comb_mix_cur, layer_input_cur
PyLayer->>Compressor: forward(x, metadata)
Compressor->>compressorKernels: prefillReductionLaunch (prefill tokens)
Compressor->>compressorKernels: pagedKvCompressLaunch (decode tokens)
compressorKernels->>PagedKVCache: update online softmax state, emit kv_comp
Compressor->>compressorKernels: postProcessScatterLaunch (RMSNorm+RoPE+Hadamard, cache_scale_type)
compressorKernels->>PagedKVCache: scatter normalized/quantized KV (FP8/MXFP4/BF16)
Compressor-->>PyLayer: (kv_out or quant_output, scale_output)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (4)
tensorrt_llm/_torch/attention_backend/sparse/deepseek_v4/compressor.py (1)
56-59: 💤 Low valueUnknown string keys raise cryptic
KeyError.
resolve_kv_cache_dtypewill raise a bareKeyErrorif passed an unrecognized string. A more informative error message would help users diagnose configuration issues.Suggested improvement
def resolve_kv_cache_dtype(kv_cache_dtype: Union[str, KVCacheDtype]) -> KVCacheDtype: if isinstance(kv_cache_dtype, str): - return _KV_CACHE_DTYPE_MAP[kv_cache_dtype] + if kv_cache_dtype not in _KV_CACHE_DTYPE_MAP: + raise ValueError( + f"Unknown kv_cache_dtype: '{kv_cache_dtype}'. " + f"Valid options: {list(_KV_CACHE_DTYPE_MAP.keys())}" + ) + return _KV_CACHE_DTYPE_MAP[kv_cache_dtype] return kv_cache_dtype🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/attention_backend/sparse/deepseek_v4/compressor.py` around lines 56 - 59, The resolve_kv_cache_dtype function currently raises a bare KeyError when given an unrecognized string key, providing no helpful guidance to users. Add error handling around the _KV_CACHE_DTYPE_MAP dictionary lookup in resolve_kv_cache_dtype to catch the KeyError and re-raise a more informative error message that includes the invalid key that was provided and optionally lists the valid/supported options from the _KV_CACHE_DTYPE_MAP dictionary. This will help users diagnose configuration issues more easily.tensorrt_llm/_torch/modules/mhc/mhc_cuda.py (1)
150-150: ⚡ Quick winConsider raising a descriptive error instead of assertion.
The assertion guards against calling the function when DeepGEMM is unavailable, but assertions are removed in optimized Python (
-O). A clearRuntimeErrorwould be more robust.♻️ Suggested refactor
- assert dg_fn is not None, "DeepGEMM is not available" + if dg_fn is None: + raise RuntimeError( + "DeepGEMM is not available. Install deep_gemm or use FMA backend." + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/mhc/mhc_cuda.py` at line 150, The assertion at line 150 checking if dg_fn is not None can be removed when Python runs in optimized mode (-O flag), making the guard ineffective. Replace the assert statement with an explicit if condition that raises a RuntimeError instead, maintaining the descriptive error message "DeepGEMM is not available" so the check is always enforced regardless of Python optimization settings.tests/unittest/_torch/modules/test_mhc.py (1)
959-961: ⚡ Quick winAdd
strict=Trueto zip for safety.Python 3.10+ supports
zip(..., strict=True), which raisesValueErrorif the iterables have different lengths. This catches bugs where the returned tuple doesn't match the expected structure.♻️ Suggested fix
- for ge, ee, name in zip( - graph_out, eager_out, ["residual", "post_mix", "comb_mix", "layer_input"] - ): + for ge, ee, name in zip( + graph_out, eager_out, ["residual", "post_mix", "comb_mix", "layer_input"], strict=True + ):Apply the same change to line 976.
Also applies to: 976-978
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/test_mhc.py` around lines 959 - 961, The zip() function calls in the test file lack the strict=True parameter, which would help catch bugs where iterables have different lengths. In tests/unittest/_torch/modules/test_mhc.py at lines 959-961 (the zip call with graph_out, eager_out, and the name list), add strict=True as a parameter to the zip() function to enable length validation. Apply the identical change at lines 976-978 where another zip() call appears that needs the same strict=True parameter added.tensorrt_llm/_torch/modules/mhc/hyper_connection.py (1)
114-116: ⚡ Quick winConsider replacing assertions with explicit ValueError for dtype/shape checks.
Assertions are removed when Python runs with
-O(optimized mode), which can lead to silent failures in production. ExplicitValueErrororTypeErroris more robust for runtime validation in forward methods.♻️ Suggested refactor
- assert x.dtype == torch.bfloat16 - assert self.mult == x.shape[-2] - assert self.hidden_size == x.shape[-1] + if x.dtype != torch.bfloat16: + raise TypeError(f"pre_mapping requires bfloat16 input, got {x.dtype}") + if self.mult != x.shape[-2]: + raise ValueError(f"Expected shape[−2]={self.mult}, got {x.shape[-2]}") + if self.hidden_size != x.shape[-1]: + raise ValueError(f"Expected shape[−1]={self.hidden_size}, got {x.shape[-1]}")Apply the same pattern to lines 188-189 in
fused_hc.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/mhc/hyper_connection.py` around lines 114 - 116, Replace the assertions that validate dtype and shape in the forward method (checking x.dtype against torch.bfloat16, self.mult against x.shape[-2], and self.hidden_size against x.shape[-1]) with explicit ValueError or TypeError raises instead, since assertions are stripped when Python runs with -O flag and will silently fail in production. Apply the same pattern to the assertions at lines 188-189 in the fused_hc method to ensure consistent runtime validation across both locations.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/kernels/compressorKernels/compressorKernels.cu`:
- Around line 162-176: The packE2M1x2 function silently returns 0 on pre-SM100
architectures instead of failing, which causes data corruption when
kMXFP4Blockwise mode is enabled on unsupported GPUs. Add a runtime architecture
validation in postProcessScatterLaunch (around line 1692) that rejects
kMXFP4Blockwise mode if the GPU compute capability is less than SM100 (major
version less than 10). Use TLLM_CHECK_WITH_INFO to perform this validation and
report the actual GPU architecture in the error message so users know why their
request was rejected. This prevents silent data corruption by catching the
incompatibility at runtime rather than letting packE2M1x2 silently degrade.
In `@cpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cu`:
- Around line 530-542: The FMA kernel launchers use tile_n and tile_m parameters
in modulo and division operations without first validating that these values are
positive, which can cause host crashes or undefined behavior if they are zero or
negative. Add TLLM_CHECK_WITH_INFO validation calls to ensure tile_n is greater
than zero before the modulo operation FHC_SHAPE_N % tile_n (around line 534),
and ensure tile_m is greater than zero before the division operation m_batches =
... / tile_m (around line 804). Apply the same validation checks to the second
FMA launcher function in the consolidated_sites range (lines 795-820) to prevent
the same crash paths in both implementations.
In `@cpp/tensorrt_llm/kernels/mhcKernels/mhcKernels.cu`:
- Around line 867-875: The mhcHcHeadApplyLaunch wrapper function accepts a
runtime parameter mult without validation, but the kernel mhcHcHeadApplyKernel
only allocates s_pre[8] in shared memory. If mult exceeds 8, it causes
out-of-bounds memory writes and reads. Add a validation check in
mhcHcHeadApplyLaunch before the kernel launch statement to ensure mult does not
exceed 8, either through an assertion, error return, or exception to prevent
invalid kernel behavior.
- Around line 847-863: The default case in the switch statement for tileN in
mhcGemmSqrsumFmaLaunch incorrectly maps all unrecognized tileN values to the
TN=24 kernel, which causes out-of-bounds writes when tileN is unsupported (e.g.,
tileN=16). Add a validation check using TLLM_CHECK_WITH_INFO after the existing
divisibility check on line 849 to explicitly reject unsupported tileN values,
ensuring only the valid cases (1, 2, 3, 4, 6, 8, 12, and 24) are allowed to
proceed to the switch statement.
In `@cpp/tensorrt_llm/thop/compressorOp.cpp`:
- Around line 78-118: The compressorPostProcessScatterOp function is missing
contiguity validation checks for several input tensors that the kernel expects
to have contiguous memory layout. Add TORCH_CHECK statements after the existing
position_ids contiguity check (around line 100) to validate that kv_comp,
rms_weight, block_offsets, and kv_cache are all contiguous, using the same
pattern as the existing cos_sin_table and position_ids checks with the
is_contiguous() method and appropriate error messages identifying each tensor.
In `@cpp/tensorrt_llm/thop/mhcOp.cpp`:
- Around line 94-100: The norm_weight tensor validation in the mhc_fused_hc
function is missing device checks, which could allow CPU or different-device
tensors to be passed to the CUDA kernel, causing illegal memory access. Add
TORCH_CHECK validations to ensure that norm_weight is a CUDA tensor and is on
the same device as the other tensors being used in the operation. Insert these
device checks in the existing norm_weight validation block (after checking
dtype, contiguity, and numel) to verify the tensor is on a CUDA device before
dereferencing its pointer.
- Around line 103-155: The code currently handles backend values 3, 2, and 1
with explicit if conditions, but any other value silently falls through to the
final mhcFusedHcLaunch call (backend 0 fallback), which masks bugs like typos or
autotuner errors. Add a guard after the if conditions to reject unknown backend
values (anything other than 0, 1, 2, or 3) by throwing an error or asserting,
before reaching the mhcFusedHcLaunch call. This ensures invalid backend values
fail fast instead of silently executing the wrong kernel.
---
Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/deepseek_v4/compressor.py`:
- Around line 56-59: The resolve_kv_cache_dtype function currently raises a bare
KeyError when given an unrecognized string key, providing no helpful guidance to
users. Add error handling around the _KV_CACHE_DTYPE_MAP dictionary lookup in
resolve_kv_cache_dtype to catch the KeyError and re-raise a more informative
error message that includes the invalid key that was provided and optionally
lists the valid/supported options from the _KV_CACHE_DTYPE_MAP dictionary. This
will help users diagnose configuration issues more easily.
In `@tensorrt_llm/_torch/modules/mhc/hyper_connection.py`:
- Around line 114-116: Replace the assertions that validate dtype and shape in
the forward method (checking x.dtype against torch.bfloat16, self.mult against
x.shape[-2], and self.hidden_size against x.shape[-1]) with explicit ValueError
or TypeError raises instead, since assertions are stripped when Python runs with
-O flag and will silently fail in production. Apply the same pattern to the
assertions at lines 188-189 in the fused_hc method to ensure consistent runtime
validation across both locations.
In `@tensorrt_llm/_torch/modules/mhc/mhc_cuda.py`:
- Line 150: The assertion at line 150 checking if dg_fn is not None can be
removed when Python runs in optimized mode (-O flag), making the guard
ineffective. Replace the assert statement with an explicit if condition that
raises a RuntimeError instead, maintaining the descriptive error message
"DeepGEMM is not available" so the check is always enforced regardless of Python
optimization settings.
In `@tests/unittest/_torch/modules/test_mhc.py`:
- Around line 959-961: The zip() function calls in the test file lack the
strict=True parameter, which would help catch bugs where iterables have
different lengths. In tests/unittest/_torch/modules/test_mhc.py at lines 959-961
(the zip call with graph_out, eager_out, and the name list), add strict=True as
a parameter to the zip() function to enable length validation. Apply the
identical change at lines 976-978 where another zip() call appears that needs
the same strict=True parameter added.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: e0251fc7-d3e0-4b27-b4fc-d3a7189d27f1
📒 Files selected for processing (23)
cpp/tensorrt_llm/CMakeLists.txtcpp/tensorrt_llm/kernels/CMakeLists.txtcpp/tensorrt_llm/kernels/compressorKernels/CMakeLists.txtcpp/tensorrt_llm/kernels/compressorKernels/compressorKernels.cucpp/tensorrt_llm/kernels/compressorKernels/compressorKernels.hcpp/tensorrt_llm/kernels/mhcKernels/CMakeLists.txtcpp/tensorrt_llm/kernels/mhcKernels/fused_tf32_pmap_gemm.cuhcpp/tensorrt_llm/kernels/mhcKernels/mhcFusedHcKernel.cucpp/tensorrt_llm/kernels/mhcKernels/mhcKernels.cucpp/tensorrt_llm/kernels/mhcKernels/mhcKernels.hcpp/tensorrt_llm/kernels/mhcKernels/mhc_fused_fma.cuhcpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/compressorOp.cppcpp/tensorrt_llm/thop/mhcOp.cpptensorrt_llm/_torch/attention_backend/sparse/deepseek_v4/__init__.pytensorrt_llm/_torch/attention_backend/sparse/deepseek_v4/compressor.pytensorrt_llm/_torch/modules/mhc/__init__.pytensorrt_llm/_torch/modules/mhc/hyper_connection.pytensorrt_llm/_torch/modules/mhc/mhc_cuda.pytests/unittest/_torch/attention/sparse/deepseek_v4/__init__.pytests/unittest/_torch/attention/sparse/deepseek_v4/test_compressor_kernel.pytests/unittest/_torch/attention/sparse/deepseek_v4/test_compressor_tf32.pytests/unittest/_torch/modules/test_mhc.py
|
/bot run --disable-fail-fast |
|
PR_Github #54321 [ run ] triggered by Bot. Commit: |
|
PR_Github #54321 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54397 [ run ] triggered by Bot. Commit: |
09f114d to
8b66876
Compare
|
/bot kill |
|
/bot run --disable-fail-fast |
|
PR_Github #54411 [ run ] triggered by Bot. Commit: |
|
PR_Github #54412 [ kill ] triggered by Bot. Commit: |
|
PR_Github #54411 [ run ] completed with state |
|
PR_Github #54397 [ run ] completed with state |
|
PR_Github #54412 [ kill ] completed with state |
Signed-off-by: Mingyang Hao <200044211+mingyangHao@users.noreply.github.com>
Signed-off-by: Mingyang Hao <200044211+mingyangHao@users.noreply.github.com>
c75b5fc to
6e4e004
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #54516 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54601 [ run ] triggered by Bot. Commit: |
|
PR_Github #54601 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54696 [ run ] triggered by Bot. Commit: |
|
PR_Github #54696 [ run ] completed with state
|
|
/bot run |
|
PR_Github #54779 [ run ] triggered by Bot. Commit: |
|
PR_Github #54779 [ run ] completed with state
|
|
/bot run |
|
PR_Github #54811 [ run ] triggered by Bot. Commit: |
|
PR_Github #54811 [ run ] completed with state |
|
|
||
| TORCH_LIBRARY_IMPL(trtllm, CUDA, m) | ||
| { | ||
| m.impl("mhc_big_fuse", &mhcBigFuseOp); |
There was a problem hiding this comment.
Should we add register_fake for all ops to tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py?
There was a problem hiding this comment.
All the newly added custom ops are in-place, so there's no need to add them to register_fake. I just updated the inplace_map in tensorrt_llm/_torch/compilation/utils.py.
|
|
||
| TORCH_LIBRARY_IMPL(trtllm, CUDA, m) | ||
| { | ||
| m.impl("compressor_paged_kv_compress", &compressorPagedKvCompressOp); |
There was a problem hiding this comment.
Should we add register_fake for all ops to tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py?
Signed-off-by: Fanrong Li <lfr-0531@users.noreply.github.com>
Signed-off-by: Fanrong Li <lfr-0531@users.noreply.github.com>
|
/bot run |
|
PR_Github #55244 [ run ] triggered by Bot. Commit: |
|
PR_Github #55244 [ run ] completed with state |
Description
This is PR-2 from the DSv4 umbrella split. It lands standalone compressor and mHC primitives from #14751 without pulling in the DSv4 sparse cache manager, sparse MLA backend, MoE routing, or model/tokenizer changes.
Included:
deepseek_v4/__init__.pyso standalone compressor import does not import the full DSv4 backendIntentionally excluded:
test_compressor_module.pyandcache_manager.py(PR-6)Verification
Base/source:
aea4ae426b619bbc1e8411e8028de7fc77747664.github/main20b606838773ed96194a20f37b54441f4281e4d0.Build/install:
python3 ./scripts/build_wheel.py --trt_root /usr/local/tensorrt --benchmarks --use_ccache --cuda_architectures "90-real;100-real" --configure_cmakebuild/tensorrt_llm-1.3.0rc18-cp312-cp312-linux_x86_64.whlmhcOp.cppandcompressorOp.cppcompiled intoth_common.python -m pip install --force-reinstall --no-deps build/tensorrt_llm-1.3.0rc18-cp312-cp312-linux_x86_64.whlninja -t inputs returned no results for wheel targets, but wheel build exited 0.Import/custom-op smoke:
from tensorrt_llm.bindings.internal import thopsucceeded.deepseek_v4.compressorwas imported from this worktree path.cache_managerand backenddeepseek_v4.pymodules were not loaded by compressor import.torch.ops.trtllmexposed:compressor_prefill_reductioncompressor_paged_kv_compresscompressor_postprocess_scattermhc_big_fusemhc_gemm_sqrsum_fmamhc_post_mappingmhc_fused_hcmhc_hc_head_applyTests:
CUDA_VISIBLE_DEVICES=0.timeout 1200 python -m pytest -q --tb=short -ra tests/unittest/_torch/attention/sparse/deepseek_v4/test_compressor_kernel.py: 63 passed, 22 skipped, 3 warnings in 4.41stimeout 1200 python -m pytest -q --tb=short -ra tests/unittest/_torch/attention/sparse/deepseek_v4/test_compressor_tf32.py: 4 passed, 2 warnings in 2.53stimeout 1800 python -m pytest -q --tb=short -ra tests/unittest/_torch/modules/test_mhc.py: 50 passed, 3 warnings in 4.16sScope/pre-commit:
pre-commit run --files $(git diff --name-only HEAD)passed before commit; commit hooks passed.git diff --name-only HEAD | rg 'cache_manager.py|attention_backend/sparse/deepseek_v4/deepseek_v4.py|IndexerTopK|indexerTopK|RoutingKernelTopK|modeling_deepseekv4|tokenizer/deepseek_v4|fused_moe|moeGate|deepseekV4QNorm|fp8Quantize.cpp|mlaRopeInplaceOp|triton_fused_inv_rope'Summary by CodeRabbit
New Features
Tests