Skip to content

Commit ee517d3

Browse files
committed
linting
Signed-off-by: Alp Dener <adener@nvidia.com>
1 parent d79bf21 commit ee517d3

4 files changed

Lines changed: 7 additions & 5 deletions

File tree

tests/pytorch/distributed/run_gemm_with_overlap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
MXFP8Quantizer,
2525
)
2626
import transformer_engine.pytorch.cpp_extensions as tex
27-
from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
27+
2828
from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather
2929

3030
warnings.filterwarnings("ignore", category=DeprecationWarning)

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,11 @@ CommOverlapCore::CommOverlapCore(int64_t nccl_comm_ptr, int tp_rank, int tp_size
8484
"Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1.");
8585
_with_cublasmp = true;
8686

87-
nvte_comm_gemm_ctx_create(reinterpret_cast<ncclComm_t>(nccl_comm_ptr), tp_size, tp_rank);
87+
_cublasmp_ctx = nvte_comm_gemm_ctx_create(reinterpret_cast<ncclComm_t>(nccl_comm_ptr), tp_size,
88+
tp_rank);
8889

90+
_tp_id = tp_rank;
91+
_tp_size = tp_size;
8992
_num_comm_sm = num_comm_sm;
9093
_is_p2p = is_p2p;
9194
_atomic_gemm = atomic_gemm;

transformer_engine/pytorch/module/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,6 @@ def add_ub(
331331
comm_priority: int = 0,
332332
gemm_priority: int = 0,
333333
pipeline_rs_overlap_first_gemm: bool = False,
334-
with_cublasmp: bool = False,
335334
) -> None:
336335
if atomic_gemm:
337336
warnings.warn(
@@ -506,7 +505,7 @@ def fill_userbuffers_buffer_for_all_gather(
506505
"""
507506
# cuBlasMp already handles its own buffer filling and quantization factors
508507
if comm.with_cublasmp():
509-
return
508+
return local_tensor, local_tensor
510509

511510
# Tensor dimensions
512511
local_shape = local_tensor.size()

0 commit comments

Comments
 (0)