@@ -256,8 +256,7 @@ struct GroupedGemmResources {
256256};
257257
258258GroupedGemmResources make_grouped_mlp_backend_resources (const c10::Device &device,
259- size_t num_groups,
260- py::handle scratch) {
259+ size_t num_groups, py::handle scratch) {
261260 // Keep the backend resource policy private to megacpp. Today this is cuBLAS
262261 // grouped GEMM scratch; future backends can change this helper without
263262 // changing the Python or pybind contract.
@@ -275,9 +274,8 @@ void grouped_gemm(GroupedTensorWrapper *A, bool transa, GroupedTensorWrapper *B,
275274 });
276275}
277276
278- std::vector<at::Tensor> output_tensor_list_from_arg (py::handle arg, size_t num_groups,
279- int64_t rows, int64_t cols,
280- const std::string &name) {
277+ std::vector<at::Tensor> output_tensor_list_from_arg (py::handle arg, size_t num_groups, int64_t rows,
278+ int64_t cols, const std::string &name) {
281279 std::vector<at::Tensor> out;
282280 if (is_none (arg)) {
283281 return out;
@@ -327,8 +325,7 @@ WgradOutput wgrad_output_from_arg(py::handle arg, bool compute_wgrad, size_t num
327325 // allocation. Single grouped weight keeps this packed as [G, N, K];
328326 // discrete weights split the same packed allocation into per-expert views.
329327 out.packed =
330- at::empty ({static_cast <int64_t >(num_groups), rows, cols},
331- at::device (device).dtype (dtype));
328+ at::empty ({static_cast <int64_t >(num_groups), rows, cols}, at::device (device).dtype (dtype));
332329 out.owns_storage = true ;
333330 out.is_grouped = prefer_grouped_output;
334331 if (out.is_grouped ) {
@@ -389,14 +386,12 @@ void grouped_gemm_fwd_dgrad(GroupedWeightArg *weights, bool trans_weight,
389386}
390387
391388std::vector<at::Tensor> grouped_gemm_wgrad (GroupedTensorWrapper *x, GroupedTensorWrapper *dy,
392- py::handle output, bool compute_wgrad, bool accumulate,
393- GroupedGemmResources *resources,
394- at::ScalarType dtype, int64_t rows,
395- int64_t cols, const std::string &name,
396- bool prefer_grouped_output) {
397- auto prepared =
398- wgrad_output_from_arg (output, compute_wgrad, resources->num_groups , dtype,
399- resources->device , rows, cols, name, prefer_grouped_output);
389+ py::handle output, bool compute_wgrad, bool accumulate,
390+ GroupedGemmResources *resources, at::ScalarType dtype,
391+ int64_t rows, int64_t cols, const std::string &name,
392+ bool prefer_grouped_output) {
393+ auto prepared = wgrad_output_from_arg (output, compute_wgrad, resources->num_groups , dtype,
394+ resources->device , rows, cols, name, prefer_grouped_output);
400395 NVTE_CHECK (!(prepared.owns_storage && accumulate), name,
401396 " cannot accumulate into a newly allocated wgrad buffer." );
402397 std::vector<at::Tensor> returned_wgrads;
@@ -643,8 +638,7 @@ ActivationBackwardResult grouped_mlp_activation_backward(
643638
644639std::vector<at::Tensor> megacpp_grouped_mlp_forward (
645640 const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes,
646- py::handle fc1_weight,
647- py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias,
641+ py::handle fc1_weight, py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias,
648642 const std::optional<at::Tensor> &act_scales, const std::string &activation,
649643 int64_t glu_interleave_size, double activation_limit, double activation_alpha,
650644 double activation_glu_linear_offset, py::handle gemm_scratch) {
@@ -712,8 +706,7 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
712706 std::vector<int64_t > out_shape = input.sizes ().vec ();
713707 out_shape.back () = fc2_out_features;
714708 auto output = at::empty (out_shape, x.options ());
715- auto grouped_fc2_x =
716- make_grouped_tensor (fc2_x, split_sizes_i64, fc2_offsets, fc2_in_features);
709+ auto grouped_fc2_x = make_grouped_tensor (fc2_x, split_sizes_i64, fc2_offsets, fc2_in_features);
717710 auto grouped_output =
718711 make_grouped_tensor (output, split_sizes_i64, output_offsets, fc2_out_features);
719712 grouped_gemm_fwd_dgrad (&fc2_weights, true , &grouped_fc2_x, false , &grouped_output,
@@ -726,15 +719,15 @@ std::vector<at::Tensor> megacpp_grouped_mlp_forward(
726719
727720py::tuple megacpp_grouped_mlp_backward (
728721 const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes,
729- const at::Tensor &x_offsets,
730- const at::Tensor &fc1_offsets , const at::Tensor &fc2_offsets , const at::Tensor &fc2_dy_offsets ,
731- const at::Tensor &base_offsets , const at::Tensor &x, const at::Tensor &fc1_activation_input ,
732- const at::Tensor &fc2_x, const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
733- py::handle fc2_weight, py::handle fc1_wgrad_output , bool fc1_compute_wgrad ,
734- bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad,
735- bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size,
736- double activation_limit , double activation_alpha, double activation_glu_linear_offset ,
737- bool act_scales_requires_grad, bool input_requires_grad, py::handle gemm_scratch) {
722+ const at::Tensor &x_offsets, const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets,
723+ const at::Tensor &fc2_dy_offsets , const at::Tensor &base_offsets , const at::Tensor &x ,
724+ const at::Tensor &fc1_activation_input , const at::Tensor &fc2_x ,
725+ const std::optional<at::Tensor> &act_scales, py::handle fc1_weight, py::handle fc2_weight ,
726+ py::handle fc1_wgrad_output, bool fc1_compute_wgrad , bool fc1_accumulate_wgrad ,
727+ py::handle fc2_wgrad_output, bool fc2_compute_wgrad, bool fc2_accumulate_wgrad ,
728+ const std::string &activation, int64_t glu_interleave_size, double activation_limit ,
729+ double activation_alpha , double activation_glu_linear_offset, bool act_scales_requires_grad ,
730+ bool input_requires_grad, py::handle gemm_scratch) {
738731 (void )base_offsets;
739732 NVTE_CHECK (grad_output.is_cuda (), " megacpp_grouped_mlp_backward requires CUDA grad_output." );
740733 at::cuda::CUDAGuard device_guard (grad_output.device ());
@@ -769,23 +762,21 @@ py::tuple megacpp_grouped_mlp_backward(
769762 auto grouped_fc2_x_for_wgrad =
770763 make_grouped_tensor (fc2_x_for_wgrad, split_sizes, fc2_offsets, fc2_in_features);
771764 fc2_wgrads = grouped_gemm_wgrad (&grouped_fc2_x_for_wgrad, &grouped_dy, fc2_wgrad_output,
772- fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources,
773- dtype, fc2_out_features, fc2_in_features, " fc2_wgrad_output" ,
765+ fc2_compute_wgrad, fc2_accumulate_wgrad, &gemm_resources, dtype,
766+ fc2_out_features, fc2_in_features, " fc2_wgrad_output" ,
774767 fc2_weights.is_grouped );
775768 }
776769
777770 auto fc2_dx = at::empty ({total_tokens, fc2_in_features}, dy.options ());
778- auto grouped_fc2_dx =
779- make_grouped_tensor (fc2_dx, split_sizes, fc2_offsets, fc2_in_features);
771+ auto grouped_fc2_dx = make_grouped_tensor (fc2_dx, split_sizes, fc2_offsets, fc2_in_features);
780772 grouped_gemm_fwd_dgrad (&fc2_weights, false , &grouped_dy, false , &grouped_fc2_dx, &gemm_resources);
781773
782774 auto activation_grads = grouped_mlp_activation_backward (
783775 fc2_dx, fc1_activation_input, act_scales, activation, glu_interleave_size, activation_limit,
784776 activation_alpha, activation_glu_linear_offset, dtype, act_scales_requires_grad);
785777 auto fc1_dy = activation_grads.grad_input ;
786778 auto grad_act_scales = activation_grads.grad_act_scales ;
787- auto grouped_fc1_dy =
788- make_grouped_tensor (fc1_dy, split_sizes, fc1_offsets, fc1_out_features);
779+ auto grouped_fc1_dy = make_grouped_tensor (fc1_dy, split_sizes, fc1_offsets, fc1_out_features);
789780
790781 std::vector<at::Tensor> fc1_wgrads;
791782 if (fc1_compute_wgrad) {
@@ -794,8 +785,8 @@ py::tuple megacpp_grouped_mlp_backward(
794785 auto grouped_x_for_wgrad =
795786 make_grouped_tensor (x_for_wgrad, split_sizes, x_offsets, in_features);
796787 fc1_wgrads = grouped_gemm_wgrad (&grouped_x_for_wgrad, &grouped_fc1_dy, fc1_wgrad_output,
797- fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources,
798- dtype, fc1_out_features, in_features, " fc1_wgrad_output" ,
788+ fc1_compute_wgrad, fc1_accumulate_wgrad, &gemm_resources, dtype,
789+ fc1_out_features, in_features, " fc1_wgrad_output" ,
799790 fc1_weights.is_grouped );
800791 }
801792
@@ -804,8 +795,7 @@ py::tuple megacpp_grouped_mlp_backward(
804795 std::vector<int64_t > grad_input_shape = grad_output.sizes ().vec ();
805796 grad_input_shape.back () = in_features;
806797 grad_input = at::empty (grad_input_shape, dy.options ());
807- auto grouped_grad_input =
808- make_grouped_tensor (grad_input, split_sizes, x_offsets, in_features);
798+ auto grouped_grad_input = make_grouped_tensor (grad_input, split_sizes, x_offsets, in_features);
809799 grouped_gemm_fwd_dgrad (&fc1_weights, false , &grouped_fc1_dy, false , &grouped_grad_input,
810800 &gemm_resources);
811801 } else {
0 commit comments