|
| 1 | +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. |
| 2 | +// SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +#include <iostream> |
| 5 | +#include <numeric> |
| 6 | +#include <initializer_list> |
| 7 | +#include <cstdlib> |
| 8 | + |
| 9 | +#include "ck/ck.hpp" |
| 10 | +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" |
| 11 | +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" |
| 12 | +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp" |
| 13 | +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" |
| 14 | +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" |
| 15 | + |
| 16 | +#include <ck/utility/data_type.hpp> |
| 17 | +#include <ck/utility/tuple.hpp> |
| 18 | + |
| 19 | +#include "ck/library/utility/check_err.hpp" |
| 20 | +#include "ck/library/utility/device_memory.hpp" |
| 21 | +#include "ck/library/utility/host_tensor.hpp" |
| 22 | +#include "ck/library/utility/host_tensor_generator.hpp" |
| 23 | +#include "ck/library/utility/literals.hpp" |
| 24 | +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" |
| 25 | + |
| 26 | +using ::ck::DeviceMem; |
| 27 | +using ::ck::hip_check_error; |
| 28 | +using ::ck::HostTensorDescriptor; |
| 29 | +using ::ck::Tensor; |
| 30 | + |
| 31 | +template <ck::index_t... Is> |
| 32 | +using S = ck::Sequence<Is...>; |
| 33 | + |
| 34 | +using F16 = ck::half_t; |
| 35 | +using F32 = float; |
| 36 | + |
| 37 | +using Row = ck::tensor_layout::gemm::RowMajor; |
| 38 | +using Col = ck::tensor_layout::gemm::ColumnMajor; |
| 39 | + |
| 40 | +using PassThrough = ck::tensor_operation::element_wise::PassThrough; |
| 41 | +using AddAdd = ck::tensor_operation::element_wise::AddAdd; |
| 42 | + |
| 43 | +using ADataType = F16; |
| 44 | +using BDataType = F16; |
| 45 | +using AccDataType = F32; |
| 46 | +using CShuffleDataType = F32; |
| 47 | +using DDataType = F16; |
| 48 | +using DsDataType = ck::Tuple<DDataType, DDataType>; |
| 49 | +using EDataType = F16; |
| 50 | + |
| 51 | +using ALayout = Row; |
| 52 | +using BLayout = Col; |
| 53 | +using DLayout = Row; |
| 54 | +using DsLayout = ck::Tuple<DLayout, DLayout>; |
| 55 | +using ELayout = Row; |
| 56 | + |
| 57 | +using AElementOp = PassThrough; |
| 58 | +using BElementOp = PassThrough; |
| 59 | +using CDEElementOp = AddAdd; |
| 60 | + |
| 61 | +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; |
| 62 | +static constexpr int NumDs = 2; |
| 63 | + |
| 64 | +using DeviceGemmInstance = |
| 65 | + ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 |
| 66 | + // clang-format off |
| 67 | +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| |
| 68 | +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| |
| 69 | +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| |
| 70 | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| 71 | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>; |
| 72 | +// clang-format on |
| 73 | + |
| 74 | +#include "run_grouped_gemm_multiple_d_example.inc" |
| 75 | + |
| 76 | +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } |
0 commit comments