Skip to content

Commit 94ef4fa

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5572c97 commit 94ef4fa

3 files changed

Lines changed: 19 additions & 24 deletions

File tree

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
4747
int routing_map_format = static_cast<int>(NVTE_ROUTING_MAP_FORMAT_BYTEMAP));
4848

4949
void fused_score_for_moe_aux_loss_bwd(at::Tensor intermediate_output, at::Tensor grad_scores,
50-
at::Tensor grad_logits, int topk,
51-
std::string score_function);
50+
at::Tensor grad_logits, int topk, std::string score_function);
5251

5352
std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
5453
at::Tensor tokens_per_expert,

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,7 @@ void init_router_bindings(pybind11::module &m) {
156156
"Fused aux loss with score function fwd");
157157
m.def("fused_score_for_moe_aux_loss_bwd", &fused_score_for_moe_aux_loss_bwd,
158158
py::arg("intermediate_output"), py::arg("grad_scores"), py::arg("grad_logits"),
159-
py::arg("topk"), py::arg("score_function"),
160-
"Fused aux loss with score function bwd");
159+
py::arg("topk"), py::arg("score_function"), "Fused aux loss with score function bwd");
161160
m.def("fused_moe_aux_loss_fwd", &fused_moe_aux_loss_fwd, py::arg("probs"),
162161
py::arg("tokens_per_expert"), py::arg("total_num_tokens"), py::arg("num_experts"),
163162
py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), py::arg("coeff"),

transformer_engine/pytorch/csrc/extensions/router.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fw
6969
// match logits exactly; routing_map's trailing dim depends on format).
7070
// No caller-side .view() needed.
7171
at::Tensor probs = at::empty(sizes, at::dtype(logits.scalar_type()).device(at::kCUDA));
72-
at::Tensor routing_map = allocate_routing_map(sizes.slice(0, sizes.size() - 1), num_experts,
73-
routing_map_format);
74-
at::Tensor intermediate_output =
75-
at::empty(sizes, at::dtype(at::kFloat).device(at::kCUDA));
72+
at::Tensor routing_map =
73+
allocate_routing_map(sizes.slice(0, sizes.size() - 1), num_experts, routing_map_format);
74+
at::Tensor intermediate_output = at::empty(sizes, at::dtype(at::kFloat).device(at::kCUDA));
7675

7776
// Wrap with explicit 2D shape for the kernel — the common-layer NVTE_CHECKs
7877
// expect {num_tokens, num_experts} (or {num_tokens, ceil(num_experts/8)} for
@@ -90,10 +89,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fw
9089

9190
auto logits_cu = makeTransformerEngineTensor(logits.data_ptr(), shape_2d, logits_dtype);
9291
auto probs_cu = makeTransformerEngineTensor(probs.data_ptr(), shape_2d, logits_dtype);
93-
auto routing_map_cu = makeTransformerEngineTensor(
94-
routing_map.data_ptr(), routing_map_shape_2d, routing_map_dtype);
95-
auto intermediate_output_cu = makeTransformerEngineTensor(
96-
intermediate_output.data_ptr(), shape_2d, DType::kFloat32);
92+
auto routing_map_cu =
93+
makeTransformerEngineTensor(routing_map.data_ptr(), routing_map_shape_2d, routing_map_dtype);
94+
auto intermediate_output_cu =
95+
makeTransformerEngineTensor(intermediate_output.data_ptr(), shape_2d, DType::kFloat32);
9796
auto expert_bias_cu = TensorWrapper(); // empty expert_bias_cu tensor
9897
if (expert_bias.has_value()) {
9998
expert_bias_cu = makeTransformerEngineTensor(expert_bias.value());
@@ -111,8 +110,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fw
111110

112111
void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor intermediate_output,
113112
at::Tensor grad_probs, at::Tensor grad_logits, int topk,
114-
bool use_pre_softmax,
115-
std::optional<float> scaling_factor,
113+
bool use_pre_softmax, std::optional<float> scaling_factor,
116114
std::string score_function, int routing_map_format) {
117115
// grad_probs / grad_logits are N-D matching the caller's logits shape; the
118116
// kernel sees a 2D {num_tokens, num_experts} view.
@@ -137,8 +135,8 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter
137135
auto grad_dtype = GetTransformerEngineDType(grad_probs.scalar_type());
138136
auto routing_map_dtype = GetTransformerEngineDType(routing_map.scalar_type());
139137

140-
auto routing_map_cu = makeTransformerEngineTensor(routing_map.data_ptr(),
141-
routing_map_shape_2d, routing_map_dtype);
138+
auto routing_map_cu =
139+
makeTransformerEngineTensor(routing_map.data_ptr(), routing_map_shape_2d, routing_map_dtype);
142140
auto intermediate_output_cu =
143141
makeTransformerEngineTensor(intermediate_output.data_ptr(), shape_2d, DType::kFloat32);
144142
auto grad_probs_cu = makeTransformerEngineTensor(grad_probs.data_ptr(), shape_2d, grad_dtype);
@@ -169,10 +167,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
169167

170168
// N-D allocations matching logits shape (except routing_map trailing dim).
171169
at::Tensor scores = at::empty(sizes, at::dtype(at::kFloat).device(at::kCUDA));
172-
at::Tensor routing_map = allocate_routing_map(sizes.slice(0, sizes.size() - 1), num_experts,
173-
routing_map_format);
174-
at::Tensor intermediate_output =
175-
at::empty(sizes, at::dtype(at::kFloat).device(at::kCUDA));
170+
at::Tensor routing_map =
171+
allocate_routing_map(sizes.slice(0, sizes.size() - 1), num_experts, routing_map_format);
172+
at::Tensor intermediate_output = at::empty(sizes, at::dtype(at::kFloat).device(at::kCUDA));
176173

177174
const std::vector<size_t> shape_2d = {static_cast<size_t>(num_tokens),
178175
static_cast<size_t>(num_experts)};
@@ -186,10 +183,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
186183

187184
auto logits_cu = makeTransformerEngineTensor(logits.data_ptr(), shape_2d, logits_dtype);
188185
auto scores_cu = makeTransformerEngineTensor(scores.data_ptr(), shape_2d, DType::kFloat32);
189-
auto routing_map_cu = makeTransformerEngineTensor(
190-
routing_map.data_ptr(), routing_map_shape_2d, routing_map_dtype);
191-
auto intermediate_output_cu = makeTransformerEngineTensor(
192-
intermediate_output.data_ptr(), shape_2d, DType::kFloat32);
186+
auto routing_map_cu =
187+
makeTransformerEngineTensor(routing_map.data_ptr(), routing_map_shape_2d, routing_map_dtype);
188+
auto intermediate_output_cu =
189+
makeTransformerEngineTensor(intermediate_output.data_ptr(), shape_2d, DType::kFloat32);
193190

194191
nvte_fused_score_for_moe_aux_loss_forward_v2(
195192
logits_cu.data(), static_cast<int>(num_tokens), static_cast<int>(num_experts), topk,

0 commit comments

Comments
 (0)