Skip to content

Commit ebca615

Browse files
authored
[Common] PDL for Blockwise Quantization (NVIDIA#2066)
* enable PDL for blockwise qunatization kernels Signed-off-by: Xin Yao <xiny@nvidia.com> * add comment Signed-off-by: Xin Yao <xiny@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent ec65ba3 commit ebca615

2 files changed

Lines changed: 84 additions & 33 deletions

File tree

transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "common/common.h"
1616
#include "common/recipe/recipe_common.cuh"
17+
#include "common/util/cuda_runtime.h"
1718
#include "common/util/ptx.cuh"
1819
#include "common/utils.cuh"
1920

@@ -167,6 +168,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
167168
}
168169
}
169170

171+
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
172+
// store to global memory.
173+
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
174+
cudaTriggerProgrammaticLaunchCompletion();
175+
#endif
176+
170177
// Step 3: Store cast output, Step 4: do transpose within thread tile
171178
OVecCast tmp_output_c;
172179

@@ -390,6 +397,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
390397
}
391398
}
392399

400+
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
401+
// store to global memory.
402+
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
403+
cudaTriggerProgrammaticLaunchCompletion();
404+
#endif
405+
393406
// Step 3: Store cast output, Step 4: do transpose within thread tile
394407
// Edge case: in the non-full tile case, there are three subcases
395408
// for full thread tile, it's the same thing here
@@ -511,6 +524,15 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
511524

512525
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
513526
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
527+
dim3 grid(num_blocks_x, num_blocks_y, 1);
528+
cudaLaunchAttribute attribute[1];
529+
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
530+
attribute[0].val.programmaticStreamSerializationAllowed = 1;
531+
cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0};
532+
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) {
533+
cfg.attrs = attribute;
534+
cfg.numAttrs = 1;
535+
}
514536

515537
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
516538
input.dtype, InputType,
@@ -521,7 +543,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
521543
TRANSFORMER_ENGINE_SWITCH_CONDITION(
522544
return_transpose, kReturnTranspose,
523545

524-
dim3 grid(num_blocks_x, num_blocks_y, 1);
525546
const bool full_tile =
526547
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
527548

@@ -531,26 +552,28 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
531552
tensor_map_output_trans =
532553
get_tensor_map<OutputType>(output_t, num_rows, row_length);
533554
}
534-
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType>
535-
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
536-
reinterpret_cast<const InputType*>(input.dptr),
537-
reinterpret_cast<OutputType*>(output.dptr),
538-
reinterpret_cast<OutputType*>(output_t.dptr),
539-
reinterpret_cast<float*>(scale_inv.dptr),
540-
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
541-
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
542-
tensor_map_output_trans, pow_2_scale);
555+
cudaLaunchKernelEx(&cfg,
556+
block_scaled_cast_transpose_kernel<kReturnTranspose, float,
557+
InputType, OutputType>,
558+
reinterpret_cast<const InputType*>(input.dptr),
559+
reinterpret_cast<OutputType*>(output.dptr),
560+
reinterpret_cast<OutputType*>(output_t.dptr),
561+
reinterpret_cast<float*>(scale_inv.dptr),
562+
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
563+
scale_stride_x, scale_stride_y, scale_t_stride_x,
564+
scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale);
543565
} else {
544-
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
545-
OutputType>
546-
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
547-
reinterpret_cast<const InputType*>(input.dptr),
548-
reinterpret_cast<OutputType*>(output.dptr),
549-
reinterpret_cast<OutputType*>(output_t.dptr),
550-
reinterpret_cast<float*>(scale_inv.dptr),
551-
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
552-
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
553-
pow_2_scale);
566+
cudaLaunchKernelEx(
567+
&cfg,
568+
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float,
569+
InputType, OutputType>,
570+
reinterpret_cast<const InputType*>(input.dptr),
571+
reinterpret_cast<OutputType*>(output.dptr),
572+
reinterpret_cast<OutputType*>(output_t.dptr),
573+
reinterpret_cast<float*>(scale_inv.dptr),
574+
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
575+
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
576+
pow_2_scale);
554577
} // full-tile
555578
) // return_transpose
556579
) // OutputType

transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "common/common.h"
1818
#include "common/recipe/recipe_common.cuh"
1919
#include "common/transpose/cast_transpose.h"
20+
#include "common/util/cuda_runtime.h"
2021
#include "common/utils.cuh"
2122

2223
namespace transformer_engine {
@@ -234,6 +235,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
234235

235236
__syncthreads();
236237

238+
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
239+
// can overlap with this kernel's return rowwise.
240+
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
241+
if (!return_columnwise_gemm_ready && !return_columnwise_compact) {
242+
cudaTriggerProgrammaticLaunchCompletion();
243+
}
244+
#endif
245+
237246
// Step 2: Cast and store to output_c
238247
if (return_rowwise) {
239248
constexpr int r_stride =
@@ -325,6 +334,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
325334
}
326335
}
327336

337+
// If return columnwise, we trigger the next kernel here so that it's load from global memory
338+
// can overlap with this kernel's return columnwise.
339+
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
340+
if (return_columnwise_gemm_ready || return_columnwise_compact) {
341+
cudaTriggerProgrammaticLaunchCompletion();
342+
}
343+
#endif
344+
328345
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
329346
if (return_columnwise_gemm_ready) {
330347
constexpr int c_stride =
@@ -584,38 +601,49 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
584601

585602
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
586603
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
604+
dim3 grid(num_blocks_x, num_blocks_y, 1);
605+
cudaLaunchAttribute attribute[1];
606+
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
607+
attribute[0].val.programmaticStreamSerializationAllowed = 1;
587608

588609
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
589610
input.dtype, InputType,
590611

591612
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
592613
output.dtype, OutputType,
593614

594-
dim3 grid(num_blocks_x, num_blocks_y, 1);
595-
596615
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
597616

598617
TRANSFORMER_ENGINE_SWITCH_CONDITION(
599618
full_tile, kAligned,
600619

601620
size_t smem_bytes = kSMemSize * sizeof(InputType);
621+
622+
cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0};
623+
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >=
624+
90) {
625+
cfg.attrs = attribute;
626+
cfg.numAttrs = 1;
627+
}
602628
// shared memory must be requested up
603629
if (smem_bytes >= 48 * 1024) {
604630
cudaError_t err = cudaFuncSetAttribute(
605631
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
606632
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
607633
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
608-
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
609-
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
610-
reinterpret_cast<const InputType*>(input.dptr),
611-
reinterpret_cast<OutputType*>(output.dptr),
612-
reinterpret_cast<OutputType*>(output_t.dptr),
613-
reinterpret_cast<float*>(scale_inv.dptr),
614-
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
615-
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
616-
columnwise_option, pow2_scale);) // kAligned
617-
) // OutputType
618-
) // InputType
634+
} cudaLaunchKernelEx(&cfg,
635+
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType,
636+
OutputType>,
637+
reinterpret_cast<const InputType*>(input.dptr),
638+
reinterpret_cast<OutputType*>(output.dptr),
639+
reinterpret_cast<OutputType*>(output_t.dptr),
640+
reinterpret_cast<float*>(scale_inv.dptr),
641+
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
642+
scale_stride_x, scale_stride_y, scale_t_stride_x,
643+
scale_t_stride_y, epsilon, rowwise_option, columnwise_option,
644+
pow2_scale);) // kAligned
645+
) // OutputType
646+
) // InputType
619647
NVTE_CHECK_CUDA(cudaGetLastError());
620648
}
621649

0 commit comments

Comments
 (0)