Skip to content

Commit 634634f

Browse files
aledudekaosewski
andauthored
[CK_TILE] Blockwise GEMM pipeline v6 - port of v5 from old CK (#2955)
* First checkpoint * Second checkpoint - hot loop scheduler * Third checkpoint - init main operator * Fourth checkpoint - main loop ready * Fifth checkpoint - main loop fix * Sixth checkpoint - ReadWritecompFunc * Seventh checkpoint - Tail finished * [CK_TILE] Blockwise gemm pipeline v5 complete * Working * Working fixes 2 * Rename v5 to v77 temporarily * Data type adjustment * Data type adjustment 2 * [CK_TILE] Blockwise Gemm pipeline v5 add tests * [CK_TILE] Fix calculation error * TEMP: check pipeline * Fix name to V6 * naming and documentation changes * WIP dump * Try fixing v1 * Failing tests v5 * Debugging * Changes v2 * F16 tests working great * Working BlockwiseGemmPipelineV5 as V6 * Cleanup and format * Merging changes part1 * [CK_TILE] Blockwise Gemm Pipeline Comp V5/V6 * Remove commented code * Fix gfx950 build issues * Fix file formatting * Review changes, more concat info, add bf16 bf8 tests * Fix formatting * Add bf16 and bf8 tests --------- Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
1 parent 3021604 commit 634634f

8 files changed

Lines changed: 924 additions & 12 deletions

File tree

example/ck_tile/03_gemm/gemm_utils.hpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
#define CK_TILE_PIPELINE_MEMORY 2
1717
#define CK_TILE_PIPELINE_COMPUTE_V4 3
1818
#define CK_TILE_PIPELINE_COMPUTE_V5 4
19-
#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5
20-
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6
19+
#define CK_TILE_PIPELINE_COMPUTE_V6 5
20+
#define CK_TILE_PIPELINE_PRESHUFFLE_V1 6
21+
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 7
2122

2223
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
2324
constexpr ck_tile::index_t get_k_warp_tile()
@@ -251,9 +252,29 @@ struct GemmConfigComputeV5 : public GemmConfigBase
251252
static constexpr ck_tile::index_t N_Warp_Tile = 32;
252253
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
253254

254-
static constexpr bool DoubleSmemBuffer = false;
255-
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
256-
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
255+
static constexpr bool DoubleSmemBuffer = false;
256+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
257+
static constexpr ck_tile::index_t NumWaveGroups = 2;
258+
};
259+
260+
template <typename PrecType>
261+
struct GemmConfigComputeV6 : public GemmConfigBase
262+
{
263+
static constexpr ck_tile::index_t M_Tile = 256;
264+
static constexpr ck_tile::index_t N_Tile = 256;
265+
static constexpr ck_tile::index_t K_Tile = 32;
266+
267+
static constexpr ck_tile::index_t M_Warp = 2;
268+
static constexpr ck_tile::index_t N_Warp = 2;
269+
static constexpr ck_tile::index_t K_Warp = 1;
270+
271+
static constexpr ck_tile::index_t M_Warp_Tile = 32;
272+
static constexpr ck_tile::index_t N_Warp_Tile = 32;
273+
static constexpr ck_tile::index_t K_Warp_Tile = 16;
274+
275+
static constexpr bool DoubleSmemBuffer = false;
276+
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6;
277+
static constexpr ck_tile::index_t NumWaveGroups = 1;
257278
};
258279

259280
template <typename PrecType>
@@ -484,6 +505,15 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
484505
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
485506
};
486507

508+
template <>
509+
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
510+
{
511+
template <typename PipelineProblem>
512+
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
513+
template <typename PipelineProblem>
514+
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
515+
};
516+
487517
template <>
488518
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V1>
489519
{

include/ck_tile/ops/gemm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
4545
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
4646
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
47+
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp"
48+
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp"
4749
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
4850
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
4951
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"

0 commit comments

Comments
 (0)