Skip to content

Commit 104e6a8

Browse files
committed
second kernel
1 parent 14c2169 commit 104e6a8

6 files changed

Lines changed: 105 additions & 61 deletions

File tree

bitsandbytes/backends/mps/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _dequantize_4bit_native(
131131
ct.c_int32(blocksize),
132132
ct.c_int32(out.numel()),
133133
)
134+
134135
return True
135136

136137

@@ -163,7 +164,7 @@ def _(
163164
out = torch.empty(shape, dtype=dtype, device=A.device)
164165
if _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out):
165166
return out
166-
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
167+
# return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
167168

168169

169170
@register_kernel("bitsandbytes::dequantize_4bit.out", "mps")
@@ -182,7 +183,6 @@ def _(
182183
torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}")
183184
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
184185

185-
if not _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out):
186-
result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
187-
out.copy_(result)
188-
186+
_dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out)
187+
# result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
188+
# out.copy_(result)

bitsandbytes/test_bnb_mac.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

csrc/mps_kernels.metal

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -85,30 +85,46 @@ inline void dequantize_block(
8585
uint n,
8686
uint blocksize,
8787
uint block_index,
88-
constant float* code_table
88+
uint thread_idx,
89+
uint threadgroup_size,
90+
constant float* code_table,
91+
threadgroup float& shared_scale
8992
) {
90-
uint start = block_index * blocksize;
91-
if (start >= n) {
93+
uint block_start = block_index * blocksize;
94+
if (block_start >= n) {
9295
return;
9396
}
97+
uint block_end = min(block_start + blocksize, n);
98+
uint pairs_in_block = (block_end - block_start + 1) >> 1;
9499

95-
uint end = min(start + blocksize, n);
96-
float scale = absmax[block_index];
97-
if (scale == 0.0f) {
98-
for (uint i = start; i < end; ++i) {
99-
output[i] = scalar_t(0.0f);
100-
}
101-
return;
100+
if (thread_idx == 0) {
101+
shared_scale = absmax[block_index];
102102
}
103+
threadgroup_barrier(mem_flags::mem_threadgroup);
104+
float scale = shared_scale;
105+
106+
for (uint pair = thread_idx; pair < pairs_in_block; pair += threadgroup_size) {
107+
uint value_index0 = block_start + pair * 2;
108+
if (value_index0 >= block_end) {
109+
break;
110+
}
111+
112+
uint byte_index0 = value_index0 >> 1;
113+
uchar byte_val0 = packed[byte_index0];
114+
bool upper0 = ((value_index0 & 1) == 0);
115+
uchar nibble0 = upper0 ? ((byte_val0 >> 4) & 0xF) : (byte_val0 & 0xF);
116+
float decoded0 = code_table[nibble0] * scale;
117+
output[value_index0] = scalar_t(decoded0);
103118

104-
uint base_byte = start >> 1;
105-
for (uint offset = 0; offset < end - start; ++offset) {
106-
uint global_index = start + offset;
107-
uint byte_index = base_byte + (offset >> 1);
108-
uchar byte_val = packed[byte_index];
109-
uchar nibble = (offset & 1) == 0 ? (byte_val >> 4) & 0xF : byte_val & 0xF;
110-
float decoded = code_table[nibble] * scale;
111-
output[global_index] = scalar_t(decoded);
119+
uint value_index1 = value_index0 + 1;
120+
if (value_index1 < block_end) {
121+
uint byte_index1 = value_index1 >> 1;
122+
uchar byte_val1 = (byte_index1 == byte_index0) ? byte_val0 : packed[byte_index1];
123+
bool upper1 = ((value_index1 & 1) == 0);
124+
uchar nibble1 = upper1 ? ((byte_val1 >> 4) & 0xF) : (byte_val1 & 0xF);
125+
float decoded1 = code_table[nibble1] * scale;
126+
output[value_index1] = scalar_t(decoded1);
127+
}
112128
}
113129
}
114130

@@ -183,13 +199,15 @@ kernel void dequantize_4bit_fp16_fp4(
183199
constant uint& n [[buffer(3)]],
184200
constant uint& blocksize [[buffer(4)]],
185201
constant uint& blocks [[buffer(5)]],
186-
uint gid [[thread_position_in_grid]],
187-
202+
uint tgid [[threadgroup_position_in_grid]],
203+
uint tid [[thread_index_in_threadgroup]],
204+
uint threadgroup_size [[threads_per_threadgroup]]
188205
) {
189-
if (gid >= blocks) {
206+
if (tgid >= blocks) {
190207
return;
191208
}
192-
dequantize_block(packed, absmax, output, n, blocksize, gid, FP4_CODE);
209+
threadgroup float shared_scale;
210+
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_scale);
193211
}
194212

195213
kernel void dequantize_4bit_fp16_nf4(
@@ -199,12 +217,15 @@ kernel void dequantize_4bit_fp16_nf4(
199217
constant uint& n [[buffer(3)]],
200218
constant uint& blocksize [[buffer(4)]],
201219
constant uint& blocks [[buffer(5)]],
202-
uint gid [[thread_position_in_grid]]
220+
uint tgid [[threadgroup_position_in_grid]],
221+
uint tid [[thread_index_in_threadgroup]],
222+
uint threadgroup_size [[threads_per_threadgroup]]
203223
) {
204-
if (gid >= blocks) {
224+
if (tgid >= blocks) {
205225
return;
206226
}
207-
dequantize_block(packed, absmax, output, n, blocksize, gid, NF4_CODE);
227+
threadgroup float shared_scale;
228+
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_scale);
208229
}
209230

210231
kernel void dequantize_4bit_fp32_fp4(
@@ -214,12 +235,15 @@ kernel void dequantize_4bit_fp32_fp4(
214235
constant uint& n [[buffer(3)]],
215236
constant uint& blocksize [[buffer(4)]],
216237
constant uint& blocks [[buffer(5)]],
217-
uint gid [[thread_position_in_grid]]
238+
uint tgid [[threadgroup_position_in_grid]],
239+
uint tid [[thread_index_in_threadgroup]],
240+
uint threadgroup_size [[threads_per_threadgroup]]
218241
) {
219-
if (gid >= blocks) {
242+
if (tgid >= blocks) {
220243
return;
221244
}
222-
dequantize_block(packed, absmax, output, n, blocksize, gid, FP4_CODE);
245+
threadgroup float shared_scale;
246+
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_scale);
223247
}
224248

225249
kernel void dequantize_4bit_fp32_nf4(
@@ -229,10 +253,13 @@ kernel void dequantize_4bit_fp32_nf4(
229253
constant uint& n [[buffer(3)]],
230254
constant uint& blocksize [[buffer(4)]],
231255
constant uint& blocks [[buffer(5)]],
232-
uint gid [[thread_position_in_grid]]
256+
uint tgid [[threadgroup_position_in_grid]],
257+
uint tid [[thread_index_in_threadgroup]],
258+
uint threadgroup_size [[threads_per_threadgroup]]
233259
) {
234-
if (gid >= blocks) {
260+
if (tgid >= blocks) {
235261
return;
236262
}
237-
dequantize_block(packed, absmax, output, n, blocksize, gid, NF4_CODE);
263+
threadgroup float shared_scale;
264+
dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_scale);
238265
}

csrc/mps_ops.mm

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stddef.h>
66
#include <stdint.h>
77
#include <string.h>
8+
#include <algorithm>
89

910
namespace {
1011

@@ -167,7 +168,6 @@ static inline void dispatch_dequant_kernel(
167168
if (n == 0) {
168169
return;
169170
}
170-
171171
uint32_t blocks = (n + blocksize - 1) / blocksize;
172172
TensorView packedView = make_tensor_view(packed, "packed");
173173
TensorView absmaxView = make_tensor_view(absmax, "absmax");
@@ -184,17 +184,25 @@ static inline void dispatch_dequant_kernel(
184184
[encoder setBytes:&n length:sizeof(uint32_t) atIndex:3];
185185
[encoder setBytes:&blocksize length:sizeof(uint32_t) atIndex:4];
186186
[encoder setBytes:&blocks length:sizeof(uint32_t) atIndex:5];
187-
NSUInteger threadsPerThreadgroup = pipeline.threadExecutionWidth;
188-
if (threadsPerThreadgroup == 0) {
189-
threadsPerThreadgroup = 1;
187+
188+
NSUInteger maxThreadsPerTG = pipeline.maxTotalThreadsPerThreadgroup;
189+
NSUInteger desiredThreads = (blocksize + 1) / 2;
190+
if (desiredThreads == 0) {
191+
desiredThreads = 1;
190192
}
193+
NSUInteger threadsPerThreadgroup = std::min(maxThreadsPerTG, std::max<NSUInteger>(1, desiredThreads));
194+
if (threadsPerThreadgroup < pipeline.threadExecutionWidth) {
195+
threadsPerThreadgroup = std::min(pipeline.threadExecutionWidth, maxThreadsPerTG);
196+
}
197+
198+
NSUInteger totalThreads = threadsPerThreadgroup * blocks;
191199
MTLSize threads = MTLSizeMake(threadsPerThreadgroup, 1, 1);
192-
MTLSize grid = MTLSizeMake(blocks, 1, 1);
200+
MTLSize grid = MTLSizeMake(totalThreads, 1, 1);
193201
[encoder dispatchThreads:grid threadsPerThreadgroup:threads];
194202
[encoder endEncoding];
195203

196204
[commandBuffer commit];
197-
[commandBuffer waitUntilCompleted];
205+
// [commandBuffer waitUntilCompleted];
198206
}
199207

200208
} // namespace

script.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
3+
PYTHON_PATH=/Users/medmekk/miniforge3/envs/gpt/bin/python
4+
$PYTHON_PATH ./test_bnb_mac.py

test_bnb_mac.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
outputs = model.generate(**inputs, max_new_tokens=20)
1010
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) # or whatever entry function you have
1111

12-
import torch
13-
import bitsandbytes as bnb
14-
A = torch.randn(2048, device='mps', dtype=torch.float16)
15-
q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, 64, 'nf4', torch.uint8)
16-
print('q.shape:', q.shape, q.dtype)
17-
print('absmax.shape:', absmax.shape, absmax.dtype)
18-
B = torch.ops.bitsandbytes.dequantize_4bit(q, absmax, 64, 'nf4', A.shape, A.dtype)
19-
print('ok', float((A-B).abs().max()))
12+
# import torch
13+
# import bitsandbytes as bnb
14+
# A = torch.randn(2048, device='mps', dtype=torch.float16)
15+
# q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, 64, 'nf4', torch.uint8)
16+
# print('q.shape:', q.shape, q.dtype)
17+
# print('absmax.shape:', absmax.shape, absmax.dtype)
18+
# B = torch.ops.bitsandbytes.dequantize_4bit(q, absmax, 64, 'nf4', A.shape, A.dtype)
19+
# print('ok', float((A-B).abs().max()))
2020

2121
# import torch, bitsandbytes as bnb
2222

@@ -52,4 +52,19 @@
5252
# print("q_mps[:8]:", q.view(-1)[:8].cpu())
5353
# print("q_cpu[:8]:", q_cpu.view(-1)[:8])
5454
# print("absmax_mps[:4]:", absmax[:4].cpu())
55-
# print("absmax_cpu[:4]:", absmax_cpu[:4])
55+
# print("absmax_cpu[:4]:", absmax_cpu[:4])
56+
57+
# import torch, bitsandbytes as bnb, time
58+
59+
# torch.manual_seed(0)
60+
# A = torch.randn(4096 * 4096, device="mps", dtype=torch.float16)
61+
# blocksize = 64
62+
63+
# q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, "nf4", torch.uint8)
64+
65+
# torch.mps.synchronize()
66+
# t0 = time.perf_counter()
67+
# torch.ops.bitsandbytes.dequantize_4bit(q, absmax, blocksize, "nf4", A.shape, A.dtype)
68+
# torch.mps.synchronize()
69+
# dt = time.perf_counter() - t0
70+
# print(f"Dequant time: {dt*1000:.2f} ms for {A.numel()/1e6:.1f}M elements")

0 commit comments

Comments
 (0)