Skip to content

Commit 4bd2f5b

Browse files
committed
pytorch: guard cusolvermp/newton_schulz pybind + decl on NVTE_WITH_CUSOLVERMP
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
1 parent 93ee116 commit 4bd2f5b

2 files changed

Lines changed: 4 additions & 0 deletions

File tree

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,12 +631,14 @@ void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at:
631631
* Newton-Schulz (cuSolverMp)
632632
**************************************************************************************************/
633633

634+
#ifdef NVTE_WITH_CUSOLVERMP
634635
int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank);
635636

636637
void cusolvermp_ctx_destroy(int64_t ctx_ptr);
637638

638639
void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations,
639640
std::vector<float> coefficients);
641+
#endif // NVTE_WITH_CUSOLVERMP
640642

641643
} // namespace transformer_engine::pytorch
642644

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
589589
&transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda,
590590
"Fused compute E8M0 scale_inv from amax", py::call_guard<py::gil_scoped_release>());
591591

592+
#ifdef NVTE_WITH_CUSOLVERMP
592593
// Newton-Schulz (cuSolverMp)
593594
m.def("cusolvermp_ctx_create", &transformer_engine::pytorch::cusolvermp_ctx_create,
594595
"Create cuSolverMp context for Newton-Schulz", py::arg("nccl_comm_ptr"), py::arg("nranks"),
@@ -599,6 +600,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
599600
"Newton-Schulz matrix orthogonalization", py::arg("ctx_ptr"), py::arg("m"), py::arg("n"),
600601
py::arg("x"), py::arg("num_iterations"), py::arg("coefficients"),
601602
py::call_guard<py::gil_scoped_release>());
603+
#endif // NVTE_WITH_CUSOLVERMP
602604

603605
// Comm+GEMM Overlap
604606
m.def("bulk_overlap_ag_with_external_gemm",

0 commit comments

Comments
 (0)