Skip to content

Commit fd6d304

Browse files
authored
win: fix cuda build (#3204)
1 parent e1e1399 commit fd6d304

File tree

8 files changed

+25
-7
lines changed

8 files changed

+25
-7
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,10 @@ target_compile_options(
118118
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
119119

120120
# Required for generating optimized CUTLASS code.
121-
target_compile_options(
122-
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fno-strict-aliasing>")
121+
if(NOT MSVC)
122+
target_compile_options(
123+
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fno-strict-aliasing>")
124+
endif()
123125

124126
# Suppress nvcc warnings on C++ headers.
125127
target_compile_options(

mlx/backend/cuda/device.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ Device::~Device() {
6666

6767
void Device::make_current() {
6868
// We need to set/get current CUDA device very frequently, cache it to reduce
69-
// actual calls of CUDA APIs.
70-
static thread_local int current = 0;
69+
// actual calls of CUDA APIs. Use -1 as sentinel so the first call on each
70+
// new thread always calls cudaSetDevice (which establishes the CUDA primary
71+
// context). Without this, device 0 would never get set on a new thread.
72+
static thread_local int current = -1;
7173
if (current != device_) {
7274
CHECK_CUDA_ERROR(cudaSetDevice(device_));
7375
current = device_;

mlx/backend/cuda/eval.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ void new_stream(Stream s) {
2121

2222
void eval(array& arr) {
2323
nvtx3::scoped_range r("gpu::eval");
24+
// Ensure CUDA context is active on this thread. Required when MLX is called
25+
// from threads that have not yet established a CUDA context (e.g. thread
26+
// pools, language runtimes that migrate work across OS threads).
27+
cu::device(arr.primitive().stream().device).make_current();
2428
auto outputs = arr.outputs();
2529
{
2630
// If the array is a tracer hold a reference

mlx/backend/cuda/quantized/qmm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
target_sources(
22
mlx
3-
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cpp
3+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu
44
${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu
55
${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu
66
${CMAKE_CURRENT_SOURCE_DIR}/qmm_impl_sm90_m128_n16_m1.cu

mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ void qmm_sm90(
111111
reinterpret_cast<void*>(kernel),
112112
gemm.get_grid_shape(gemm.params()),
113113
GemmKernel::get_block_shape(),
114-
{get<0>(cluster), get<1>(cluster), get<2>(cluster)},
114+
{static_cast<unsigned>(get<0>(cluster)),
115+
static_cast<unsigned>(get<1>(cluster)),
116+
static_cast<unsigned>(get<2>(cluster))},
115117
GemmKernel::SharedStorageSize,
116118
kernel_params);
117119
}

mlx/backend/cuda/scaled_dot_product_attention.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,15 @@ bool supports_sdpa_cudnn(
318318
bool has_arr_mask,
319319
bool do_causal,
320320
Stream s) {
321+
#ifdef _WIN32
322+
// On Windows (WDDM), cuDNN SDPA has severe performance issues due to
323+
// high per-kernel-launch overhead in the WDDM driver model. cuDNN's
324+
// multi-kernel SDPA amplifies this, making it much slower than the
325+
// single-kernel sdpa_vector path for both prefill and generation.
326+
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SDPA", 0);
327+
#else
321328
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SDPA", 1);
329+
#endif
322330
if (!enabled) {
323331
return false;
324332
}

mlx/distributed/nccl/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
if(MLX_BUILD_CUDA)
1+
if(MLX_BUILD_CUDA AND NOT WIN32)
22
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
33
find_package(NCCL)
44
if(NCCL_FOUND)

0 commit comments

Comments
 (0)