Skip to content

Commit 72c5a63

Browse files
committed
micro optimizations
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
1 parent c9d56c4 commit 72c5a63

5 files changed

Lines changed: 171 additions & 78 deletions

File tree

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,22 +190,24 @@ py::object te_general_grouped_gemm_for_discrete_out(py::handle A, bool transa, p
190190
**************************************************************************************************/
191191

192192
std::vector<at::Tensor> megacpp_grouped_mlp_forward(
193-
const at::Tensor &input, const at::Tensor &split_sizes, py::handle fc1_weight,
193+
const at::Tensor &input, at::ScalarType act_dtype, const at::Tensor &split_sizes,
194+
py::handle fc1_weight,
194195
py::handle fc1_bias, py::handle fc2_weight, py::handle fc2_bias,
195196
const std::optional<at::Tensor> &act_scales, const std::string &activation,
196197
int64_t glu_interleave_size, double activation_limit, double activation_alpha,
197-
double activation_glu_linear_offset);
198+
double activation_glu_linear_offset, py::handle gemm_scratch);
198199

199200
py::tuple megacpp_grouped_mlp_backward(
200-
const at::Tensor &grad_output, const at::Tensor &split_sizes, const at::Tensor &x_offsets,
201+
const at::Tensor &grad_output, at::ScalarType act_dtype, const at::Tensor &split_sizes,
202+
const at::Tensor &x_offsets,
201203
const at::Tensor &fc1_offsets, const at::Tensor &fc2_offsets, const at::Tensor &fc2_dy_offsets,
202204
const at::Tensor &base_offsets, const at::Tensor &x, const at::Tensor &fc1_activation_input,
203205
const at::Tensor &fc2_x, const std::optional<at::Tensor> &act_scales, py::handle fc1_weight,
204206
py::handle fc2_weight, py::handle fc1_wgrad_output, bool fc1_compute_wgrad,
205207
bool fc1_accumulate_wgrad, py::handle fc2_wgrad_output, bool fc2_compute_wgrad,
206208
bool fc2_accumulate_wgrad, const std::string &activation, int64_t glu_interleave_size,
207209
double activation_limit, double activation_alpha, double activation_glu_linear_offset,
208-
bool act_scales_requires_grad, bool input_requires_grad);
210+
bool act_scales_requires_grad, bool input_requires_grad, py::handle gemm_scratch);
209211

210212
/***************************************************************************************************
211213
* Transpose

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,22 +358,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
358358
&transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out,
359359
"Grouped GEMM for discrete output list");
360360
m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward,
361-
"Mega C++ grouped MLP forward", py::arg("input"), py::arg("split_sizes"),
362-
py::arg("fc1_weight"), py::arg("fc1_bias"), py::arg("fc2_weight"), py::arg("fc2_bias"),
361+
"Mega C++ grouped MLP forward", py::arg("input"), py::arg("act_dtype"),
362+
py::arg("split_sizes"), py::arg("fc1_weight"), py::arg("fc1_bias"),
363+
py::arg("fc2_weight"), py::arg("fc2_bias"),
363364
py::arg("act_scales"), py::arg("activation"), py::arg("glu_interleave_size"),
364365
py::arg("activation_limit") = 0.0, py::arg("activation_alpha") = 0.0,
365-
py::arg("activation_glu_linear_offset") = 0.0);
366+
py::arg("activation_glu_linear_offset") = 0.0,
367+
py::arg("gemm_scratch") = py::none());
366368
m.def("megacpp_grouped_mlp_backward", &transformer_engine::pytorch::megacpp_grouped_mlp_backward,
367-
"Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("split_sizes"),
368-
py::arg("x_offsets"), py::arg("fc1_offsets"), py::arg("fc2_offsets"),
369-
py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"),
369+
"Mega C++ grouped MLP backward", py::arg("grad_output"), py::arg("act_dtype"),
370+
py::arg("split_sizes"), py::arg("x_offsets"), py::arg("fc1_offsets"),
371+
py::arg("fc2_offsets"), py::arg("fc2_dy_offsets"), py::arg("base_offsets"), py::arg("x"),
370372
py::arg("fc1_activation_input"), py::arg("fc2_x"), py::arg("act_scales"),
371373
py::arg("fc1_weight"), py::arg("fc2_weight"), py::arg("fc1_wgrad_output"),
372374
py::arg("fc1_compute_wgrad"), py::arg("fc1_accumulate_wgrad"), py::arg("fc2_wgrad_output"),
373375
py::arg("fc2_compute_wgrad"), py::arg("fc2_accumulate_wgrad"), py::arg("activation"),
374376
py::arg("glu_interleave_size"), py::arg("activation_limit") = 0.0,
375377
py::arg("activation_alpha") = 0.0, py::arg("activation_glu_linear_offset") = 0.0,
376-
py::arg("act_scales_requires_grad") = true, py::arg("input_requires_grad") = true);
378+
py::arg("act_scales_requires_grad") = true, py::arg("input_requires_grad") = true,
379+
py::arg("gemm_scratch") = py::none());
377380
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
378381
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
379382
py::call_guard<py::gil_scoped_release>());

0 commit comments

Comments
 (0)