Skip to content

Commit 8ef5395

Browse files
authored
Fix failing python tests on Windows (ml-explore#3076)
1 parent 212077f commit 8ef5395

10 files changed

Lines changed: 63 additions & 68 deletions

File tree

mlx/CMakeLists.txt

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ set_target_properties(
3232
CXX_VISIBILITY_PRESET hidden
3333
CUDA_VISIBILITY_PRESET hidden)
3434

35-
# Define MLX_EXPORT for shared libraries.
36-
set_target_properties(mlx mlx_version PROPERTIES DEFINE_SYMBOL MLX_EXPORT)
37-
# Define MLX_STATIC for static libraries.
38-
if(NOT BUILD_SHARED_LIBS)
35+
# Define MLX_EXPORT for shared libraries, MLX_STATIC for static libraries.
36+
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL MLX_EXPORT)
37+
if(BUILD_SHARED_LIBS)
38+
target_compile_definitions(mlx_version PUBLIC MLX_EXPORT)
39+
else()
3940
target_compile_definitions(mlx PUBLIC MLX_STATIC)
4041
target_compile_definitions(mlx_version PUBLIC MLX_STATIC)
4142
endif()
@@ -49,20 +50,20 @@ endif()
4950

5051
if(MSVC)
5152
# Some of CUDA's headers include windows.h, which defines min/max macros.
52-
target_compile_definitions(mlx PRIVATE NOMINMAX)
53+
target_compile_definitions(mlx PRIVATE NOMINMAX WIN32_LEAN_AND_MEAN)
54+
# Unicode support in fmt does not compile in .cu files.
55+
target_compile_definitions(mlx PRIVATE FMT_UNICODE=0)
5356
# Disable some MSVC warnings to speed up compilation.
5457
target_compile_options(
5558
mlx
56-
PUBLIC $<$<COMPILE_LANGUAGE:CXX>:/wd4068
57-
/wd4244
58-
/wd4267
59-
/wd4700
60-
/wd4804>
61-
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=/wd4068
62-
-Xcompiler=/wd4244
63-
-Xcompiler=/wd4267
64-
-Xcompiler=/wd4700
65-
-Xcompiler=/wd4804>)
59+
PUBLIC $<$<COMPILE_LANGUAGE:CXX>:/wd4244 /wd4267>
60+
PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/wd4068
61+
/wd4146
62+
/wd4700
63+
/wd4804
64+
/wd4805>
65+
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=/wd4244
66+
-Xcompiler=/wd4267>)
6667
# Enable /bigobj for heavily templated code (e.g., binary.cpp) that exceeds
6768
# the default 65,535 section limit in COFF object files.
6869
target_compile_options(

mlx/array.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,10 @@ class MLX_API array {
489489
int64_t offset{0};
490490

491491
// The size in elements of the data buffer the array accesses
492-
size_t data_size;
492+
size_t data_size{0};
493493

494494
// Contains useful meta data about the array
495-
Flags flags;
495+
Flags flags{true, true, true};
496496

497497
std::vector<array> inputs;
498498
// An array to keep track of the siblings from a multi-output

mlx/backend/cpu/device_info.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
#include <sys/sysctl.h>
77
#include <sys/utsname.h>
88
#elif defined(_WIN32)
9-
#define WIN32_LEAN_AND_MEAN
10-
#define NOMINMAX
119
#include <windows.h>
1210
#else
1311
#include <sys/utsname.h>

mlx/backend/cuda/allocator.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
196196
if (device == -1) {
197197
data = unified_malloc(size);
198198
} else {
199-
if (free_streams_[device]) { // supports memory pools
199+
if (mem_pools_[device]) { // supports memory pools
200200
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
201201
} else {
202202
CHECK_CUDA_ERROR(cudaMalloc(&data, size));
@@ -283,12 +283,13 @@ void CudaAllocator::move_to_unified_memory(
283283
void* data = unified_malloc(buf.size);
284284
cudaMemcpyKind kind =
285285
supports_managed_memory() ? cudaMemcpyDefault : cudaMemcpyDeviceToHost;
286-
if (stream) {
286+
if (stream && mem_pools_[buf.device]) {
287287
CHECK_CUDA_ERROR(cudaMemcpyAsync(data, buf.data, buf.size, kind, stream));
288+
free_async(buf, stream);
288289
} else {
289290
CHECK_CUDA_ERROR(cudaMemcpy(data, buf.data, buf.size, kind));
291+
free_async(buf);
290292
}
291-
cuda_free(buf);
292293
buf.data = data;
293294
buf.device = -1;
294295
}
@@ -298,17 +299,20 @@ void CudaAllocator::free_cuda_buffer(CudaBuffer* buf) {
298299
if (scalar_pool_.in_pool(buf)) {
299300
scalar_pool_.free(buf);
300301
} else {
301-
cuda_free(*buf);
302+
free_async(*buf);
302303
delete buf;
303304
}
304305
}
305306

306-
void CudaAllocator::cuda_free(CudaBuffer& buf) {
307+
void CudaAllocator::free_async(CudaBuffer& buf, cudaStream_t stream) {
307308
if (buf.device == -1) {
308309
unified_free(buf.data);
309310
} else {
310-
cudaStream_t stream = free_streams_[buf.device];
311-
if (stream) {
311+
// Free asynchronously when memory pools is supported.
312+
if (mem_pools_[buf.device]) {
313+
if (!stream) {
314+
stream = free_streams_[buf.device];
315+
}
312316
CHECK_CUDA_ERROR(cudaFreeAsync(buf.data, stream));
313317
} else {
314318
CHECK_CUDA_ERROR(cudaFree(buf.data));

mlx/backend/cuda/allocator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class CudaAllocator : public allocator::Allocator {
6969

7070
private:
7171
void free_cuda_buffer(CudaBuffer* buf);
72-
void cuda_free(CudaBuffer& buf);
72+
void free_async(CudaBuffer& buf, cudaStream_t stream = nullptr);
7373

7474
CudaAllocator();
7575
friend CudaAllocator& allocator();

mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,12 @@ struct GemmConfiguration : public CommonGemmConfiguration<T, Arch, 1> {
124124
};
125125

126126
// Specialized GEMM configuration for sm80 and later.
127-
template <typename T, typename Arch, int kAlignmentC, bool kEnableTF32>
127+
template <typename T, typename Arch, int kAlignmentC>
128128
struct GemmConfiguration<
129129
T,
130130
Arch,
131131
kAlignmentC,
132-
kEnableTF32,
132+
true,
133133
std::enable_if_t<Arch::kMinComputeCapability >= 80 && sizeof(T) <= 4>>
134134
: public CommonGemmConfiguration<T, cutlass::arch::Sm80, kAlignmentC> {
135135
using OpClass = cutlass::arch::OpClassTensorOp;

mlx/backend/cuda/quantized/qmv.cu

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ void fp_qmv(
232232
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
233233
if constexpr (!std::is_same_v<T, double>) {
234234
dim3 block_dims{WARP_SIZE, rows_per_block};
235-
uint B = out.size() / (M * N);
236-
uint blocks_y = (N + rows_per_block - 1) / rows_per_block;
235+
uint32_t B = out.size() / (M * N);
236+
uint32_t blocks_y = (N + rows_per_block - 1) / rows_per_block;
237237
const uint32_t* mat_ptr = gpu_ptr<uint32_t>(mat);
238238
const T* vec_ptr = gpu_ptr<T>(vec);
239239
int n = 1;
@@ -249,16 +249,17 @@ void fp_qmv(
249249
}
250250
dispatch_1_2_4(n, [&](auto n) {
251251
dispatch_bool(B > 1, [&](auto batched) {
252-
if (!batched()) {
253-
auto kernel = fp_qmv_single<T, rows_per_block, n(), 4, 32, true>;
252+
if (!batched.value) {
253+
auto kernel =
254+
fp_qmv_single<T, rows_per_block, n.value, 4, 32, true>;
254255
if (bits == 8) {
255-
kernel = fp_qmv_single<T, rows_per_block, n(), 8, 32, true>;
256+
kernel = fp_qmv_single<T, rows_per_block, n.value, 8, 32, true>;
256257
} else if (group_size == 16) {
257-
kernel = fp_qmv_single<T, rows_per_block, n(), 4, 16, false>;
258+
kernel = fp_qmv_single<T, rows_per_block, n.value, 4, 16, false>;
258259
}
259260
encoder.add_kernel_node(
260261
kernel,
261-
{static_cast<uint>(M), blocks_y},
262+
{static_cast<uint32_t>(M), blocks_y},
262263
block_dims,
263264
0,
264265
mat_ptr,
@@ -268,15 +269,16 @@ void fp_qmv(
268269
N,
269270
K);
270271
} else {
271-
auto kernel = fp_qmv_batched<T, rows_per_block, n(), 4, 32, true>;
272+
auto kernel =
273+
fp_qmv_batched<T, rows_per_block, n.value, 4, 32, true>;
272274
if (bits == 8) {
273-
kernel = fp_qmv_batched<T, rows_per_block, n(), 8, 32, true>;
275+
kernel = fp_qmv_batched<T, rows_per_block, n.value, 8, 32, true>;
274276
} else if (group_size == 16) {
275-
kernel = fp_qmv_batched<T, rows_per_block, n(), 4, 16, false>;
277+
kernel = fp_qmv_batched<T, rows_per_block, n.value, 4, 16, false>;
276278
}
277279
encoder.add_kernel_node(
278280
kernel,
279-
{static_cast<uint>(M), blocks_y, B},
281+
{static_cast<uint32_t>(M), blocks_y, B},
280282
block_dims,
281283
0,
282284
mat_ptr,

mlx/backend/cuda/scaled_dot_product_attention.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ DnnGraph build_sdpa_graph(
140140
const std::optional<array>& mask_arr,
141141
bool output_logsumexp,
142142
const array& o,
143-
const array& stats) {
143+
const std::optional<array>& stats) {
144144
DnnGraph graph(handle, q.dtype());
145145

146146
auto q_ = graph.tensor("Q", Q, q);
@@ -161,7 +161,7 @@ DnnGraph build_sdpa_graph(
161161
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
162162
graph.tensor(o_, O, o)->set_output(true);
163163
if (output_logsumexp) {
164-
graph.tensor(stats_, STATS, stats)->set_output(true);
164+
graph.tensor(stats_, STATS, *stats)->set_output(true);
165165
}
166166

167167
CHECK_CUDNN_FE_ERROR(graph.prepare());
@@ -239,6 +239,11 @@ bool supports_sdpa_cudnn(
239239
return false;
240240
}
241241

242+
// cuDNN does not support bottom right mask when T_q > T_kv.
243+
if (do_causal && (q.shape(2) > k.shape(2))) {
244+
return false;
245+
}
246+
242247
// D_qk and D_v must be a multiple of 8 with maximum value 128.
243248
if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) ||
244249
(v.shape(-1) > 128)) {
@@ -255,7 +260,7 @@ void sdpa_cudnn(
255260
const array& v,
256261
float scale,
257262
array& o,
258-
array& stats,
263+
std::optional<array>& stats,
259264
bool do_causal,
260265
const std::optional<array>& mask_arr,
261266
bool output_logsumexp,
@@ -273,8 +278,8 @@ void sdpa_cudnn(
273278
encoder.set_input_array(*mask_arr);
274279
}
275280
if (output_logsumexp) {
276-
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
277-
encoder.set_output_array(stats);
281+
stats->set_data(cu::malloc_async(stats->nbytes(), encoder));
282+
encoder.set_output_array(*stats);
278283
}
279284

280285
// Search cache.
@@ -298,7 +303,7 @@ void sdpa_cudnn(
298303
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
299304
}
300305
if (output_logsumexp) {
301-
variant_pack[STATS] = gpu_ptr<void>(stats);
306+
variant_pack[STATS] = gpu_ptr<void>(*stats);
302307
}
303308

304309
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
@@ -420,15 +425,18 @@ void ScaledDotProductAttention::eval_gpu(
420425
array q = prepare_sdpa_input(inputs[0], s);
421426
array k = prepare_sdpa_input(inputs[1], s);
422427
array v = prepare_sdpa_input(inputs[2], s);
423-
auto& out = outputs[0];
424-
auto& stats = outputs[1];
428+
array& out = outputs[0];
425429
bool has_mask = inputs.size() - has_sinks_ > 3;
426430
bool has_arr_mask = has_mask && !do_causal_;
427431

428432
std::optional<array> mask_arr;
429433
if (has_arr_mask) {
430434
mask_arr = prepare_sdpa_input(inputs[3], s);
431435
}
436+
std::optional<array> stats;
437+
if (output_logsumexp_) {
438+
stats = outputs[1];
439+
}
432440

433441
if (supports_sdpa_vector(
434442
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {

python/tests/test_fast_sdpa.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -771,20 +771,6 @@ def test_grad(slow, fast, args):
771771

772772
self.assertTrue(mx.allclose(g1, g2, **tolerance))
773773

774-
sdpa_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
775-
q, k, v, scale=scale, mask=mask
776-
)
777-
sdpa_mask_fast = lambda q, k, v, mask: mx.fast.scaled_dot_product_attention(
778-
q, k, v, scale=scale, mask=mask
779-
)
780-
781-
loss_mask_slow = lambda q, k, v, mask: mlx_ref_attn(
782-
q, k, v, scale=scale, mask=mask
783-
).sum()
784-
loss_mask_fast = lambda q, k, v, mask: (
785-
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
786-
).sum()
787-
788774
B, N_kv, T, D = (2, 8, 128, 64)
789775
scale = D**-0.5
790776

@@ -796,11 +782,7 @@ def test_grad(slow, fast, args):
796782
mask_additive = mx.random.normal((B, N_q, T, T), dtype=mx.float16)
797783
mask_bool = mx.random.uniform(0, 1, (B, N_q, T, T), dtype=mx.float16) < 0.5
798784

799-
for mask in (mask_additive, mask_bool):
800-
test_vjp(sdpa_mask_slow, sdpa_mask_fast, [q, k, v, mask])
801-
test_grad(loss_mask_slow, loss_mask_fast, [q, k, v, mask])
802-
803-
for mask in (None, "causal"):
785+
for mask in (None, "causal", mask_additive, mask_bool):
804786
sdpa_slow = lambda q, k, v: mlx_ref_attn(
805787
q, k, v, scale=scale, mask=mask
806788
)

tests/linalg_tests.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ TEST_CASE("test SVD factorization") {
350350
const auto A_again = matmul(matmul(U_slice, diag(S)), Vt);
351351

352352
CHECK(
353-
allclose(A_again, A, /* rtol = */ 1e-4, /* atol = */ 1e-4).item<bool>());
353+
allclose(A_again, A, /* rtol = */ 1e-3, /* atol = */ 1e-3).item<bool>());
354354
CHECK_EQ(U.dtype(), float32);
355355
CHECK_EQ(S.dtype(), float32);
356356
CHECK_EQ(Vt.dtype(), float32);

0 commit comments

Comments
 (0)