Skip to content

Commit 10a782d

Browse files
authored
Fix template parameter macros (#3305)
Some of the device implementation templates have macros like GridwiseGemmMultiABDTemplateParameters that can cause build errors if multiple files are included together. This error comes up with our builder code. To clean up the macros and make them safer, we follow these follow rules: * Use more specific names to avoid duplication. * Undefine the macro after it is used to avoid leaking out of the file scope. * Use a prefix CK_ on the macro to avoid conflicting with other libraries. * Use all caps with underscores for preprocessor macro names.
1 parent 35a4b26 commit 10a782d

4 files changed

Lines changed: 32 additions & 22 deletions

include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
446446
using GemmADataType = ck::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
447447
using GemmBDataType = ck::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
448448

449-
#define GridwiseGemmMultiABDTemplateParameters \
449+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS \
450450
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
451451
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
452452
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
@@ -462,7 +462,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
462462
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
463463
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
464464

465-
#define GridwiseGemmTemplateParameters \
465+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS \
466466
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
467467
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
468468
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -480,8 +480,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
480480
template <index_t NXdlPerWave_>
481481
using GridwiseGemmBase = ck::conditional_t<
482482
isMultiA || isMultiB,
483-
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
484-
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
483+
GridwiseGemmMultipleABD_xdl_cshuffle<CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS>,
484+
GridwiseGemmMultipleD_xdl_cshuffle<CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS>>;
485+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS
486+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS
485487
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
486488
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
487489

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
439439
}
440440

441441
// GridwiseGemm
442-
#define GridwiseGemmMultiDTemplateParams \
442+
#define CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS \
443443
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
444444
AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
445445
MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
@@ -454,7 +454,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
454454
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
455455
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
456456

457-
#define GridwiseGemmCTransposeTemplateParameters \
457+
#define CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS \
458458
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
459459
BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
460460
NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \
@@ -470,10 +470,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
470470
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
471471

472472
template <index_t NXdlPerWave_>
473-
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>;
473+
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
474+
CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS>;
474475
template <index_t NXdlPerWave_>
475-
using GridwiseGemmCTransposeBase =
476-
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>;
476+
using GridwiseGemmCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
477+
CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS>;
478+
#undef CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS
479+
#undef CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS
477480
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
478481
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
479482

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
485485
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
486486
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
487487

488-
#define GridwiseGemmMultiABDTemplateParameters \
488+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
489489
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
490490
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
491491
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
@@ -502,7 +502,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
502502
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
503503
BComputeDataType
504504

505-
#define GridwiseGemmTemplateParameters \
505+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
506506
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
507507
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
508508
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -518,7 +518,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
518518
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
519519
BComputeDataType, DoElementwiseBeforeCShuffle
520520

521-
#define GridwiseGemmCTransposeTemplateParameters \
521+
#define CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
522522
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
523523
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
524524
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
@@ -536,14 +536,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
536536

537537
// Use appropriate gridwise gemm
538538
template <index_t NXdlPerWave_>
539-
using GridwiseGemmMultipleABDBase =
540-
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>;
539+
using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle<
540+
CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
541541
template <index_t NXdlPerWave_>
542-
using GridwiseGemmMultipleDBase =
543-
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
542+
using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle<
543+
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
544544
template <index_t NXdlPerWave_>
545-
using GridwiseGemmMultipleDCTransposeBase =
546-
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>;
545+
using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
546+
CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
547+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
548+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
549+
#undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
547550

548551
using GridwiseGemm64 =
549552
std::conditional_t<isMultiA || isMultiB,

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
405405
is_split_valid);
406406
}
407407

408-
#define GridwiseGemmTemplateParameters \
408+
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS \
409409
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
410410
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
411411
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -422,9 +422,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
422422
AComputeDataType, DoElementwiseBeforeCShuffle
423423
// Use appropriate gridwise gemm
424424
template <index_t NXdlPerWave_>
425-
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
426-
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
427-
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
425+
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
426+
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>;
427+
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS
428+
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
429+
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
428430

429431
// desc for blockwise copy
430432
using AGridDesc_AK0_M_AK1 =

0 commit comments

Comments
 (0)