Skip to content

Commit d999ca6

Browse files
authored
Merge branch 'ml-explore:main' into rocm-support
2 parents 4f60779 + 211e57b commit d999ca6

131 files changed

Lines changed: 4559 additions & 1184 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/actions/build-cuda-release/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ runs:
1818
env:
1919
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
2020
run: |
21-
pip install auditwheel build patchelf setuptools
21+
pip install auditwheel "build<=1.4.2" patchelf setuptools
2222
python setup.py clean --all
2323
MLX_DISABLE_SM90A_KERNELS=1 MLX_BUILD_STAGE=2 python -m build -w
2424

.github/actions/build-linux-release/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ runs:
2525
- name: Build Python package
2626
shell: bash
2727
run: |
28-
pip install auditwheel patchelf build
28+
pip install auditwheel patchelf "build<=1.4.2"
2929
python setup.py clean --all
3030
MLX_BUILD_STAGE=1 python -m build -w
3131
auditwheel repair dist/mlx-*.whl \

.github/actions/setup-linux/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ runs:
5555
echo "::endgroup::"
5656
5757
- name: Set swap space
58-
if: ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }}
58+
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
5959
uses: pierotofy/set-swap-space@fc79b3f67fa8a838184ce84a674ca12238d2c761
6060
with:
6161
swap-size-gb: 16

CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,15 @@ FetchContent_MakeAvailable(json)
361361
target_include_directories(
362362
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
363363

364+
# Add standalone JACCL library (RDMA over Thunderbolt distributed backend)
365+
if(MLX_BUILD_CPU
366+
AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
367+
AND DEFINED MACOS_SDK_VERSION
368+
AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
369+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx/distributed/jaccl/lib
370+
${CMAKE_BINARY_DIR}/jaccl)
371+
endif()
372+
364373
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
365374

366375
target_include_directories(
@@ -388,7 +397,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
388397
FetchContent_Declare(
389398
nanobind
390399
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
391-
GIT_TAG v2.10.2
400+
GIT_TAG v2.12.0
392401
GIT_SHALLOW TRUE
393402
EXCLUDE_FROM_ALL)
394403
FetchContent_MakeAvailable(nanobind)

benchmarks/python/sdpa_bench.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,20 +176,26 @@ def get_gflop_count(B, M, N, K):
176176
( 1, 1024, 1024, 64, 32, 8),
177177
( 1, 2048, 2048, 64, 32, 8),
178178
( 1, 4096, 4096, 64, 32, 8),
179+
( 1, 4096, 5000, 64, 32, 8),
180+
( 1, 2048, 32121, 64, 32, 8),
179181
)
180182

181183
shapes_80 = (
182184
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
183185
( 1, 1024, 1024, 80, 32, 8),
184186
( 1, 2048, 2048, 80, 32, 8),
185187
( 1, 4096, 4096, 80, 32, 8),
188+
( 1, 4096, 5000, 80, 32, 8),
189+
( 1, 2048, 32121, 80, 32, 8),
186190
)
187191

188192
shapes_128 = (
189193
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
190194
( 1, 1024, 1024, 128, 32, 8),
191195
( 1, 2048, 2048, 128, 32, 8),
192196
( 1, 4096, 4096, 128, 32, 8),
197+
( 1, 4096, 5000, 128, 32, 8),
198+
( 1, 2048, 32121, 128, 32, 8),
193199
)
194200
# fmt: on
195201

docs/src/python/devices_and_streams.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ Devices and Streams
1414
set_default_device
1515
default_stream
1616
new_stream
17+
new_thread_local_stream
1718
set_default_stream
1819
stream
1920
synchronize
21+
clear_streams
2022
device_count
2123
device_info

examples/extensions/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ requires = [
33
"setuptools>=42",
44
"cmake>=3.25",
55
"mlx>=0.18.0",
6-
"nanobind==2.10.2",
6+
"nanobind==2.12.0",
77
]
88
build-backend = "setuptools.build_meta"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
setuptools>=42
22
cmake>=3.25
33
mlx>=0.21.0
4-
nanobind==2.10.2
4+
nanobind==2.12.0

mlx/backend/cpu/sort.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ struct StridedIterator {
107107
return *this;
108108
}
109109

110-
StridedIterator operator+(difference_type diff) {
110+
StridedIterator operator+(difference_type diff) const {
111111
return StridedIterator(ptr_, stride_, diff);
112112
}
113113

114-
StridedIterator operator-(difference_type diff) {
114+
StridedIterator operator-(difference_type diff) const {
115115
return StridedIterator(ptr_, stride_, -diff);
116116
}
117117

mlx/backend/cuda/CMakeLists.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ target_sources(
3131
${CMAKE_CURRENT_SOURCE_DIR}/gemms/block_mask.cu
3232
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
3333
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gather_gemm.cu
3435
${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu
3536
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cu
3637
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
@@ -119,11 +120,11 @@ target_compile_options(mlx
119120
target_compile_options(
120121
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
121122

122-
if(MSVC)
123-
# Ignore warnings from CUTLASS.
124-
target_compile_options(
125-
mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe="--diag_suppress=2908">)
126-
else()
123+
# Ignore warnings from CUTLASS.
124+
target_compile_options(
125+
mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe="--diag_suppress=2908,2361">)
126+
127+
if(NOT MSVC)
127128
# Required for generating optimized CUTLASS code.
128129
target_compile_options(
129130
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fno-strict-aliasing>")
@@ -279,7 +280,7 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
279280
FetchContent_Declare(
280281
cutlass
281282
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
282-
GIT_TAG v4.3.5
283+
GIT_TAG v4.4.2
283284
GIT_SHALLOW TRUE
284285
SOURCE_SUBDIR include EXCLUDE_FROM_ALL)
285286
FetchContent_MakeAvailable(cutlass)

0 commit comments

Comments
 (0)