@@ -431,62 +431,60 @@ __global__ void kQuantizeBlockwise32(
431431 float * code, T* __restrict__ const A, float * absmax, unsigned char * out, float * __restrict__ const rand,
432432 const int rand_offset, const int n
433433) {
434- // Fixed parameters for blocksize=32 with 4-bit
435434 constexpr int BLOCK_SIZE = 32 ; // Size of each quantization block
436435 constexpr int NUM_PER_TH = 2 ; // Values per thread (for 4-bit packing)
437436 constexpr int THREADS = 32 ; // Total threads (full warp)
438437 constexpr int THREADS_PER_BLOCK = 16 ; // Threads handling each quantization block
439438
440- // Each CUDA thread block processes 2 quantization blocks of 32 values each
441439 const int base_idx = blockIdx .x * BLOCK_SIZE * 2 ; // 2 blocks per CUDA block
442440
443441 T vals[NUM_PER_TH];
444442 unsigned char qvals[NUM_PER_TH / 2 ]; // For 4-bit: 2 values per byte
445443 float local_abs_max = 0 .0f ;
446444
447- // Determine which quantization block this thread belongs to (0 or 1)
448445 const int block_id = threadIdx .x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
449446 const int local_thread_id = threadIdx .x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)
450447
451448 typedef cub::BlockLoad<T, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
452449 typedef cub::BlockStore<unsigned char , THREADS, NUM_PER_TH / 2 , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
453- typedef cub::WarpReduce<float > WarpReduce;
450+ typedef cub::WarpReduce<float , 16 >
451+ WarpReduce; // Logical warp size of 16: threads 0-15 and 16-31 reduce independently
454452
455453 __shared__ typename LoadT::TempStorage loadt;
456454 __shared__ typename StoreChar::TempStorage storec;
457- __shared__ typename WarpReduce::TempStorage warp_reduce[2 ]; // One for each warp half
458- __shared__ float smem_absmax_value[2 ]; // Store 2 absmax values
455+ __shared__ typename WarpReduce::TempStorage warp_reduce[2 ]; // One per logical warp
456+ __shared__ float smem_absmax_value[2 ];
459457
460458 const int i = base_idx + block_id * BLOCK_SIZE;
459+ // Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
460+ // operations that require ALL 32 threads to participate
461+ const bool block_valid = (i < n);
461462
462- // Early exit if this quantization block is out of bounds
463- if (i >= n)
464- return ;
465-
466- // Load 64 values total (32 threads × 2 values each)
463+ // All 32 threads participate in the load (out-of-bounds threads get 0.0f)
467464 __syncthreads ();
468465 LoadT (loadt).Load (&(A[base_idx]), vals, min (BLOCK_SIZE * 2 , n - base_idx), (T)0 .0f );
469466
470- // Each thread computes max of its NUM_PER_TH values
467+ // Each thread computes max of its values
471468 local_abs_max = -FLT_MAX;
472469#pragma unroll NUM_PER_TH
473470 for (int j = 0 ; j < NUM_PER_TH; j++)
474471 local_abs_max = fmaxf (local_abs_max, fabsf ((float )vals[j]));
475472
476- // Warp-level reduction within each half (threads 0-15 and 16-31 separately)
473+ // Reduce within each logical warp of 16 threads independently
477474 local_abs_max = WarpReduce (warp_reduce[block_id]).Reduce (local_abs_max, CUB_REDUCTIONOP_MAX);
478475
479- // First thread of each warp half stores the absmax
480476 if (local_thread_id == 0 ) {
481- smem_absmax_value[block_id] = 1 .0f / local_abs_max;
482- absmax[blockIdx .x * 2 + block_id] = local_abs_max;
477+ if (block_valid) {
478+ smem_absmax_value[block_id] = 1 .0f / local_abs_max;
479+ absmax[blockIdx .x * 2 + block_id] = local_abs_max;
480+ } else {
481+ smem_absmax_value[block_id] = 0 .0f ;
482+ }
483483 }
484484 __syncthreads ();
485485
486- // Broadcast absmax to all threads in each half
487486 local_abs_max = smem_absmax_value[block_id];
488487
489- // Quantize values based on data type
490488 switch (DATA_TYPE) {
491489 case FP4:
492490#pragma unroll NUM_PER_TH
@@ -504,7 +502,7 @@ __global__ void kQuantizeBlockwise32(
504502 break ;
505503 }
506504
507- // Store quantized values (all 32 threads write their outputs )
505+ // All 32 threads participate in the store (valid_items limits the actual writes )
508506 __syncthreads ();
509507 StoreChar (storec).Store (&(out[base_idx / 2 ]), qvals, min ((BLOCK_SIZE * 2 + 1 ) / 2 , (n - base_idx + 1 ) / 2 ));
510508}
0 commit comments