@@ -85,30 +85,46 @@ inline void dequantize_block(
8585 uint n,
8686 uint blocksize,
8787 uint block_index,
88- constant float * code_table
88+ uint thread_idx,
89+ uint threadgroup_size,
90+ constant float * code_table,
91+ threadgroup float & shared_scale
8992) {
90- uint start = block_index * blocksize;
91- if (start >= n) {
93+ uint block_start = block_index * blocksize;
94+ if (block_start >= n) {
9295 return ;
9396 }
97+ uint block_end = min (block_start + blocksize, n);
98+ uint pairs_in_block = (block_end - block_start + 1 ) >> 1 ;
9499
95- uint end = min (start + blocksize, n);
96- float scale = absmax[block_index];
97- if (scale == 0 .0f ) {
98- for (uint i = start; i < end; ++i) {
99- output[i] = scalar_t (0 .0f );
100- }
101- return ;
100+ if (thread_idx == 0 ) {
101+ shared_scale = absmax[block_index];
102102 }
103+ threadgroup_barrier (mem_flags::mem_threadgroup);
104+ float scale = shared_scale;
105+
106+ for (uint pair = thread_idx; pair < pairs_in_block; pair += threadgroup_size) {
107+ uint value_index0 = block_start + pair * 2 ;
108+ if (value_index0 >= block_end) {
109+ break ;
110+ }
111+
112+ uint byte_index0 = value_index0 >> 1 ;
113+ uchar byte_val0 = packed[byte_index0];
114+ bool upper0 = ((value_index0 & 1 ) == 0 );
115+ uchar nibble0 = upper0 ? ((byte_val0 >> 4 ) & 0xF ) : (byte_val0 & 0xF );
116+ float decoded0 = code_table[nibble0] * scale;
117+ output[value_index0] = scalar_t (decoded0);
103118
104- uint base_byte = start >> 1 ;
105- for (uint offset = 0 ; offset < end - start; ++offset) {
106- uint global_index = start + offset;
107- uint byte_index = base_byte + (offset >> 1 );
108- uchar byte_val = packed[byte_index];
109- uchar nibble = (offset & 1 ) == 0 ? (byte_val >> 4 ) & 0xF : byte_val & 0xF ;
110- float decoded = code_table[nibble] * scale;
111- output[global_index] = scalar_t (decoded);
119+ uint value_index1 = value_index0 + 1 ;
120+ if (value_index1 < block_end) {
121+ uint byte_index1 = value_index1 >> 1 ;
122+ uchar byte_val1 = (byte_index1 == byte_index0) ? byte_val0 : packed[byte_index1];
123+ bool upper1 = ((value_index1 & 1 ) == 0 );
124+ uchar nibble1 = upper1 ? ((byte_val1 >> 4 ) & 0xF ) : (byte_val1 & 0xF );
125+ float decoded1 = code_table[nibble1] * scale;
126+ output[value_index1] = scalar_t (decoded1);
127+ }
112128 }
113129}
114130
@@ -183,13 +199,15 @@ kernel void dequantize_4bit_fp16_fp4(
183199 constant uint& n [[buffer(3 )]],
184200 constant uint& blocksize [[buffer(4 )]],
185201 constant uint& blocks [[buffer(5 )]],
186- uint gid [[thread_position_in_grid]],
187-
202+ uint tgid [[threadgroup_position_in_grid]],
203+ uint tid [[thread_index_in_threadgroup]],
204+ uint threadgroup_size [[threads_per_threadgroup]]
188205) {
189- if (gid >= blocks) {
206+ if (tgid >= blocks) {
190207 return ;
191208 }
192- dequantize_block (packed, absmax, output, n, blocksize, gid, FP4_CODE);
209+ threadgroup float shared_scale;
210+ dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_scale);
193211}
194212
195213kernel void dequantize_4bit_fp16_nf4 (
@@ -199,12 +217,15 @@ kernel void dequantize_4bit_fp16_nf4(
199217 constant uint& n [[buffer(3 )]],
200218 constant uint& blocksize [[buffer(4 )]],
201219 constant uint& blocks [[buffer(5 )]],
202- uint gid [[thread_position_in_grid]]
220+ uint tgid [[threadgroup_position_in_grid]],
221+ uint tid [[thread_index_in_threadgroup]],
222+ uint threadgroup_size [[threads_per_threadgroup]]
203223) {
204- if (gid >= blocks) {
224+ if (tgid >= blocks) {
205225 return ;
206226 }
207- dequantize_block (packed, absmax, output, n, blocksize, gid, NF4_CODE);
227+ threadgroup float shared_scale;
228+ dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_scale);
208229}
209230
210231kernel void dequantize_4bit_fp32_fp4 (
@@ -214,12 +235,15 @@ kernel void dequantize_4bit_fp32_fp4(
214235 constant uint& n [[buffer(3 )]],
215236 constant uint& blocksize [[buffer(4 )]],
216237 constant uint& blocks [[buffer(5 )]],
217- uint gid [[thread_position_in_grid]]
238+ uint tgid [[threadgroup_position_in_grid]],
239+ uint tid [[thread_index_in_threadgroup]],
240+ uint threadgroup_size [[threads_per_threadgroup]]
218241) {
219- if (gid >= blocks) {
242+ if (tgid >= blocks) {
220243 return ;
221244 }
222- dequantize_block (packed, absmax, output, n, blocksize, gid, FP4_CODE);
245+ threadgroup float shared_scale;
246+ dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_scale);
223247}
224248
225249kernel void dequantize_4bit_fp32_nf4 (
@@ -229,10 +253,13 @@ kernel void dequantize_4bit_fp32_nf4(
229253 constant uint& n [[buffer(3 )]],
230254 constant uint& blocksize [[buffer(4 )]],
231255 constant uint& blocks [[buffer(5 )]],
232- uint gid [[thread_position_in_grid]]
256+ uint tgid [[threadgroup_position_in_grid]],
257+ uint tid [[thread_index_in_threadgroup]],
258+ uint threadgroup_size [[threads_per_threadgroup]]
233259) {
234- if (gid >= blocks) {
260+ if (tgid >= blocks) {
235261 return ;
236262 }
237- dequantize_block (packed, absmax, output, n, blocksize, gid, NF4_CODE);
263+ threadgroup float shared_scale;
264+ dequantize_block (packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_scale);
238265}
0 commit comments