Skip to content

Commit f8261d3

Browse files
committed
cleaned public facing C++ api for CommOverlapCore
Signed-off-by: Alp Dener <adener@nvidia.com>
1 parent 94e90b9 commit f8261d3

4 files changed

Lines changed: 68 additions & 77 deletions

File tree

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -386,50 +386,8 @@ CublasMpDims compute_rs_dims(const TensorWrapper &A, bool transa, const TensorWr
386386
return {m, n, k};
387387
}
388388

389-
CublasMpDims compute_ar_dims(const TensorWrapper &A, bool transa, const TensorWrapper &B,
390-
bool transb, int tp_size) {
391-
// AR shares the same m/n/k semantics as RS at descriptor level.
392-
return compute_rs_dims(A, transa, B, transb, tp_size);
393-
}
394-
395389
} // namespace
396390

397-
void CommOverlapCore::cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B,
398-
bool transb, TensorWrapper &D, TensorWrapper &bias,
399-
TensorWrapper &pre_gelu_out, bool grad, bool accumulate,
400-
cudaStream_t stream_main) {
401-
auto [m, n, k] = compute_ag_dims(A, transa, B, transb, _tp_size);
402-
// col-major GEMM compute overlapped with all-gather on input B
403-
// (M/P, K) x [(K, N/P) -(AG)-> (K, N)] = (M/P, N)
404-
nvte_all_gather_gemm(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
405-
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
406-
stream_main, _algo_type);
407-
}
408-
409-
void CommOverlapCore::cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
410-
bool transb, TensorWrapper &D, TensorWrapper &bias,
411-
TensorWrapper &pre_gelu_out, bool grad, bool accumulate,
412-
cudaStream_t stream_main) {
413-
auto [m, n, k] = compute_rs_dims(A, transa, B, transb, _tp_size);
414-
// col-major GEMM compute overlapped with reduce-scatter on the output
415-
// (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
416-
nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
417-
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
418-
stream_main, _algo_type);
419-
}
420-
421-
void CommOverlapCore::cublasmp_gemm_ar(const TensorWrapper &A, bool transa, const TensorWrapper &B,
422-
bool transb, TensorWrapper &D, TensorWrapper &bias,
423-
TensorWrapper &pre_gelu_out, bool grad, bool accumulate,
424-
cudaStream_t stream_main) {
425-
auto [m, n, k] = compute_ar_dims(A, transa, B, transb, _tp_size);
426-
// col-major GEMM compute overlapped with all-reduce on the output
427-
// (M, K/P) x (K/P, N) = (M, N) -(AR)-> (M, N)
428-
nvte_gemm_all_reduce(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
429-
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
430-
stream_main, _algo_type);
431-
}
432-
433391
/***************************************************************************************************
434392
* Comm+GEMM Overlap Base (Pipelined / Collective)
435393
**************************************************************************************************/
@@ -549,8 +507,13 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa
549507
bool use_split_accumulator, TensorWrapper &rs_output,
550508
cudaStream_t stream_main) {
551509
if (_with_cublasmp) {
552-
return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
553-
stream_main);
510+
auto [m, n, k] = compute_rs_dims(A, transa, B, transb, _tp_size);
511+
// col-major GEMM compute overlapped with reduce-scatter on the output
512+
// (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
513+
nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
514+
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
515+
stream_main, _algo_type);
516+
return;
554517
}
555518

556519
int ori_sms = _ub_comm->sms;
@@ -651,8 +614,13 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
651614
bool grad, bool accumulate, bool use_split_accumulator,
652615
TensorWrapper &rs_output, cudaStream_t stream_main) {
653616
if (_with_cublasmp) {
654-
return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
655-
stream_main);
617+
auto [m, n, k] = compute_rs_dims(A, transa, B, transb, _tp_size);
618+
// col-major GEMM compute overlapped with reduce-scatter on the output
619+
// (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
620+
nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
621+
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
622+
stream_main, _algo_type);
623+
return;
656624
}
657625

658626
// Get GEMM dimensions
@@ -968,8 +936,13 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(
968936
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
969937
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) {
970938
if (_with_cublasmp) {
971-
return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
972-
stream_main);
939+
auto [m, n, k] = compute_ag_dims(A, transa, B, transb, _tp_size);
940+
// col-major GEMM compute overlapped with all-gather on input B
941+
// (M/P, K) x [(K, N/P) -(AG)-> (K, N)] = (M/P, N)
942+
nvte_all_gather_gemm(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
943+
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
944+
stream_main, _algo_type);
945+
return;
973946
}
974947

975948
int ori_sms = _ub_comm->sms;
@@ -1075,8 +1048,13 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
10751048
bool use_split_accumulator, TensorWrapper &B_copy,
10761049
cudaStream_t stream_main) {
10771050
if (_with_cublasmp) {
1078-
return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
1079-
stream_main);
1051+
auto [m, n, k] = compute_ag_dims(A, transa, B, transb, _tp_size);
1052+
// col-major GEMM compute overlapped with all-gather on input B
1053+
// (M/P, K) x [(K, N/P) -(AG)-> (K, N)] = (M/P, N)
1054+
nvte_all_gather_gemm(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
1055+
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
1056+
stream_main, _algo_type);
1057+
return;
10801058
}
10811059

10821060
int ori_sms = _ub_comm->sms;
@@ -1247,8 +1225,13 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(
12471225
bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
12481226
cudaStream_t stream_main) {
12491227
if (_with_cublasmp) {
1250-
return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
1251-
stream_main);
1228+
auto [m, n, k] = compute_rs_dims(A, transa, B, transb, _tp_size);
1229+
// col-major GEMM compute overlapped with reduce-scatter on the output
1230+
// (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
1231+
nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
1232+
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
1233+
stream_main, _algo_type);
1234+
return;
12521235
}
12531236

12541237
int ori_sms = _ub_comm->sms;
@@ -1316,8 +1299,13 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
13161299
bool use_split_accumulator, TensorWrapper &rs_output,
13171300
cudaStream_t stream_main) {
13181301
if (_with_cublasmp) {
1319-
return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate,
1320-
stream_main);
1302+
auto [m, n, k] = compute_rs_dims(A, transa, B, transb, _tp_size);
1303+
// col-major GEMM compute overlapped with reduce-scatter on the output
1304+
// (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P)
1305+
nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(),
1306+
pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm,
1307+
stream_main, _algo_type);
1308+
return;
13211309
}
13221310

13231311
int ori_sms = _ub_comm->sms;

transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,18 +128,6 @@ class CommOverlapCore {
128128

129129
bool with_cublasmp() { return _with_cublasmp; }
130130

131-
void cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
132-
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
133-
bool grad, bool accumulate, cudaStream_t stream_main);
134-
135-
void cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
136-
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
137-
bool grad, bool accumulate, cudaStream_t stream_main);
138-
139-
void cublasmp_gemm_ar(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
140-
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
141-
bool grad, bool accumulate, cudaStream_t stream_main);
142-
143131
virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B,
144132
bool transb, TensorWrapper &D, TensorWrapper &bias,
145133
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,

transformer_engine/jax/csrc/extensions/gemm.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,17 @@ Error_Type GemmInitV2FFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type
170170
std::vector<size_t>{static_cast<size_t>(bias.element_count())});
171171
}
172172
TensorWrapper pre_gelu_out_(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING));
173+
TensorWrapper dummy;
173174
// Match GemmV2FFI's operand swap: rhs becomes A, lhs becomes B.
174175
cudaStream_t prepare_stream = cudaStreamPerThread;
175176
if (config.collective_op == JAXX_Collective_Op::ALL_GATHER) {
176-
executor->cublasmp_ag_gemm(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, d_,
177-
bias_, pre_gelu_out_, false /*grad*/, false /*accumulate*/,
178-
prepare_stream);
177+
executor->split_overlap_ag(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, d_,
178+
bias_, pre_gelu_out_, dummy, false /*grad*/, false /*accumulate*/,
179+
false /*use_split_accumulator*/, dummy, prepare_stream);
179180
} else if (config.collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
180-
executor->cublasmp_gemm_rs(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, d_,
181-
bias_, pre_gelu_out_, false /*grad*/, false /*accumulate*/,
182-
prepare_stream);
181+
executor->split_overlap_rs(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, d_,
182+
bias_, pre_gelu_out_, dummy, false /*grad*/, false /*accumulate*/,
183+
false /*use_split_accumulator*/, dummy, prepare_stream);
183184
}
184185
NVTE_CHECK_CUDA(cudaStreamSynchronize(prepare_stream));
185186
}

transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,18 +274,32 @@ void cublasmp_capture_warmup(te::CommOverlapCore *core, int tp_size, te::CommOve
274274
NVTE_CHECK_CUDA(cudaMemset(a_ptr, 0, a_bytes));
275275
NVTE_CHECK_CUDA(cudaMemset(b_ptr, 0, b_bytes));
276276

277-
te::TensorWrapper A_tw, B_tw, D_tw, bias_tw, pre_gelu_tw;
277+
te::TensorWrapper A_tw, B_tw, D_tw, bias_tw, pre_gelu_tw, dummy;
278278
A_tw.set_rowwise_data(a_ptr, te::DType::kBFloat16, a_shape);
279279
B_tw.set_rowwise_data(b_ptr, te::DType::kBFloat16, b_shape);
280280
D_tw.set_rowwise_data(d_ptr, te::DType::kBFloat16, d_shape);
281281

282282
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
283283
if (comm_type == te::CommOverlapType::AG) {
284-
core->cublasmp_ag_gemm(A_tw, /*transa=*/true, B_tw, /*transb=*/false, D_tw, bias_tw,
285-
pre_gelu_tw, /*grad=*/false, /*accumulate=*/false, stream);
284+
if (core->is_atomic_gemm()) {
285+
core->atomic_gemm_overlap_ag(
286+
A_tw, /*transa=*/true, B_tw, /*transb=*/false, D_tw, bias_tw, pre_gelu_tw, dummy,
287+
/*grad=*/false, /*accumulate=*/false, /*use_split_accumulator=*/false, dummy, stream);
288+
} else {
289+
core->split_overlap_ag(
290+
A_tw, /*transa=*/true, B_tw, /*transb=*/false, D_tw, bias_tw, pre_gelu_tw, dummy,
291+
/*grad=*/false, /*accumulate=*/false, /*use_split_accumulator=*/false, dummy, stream);
292+
}
286293
} else {
287-
core->cublasmp_gemm_rs(A_tw, /*transa=*/true, B_tw, /*transb=*/false, D_tw, bias_tw,
288-
pre_gelu_tw, /*grad=*/false, /*accumulate=*/false, stream);
294+
if (core->is_atomic_gemm()) {
295+
core->atomic_gemm_overlap_rs(
296+
A_tw, /*transa=*/true, B_tw, /*transb=*/false, D_tw, bias_tw, pre_gelu_tw, dummy,
297+
/*grad=*/false, /*accumulate=*/false, /*use_split_accumulator=*/false, dummy, stream);
298+
} else {
299+
core->split_overlap_rs(
300+
A_tw, /*transa=*/true, B_tw, /*transb=*/false, D_tw, bias_tw, pre_gelu_tw, dummy,
301+
/*grad=*/false, /*accumulate=*/false, /*use_split_accumulator=*/false, dummy, stream);
302+
}
289303
}
290304
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
291305
cudaFree(a_ptr);

0 commit comments

Comments
 (0)