@@ -61,6 +61,39 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
6161#endif // INIT_SRC1_SHMEM_FLOAT
6262#endif
6363
64+ #ifdef INIT_SRC0_SHMEM_Q1_0
65+ const BLOCK_SIZE = 128u;
66+ const BLOCK_SIZE_BYTES = 18u;
67+ const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration
68+
69+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
70+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
71+ let tile_m = i / TILE_K;
72+ let tile_k_start = i % TILE_K;
73+ let global_m = offset_m + tile_m;
74+ let global_k_start = k_outer + tile_k_start;
75+
76+ if (global_m >= params.m ) {
77+ break;
78+ }
79+
80+ let block_k = global_k_start / BLOCK_SIZE;
81+ let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u;
82+ let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
83+ let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
84+ let d = load_f16_at_src0(block_byte_base);
85+ let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu;
86+
87+ for (var bit = 0u; bit < NQ; bit++) {
88+ let global_k = global_k_start + bit;
89+ if (global_k < params.k ) {
90+ shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
91+ }
92+ }
93+ }
94+ }
95+ #endif // INIT_SRC0_SHMEM_Q1_0
96+
6497#ifdef INIT_SRC0_SHMEM_Q4_0
6598const BLOCK_SIZE = 32u;
6699const BLOCK_SIZE_BYTES = 18u;
0 commit comments