|
14 | 14 |
|
15 | 15 | #include "algorithm/types.hpp" |
16 | 16 | #include "chase_mpi_matrices.hpp" |
| 17 | +#include "mpi_wrapper.hpp" |
17 | 18 |
|
18 | 19 | namespace chase |
19 | 20 | { |
@@ -436,6 +437,41 @@ class ChaseMpiProperties |
436 | 437 | #else |
437 | 438 | V_.reset(new T[N_ * max_block_]()); |
438 | 439 | #endif |
| 440 | + |
| 441 | + comm_2 row_comm_dup; |
| 442 | + comm_2 col_comm_dup; |
| 443 | +#if defined(HAS_NCCL) |
| 444 | + ncclUniqueId nccl_id, nccl_ids[nprocs_]; |
| 445 | + ncclGetUniqueId(&nccl_id); |
| 446 | + MPI_Allgather(&nccl_id, sizeof(ncclUniqueId), MPI_UINT8_T, |
| 447 | + &nccl_ids[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm); |
| 448 | + |
| 449 | + |
| 450 | + for(auto i = 0; i < dims_[0]; i++){ |
| 451 | + if(coord_[0] == i){ |
| 452 | + ncclCommInitRank(&row_comm_dup, dims_[1], nccl_ids[i], coord_[1]); |
| 453 | + } |
| 454 | + } |
| 455 | + |
| 456 | + //col_comm |
| 457 | + ncclUniqueId nccl_id_2, nccl_ids_2[nprocs_]; |
| 458 | + ncclGetUniqueId(&nccl_id_2); |
| 459 | + MPI_Allgather(&nccl_id_2, sizeof(ncclUniqueId), MPI_UINT8_T, |
| 460 | + &nccl_ids_2[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm); |
| 461 | + |
| 462 | + |
| 463 | + for(auto i = 0; i < dims_[1]; i++){ |
| 464 | + if(coord_[1] == i){ |
| 465 | + ncclCommInitRank(&col_comm_dup, dims_[0], nccl_ids_2[i * dims_[0]], coord_[0]); |
| 466 | + } |
| 467 | + } |
| 468 | +#else |
| 469 | + MPI_Comm_dup(row_comm_, &row_comm_dup); |
| 470 | + MPI_Comm_dup(col_comm_, &col_comm_dup); |
| 471 | +#endif |
| 472 | + mpi_wrapper_.add(row_comm_, row_comm_dup); |
| 473 | + mpi_wrapper_.add(col_comm_, col_comm_dup); |
| 474 | + |
439 | 475 | #ifdef USE_NSIGHT |
440 | 476 | nvtxRangePop(); |
441 | 477 | #endif |
@@ -649,6 +685,41 @@ class ChaseMpiProperties |
649 | 685 | #else |
650 | 686 | V_.reset(new T[N_ * max_block_]()); |
651 | 687 | #endif |
| 688 | + |
| 689 | + comm_2 row_comm_dup; |
| 690 | + comm_2 col_comm_dup; |
| 691 | +#if defined(HAS_NCCL) |
| 692 | + ncclUniqueId nccl_id, nccl_ids[nprocs_]; |
| 693 | + ncclGetUniqueId(&nccl_id); |
| 694 | + MPI_Allgather(&nccl_id, sizeof(ncclUniqueId), MPI_UINT8_T, |
| 695 | + &nccl_ids[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm); |
| 696 | + |
| 697 | + |
| 698 | + for(auto i = 0; i < dims_[0]; i++){ |
| 699 | + if(coord_[0] == i){ |
| 700 | + ncclCommInitRank(&row_comm_dup, dims_[1], nccl_ids[i], coord_[1]); |
| 701 | + } |
| 702 | + } |
| 703 | + |
| 704 | + //col_comm |
| 705 | + ncclUniqueId nccl_id_2, nccl_ids_2[nprocs_]; |
| 706 | + ncclGetUniqueId(&nccl_id_2); |
| 707 | + MPI_Allgather(&nccl_id_2, sizeof(ncclUniqueId), MPI_UINT8_T, |
| 708 | + &nccl_ids_2[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm); |
| 709 | + |
| 710 | + |
| 711 | + for(auto i = 0; i < dims_[1]; i++){ |
| 712 | + if(coord_[1] == i){ |
| 713 | + ncclCommInitRank(&col_comm_dup, dims_[0], nccl_ids_2[i * dims_[0]], coord_[0]); |
| 714 | + } |
| 715 | + } |
| 716 | +#else |
| 717 | + MPI_Comm_dup(row_comm_, &row_comm_dup); |
| 718 | + MPI_Comm_dup(col_comm_, &col_comm_dup); |
| 719 | +#endif |
| 720 | + mpi_wrapper_.add(row_comm_, row_comm_dup); |
| 721 | + mpi_wrapper_.add(col_comm_, col_comm_dup); |
| 722 | + |
652 | 723 | #ifdef USE_NSIGHT |
653 | 724 | nvtxRangePop(); |
654 | 725 | #endif |
@@ -847,11 +918,47 @@ class ChaseMpiProperties |
847 | 918 | #else |
848 | 919 | V_.reset(new T[N_ * max_block_]()); |
849 | 920 | #endif |
| 921 | + |
| 922 | + comm_2 row_comm_dup; |
| 923 | + comm_2 col_comm_dup; |
| 924 | +#if defined(HAS_NCCL) |
| 925 | + ncclUniqueId nccl_id, nccl_ids[nprocs_]; |
| 926 | + ncclGetUniqueId(&nccl_id); |
| 927 | + MPI_Allgather(&nccl_id, sizeof(ncclUniqueId), MPI_UINT8_T, |
| 928 | + &nccl_ids[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm); |
| 929 | + |
| 930 | + |
| 931 | + for(auto i = 0; i < dims_[0]; i++){ |
| 932 | + if(coord_[0] == i){ |
| 933 | + ncclCommInitRank(&row_comm_dup, dims_[1], nccl_ids[i], coord_[1]); |
| 934 | + } |
| 935 | + } |
| 936 | + |
| 937 | + //col_comm |
| 938 | + ncclUniqueId nccl_id_2, nccl_ids_2[nprocs_]; |
| 939 | + ncclGetUniqueId(&nccl_id_2); |
| 940 | + MPI_Allgather(&nccl_id_2, sizeof(ncclUniqueId), MPI_UINT8_T, |
| 941 | + &nccl_ids_2[0], sizeof(ncclUniqueId), MPI_UINT8_T, comm); |
| 942 | + |
| 943 | + |
| 944 | + for(auto i = 0; i < dims_[1]; i++){ |
| 945 | + if(coord_[1] == i){ |
| 946 | + ncclCommInitRank(&col_comm_dup, dims_[0], nccl_ids_2[i * dims_[0]], coord_[0]); |
| 947 | + } |
| 948 | + } |
| 949 | +#else |
| 950 | + MPI_Comm_dup(row_comm_, &row_comm_dup); |
| 951 | + MPI_Comm_dup(col_comm_, &col_comm_dup); |
| 952 | +#endif |
| 953 | + mpi_wrapper_.add(row_comm_, row_comm_dup); |
| 954 | + mpi_wrapper_.add(col_comm_, col_comm_dup); |
| 955 | + |
850 | 956 | #ifdef USE_NSIGHT |
851 | 957 | nvtxRangePop(); |
852 | 958 | #endif |
853 | 959 | } |
854 | 960 |
|
| 961 | + Comm_t get_mpi_wrapper() { return mpi_wrapper_;} |
855 | 962 | #if defined(HAS_SCALAPACK) |
856 | 963 | int get_colcomm_ctxt() { return colcomm_ctxt_; } |
857 | 964 |
|
@@ -1600,6 +1707,8 @@ class ChaseMpiProperties |
1600 | 1707 | //! It is allocated only when no ScaLAPACK is detected. |
1601 | 1708 | std::unique_ptr<T[]> V_; |
1602 | 1709 | #endif |
| 1710 | + |
| 1711 | + Comm_t mpi_wrapper_; |
1603 | 1712 | }; |
1604 | 1713 | } // namespace mpi |
1605 | 1714 | } // namespace chase |
0 commit comments