@@ -358,22 +358,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
358358 &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out,
359359 " Grouped GEMM for discrete output list" );
360360 m.def (" megacpp_grouped_mlp_forward" , &transformer_engine::pytorch::megacpp_grouped_mlp_forward,
361- " Mega C++ grouped MLP forward" , py::arg (" input" ), py::arg (" split_sizes" ),
362- py::arg (" fc1_weight" ), py::arg (" fc1_bias" ), py::arg (" fc2_weight" ), py::arg (" fc2_bias" ),
361+ " Mega C++ grouped MLP forward" , py::arg (" input" ), py::arg (" act_dtype" ),
362+ py::arg (" split_sizes" ), py::arg (" fc1_weight" ), py::arg (" fc1_bias" ),
363+ py::arg (" fc2_weight" ), py::arg (" fc2_bias" ),
363364 py::arg (" act_scales" ), py::arg (" activation" ), py::arg (" glu_interleave_size" ),
364365 py::arg (" activation_limit" ) = 0.0 , py::arg (" activation_alpha" ) = 0.0 ,
365- py::arg (" activation_glu_linear_offset" ) = 0.0 );
366+ py::arg (" activation_glu_linear_offset" ) = 0.0 ,
367+ py::arg (" gemm_scratch" ) = py::none ());
366368 m.def (" megacpp_grouped_mlp_backward" , &transformer_engine::pytorch::megacpp_grouped_mlp_backward,
367- " Mega C++ grouped MLP backward" , py::arg (" grad_output" ), py::arg (" split_sizes " ),
368- py::arg (" x_offsets " ), py::arg (" fc1_offsets " ), py::arg (" fc2_offsets " ),
369- py::arg (" fc2_dy_offsets" ), py::arg (" base_offsets" ), py::arg (" x" ),
369+ " Mega C++ grouped MLP backward" , py::arg (" grad_output" ), py::arg (" act_dtype " ),
370+ py::arg (" split_sizes " ), py::arg (" x_offsets " ), py::arg (" fc1_offsets " ),
371+ py::arg (" fc2_offsets " ), py::arg ( " fc2_dy_offsets" ), py::arg (" base_offsets" ), py::arg (" x" ),
370372 py::arg (" fc1_activation_input" ), py::arg (" fc2_x" ), py::arg (" act_scales" ),
371373 py::arg (" fc1_weight" ), py::arg (" fc2_weight" ), py::arg (" fc1_wgrad_output" ),
372374 py::arg (" fc1_compute_wgrad" ), py::arg (" fc1_accumulate_wgrad" ), py::arg (" fc2_wgrad_output" ),
373375 py::arg (" fc2_compute_wgrad" ), py::arg (" fc2_accumulate_wgrad" ), py::arg (" activation" ),
374376 py::arg (" glu_interleave_size" ), py::arg (" activation_limit" ) = 0.0 ,
375377 py::arg (" activation_alpha" ) = 0.0 , py::arg (" activation_glu_linear_offset" ) = 0.0 ,
376- py::arg (" act_scales_requires_grad" ) = true , py::arg (" input_requires_grad" ) = true );
378+ py::arg (" act_scales_requires_grad" ) = true , py::arg (" input_requires_grad" ) = true ,
379+ py::arg (" gemm_scratch" ) = py::none ());
377380 m.def (" fp8_transpose" , &transformer_engine::pytorch::fp8_transpose, " Transpose with FP8 I/O" ,
378381 py::arg (" input" ), py::arg (" dtype" ), py::kw_only (), py::arg (" out" ),
379382 py::call_guard<py::gil_scoped_release>());
0 commit comments