Skip to content

Commit 7ab8bc6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b3847a2 commit 7ab8bc6

3 files changed

Lines changed: 42 additions & 54 deletions

File tree

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,23 +191,22 @@ py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, p
191191

192192
std::vector<at::Tensor> megacpp_grouped_mlp_forward(
193193
const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes,
194-
py::handle fc1_weight,
195-
py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias,
194+
py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias,
196195
const std::optional<at::Tensor> &act_scales, const std::string &activation,
197196
int64_t glu_interleave_size, double activation_limit, double activation_alpha,
198197
double activation_glu_linear_offset, py::handle gemm_scratch);
199198

200199
py::tuple megacpp_grouped_mlp_backward(
201200
const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes,
202-
const at::Tensor &x_offsets,
203-
const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets,
204-
const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input,
205-
const at::Tensor &fc2_x, const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
206-
py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad,
207-
bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad,
208-
bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size,
209-
double activation_limit, double activation_alpha, double activation_glu_linear_offset,
210-
bool act_scales_requires_grad, bool input_requires_grad, py::handle gemm_scratch);
201+
const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets,
202+
const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, const at::Tensor &x,
203+
const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x,
204+
const std::optional<at::Tensor> &act_scales, py::handle fc1_weight, py::handle fc2_weight,
205+
py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad,
206+
py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad,
207+
const std::string &activation, int64_t glu_interleave_size, double activation_limit,
208+
double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad,
209+
bool input_requires_grad, py::handle gemm_scratch);
211210

212211
/***************************************************************************************************
213212
* Transpose

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
359359
"Grouped GEMM for discrete output list");
360360
m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward,
361361
"Mega C++ grouped MLP forward", py::arg("input"), py::arg("act_dtype"),
362-
py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"),
363-
py::arg("fc2_weight"), py::arg("fc2_bias"),
364-
py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"),
365-
py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0,
366-
py::arg("activation_glu_linear_offset") = 0.0,
362+
py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"),
363+
py::arg("fc2_bias"), py::arg("act_scales"), py::arg("activation"),
364+
py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0,
365+
py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0,
367366
py::arg("gemm_scratch") = py::none());
368367
m.def("megacpp_grouped_mlp_backward", &transformer_engine::pytorch::megacpp_grouped_mlp_backward,
369368
"Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("act_dtype"),

transformer_engine/pytorch/csrc/megacpp/grouped_mlp.cpp

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,7 @@ struct GroupedGemmResources {
256256
};
257257

258258
GroupedGemmResources make_grouped_mlp_backend_resources(const c10::Device &device,
259-
size_t num_groups,
260-
py::handle scratch) {
259+
size_t num_groups, py::handle scratch) {
261260
// Keep the backend resource policy private to megacpp. Today this is cuBLAS
262261
// grouped GEMM scratch; future backends can change this helper without
263262
// changing the Python or pybind contract.
@@ -275,9 +274,8 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B,
275274
});
276275
}
277276

278-
std::vector<at::Tensor> output_tensor_list_from_arg(py::handle arg, size_t num_groups,
279-
int64_t rows, int64_t cols,
280-
const std::string &name) {
277+
std::vector<at::Tensor> output_tensor_list_from_arg(py::handle arg, size_t num_groups, int64_t rows,
278+
int64_t cols, const std::string &name) {
281279
std::vector<at::Tensor> out;
282280
if (is_none(arg)) {
283281
return out;
@@ -327,8 +325,7 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num
327325
// allocation. Single grouped weight keeps this packed as [G, N, K];
328326
// discrete weights split the same packed allocation into per-expert views.
329327
out.packed =
330-
at::empty({static_cast<int64_t>(num_groups), rows, cols},
331-
at::device(device).dtype(dtype));
328+
at::empty({static_cast<int64_t>(num_groups), rows, cols}, at::device(device).dtype(dtype));
332329
out.owns_storage = true;
333330
out.is_grouped = prefer_grouped_output;
334331
if (out.is_grouped) {
@@ -389,14 +386,12 @@ void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight,
389386
}
390387

391388
std::vector<at::Tensor> grouped_gemm_wgrad(GroupedTensorWrapper *x, GroupedTensorWrapper *dy,
392-
py::handle output, bool compute_wgrad, bool accumulate,
393-
GroupedGemmResources *resources,
394-
at::ScalarType dtype, int64_t rows,
395-
int64_t cols, const std::string &name,
396-
bool prefer_grouped_output) {
397-
auto prepared =
398-
wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype,
399-
resources->device, rows, cols, name, prefer_grouped_output);
389+
py::handle output, bool compute_wgrad, bool accumulate,
390+
GroupedGemmResources *resources, at::ScalarType dtype,
391+
int64_t rows, int64_t cols, const std::string &name,
392+
bool prefer_grouped_output) {
393+
auto prepared = wgrad_output_from_arg(output, compute_wgrad, resources->num_groups, dtype,
394+
resources->device, rows, cols, name, prefer_grouped_output);
400395
NVTE_CHECK(!(prepared.owns_storage && accumulate), name,
401396
" cannot accumulate into a newly allocated wgrad buffer.");
402397
std::vector<at::Tensor> returned_wgrads;
@@ -643,8 +638,7 @@ ActivationBackwardResult grouped_mlp_activation_backward(
643638

644639
std::vector<at::Tensor> megacpp_grouped_mlp_forward(
645640
const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes,
646-
py::handle fc1_weight,
647-
py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias,
641+
py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias,
648642
const std::optional<at::Tensor> &act_scales, const std::string &activation,
649643
int64_t glu_interleave_size, double activation_limit, double activation_alpha,
650644
double activation_glu_linear_offset, py::handle gemm_scratch) {
@@ -712,8 +706,7 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
712706
std::vector<int64_t> out_shape = input.sizes().vec();
713707
out_shape.back() = fc2_out_features;
714708
auto output = at::empty(out_shape, x.options());
715-
auto grouped_fc2_x =
716-
make_grouped_tensor(fc2_x, split_sizes_i64, fc2_offsets, fc2_in_features);
709+
auto grouped_fc2_x = make_grouped_tensor(fc2_x, split_sizes_i64, fc2_offsets, fc2_in_features);
717710
auto grouped_output =
718711
make_grouped_tensor(output, split_sizes_i64, output_offsets, fc2_out_features);
719712
grouped_gemm_fwd_dgrad(&fc2_weights, true, &grouped_fc2_x, false, &grouped_output,
@@ -726,15 +719,15 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
726719

727720
py::tuple megacpp_grouped_mlp_backward(
728721
const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes,
729-
const at::Tensor &x_offsets,
730-
const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets,
731-
const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input,
732-
const at::Tensor &fc2_x, const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
733-
py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad,
734-
bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad,
735-
bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size,
736-
double activation_limit, double activation_alpha, double activation_glu_linear_offset,
737-
bool act_scales_requires_grad, bool input_requires_grad, py::handle gemm_scratch) {
722+
const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets,
723+
const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets, const at::Tensor &x,
724+
const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x,
725+
const std::optional<at::Tensor> &act_scales, py::handle fc1_weight, py::handle fc2_weight,
726+
py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad,
727+
py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad,
728+
const std::string &activation, int64_t glu_interleave_size, double activation_limit,
729+
double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad,
730+
bool input_requires_grad, py::handle gemm_scratch) {
738731
(void)base_offsets;
739732
NVTE_CHECK(grad_output.is_cuda(), "megacpp_grouped_mlp_backward requires CUDA grad_output.");
740733
at::cuda::CUDAGuard device_guard(grad_output.device());
@@ -769,23 +762,21 @@ py::tuple megacpp_grouped_mlp_backward(
769762
auto grouped_fc2_x_for_wgrad =
770763
make_grouped_tensor(fc2_x_for_wgrad, split_sizes, fc2_offsets, fc2_in_features);
771764
fc2_wgrads = grouped_gemm_wgrad(&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output,
772-
fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources,
773-
dtype, fc2_out_features, fc2_in_features, "fc2_wgrad_output",
765+
fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype,
766+
fc2_out_features, fc2_in_features, "fc2_wgrad_output",
774767
fc2_weights.is_grouped);
775768
}
776769

777770
auto fc2_dx = at::empty({total_tokens, fc2_in_features}, dy.options());
778-
auto grouped_fc2_dx =
779-
make_grouped_tensor(fc2_dx, split_sizes, fc2_offsets, fc2_in_features);
771+
auto grouped_fc2_dx = make_grouped_tensor(fc2_dx, split_sizes, fc2_offsets, fc2_in_features);
780772
grouped_gemm_fwd_dgrad(&fc2_weights, false, &grouped_dy, false, &grouped_fc2_dx, &gemm_resources);
781773

782774
auto activation_grads = grouped_mlp_activation_backward(
783775
fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, activation_limit,
784776
activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad);
785777
auto fc1_dy = activation_grads.grad_input;
786778
auto grad_act_scales = activation_grads.grad_act_scales;
787-
auto grouped_fc1_dy =
788-
make_grouped_tensor(fc1_dy, split_sizes, fc1_offsets, fc1_out_features);
779+
auto grouped_fc1_dy = make_grouped_tensor(fc1_dy, split_sizes, fc1_offsets, fc1_out_features);
789780

790781
std::vector<at::Tensor> fc1_wgrads;
791782
if (fc1_compute_wgrad) {
@@ -794,8 +785,8 @@ py::tuple megacpp_grouped_mlp_backward(
794785
auto grouped_x_for_wgrad =
795786
make_grouped_tensor(x_for_wgrad, split_sizes, x_offsets, in_features);
796787
fc1_wgrads = grouped_gemm_wgrad(&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output,
797-
fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources,
798-
dtype, fc1_out_features, in_features, "fc1_wgrad_output",
788+
fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype,
789+
fc1_out_features, in_features, "fc1_wgrad_output",
799790
fc1_weights.is_grouped);
800791
}
801792

@@ -804,8 +795,7 @@ py::tuple megacpp_grouped_mlp_backward(
804795
std::vector<int64_t> grad_input_shape = grad_output.sizes().vec();
805796
grad_input_shape.back() = in_features;
806797
grad_input = at::empty(grad_input_shape, dy.options());
807-
auto grouped_grad_input =
808-
make_grouped_tensor(grad_input, split_sizes, x_offsets, in_features);
798+
auto grouped_grad_input = make_grouped_tensor(grad_input, split_sizes, x_offsets, in_features);
809799
grouped_gemm_fwd_dgrad(&fc1_weights, false, &grouped_fc1_dy, false, &grouped_grad_input,
810800
&gemm_resources);
811801
} else {

0 commit comments

Comments
 (0)