Skip to content

Commit c9d56c4

Browse files
committed
fixes for E2E run
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
1 parent 19304b4 commit c9d56c4

2 files changed

Lines changed: 40 additions & 34 deletions

File tree

tests/pytorch/megacpp/test_grouped_mlp.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,20 +98,20 @@ def _copy_grouped_mlp_params(dst: te_ops.Sequential, src: te_ops.Sequential) ->
9898
)
9999

100100

101-
def _init_main_grads(module: te_ops.Sequential) -> None:
101+
def _init_main_grads(module: te_ops.Sequential, dtype: torch.dtype) -> None:
102102
for linear in (module[0], module[2]):
103103
if linear.single_grouped_weight:
104104
linear.weight.main_grad = torch.zeros(
105105
linear.num_groups,
106106
linear.out_features,
107107
linear.in_features,
108108
device="cuda",
109-
dtype=torch.bfloat16,
109+
dtype=dtype,
110110
)
111111
else:
112112
for group_idx in range(linear.num_groups):
113113
weight = getattr(linear, f"weight{group_idx}")
114-
weight.main_grad = torch.zeros_like(weight)
114+
weight.main_grad = torch.zeros_like(weight, dtype=dtype)
115115

116116

117117
def _run_grouped_mlp(
@@ -241,6 +241,7 @@ def _run_megacpp_against_python(
241241
activation_kind: str = "scaled_swiglu",
242242
single_grouped_param: bool = False,
243243
accumulate_into_main_grad: bool = False,
244+
main_grad_dtype: torch.dtype | None = None,
244245
compare_zero_expert_grads: bool = True,
245246
monkeypatch,
246247
) -> None:
@@ -274,8 +275,10 @@ def _run_megacpp_against_python(
274275
)
275276
_copy_grouped_mlp_params(test, ref)
276277
if accumulate_into_main_grad:
277-
_init_main_grads(ref)
278-
_init_main_grads(test)
278+
if main_grad_dtype is None:
279+
raise ValueError("main_grad_dtype must be set when using Megatron-owned main_grad.")
280+
_init_main_grads(ref, main_grad_dtype)
281+
_init_main_grads(test, main_grad_dtype)
279282

280283
# Paged stashing passes a static physical buffer to the op while m_splits
281284
# describe only the valid prefix. Rows after sum(m_splits) are garbage and
@@ -332,13 +335,17 @@ def _run_megacpp_against_python(
332335
ids=["discrete_weight", "packed_weight"],
333336
)
334337
@pytest.mark.parametrize(
335-
"accumulate_into_main_grad",
336-
[False, True],
337-
ids=["cpp_allocated_wgrad", "megatron_main_grad"],
338+
"accumulate_into_main_grad,main_grad_dtype",
339+
[
340+
pytest.param(False, None, id="cpp_allocated_wgrad"),
341+
pytest.param(True, torch.bfloat16, id="megatron_main_grad_bf16"),
342+
pytest.param(True, torch.float32, id="megatron_main_grad_fp32"),
343+
],
338344
)
339345
def test_megacpp_grouped_mlp_wgrad_storage_matches_python(
340346
single_grouped_param,
341347
accumulate_into_main_grad,
348+
main_grad_dtype,
342349
monkeypatch,
343350
):
344351
torch.manual_seed(1234)
@@ -349,6 +356,7 @@ def test_megacpp_grouped_mlp_wgrad_storage_matches_python(
349356
split_device="cuda",
350357
single_grouped_param=single_grouped_param,
351358
accumulate_into_main_grad=accumulate_into_main_grad,
359+
main_grad_dtype=main_grad_dtype,
352360
monkeypatch=monkeypatch,
353361
)
354362

transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -243,36 +243,33 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B,
243243
}
244244

245245
std::vector<at::Tensor> output_tensor_list_from_arg(py::handle arg, size_t num_groups,
246-
at::ScalarType dtype, const std::string &name) {
246+
int64_t rows, int64_t cols,
247+
const std::string &name) {
247248
std::vector<at::Tensor> out;
248249
if (is_none(arg)) {
249250
return out;
250251
}
251252
out.reserve(num_groups);
252-
if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) {
253-
auto seq = py::reinterpret_borrow<py::sequence>(arg);
254-
NVTE_CHECK(static_cast<size_t>(seq.size()) == num_groups, name, " must have ", num_groups,
255-
" tensors.");
256-
for (size_t i = 0; i < num_groups; ++i) {
257-
auto tensor = seq[i].cast<at::Tensor>();
258-
NVTE_CHECK(tensor.is_cuda(), name, " tensors must be CUDA tensors.");
259-
NVTE_CHECK(tensor.scalar_type() == dtype, name, " tensors must have the requested dtype.");
260-
NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2 wgrad buffers.");
261-
check_contiguous(tensor, name);
262-
out.emplace_back(tensor);
263-
}
264-
return out;
265-
}
266-
267-
auto packed = arg.cast<at::Tensor>();
268-
NVTE_CHECK(packed.is_cuda(), name, " must be a CUDA tensor.");
269-
NVTE_CHECK(packed.scalar_type() == dtype, name, " must have the requested dtype.");
270-
NVTE_CHECK(packed.dim() == 3, name, " must have shape [num_groups, rows, cols].");
271-
NVTE_CHECK(static_cast<size_t>(packed.size(0)) == num_groups, name, " first dimension must be ",
272-
num_groups, ".");
273-
check_contiguous(packed, name);
253+
// This helper is intentionally only for the discrete-weight external wgrad
254+
// path, where Megatron provides one main_grad tensor per expert. The packed
255+
// [G, rows, cols] external buffer used by single grouped weight is handled in
256+
// wgrad_output_from_arg so it can stay packed and use grouped-tensor GEMM.
257+
NVTE_CHECK(py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg), name,
258+
" must be a list or tuple of wgrad output tensors.");
259+
auto seq = py::reinterpret_borrow<py::sequence>(arg);
260+
NVTE_CHECK(static_cast<size_t>(seq.size()) == num_groups, name, " must have ", num_groups,
261+
" tensors.");
274262
for (size_t i = 0; i < num_groups; ++i) {
275-
out.emplace_back(packed.select(0, static_cast<int64_t>(i)));
263+
auto tensor = seq[i].cast<at::Tensor>();
264+
NVTE_CHECK(tensor.is_cuda(), name, " tensors must be CUDA tensors.");
265+
// Do not require tensor.scalar_type() == compute dtype. Caller-owned
266+
// main_grad buffers are allocated by Megatron and may be FP32 even when TE
267+
// grouped MLP compute is BF16.
268+
NVTE_CHECK(tensor.dim() == 2, name, " tensors must be rank-2 wgrad buffers.");
269+
NVTE_CHECK(tensor.size(0) == rows && tensor.size(1) == cols, name,
270+
" tensors must have shape [rows, cols].");
271+
check_contiguous(tensor, name);
272+
out.emplace_back(tensor);
276273
}
277274
return out;
278275
}
@@ -315,7 +312,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num
315312
// should not receive a newly allocated grad tensor from this helper.
316313
out.packed = arg.cast<at::Tensor>();
317314
NVTE_CHECK(out.packed.is_cuda(), name, " must be a CUDA tensor.");
318-
NVTE_CHECK(out.packed.scalar_type() == dtype, name, " must have the requested dtype.");
315+
// Do not require out.packed.scalar_type() == compute dtype. Caller-owned
316+
// main_grad buffers keep the dtype chosen by Megatron's grad-buffer config.
319317
NVTE_CHECK(out.packed.dim() == 3, name, " must have shape [num_groups, rows, cols].");
320318
NVTE_CHECK(static_cast<size_t>(out.packed.size(0)) == num_groups, name,
321319
" first dimension must be ", num_groups, ".");
@@ -328,7 +326,7 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num
328326
// Case 4: discrete weights with externally-owned per-expert buffers, e.g.
329327
// Megatron main_grad list. GEMM writes each tensor in-place and returns no
330328
// allocated grad list to Python.
331-
out.tensors = output_tensor_list_from_arg(arg, num_groups, dtype, name);
329+
out.tensors = output_tensor_list_from_arg(arg, num_groups, rows, cols, name);
332330
return out;
333331
}
334332

0 commit comments

Comments
 (0)