Skip to content

Commit 879a200

Browse files
committed
Merge branch 'main' of github.com:ml-explore/mlx into rocm-support-fixes
2 parents 3ca29dc + 1ce7118 commit 879a200

13 files changed

Lines changed: 362 additions & 39 deletions

File tree

.github/actions/build-cuda-release/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ runs:
2020
run: |
2121
pip install auditwheel build patchelf setuptools
2222
python setup.py clean --all
23-
MLX_BUILD_STAGE=2 python -m build -w
23+
MLX_DISABLE_SM90A_KERNELS=1 MLX_BUILD_STAGE=2 python -m build -w
2424
2525
auditwheel repair dist/mlx_cuda*.whl \
2626
--plat manylinux_2_35_${{ inputs.arch }} \

docs/src/python/nn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ In detail:
175175
value_and_grad
176176
quantize
177177
average_gradients
178+
fsdp_apply_gradients
178179

179180
.. toctree::
180181

mlx/backend/cuda/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,10 @@ message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
158158
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
159159
"${MLX_CUDA_ARCHITECTURES}")
160160

161-
if(("90a" IN_LIST MLX_CUDA_ARCHITECTURES) OR ("90a-real" IN_LIST
162-
MLX_CUDA_ARCHITECTURES))
161+
# Skip Hopper-only kernels when not building for sm90a.
162+
if(NOT DEFINED ENV{MLX_DISABLE_SM90A_KERNELS}
163+
AND (("90a" IN_LIST MLX_CUDA_ARCHITECTURES) OR ("90a-real" IN_LIST
164+
MLX_CUDA_ARCHITECTURES)))
163165
target_compile_definitions(mlx PRIVATE MLX_CUDA_SM90A_ENABLED)
164166
endif()
165167

mlx/backend/metal/device.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,18 @@ Device::Device() {
323323
auto pool = new_scoped_memory_pool();
324324
device_ = load_device();
325325
default_library_ = load_default_library(device_);
326-
arch_ = std::string(device_->architecture()->name()->utf8String());
327-
int ag_tens = arch_[arch_.size() - 3] - '0';
328-
int ag_ones = arch_[arch_.size() - 2] - '0';
326+
arch_ = env::metal_gpu_arch();
327+
if (arch_.empty()) {
328+
arch_ = std::string(device_->architecture()->name()->utf8String());
329+
}
330+
int ag_tens = 0;
331+
int ag_ones = 0;
332+
if (arch_.size() >= 3) {
333+
ag_tens = arch_[arch_.size() - 3] - '0';
334+
ag_ones = arch_[arch_.size() - 2] - '0';
335+
ag_tens = (ag_tens < 10 && ag_tens >= 0) ? ag_tens : 0;
336+
ag_ones = (ag_ones < 10 && ag_ones >= 0) ? ag_ones : 0;
337+
}
329338
arch_gen_ = ag_tens * 10 + ag_ones;
330339
auto arch = arch_.back();
331340
switch (arch) {

mlx/backend/metal/device_info.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ device_info(int device_index) {
2121
auto init_device_info = []()
2222
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
2323
auto pool = metal::new_scoped_memory_pool();
24-
auto raw_device = metal::device(mlx::core::Device::gpu).mtl_device();
24+
auto& device = metal::device(mlx::core::Device::gpu);
25+
auto raw_device = device.mtl_device();
2526
auto name = std::string(raw_device->name()->utf8String());
26-
auto arch = std::string(raw_device->architecture()->name()->utf8String());
27+
auto arch = device.get_architecture();
2728

2829
size_t memsize = 0;
2930
size_t length = sizeof(memsize);

mlx/backend/metal/quantized.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,9 @@ inline array ensure_row_contiguous_matrix(
8282
}
8383

8484
inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {
85-
auto arch = d.get_architecture();
86-
auto arch_size = arch.back();
87-
auto arch_gen = arch.substr(arch.size() - 3, 2);
88-
if (arch_gen == "13" || arch_gen == "14") {
85+
auto arch_size = d.get_architecture().back();
86+
auto arch_gen = d.get_architecture_gen();
87+
if (arch_gen == 13 || arch_gen == 14) {
8988
switch (arch_size) {
9089
case 'd':
9190
if (D <= 2048 && O <= 2048) {

mlx/utils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,14 @@ int get_var(const char* name, int default_value) {
258258
}
259259
}
260260

261+
std::string get_var(const char* name, const char* default_value) {
262+
if (const char* buff_str = std::getenv(name)) {
263+
return buff_str;
264+
} else {
265+
return default_value;
266+
}
267+
}
268+
261269
} // namespace env
262270

263271
template <typename T>

mlx/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ inline int next_power_of_2(int n) {
136136
namespace env {
137137

138138
int get_var(const char* name, int default_value);
139+
std::string get_var(const char* name, const char* default_value);
139140

140141
inline int bfs_max_width() {
141142
static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
@@ -169,6 +170,11 @@ inline int nccl_timeout(int default_value) {
169170
return nccl_timeout;
170171
}
171172

173+
inline const std::string& metal_gpu_arch() {
174+
static std::string gpu_arch_ = get_var("MLX_METAL_GPU_ARCH", "");
175+
return gpu_arch_;
176+
}
177+
172178
} // namespace env
173179

174180
} // namespace mlx::core

mlx/version.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#include "mlx/api.h"
66

77
#define MLX_VERSION_MAJOR 0
8-
#define MLX_VERSION_MINOR 30
9-
#define MLX_VERSION_PATCH 7
8+
#define MLX_VERSION_MINOR 31
9+
#define MLX_VERSION_PATCH 1
1010
#define MLX_VERSION_NUMERIC \
1111
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
1212

python/mlx/nn/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,8 @@
22

33
from mlx.nn import init, losses
44
from mlx.nn.layers import *
5-
from mlx.nn.utils import average_gradients, value_and_grad
5+
from mlx.nn.utils import (
6+
average_gradients,
7+
fsdp_apply_gradients,
8+
value_and_grad,
9+
)

0 commit comments

Comments
 (0)