Skip to content

Commit 19304b4

Browse files
pre-commit-ci[bot]zhongbozhu
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent c2819b1 commit 19304b4

6 files changed

Lines changed: 85 additions & 109 deletions

File tree

tests/pytorch/megacpp/test_grouped_mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,10 +464,10 @@ def test_megacpp_grouped_mlp_delay_wgrad_raises(monkeypatch):
464464
glu_interleave_size=None,
465465
single_grouped_param=False,
466466
)
467-
x = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16).requires_grad_()
468-
act_scales = torch.rand(
469-
total_tokens, device="cuda", dtype=torch.bfloat16
467+
x = torch.randn(
468+
total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16
470469
).requires_grad_()
470+
act_scales = torch.rand(total_tokens, device="cuda", dtype=torch.bfloat16).requires_grad_()
471471
dy = torch.randn(total_tokens, _HIDDEN_SIZE, device="cuda", dtype=torch.bfloat16)
472472

473473
monkeypatch.setenv("NVTE_MEGACPP_GROUPED_LINEAR", "1")

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,14 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
198198

199199
py::tuple megacpp_grouped_mlp_backward(
200200
const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets,
201-
const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets,
202-
const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets,
203-
const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x,
204-
const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
205-
py::handle fc2_weight,
206-
py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad,
207-
py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad,
208-
const std::string &activation, int64_t glu_interleave_size, double activation_limit,
209-
double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad,
210-
bool input_requires_grad);
201+
const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets,
202+
const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input,
203+
const at::Tensor &fc2_x, const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
204+
py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad,
205+
bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad,
206+
bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size,
207+
double activation_limit, double activation_alpha, double activation_glu_linear_offset,
208+
bool act_scales_requires_grad, bool input_requires_grad);
211209

212210
/***************************************************************************************************
213211
* Transpose

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,15 +357,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
357357
m.def("te_general_grouped_gemm_for_discrete_out",
358358
&transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out,
359359
"Grouped GEMM for discrete output list");
360-
m.def("megacpp_grouped_mlp_forward",
361-
&transformer_engine::pytorch::megacpp_grouped_mlp_forward,
360+
m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward,
362361
"Mega C++ grouped MLP forward", py::arg("input"), py::arg("split_sizes"),
363362
py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), py::arg("fc2_bias"),
364363
py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"),
365364
py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0,
366365
py::arg("activation_glu_linear_offset") = 0.0);
367-
m.def("megacpp_grouped_mlp_backward",
368-
&transformer_engine::pytorch::megacpp_grouped_mlp_backward,
366+
m.def("megacpp_grouped_mlp_backward", &transformer_engine::pytorch::megacpp_grouped_mlp_backward,
369367
"Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("split_sizes"),
370368
py::arg("x_offsets"), py::arg("fc1_offsets"), py::arg("fc2_offsets"),
371369
py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"),

transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp

Lines changed: 53 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
* See LICENSE for license information.
55
************************************************************************/
66

7+
#include <ATen/cuda/CUDAContext.h>
8+
#include <c10/cuda/CUDAGuard.h>
79
#include <pybind11/pybind11.h>
810
#include <pybind11/stl.h>
911

@@ -12,9 +14,6 @@
1214
#include <tuple>
1315
#include <vector>
1416

15-
#include <ATen/cuda/CUDAContext.h>
16-
#include <c10/cuda/CUDAGuard.h>
17-
1817
#include "../extensions.h"
1918
#include "../pybind.h"
2019
#include "common/util/cuda_runtime.h"
@@ -58,7 +57,8 @@ size_t num_groups_from_prepared_split_sizes(const at::Tensor &split_sizes,
5857
}
5958

6059
GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prepared_split_sizes,
61-
const at::Tensor &tensor_offsets, int64_t logical_last_dim) {
60+
const at::Tensor &tensor_offsets,
61+
int64_t logical_last_dim) {
6262
const auto num_groups = static_cast<size_t>(prepared_split_sizes.numel());
6363
const auto total_tokens = static_cast<size_t>(data.numel() / logical_last_dim);
6464
auto grouped = GroupedTensorWrapper(
@@ -75,9 +75,8 @@ GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prep
7575
GroupedTensorWrapper make_uniform_grouped_tensor(at::Tensor data, size_t num_groups,
7676
int64_t first_dim, int64_t last_dim) {
7777
auto grouped = GroupedTensorWrapper(
78-
num_groups,
79-
std::vector<size_t>{num_groups * static_cast<size_t>(first_dim),
80-
static_cast<size_t>(last_dim)});
78+
num_groups, std::vector<size_t>{num_groups * static_cast<size_t>(first_dim),
79+
static_cast<size_t>(last_dim)});
8180
grouped.set_rowwise_data(data.data_ptr(), GetTransformerEngineDType(data.scalar_type()),
8281
tensor_shape_1d(data));
8382
return grouped;
@@ -94,9 +93,7 @@ struct GroupedWeightArg {
9493
int64_t rows = 0;
9594
int64_t cols = 0;
9695

97-
c10::Device device() const {
98-
return is_grouped ? packed.device() : discrete[0].device();
99-
}
96+
c10::Device device() const { return is_grouped ? packed.device() : discrete[0].device(); }
10097
};
10198

10299
GroupedWeightArg weight_arg_from_py(py::handle arg, size_t num_groups, at::ScalarType dtype,
@@ -201,9 +198,9 @@ struct GroupedGemmResources {
201198
te_alpha(makeTransformerEngineTensor(alpha)),
202199
te_beta_zero(makeTransformerEngineTensor(beta_zero)),
203200
te_beta_one(makeTransformerEngineTensor(beta_one)),
204-
te_setup(makeTransformerEngineTensor(setup.data_ptr(),
205-
std::vector<size_t>{static_cast<size_t>(setup.numel())},
206-
DType::kByte)),
201+
te_setup(makeTransformerEngineTensor(
202+
setup.data_ptr(), std::vector<size_t>{static_cast<size_t>(setup.numel())},
203+
DType::kByte)),
207204
te_cublas(makeTransformerEngineTensor(
208205
cublas.data_ptr(), std::vector<size_t>{static_cast<size_t>(cublas.numel())},
209206
DType::kByte)) {
@@ -220,9 +217,7 @@ struct GroupedGemmResources {
220217
}
221218
}
222219

223-
NVTETensor beta(bool accumulate) {
224-
return accumulate ? te_beta_one.data() : te_beta_zero.data();
225-
}
220+
NVTETensor beta(bool accumulate) { return accumulate ? te_beta_one.data() : te_beta_zero.data(); }
226221

227222
NVTEGroupedMatmulConfig config_data() {
228223
return config.has_value() ? static_cast<NVTEGroupedMatmulConfig>(*config) : nullptr;
@@ -243,14 +238,12 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B,
243238
nvte_grouped_gemm(A->data(), transa, B->data(), transb, D->data(), D->data(),
244239
resources->te_alpha.data(), resources->beta(accumulate),
245240
resources->te_setup.data(), resources->te_cublas.data(),
246-
resources->config_data(),
247-
at::cuda::getCurrentCUDAStream());
241+
resources->config_data(), at::cuda::getCurrentCUDAStream());
248242
});
249243
}
250244

251245
std::vector<at::Tensor> output_tensor_list_from_arg(py::handle arg, size_t num_groups,
252-
at::ScalarType dtype,
253-
const std::string &name) {
246+
at::ScalarType dtype, const std::string &name) {
254247
std::vector<at::Tensor> out;
255248
if (is_none(arg)) {
256249
return out;
@@ -303,8 +296,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num
303296
// Cases 1 and 2: no external wgrad buffer was provided, so C++ owns the
304297
// allocation. Single grouped weight keeps this packed as [G, N, K];
305298
// discrete weights split the same packed allocation into per-expert views.
306-
out.packed = at::empty({static_cast<int64_t>(num_groups), rows, cols},
307-
at::device(device).dtype(dtype));
299+
out.packed =
300+
at::empty({static_cast<int64_t>(num_groups), rows, cols}, at::device(device).dtype(dtype));
308301
out.owns_storage = true;
309302
out.is_grouped = prefer_grouped_output;
310303
if (out.is_grouped) {
@@ -345,9 +338,8 @@ void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight,
345338
if (weights->is_grouped) {
346339
// Single grouped weight case: weights are packed as [G, N, K]. Wrap the
347340
// packed buffer as a uniform GroupedTensor and use the grouped-tensor GEMM.
348-
auto grouped_weight =
349-
make_uniform_grouped_tensor(weights->packed, input->num_tensors(), weights->rows,
350-
weights->cols);
341+
auto grouped_weight = make_uniform_grouped_tensor(weights->packed, input->num_tensors(),
342+
weights->rows, weights->cols);
351343
grouped_gemm(&grouped_weight, trans_weight, input, trans_input, output, resources, false);
352344
} else {
353345
// Discrete weight case: weights are a list of per-expert tensors. Use the
@@ -413,7 +405,8 @@ GroupedTensorWrapper make_grouped_bias(const at::Tensor &bias, size_t num_groups
413405
NVTE_CHECK(bias.defined(), "Bias tensor must be defined.");
414406
auto grouped = GroupedTensorWrapper(
415407
num_groups, std::vector<size_t>{num_groups, static_cast<size_t>(out_features)});
416-
grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype), tensor_shape_1d(bias));
408+
grouped.set_rowwise_data(bias.data_ptr(), GetTransformerEngineDType(dtype),
409+
tensor_shape_1d(bias));
417410
return grouped;
418411
}
419412

@@ -498,7 +491,8 @@ at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &a
498491
} else if (activation == "sreglu") {
499492
nvte_sreglu(te_input.data(), te_output.data(), stream);
500493
} else if (activation == "clamped_swiglu") {
501-
nvte_clamped_swiglu_v2(te_input.data(), te_output.data(), static_cast<float>(activation_limit),
494+
nvte_clamped_swiglu_v2(te_input.data(), te_output.data(),
495+
static_cast<float>(activation_limit),
502496
static_cast<float>(activation_alpha),
503497
static_cast<float>(activation_glu_linear_offset), stream);
504498
} else if (activation == "srelu") {
@@ -520,8 +514,7 @@ at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &a
520514

521515
at::Tensor activation_backward_impl(const at::Tensor &grad, const at::Tensor &input,
522516
const std::string &activation, double activation_limit,
523-
double activation_alpha,
524-
double activation_glu_linear_offset) {
517+
double activation_alpha, double activation_glu_linear_offset) {
525518
auto output = at::empty_like(input);
526519
auto te_grad = makeTransformerEngineTensor(grad);
527520
auto te_input = makeTransformerEngineTensor(input);
@@ -568,7 +561,7 @@ at::Tensor grouped_mlp_activation_forward(
568561
double activation_alpha, double activation_glu_linear_offset, at::ScalarType dtype) {
569562
auto activation_input = maybe_deinterleave_glu(input, glu_interleave_size);
570563
auto activation_output = activation_forward_impl(activation_input, activation, activation_limit,
571-
activation_alpha, activation_glu_linear_offset);
564+
activation_alpha, activation_glu_linear_offset);
572565
if (!act_scales.has_value()) {
573566
return activation_output;
574567
}
@@ -607,10 +600,9 @@ ActivationBackwardResult grouped_mlp_activation_backward(
607600
}
608601

609602
auto grad_activation_input =
610-
activation_backward_impl(grad_activation_output, activation_input, activation, activation_limit,
611-
activation_alpha, activation_glu_linear_offset);
612-
return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size),
613-
grad_act_scales};
603+
activation_backward_impl(grad_activation_output, activation_input, activation,
604+
activation_limit, activation_alpha, activation_glu_linear_offset);
605+
return {maybe_reinterleave_glu_grad(grad_activation_input, glu_interleave_size), grad_act_scales};
614606
}
615607

616608
} // namespace
@@ -653,8 +645,7 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
653645
split_sizes, x.device(),
654646
std::vector<int64_t>{1, in_features, fc1_out_features, fc2_in_features, fc2_out_features},
655647
std::vector<bool>{true, true, true, true, true},
656-
std::vector<at::ScalarType>{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong},
657-
true);
648+
std::vector<at::ScalarType>{at::kLong, at::kLong, at::kLong, at::kLong, at::kLong}, true);
658649
// splits_to_offsets_multi returns the canonical int64 CUDA split sizes and
659650
// offsets in the same order as the stride list above. The CuTe path also asks
660651
// for int32 split_points, but cuBLAS grouped GEMM does not consume them.
@@ -675,10 +666,9 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
675666
&gemm_resources);
676667
add_grouped_bias(&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features);
677668

678-
auto fc2_x =
679-
grouped_mlp_activation_forward(fc1_preact, act_scales, activation, glu_interleave_size,
680-
activation_limit, activation_alpha,
681-
activation_glu_linear_offset, dtype);
669+
auto fc2_x = grouped_mlp_activation_forward(
670+
fc1_preact, act_scales, activation, glu_interleave_size, activation_limit, activation_alpha,
671+
activation_glu_linear_offset, dtype);
682672

683673
std::vector<int64_t> out_shape = input.sizes().vec();
684674
out_shape.back() = fc2_out_features;
@@ -692,22 +682,20 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
692682
&gemm_resources);
693683
add_grouped_bias(&grouped_output, fc2_bias_tensor, num_groups, dtype, fc2_out_features);
694684

695-
return {output, x, split_sizes_i64, base_offsets, x_offsets, fc1_offsets, fc2_offsets,
696-
output_offsets, fc1_preact, fc2_x};
685+
return {output, x, split_sizes_i64, base_offsets, x_offsets,
686+
fc1_offsets, fc2_offsets, output_offsets, fc1_preact, fc2_x};
697687
}
698688

699689
py::tuple megacpp_grouped_mlp_backward(
700690
const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets,
701-
const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets,
702-
const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets,
703-
const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x,
704-
const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
705-
py::handle fc2_weight,
706-
py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad,
707-
py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad,
708-
const std::string &activation, int64_t glu_interleave_size, double activation_limit,
709-
double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad,
710-
bool input_requires_grad) {
691+
const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets,
692+
const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input,
693+
const at::Tensor &fc2_x, const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
694+
py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad,
695+
bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad,
696+
bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size,
697+
double activation_limit, double activation_alpha, double activation_glu_linear_offset,
698+
bool act_scales_requires_grad, bool input_requires_grad) {
711699
(void)base_offsets;
712700
NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output.");
713701
at::cuda::CUDAGuard device_guard(grad_output.device());
@@ -737,23 +725,20 @@ py::tuple megacpp_grouped_mlp_backward(
737725
fc2_x_for_wgrad = fc2_x_for_wgrad.view({-1, fc2_in_features});
738726
auto grouped_fc2_x_for_wgrad =
739727
make_grouped_tensor(fc2_x_for_wgrad.view({-1}), split_sizes, fc2_offsets, fc2_in_features);
740-
fc2_wgrads =
741-
grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output,
742-
fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype,
743-
fc2_out_features, fc2_in_features, "fc2_wgrad_output",
744-
fc2_weights.is_grouped);
728+
fc2_wgrads = grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output,
729+
fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype,
730+
fc2_out_features, fc2_in_features, "fc2_wgrad_output",
731+
fc2_weights.is_grouped);
745732
}
746733

747734
auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options());
748735
auto grouped_fc2_dx =
749736
make_grouped_tensor(fc2_dx.view({-1}), split_sizes, fc2_offsets, fc2_in_features);
750-
grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx,
751-
&gemm_resources);
737+
grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, &gemm_resources);
752738

753739
auto activation_grads = grouped_mlp_activation_backward(
754-
fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size,
755-
activation_limit, activation_alpha, activation_glu_linear_offset, dtype,
756-
act_scales_requires_grad);
740+
fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, activation_limit,
741+
activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad);
757742
auto fc1_dy = activation_grads.grad_input;
758743
auto grad_act_scales = activation_grads.grad_act_scales;
759744
auto grouped_fc1_dy =
@@ -766,11 +751,10 @@ py::tuple megacpp_grouped_mlp_backward(
766751
x_for_wgrad = x_for_wgrad.view({-1, in_features});
767752
auto grouped_x_for_wgrad =
768753
make_grouped_tensor(x_for_wgrad.view({-1}), split_sizes, x_offsets, in_features);
769-
fc1_wgrads =
770-
grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output,
771-
fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype,
772-
fc1_out_features, in_features, "fc1_wgrad_output",
773-
fc1_weights.is_grouped);
754+
fc1_wgrads = grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output,
755+
fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype,
756+
fc1_out_features, in_features, "fc1_wgrad_output",
757+
fc1_weights.is_grouped);
774758
}
775759

776760
at::Tensor grad_input;
@@ -779,8 +763,8 @@ py::tuple megacpp_grouped_mlp_backward(
779763
grad_input_shape.back() = in_features;
780764
grad_input = at::empty(grad_input_shape, dy.options());
781765
auto grad_input_2d = grad_input.view({-1, in_features});
782-
auto grouped_grad_input = make_grouped_tensor(grad_input_2d.view({-1}), split_sizes,
783-
x_offsets, in_features);
766+
auto grouped_grad_input =
767+
make_grouped_tensor(grad_input_2d.view({-1}), split_sizes, x_offsets, in_features);
784768
grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input,
785769
&gemm_resources);
786770
} else {

0 commit comments

Comments
 (0)