@@ -162,11 +162,8 @@ __device__ inline void bitunpack(const T *__restrict packed,
162162 uint64_t chunk_start,
163163 uint32_t chunk_len,
164164 const struct SourceOp &src) {
165- constexpr uint32_t T_BITS = sizeof (T) * 8 ;
166- constexpr uint32_t FL_CHUNK = 1024 ;
167- constexpr uint32_t LANES = FL_CHUNK / T_BITS;
168165 const uint32_t bw = src.params .bitunpack .bit_width ;
169- const uint32_t words_per_block = LANES * bw;
166+ const uint32_t words_per_block = FL_LANES<T> * bw;
170167 const uint32_t elem_off = src.params .bitunpack .element_offset ;
171168 const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK;
172169 const uint64_t first_block = (chunk_start + elem_off) / FL_CHUNK;
@@ -177,12 +174,86 @@ __device__ inline void bitunpack(const T *__restrict packed,
177174 for (uint32_t c = 0 ; c < n_chunks; ++c) {
178175 const T *src_chunk = packed + (first_block + c) * words_per_block;
179176 T *chunk_dst = dst + c * FL_CHUNK;
180- for (uint32_t lane = threadIdx .x ; lane < LANES ; lane += blockDim .x ) {
177+ for (uint32_t lane = threadIdx .x ; lane < FL_LANES<T> ; lane += blockDim .x ) {
181178 bit_unpack_lane<T>(src_chunk, chunk_dst, 0 , lane, bw);
182179 }
183180 }
184181}
185182
183+ // ═══════════════════════════════════════════════════════════════════════════
184+ // Patches
185+ // ═══════════════════════════════════════════════════════════════════════════
186+
187+ // / Parsed view into a packed patches buffer (the fused-dispatch counterpart
188+ // / of GPUPatches, which is used by the standalone per-bitwidth kernels).
189+ // / Each op with patches gets its own contiguous device allocation holding
190+ // / lane_offsets, indices, and values, referenced by a single uint64_t pointer
191+ // / (patches_ptr in BitunpackParams / AlpParams); see PackedPatchesHeader in
192+ // / patches.h for the layout.
193+ template <typename T>
194+ struct PackedPatchesView {
195+ const uint32_t *lane_offsets;
196+ uint32_t num_lane_offsets;
197+ const uint16_t *indices;
198+ const T *values;
199+ };
200+
201+ // / Parse a packed patches buffer into its component arrays.
202+ template <typename T>
203+ __device__ inline PackedPatchesView<T> parse_patches (uint64_t patches_ptr) {
204+ const uint8_t *base = reinterpret_cast <const uint8_t *>(patches_ptr);
205+ const auto *header = reinterpret_cast <const PackedPatchesHeader *>(base);
206+ return {
207+ reinterpret_cast <const uint32_t *>(base + sizeof (PackedPatchesHeader)),
208+ static_cast <uint32_t >((header->indices_byte_offset - sizeof (PackedPatchesHeader)) / sizeof (uint32_t )),
209+ reinterpret_cast <const uint16_t *>(base + header->indices_byte_offset ),
210+ reinterpret_cast <const T *>(base + header->values_byte_offset ),
211+ };
212+ }
213+
214+ // / Apply source patches for a single FL chunk.
215+ // /
216+ // / Overwrites patched positions in `out` and issues __syncthreads().
217+ template <typename T>
218+ __device__ inline void patch_fl_chunk (uint64_t patches_ptr, T *__restrict out, uint32_t fl_chunk) {
219+ const auto patches = parse_patches<T>(patches_ptr);
220+
221+ for (uint32_t lane = threadIdx .x ; lane < FL_LANES<T>; lane += blockDim .x ) {
222+ auto slot = fl_chunk * FL_LANES<T> + lane;
223+ assert (slot + 1 < patches.num_lane_offsets );
224+ auto start = patches.lane_offsets [slot];
225+ auto end = patches.lane_offsets [slot + 1 ];
226+ for (auto i = start; i < end; ++i) {
227+ out[patches.indices [i]] = patches.values [i];
228+ }
229+ }
230+ __syncthreads ();
231+ }
232+
233+ // / Apply source patches for all FL chunks in a contiguous region.
234+ // / Overwrites patched positions in `out` and issues __syncthreads().
235+ template <typename T>
236+ __device__ inline void
237+ patch_all_fl_chunks (uint64_t patches_ptr, T *__restrict out, uint32_t stage_len, uint32_t element_offset) {
238+ const auto patches = parse_patches<T>(patches_ptr);
239+
240+ const uint32_t first_chunk = element_offset / FL_CHUNK;
241+ const uint32_t n_chunks = (stage_len + (element_offset % FL_CHUNK) + FL_CHUNK - 1 ) / FL_CHUNK;
242+ for (uint32_t c = 0 ; c < n_chunks; ++c) {
243+ T *chunk_base = out + c * FL_CHUNK;
244+ for (uint32_t lane = threadIdx .x ; lane < FL_LANES<T>; lane += blockDim .x ) {
245+ auto slot = (first_chunk + c) * FL_LANES<T> + lane;
246+ assert (slot + 1 < patches.num_lane_offsets );
247+ auto start = patches.lane_offsets [slot];
248+ auto end = patches.lane_offsets [slot + 1 ];
249+ for (auto i = start; i < end; ++i) {
250+ chunk_base[patches.indices [i]] = patches.values [i];
251+ }
252+ }
253+ }
254+ __syncthreads ();
255+ }
256+
186257// / Read N values from a source op into `out`.
187258// /
188259// / Dispatches on `src.op_code` to handle each encoding:
@@ -313,11 +384,17 @@ __device__ void execute_output_stage(T *__restrict output,
313384 block_start + elem_idx,
314385 chunk_len,
315386 src);
316- constexpr uint32_t FL_CHUNK = 1024 ; // FastLanes chunk size
317387 const uint32_t align = (block_start + elem_idx + src.params .bitunpack .element_offset ) % FL_CHUNK;
318388 smem_src = scratch + align;
319389 // Write barrier: all threads finished bitunpack, safe to read from scratch.
320390 __syncthreads ();
391+
392+ // Overwrite patched positions in the decoded scratch buffer.
393+ if (src.params .bitunpack .patches_ptr != 0 ) {
394+ const uint32_t fl_chunk = static_cast <uint32_t >(
395+ (block_start + elem_idx + src.params .bitunpack .element_offset ) / FL_CHUNK);
396+ patch_fl_chunk<T>(src.params .bitunpack .patches_ptr , scratch, fl_chunk);
397+ }
321398 } else {
322399 chunk_len = block_len;
323400 }
@@ -392,12 +469,22 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
392469 const auto &src = stage.source ;
393470
394471 if (src.op_code == SourceOp::BITUNPACK) {
472+ T *raw_smem = smem_out;
395473 bitunpack<T>(reinterpret_cast <const T *>(stage.input_ptr ), smem_out, 0 , stage.len , src);
396- smem_out += src.params .bitunpack .element_offset % SMEM_TILE_SIZE;
397474 // Write barrier: cooperative bitunpack finished, safe to read
398475 // decoded elements in the scalar-op loop below.
399476 __syncthreads ();
400477
478+ // Overwrite exception positions in the decoded buffer with patch values.
479+ if (src.params .bitunpack .patches_ptr != 0 ) {
480+ patch_all_fl_chunks<T>(src.params .bitunpack .patches_ptr ,
481+ raw_smem,
482+ stage.len ,
483+ src.params .bitunpack .element_offset );
484+ }
485+
486+ smem_out += src.params .bitunpack .element_offset % SMEM_TILE_SIZE;
487+
401488 if (stage.num_scalar_ops > 0 ) {
402489 for (uint32_t i = threadIdx .x ; i < stage.len ; i += blockDim .x ) {
403490 T val = smem_out[i];
0 commit comments