1515
1616#ifdef USE_ROCM
1717#include < inttypes.h>
18+ #include " amd_smi/amdsmi.h"
1819#include " hip/hip_runtime.h"
19- #include " rocm_smi/rocm_smi.h"
2020
21- #define RSMI_CHECK (fn ) \
22- do { \
23- rsmi_status_t ret = (fn); \
24- TORCH_CHECK_EQ ((ret), RSMI_STATUS_SUCCESS); \
21+ #define AMDSMI_CHECK (fn ) \
22+ do { \
23+ amdsmi_status_t ret = (fn); \
24+ TORCH_CHECK_EQ ((ret), AMDSMI_STATUS_SUCCESS); \
2525 } while (0 )
2626
27- #define RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE 16
27+ #define AMDSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE 16
2828
2929namespace fbgemm_gpu {
3030AdjacencyMatrix<Links> get_nvlink_matrix () {
3131 auto world_size = at::cuda::getNumGPUs ();
32- RSMI_CHECK ( rsmi_init ( 0 ));
32+ AMDSMI_CHECK ( amdsmi_init (AMDSMI_INIT_AMD_GPUS ));
3333
34- // Note that ROCm_SMI uses a different numbering method to ROCm runtime,
34+ // Note that AMD SMI uses a different numbering method to ROCm runtime,
3535 // so we need to learn the mapping by using the bus ID.
36- uint32_t device_count;
37- RSMI_CHECK (rsmi_num_monitor_devices (&device_count));
3836
39- std::unordered_map<Node, uint32_t > rocm_device_to_rsmi_device;
37+ // Get all sockets, then collect all GPU processor handles across sockets.
38+ uint32_t socket_count = 0 ;
39+ AMDSMI_CHECK (amdsmi_get_socket_handles (&socket_count, nullptr ));
40+ std::vector<amdsmi_socket_handle> sockets (socket_count);
41+ AMDSMI_CHECK (amdsmi_get_socket_handles (&socket_count, sockets.data ()));
42+
43+ std::vector<amdsmi_processor_handle> processor_handles;
44+ for (uint32_t s = 0 ; s < socket_count; s++) {
45+ uint32_t device_count = 0 ;
46+ AMDSMI_CHECK (amdsmi_get_processor_handles (sockets[s], &device_count, nullptr ));
47+ std::vector<amdsmi_processor_handle> socket_handles (device_count);
48+ AMDSMI_CHECK (amdsmi_get_processor_handles (
49+ sockets[s], &device_count, socket_handles.data ()));
50+ processor_handles.insert (
51+ processor_handles.end (), socket_handles.begin (), socket_handles.end ());
52+ }
4053
41- for (const auto i : c10::irange (device_count)) {
54+ std::unordered_map<Node, amdsmi_processor_handle> hip_device_to_handle;
55+
56+ for (const auto & handle : processor_handles) {
4257 uint64_t pci_info;
43- RSMI_CHECK ( rsmi_dev_pci_id_get (i , &pci_info));
58+ AMDSMI_CHECK ( amdsmi_get_gpu_bdf_id (handle , &pci_info));
4459 uint64_t domain, bus, device, function;
4560 domain = (pci_info >> 32 ) & 0xffffffff ;
4661 bus = (pci_info >> 8 ) & 0xff ;
4762 device = (pci_info >> 3 ) & 0x1f ;
4863 function = pci_info & 0x7 ;
4964 // Different from CUDA, we do not get the PCI BUS ID as a char* and we need
5065 // to reconstruct it.
51- char pci_bus_id_str[RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE ];
66+ char pci_bus_id_str[AMDSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE ];
5267 sprintf (
5368 pci_bus_id_str,
5469 " %04" PRIu64 " :%02" PRIu64 " :%02" PRIu64 " .%0" PRIu64,
@@ -57,15 +72,15 @@ AdjacencyMatrix<Links> get_nvlink_matrix() {
5772 device,
5873 function);
5974
60- std::array<char , RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE > pci_bus_id;
75+ std::array<char , AMDSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE > pci_bus_id;
6176 std::copy (
6277 &pci_bus_id_str[0 ],
63- &pci_bus_id_str[RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE ],
78+ &pci_bus_id_str[AMDSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE ],
6479 pci_bus_id.data ());
6580 int32_t node = 0 ;
6681 auto err = hipDeviceGetByPCIBusId (&node, pci_bus_id.data ());
6782 if (err == hipSuccess) {
68- rocm_device_to_rsmi_device .insert ({node, i });
83+ hip_device_to_handle .insert ({node, handle });
6984 } else {
7085 // flush the last error - this can occur when e.g. we set
7186 // HIP_VISIBLE_DEVICES to a subset of the available GPUs in the system.
@@ -75,22 +90,22 @@ AdjacencyMatrix<Links> get_nvlink_matrix() {
7590
7691 std::vector<Links> links (world_size * world_size);
7792 for (const auto i : c10::irange (world_size)) {
78- auto src_rsmi_device = rocm_device_to_rsmi_device .find (i);
79- if (src_rsmi_device != rocm_device_to_rsmi_device .end ()) {
93+ auto src = hip_device_to_handle .find (i);
94+ if (src != hip_device_to_handle .end ()) {
8095 for (const auto j : c10::irange (world_size)) {
81- auto dst_rsmi_device = rocm_device_to_rsmi_device .find (j);
82- if (dst_rsmi_device != rocm_device_to_rsmi_device .end ()) {
96+ auto dst = hip_device_to_handle .find (j);
97+ if (dst != hip_device_to_handle .end ()) {
8398 bool is_active;
84- RSMI_CHECK ( rsmi_is_P2P_accessible (
85- src_rsmi_device ->second , dst_rsmi_device ->second , &is_active));
99+ AMDSMI_CHECK (
100+ amdsmi_is_P2P_accessible (src ->second , dst ->second , &is_active));
86101 if (is_active) {
87102 links[i * world_size + j] += 1 ;
88103 }
89104 }
90105 }
91106 }
92107 }
93- RSMI_CHECK ( rsmi_shut_down ());
108+ AMDSMI_CHECK ( amdsmi_shut_down ());
94109 return [=](Node i, Node j) {
95110 TORCH_CHECK_LT (i, world_size);
96111 TORCH_CHECK_LT (j, world_size);
0 commit comments