Skip to content

Commit c5f2ee1

Browse files
committed
Support MoE per-expert finalize input scales
1 parent db57b9c commit c5f2ee1

5 files changed

Lines changed: 94 additions & 12 deletions

File tree

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -700,6 +700,19 @@ namespace tg = batchedGemm::trtllm::gen;
700700

701701
////////////////////////////////////////////////////////////////////////////////////////////////////
702702

703+
template <typename KernelParams>
704+
__device__ __forceinline__ int32_t getExpertIdx(KernelParams const& params, int32_t expandedIdx)
705+
{
706+
if (params.topKIds != nullptr)
707+
{
708+
return params.topKIds[expandedIdx];
709+
}
710+
711+
return static_cast<int32_t>(params.packedExpertIndexes[expandedIdx].idx);
712+
}
713+
714+
////////////////////////////////////////////////////////////////////////////////////////////////////
715+
703716
template <typename KernelParams>
704717
__global__ void finalizeKernel(KernelParams params)
705718
{
@@ -735,15 +748,22 @@ __global__ void finalizeKernel(KernelParams params)
735748
continue;
736749
}
737750

751+
float data_k = float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]};
752+
738753
if (params.expertWeightsPtr != nullptr)
739754
{
740755
TypeExpW const scale = params.expertWeightsPtr[expandedIdx];
741-
data += float{scale} * float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]};
756+
data_k *= float{scale};
742757
}
743-
else
758+
759+
if (params.inScalePtr != nullptr)
744760
{
745-
data += float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]};
761+
int const expertIdx = getExpertIdx(params, expandedIdx);
762+
int const inScaleIdx = expertIdx * params.hiddenDimPadded + hiddenIdx;
763+
data_k *= params.inScalePtr[inScaleIdx];
746764
}
765+
766+
data += data_k;
747767
}
748768

749769
params.outPtr[tokenIdx * params.hiddenDim + hiddenIdx] = static_cast<Type>(data);
@@ -823,6 +843,17 @@ __global__ void finalizeKernelVecLoad(KernelParams params)
823843
float4 input = vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));
824844
InputElem inputPermutedElem = *reinterpret_cast<InputElem const*>(&input);
825845
ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputPermutedElem);
846+
if (params.inScalePtr != nullptr)
847+
{
848+
int const expertIdx = getExpertIdx(params, expandedIdx);
849+
#pragma unroll
850+
for (int idx = 0; idx < FINALIZE_ELEM_PER_THREAD; ++idx)
851+
{
852+
int const hiddenIdx = elemIndex * FINALIZE_ELEM_PER_THREAD + idx;
853+
int const inScaleIdx = expertIdx * params.hiddenDimPadded + hiddenIdx;
854+
expertResult[idx] *= params.inScalePtr[inScaleIdx];
855+
}
856+
}
826857

827858
threadOutput = threadOutput + scale * expertResult;
828859
}
@@ -873,9 +904,17 @@ __global__ void finalizeDeepSeekKernel(KernelParams params)
873904
int const scaleIdx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
874905
float const blockScale = params.inDqSfsPtr ? params.inDqSfsPtr[scaleIdx] : 1;
875906

907+
float inputScale = 1.0f;
908+
if (params.inScalePtr != nullptr)
909+
{
910+
int const expertIdx = getExpertIdx(params, expandedIdx);
911+
int const inScaleIdx = expertIdx * params.hiddenDimPadded + hiddenIdx;
912+
inputScale = params.inScalePtr[inScaleIdx];
913+
}
914+
876915
float const expertProb = (float) params.expertWeightsPtr[tokenIdx * params.topK + k];
877916

878-
float const scale = expertProb * blockScale;
917+
float const scale = inputScale * expertProb * blockScale;
879918
acc += scale * static_cast<float>(params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]);
880919
}
881920

@@ -909,6 +948,12 @@ __global__ void finalizeDeepSeekKernel(KernelParams params)
909948
////////////////////////////////////////////////////////////////////////////////////////////////////
910949
void run(Data const& data, void* stream)
911950
{
951+
if (data.inScalePtr != nullptr)
952+
{
953+
TLLM_CHECK_WITH_INFO(data.topKIds != nullptr || data.packedExpertIndexes != nullptr,
954+
"Finalize input scales require expert indexes.");
955+
}
956+
912957
if (data.mUseDeepSeekFp8)
913958
{
914959
int const numThreads = 128;

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
1616

1717
#pragma once
1818

19+
#include "RoutingKernel.h"
1920
#include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
2021
#include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h"
2122
#include <cuda.h>
@@ -547,9 +548,12 @@ struct Data
547548
void* outPtr;
548549
float* inDqSfsPtr = nullptr;
549550
float* outDqSfsPtr = nullptr;
551+
float const* inScalePtr = nullptr;
550552

551553
void* expertWeightsPtr;
552554
int32_t* expandedIdxToPermutedIdx;
555+
int32_t const* topKIds = nullptr;
556+
void const* packedExpertIndexes = nullptr;
553557

554558
int32_t numTokens;
555559
int32_t numExperts;
@@ -574,8 +578,11 @@ struct KernelParams
574578

575579
float* inDqSfsPtr = nullptr;
576580
float* outDqSfsPtr = nullptr;
581+
float const* inScalePtr = nullptr;
577582

578583
int32_t* expandedIdxToPermutedIdx;
584+
int32_t const* topKIds = nullptr;
585+
routing::PackedScoreIdx<TypeExpW> const* packedExpertIndexes = nullptr;
579586

580587
int32_t hiddenDim;
581588
int32_t hiddenDimPadded;
@@ -594,8 +601,11 @@ struct KernelParams
594601
params.outPtr = (Type*) data.outPtr;
595602
params.inDqSfsPtr = data.inDqSfsPtr;
596603
params.outDqSfsPtr = data.outDqSfsPtr;
604+
params.inScalePtr = data.inScalePtr;
597605

598606
params.expandedIdxToPermutedIdx = data.expandedIdxToPermutedIdx;
607+
params.topKIds = data.topKIds;
608+
params.packedExpertIndexes = static_cast<routing::PackedScoreIdx<TypeExpW> const*>(data.packedExpertIndexes);
599609

600610
params.hiddenDim = data.hiddenDim;
601611
params.hiddenDimPadded = data.hiddenDimPadded;

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -510,6 +510,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
510510
finalizeData.outPtr = args.output;
511511
finalizeData.inDqSfsPtr = workspace.gemm2_output_scale;
512512
finalizeData.outDqSfsPtr = args.output_scale;
513+
finalizeData.inScalePtr = args.finalize_input_scale;
513514
if (args.mUseRoutingScalesOnInput)
514515
{
515516
finalizeData.expertWeightsPtr = nullptr;
@@ -519,6 +520,8 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace
519520
finalizeData.expertWeightsPtr = workspace.expert_weights;
520521
}
521522
finalizeData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;
523+
finalizeData.topKIds = args.topk_ids;
524+
finalizeData.packedExpertIndexes = workspace.routing_expert_indexes;
522525
finalizeData.numTokens = args.num_tokens;
523526
finalizeData.numExperts = args.num_experts;
524527
finalizeData.topK = args.top_k;
@@ -633,6 +636,9 @@ void Runner::run(
633636
if (args.do_finalize)
634637
{
635638
// Run finalize
639+
TLLM_CHECK_WITH_INFO(args.finalize_input_scale == nullptr || args.topk_ids != nullptr
640+
|| workspace.routing_expert_indexes != nullptr,
641+
"Finalize input scale factors require expert indexes.");
636642
moe::dev::finalize::run(finalizeData, stream);
637643
sync_check_cuda_error(stream);
638644
}

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,10 @@ struct MoERunnerArgs
311311
float* output1_scales_gate_scalar = nullptr;
312312
float* output2_scales_scalar = nullptr;
313313

314+
// Optional per-expert factors applied inside the finalize kernel.
315+
// input: [num_experts, hidden_size].
316+
float* finalize_input_scale = nullptr;
317+
314318
// Output:
315319
void* output = nullptr;
316320
float* output_scale = nullptr;

cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor>
5050
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
5151
btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t moeConfigIndex,
5252
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids,
53-
torch::optional<torch::Tensor> const& out_tensor)
53+
torch::optional<torch::Tensor> const& out_tensor, torch::optional<torch::Tensor> const& finalize_input_scale)
5454
{
5555
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by MXFP4 block scale MOE");
5656
TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64
@@ -173,6 +173,8 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor>
173173
= output1_scale_gate_scalar.has_value() ? output1_scale_gate_scalar.value().data_ptr<float>() : nullptr;
174174
args.output2_scales_scalar
175175
= output2_scale_scalar.has_value() ? output2_scale_scalar.value().data_ptr<float>() : nullptr;
176+
args.finalize_input_scale
177+
= finalize_input_scale.has_value() ? finalize_input_scale.value().data_ptr<float>() : nullptr;
176178
args.num_tokens = hidden_states.sizes()[0];
177179
args.num_experts = num_experts;
178180
// Hidden dimension input of MoE block. It might be padded.
@@ -421,6 +423,19 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor>
421423
output2_scale_scalar->sizes()[0] == local_num_experts, "output2_scales_scalar has incorrect dim 0.");
422424
}
423425

426+
if (finalize_input_scale.has_value())
427+
{
428+
TORCH_CHECK(finalize_input_scale->scalar_type() == at::ScalarType::Float,
429+
"finalize_input_scale must be float, got %s.", c10::toString(finalize_input_scale->scalar_type()));
430+
TORCH_CHECK(finalize_input_scale->dim() == 2, "finalize_input_scale must be 2D.");
431+
TORCH_CHECK(finalize_input_scale->sizes()[0] == num_experts, "finalize_input_scale has incorrect dim 0.");
432+
TORCH_CHECK(finalize_input_scale->sizes()[1] == args.output_hidden_size.value_or(args.hidden_size),
433+
"finalize_input_scale has incorrect dim 1.");
434+
TORCH_CHECK(finalize_input_scale->device() == hidden_states.device(),
435+
"finalize_input_scale must be on the input device.");
436+
TORCH_CHECK(finalize_input_scale->is_contiguous(), "finalize_input_scale must be contiguous.");
437+
}
438+
424439
// allocate or use provided output
425440
at::Tensor output;
426441
if (out_tensor.has_value())
@@ -531,7 +546,8 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
531546
int64_t local_expert_offset, int64_t local_num_experts, std::optional<double> routed_scaling_factor,
532547
int64_t routing_method_type, std::vector<int64_t> moeConfigIndex,
533548
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids,
534-
torch::optional<torch::Tensor> const& output = torch::nullopt)
549+
torch::optional<torch::Tensor> const& output = torch::nullopt,
550+
torch::optional<torch::Tensor> const& finalize_input_scale = torch::nullopt)
535551

536552
{
537553
// moeConfigIndex corresponds to pair (tileN, config)
@@ -556,7 +572,7 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
556572
gemm2_weights_scale, gemm2_bias, std::nullopt, std::nullopt, std::nullopt, num_experts, top_k, n_group,
557573
topk_group, intermediate_size, valid_hidden_size, valid_intermediate_size, local_expert_offset,
558574
local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct, *mRunners[tileN], config,
559-
topk_weights, topk_ids, output);
575+
topk_weights, topk_ids, output, finalize_input_scale);
560576
}
561577

562578
private:
@@ -626,7 +642,8 @@ class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
626642
int64_t local_expert_offset, int64_t local_num_experts, std::optional<double> routed_scaling_factor,
627643
int64_t routing_method_type, std::vector<int64_t> tile_config_pair,
628644
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids,
629-
torch::optional<torch::Tensor> const& output)
645+
torch::optional<torch::Tensor> const& output,
646+
torch::optional<torch::Tensor> const& finalize_input_scale = torch::nullopt)
630647
{
631648
// tile_config_pair corresponds to pair (tileN, config)
632649
auto [tileN, config] = std::tie(tile_config_pair[0], tile_config_pair[1]);
@@ -650,7 +667,7 @@ class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
650667
gemm2_weights_scale, gemm2_bias, output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar,
651668
num_experts, top_k, n_group, topk_group, intermediate_size, valid_hidden_size, valid_intermediate_size,
652669
local_expert_offset, local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct,
653-
*mRunners[tileN], config, topk_weights, topk_ids, output);
670+
*mRunners[tileN], config, topk_weights, topk_ids, output, finalize_input_scale);
654671
}
655672

656673
/**

0 commit comments

Comments
 (0)