Skip to content

Commit ef8190c

Browse files
committed
hipBLASLt auto-tune + eliminate hipMemcpyAsync in copy kernels
GEMM tuning: Request 8 algorithms from hipBLASLt heuristic and benchmark each on first call per (M,N,K) shape. Cache the winner for subsequent calls. Finds lower-VGPR kernels for better CU occupancy. Copy reduction: Replace hipMemcpyAsync-based shape/stride passing in copy_general and copy_general_input with by-value hip_array kernel arguments. Eliminates 3 HIP API calls per general copy dispatch. Results (Qwen3.5-35B-A3B-4bit): - hipMemcpyAsync: 964 -> 77 (-92%) - Gen tok/s: 25.1 -> 26.6 (+6%) - Short gen: 21 -> 46 tok/s (+120%)
1 parent ce31887 commit ef8190c

3 files changed

Lines changed: 151 additions & 156 deletions

File tree

mlx/backend/rocm/copy/copy_general.hip

Lines changed: 33 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "mlx/backend/rocm/copy/copy.hpp"
44
#include "mlx/backend/rocm/device.h"
55
#include "mlx/backend/rocm/kernel_utils.hpp"
6+
#include "mlx/backend/rocm/device/config.h"
7+
#include "mlx/backend/rocm/device/utils.hpp"
68
#include "mlx/dtype_utils.h"
79

810
#include <hip/hip_runtime.h>
@@ -11,59 +13,28 @@ namespace mlx::core {
1113

1214
namespace rocm {
1315

14-
// Helper function to convert linear index to strided offset
15-
template <typename IdxT>
16-
__device__ IdxT linear_to_strided(
17-
IdxT elem,
18-
const int* shape,
19-
const int64_t* strides,
16+
// General copy kernel with by-value shape/strides (no hipMemcpyAsync needed)
17+
template <typename In, typename Out, typename IdxT>
18+
__global__ void copy_gg_byval(
19+
const In* in,
20+
Out* out,
21+
IdxT size,
22+
hip_array<int32_t, MAX_NDIM> shape,
23+
hip_array<int64_t, MAX_NDIM> strides_in,
24+
hip_array<int64_t, MAX_NDIM> strides_out,
2025
int ndim) {
21-
IdxT loc = 0;
22-
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
23-
loc += (elem % shape[i]) * IdxT(strides[i]);
24-
elem /= shape[i];
25-
}
26-
return loc;
27-
}
26+
IdxT index = blockIdx.x * blockDim.x + threadIdx.x;
27+
if (index >= size) return;
2828

29-
// Helper function to convert linear index to two strided offsets
30-
template <typename IdxT>
31-
__device__ void linear_to_strided_2(
32-
IdxT elem,
33-
const int* shape,
34-
const int64_t* strides_in,
35-
const int64_t* strides_out,
36-
int ndim,
37-
IdxT& loc_in,
38-
IdxT& loc_out) {
39-
loc_in = 0;
40-
loc_out = 0;
29+
IdxT loc_in = 0, loc_out = 0;
30+
IdxT elem = index;
4131
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
4232
IdxT dim_idx = elem % shape[i];
4333
loc_in += dim_idx * IdxT(strides_in[i]);
4434
loc_out += dim_idx * IdxT(strides_out[i]);
4535
elem /= shape[i];
4636
}
47-
}
48-
49-
// General copy kernel - strided input to strided output (dynamic ndim)
50-
template <typename In, typename Out, typename IdxT>
51-
__global__ void copy_gg_dynamic(
52-
const In* in,
53-
Out* out,
54-
IdxT size,
55-
const int* shape,
56-
const int64_t* strides_in,
57-
const int64_t* strides_out,
58-
int ndim) {
59-
IdxT index = blockIdx.x * blockDim.x + threadIdx.x;
60-
if (index >= size) {
61-
return;
62-
}
63-
64-
IdxT idx_in, idx_out;
65-
linear_to_strided_2(index, shape, strides_in, strides_out, ndim, idx_in, idx_out);
66-
out[idx_out] = cast_to<Out>(in[idx_in]);
37+
out[loc_out] = cast_to<Out>(in[loc_in]);
6738
}
6839

6940
} // namespace rocm
@@ -78,78 +49,48 @@ void copy_general(
7849
const Shape& shape,
7950
const Strides& strides_in,
8051
const Strides& strides_out) {
81-
52+
8253
int ndim = shape.size();
8354
size_t data_size = 1;
8455
for (auto& s : shape) {
8556
data_size *= s;
8657
}
87-
58+
8859
if (data_size == 0) {
8960
return;
9061
}
9162

92-
// Allocate device memory for shape and strides
93-
array shape_arr({ndim}, int32, nullptr, {});
94-
array strides_in_arr({ndim}, int64, nullptr, {});
95-
array strides_out_arr({ndim}, int64, nullptr, {});
96-
shape_arr.set_data(allocator::malloc(shape_arr.nbytes()));
97-
strides_in_arr.set_data(allocator::malloc(strides_in_arr.nbytes()));
98-
strides_out_arr.set_data(allocator::malloc(strides_out_arr.nbytes()));
99-
encoder.add_temporary(shape_arr);
100-
encoder.add_temporary(strides_in_arr);
101-
encoder.add_temporary(strides_out_arr);
102-
103-
void* shape_ptr = gpu_ptr<void>(shape_arr);
104-
void* strides_in_ptr = gpu_ptr<void>(strides_in_arr);
105-
void* strides_out_ptr = gpu_ptr<void>(strides_out_arr);
63+
// Pack shape/strides into by-value structs (no device allocation needed)
64+
rocm::hip_array<int32_t, MAX_NDIM> shape_arg = {};
65+
rocm::hip_array<int64_t, MAX_NDIM> strides_in_arg = {};
66+
rocm::hip_array<int64_t, MAX_NDIM> strides_out_arg = {};
67+
for (int i = 0; i < ndim; i++) {
68+
shape_arg.data_[i] = static_cast<int32_t>(shape[i]);
69+
strides_in_arg.data_[i] = strides_in[i];
70+
strides_out_arg.data_[i] = strides_out[i];
71+
}
72+
10673
const void* in_ptr = gpu_ptr<void>(in);
10774
void* out_ptr = gpu_ptr<void>(out);
10875

10976
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
11077
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
11178
using InType = hip_type_t<MLX_GET_TYPE(in_type_tag)>;
11279
using OutType = hip_type_t<MLX_GET_TYPE(out_type_tag)>;
113-
114-
encoder.launch_kernel([
115-
&,
116-
shape_ptr,
117-
strides_in_ptr,
118-
strides_out_ptr,
119-
in_ptr,
120-
out_ptr](hipStream_t stream) {
121-
// Copy shape and strides to device
122-
(void)hipMemcpyAsync(
123-
shape_ptr,
124-
shape.data(),
125-
ndim * sizeof(int32_t),
126-
hipMemcpyHostToDevice,
127-
stream);
128-
(void)hipMemcpyAsync(
129-
strides_in_ptr,
130-
strides_in.data(),
131-
ndim * sizeof(int64_t),
132-
hipMemcpyHostToDevice,
133-
stream);
134-
(void)hipMemcpyAsync(
135-
strides_out_ptr,
136-
strides_out.data(),
137-
ndim * sizeof(int64_t),
138-
hipMemcpyHostToDevice,
139-
stream);
14080

81+
encoder.launch_kernel([=](hipStream_t stream) {
14182
int block_size = 256;
14283
int num_blocks = (data_size + block_size - 1) / block_size;
14384

14485
hipLaunchKernelGGL(
145-
(rocm::copy_gg_dynamic<InType, OutType, int64_t>),
86+
(rocm::copy_gg_byval<InType, OutType, int64_t>),
14687
dim3(num_blocks), dim3(block_size), 0, stream,
14788
static_cast<const InType*>(in_ptr) + offset_in,
14889
static_cast<OutType*>(out_ptr) + offset_out,
14990
static_cast<int64_t>(data_size),
150-
static_cast<const int*>(shape_ptr),
151-
static_cast<const int64_t*>(strides_in_ptr),
152-
static_cast<const int64_t*>(strides_out_ptr),
91+
shape_arg,
92+
strides_in_arg,
93+
strides_out_arg,
15394
ndim);
15495
});
15596
});

mlx/backend/rocm/copy/copy_general_input.hip

Lines changed: 28 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "mlx/backend/rocm/copy/copy.hpp"
44
#include "mlx/backend/rocm/device.h"
55
#include "mlx/backend/rocm/kernel_utils.hpp"
6+
#include "mlx/backend/rocm/device/config.h"
7+
#include "mlx/backend/rocm/device/utils.hpp"
68
#include "mlx/dtype_utils.h"
79

810
#include <hip/hip_runtime.h>
@@ -13,37 +15,25 @@ static constexpr int TILE_SIZE = 16;
1315

1416
namespace rocm {
1517

16-
// Helper function to convert linear index to strided offset
17-
template <typename IdxT>
18-
__device__ IdxT linear_to_strided(
19-
IdxT elem,
20-
const int* shape,
21-
const int64_t* strides,
22-
int ndim) {
23-
IdxT loc = 0;
24-
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
25-
loc += (elem % shape[i]) * IdxT(strides[i]);
26-
elem /= shape[i];
27-
}
28-
return loc;
29-
}
30-
31-
// General copy kernel - strided input to contiguous output (dynamic ndim)
18+
// General copy kernel - strided input to contiguous output (by-value args)
3219
template <typename In, typename Out, typename IdxT>
33-
__global__ void copy_g_dynamic(
20+
__global__ void copy_g_byval(
3421
const In* in,
3522
Out* out,
3623
IdxT size,
37-
const int* shape,
38-
const int64_t* strides,
24+
hip_array<int32_t, MAX_NDIM> shape,
25+
hip_array<int64_t, MAX_NDIM> strides,
3926
int ndim) {
4027
IdxT index = blockIdx.x * blockDim.x + threadIdx.x;
41-
if (index >= size) {
42-
return;
43-
}
28+
if (index >= size) return;
4429

45-
IdxT idx = linear_to_strided(index, shape, strides, ndim);
46-
out[index] = cast_to<Out>(in[idx]);
30+
IdxT loc = 0;
31+
IdxT elem = index;
32+
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
33+
loc += (elem % shape[i]) * IdxT(strides[i]);
34+
elem /= shape[i];
35+
}
36+
out[index] = cast_to<Out>(in[loc]);
4737
}
4838

4939
// Column to row transpose kernel
@@ -53,15 +43,14 @@ __global__ void copy_col_row(
5343
T* out,
5444
int64_t rows,
5545
int64_t cols) {
56-
__shared__ T tile[TILE_SIZE][TILE_SIZE + 1]; // +1 to avoid bank conflicts
46+
__shared__ T tile[TILE_SIZE][TILE_SIZE + 1];
5747

5848
int tile_row = blockIdx.x * TILE_SIZE;
5949
int tile_col = blockIdx.y * TILE_SIZE;
6050

6151
int tidx = threadIdx.x;
6252
int tidy = threadIdx.y;
6353

64-
// Load from column-major input
6554
int in_row = tile_row + tidx;
6655
int in_col = tile_col + tidy;
6756
if (in_row < rows && in_col < cols) {
@@ -70,7 +59,6 @@ __global__ void copy_col_row(
7059

7160
__syncthreads();
7261

73-
// Store to row-major output
7462
int out_row = tile_row + tidy;
7563
int out_col = tile_col + tidx;
7664
if (out_row < rows && out_col < cols) {
@@ -89,10 +77,10 @@ void copy_general_input(
8977
int64_t offset_out,
9078
const Shape& shape,
9179
const Strides& strides_in) {
92-
80+
9381
int ndim = shape.size();
9482
size_t data_size = out.size();
95-
83+
9684
if (data_size == 0) {
9785
return;
9886
}
@@ -117,55 +105,34 @@ void copy_general_input(
117105
return;
118106
}
119107

120-
// Allocate device memory for shape and strides
121-
array shape_arr({ndim}, int32, nullptr, {});
122-
array strides_arr({ndim}, int64, nullptr, {});
123-
shape_arr.set_data(allocator::malloc(shape_arr.nbytes()));
124-
strides_arr.set_data(allocator::malloc(strides_arr.nbytes()));
125-
encoder.add_temporary(shape_arr);
126-
encoder.add_temporary(strides_arr);
108+
// Pack shape/strides into by-value structs (no device allocation or hipMemcpyAsync)
109+
rocm::hip_array<int32_t, MAX_NDIM> shape_arg = {};
110+
rocm::hip_array<int64_t, MAX_NDIM> strides_arg = {};
111+
for (int i = 0; i < ndim; i++) {
112+
shape_arg.data_[i] = static_cast<int32_t>(shape[i]);
113+
strides_arg.data_[i] = strides_in[i];
114+
}
127115

128-
void* shape_ptr = gpu_ptr<void>(shape_arr);
129-
void* strides_ptr = gpu_ptr<void>(strides_arr);
130116
const void* in_ptr = gpu_ptr<void>(in);
131117
void* out_ptr = gpu_ptr<void>(out);
132118

133119
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
134120
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
135121
using InType = hip_type_t<MLX_GET_TYPE(in_type_tag)>;
136122
using OutType = hip_type_t<MLX_GET_TYPE(out_type_tag)>;
137-
138-
encoder.launch_kernel([
139-
&,
140-
shape_ptr,
141-
strides_ptr,
142-
in_ptr,
143-
out_ptr](hipStream_t stream) {
144-
// Copy shape and strides to device
145-
(void)hipMemcpyAsync(
146-
shape_ptr,
147-
shape.data(),
148-
ndim * sizeof(int32_t),
149-
hipMemcpyHostToDevice,
150-
stream);
151-
(void)hipMemcpyAsync(
152-
strides_ptr,
153-
strides_in.data(),
154-
ndim * sizeof(int64_t),
155-
hipMemcpyHostToDevice,
156-
stream);
157123

124+
encoder.launch_kernel([=](hipStream_t stream) {
158125
int block_size = 256;
159126
int num_blocks = (data_size + block_size - 1) / block_size;
160127

161128
hipLaunchKernelGGL(
162-
(rocm::copy_g_dynamic<InType, OutType, int64_t>),
129+
(rocm::copy_g_byval<InType, OutType, int64_t>),
163130
dim3(num_blocks), dim3(block_size), 0, stream,
164131
static_cast<const InType*>(in_ptr) + offset_in,
165132
static_cast<OutType*>(out_ptr) + offset_out,
166133
static_cast<int64_t>(data_size),
167-
static_cast<const int*>(shape_ptr),
168-
static_cast<const int64_t*>(strides_ptr),
134+
shape_arg,
135+
strides_arg,
169136
ndim);
170137
});
171138
});

0 commit comments

Comments
 (0)