We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 90e38f7 commit 185b06dCopy full SHA for 185b06d
1 file changed
mlx/backend/cuda/allocator.cpp
@@ -168,6 +168,9 @@ CudaAllocator::CudaAllocator()
168
free_limit_ = total_memory_ - memory_limit_;
169
max_pool_size_ = memory_limit_;
170
171
+ int curr;
172
+ CHECK_CUDA_ERROR(cudaGetDevice(&curr));
173
+
174
int device_count = gpu::device_count();
175
free_streams_.resize(device_count);
176
mem_pools_.resize(device_count);
@@ -178,6 +181,7 @@ CudaAllocator::CudaAllocator()
178
181
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pools_[i], i));
179
182
}
180
183
184
+ CHECK_CUDA_ERROR(cudaSetDevice(curr));
185
186
187
Buffer
0 commit comments