Commit 5572c97
committed
[PyTorch] Address PR #3009 review: remove .view() calls, int routing_map_format
Apply the four CPU-overhead fixes the reviewer asked for and the
CLAUDE.md "CPU overhead in PyTorch wrappers" section codifies:
1. _validate_routing_map_format returns plain int (not enum); the
autograd Function + tex.* bindings only see ints. Validates via
precomputed frozenset and a single dict.get with canonical
lowercase keys (no .lower()/.upper()).
2. Type annotations on Function.forward use int (not the string
forward-ref 'RoutingMapFormat').
3. Removed every .view() from FusedTopkScoreFunction.{forward,backward}
and FusedComputeScoresForMoEAuxLoss.{forward,backward}. C++
extension now accepts N-D logits/grad_probs, computes num_tokens
from the product of leading dims, num_experts from the last dim,
allocates outputs at the user-facing N-D shape, and wraps tensors
with an explicit 2D shape via makeTransformerEngineTensor only for
the kernel call. Asserts is_contiguous() on inputs.
4. Bwd allocates grad_logits with torch.empty_like(grad_probs) (N-D)
instead of allocate-2D-then-view.
PyTorch-extension boundary takes 'int routing_map_format' and casts
to NVTERoutingMapFormat inside; the common-layer C API (nvte_*_v2)
keeps the enum.
Signed-off-by: tdophung <tdophung@nvidia.com>1 parent af30717 commit 5572c97
4 files changed
Lines changed: 217 additions & 159 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
29 | 33 | | |
30 | 34 | | |
31 | 35 | | |
32 | 36 | | |
33 | | - | |
| 37 | + | |
34 | 38 | | |
35 | 39 | | |
36 | | - | |
37 | | - | |
38 | | - | |
39 | | - | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
40 | 44 | | |
41 | 45 | | |
42 | 46 | | |
43 | | - | |
| 47 | + | |
44 | 48 | | |
45 | | - | |
46 | | - | |
47 | | - | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
48 | 52 | | |
49 | 53 | | |
50 | 54 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
136 | 136 | | |
137 | 137 | | |
138 | 138 | | |
| 139 | + | |
| 140 | + | |
139 | 141 | | |
140 | 142 | | |
141 | 143 | | |
142 | | - | |
| 144 | + | |
| 145 | + | |
143 | 146 | | |
144 | 147 | | |
145 | | - | |
146 | | - | |
147 | | - | |
148 | | - | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
149 | 152 | | |
150 | 153 | | |
151 | 154 | | |
152 | | - | |
| 155 | + | |
153 | 156 | | |
154 | 157 | | |
155 | | - | |
156 | | - | |
| 158 | + | |
| 159 | + | |
157 | 160 | | |
158 | 161 | | |
159 | 162 | | |
| |||
0 commit comments