@@ -69,10 +69,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fw
6969 // match logits exactly; routing_map's trailing dim depends on format).
7070 // No caller-side .view() needed.
7171 at::Tensor probs = at::empty (sizes, at::dtype (logits.scalar_type ()).device (at::kCUDA ));
72- at::Tensor routing_map = allocate_routing_map (sizes.slice (0 , sizes.size () - 1 ), num_experts,
73- routing_map_format);
74- at::Tensor intermediate_output =
75- at::empty (sizes, at::dtype (at::kFloat ).device (at::kCUDA ));
72+ at::Tensor routing_map =
73+ allocate_routing_map (sizes.slice (0 , sizes.size () - 1 ), num_experts, routing_map_format);
74+ at::Tensor intermediate_output = at::empty (sizes, at::dtype (at::kFloat ).device (at::kCUDA ));
7675
7776 // Wrap with explicit 2D shape for the kernel — the common-layer NVTE_CHECKs
7877 // expect {num_tokens, num_experts} (or {num_tokens, ceil(num_experts/8)} for
@@ -90,10 +89,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fw
9089
9190 auto logits_cu = makeTransformerEngineTensor (logits.data_ptr (), shape_2d, logits_dtype);
9291 auto probs_cu = makeTransformerEngineTensor (probs.data_ptr (), shape_2d, logits_dtype);
93- auto routing_map_cu = makeTransformerEngineTensor (
94- routing_map.data_ptr (), routing_map_shape_2d, routing_map_dtype);
95- auto intermediate_output_cu = makeTransformerEngineTensor (
96- intermediate_output.data_ptr (), shape_2d, DType::kFloat32 );
92+ auto routing_map_cu =
93+ makeTransformerEngineTensor ( routing_map.data_ptr (), routing_map_shape_2d, routing_map_dtype);
94+ auto intermediate_output_cu =
95+ makeTransformerEngineTensor ( intermediate_output.data_ptr (), shape_2d, DType::kFloat32 );
9796 auto expert_bias_cu = TensorWrapper (); // empty expert_bias_cu tensor
9897 if (expert_bias.has_value ()) {
9998 expert_bias_cu = makeTransformerEngineTensor (expert_bias.value ());
@@ -111,8 +110,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fw
111110
112111void fused_topk_with_score_function_bwd (at::Tensor routing_map, at::Tensor intermediate_output,
113112 at::Tensor grad_probs, at::Tensor grad_logits, int topk,
114- bool use_pre_softmax,
115- std::optional<float > scaling_factor,
113+ bool use_pre_softmax, std::optional<float > scaling_factor,
116114 std::string score_function, int routing_map_format) {
117115 // grad_probs / grad_logits are N-D matching the caller's logits shape; the
118116 // kernel sees a 2D {num_tokens, num_experts} view.
@@ -137,8 +135,8 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter
137135 auto grad_dtype = GetTransformerEngineDType (grad_probs.scalar_type ());
138136 auto routing_map_dtype = GetTransformerEngineDType (routing_map.scalar_type ());
139137
140- auto routing_map_cu = makeTransformerEngineTensor (routing_map. data_ptr (),
141- routing_map_shape_2d, routing_map_dtype);
138+ auto routing_map_cu =
139+ makeTransformerEngineTensor (routing_map. data_ptr (), routing_map_shape_2d, routing_map_dtype);
142140 auto intermediate_output_cu =
143141 makeTransformerEngineTensor (intermediate_output.data_ptr (), shape_2d, DType::kFloat32 );
144142 auto grad_probs_cu = makeTransformerEngineTensor (grad_probs.data_ptr (), shape_2d, grad_dtype);
@@ -169,10 +167,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
169167
170168 // N-D allocations matching logits shape (except routing_map trailing dim).
171169 at::Tensor scores = at::empty (sizes, at::dtype (at::kFloat ).device (at::kCUDA ));
172- at::Tensor routing_map = allocate_routing_map (sizes.slice (0 , sizes.size () - 1 ), num_experts,
173- routing_map_format);
174- at::Tensor intermediate_output =
175- at::empty (sizes, at::dtype (at::kFloat ).device (at::kCUDA ));
170+ at::Tensor routing_map =
171+ allocate_routing_map (sizes.slice (0 , sizes.size () - 1 ), num_experts, routing_map_format);
172+ at::Tensor intermediate_output = at::empty (sizes, at::dtype (at::kFloat ).device (at::kCUDA ));
176173
177174 const std::vector<size_t > shape_2d = {static_cast <size_t >(num_tokens),
178175 static_cast <size_t >(num_experts)};
@@ -186,10 +183,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
186183
187184 auto logits_cu = makeTransformerEngineTensor (logits.data_ptr (), shape_2d, logits_dtype);
188185 auto scores_cu = makeTransformerEngineTensor (scores.data_ptr (), shape_2d, DType::kFloat32 );
189- auto routing_map_cu = makeTransformerEngineTensor (
190- routing_map.data_ptr (), routing_map_shape_2d, routing_map_dtype);
191- auto intermediate_output_cu = makeTransformerEngineTensor (
192- intermediate_output.data_ptr (), shape_2d, DType::kFloat32 );
186+ auto routing_map_cu =
187+ makeTransformerEngineTensor ( routing_map.data_ptr (), routing_map_shape_2d, routing_map_dtype);
188+ auto intermediate_output_cu =
189+ makeTransformerEngineTensor ( intermediate_output.data_ptr (), shape_2d, DType::kFloat32 );
193190
194191 nvte_fused_score_for_moe_aux_loss_forward_v2 (
195192 logits_cu.data (), static_cast <int >(num_tokens), static_cast <int >(num_experts), topk,
0 commit comments