@@ -243,36 +243,33 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B,
243243}
244244
245245std::vector<at::Tensor> output_tensor_list_from_arg (py::handle arg, size_t num_groups,
246- at::ScalarType dtype, const std::string &name) {
246+ int64_t rows, int64_t cols,
247+ const std::string &name) {
247248 std::vector<at::Tensor> out;
248249 if (is_none (arg)) {
249250 return out;
250251 }
251252 out.reserve (num_groups);
252- if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) {
253- auto seq = py::reinterpret_borrow<py::sequence>(arg);
254- NVTE_CHECK (static_cast <size_t >(seq.size ()) == num_groups, name, " must have " , num_groups,
255- " tensors." );
256- for (size_t i = 0 ; i < num_groups; ++i) {
257- auto tensor = seq[i].cast <at::Tensor>();
258- NVTE_CHECK (tensor.is_cuda (), name, " tensors must be CUDA tensors." );
259- NVTE_CHECK (tensor.scalar_type () == dtype, name, " tensors must have the requested dtype." );
260- NVTE_CHECK (tensor.dim () == 2 , name, " tensors must be rank-2 wgrad buffers." );
261- check_contiguous (tensor, name);
262- out.emplace_back (tensor);
263- }
264- return out;
265- }
266-
267- auto packed = arg.cast <at::Tensor>();
268- NVTE_CHECK (packed.is_cuda (), name, " must be a CUDA tensor." );
269- NVTE_CHECK (packed.scalar_type () == dtype, name, " must have the requested dtype." );
270- NVTE_CHECK (packed.dim () == 3 , name, " must have shape [num_groups, rows, cols]." );
271- NVTE_CHECK (static_cast <size_t >(packed.size (0 )) == num_groups, name, " first dimension must be " ,
272- num_groups, " ." );
273- check_contiguous (packed, name);
253+ // This helper is intentionally only for the discrete-weight external wgrad
254+ // path, where Megatron provides one main_grad tensor per expert. The packed
255+ // [G, rows, cols] external buffer used by single grouped weight is handled in
256+ // wgrad_output_from_arg so it can stay packed and use grouped-tensor GEMM.
257+ NVTE_CHECK (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg), name,
258+ " must be a list or tuple of wgrad output tensors." );
259+ auto seq = py::reinterpret_borrow<py::sequence>(arg);
260+ NVTE_CHECK (static_cast <size_t >(seq.size ()) == num_groups, name, " must have " , num_groups,
261+ " tensors." );
274262 for (size_t i = 0 ; i < num_groups; ++i) {
275- out.emplace_back (packed.select (0 , static_cast <int64_t >(i)));
263+ auto tensor = seq[i].cast <at::Tensor>();
264+ NVTE_CHECK (tensor.is_cuda (), name, " tensors must be CUDA tensors." );
265+ // Do not require tensor.scalar_type() == compute dtype. Caller-owned
266+ // main_grad buffers are allocated by Megatron and may be FP32 even when TE
267+ // grouped MLP compute is BF16.
268+ NVTE_CHECK (tensor.dim () == 2 , name, " tensors must be rank-2 wgrad buffers." );
269+ NVTE_CHECK (tensor.size (0 ) == rows && tensor.size (1 ) == cols, name,
270+ " tensors must have shape [rows, cols]." );
271+ check_contiguous (tensor, name);
272+ out.emplace_back (tensor);
276273 }
277274 return out;
278275}
@@ -315,7 +312,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num
315312 // should not receive a newly allocated grad tensor from this helper.
316313 out.packed = arg.cast <at::Tensor>();
317314 NVTE_CHECK (out.packed .is_cuda (), name, " must be a CUDA tensor." );
318- NVTE_CHECK (out.packed .scalar_type () == dtype, name, " must have the requested dtype." );
315+ // Do not require out.packed.scalar_type() == compute dtype. Caller-owned
316+ // main_grad buffers keep the dtype chosen by Megatron's grad-buffer config.
319317 NVTE_CHECK (out.packed .dim () == 3 , name, " must have shape [num_groups, rows, cols]." );
320318 NVTE_CHECK (static_cast <size_t >(out.packed .size (0 )) == num_groups, name,
321319 " first dimension must be " , num_groups, " ." );
@@ -328,7 +326,7 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num
328326 // Case 4: discrete weights with externally-owned per-expert buffers, e.g.
329327 // Megatron main_grad list. GEMM writes each tensor in-place and returns no
330328 // allocated grad list to Python.
331- out.tensors = output_tensor_list_from_arg (arg, num_groups, dtype , name);
329+ out.tensors = output_tensor_list_from_arg (arg, num_groups, rows, cols , name);
332330 return out;
333331}
334332
0 commit comments