@@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
4747#endif
4848#elif defined(DATA_A_Q4_0)
4949 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
50- const uint buf_idx = col * SHMEM_STRIDE + 2 * row ;
50+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4 ;
5151
5252 const uint ib = idx / 4 ;
5353 const uint iqs = idx & 0x03;
@@ -63,24 +63,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
6363 buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
6464#elif defined(DATA_A_Q4_1)
6565 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
66- const uint buf_idx = col * SHMEM_STRIDE + 2 * row ;
66+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4 ;
6767
6868 const uint ib = idx / 4 ;
6969 const uint iqs = idx & 0x03;
7070
71- const float d = float (data_a_packed16[ib].d);
72- const float m = float (data_a_packed16[ib].m);
73- const uint vui = uint (data_a_packed16[ib].qs[2 * iqs]) | (uint (data_a_packed16[ib].qs[2 * iqs + 1 ]) << 16 );
74- const vec4 v0 = vec4 (unpack8(vui & 0x0F0F0F0F)) * d + m;
75- const vec4 v1 = vec4 (unpack8((vui >> 4 ) & 0x0F0F0F0F)) * d + m;
71+ const vec2 dm = vec2 (data_a_packed32[ib].dm);
72+ const uint vui = data_a_packed32[ib].qs[iqs];
73+ const vec4 v0 = vec4 (unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
74+ const vec4 v1 = vec4 (unpack8((vui >> 4 ) & 0x0F0F0F0F)) * dm.x + dm.y;
7675
7776 buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
7877 buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
7978 buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
8079 buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
8180#elif defined(DATA_A_Q5_0)
8281 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
83- const uint buf_idx = col * SHMEM_STRIDE + row;
82+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4 ;
8483
8584 const uint ib = idx / 8 ;
8685 const uint iqs = idx & 0x07;
@@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
9796 buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v.yw);
9897#elif defined(DATA_A_Q5_1)
9998 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
100- const uint buf_idx = col * SHMEM_STRIDE + row;
101-
102- const uint ib = idx / 8 ;
103- const uint iqs = idx & 0x07;
104-
105- const float d = float (data_a_packed16[ib].d);
106- const float m = float (data_a_packed16[ib].m);
107- const uint uint_qh = data_a_packed16[ib].qh;
108- const ivec2 qh0 = ivec2 (((uint_qh >> 2 * iqs) << 4 ) & 0x10, (uint_qh >> (2 * iqs + 12 )) & 0x10);
109- const ivec2 qh1 = ivec2 (((uint_qh >> (2 * iqs + 1 )) << 4 ) & 0x10, (uint_qh >> (2 * iqs + 13 )) & 0x10);
99+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4 ;
110100
111- const uint vui = uint (data_a_packed16[ib].qs[iqs]) ;
112- const vec4 v = vec4 ((vui & 0xF) | qh0.x, ((vui >> 4 ) & 0xF) | qh0.y, ((vui >> 8 ) & 0xF) | qh1.x, (vui >> 12 ) | qh1.y) * d + m ;
101+ const uint ib = idx / 4 ;
102+ const uint iqs = idx & 0x03 ;
113103
114- buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
115- buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v.yw);
104+ const vec2 dm = vec2 (data_a_packed32[ib].dm);
105+ const uint uint_qh = data_a_packed32[ib].qh;
106+ const uvec2 qh0 = uvec2 (((uint_qh >> 4 * iqs) << 4 ) & 0x10, (uint_qh >> (4 * iqs + 12 )) & 0x10);
107+ const uvec2 qh1 = uvec2 (((uint_qh >> (4 * iqs + 1 )) << 4 ) & 0x10, (uint_qh >> (4 * iqs + 13 )) & 0x10);
108+ const uvec2 qh2 = uvec2 (((uint_qh >> (4 * iqs + 2 )) << 4 ) & 0x10, (uint_qh >> (4 * iqs + 14 )) & 0x10);
109+ const uvec2 qh3 = uvec2 (((uint_qh >> (4 * iqs + 3 )) << 4 ) & 0x10, (uint_qh >> (4 * iqs + 15 )) & 0x10);
110+
111+ const uint vui = data_a_packed32[ib].qs[iqs];
112+ const vec4 v0 = vec4 ((vui & 0xF) | qh0.x, ((vui >> 4 ) & 0xF) | qh0.y, ((vui >> 8 ) & 0xF) | qh1.x, ((vui >> 12 ) & 0xF) | qh1.y) * dm.x + dm.y;
113+ const vec4 v1 = vec4 (((vui >> 16 ) & 0xF) | qh2.x, ((vui >> 20 ) & 0xF) | qh2.y, ((vui >> 24 ) & 0xF) | qh3.x, ((vui >> 28 ) & 0xF) | qh3.y) * dm.x + dm.y;
114+
115+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
116+ buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v1.xz);
117+ buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v0.yw);
118+ buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.yw);
116119#elif defined(DATA_A_Q8_0)
117120 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
118121 const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2 ;
@@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
131134 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
132135 const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2 ;
133136
134- const uint ib = idx / 128 ; // 2 values per idx
135- const uint iqs = idx % 128 ; // 0..127
137+ const uint ib = idx / 64 ; // 4 values per idx
138+ const uint iqs = ( idx % 64 ) * 2 ; // 0,2,4..126
136139
137140 const uint qsi = (iqs / 64 ) * 16 + (iqs % 16 ); // 0..15
138141 const uint scalesi = iqs / 8 ; // 0..15
139142 const uint qsshift = ((iqs % 64 ) / 16 ) * 2 ; // 0,2,4,6
140143
141- const uvec2 qs = uvec2 (unpack8(data_a_packed16 [ib].qs[qsi] ));
144+ const vec4 qs = vec4 (unpack8((data_a_packed32 [ib].qs[qsi / 2 ] >> qsshift) & 0x03030303 ));
142145 const uint scales = data_a[ib].scales[scalesi];
143146 const vec2 dm = vec2 (data_a[ib].dm);
144147
145- const vec2 v = dm.x * float (scales & 0xF) * vec2 ((qs >> qsshift) & 3 ) - dm.y * float (scales >> 4 );
148+ const vec4 v = dm.x * float (scales & 0xF) * qs - dm.y * float (scales >> 4 );
146149
147- buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
150+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
151+ buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v.zw);
148152#elif defined(DATA_A_Q3_K)
149153 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
150154 const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2 ;
@@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
173177 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
174178 const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2 ;
175179
176- const uint ib = idx / 128 ; // 2 values per idx
177- const uint iqs = idx % 128 ; // 0..127
180+ const uint ib = idx / 64 ; // 4 values per idx
181+ const uint iqs = ( idx % 64 ) * 2 ; // 0,2,4..126
178182
179183 const uint n = iqs / 32 ; // 0,1,2,3
180184 const uint b = (iqs % 32 ) / 16 ; // 0,1
@@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
200204 const float d = loadd.x * sc;
201205 const float m = - loadd.y * mbyte;
202206
203- const vec2 q = vec2 (unpack8((uint (data_a_packed16 [ib].qs[qsi / 2 ]) >> (b * 4 )) & 0x0F0F).xy );
207+ const vec4 q = vec4 (unpack8((data_a_packed32 [ib].qs[qsi / 4 ] >> (b * 4 )) & 0x0F0F0F0F) );
204208
205- buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
206- fma(d, q.y , m));
209+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
210+ buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w , m));
207211#elif defined(DATA_A_Q5_K)
208212 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
209213 const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2 ;
210214
211- const uint ib = idx / 128 ; // 2 values per idx
212- const uint iqs = idx % 128 ; // 0..127
215+ const uint ib = idx / 64 ; // 4 values per idx
216+ const uint iqs = ( idx % 64 ) * 2 ; // 0,2,4..126
213217
214218 const uint n = iqs / 32 ; // 0,1,2,3
215219 const uint b = (iqs % 32 ) / 16 ; // 0,1
@@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
236240 const float d = loadd.x * sc;
237241 const float m = - loadd.y * mbyte;
238242
239- const uint qs = (uint (data_a_packed16 [ib].qs[qsi / 2 ]) >> (b * 4 )) & 0x0F0F ;
240- const uint qh = ((uint (data_a_packed16 [ib].qh[qhi / 2 ]) >> (iqs / 16 )) & 0x0101 ) << 4 ;
241- const vec2 q = vec2 (unpack8(qs | qh).xy );
243+ const uint qs = (data_a_packed32 [ib].qs[qsi / 4 ] >> (b * 4 )) & 0x0F0F0F0F ;
244+ const uint qh = ((data_a_packed32 [ib].qh[qhi / 4 ] >> (iqs / 16 )) & 0x01010101 ) << 4 ;
245+ const vec4 q = vec4 (unpack8(qs | qh));
242246
243- buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
244- fma(d, q.y , m));
247+ buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
248+ buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w , m));
245249#elif defined(DATA_A_Q6_K)
246250 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
247251 const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2 ;
@@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
455459 buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
456460#elif defined(DATA_A_IQ4_NL)
457461 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
458- const uint buf_idx = col * SHMEM_STRIDE + row;
462+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4 ;
459463
460464 const uint ib = idx / 8 ;
461465 const uint iqs = idx & 0x07;
@@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
469473 kvalues_iq4nl[vui >> 12 ]);
470474#elif defined(DATA_A_MXFP4)
471475 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
472- const uint buf_idx = col * SHMEM_STRIDE + row;
476+ const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4 ;
473477
474478 const uint ib = idx / 8 ;
475479 const uint iqs = (idx & 0x07) * 2 ;
0 commit comments