@@ -50,7 +50,7 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor>
5050 std::optional<double > const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
5151 btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t moeConfigIndex,
5252 torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids,
53- torch::optional<torch::Tensor> const & out_tensor)
53+ torch::optional<torch::Tensor> const & out_tensor, torch::optional<torch::Tensor> const & finalize_input_scale )
5454{
5555 TORCH_CHECK (tensorrt_llm::common::isSM100Family (), " Only SM100f is supported by MXFP4 block scale MOE" );
5656 TORCH_CHECK (tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64
@@ -173,6 +173,8 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor>
173173 = output1_scale_gate_scalar.has_value () ? output1_scale_gate_scalar.value ().data_ptr <float >() : nullptr ;
174174 args.output2_scales_scalar
175175 = output2_scale_scalar.has_value () ? output2_scale_scalar.value ().data_ptr <float >() : nullptr ;
176+ args.finalize_input_scale
177+ = finalize_input_scale.has_value () ? finalize_input_scale.value ().data_ptr <float >() : nullptr ;
176178 args.num_tokens = hidden_states.sizes ()[0 ];
177179 args.num_experts = num_experts;
178180 // Hidden dimension input of MoE block. It might be padded.
@@ -421,6 +423,19 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor>
421423 output2_scale_scalar->sizes ()[0 ] == local_num_experts, " output2_scales_scalar has incorrect dim 0." );
422424 }
423425
426+ if (finalize_input_scale.has_value ())
427+ {
428+ TORCH_CHECK (finalize_input_scale->scalar_type () == at::ScalarType::Float,
429+ " finalize_input_scale must be float, got %s." , c10::toString (finalize_input_scale->scalar_type ()));
430+ TORCH_CHECK (finalize_input_scale->dim () == 2 , " finalize_input_scale must be 2D." );
431+ TORCH_CHECK (finalize_input_scale->sizes ()[0 ] == num_experts, " finalize_input_scale has incorrect dim 0." );
432+ TORCH_CHECK (finalize_input_scale->sizes ()[1 ] == args.output_hidden_size .value_or (args.hidden_size ),
433+ " finalize_input_scale has incorrect dim 1." );
434+ TORCH_CHECK (finalize_input_scale->device () == hidden_states.device (),
435+ " finalize_input_scale must be on the input device." );
436+ TORCH_CHECK (finalize_input_scale->is_contiguous (), " finalize_input_scale must be contiguous." );
437+ }
438+
424439 // allocate or use provided output
425440 at::Tensor output;
426441 if (out_tensor.has_value ())
@@ -531,7 +546,8 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
531546 int64_t local_expert_offset, int64_t local_num_experts, std::optional<double > routed_scaling_factor,
532547 int64_t routing_method_type, std::vector<int64_t > moeConfigIndex,
533548 torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids,
534- torch::optional<torch::Tensor> const & output = torch::nullopt )
549+ torch::optional<torch::Tensor> const & output = torch::nullopt ,
550+ torch::optional<torch::Tensor> const & finalize_input_scale = torch::nullopt )
535551
536552 {
537553 // moeConfigIndex corresponds to pair (tileN, config)
@@ -556,7 +572,7 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
556572 gemm2_weights_scale, gemm2_bias, std::nullopt , std::nullopt , std::nullopt , num_experts, top_k, n_group,
557573 topk_group, intermediate_size, valid_hidden_size, valid_intermediate_size, local_expert_offset,
558574 local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct , *mRunners [tileN], config,
559- topk_weights, topk_ids, output);
575+ topk_weights, topk_ids, output, finalize_input_scale );
560576 }
561577
562578private:
@@ -626,7 +642,8 @@ class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
626642 int64_t local_expert_offset, int64_t local_num_experts, std::optional<double > routed_scaling_factor,
627643 int64_t routing_method_type, std::vector<int64_t > tile_config_pair,
628644 torch::optional<torch::Tensor> const & topk_weights, torch::optional<torch::Tensor> const & topk_ids,
629- torch::optional<torch::Tensor> const & output)
645+ torch::optional<torch::Tensor> const & output,
646+ torch::optional<torch::Tensor> const & finalize_input_scale = torch::nullopt )
630647 {
631648 // tile_config_pair corresponds to pair (tileN, config)
632649 auto [tileN, config] = std::tie (tile_config_pair[0 ], tile_config_pair[1 ]);
@@ -650,7 +667,7 @@ class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
650667 gemm2_weights_scale, gemm2_bias, output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar,
651668 num_experts, top_k, n_group, topk_group, intermediate_size, valid_hidden_size, valid_intermediate_size,
652669 local_expert_offset, local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct ,
653- *mRunners [tileN], config, topk_weights, topk_ids, output);
670+ *mRunners [tileN], config, topk_weights, topk_ids, output, finalize_input_scale );
654671 }
655672
656673 /* *
0 commit comments