1414
1515#include " common/common.h"
1616#include " common/recipe/recipe_common.cuh"
17+ #include " common/util/cuda_runtime.h"
1718#include " common/util/ptx.cuh"
1819#include " common/utils.cuh"
1920
@@ -167,6 +168,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
167168 }
168169 }
169170
171+ // Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
172+ // store to global memory.
173+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
174+ cudaTriggerProgrammaticLaunchCompletion ();
175+ #endif
176+
170177 // Step 3: Store cast output, Step 4: do transpose within thread tile
171178 OVecCast tmp_output_c;
172179
@@ -390,6 +397,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
390397 }
391398 }
392399
400+ // Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
401+ // store to global memory.
402+ #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
403+ cudaTriggerProgrammaticLaunchCompletion ();
404+ #endif
405+
393406 // Step 3: Store cast output, Step 4: do transpose within thread tile
394407 // Edge case: in the non-full tile case, there are three subcases
395408 // for full thread tile, it's the same thing here
@@ -511,6 +524,15 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
511524
512525 const size_t num_blocks_x = DIVUP (row_length, BLOCK_TILE_DIM);
513526 const size_t num_blocks_y = DIVUP (num_rows, BLOCK_TILE_DIM);
527+ dim3 grid (num_blocks_x, num_blocks_y, 1 );
528+ cudaLaunchAttribute attribute[1 ];
529+ attribute[0 ].id = cudaLaunchAttributeProgrammaticStreamSerialization;
530+ attribute[0 ].val .programmaticStreamSerializationAllowed = 1 ;
531+ cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0 , stream, NULL , 0 };
532+ if (transformer_engine::cuda::sm_arch (transformer_engine::cuda::current_device ()) >= 90 ) {
533+ cfg.attrs = attribute;
534+ cfg.numAttrs = 1 ;
535+ }
514536
515537 TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
516538 input.dtype , InputType,
@@ -521,7 +543,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
521543 TRANSFORMER_ENGINE_SWITCH_CONDITION (
522544 return_transpose, kReturnTranspose ,
523545
524- dim3 grid (num_blocks_x, num_blocks_y, 1 );
525546 const bool full_tile =
526547 row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0 ;
527548
@@ -531,26 +552,28 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
531552 tensor_map_output_trans =
532553 get_tensor_map<OutputType>(output_t , num_rows, row_length);
533554 }
534- block_scaled_cast_transpose_kernel<kReturnTranspose , float , InputType, OutputType>
535- <<<grid, THREADS_PER_BLOCK, 0 , stream>>> (
536- reinterpret_cast <const InputType*>(input.dptr ),
537- reinterpret_cast <OutputType*>(output.dptr ),
538- reinterpret_cast <OutputType*>(output_t .dptr ),
539- reinterpret_cast <float *>(scale_inv.dptr ),
540- reinterpret_cast <float *>(scale_inv_t .dptr ), row_length, num_rows,
541- scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
542- tensor_map_output_trans, pow_2_scale);
555+ cudaLaunchKernelEx (&cfg,
556+ block_scaled_cast_transpose_kernel<kReturnTranspose , float ,
557+ InputType, OutputType>,
558+ reinterpret_cast <const InputType*>(input.dptr ),
559+ reinterpret_cast <OutputType*>(output.dptr ),
560+ reinterpret_cast <OutputType*>(output_t .dptr ),
561+ reinterpret_cast <float *>(scale_inv.dptr ),
562+ reinterpret_cast <float *>(scale_inv_t .dptr ), row_length, num_rows,
563+ scale_stride_x, scale_stride_y, scale_t_stride_x,
564+ scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale);
543565 } else {
544- block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose , float , InputType,
545- OutputType>
546- <<<grid, THREADS_PER_BLOCK, 0 , stream>>> (
547- reinterpret_cast <const InputType*>(input.dptr ),
548- reinterpret_cast <OutputType*>(output.dptr ),
549- reinterpret_cast <OutputType*>(output_t .dptr ),
550- reinterpret_cast <float *>(scale_inv.dptr ),
551- reinterpret_cast <float *>(scale_inv_t .dptr ), row_length, num_rows,
552- scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
553- pow_2_scale);
566+ cudaLaunchKernelEx (
567+ &cfg,
568+ block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose , float ,
569+ InputType, OutputType>,
570+ reinterpret_cast <const InputType*>(input.dptr ),
571+ reinterpret_cast <OutputType*>(output.dptr ),
572+ reinterpret_cast <OutputType*>(output_t .dptr ),
573+ reinterpret_cast <float *>(scale_inv.dptr ),
574+ reinterpret_cast <float *>(scale_inv_t .dptr ), row_length, num_rows,
575+ scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
576+ pow_2_scale);
554577 } // full-tile
555578 ) // return_transpose
556579 ) // OutputType
0 commit comments