@@ -386,50 +386,8 @@ CublasMpDims compute_rs_dims(const TensorWrapper &A, bool transa, const TensorWr
386386 return {m, n, k};
387387}
388388
389- CublasMpDims compute_ar_dims (const TensorWrapper &A, bool transa, const TensorWrapper &B,
390- bool transb, int tp_size) {
391- // AR shares the same m/n/k semantics as RS at descriptor level.
392- return compute_rs_dims (A, transa, B, transb, tp_size);
393- }
394-
395389} // namespace
396390
397- void CommOverlapCore::cublasmp_ag_gemm (const TensorWrapper &A, bool transa, const TensorWrapper &B,
398- bool transb, TensorWrapper &D, TensorWrapper &bias,
399- TensorWrapper &pre_gelu_out, bool grad, bool accumulate,
400- cudaStream_t stream_main) {
401- auto [m, n, k] = compute_ag_dims (A, transa, B, transb, _tp_size);
402- // col-major GEMM compute overlapped with all-gather on input B
403- // (M/P, K) x [(K, N/P) -(AG)-> (K, N)] = (M/P, N)
404- nvte_all_gather_gemm (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
405- pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
406- stream_main, _algo_type);
407- }
408-
409- void CommOverlapCore::cublasmp_gemm_rs (const TensorWrapper &A, bool transa, const TensorWrapper &B,
410- bool transb, TensorWrapper &D, TensorWrapper &bias,
411- TensorWrapper &pre_gelu_out, bool grad, bool accumulate,
412- cudaStream_t stream_main) {
413- auto [m, n, k] = compute_rs_dims (A, transa, B, transb, _tp_size);
414- // col-major GEMM compute overlapped with reduce-scatter on the output
415- // (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
416- nvte_gemm_reduce_scatter (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
417- pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
418- stream_main, _algo_type);
419- }
420-
421- void CommOverlapCore::cublasmp_gemm_ar (const TensorWrapper &A, bool transa, const TensorWrapper &B,
422- bool transb, TensorWrapper &D, TensorWrapper &bias,
423- TensorWrapper &pre_gelu_out, bool grad, bool accumulate,
424- cudaStream_t stream_main) {
425- auto [m, n, k] = compute_ar_dims (A, transa, B, transb, _tp_size);
426- // col-major GEMM compute overlapped with all-reduce on the output
427- // (M, K/P) x (K/P, N) = (M, N) -(AR)-> (M, N)
428- nvte_gemm_all_reduce (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
429- pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
430- stream_main, _algo_type);
431- }
432-
433391/* **************************************************************************************************
434392 * Comm+GEMM Overlap Base (Pipelined / Collective)
435393 **************************************************************************************************/
@@ -549,8 +507,13 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa
549507 bool use_split_accumulator, TensorWrapper &rs_output,
550508 cudaStream_t stream_main) {
551509 if (_with_cublasmp) {
552- return cublasmp_gemm_rs (A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
553- stream_main);
510+ auto [m, n, k] = compute_rs_dims (A, transa, B, transb, _tp_size);
511+ // col-major GEMM compute overlapped with reduce-scatter on the output
512+ // (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
513+ nvte_gemm_reduce_scatter (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
514+ pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
515+ stream_main, _algo_type);
516+ return ;
554517 }
555518
556519 int ori_sms = _ub_comm->sms ;
@@ -651,8 +614,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
651614 bool grad, bool accumulate, bool use_split_accumulator,
652615 TensorWrapper &rs_output, cudaStream_t stream_main) {
653616 if (_with_cublasmp) {
654- return cublasmp_gemm_rs (A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
655- stream_main);
617+ auto [m, n, k] = compute_rs_dims (A, transa, B, transb, _tp_size);
618+ // col-major GEMM compute overlapped with reduce-scatter on the output
619+ // (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
620+ nvte_gemm_reduce_scatter (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
621+ pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
622+ stream_main, _algo_type);
623+ return ;
656624 }
657625
658626 // Get GEMM dimensions
@@ -968,8 +936,13 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
968936 TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
969937 bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) {
970938 if (_with_cublasmp) {
971- return cublasmp_ag_gemm (A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
972- stream_main);
939+ auto [m, n, k] = compute_ag_dims (A, transa, B, transb, _tp_size);
940+ // col-major GEMM compute overlapped with all-gather on input B
941+ // (M/P, K) x [(K, N/P) -(AG)-> (K, N)] = (M/P, N)
942+ nvte_all_gather_gemm (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
943+ pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
944+ stream_main, _algo_type);
945+ return ;
973946 }
974947
975948 int ori_sms = _ub_comm->sms ;
@@ -1075,8 +1048,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
10751048 bool use_split_accumulator, TensorWrapper &B_copy,
10761049 cudaStream_t stream_main) {
10771050 if (_with_cublasmp) {
1078- return cublasmp_ag_gemm (A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
1079- stream_main);
1051+ auto [m, n, k] = compute_ag_dims (A, transa, B, transb, _tp_size);
1052+ // col-major GEMM compute overlapped with all-gather on input B
1053+ // (M/P, K) x [(K, N/P) -(AG)-> (K, N)] = (M/P, N)
1054+ nvte_all_gather_gemm (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
1055+ pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
1056+ stream_main, _algo_type);
1057+ return ;
10801058 }
10811059
10821060 int ori_sms = _ub_comm->sms ;
@@ -1247,8 +1225,13 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
12471225 bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
12481226 cudaStream_t stream_main) {
12491227 if (_with_cublasmp) {
1250- return cublasmp_gemm_rs (A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
1251- stream_main);
1228+ auto [m, n, k] = compute_rs_dims (A, transa, B, transb, _tp_size);
1229+ // col-major GEMM compute overlapped with reduce-scatter on the output
1230+ // (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
1231+ nvte_gemm_reduce_scatter (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
1232+ pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
1233+ stream_main, _algo_type);
1234+ return ;
12521235 }
12531236
12541237 int ori_sms = _ub_comm->sms ;
@@ -1316,8 +1299,13 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
13161299 bool use_split_accumulator, TensorWrapper &rs_output,
13171300 cudaStream_t stream_main) {
13181301 if (_with_cublasmp) {
1319- return cublasmp_gemm_rs (A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
1320- stream_main);
1302+ auto [m, n, k] = compute_rs_dims (A, transa, B, transb, _tp_size);
1303+ // col-major GEMM compute overlapped with reduce-scatter on the output
1304+ // (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
1305+ nvte_gemm_reduce_scatter (_cublasmp_ctx, m, n, k, A.data (), B.data (), D.data (), bias.data (),
1306+ pre_gelu_out.data (), transa, transb, grad, accumulate, _num_comm_sm,
1307+ stream_main, _algo_type);
1308+ return ;
13211309 }
13221310
13231311 int ori_sms = _ub_comm->sms ;
0 commit comments