@@ -155,6 +155,198 @@ fn main(
155155 }
156156#endif
157157
158+ #ifdef MUL_ACC_Q4_1
159+ #define BLOCK_SIZE 32
160+ #define BLOCK_SIZE_BYTES 20
161+ #define THREADS_PER_BLOCK 4
162+ #define ELEMS_PER_THREAD (BLOCK_SIZE / THREADS_PER_BLOCK )
163+
164+ let num_blocks = params . k / BLOCK_SIZE ;
165+ let thread_within_block = thread_id % THREADS_PER_BLOCK ;
166+ for (var block = thread_id / THREADS_PER_BLOCK ; block < num_blocks ; block += WG_SIZE / THREADS_PER_BLOCK ) {
167+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4 ;
168+ var x_block : array <f32 , ELEMS_PER_THREAD >;
169+ for (var i = 0u ; i < ELEMS_PER_THREAD / 2 ; i ++ ) {
170+ x_block [i ] = f32 (src1 [x_base + i ]);
171+ x_block [i + 4 ] = f32 (src1 [x_base + i + 16 ]);
172+ }
173+
174+ for (var row = 0u ; row < OUTPUTS_PER_WG ; row ++ ) {
175+ let output_row = row_base + row ;
176+ if (output_row < params . m ) {
177+ let block_byte_base = (src0_batch_offset + output_row * params . stride_01 + block ) * BLOCK_SIZE_BYTES ;
178+ let d = f32 (load_src0_f16_at (block_byte_base ));
179+ let m = f32 (load_src0_f16_at (block_byte_base + 2u ));
180+ var row_sum = 0 .0 ;
181+
182+ let q_packed = load_src0_u32_at (block_byte_base + 4u + 4u * thread_within_block );
183+ for (var byte_idx = 0u ; byte_idx < 4u ; byte_idx ++ ) {
184+ let q_byte = get_byte (q_packed , byte_idx );
185+ let q_lo = f32 (q_byte & 0xFu ) * d + m ;
186+ let q_hi = f32 ((q_byte >> 4u ) & 0xFu ) * d + m ;
187+ row_sum += q_lo * x_block [byte_idx ];
188+ row_sum += q_hi * x_block [byte_idx + 4u ];
189+ }
190+ acc [row ] += row_sum ;
191+ }
192+ }
193+ }
194+ #endif
195+
196+ #ifdef MUL_ACC_Q5_0
197+ #define BLOCK_SIZE 32
198+ #define BLOCK_SIZE_BYTES 22
199+ #define THREADS_PER_BLOCK 4
200+ #define ELEMS_PER_THREAD (BLOCK_SIZE / THREADS_PER_BLOCK )
201+
202+ let num_blocks = params . k / BLOCK_SIZE ;
203+ let thread_within_block = thread_id % THREADS_PER_BLOCK ;
204+ for (var block = thread_id / THREADS_PER_BLOCK ; block < num_blocks ; block += WG_SIZE / THREADS_PER_BLOCK ) {
205+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4 ;
206+ var x_block : array <f32 , ELEMS_PER_THREAD >;
207+ for (var i = 0u ; i < ELEMS_PER_THREAD / 2 ; i ++ ) {
208+ x_block [i ] = f32 (src1 [x_base + i ]);
209+ x_block [i + 4 ] = f32 (src1 [x_base + i + 16 ]);
210+ }
211+
212+ for (var row = 0u ; row < OUTPUTS_PER_WG ; row ++ ) {
213+ let output_row = row_base + row ;
214+ if (output_row < params . m ) {
215+ let block_byte_base = (src0_batch_offset + output_row * params . stride_01 + block ) * BLOCK_SIZE_BYTES ;
216+ let d = f32 (load_src0_f16_at (block_byte_base ));
217+ let qh_packed = load_src0_u32_at (block_byte_base + 2u );
218+ let q_packed = load_src0_u32_at (block_byte_base + 6u + 4u * thread_within_block );
219+ let qh_shift = thread_within_block * 4u ;
220+ var row_sum = 0 .0 ;
221+
222+ for (var byte_idx = 0u ; byte_idx < 4u ; byte_idx ++ ) {
223+ let q_byte = get_byte (q_packed , byte_idx );
224+ let qh_lo = ((qh_packed >> (qh_shift + byte_idx )) << 4u ) & 0x10u ;
225+ let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u )) & 0x10u ;
226+ let q_lo = (f32 ((q_byte & 0xFu ) | qh_lo ) - 16 .0 ) * d ;
227+ let q_hi = (f32 (((q_byte >> 4u ) & 0xFu ) | qh_hi ) - 16 .0 ) * d ;
228+ row_sum += q_lo * x_block [byte_idx ];
229+ row_sum += q_hi * x_block [byte_idx + 4u ];
230+ }
231+ acc [row ] += row_sum ;
232+ }
233+ }
234+ }
235+ #endif
236+
237+ #ifdef MUL_ACC_Q5_1
238+ #define BLOCK_SIZE 32
239+ #define BLOCK_SIZE_BYTES 24
240+ #define THREADS_PER_BLOCK 4
241+ #define ELEMS_PER_THREAD (BLOCK_SIZE / THREADS_PER_BLOCK )
242+
243+ let num_blocks = params . k / BLOCK_SIZE ;
244+ let thread_within_block = thread_id % THREADS_PER_BLOCK ;
245+ for (var block = thread_id / THREADS_PER_BLOCK ; block < num_blocks ; block += WG_SIZE / THREADS_PER_BLOCK ) {
246+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4 ;
247+ var x_block : array <f32 , ELEMS_PER_THREAD >;
248+ for (var i = 0u ; i < ELEMS_PER_THREAD / 2 ; i ++ ) {
249+ x_block [i ] = f32 (src1 [x_base + i ]);
250+ x_block [i + 4 ] = f32 (src1 [x_base + i + 16 ]);
251+ }
252+
253+ for (var row = 0u ; row < OUTPUTS_PER_WG ; row ++ ) {
254+ let output_row = row_base + row ;
255+ if (output_row < params . m ) {
256+ let block_byte_base = (src0_batch_offset + output_row * params . stride_01 + block ) * BLOCK_SIZE_BYTES ;
257+ let d = f32 (load_src0_f16_at (block_byte_base ));
258+ let m = f32 (load_src0_f16_at (block_byte_base + 2u ));
259+ let qh_packed = load_src0_u32_at (block_byte_base + 4u );
260+ let q_packed = load_src0_u32_at (block_byte_base + 8u + 4u * thread_within_block );
261+ let qh_shift = thread_within_block * 4u ;
262+ var row_sum = 0 .0 ;
263+
264+ for (var byte_idx = 0u ; byte_idx < 4u ; byte_idx ++ ) {
265+ let q_byte = get_byte (q_packed , byte_idx );
266+ let qh_lo = ((qh_packed >> (qh_shift + byte_idx )) << 4u ) & 0x10u ;
267+ let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u )) & 0x10u ;
268+ let q_lo = f32 ((q_byte & 0xFu ) | qh_lo ) * d + m ;
269+ let q_hi = f32 (((q_byte >> 4u ) & 0xFu ) | qh_hi ) * d + m ;
270+ row_sum += q_lo * x_block [byte_idx ];
271+ row_sum += q_hi * x_block [byte_idx + 4u ];
272+ }
273+ acc [row ] += row_sum ;
274+ }
275+ }
276+ }
277+ #endif
278+
279+ #ifdef MUL_ACC_Q8_0
280+ #define BLOCK_SIZE 32
281+ #define BLOCK_SIZE_BYTES 34
282+ #define THREADS_PER_BLOCK 4
283+ #define ELEMS_PER_THREAD (BLOCK_SIZE / THREADS_PER_BLOCK )
284+
285+ let num_blocks = params . k / BLOCK_SIZE ;
286+ let thread_within_block = thread_id % THREADS_PER_BLOCK ;
287+ for (var block = thread_id / THREADS_PER_BLOCK ; block < num_blocks ; block += WG_SIZE / THREADS_PER_BLOCK ) {
288+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD ;
289+ var x_block : array <f32 , ELEMS_PER_THREAD >;
290+ for (var i = 0u ; i < ELEMS_PER_THREAD ; i ++ ) {
291+ x_block [i ] = f32 (src1 [x_base + i ]);
292+ }
293+
294+ for (var row = 0u ; row < OUTPUTS_PER_WG ; row ++ ) {
295+ let output_row = row_base + row ;
296+ if (output_row < params . m ) {
297+ let block_byte_base = (src0_batch_offset + output_row * params . stride_01 + block ) * BLOCK_SIZE_BYTES ;
298+ let d = f32 (load_src0_f16_at (block_byte_base ));
299+ var row_sum = 0 .0 ;
300+
301+ for (var packed_idx = 0u ; packed_idx < ELEMS_PER_THREAD / 4u ; packed_idx ++ ) {
302+ let q_packed = load_src0_u32_at (block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx ));
303+ for (var byte_idx = 0u ; byte_idx < 4u ; byte_idx ++ ) {
304+ let q_val = f32 (get_byte_i32 (q_packed , byte_idx )) * d ;
305+ row_sum += q_val * x_block [packed_idx * 4u + byte_idx ];
306+ }
307+ }
308+ acc [row ] += row_sum ;
309+ }
310+ }
311+ }
312+ #endif
313+
314+ #ifdef MUL_ACC_Q8_1
315+ #define BLOCK_SIZE 32
316+ #define BLOCK_SIZE_BYTES 36
317+ #define THREADS_PER_BLOCK 4
318+ #define ELEMS_PER_THREAD (BLOCK_SIZE / THREADS_PER_BLOCK )
319+
320+ let num_blocks = params . k / BLOCK_SIZE ;
321+ let thread_within_block = thread_id % THREADS_PER_BLOCK ;
322+ for (var block = thread_id / THREADS_PER_BLOCK ; block < num_blocks ; block += WG_SIZE / THREADS_PER_BLOCK ) {
323+ let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD ;
324+ var x_block : array <f32 , ELEMS_PER_THREAD >;
325+ for (var i = 0u ; i < ELEMS_PER_THREAD ; i ++ ) {
326+ x_block [i ] = f32 (src1 [x_base + i ]);
327+ }
328+
329+ for (var row = 0u ; row < OUTPUTS_PER_WG ; row ++ ) {
330+ let output_row = row_base + row ;
331+ if (output_row < params . m ) {
332+ let block_byte_base = (src0_batch_offset + output_row * params . stride_01 + block ) * BLOCK_SIZE_BYTES ;
333+ let d = f32 (load_src0_f16_at (block_byte_base ));
334+ let m = f32 (load_src0_f16_at (block_byte_base + 2u ));
335+ var row_sum = 0 .0 ;
336+
337+ for (var packed_idx = 0u ; packed_idx < ELEMS_PER_THREAD / 4u ; packed_idx ++ ) {
338+ let q_packed = load_src0_u32_at (block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx ));
339+ for (var byte_idx = 0u ; byte_idx < 4u ; byte_idx ++ ) {
340+ let q_val = f32 (get_byte_i32 (q_packed , byte_idx )) * d + m ;
341+ row_sum += q_val * x_block [packed_idx * 4u + byte_idx ];
342+ }
343+ }
344+ acc [row ] += row_sum ;
345+ }
346+ }
347+ }
348+ #endif
349+
158350#ifdef USE_SUBGROUP_REDUCTION
159351 for (var row = 0u ; row < OUTPUTS_PER_WG ; row ++ ) {
160352 let subgroup_total = subgroupAdd (acc [row ]);
0 commit comments