Skip to content

Commit 251c8d8

Browse files
authored
Merge pull request #5 from lemonade-sdk/rocm-optimizations
ROCm: WMMA prefill, hipBLASLt, 4-bit MoE dispatch, QMV tuning, iGPU allocator
2 parents fe75135 + f26c802 commit 251c8d8

22 files changed

Lines changed: 3062 additions & 238 deletions

mlx/backend/rocm/CMakeLists.txt

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,20 @@ find_package(rocblas REQUIRED CONFIG)
1010
find_package(rocthrust REQUIRED CONFIG)
1111
find_package(rocprim REQUIRED CONFIG)
1212
find_package(hiprand REQUIRED CONFIG)
13+
find_package(rocwmma REQUIRED CONFIG)
1314

1415
# Ensure HIP architectures are set - respect user-provided value from command
1516
# line The user can set this via -DCMAKE_HIP_ARCHITECTURES=gfx1011
1617
#
17-
# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix: CDNA:
18-
# gfx908 (MI100), gfx90a (MI200), gfx942 (MI300) CDNA4: gfx950 (MI400 series)
19-
# RDNA2: gfx1030 (RX 6000 series) RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600)
20-
# RDNA4: gfx1200, gfx1201 (RX 8000 series)
18+
# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix:
19+
# CDNA: gfx908 (MI100), gfx90a (MI200), gfx942 (MI300)
20+
# RDNA2: gfx1030 (RX 6000 series)
21+
# RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600)
22+
# RDNA3.5: gfx1150, gfx1151, gfx1152 (Ryzen AI / Radeon 8060S)
23+
# RDNA4: gfx1200, gfx1201 (RX 9000 series)
2124
if(NOT CMAKE_HIP_ARCHITECTURES)
2225
set(CMAKE_HIP_ARCHITECTURES
23-
"gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102"
26+
"gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102;gfx1150;gfx1151;gfx1152;gfx1200;gfx1201"
2427
CACHE STRING "HIP architectures" FORCE)
2528
endif()
2629
message(
@@ -39,6 +42,8 @@ get_target_property(ROCTHRUST_INCLUDES roc::rocthrust
3942
INTERFACE_INCLUDE_DIRECTORIES)
4043
get_target_property(ROCPRIM_INCLUDES roc::rocprim INTERFACE_INCLUDE_DIRECTORIES)
4144
get_target_property(HIPRAND_INCLUDES hip::hiprand INTERFACE_INCLUDE_DIRECTORIES)
45+
get_target_property(ROCWMMA_INCLUDES roc::rocwmma
46+
INTERFACE_INCLUDE_DIRECTORIES)
4247

4348
# Find GCC installation for C++ standard library headers ROCm's clang needs to
4449
# know where to find libstdc++ headers
@@ -101,6 +106,11 @@ foreach(inc ${HIPRAND_INCLUDES})
101106
list(APPEND HIP_INCLUDE_FLAGS "-I${inc}")
102107
endif()
103108
endforeach()
109+
foreach(inc ${ROCWMMA_INCLUDES})
110+
if(inc)
111+
list(APPEND HIP_INCLUDE_FLAGS "-I${inc}")
112+
endif()
113+
endforeach()
104114

105115
message(STATUS "HIP include flags: ${HIP_INCLUDE_FLAGS}")
106116

@@ -147,6 +157,20 @@ set(HIP_SOURCES
147157
set(HIP_OBJ_DIR "${CMAKE_CURRENT_BINARY_DIR}/hip_objs")
148158
file(MAKE_DIRECTORY ${HIP_OBJ_DIR})
149159

160+
# Detect CPU count for parallel HIP offload compilation
161+
# Use half of available CPUs for parallel HIP offload compilation per file
162+
# (Ninja already parallelizes across files, so this avoids oversubscription)
163+
include(ProcessorCount)
164+
ProcessorCount(NPROC)
165+
if(NPROC EQUAL 0)
166+
set(NPROC 4)
167+
else()
168+
math(EXPR NPROC "${NPROC} / 2")
169+
if(NPROC LESS 2)
170+
set(NPROC 2)
171+
endif()
172+
endif()
173+
150174
# Compile each HIP file to object file using custom commands Use -fno-gpu-rdc to
151175
# avoid needing device link step
152176
set(HIP_OBJECTS "")
@@ -168,6 +192,7 @@ foreach(hip_src ${HIP_SOURCES})
168192
OUTPUT ${hip_obj}
169193
COMMAND ${CMAKE_HIP_COMPILER} -c ${hip_src} -o ${hip_obj} -fPIC
170194
-DMLX_USE_ROCM ${HIP_ARCH_FLAGS} ${HIP_INCLUDE_FLAGS} -std=c++17
195+
-parallel-jobs=${NPROC}
171196
DEPENDS ${hip_src}
172197
COMMENT "Compiling HIP source ${hip_src}"
173198
VERBATIM)
@@ -211,7 +236,8 @@ target_sources(
211236
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
212237
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp
213238
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv.cpp
214-
${CMAKE_CURRENT_SOURCE_DIR}/gemms/rocblas_gemm.cpp)
239+
${CMAKE_CURRENT_SOURCE_DIR}/gemms/rocblas_gemm.cpp
240+
${CMAKE_CURRENT_SOURCE_DIR}/gemms/hipblaslt_gemm.cpp)
215241

216242
target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)
217243

@@ -247,16 +273,21 @@ find_library(AMDHIP64_LIB amdhip64 PATHS ${ROCM_PATH}/lib /opt/rocm/lib
247273
find_library(HIPRTC_LIB hiprtc PATHS ${ROCM_PATH}/lib /opt/rocm/lib
248274
/opt/rocm-6.0.0/lib)
249275

276+
# Find hipBLASLt library (optimized GEMM for half-precision)
277+
find_library(HIPBLASLT_LIB hipblaslt PATHS ${ROCM_PATH}/lib /opt/rocm/lib
278+
/opt/rocm-6.0.0/lib)
279+
250280
message(
251281
STATUS
252-
"ROCm libraries: rocblas=${ROCBLAS_LIB}, hiprand=${HIPRAND_LIB}, amdhip64=${AMDHIP64_LIB}, hiprtc=${HIPRTC_LIB}"
282+
"ROCm libraries: rocblas=${ROCBLAS_LIB}, hiprand=${HIPRAND_LIB}, amdhip64=${AMDHIP64_LIB}, hiprtc=${HIPRTC_LIB}, hipblaslt=${HIPBLASLT_LIB}"
253283
)
254284

255285
# Link the static library and ROCm libraries to mlx We link directly to the .so
256286
# files instead of using CMake targets to avoid propagating compile options like
257287
# -x hip
258288
target_link_libraries(mlx PRIVATE ${HIP_STATIC_LIB} ${AMDHIP64_LIB}
259-
${ROCBLAS_LIB} ${HIPRAND_LIB} ${HIPRTC_LIB})
289+
${ROCBLAS_LIB} ${HIPRAND_LIB} ${HIPRTC_LIB}
290+
${HIPBLASLT_LIB})
260291

261292
# Include ROCm headers for mlx C++ files Get the HIP include directory from the
262293
# hip package

mlx/backend/rocm/allocator.cpp

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,26 @@ static bool rocm_available() {
3535
return available == 1;
3636
}
3737

38-
// Check if managed memory is supported on this device
38+
// Check if managed memory (HMM) is supported on this device.
39+
// On integrated GPUs (Strix Halo), HMM is actually fast since there's no
40+
// discrete VRAM — managed memory avoids the overhead of hipExtMallocWithFlags.
3941
static bool managed_memory_supported() {
40-
// Always return false to force the use of hipHostMalloc (GTT RAM).
41-
// hipMallocManaged uses HMM, which causes implicit page migrations and
42-
// significant memory copying between host and device on access.
43-
// Using hipHostMalloc maps pinned host memory directly to the GPU's address space.
44-
return false;
42+
static int supported = -1;
43+
if (supported < 0) {
44+
if (!rocm_available()) {
45+
supported = 0;
46+
} else {
47+
void* test_ptr = nullptr;
48+
hipError_t err = hipMallocManaged(&test_ptr, 64);
49+
if (err == hipSuccess) {
50+
(void)hipFree(test_ptr);
51+
supported = 1;
52+
} else {
53+
supported = 0;
54+
}
55+
}
56+
}
57+
return supported == 1;
4558
}
4659

4760
static bool is_integrated() {
@@ -64,18 +77,18 @@ inline void* rocm_unified_malloc(size_t size, bool& is_managed) {
6477
void* data = nullptr;
6578
hipError_t err;
6679
if (is_integrated()) {
80+
// Unified memory device (iGPU/APU): CPU and GPU share system RAM.
81+
// Try hipExtMallocWithFlags first (fine-grained coherent, best GPU
82+
// bandwidth). Falls back to hipMallocManaged for large allocations
83+
// that exceed the small device-local VRAM (~2GB).
6784
err = hipExtMallocWithFlags(&data, size, hipDeviceMallocFinegrained);
68-
is_managed = true; // Use is_managed=true to signify hipFree should be used
85+
if (err != hipSuccess) {
86+
err = hipMallocManaged(&data, size);
87+
}
88+
is_managed = true;
6989
} else if (managed_memory_supported()) {
7090
err = hipMallocManaged(&data, size);
7191
is_managed = true;
72-
if (err == hipSuccess) {
73-
int device_count = 0;
74-
(void)hipGetDeviceCount(&device_count);
75-
for (int i = 0; i < device_count; ++i) {
76-
(void)hipMemAdvise(data, size, hipMemAdviseSetAccessedBy, i);
77-
}
78-
}
7992
} else {
8093
err = hipHostMalloc(&data, size, hipHostMallocDefault);
8194
is_managed = false;
@@ -193,6 +206,14 @@ Buffer RocmAllocator::malloc(size_t size) {
193206
}
194207

195208
// Find available buffer from cache.
209+
// Use aggressive size rounding to maximize cache hit rate:
210+
// - Small (<=8B): scalar pool
211+
// - Medium (<16KB): power-of-2
212+
// - Large (<1MB): 16KB page aligned
213+
// - Very large (>=1MB): power-of-2 (coarser buckets = more cache hits)
214+
// The power-of-2 rounding for large allocations is critical for decode —
215+
// without it, slightly different sizes (e.g., 1.01MB vs 1.02MB) miss the
216+
// cache and trigger hipExtMallocWithFlags at ~7ms each.
196217
auto orig_size = size;
197218
std::unique_lock lock(mutex_);
198219
if (size <= small_block_size) {
@@ -219,14 +240,11 @@ Buffer RocmAllocator::malloc(size_t size) {
219240
lock.unlock();
220241
if (!buf) {
221242
if (is_integrated()) {
222-
buf = new RocmBuffer{nullptr, size, false, -1};
223-
hipError_t err = hipExtMallocWithFlags(&buf->data, size, hipDeviceMallocFinegrained);
224-
if (err != hipSuccess) {
225-
delete buf;
226-
std::ostringstream oss;
227-
oss << "hipExtMallocWithFlags failed: " << hipGetErrorString(err) << ".";
228-
throw std::runtime_error(oss.str());
229-
}
243+
// Integrated GPU: allocate unified memory (CPU+GPU accessible).
244+
// device=-1 signals unified memory — no move_to_unified_memory needed.
245+
bool is_managed = false;
246+
void* data = rocm_unified_malloc(size, is_managed);
247+
buf = new RocmBuffer{data, size, is_managed, -1};
230248
} else {
231249
int device = 0;
232250
hipGetDevice(&device);
@@ -373,12 +391,18 @@ void* Buffer::raw_ptr() {
373391
if (!ptr_) {
374392
return nullptr;
375393
}
376-
// Synchronize all streams before accessing memory from CPU
377-
// This ensures all GPU operations have completed
378-
(void)hipDeviceSynchronize();
379-
380394
auto& cbuf = *static_cast<rocm::RocmBuffer*>(ptr_);
381-
rocm::allocator().move_to_unified_memory(cbuf);
395+
396+
if (cbuf.device == -1) {
397+
// Unified memory (integrated GPU or hipMallocManaged): CPU-accessible.
398+
// hipStreamSynchronize(nullptr) waits for the default stream — lighter
399+
// than hipDeviceSynchronize which waits for ALL streams.
400+
(void)hipStreamSynchronize(nullptr);
401+
} else {
402+
// Discrete GPU VRAM: full sync + migrate to host-accessible memory.
403+
(void)hipDeviceSynchronize();
404+
rocm::allocator().move_to_unified_memory(cbuf);
405+
}
382406
return cbuf.data;
383407
}
384408

0 commit comments

Comments
 (0)