Skip to content

Commit 5572c97

Browse files
committed
[PyTorch] Address PR #3009 review: remove .view() calls, int routing_map_format
Apply the four CPU-overhead fixes the reviewer asked for and the CLAUDE.md "CPU overhead in PyTorch wrappers" section codifies: 1. _validate_routing_map_format returns plain int (not enum); the autograd Function + tex.* bindings only see ints. Validates via precomputed frozenset and a single dict.get with canonical lowercase keys (no .lower()/.upper()). 2. Type annotations on Function.forward use int (not the string forward-ref 'RoutingMapFormat'). 3. Removed every .view() from FusedTopkScoreFunction.{forward,backward} and FusedComputeScoresForMoEAuxLoss.{forward,backward}. C++ extension now accepts N-D logits/grad_probs, computes num_tokens from the product of leading dims, num_experts from the last dim, allocates outputs at the user-facing N-D shape, and wraps tensors with an explicit 2D shape via makeTransformerEngineTensor only for the kernel call. Asserts is_contiguous() on inputs. 4. Bwd allocates grad_logits with torch.empty_like(grad_probs) (N-D) instead of allocate-2D-then-view. PyTorch-extension boundary takes 'int routing_map_format' and casts to NVTERoutingMapFormat inside; the common-layer C API (nvte_*_v2) keeps the enum. Signed-off-by: tdophung <tdophung@nvidia.com>
1 parent af30717 commit 5572c97

4 files changed

Lines changed: 217 additions & 159 deletions

File tree

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,29 @@ namespace transformer_engine::pytorch {
2626
* Router fusion
2727
**************************************************************************************************/
2828

29+
// PyTorch-extension boundary uses int for routing_map_format (not the enum) to
30+
// keep the Python hot path free of pybind11 enum construction. The int is
31+
// cast to NVTERoutingMapFormat once inside each function. See CLAUDE.md
32+
// "CPU overhead in PyTorch wrappers".
2933
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fwd(
3034
at::Tensor logits, int topk, bool use_pre_softmax, std::optional<int> num_groups,
3135
std::optional<int> group_topk, std::optional<float> scaling_factor, std::string score_function,
3236
std::optional<at::Tensor> expert_bias,
33-
NVTERoutingMapFormat routing_map_format = NVTE_ROUTING_MAP_FORMAT_BYTEMAP);
37+
int routing_map_format = static_cast<int>(NVTE_ROUTING_MAP_FORMAT_BYTEMAP));
3438

3539
void fused_topk_with_score_function_bwd(
36-
int num_tokens, int num_experts, at::Tensor routing_map, at::Tensor intermediate_output,
37-
at::Tensor grad_probs, at::Tensor grad_logits, int topk, bool use_pre_softmax,
38-
std::optional<float> scaling_factor, std::string score_function,
39-
NVTERoutingMapFormat routing_map_format = NVTE_ROUTING_MAP_FORMAT_BYTEMAP);
40+
at::Tensor routing_map, at::Tensor intermediate_output, at::Tensor grad_probs,
41+
at::Tensor grad_logits, int topk, bool use_pre_softmax, std::optional<float> scaling_factor,
42+
std::string score_function,
43+
int routing_map_format = static_cast<int>(NVTE_ROUTING_MAP_FORMAT_BYTEMAP));
4044

4145
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
4246
at::Tensor logits, int topk, std::string score_function,
43-
NVTERoutingMapFormat routing_map_format = NVTE_ROUTING_MAP_FORMAT_BYTEMAP);
47+
int routing_map_format = static_cast<int>(NVTE_ROUTING_MAP_FORMAT_BYTEMAP));
4448

45-
void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
46-
at::Tensor intermediate_output, at::Tensor grad_probs,
47-
at::Tensor grad_logits, int topk, std::string score_function);
49+
void fused_score_for_moe_aux_loss_bwd(at::Tensor intermediate_output, at::Tensor grad_scores,
50+
at::Tensor grad_logits, int topk,
51+
std::string score_function);
4852

4953
std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
5054
at::Tensor tokens_per_expert,

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,24 +136,27 @@ void init_router_bindings(pybind11::module &m) {
136136
pybind11::enum_<NVTERoutingMapFormat>(m, "NVTERoutingMapFormat", pybind11::module_local())
137137
.value("BYTEMAP", NVTE_ROUTING_MAP_FORMAT_BYTEMAP)
138138
.value("BITMAP_U8", NVTE_ROUTING_MAP_FORMAT_BITMAP_U8);
139+
// routing_map_format is passed as int (not the enum) on the PyTorch hot
140+
// path; see CLAUDE.md "CPU overhead in PyTorch wrappers".
139141
m.def("fused_topk_with_score_function_fwd", &fused_topk_with_score_function_fwd,
140142
py::arg("logits"), py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"),
141143
py::arg("group_topk"), py::arg("scaling_factor"), py::arg("score_function"),
142-
py::arg("expert_bias"), py::arg("routing_map_format") = NVTE_ROUTING_MAP_FORMAT_BYTEMAP,
144+
py::arg("expert_bias"),
145+
py::arg("routing_map_format") = static_cast<int>(NVTE_ROUTING_MAP_FORMAT_BYTEMAP),
143146
"Fused topk with score function fwd");
144147
m.def("fused_topk_with_score_function_bwd", &fused_topk_with_score_function_bwd,
145-
py::arg("num_tokens"), py::arg("num_experts"), py::arg("routing_map"),
146-
py::arg("intermediate_output"), py::arg("grad_probs"), py::arg("grad_logits"),
147-
py::arg("topk"), py::arg("use_pre_softmax"), py::arg("scaling_factor"),
148-
py::arg("score_function"), py::arg("routing_map_format") = NVTE_ROUTING_MAP_FORMAT_BYTEMAP,
148+
py::arg("routing_map"), py::arg("intermediate_output"), py::arg("grad_probs"),
149+
py::arg("grad_logits"), py::arg("topk"), py::arg("use_pre_softmax"),
150+
py::arg("scaling_factor"), py::arg("score_function"),
151+
py::arg("routing_map_format") = static_cast<int>(NVTE_ROUTING_MAP_FORMAT_BYTEMAP),
149152
"Fused topk with score function bwd");
150153
m.def("fused_score_for_moe_aux_loss_fwd", &fused_score_for_moe_aux_loss_fwd, py::arg("logits"),
151154
py::arg("topk"), py::arg("score_function"),
152-
py::arg("routing_map_format") = NVTE_ROUTING_MAP_FORMAT_BYTEMAP,
155+
py::arg("routing_map_format") = static_cast<int>(NVTE_ROUTING_MAP_FORMAT_BYTEMAP),
153156
"Fused aux loss with score function fwd");
154157
m.def("fused_score_for_moe_aux_loss_bwd", &fused_score_for_moe_aux_loss_bwd,
155-
py::arg("num_tokens"), py::arg("num_experts"), py::arg("intermediate_output"),
156-
py::arg("grad_scores"), py::arg("grad_logits"), py::arg("topk"), py::arg("score_function"),
158+
py::arg("intermediate_output"), py::arg("grad_scores"), py::arg("grad_logits"),
159+
py::arg("topk"), py::arg("score_function"),
157160
"Fused aux loss with score function bwd");
158161
m.def("fused_moe_aux_loss_fwd", &fused_moe_aux_loss_fwd, py::arg("probs"),
159162
py::arg("tokens_per_expert"), py::arg("total_num_tokens"), py::arg("num_experts"),

0 commit comments

Comments
 (0)