Skip to content

Commit 185b06d

Browse files
author
Awni Hannun
authored
Patch for multi device CUDA (#3100)
1 parent 90e38f7 commit 185b06d

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

mlx/backend/cuda/allocator.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ CudaAllocator::CudaAllocator()
168168
free_limit_ = total_memory_ - memory_limit_;
169169
max_pool_size_ = memory_limit_;
170170

171+
int curr;
172+
CHECK_CUDA_ERROR(cudaGetDevice(&curr));
173+
171174
int device_count = gpu::device_count();
172175
free_streams_.resize(device_count);
173176
mem_pools_.resize(device_count);
@@ -178,6 +181,7 @@ CudaAllocator::CudaAllocator()
178181
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pools_[i], i));
179182
}
180183
}
184+
CHECK_CUDA_ERROR(cudaSetDevice(curr));
181185
}
182186

183187
Buffer

0 commit comments

Comments
 (0)