Skip to content

Commit 2b27016

Browse files
independent logical warp
1 parent 4fd8d6e commit 2b27016

1 file changed

Lines changed: 17 additions & 19 deletions

File tree

csrc/kernels.cu

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -431,62 +431,60 @@ __global__ void kQuantizeBlockwise32(
431431
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
432432
const int rand_offset, const int n
433433
) {
434-
// Fixed parameters for blocksize=32 with 4-bit
435434
constexpr int BLOCK_SIZE = 32; // Size of each quantization block
436435
constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing)
437436
constexpr int THREADS = 32; // Total threads (full warp)
438437
constexpr int THREADS_PER_BLOCK = 16; // Threads handling each quantization block
439438

440-
// Each CUDA thread block processes 2 quantization blocks of 32 values each
441439
const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per CUDA block
442440

443441
T vals[NUM_PER_TH];
444442
unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte
445443
float local_abs_max = 0.0f;
446444

447-
// Determine which quantization block this thread belongs to (0 or 1)
448445
const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
449446
const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)
450447

451448
typedef cub::BlockLoad<T, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
452449
typedef cub::BlockStore<unsigned char, THREADS, NUM_PER_TH / 2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
453-
typedef cub::WarpReduce<float> WarpReduce;
450+
typedef cub::WarpReduce<float, 16>
451+
WarpReduce; // Logical warp size of 16: threads 0-15 and 16-31 reduce independently
454452

455453
__shared__ typename LoadT::TempStorage loadt;
456454
__shared__ typename StoreChar::TempStorage storec;
457-
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One for each warp half
458-
__shared__ float smem_absmax_value[2]; // Store 2 absmax values
455+
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp
456+
__shared__ float smem_absmax_value[2];
459457

460458
const int i = base_idx + block_id * BLOCK_SIZE;
459+
// Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative
460+
// operations that require ALL 32 threads to participate
461+
const bool block_valid = (i < n);
461462

462-
// Early exit if this quantization block is out of bounds
463-
if (i >= n)
464-
return;
465-
466-
// Load 64 values total (32 threads × 2 values each)
463+
// All 32 threads participate in the load (out-of-bounds threads get 0.0f)
467464
__syncthreads();
468465
LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f);
469466

470-
// Each thread computes max of its NUM_PER_TH values
467+
// Each thread computes max of its values
471468
local_abs_max = -FLT_MAX;
472469
#pragma unroll NUM_PER_TH
473470
for (int j = 0; j < NUM_PER_TH; j++)
474471
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
475472

476-
// Warp-level reduction within each half (threads 0-15 and 16-31 separately)
473+
// Reduce within each logical warp of 16 threads independently
477474
local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX);
478475

479-
// First thread of each warp half stores the absmax
480476
if (local_thread_id == 0) {
481-
smem_absmax_value[block_id] = 1.0f / local_abs_max;
482-
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
477+
if (block_valid) {
478+
smem_absmax_value[block_id] = 1.0f / local_abs_max;
479+
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
480+
} else {
481+
smem_absmax_value[block_id] = 0.0f;
482+
}
483483
}
484484
__syncthreads();
485485

486-
// Broadcast absmax to all threads in each half
487486
local_abs_max = smem_absmax_value[block_id];
488487

489-
// Quantize values based on data type
490488
switch (DATA_TYPE) {
491489
case FP4:
492490
#pragma unroll NUM_PER_TH
@@ -504,7 +502,7 @@ __global__ void kQuantizeBlockwise32(
504502
break;
505503
}
506504

507-
// Store quantized values (all 32 threads write their outputs)
505+
// All 32 threads participate in the store (valid_items limits the actual writes)
508506
__syncthreads();
509507
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2));
510508
}

0 commit comments

Comments
 (0)