Skip to content

Commit 46e68f0

Browse files
committed
Merge remote-tracking branch 'origin/develop' into tenpercent/persistent_async_scheduler_for_args
2 parents af1265a + 00c4678 commit 46e68f0

173 files changed

Lines changed: 8016 additions & 2580 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Jenkinsfile

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ def cmake_build(Map conf=[:]){
574574
def setup_cmd
575575
def build_cmd
576576
def execute_cmd = conf.get("execute_cmd", "")
577+
//check the node gpu architecture
578+
def arch_name = check_arch_name()
577579
if(!setup_args.contains("NO_CK_BUILD")){
578580
if (params.NINJA_BUILD_TRACE) {
579581
echo "running ninja build trace"
@@ -646,15 +648,15 @@ def cmake_build(Map conf=[:]){
646648

647649
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
648650
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
649-
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${check_arch_name()}.json"
650-
archiveArtifacts "ck_build_trace_${check_arch_name()}.json"
651-
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${check_arch_name()}.json"
651+
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json"
652+
archiveArtifacts "ck_build_trace_${arch_name}.json"
653+
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json"
652654
if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){
653655
if (params.NINJA_FTIME_TRACE) {
654656
echo "running ClangBuildAnalyzer"
655657
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
656-
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${check_arch_name()}.log"
657-
archiveArtifacts "clang_build_analysis_${check_arch_name()}.log"
658+
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log"
659+
archiveArtifacts "clang_build_analysis_${arch_name}.log"
658660
}
659661

660662

@@ -672,8 +674,8 @@ def cmake_build(Map conf=[:]){
672674
if(params.BUILD_PACKAGES){
673675
echo "Build ckProfiler packages"
674676
sh 'ninja -j64 package'
675-
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb"
676-
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}"
677+
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
678+
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
677679
}
678680
}
679681
if(params.BUILD_INSTANCES_ONLY){
@@ -699,16 +701,14 @@ def cmake_build(Map conf=[:]){
699701
if(params.BUILD_PACKAGES){
700702
echo "Build ckProfiler packages"
701703
sh 'ninja -j64 package'
702-
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb"
703-
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}"
704+
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
705+
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
704706
}
705707
}
706708
}
707709
}
708710
}
709711

710-
//check the node gpu architecture
711-
def arch_name = check_arch_name()
712712
if (params.RUN_CK_TILE_FMHA_TESTS){
713713
try{
714714
archiveArtifacts "perf_fmha_*.log"

example/15_grouped_gemm/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
4444
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
4545
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
4646

47+
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
48+
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
49+
4750
list(APPEND gpu_list_tf32 gfx942 gfx950)
4851
set(target 0)
4952
foreach(gpu IN LISTS GPU_TARGETS)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include <iostream>
5+
#include <numeric>
6+
#include <initializer_list>
7+
#include <cstdlib>
8+
9+
#include "ck/ck.hpp"
10+
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
11+
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
12+
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
13+
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
14+
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
15+
16+
#include <ck/utility/data_type.hpp>
17+
#include <ck/utility/tuple.hpp>
18+
19+
#include "ck/library/utility/check_err.hpp"
20+
#include "ck/library/utility/device_memory.hpp"
21+
#include "ck/library/utility/host_tensor.hpp"
22+
#include "ck/library/utility/host_tensor_generator.hpp"
23+
#include "ck/library/utility/literals.hpp"
24+
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
25+
26+
using ::ck::DeviceMem;
27+
using ::ck::hip_check_error;
28+
using ::ck::HostTensorDescriptor;
29+
using ::ck::Tensor;
30+
31+
template <ck::index_t... Is>
32+
using S = ck::Sequence<Is...>;
33+
34+
using F16 = ck::half_t;
35+
using F32 = float;
36+
37+
using Row = ck::tensor_layout::gemm::RowMajor;
38+
using Col = ck::tensor_layout::gemm::ColumnMajor;
39+
40+
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
41+
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
42+
43+
using ADataType = F16;
44+
using BDataType = F16;
45+
using AccDataType = F32;
46+
using CShuffleDataType = F32;
47+
using DDataType = F16;
48+
using DsDataType = ck::Tuple<DDataType, DDataType>;
49+
using EDataType = F16;
50+
51+
using ALayout = Row;
52+
using BLayout = Col;
53+
using DLayout = Row;
54+
using DsLayout = ck::Tuple<DLayout, DLayout>;
55+
using ELayout = Row;
56+
57+
using AElementOp = PassThrough;
58+
using BElementOp = PassThrough;
59+
using CDEElementOp = AddAdd;
60+
61+
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
62+
static constexpr int NumDs = 2;
63+
64+
using DeviceGemmInstance =
65+
ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
66+
// clang-format off
67+
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
68+
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
69+
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
70+
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
71+
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>;
72+
// clang-format on
73+
74+
#include "run_grouped_gemm_multiple_d_example.inc"
75+
76+
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

0 commit comments

Comments
 (0)