Skip to content

Commit 5d947a0

Browse files
authored
Fix the race in the dbias computation in MXFP8 quantization and grouped quantization kernel (NVIDIA#2921)
Fix the race in the dbias computation Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
1 parent 0c2e7b0 commit 5d947a0

2 files changed

Lines changed: 4 additions & 0 deletions

File tree

transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
713713
if constexpr (COLWISE_SCALING) {
714714
thread_partial_dbias = partial_dbias_colwise;
715715
} else {
716+
ptx::cp_async_bulk_wait_group_read<0>();
717+
__syncthreads();
716718
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
717719

718720
constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);

transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
498498
if constexpr (COLWISE_SCALING) {
499499
thread_partial_dbias = partial_dbias_colwise;
500500
} else {
501+
ptx::cp_async_bulk_wait_group_read<0>();
502+
__syncthreads();
501503
// Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
502504
// HEIGHT = THREADS_Y
503505
// WIDTH = THREADS_X * (SCALE_DIM_X + 1)

0 commit comments

Comments
 (0)