@@ -162,11 +162,8 @@ __device__ inline void bitunpack(const T *__restrict packed,
162162 uint64_t chunk_start,
163163 uint32_t chunk_len,
164164 const struct SourceOp &src) {
165- constexpr uint32_t T_BITS = sizeof (T) * 8 ;
166- constexpr uint32_t FL_CHUNK = 1024 ;
167- constexpr uint32_t LANES = FL_CHUNK / T_BITS;
168165 const uint32_t bw = src.params .bitunpack .bit_width ;
169- const uint32_t words_per_block = LANES * bw;
166+ const uint32_t words_per_block = FL_LANES<T> * bw;
170167 const uint32_t elem_off = src.params .bitunpack .element_offset ;
171168 const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK;
172169 const uint64_t first_block = (chunk_start + elem_off) / FL_CHUNK;
@@ -177,12 +174,77 @@ __device__ inline void bitunpack(const T *__restrict packed,
177174 for (uint32_t c = 0 ; c < n_chunks; ++c) {
178175 const T *src_chunk = packed + (first_block + c) * words_per_block;
179176 T *chunk_dst = dst + c * FL_CHUNK;
180- for (uint32_t lane = threadIdx .x ; lane < LANES ; lane += blockDim .x ) {
177+ for (uint32_t lane = threadIdx .x ; lane < FL_LANES<T> ; lane += blockDim .x ) {
181178 bit_unpack_lane<T>(src_chunk, chunk_dst, 0 , lane, bw);
182179 }
183180 }
184181}
185182
183+ // ═══════════════════════════════════════════════════════════════════════════
184+ // Patches
185+ // ═══════════════════════════════════════════════════════════════════════════
186+
187+ // / Parsed view into a packed patches buffer (the fused-dispatch counterpart
188+ // / of GPUPatches, which is used by the standalone per-bitwidth kernels).
189+ // / Each op with patches gets its own contiguous device allocation holding
190+ // / chunk_offsets, indices, and values, referenced by a single uint64_t
191+ // / pointer (patches_ptr in BitunpackParams); see PackedPatchesHeader in
192+ // / patches.h for the layout.
193+ template <typename T>
194+ struct PackedPatchesView {
195+ const uint32_t *chunk_offsets; // n_chunks+1 entries (sentinel)
196+ uint32_t n_chunks;
197+ const uint16_t *indices; // within-chunk positions (0–1023)
198+ const T *values;
199+ };
200+
201+ // / Parse a packed patches buffer into its component arrays.
202+ template <typename T>
203+ __device__ inline PackedPatchesView<T> parse_patches (uint64_t patches_ptr) {
204+ const uint8_t *base = reinterpret_cast <const uint8_t *>(patches_ptr);
205+ const auto *header = reinterpret_cast <const PackedPatchesHeader *>(base);
206+ return {
207+ reinterpret_cast <const uint32_t *>(base + sizeof (PackedPatchesHeader)),
208+ header->n_chunks ,
209+ reinterpret_cast <const uint16_t *>(base + header->indices_byte_offset ),
210+ reinterpret_cast <const T *>(base + header->values_byte_offset ),
211+ };
212+ }
213+
214+ // / Overwrite exception positions in `out` for a single chunk.
215+ // / All threads in the block cooperate. Caller must issue __syncthreads()
216+ // / afterward if other threads read from `out`.
217+ template <typename T>
218+ __device__ __noinline__ void apply_patches (uint64_t patches_ptr, T *__restrict out, uint32_t chunk) {
219+ const auto patches = parse_patches<T>(patches_ptr);
220+ assert (chunk + 1 <= patches.n_chunks );
221+ uint32_t start = patches.chunk_offsets [chunk];
222+ uint32_t end = patches.chunk_offsets [chunk + 1 ];
223+ for (uint32_t i = start + threadIdx .x ; i < end; i += blockDim .x ) {
224+ out[patches.indices [i]] = patches.values [i];
225+ }
226+ }
227+
228+ // / Overwrite exception positions in `out` for a range of chunks.
229+ // / All threads in the block cooperate. Caller must issue __syncthreads()
230+ // / afterward if other threads read from `out`.
231+ template <typename T>
232+ __device__ __noinline__ void
233+ apply_patches_range (uint64_t patches_ptr, T *__restrict out, uint32_t stage_len, uint32_t element_offset) {
234+ const auto patches = parse_patches<T>(patches_ptr);
235+ const uint32_t first_chunk = element_offset / FL_CHUNK;
236+ const uint32_t n_chunks = (stage_len + (element_offset % FL_CHUNK) + FL_CHUNK - 1 ) / FL_CHUNK;
237+ assert (first_chunk + n_chunks <= patches.n_chunks );
238+ for (uint32_t c = 0 ; c < n_chunks; ++c) {
239+ T *chunk_base = out + c * FL_CHUNK;
240+ uint32_t start = patches.chunk_offsets [first_chunk + c];
241+ uint32_t end = patches.chunk_offsets [first_chunk + c + 1 ];
242+ for (uint32_t i = start + threadIdx .x ; i < end; i += blockDim .x ) {
243+ chunk_base[patches.indices [i]] = patches.values [i];
244+ }
245+ }
246+ }
247+
186248// / Read N values from a source op into `out`.
187249// /
188250// / Dispatches on `src.op_code` to handle each encoding:
0 commit comments