Skip to content

Commit a38aece

Browse files
author
Thomas Ning
authored
Fix and improve the gemm quant pipeline infrastructure (#3245)
1 parent 79aae7c commit a38aece

11 files changed

Lines changed: 96 additions & 272 deletions

include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -214,22 +214,27 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
214214

215215
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
216216

217+
// ---- Lower 4 int4 values (even positions) ----
218+
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
217219
uint32_t dict_sel = a & 0x07070707;
218-
uint32_t sign = a >> 1;
219-
asm volatile("v_and_or_b32 %0, %1, %2, %3"
220-
: "=v"(final_sel)
221-
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
222-
223-
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
224-
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
220+
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
221+
uint32_t sign = a >> 1;
222+
// Build final selector:
223+
// - bit 2 of each byte (0x04) selects negative vs positive table
224+
// - 0x03020100 selects byte lanes [0,1,2,3] in order
225+
final_sel = (sign & 0x04040404) | 0x03020100;
226+
// Lookup positive and negative fp8 codes from the small register tables.
227+
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
228+
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
229+
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
225230
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
226231

232+
// ---- Upper 4 int4 values (odd positions) ----
233+
// Shift to bring the high-nibble int4s into place and repeat the process.
227234
a >>= 4;
228-
dict_sel = a & 0x07070707;
229-
sign = a >> 1;
230-
asm volatile("v_and_or_b32 %0, %1, %2, %3"
231-
: "=v"(final_sel)
232-
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
235+
dict_sel = a & 0x07070707;
236+
sign = a >> 1;
237+
final_sel = (sign & 0x04040404) | 0x03020100;
233238

234239
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
235240
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
@@ -306,22 +311,29 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
306311

307312
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
308313

314+
// ---- Lower 4 int4 values (even positions) ----
315+
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
309316
uint32_t dict_sel = a & 0x07070707;
310-
uint32_t sign = a >> 1;
311-
asm volatile("v_and_or_b32 %0, %1, %2, %3"
312-
: "=v"(final_sel)
313-
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
314317

315-
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
316-
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
318+
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
319+
uint32_t sign = a >> 1;
320+
// Build final selector:
321+
// - bit 2 of each byte (0x04) selects negative vs positive table
322+
// - 0x03020100 selects byte lanes [0,1,2,3] in order
323+
final_sel = (sign & 0x04040404) | 0x03020100;
324+
325+
// Lookup positive and negative fp8 codes from the small register tables.
326+
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
327+
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
328+
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
317329
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
318330

331+
// ---- Upper 4 int4 values (odd positions) ----
332+
// Shift to bring the high-nibble int4s into place and repeat the process.
319333
a >>= 4;
320-
dict_sel = a & 0x07070707;
321-
sign = a >> 1;
322-
asm volatile("v_and_or_b32 %0, %1, %2, %3"
323-
: "=v"(final_sel)
324-
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
334+
dict_sel = a & 0x07070707;
335+
sign = a >> 1;
336+
final_sel = (sign & 0x04040404) | 0x03020100;
325337

326338
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
327339
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);

include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct BaseGemmPipelineAgBgCrCompV3
3030
{
3131
if(BlockHasHotloop(num_loop))
3232
{
33-
return TailNumber::Full;
33+
return TailNumber::Odd;
3434
}
3535
else
3636
{
@@ -52,40 +52,36 @@ struct BaseGemmPipelineAgBgCrCompV3
5252
// Handle all the valid cases.
5353
if(has_hot_loop)
5454
{
55-
if(tail_number == TailNumber::Full)
55+
if(tail_number == ck_tile::TailNumber::Odd)
5656
{
57-
return run_func(bool_constant<true>{},
58-
integral_constant<TailNumber, TailNumber::Full>{});
57+
return run_func(
58+
ck_tile::bool_constant<true>{},
59+
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
5960
}
6061
}
6162
else
6263
{
63-
if(tail_number == TailNumber::Odd)
64+
65+
if(tail_number == ck_tile::TailNumber::Odd)
6466
{
65-
return run_func(bool_constant<false>{},
66-
integral_constant<TailNumber, TailNumber::Odd>{});
67+
return run_func(
68+
ck_tile::bool_constant<false>{},
69+
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
6770
}
68-
else if(tail_number == TailNumber::Even)
71+
else if(tail_number == ck_tile::TailNumber::Even)
6972
{
70-
return run_func(bool_constant<false>{},
71-
integral_constant<TailNumber, TailNumber::Even>{});
73+
return run_func(
74+
ck_tile::bool_constant<false>{},
75+
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
7276
}
7377
}
7478
#if defined(__HIP_DEVICE_COMPILE__)
7579
// This path should be unreachable in device code if tail_number is valid.
7680
__builtin_unreachable();
7781
#else
7882
// If execution reaches here, it's an invalid combination of arguments.
79-
if(has_hot_loop)
80-
{
81-
throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
82-
"be TailNumber::Full.");
83-
}
84-
else
85-
{
86-
throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
87-
"be TailNumber::Odd or TailNumber::Even.");
88-
}
83+
throw std::logic_error("Invalid TailNumber value: must be "
84+
"TailNumber::Odd or TailNumber::Even");
8985
#endif
9086
}
9187
};
@@ -588,7 +584,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
588584
} while(i < (num_loop - 1));
589585
}
590586
// tail
591-
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
587+
if constexpr(TailNum == TailNumber::Odd)
592588
{
593589
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
594590
// latency

include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,8 @@ struct QuantGemmKernel
786786
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
787787
return make_naive_tensor_view<address_space_enum::global>(
788788
bq_ptr,
789-
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
790-
make_tuple(1, kargs.stride_BQ),
789+
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
790+
make_tuple(kargs.stride_BQ, 1),
791791
number<GemmPipeline::GetVectorSizeBQ()>{},
792792
number<1>{});
793793
}
@@ -1030,9 +1030,9 @@ struct QuantGemmKernel
10301030
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
10311031
return make_tile_window(
10321032
bq_pad_view,
1033-
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
1034-
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
1035-
{0, i_n / QuantGroupSize::kN});
1033+
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
1034+
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
1035+
{i_n / QuantGroupSize::kN, 0});
10361036
}
10371037
}
10381038
else

include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,68 +15,9 @@
1515

1616
namespace ck_tile {
1717

18-
template <typename Problem>
19-
struct BaseAQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
20-
{
21-
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
22-
{
23-
if(num_loop % BaseGemmPipelineAgBgCrCompV3<Problem>::PrefetchStages == 0)
24-
{
25-
return TailNumber::Even;
26-
}
27-
else
28-
{
29-
return TailNumber::Odd;
30-
}
31-
}
32-
template <typename RunFunction>
33-
CK_TILE_HOST_DEVICE static auto
34-
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
35-
{
36-
if(has_hot_loop)
37-
{
38-
if(tail_number == ck_tile::TailNumber::Odd)
39-
{
40-
return run_func(
41-
ck_tile::bool_constant<true>{},
42-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
43-
}
44-
else if(tail_number == ck_tile::TailNumber::Even)
45-
{
46-
return run_func(
47-
ck_tile::bool_constant<true>{},
48-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
49-
}
50-
else
51-
{
52-
throw std::runtime_error("Unsupported tail number for this operation !!!");
53-
}
54-
}
55-
else
56-
{
57-
58-
if(tail_number == ck_tile::TailNumber::Odd)
59-
{
60-
return run_func(
61-
ck_tile::bool_constant<false>{},
62-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
63-
}
64-
else if(tail_number == ck_tile::TailNumber::Even)
65-
{
66-
return run_func(
67-
ck_tile::bool_constant<false>{},
68-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
69-
}
70-
else
71-
{
72-
throw std::runtime_error("Unsupported tail number for this operation !!!");
73-
}
74-
}
75-
}
76-
};
77-
18+
// ToDo: Change the Pipeline to actual memory pipeline.
7819
template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
79-
struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Problem>
20+
struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
8021
{
8122
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
8223
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;

include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -14,74 +14,8 @@
1414

1515
namespace ck_tile {
1616

17-
// Compute optimized pipeline
18-
// GlobalPrefetchStages: 2
19-
// LocalPreFillStages: 1
20-
// LocalPreFetchStages: 1
21-
// LocalSharedMemoryBuffer: 1
22-
23-
template <typename Problem>
24-
struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
25-
{
26-
template <typename RunFunction>
27-
CK_TILE_HOST_DEVICE static auto
28-
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
29-
{
30-
if(has_hot_loop)
31-
{
32-
if(tail_number == ck_tile::TailNumber::Full)
33-
{
34-
return run_func(
35-
ck_tile::bool_constant<true>{},
36-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
37-
}
38-
else if(tail_number == ck_tile::TailNumber::Odd)
39-
{
40-
return run_func(
41-
ck_tile::bool_constant<true>{},
42-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
43-
}
44-
else if(tail_number == ck_tile::TailNumber::Even)
45-
{
46-
return run_func(
47-
ck_tile::bool_constant<true>{},
48-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
49-
}
50-
else
51-
{
52-
throw std::runtime_error("Unsupported tail number for this operation !!!");
53-
}
54-
}
55-
else
56-
{
57-
if(tail_number == ck_tile::TailNumber::Full)
58-
{
59-
return run_func(
60-
ck_tile::bool_constant<false>{},
61-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
62-
}
63-
else if(tail_number == ck_tile::TailNumber::Odd)
64-
{
65-
return run_func(
66-
ck_tile::bool_constant<false>{},
67-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
68-
}
69-
else if(tail_number == ck_tile::TailNumber::Even)
70-
{
71-
return run_func(
72-
ck_tile::bool_constant<false>{},
73-
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
74-
}
75-
else
76-
{
77-
throw std::runtime_error("Unsupported tail number for this operation !!!");
78-
}
79-
}
80-
}
81-
};
82-
8317
template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
84-
struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3<Problem>
18+
struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
8519
{
8620
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
8721
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;

include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
7171
tile_distribution_encoding_pattern_bq<BlockGemmShape,
7272
WarpGemm,
7373
BlockSize,
74-
KPerBlockBQ,
7574
NPerBlockBQ,
75+
KPerBlockBQ,
7676
Problem::QuantGroupSize::kN>;
7777

7878
return TileEncodingPattern::make_2d_static_tile_distribution();

0 commit comments

Comments
 (0)