@@ -112,21 +112,28 @@ inline void dequantize_block(
112112 uint block_index,
113113 uint thread_idx,
114114 uint threadgroup_size,
115- constant float * code_table,
116- threadgroup float & shared_scale
115+ constant float * code_table
117116) {
118117 uint block_start = block_index * blocksize;
119118 if (block_start >= n) {
120119 return ;
121120 }
122- uint block_end = min (block_start + blocksize, n);
121+ uint block_end;
122+ if (block_start + blocksize < n) {
123+ block_end = block_start + blocksize;
124+ } else {
125+ block_end = n;
126+ }
123127 uint pairs_in_block = (block_end - block_start + 1 ) >> 1 ;
124128
125- if (thread_idx == 0 ) {
126- shared_scale = absmax[block_index];
129+ float scale = absmax[block_index];
130+
131+ // Precompute scaled table in registers - avoids threadgroup bank conflicts
132+ // and constant memory is broadcast-optimized so initial loads are fast
133+ float scaled_table[16 ];
134+ for (uint i = 0 ; i < 16 ; i++) {
135+ scaled_table[i] = code_table[i] * scale;
127136 }
128- threadgroup_barrier (mem_flags::mem_threadgroup);
129- float scale = shared_scale;
130137
131138 for (uint pair = thread_idx; pair < pairs_in_block; pair += threadgroup_size) {
132139 uint value_index0 = block_start + pair * 2 ;
@@ -136,23 +143,204 @@ inline void dequantize_block(
136143
137144 uint byte_index0 = value_index0 >> 1 ;
138145 uchar byte_val0 = packed[byte_index0];
139- bool upper0 = ((value_index0 & 1 ) == 0 );
140- uchar nibble0 = upper0 ? ((byte_val0 >> 4 ) & 0xF ) : (byte_val0 & 0xF );
141- float decoded0 = code_table[nibble0] * scale;
146+ // High nibble -> even index, low nibble -> odd index (matches Python ref)
147+ uchar nibble0 = (byte_val0 >> 4 ) & 0xF ;
148+ uchar nibble1 = byte_val0 & 0xF ;
149+ float decoded0 = scaled_table[nibble0];
150+ float decoded1 = scaled_table[nibble1];
151+ // value_index0 is already the output index (block_start + pair*2)
142152 output[value_index0] = scalar_t (decoded0);
143-
144- uint value_index1 = value_index0 + 1 ;
145- if (value_index1 < block_end) {
146- uint byte_index1 = value_index1 >> 1 ;
147- uchar byte_val1 = (byte_index1 == byte_index0) ? byte_val0 : packed[byte_index1];
148- bool upper1 = ((value_index1 & 1 ) == 0 );
149- uchar nibble1 = upper1 ? ((byte_val1 >> 4 ) & 0xF ) : (byte_val1 & 0xF );
150- float decoded1 = code_table[nibble1] * scale;
151- output[value_index1] = scalar_t (decoded1);
153+
154+ // Bounds check for odd-length blocks
155+ if (value_index0 + 1 < block_end) {
156+ output[value_index0 + 1 ] = scalar_t (decoded1);
152157 }
153158 }
154159}
155160
161+ // template <typename scalar_t>
162+ // inline void dequantize_block(
163+ // device const uchar* packed,
164+ // device const float* absmax,
165+ // device scalar_t* output,
166+ // uint n,
167+ // uint blocksize,
168+ // uint block_index,
169+ // uint thread_idx,
170+ // uint threadgroup_size,
171+ // constant float* code_table
172+ // ) {
173+ // const uint block_start = block_index * blocksize;
174+ // if (block_start >= n) return;
175+
176+ // const uint block_end = min(block_start + blocksize, n);
177+ // const uint num_values = block_end - block_start;
178+
179+ // const float scale = absmax[block_index];
180+
181+ // // Precompute scaled code table
182+ // float scaled_table[16];
183+ // for (uint i = 0; i < 16; ++i)
184+ // scaled_table[i] = code_table[i] * scale;
185+
186+ // device const uchar* packed_ptr = packed + (block_start >> 1);
187+ // device scalar_t* output_ptr = output + block_start;
188+
189+ // // Each thread processes multiple *bytes* at a stride
190+ // const uint bytes_in_block = (num_values + 1) >> 1;
191+
192+ // for (uint byte_idx = thread_idx; byte_idx < bytes_in_block; byte_idx += threadgroup_size) {
193+ // uchar byte_val = packed_ptr[byte_idx];
194+
195+ // // Decode upper and lower nibbles
196+ // uchar upper_nib = (byte_val >> 4) & 0xF;
197+ // uchar lower_nib = byte_val & 0xF;
198+
199+ // // Compute global value index
200+ // uint val_idx = byte_idx << 1; // byte_idx * 2
201+
202+ // // Write both values if in bounds
203+ // if (val_idx < num_values) output_ptr[val_idx] = scalar_t(scaled_table[upper_nib]);
204+ // if (val_idx + 1 < num_values) output_ptr[val_idx + 1] = scalar_t(scaled_table[lower_nib]);
205+ // }
206+ // }
207+
208+ // template <typename scalar_t>
209+ // inline void dequantize_block(
210+ // device const uchar* packed,
211+ // device const float* absmax,
212+ // device scalar_t* output,
213+ // uint n,
214+ // uint blocksize,
215+ // uint block_index,
216+ // uint thread_idx,
217+ // uint threadgroup_size,
218+ // constant float* code_table
219+ // ) {
220+ // const uint block_start = block_index * blocksize;
221+ // if (block_start >= n) return;
222+
223+ // const uint block_end = min(block_start + blocksize, n);
224+ // const uint num_values = block_end - block_start;
225+
226+ // const float scale = absmax[block_index];
227+
228+ // // Precompute scaled code table
229+ // float scaled_table[16];
230+ // for (uint i = 0; i < 16; ++i)
231+ // scaled_table[i] = code_table[i] * scale;
232+
233+ // device const uchar* packed_ptr = packed + (block_start >> 1);
234+ // device scalar_t* output_ptr = output + block_start;
235+
236+ // // Each thread processes multiple uchar4 (4 bytes = 8 values)
237+ // const uint num_bytes = (num_values + 1) >> 1; // total bytes in block
238+ // const uint num_blocks = (num_bytes + 3) >> 2; // number of uchar4 blocks
239+
240+ // for (uint block_idx = thread_idx; block_idx < num_blocks; block_idx += threadgroup_size) {
241+ // uint byte_offset = block_idx * 4; // starting byte in packed array
242+ // uchar4 b = uchar4(0); // default zero
243+
244+ // // Load safely (handle tail)
245+ // if (byte_offset + 3 < num_bytes) {
246+ // b = *((device uchar4*)(packed_ptr + byte_offset));
247+ // } else {
248+ // // Tail case: read remaining bytes safely
249+ // uchar temp[4] = {0, 0, 0, 0};
250+ // for (uint i = 0; i < num_bytes - byte_offset; ++i) {
251+ // temp[i] = packed_ptr[byte_offset + i];
252+ // }
253+ // b = uchar4(temp[0], temp[1], temp[2], temp[3]);
254+ // }
255+
256+ // // Decode 8 nibbles into 8 values
257+ // uchar nibbles[8] = {
258+ // uchar((b.x >> 4) & 0xF), uchar(b.x & 0xF),
259+ // uchar((b.y >> 4) & 0xF), uchar(b.y & 0xF),
260+ // uchar((b.z >> 4) & 0xF), uchar(b.z & 0xF),
261+ // uchar((b.w >> 4) & 0xF), uchar(b.w & 0xF)
262+ // };
263+
264+ // // Compute global value indices and write outputs
265+ // uint val_idx = byte_offset << 1; // byte_offset * 2
266+ // for (uint i = 0; i < 8; ++i) {
267+ // if (val_idx + i < num_values)
268+ // output_ptr[val_idx + i] = scalar_t(scaled_table[nibbles[i]]);
269+ // }
270+ // }
271+ // }
272+
273+ // template <typename scalar_t>
274+ // inline void dequantize_block(
275+ // device const uchar* packed,
276+ // device const float* absmax,
277+ // device scalar_t* output,
278+ // uint n,
279+ // uint blocksize,
280+ // uint block_index,
281+ // uint thread_idx,
282+ // uint threadgroup_size,
283+ // constant float* code_table
284+ // ) {
285+ // const uint block_start = block_index * blocksize;
286+ // if (block_start >= n) return;
287+
288+ // const uint block_end = min(block_start + blocksize, n);
289+ // const uint num_values = block_end - block_start;
290+
291+ // const float scale = absmax[block_index];
292+
293+ // // Precompute scaled code table
294+ // float scaled_table[16];
295+ // for (uint i = 0; i < 16; ++i)
296+ // scaled_table[i] = code_table[i] * scale;
297+
298+ // device const uchar* packed_ptr = packed + (block_start >> 1);
299+ // device scalar_t* output_ptr = output + block_start;
300+
301+ // const uint num_bytes = (num_values + 1) >> 1; // total bytes in block
302+ // const uint num_uchar4 = (num_bytes + 3) >> 2; // total uchar4 blocks
303+
304+ // // Each thread handles one or two uchar4 blocks
305+ // uint block_pos = thread_idx;
306+ // if (block_pos >= num_uchar4) return;
307+
308+ // // Compute byte offset
309+ // uint byte_offset = block_pos * 4;
310+ // uchar4 b = uchar4(0, 0, 0, 0);
311+
312+ // // Safe load
313+ // if (byte_offset + 3 < num_bytes) {
314+ // b = *((device uchar4*)(packed_ptr + byte_offset));
315+ // } else {
316+ // uchar temp[4] = {0, 0, 0, 0};
317+ // for (uint i = 0; i < num_bytes - byte_offset; ++i)
318+ // temp[i] = packed_ptr[byte_offset + i];
319+ // b = uchar4(temp[0], temp[1], temp[2], temp[3]);
320+ // }
321+
322+ // // Decode 8 nibbles
323+ // uchar nibbles[8] = {
324+ // uchar((b.x >> 4) & 0xF), uchar(b.x & 0xF),
325+ // uchar((b.y >> 4) & 0xF), uchar(b.y & 0xF),
326+ // uchar((b.z >> 4) & 0xF), uchar(b.z & 0xF),
327+ // uchar((b.w >> 4) & 0xF), uchar(b.w & 0xF)
328+ // };
329+
330+ // // Compute global value index
331+ // uint val_idx = byte_offset << 1; // byte_offset * 2
332+
333+ // // Fully unrolled writes (branch-free for main values)
334+ // if (val_idx + 0 < num_values) output_ptr[val_idx + 0] = scalar_t(scaled_table[nibbles[0]]);
335+ // if (val_idx + 1 < num_values) output_ptr[val_idx + 1] = scalar_t(scaled_table[nibbles[1]]);
336+ // if (val_idx + 2 < num_values) output_ptr[val_idx + 2] = scalar_t(scaled_table[nibbles[2]]);
337+ // if (val_idx + 3 < num_values) output_ptr[val_idx + 3] = scalar_t(scaled_table[nibbles[3]]);
338+ // if (val_idx + 4 < num_values) output_ptr[val_idx + 4] = scalar_t(scaled_table[nibbles[4]]);
339+ // if (val_idx + 5 < num_values) output_ptr[val_idx + 5] = scalar_t(scaled_table[nibbles[5]]);
340+ // if (val_idx + 6 < num_values) output_ptr[val_idx + 6] = scalar_t(scaled_table[nibbles[6]]);
341+ // if (val_idx + 7 < num_values) output_ptr[val_idx + 7] = scalar_t(scaled_table[nibbles[7]]);
342+ // }
343+
156344} // namespace
157345
158346// Quantization kernels
@@ -255,8 +443,7 @@ kernel void dequantize_4bit_fp16_fp4(
255443 if (tgid >= blocks) {
256444 return ;
257445 }
258- threadgroup float shared_scale;
259- dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE , shared_scale);
446+ dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE );
260447}
261448
262449kernel void dequantize_4bit_fp16_nf4 (
@@ -273,8 +460,7 @@ kernel void dequantize_4bit_fp16_nf4(
273460 if (tgid >= blocks) {
274461 return ;
275462 }
276- threadgroup float shared_scale;
277- dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE , shared_scale);
463+ dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE );
278464}
279465
280466kernel void dequantize_4bit_fp32_fp4 (
@@ -291,8 +477,7 @@ kernel void dequantize_4bit_fp32_fp4(
291477 if (tgid >= blocks) {
292478 return ;
293479 }
294- threadgroup float shared_scale;
295- dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE , shared_scale);
480+ dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE );
296481}
297482
298483kernel void dequantize_4bit_fp32_nf4 (
@@ -306,9 +491,8 @@ kernel void dequantize_4bit_fp32_nf4(
306491 uint tid [[thread_index_in_threadgroup]] ,
307492 uint threadgroup_size [[threads_per_threadgroup]]
308493) {
309- if (tgid >= blocks) {
310- return ;
311- }
312- threadgroup float shared_scale;
313- dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE , shared_scale);
494+ // if (tgid >= blocks) {
495+ // return;
496+ // }
497+ dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE );
314498}
0 commit comments