Skip to content

Commit 384f4a9

Browse files
authored
Merge branch 'main' into dev-bench-moe
2 parents a8c9840 + 0e74256 commit 384f4a9

217 files changed

Lines changed: 34429 additions & 4710 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.claude/skills/trtllm-moe-develop/SKILL.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,26 @@ Checklist:
268268
- Existing legacy `forward` methods can be read for compatibility context, but
269269
they are not the default pattern for new backend work.
270270

271+
### Imported Kernel ABI Checklist
272+
273+
When importing or wrapping an upstream kernel, derive the TRT-LLM adapter
274+
contract from the lowest-level kernel consumer. Comments, docs, design notes,
275+
and parameter names are useful hints, but they are not proof of the runtime ABI.
276+
277+
- Derive weight shape and layout from the kernel entrypoint, `make_layout`, TMA,
278+
MMA/GEMM transforms, and stride usage. Record required tensor shape, stride,
279+
physical storage layout, and boundary view layout.
280+
- Derive alpha and scale semantics from kernel consumption points. Trace where
281+
alpha, norm constants, block scales, activation scales, and weight scales are
282+
loaded and multiplied before deciding how upper layers compute or pack them.
283+
Treat weight bytes, block scales/SF, and global alpha/norm constants as
284+
separate contracts.
285+
- Design the upper-layer adapter from the kernel ABI upward. Map each kernel
286+
input/output to an adapter responsibility: storage tensor, view/transposition,
287+
dtype reinterpretation, padding, scale packing, workspace ownership,
288+
synchronization, and output reduction. Validate parity with upstream
289+
invocation dumps, not just final output.
290+
271291
### Quantization And Weights
272292

273293
Role:

.pre-commit-config.yaml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,25 @@ common-files: &common_files |
302302
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py |
303303
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py |
304304
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py |
305+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py |
306+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py |
307+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py |
308+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py |
309+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py |
310+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py |
311+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py |
312+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py |
313+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py |
314+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py |
315+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py |
316+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py |
317+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py |
318+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py |
319+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py |
320+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py |
321+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py |
322+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py |
323+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py |
305324
tensorrt_llm/_torch/cute_dsl_utils.py |
306325
tensorrt_llm/_torch/debug/__init__.py |
307326
tensorrt_llm/_torch/debug/debug_hook.py |
@@ -1658,6 +1677,25 @@ legacy-files: &legacy_files |
16581677
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py |
16591678
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py |
16601679
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py |
1680+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py |
1681+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py |
1682+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py |
1683+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py |
1684+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dynamic_mainloop.py |
1685+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue_refactor.py |
1686+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py |
1687+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py |
1688+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py |
1689+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py |
1690+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py |
1691+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py |
1692+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py |
1693+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py |
1694+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py |
1695+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py |
1696+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py |
1697+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/token_comm.py |
1698+
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/topk_reduce.py |
16611699
tensorrt_llm/_torch/cute_dsl_utils.py |
16621700
tensorrt_llm/_torch/debug/__init__.py |
16631701
tensorrt_llm/_torch/debug/debug_hook.py |

cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ class MaxRequestsScheduler : public BaseCapacityScheduler
8787

8888
/// @brief Schedule requests using the MAX_UTILIZATION policy
8989
/// @details Try reserving resources to advance requests by one step,
90-
/// may pause previously started requests.
90+
/// may pause previously started requests. When a
91+
/// ``crossKvCacheManager`` is supplied, requests in the
92+
/// ``ENCODER_INIT`` state may be admitted for encoder compute
93+
/// without consuming self- or cross-KV blocks; the later
94+
/// ``CONTEXT_INIT`` decoder admission owns cross-pool budgeting.
9195
class MaxUtilizationScheduler : public BaseCapacityScheduler
9296
{
9397
public:
@@ -96,8 +100,9 @@ class MaxUtilizationScheduler : public BaseCapacityScheduler
96100
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
97101

98102
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
99-
kv_cache_manager::BaseKVCacheManager& kvCacheManager, OptionalRef<BasePeftCacheManager const> peftCacheManager,
100-
RequestList const& activeRequests) const;
103+
kv_cache_manager::BaseKVCacheManager& kvCacheManager,
104+
OptionalRef<kv_cache_manager::BaseKVCacheManager> crossKvCacheManager,
105+
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
101106

102107
private:
103108
SizeType32 mMaxNumRequests;
@@ -106,6 +111,10 @@ class MaxUtilizationScheduler : public BaseCapacityScheduler
106111
};
107112

108113
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
114+
/// @details When a ``crossKvCacheManager`` is supplied, requests in the
115+
/// ``ENCODER_INIT`` state may be admitted for encoder compute
116+
/// without consuming self- or cross-KV blocks. The later
117+
/// ``CONTEXT_INIT`` decoder admission owns cross-pool budgeting.
109118
class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
110119
{
111120
public:
@@ -158,7 +167,11 @@ class CapacityScheduler : public Algorithm
158167
*
159168
* @param kvCacheManager Required in MaxUtilizationScheduler (as a ref) and in GuaranteedNoEvictScheduler and
160169
* StaticBatchScheduler (as a const ref).
161-
* @param crossKvCacheManager Optional used in GuaranteedNoEvictScheduler and StaticBatchScheduler.
170+
* @param crossKvCacheManager Optional cross-attention KV cache manager. Used by
171+
* MaxUtilizationScheduler (mutates: ``startScheduling`` / ``schedulingRemoveSequence``)
172+
* and GuaranteedNoEvictScheduler / StaticBatchScheduler (read-only). Required for
173+
* encoder-decoder admission. Encoder-init requests only require this pool
174+
* to be configured; decoder context admission budgets blocks from it.
162175
* @param peftCacheManager Optional used in MaxUtilizationScheduler, GuaranteedNoEvictScheduler and
163176
* StaticBatchScheduler.
164177
* @param activeRequests
@@ -168,7 +181,7 @@ class CapacityScheduler : public Algorithm
168181
[[nodiscard]] std::tuple<RequestVector, RequestVector, RequestVector> operator()(RequestList const& activeRequests,
169182
OptionalRef<kv_cache_manager::BaseKVCacheManager> kvCacheManager = std::nullopt,
170183
OptionalRef<BasePeftCacheManager const> peftCacheManager = std::nullopt,
171-
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager = std::nullopt) const;
184+
OptionalRef<kv_cache_manager::BaseKVCacheManager> crossKvCacheManager = std::nullopt) const;
172185

173186
/// @brief Sets the reorder policy to use AgentTreePolicy with the given configuration.
174187
/// @param agentPercentage The ratio of agent requests to schedule (0.0-1.0, -1.0 for random).

cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ namespace kvc = tensorrt_llm::executor::kv_cache;
2424

2525
#pragma once
2626

27+
namespace tensorrt_llm::testing
28+
{
29+
class KVCacheTransferManagerTestAccess;
30+
} // namespace tensorrt_llm::testing
31+
2732
namespace tensorrt_llm::batch_manager::kv_cache_manager
2833
{
2934

@@ -76,10 +81,15 @@ class KVCacheTransferManager
7681
[[nodiscard]] KvCacheTransferStats getAndResetTransferStats();
7782

7883
private:
84+
friend class ::tensorrt_llm::testing::KVCacheTransferManagerTestAccess;
85+
7986
//! \brief Get pointer to pool specified by cache block.
8087
static tr::ITensor::SharedPtr computeBlockPointer(
8188
BlockPtr const& block, std::vector<KVCacheBlockPool> const& pools, size_t poolIdx);
8289

90+
//! \brief Get pool-qualified index for pending transfer tracking.
91+
[[nodiscard]] static kernels::KVCacheIndex::UnderlyingType getPendingTransferIndex(BlockPtr const& block);
92+
8393
/*!
8494
* \brief The key method that copies the src block to the dst block.
8595
*
@@ -107,8 +117,8 @@ class KVCacheTransferManager
107117
runtime::BufferManager mOnboardManager;
108118
runtime::BufferManager mOffloadManager;
109119

110-
// Track reads and writes for blocks. Note that it is the memory pool index that
111-
// identifies the raw memory blocks involved in I/O, not the block Id.
120+
// Track reads and writes for blocks. Note that it is the pool-qualified memory pool index
121+
// that identifies the raw memory blocks involved in I/O, not the block Id.
112122
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
113123
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
114124
// Reference to parent loopback agent

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -665,9 +665,9 @@ class GenericLlmRequest
665665
return mEncoderUniqueTokens;
666666
}
667667

668-
/// @brief Get length of encoder input (could be tokens or features length)
669-
/// @return An integer.
670-
[[nodiscard]] SizeType32 getEncoderInputLen() const
668+
/// @brief Get length of encoder input when present, without throwing for decoder-only requests.
669+
/// @return Encoder input length, or nullopt when this request has no encoder side.
670+
[[nodiscard]] std::optional<SizeType32> tryGetEncoderInputLen() const
671671
{
672672
if (mEncoderInputFeatures.has_value())
673673
{
@@ -678,19 +678,45 @@ class GenericLlmRequest
678678
return getEncoderTokens().value()->size();
679679
}
680680

681-
TLLM_THROW("GenericLlmRequest::getEncoderInputLen - Do not have encoder length!");
681+
return std::nullopt;
682682
}
683683

684-
/// @brief Get length of encoder output. Fall back to encoder input length if not present
684+
/// @brief Get length of encoder input (could be tokens or features length)
685685
/// @return An integer.
686-
[[nodiscard]] SizeType32 getEncoderOutputLen() const
686+
[[nodiscard]] SizeType32 getEncoderInputLen() const
687+
{
688+
auto const encoderInputLen = tryGetEncoderInputLen();
689+
if (encoderInputLen.has_value())
690+
{
691+
return encoderInputLen.value();
692+
}
693+
694+
TLLM_THROW("GenericLlmRequest::getEncoderInputLen - Do not have encoder length!");
695+
}
696+
697+
/// @brief Get length of encoder output when present, without throwing for decoder-only requests.
698+
/// @return Encoder output length, or nullopt when this request has no encoder side.
699+
[[nodiscard]] std::optional<SizeType32> tryGetEncoderOutputLen() const
687700
{
688701
if (mEncoderOutputLength.has_value())
689702
{
690703
return mEncoderOutputLength.value();
691704
}
692705

693-
return getEncoderInputLen();
706+
return tryGetEncoderInputLen();
707+
}
708+
709+
/// @brief Get length of encoder output, or throw if the request has no encoder side.
710+
/// @return Explicit encoder output length, or encoder input length when the output length is not present.
711+
[[nodiscard]] SizeType32 getEncoderOutputLen() const
712+
{
713+
auto const encoderOutputLen = tryGetEncoderOutputLen();
714+
if (encoderOutputLen.has_value())
715+
{
716+
return encoderOutputLen.value();
717+
}
718+
719+
TLLM_THROW("GenericLlmRequest::getEncoderInputLen - Do not have encoder length!");
694720
}
695721

696722
[[nodiscard]] std::optional<std::shared_ptr<std::vector<SizeType32>>> getPositionIds() const

cpp/include/tensorrt_llm/common/optionalRef.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ class OptionalRef
7878
{
7979
}
8080

81+
// Implicit conversion from OptionalRef<non-const T> to OptionalRef<const T>
82+
template <typename U = T, typename = std::enable_if_t<std::is_const_v<U>>>
83+
OptionalRef(OptionalRef<std::remove_const_t<T>> const& other)
84+
: opt(other ? std::optional<std::reference_wrapper<T>>(std::ref(*other)) : std::nullopt)
85+
{
86+
}
87+
8188
T* operator->() const
8289
{
8390
return opt ? &(opt->get()) : nullptr;

0 commit comments

Comments
 (0)