Skip to content

Commit 5a94df4

Browse files
committed
add simpler kernel
1 parent 21adee8 commit 5a94df4

5 files changed

Lines changed: 371 additions & 85 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ elseif(BUILD_MPS)
239239
add_compile_definitions(BUILD_MPS)
240240
file(MAKE_DIRECTORY "build")
241241
add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib"
242-
COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES}
242+
COMMAND xcrun metal -c -g -frecord-sources -gline-tables-only -o "build/bitsandbytes.air" ${METAL_FILES}
243243
COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib"
244244
DEPENDS "${METAL_FILES}"
245245
COMMENT "Compiling Metal kernels"

bitsandbytes/backends/mps/ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ def _(
173173
) -> torch.Tensor:
174174
_check_mps_device(A, "A")
175175
_check_mps_device(absmax, "absmax")
176-
177176
out = torch.empty(shape, dtype=dtype, device=A.device)
178177
if _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out):
179178
return out
179+
else:
180+
raise RuntimeError("Failed to dequantize 4bit on MPS")
180181
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
181182

182183

csrc/mps_kernels.metal

Lines changed: 214 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,28 @@ inline void dequantize_block(
112112
uint block_index,
113113
uint thread_idx,
114114
uint threadgroup_size,
115-
constant float* code_table,
116-
threadgroup float& shared_scale
115+
constant float* code_table
117116
) {
118117
uint block_start = block_index * blocksize;
119118
if (block_start >= n) {
120119
return;
121120
}
122-
uint block_end = min(block_start + blocksize, n);
121+
uint block_end;
122+
if (block_start + blocksize < n) {
123+
block_end = block_start + blocksize;
124+
} else {
125+
block_end = n;
126+
}
123127
uint pairs_in_block = (block_end - block_start + 1) >> 1;
124128

125-
if (thread_idx == 0) {
126-
shared_scale = absmax[block_index];
129+
float scale = absmax[block_index];
130+
131+
// Precompute scaled table in registers - avoids threadgroup bank conflicts
132+
// and constant memory is broadcast-optimized so initial loads are fast
133+
float scaled_table[16];
134+
for (uint i = 0; i < 16; i++) {
135+
scaled_table[i] = code_table[i] * scale;
127136
}
128-
threadgroup_barrier(mem_flags::mem_threadgroup);
129-
float scale = shared_scale;
130137

131138
for (uint pair = thread_idx; pair < pairs_in_block; pair += threadgroup_size) {
132139
uint value_index0 = block_start + pair * 2;
@@ -136,23 +143,204 @@ inline void dequantize_block(
136143

137144
uint byte_index0 = value_index0 >> 1;
138145
uchar byte_val0 = packed[byte_index0];
139-
bool upper0 = ((value_index0 & 1) == 0);
140-
uchar nibble0 = upper0 ? ((byte_val0 >> 4) & 0xF) : (byte_val0 & 0xF);
141-
float decoded0 = code_table[nibble0] * scale;
146+
// High nibble -> even index, low nibble -> odd index (matches Python ref)
147+
uchar nibble0 = (byte_val0 >> 4) & 0xF;
148+
uchar nibble1 = byte_val0 & 0xF;
149+
float decoded0 = scaled_table[nibble0];
150+
float decoded1 = scaled_table[nibble1];
151+
// value_index0 is already the output index (block_start + pair*2)
142152
output[value_index0] = scalar_t(decoded0);
143-
144-
uint value_index1 = value_index0 + 1;
145-
if (value_index1 < block_end) {
146-
uint byte_index1 = value_index1 >> 1;
147-
uchar byte_val1 = (byte_index1 == byte_index0) ? byte_val0 : packed[byte_index1];
148-
bool upper1 = ((value_index1 & 1) == 0);
149-
uchar nibble1 = upper1 ? ((byte_val1 >> 4) & 0xF) : (byte_val1 & 0xF);
150-
float decoded1 = code_table[nibble1] * scale;
151-
output[value_index1] = scalar_t(decoded1);
153+
154+
// Bounds check for odd-length blocks
155+
if (value_index0 + 1 < block_end) {
156+
output[value_index0 + 1] = scalar_t(decoded1);
152157
}
153158
}
154159
}
155160

161+
// template <typename scalar_t>
162+
// inline void dequantize_block(
163+
// device const uchar* packed,
164+
// device const float* absmax,
165+
// device scalar_t* output,
166+
// uint n,
167+
// uint blocksize,
168+
// uint block_index,
169+
// uint thread_idx,
170+
// uint threadgroup_size,
171+
// constant float* code_table
172+
// ) {
173+
// const uint block_start = block_index * blocksize;
174+
// if (block_start >= n) return;
175+
176+
// const uint block_end = min(block_start + blocksize, n);
177+
// const uint num_values = block_end - block_start;
178+
179+
// const float scale = absmax[block_index];
180+
181+
// // Precompute scaled code table
182+
// float scaled_table[16];
183+
// for (uint i = 0; i < 16; ++i)
184+
// scaled_table[i] = code_table[i] * scale;
185+
186+
// device const uchar* packed_ptr = packed + (block_start >> 1);
187+
// device scalar_t* output_ptr = output + block_start;
188+
189+
// // Each thread processes multiple *bytes* at a stride
190+
// const uint bytes_in_block = (num_values + 1) >> 1;
191+
192+
// for (uint byte_idx = thread_idx; byte_idx < bytes_in_block; byte_idx += threadgroup_size) {
193+
// uchar byte_val = packed_ptr[byte_idx];
194+
195+
// // Decode upper and lower nibbles
196+
// uchar upper_nib = (byte_val >> 4) & 0xF;
197+
// uchar lower_nib = byte_val & 0xF;
198+
199+
// // Compute global value index
200+
// uint val_idx = byte_idx << 1; // byte_idx * 2
201+
202+
// // Write both values if in bounds
203+
// if (val_idx < num_values) output_ptr[val_idx] = scalar_t(scaled_table[upper_nib]);
204+
// if (val_idx + 1 < num_values) output_ptr[val_idx + 1] = scalar_t(scaled_table[lower_nib]);
205+
// }
206+
// }
207+
208+
// template <typename scalar_t>
209+
// inline void dequantize_block(
210+
// device const uchar* packed,
211+
// device const float* absmax,
212+
// device scalar_t* output,
213+
// uint n,
214+
// uint blocksize,
215+
// uint block_index,
216+
// uint thread_idx,
217+
// uint threadgroup_size,
218+
// constant float* code_table
219+
// ) {
220+
// const uint block_start = block_index * blocksize;
221+
// if (block_start >= n) return;
222+
223+
// const uint block_end = min(block_start + blocksize, n);
224+
// const uint num_values = block_end - block_start;
225+
226+
// const float scale = absmax[block_index];
227+
228+
// // Precompute scaled code table
229+
// float scaled_table[16];
230+
// for (uint i = 0; i < 16; ++i)
231+
// scaled_table[i] = code_table[i] * scale;
232+
233+
// device const uchar* packed_ptr = packed + (block_start >> 1);
234+
// device scalar_t* output_ptr = output + block_start;
235+
236+
// // Each thread processes multiple uchar4 (4 bytes = 8 values)
237+
// const uint num_bytes = (num_values + 1) >> 1; // total bytes in block
238+
// const uint num_blocks = (num_bytes + 3) >> 2; // number of uchar4 blocks
239+
240+
// for (uint block_idx = thread_idx; block_idx < num_blocks; block_idx += threadgroup_size) {
241+
// uint byte_offset = block_idx * 4; // starting byte in packed array
242+
// uchar4 b = uchar4(0); // default zero
243+
244+
// // Load safely (handle tail)
245+
// if (byte_offset + 3 < num_bytes) {
246+
// b = *((device uchar4*)(packed_ptr + byte_offset));
247+
// } else {
248+
// // Tail case: read remaining bytes safely
249+
// uchar temp[4] = {0, 0, 0, 0};
250+
// for (uint i = 0; i < num_bytes - byte_offset; ++i) {
251+
// temp[i] = packed_ptr[byte_offset + i];
252+
// }
253+
// b = uchar4(temp[0], temp[1], temp[2], temp[3]);
254+
// }
255+
256+
// // Decode 8 nibbles into 8 values
257+
// uchar nibbles[8] = {
258+
// uchar((b.x >> 4) & 0xF), uchar(b.x & 0xF),
259+
// uchar((b.y >> 4) & 0xF), uchar(b.y & 0xF),
260+
// uchar((b.z >> 4) & 0xF), uchar(b.z & 0xF),
261+
// uchar((b.w >> 4) & 0xF), uchar(b.w & 0xF)
262+
// };
263+
264+
// // Compute global value indices and write outputs
265+
// uint val_idx = byte_offset << 1; // byte_offset * 2
266+
// for (uint i = 0; i < 8; ++i) {
267+
// if (val_idx + i < num_values)
268+
// output_ptr[val_idx + i] = scalar_t(scaled_table[nibbles[i]]);
269+
// }
270+
// }
271+
// }
272+
273+
// template <typename scalar_t>
274+
// inline void dequantize_block(
275+
// device const uchar* packed,
276+
// device const float* absmax,
277+
// device scalar_t* output,
278+
// uint n,
279+
// uint blocksize,
280+
// uint block_index,
281+
// uint thread_idx,
282+
// uint threadgroup_size,
283+
// constant float* code_table
284+
// ) {
285+
// const uint block_start = block_index * blocksize;
286+
// if (block_start >= n) return;
287+
288+
// const uint block_end = min(block_start + blocksize, n);
289+
// const uint num_values = block_end - block_start;
290+
291+
// const float scale = absmax[block_index];
292+
293+
// // Precompute scaled code table
294+
// float scaled_table[16];
295+
// for (uint i = 0; i < 16; ++i)
296+
// scaled_table[i] = code_table[i] * scale;
297+
298+
// device const uchar* packed_ptr = packed + (block_start >> 1);
299+
// device scalar_t* output_ptr = output + block_start;
300+
301+
// const uint num_bytes = (num_values + 1) >> 1; // total bytes in block
302+
// const uint num_uchar4 = (num_bytes + 3) >> 2; // total uchar4 blocks
303+
304+
// // Each thread handles one or two uchar4 blocks
305+
// uint block_pos = thread_idx;
306+
// if (block_pos >= num_uchar4) return;
307+
308+
// // Compute byte offset
309+
// uint byte_offset = block_pos * 4;
310+
// uchar4 b = uchar4(0, 0, 0, 0);
311+
312+
// // Safe load
313+
// if (byte_offset + 3 < num_bytes) {
314+
// b = *((device uchar4*)(packed_ptr + byte_offset));
315+
// } else {
316+
// uchar temp[4] = {0, 0, 0, 0};
317+
// for (uint i = 0; i < num_bytes - byte_offset; ++i)
318+
// temp[i] = packed_ptr[byte_offset + i];
319+
// b = uchar4(temp[0], temp[1], temp[2], temp[3]);
320+
// }
321+
322+
// // Decode 8 nibbles
323+
// uchar nibbles[8] = {
324+
// uchar((b.x >> 4) & 0xF), uchar(b.x & 0xF),
325+
// uchar((b.y >> 4) & 0xF), uchar(b.y & 0xF),
326+
// uchar((b.z >> 4) & 0xF), uchar(b.z & 0xF),
327+
// uchar((b.w >> 4) & 0xF), uchar(b.w & 0xF)
328+
// };
329+
330+
// // Compute global value index
331+
// uint val_idx = byte_offset << 1; // byte_offset * 2
332+
333+
// // Fully unrolled writes (branch-free for main values)
334+
// if (val_idx + 0 < num_values) output_ptr[val_idx + 0] = scalar_t(scaled_table[nibbles[0]]);
335+
// if (val_idx + 1 < num_values) output_ptr[val_idx + 1] = scalar_t(scaled_table[nibbles[1]]);
336+
// if (val_idx + 2 < num_values) output_ptr[val_idx + 2] = scalar_t(scaled_table[nibbles[2]]);
337+
// if (val_idx + 3 < num_values) output_ptr[val_idx + 3] = scalar_t(scaled_table[nibbles[3]]);
338+
// if (val_idx + 4 < num_values) output_ptr[val_idx + 4] = scalar_t(scaled_table[nibbles[4]]);
339+
// if (val_idx + 5 < num_values) output_ptr[val_idx + 5] = scalar_t(scaled_table[nibbles[5]]);
340+
// if (val_idx + 6 < num_values) output_ptr[val_idx + 6] = scalar_t(scaled_table[nibbles[6]]);
341+
// if (val_idx + 7 < num_values) output_ptr[val_idx + 7] = scalar_t(scaled_table[nibbles[7]]);
342+
// }
343+
156344
} // namespace
157345

158346
// Quantization kernels
@@ -255,8 +443,7 @@ kernel void dequantize_4bit_fp16_fp4(
255443
if (tgid >= blocks) {
256444
return;
257445
}
258-
threadgroup float shared_scale;
259-
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_scale);
446+
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE);
260447
}
261448

262449
kernel void dequantize_4bit_fp16_nf4(
@@ -273,8 +460,7 @@ kernel void dequantize_4bit_fp16_nf4(
273460
if (tgid >= blocks) {
274461
return;
275462
}
276-
threadgroup float shared_scale;
277-
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_scale);
463+
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE);
278464
}
279465

280466
kernel void dequantize_4bit_fp32_fp4(
@@ -291,8 +477,7 @@ kernel void dequantize_4bit_fp32_fp4(
291477
if (tgid >= blocks) {
292478
return;
293479
}
294-
threadgroup float shared_scale;
295-
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_scale);
480+
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE);
296481
}
297482

298483
kernel void dequantize_4bit_fp32_nf4(
@@ -306,9 +491,8 @@ kernel void dequantize_4bit_fp32_nf4(
306491
uint tid [[thread_index_in_threadgroup]],
307492
uint threadgroup_size [[threads_per_threadgroup]]
308493
) {
309-
if (tgid >= blocks) {
310-
return;
311-
}
312-
threadgroup float shared_scale;
313-
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_scale);
494+
// if (tgid >= blocks) {
495+
// return;
496+
// }
497+
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE);
314498
}

0 commit comments

Comments
 (0)