@@ -331,19 +331,16 @@ __global__ void kQuantizeBlockwise(
331331 // block_dim=32 (from BLOCK_SIZE=64/NUM_PER_TH=2) is too small. Fall back
332332 // to DIRECT load/store in that case.
333333 static constexpr int THREADS = BLOCK_SIZE / NUM_PER_TH;
334- static constexpr auto LOAD_ALGO = (THREADS >= BNB_WARP_SIZE)
335- ? bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE : bnb_cub::BLOCK_LOAD_DIRECT;
336- static constexpr auto STORE_ALGO = (THREADS >= BNB_WARP_SIZE)
337- ? bnb_cub::BLOCK_STORE_WARP_TRANSPOSE : bnb_cub::BLOCK_STORE_DIRECT;
334+ static constexpr auto LOAD_ALGO =
335+ (THREADS >= BNB_WARP_SIZE) ? bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE : bnb_cub::BLOCK_LOAD_DIRECT;
336+ static constexpr auto STORE_ALGO =
337+ (THREADS >= BNB_WARP_SIZE) ? bnb_cub::BLOCK_STORE_WARP_TRANSPOSE : bnb_cub::BLOCK_STORE_DIRECT;
338338
339339 typedef bnb_cub::BlockLoad<T, THREADS, NUM_PER_TH, LOAD_ALGO> LoadT;
340- typedef bnb_cub::BlockStore<
341- unsigned char , THREADS, (DATA_TYPE > 0 ) ? NUM_PER_TH / 2 : NUM_PER_TH,
342- STORE_ALGO>
340+ typedef bnb_cub::BlockStore<unsigned char , THREADS, (DATA_TYPE > 0 ) ? NUM_PER_TH / 2 : NUM_PER_TH, STORE_ALGO>
343341 StoreChar;
344342 typedef bnb_cub::BlockReduce<float , THREADS> BlockReduce;
345- typedef bnb_cub::BlockLoad<float , THREADS, NUM_PER_TH, LOAD_ALGO>
346- LoadFloat;
343+ typedef bnb_cub::BlockLoad<float , THREADS, NUM_PER_TH, LOAD_ALGO> LoadFloat;
347344
348345 __shared__ typename LoadT::TempStorage loadt;
349346 __shared__ typename LoadFloat::TempStorage loadf;
0 commit comments