44 * See LICENSE for license information.
55 ************************************************************************/
66
7+ #include < ATen/cuda/CUDAContext.h>
8+ #include < c10/cuda/CUDAGuard.h>
79#include < pybind11/pybind11.h>
810#include < pybind11/stl.h>
911
1214#include < tuple>
1315#include < vector>
1416
15- #include < ATen/cuda/CUDAContext.h>
16- #include < c10/cuda/CUDAGuard.h>
17-
1817#include " ../extensions.h"
1918#include " ../pybind.h"
2019#include " common/util/cuda_runtime.h"
@@ -58,7 +57,8 @@ size_t num_groups_from_prepared_split_sizes(const at::Tensor &split_sizes,
5857}
5958
6059GroupedTensorWrapper make_grouped_tensor (at::Tensor data, const at::Tensor &prepared_split_sizes,
61- const at::Tensor &tensor_offsets, int64_t logical_last_dim) {
60+ const at::Tensor &tensor_offsets,
61+ int64_t logical_last_dim) {
6262 const auto num_groups = static_cast <size_t >(prepared_split_sizes.numel ());
6363 const auto total_tokens = static_cast <size_t >(data.numel () / logical_last_dim);
6464 auto grouped = GroupedTensorWrapper (
@@ -75,9 +75,8 @@ GroupedTensorWrapper make_grouped_tensor(at::Tensor data, const at::Tensor &prep
7575GroupedTensorWrapper make_uniform_grouped_tensor (at::Tensor data, size_t num_groups,
7676 int64_t first_dim, int64_t last_dim) {
7777 auto grouped = GroupedTensorWrapper (
78- num_groups,
79- std::vector<size_t >{num_groups * static_cast <size_t >(first_dim),
80- static_cast <size_t >(last_dim)});
78+ num_groups, std::vector<size_t >{num_groups * static_cast <size_t >(first_dim),
79+ static_cast <size_t >(last_dim)});
8180 grouped.set_rowwise_data (data.data_ptr (), GetTransformerEngineDType (data.scalar_type ()),
8281 tensor_shape_1d (data));
8382 return grouped;
@@ -94,9 +93,7 @@ struct GroupedWeightArg {
9493 int64_t rows = 0 ;
9594 int64_t cols = 0 ;
9695
97- c10::Device device () const {
98- return is_grouped ? packed.device () : discrete[0 ].device ();
99- }
96+ c10::Device device () const { return is_grouped ? packed.device () : discrete[0 ].device (); }
10097};
10198
10299GroupedWeightArg weight_arg_from_py (py::handle arg, size_t num_groups, at::ScalarType dtype,
@@ -201,9 +198,9 @@ struct GroupedGemmResources {
201198 te_alpha(makeTransformerEngineTensor(alpha)),
202199 te_beta_zero(makeTransformerEngineTensor(beta_zero)),
203200 te_beta_one(makeTransformerEngineTensor(beta_one)),
204- te_setup(makeTransformerEngineTensor(setup.data_ptr(),
205- std::vector<size_t>{static_cast <size_t >(setup.numel ())},
206- DType::kByte )),
201+ te_setup(makeTransformerEngineTensor(
202+ setup.data_ptr(), std::vector<size_t>{static_cast <size_t >(setup.numel ())},
203+ DType::kByte )),
207204 te_cublas(makeTransformerEngineTensor(
208205 cublas.data_ptr(), std::vector<size_t>{static_cast <size_t >(cublas.numel ())},
209206 DType::kByte )) {
@@ -220,9 +217,7 @@ struct GroupedGemmResources {
220217 }
221218 }
222219
223- NVTETensor beta (bool accumulate) {
224- return accumulate ? te_beta_one.data () : te_beta_zero.data ();
225- }
220+ NVTETensor beta (bool accumulate) { return accumulate ? te_beta_one.data () : te_beta_zero.data (); }
226221
227222 NVTEGroupedMatmulConfig config_data () {
228223 return config.has_value () ? static_cast <NVTEGroupedMatmulConfig>(*config) : nullptr ;
@@ -243,14 +238,12 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B,
243238 nvte_grouped_gemm (A->data (), transa, B->data (), transb, D->data (), D->data (),
244239 resources->te_alpha .data (), resources->beta (accumulate),
245240 resources->te_setup .data (), resources->te_cublas .data (),
246- resources->config_data (),
247- at::cuda::getCurrentCUDAStream ());
241+ resources->config_data (), at::cuda::getCurrentCUDAStream ());
248242 });
249243}
250244
251245std::vector<at::Tensor> output_tensor_list_from_arg (py::handle arg, size_t num_groups,
252- at::ScalarType dtype,
253- const std::string &name) {
246+ at::ScalarType dtype, const std::string &name) {
254247 std::vector<at::Tensor> out;
255248 if (is_none (arg)) {
256249 return out;
@@ -303,8 +296,8 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num
303296 // Cases 1 and 2: no external wgrad buffer was provided, so C++ owns the
304297 // allocation. Single grouped weight keeps this packed as [G, N, K];
305298 // discrete weights split the same packed allocation into per-expert views.
306- out.packed = at::empty ({ static_cast < int64_t >(num_groups), rows, cols},
307- at::device (device).dtype (dtype));
299+ out.packed =
300+ at::empty ({ static_cast < int64_t >(num_groups), rows, cols}, at::device (device).dtype (dtype));
308301 out.owns_storage = true ;
309302 out.is_grouped = prefer_grouped_output;
310303 if (out.is_grouped ) {
@@ -345,9 +338,8 @@ void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight,
345338 if (weights->is_grouped ) {
346339 // Single grouped weight case: weights are packed as [G, N, K]. Wrap the
347340 // packed buffer as a uniform GroupedTensor and use the grouped-tensor GEMM.
348- auto grouped_weight =
349- make_uniform_grouped_tensor (weights->packed , input->num_tensors (), weights->rows ,
350- weights->cols );
341+ auto grouped_weight = make_uniform_grouped_tensor (weights->packed , input->num_tensors (),
342+ weights->rows , weights->cols );
351343 grouped_gemm (&grouped_weight, trans_weight, input, trans_input, output, resources, false );
352344 } else {
353345 // Discrete weight case: weights are a list of per-expert tensors. Use the
@@ -413,7 +405,8 @@ GroupedTensorWrapper make_grouped_bias(const at::Tensor &bias, size_t num_groups
413405 NVTE_CHECK (bias.defined (), " Bias tensor must be defined." );
414406 auto grouped = GroupedTensorWrapper (
415407 num_groups, std::vector<size_t >{num_groups, static_cast <size_t >(out_features)});
416- grouped.set_rowwise_data (bias.data_ptr (), GetTransformerEngineDType (dtype), tensor_shape_1d (bias));
408+ grouped.set_rowwise_data (bias.data_ptr (), GetTransformerEngineDType (dtype),
409+ tensor_shape_1d (bias));
417410 return grouped;
418411}
419412
@@ -498,7 +491,8 @@ at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &a
498491 } else if (activation == " sreglu" ) {
499492 nvte_sreglu (te_input.data (), te_output.data (), stream);
500493 } else if (activation == " clamped_swiglu" ) {
501- nvte_clamped_swiglu_v2 (te_input.data (), te_output.data (), static_cast <float >(activation_limit),
494+ nvte_clamped_swiglu_v2 (te_input.data (), te_output.data (),
495+ static_cast <float >(activation_limit),
502496 static_cast <float >(activation_alpha),
503497 static_cast <float >(activation_glu_linear_offset), stream);
504498 } else if (activation == " srelu" ) {
@@ -520,8 +514,7 @@ at::Tensor activation_forward_impl(const at::Tensor &input, const std::string &a
520514
521515at::Tensor activation_backward_impl (const at::Tensor &grad, const at::Tensor &input,
522516 const std::string &activation, double activation_limit,
523- double activation_alpha,
524- double activation_glu_linear_offset) {
517+ double activation_alpha, double activation_glu_linear_offset) {
525518 auto output = at::empty_like (input);
526519 auto te_grad = makeTransformerEngineTensor (grad);
527520 auto te_input = makeTransformerEngineTensor (input);
@@ -568,7 +561,7 @@ at::Tensor grouped_mlp_activation_forward(
568561 double activation_alpha, double activation_glu_linear_offset, at::ScalarType dtype) {
569562 auto activation_input = maybe_deinterleave_glu (input, glu_interleave_size);
570563 auto activation_output = activation_forward_impl (activation_input, activation, activation_limit,
571- activation_alpha, activation_glu_linear_offset);
564+ activation_alpha, activation_glu_linear_offset);
572565 if (!act_scales.has_value ()) {
573566 return activation_output;
574567 }
@@ -607,10 +600,9 @@ ActivationBackwardResult grouped_mlp_activation_backward(
607600 }
608601
609602 auto grad_activation_input =
610- activation_backward_impl (grad_activation_output, activation_input, activation, activation_limit,
611- activation_alpha, activation_glu_linear_offset);
612- return {maybe_reinterleave_glu_grad (grad_activation_input, glu_interleave_size),
613- grad_act_scales};
603+ activation_backward_impl (grad_activation_output, activation_input, activation,
604+ activation_limit, activation_alpha, activation_glu_linear_offset);
605+ return {maybe_reinterleave_glu_grad (grad_activation_input, glu_interleave_size), grad_act_scales};
614606}
615607
616608} // namespace
@@ -653,8 +645,7 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
653645 split_sizes, x.device (),
654646 std::vector<int64_t >{1 , in_features, fc1_out_features, fc2_in_features, fc2_out_features},
655647 std::vector<bool >{true , true , true , true , true },
656- std::vector<at::ScalarType>{at::kLong , at::kLong , at::kLong , at::kLong , at::kLong },
657- true );
648+ std::vector<at::ScalarType>{at::kLong , at::kLong , at::kLong , at::kLong , at::kLong }, true );
658649 // splits_to_offsets_multi returns the canonical int64 CUDA split sizes and
659650 // offsets in the same order as the stride list above. The CuTe path also asks
660651 // for int32 split_points, but cuBLAS grouped GEMM does not consume them.
@@ -675,10 +666,9 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
675666 &gemm_resources);
676667 add_grouped_bias (&grouped_fc1_preact, fc1_bias_tensor, num_groups, dtype, fc1_out_features);
677668
678- auto fc2_x =
679- grouped_mlp_activation_forward (fc1_preact, act_scales, activation, glu_interleave_size,
680- activation_limit, activation_alpha,
681- activation_glu_linear_offset, dtype);
669+ auto fc2_x = grouped_mlp_activation_forward (
670+ fc1_preact, act_scales, activation, glu_interleave_size, activation_limit, activation_alpha,
671+ activation_glu_linear_offset, dtype);
682672
683673 std::vector<int64_t > out_shape = input.sizes ().vec ();
684674 out_shape.back () = fc2_out_features;
@@ -692,22 +682,20 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
692682 &gemm_resources);
693683 add_grouped_bias (&grouped_output, fc2_bias_tensor, num_groups, dtype, fc2_out_features);
694684
695- return {output, x, split_sizes_i64, base_offsets, x_offsets, fc1_offsets, fc2_offsets ,
696- output_offsets, fc1_preact, fc2_x};
685+ return {output, x, split_sizes_i64, base_offsets, x_offsets,
686+ fc1_offsets, fc2_offsets, output_offsets, fc1_preact, fc2_x};
697687}
698688
699689py::tuple megacpp_grouped_mlp_backward (
700690 const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets,
701- const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets,
702- const at::Tensor &fc2_dy_offsets, const at::Tensor &base_offsets,
703- const at::Tensor &x, const at::Tensor &fc1_activation_input, const at::Tensor &fc2_x,
704- const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
705- py::handle fc2_weight,
706- py::handle fc1_wgrad_output, bool fc1_compute_wgrad, bool fc1_accumulate_wgrad,
707- py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad,
708- const std::string &activation, int64_t glu_interleave_size, double activation_limit,
709- double activation_alpha, double activation_glu_linear_offset, bool act_scales_requires_grad,
710- bool input_requires_grad) {
691+ const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets,
692+ const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input,
693+ const at::Tensor &fc2_x, const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
694+ py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad,
695+ bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad,
696+ bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size,
697+ double activation_limit, double activation_alpha, double activation_glu_linear_offset,
698+ bool act_scales_requires_grad, bool input_requires_grad) {
711699 (void )base_offsets;
712700 NVTE_CHECK (grad_output.is_cuda (), " megacpp_grouped_mlp_backward requires CUDA grad_output." );
713701 at::cuda::CUDAGuard device_guard (grad_output.device ());
@@ -737,23 +725,20 @@ py::tuple megacpp_grouped_mlp_backward(
737725 fc2_x_for_wgrad = fc2_x_for_wgrad.view ({-1 , fc2_in_features});
738726 auto grouped_fc2_x_for_wgrad =
739727 make_grouped_tensor (fc2_x_for_wgrad.view ({-1 }), split_sizes, fc2_offsets, fc2_in_features);
740- fc2_wgrads =
741- grouped_gemm_wgrad (&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output,
742- fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype,
743- fc2_out_features, fc2_in_features, " fc2_wgrad_output" ,
744- fc2_weights.is_grouped );
728+ fc2_wgrads = grouped_gemm_wgrad (&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output,
729+ fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype,
730+ fc2_out_features, fc2_in_features, " fc2_wgrad_output" ,
731+ fc2_weights.is_grouped );
745732 }
746733
747734 auto fc2_dx = at::empty ({total_tokens, fc2_in_features}, dy.options ());
748735 auto grouped_fc2_dx =
749736 make_grouped_tensor (fc2_dx.view ({-1 }), split_sizes, fc2_offsets, fc2_in_features);
750- grouped_gemm_fwd_dgrad (&fc2_weights, false , &grouped_dy, false , &grouped_fc2_dx,
751- &gemm_resources);
737+ grouped_gemm_fwd_dgrad (&fc2_weights, false , &grouped_dy, false , &grouped_fc2_dx, &gemm_resources);
752738
753739 auto activation_grads = grouped_mlp_activation_backward (
754- fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size,
755- activation_limit, activation_alpha, activation_glu_linear_offset, dtype,
756- act_scales_requires_grad);
740+ fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, activation_limit,
741+ activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad);
757742 auto fc1_dy = activation_grads.grad_input ;
758743 auto grad_act_scales = activation_grads.grad_act_scales ;
759744 auto grouped_fc1_dy =
@@ -766,11 +751,10 @@ py::tuple megacpp_grouped_mlp_backward(
766751 x_for_wgrad = x_for_wgrad.view ({-1 , in_features});
767752 auto grouped_x_for_wgrad =
768753 make_grouped_tensor (x_for_wgrad.view ({-1 }), split_sizes, x_offsets, in_features);
769- fc1_wgrads =
770- grouped_gemm_wgrad (&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output,
771- fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype,
772- fc1_out_features, in_features, " fc1_wgrad_output" ,
773- fc1_weights.is_grouped );
754+ fc1_wgrads = grouped_gemm_wgrad (&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output,
755+ fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype,
756+ fc1_out_features, in_features, " fc1_wgrad_output" ,
757+ fc1_weights.is_grouped );
774758 }
775759
776760 at::Tensor grad_input;
@@ -779,8 +763,8 @@ py::tuple megacpp_grouped_mlp_backward(
779763 grad_input_shape.back () = in_features;
780764 grad_input = at::empty (grad_input_shape, dy.options ());
781765 auto grad_input_2d = grad_input.view ({-1 , in_features});
782- auto grouped_grad_input = make_grouped_tensor (grad_input_2d. view ({- 1 }), split_sizes,
783- x_offsets, in_features);
766+ auto grouped_grad_input =
767+ make_grouped_tensor (grad_input_2d. view ({- 1 }), split_sizes, x_offsets, in_features);
784768 grouped_gemm_fwd_dgrad (&fc1_weights, false , &grouped_fc1_dy, false , &grouped_grad_input,
785769 &gemm_resources);
786770 } else {
0 commit comments