Skip to content

Commit e1b0bdf

Browse files
ClementLinCFMHYangAMDillsilinThomasNing
authored
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)
* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang <mengyang@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
1 parent fc2a121 commit e1b0bdf

7 files changed

Lines changed: 217 additions & 67 deletions

File tree

example/ck_tile/02_layernorm2d/generate.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,39 @@ class layernorm_fwd_codegen:
7575
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
7676
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
7777
78+
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
79+
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
80+
static constexpr ck_tile::index_t total_warps =
81+
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
82+
83+
// num of warps along m
84+
static constexpr ck_tile::index_t BlockWarps_M = []() {
85+
if constexpr(is_warp_per_row)
86+
{
87+
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
88+
return total_warps;
89+
}
90+
else
91+
{
92+
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
93+
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
94+
}
95+
}();
96+
97+
// num of warps along n
98+
static constexpr ck_tile::index_t BlockWarps_N = []() {
99+
if constexpr(is_warp_per_row)
100+
{
101+
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
102+
return 1;
103+
}
104+
else
105+
{
106+
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
107+
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
108+
}
109+
}();
110+
78111
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
79112
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
80113

example/ck_tile/10_rmsnorm2d/generate.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,39 @@ class rmsnorm_fwd_codegen:
7575
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
7676
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
7777
78+
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
79+
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
80+
static constexpr ck_tile::index_t total_warps =
81+
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
82+
83+
// num of warps along m
84+
static constexpr ck_tile::index_t BlockWarps_M = []() {
85+
if constexpr(is_warp_per_row)
86+
{
87+
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
88+
return total_warps;
89+
}
90+
else
91+
{
92+
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
93+
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
94+
}
95+
}();
96+
97+
// num of warps along n
98+
static constexpr ck_tile::index_t BlockWarps_N = []() {
99+
if constexpr(is_warp_per_row)
100+
{
101+
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
102+
return 1;
103+
}
104+
else
105+
{
106+
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
107+
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
108+
}
109+
}();
110+
78111
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
79112
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
80113
@@ -605,15 +638,15 @@ def get_blobs(self):
605638
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
606639
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
607640
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
608-
}
641+
}
609642
}
610-
643+
611644
total_blob = list()
612645

613646
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
614647
current_trait_dict = h_trait_dicts[model_sensitive_flag]
615648
for hs_key in current_trait_dict:
616-
hs = current_trait_dict[hs_key]
649+
hs = current_trait_dict[hs_key]
617650
current_n = hs_key
618651
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
619652
prec_i, prec_o = dtype.split(',')

example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,16 @@ template <typename InDataType,
7070
bool SaveUnquant>
7171
bool run(const ck_tile::ArgParser& arg_parser)
7272
{
73-
ck_tile::index_t m = arg_parser.get_int("m");
74-
ck_tile::index_t n = arg_parser.get_int("n");
75-
float epsilon = arg_parser.get_float("e");
76-
int kname = arg_parser.get_int("kname");
77-
int do_validation = arg_parser.get_int("v");
78-
int fused_add = arg_parser.get_int("fadd");
79-
int fused_quant = arg_parser.get_int("fquant");
80-
int warmup = arg_parser.get_int("warmup");
81-
int repeat = arg_parser.get_int("repeat");
82-
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
73+
ck_tile::index_t m = arg_parser.get_int("m");
74+
ck_tile::index_t n = arg_parser.get_int("n");
75+
float epsilon = arg_parser.get_float("e");
76+
int kname = arg_parser.get_int("kname");
77+
int do_validation = arg_parser.get_int("v");
78+
int fused_add = arg_parser.get_int("fadd");
79+
int fused_quant = arg_parser.get_int("fquant");
80+
int warmup = arg_parser.get_int("warmup");
81+
int repeat = arg_parser.get_int("repeat");
82+
int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
8383

8484
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
8585
if(x_stride < 0)
@@ -196,6 +196,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
196196
return base_str;
197197
}();
198198

199+
if(n > 8192)
200+
{
201+
use_model_sensitive_rmsnorm = 0;
202+
}
203+
199204
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
200205
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
201206
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
@@ -297,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
297302
const int N = acc_.mDesc.get_lengths()[1];
298303
for(int n_ = 0; n_ < N; ++n_)
299304
{
300-
o_unquant_(m_, n_) = ck_tile::type_convert<OutDataType>(acc_(m_, n_));
305+
o_unquant_(m_, n_) = ck_tile::type_convert<UnquantYDataType>(acc_(m_, n_));
301306
}
302307

303308
dquant_functor(m_, o_, acc_);
@@ -316,7 +321,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
316321
invRms_host_ref,
317322
unquant_y_host_ref,
318323
epsilon,
319-
default_and_dquant_functor);
324+
default_and_dquant_functor,
325+
use_model_sensitive_rmsnorm);
320326
}
321327
else
322328
{
@@ -331,7 +337,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
331337
invRms_host_ref,
332338
unquant_y_host_ref,
333339
epsilon,
334-
dquant_functor);
340+
dquant_functor,
341+
use_model_sensitive_rmsnorm);
335342
}
336343
}
337344
else
@@ -343,7 +350,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
343350
YDataType,
344351
InvRmsDataType,
345352
ck_tile::null_type>(
346-
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon);
353+
x_host,
354+
gamma_host,
355+
y_host_ref,
356+
invRms_host_ref,
357+
unquant_y_null,
358+
epsilon,
359+
ck_tile::reference_rmsnorm2d_default_epilogue{},
360+
use_model_sensitive_rmsnorm);
347361
}
348362

349363
y_buf.FromDevice(y_host_dev.data());
@@ -354,6 +368,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
354368
y_residual_buf.FromDevice(y_residual_host_dev.data());
355369
}
356370

371+
if constexpr(SaveUnquant)
372+
{
373+
unquant_y_buf.FromDevice(unquant_y_host_dev.data());
374+
}
375+
357376
auto [rtol, atol] = get_elimit<YDataType>();
358377
if(x_stride == n)
359378
{
Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,85 @@
1-
#!/bin/sh
1+
#!/bin/bash
2+
23
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
34

4-
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
5-
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
6-
for pr_i in "fp16" "bf16" ; do
7-
for fadd in "0" "1"; do
8-
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
9-
for s in "0" "1"; do
10-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
11-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
12-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
13-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
14-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
15-
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
16-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
17-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
18-
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
19-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
20-
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
21-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
22-
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
23-
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
24-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
25-
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
26-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
27-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
28-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
29-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
30-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
31-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
32-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
33-
done
34-
done
35-
done
36-
done
5+
total=0
6+
valid=0
377

38-
# The following cases uses two pass pipeline which doesn't support quant epilogue.
39-
for fquant in ""
40-
for pr_i in "fp16" "bf16" ; do
41-
for fadd in "0" "1"; do
42-
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
43-
for s in "0" "1"; do
44-
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
45-
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
46-
done
47-
done
8+
run_case() {
9+
cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7"
10+
echo "[CMD] $cmd"
11+
output=$($cmd 2>&1)
12+
echo "$output"
13+
if echo "$output" | grep -q "valid:y"; then
14+
valid=$((valid + 1))
15+
fi
16+
total=$((total + 1))
17+
}
18+
19+
fquant_list=(
20+
""
21+
"-fquant=1 -prec_o=int8"
22+
"-fquant=2 -prec_o=int8"
23+
"-fquant=1 -prec_o=fp8"
24+
"-fquant=2 -prec_o=fp8"
25+
"-fquant=1 -prec_o=int8 -save_unquant=1"
26+
"-fquant=2 -prec_o=int8 -save_unquant=1"
27+
"-fquant=1 -prec_o=fp8 -save_unquant=1"
28+
"-fquant=2 -prec_o=fp8 -save_unquant=1"
29+
)
30+
31+
m_n_list=(
32+
"99 13" "17 16" "1 100" "4 128" "80 127"
33+
"7 599" "19 512" "11 510" "91 636"
34+
"31 1024" "8 1501" "3 1826" "5 2040"
35+
"7 2734" "1 3182" "9 4096" "3 8192"
36+
)
37+
38+
### Add special stride test ###
39+
m_n_stride_list=(
40+
"22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256"
41+
"33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000"
42+
"171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818"
43+
"12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800"
44+
"100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812"
45+
"64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004"
46+
)
47+
48+
for fquant in "${fquant_list[@]}"; do
49+
for pr_i in "fp16" "bf16"; do
50+
for fadd in "0" "1"; do
51+
for s in "0" "1"; do
52+
for pair in "${m_n_list[@]}"; do
53+
m=$(echo $pair | cut -d ' ' -f1)
54+
n=$(echo $pair | cut -d ' ' -f2)
55+
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" ""
56+
done
57+
58+
### Running tests with stride ###
59+
for triple in "${m_n_stride_list[@]}"; do
60+
m=$(echo $triple | cut -d ' ' -f1)
61+
n=$(echo $triple | cut -d ' ' -f2)
62+
stride_args=$(echo $triple | cut -d ' ' -f3-)
63+
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args"
64+
done
65+
done
66+
done
67+
done
4868
done
69+
70+
# Special two-pass only
71+
for pr_i in "fp16" "bf16"; do
72+
for fadd in "0" "1"; do
73+
for s in "0" "1"; do
74+
run_case "$pr_i" "$fadd" "$s" "" "1" "10547" ""
75+
done
76+
done
4977
done
78+
79+
# Summary
80+
echo "=============================="
81+
echo "Total cases: $total"
82+
echo "Valid cases: $valid"
83+
accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}")
84+
echo "Accuracy: $accuracy%"
85+
echo "=============================="

include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "ck_tile/core.hpp"
77
#include "ck_tile/host/host_tensor.hpp"
8+
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
89

910
namespace ck_tile {
1011

@@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
4344
HostTensor<InvRmsDataType>& invRms_m,
4445
HostTensor<UnquantYDataType>& unquant_y_m_n,
4546
ComputeDataType epsilon,
46-
Epilogue epilogue_functor = {})
47+
Epilogue epilogue_functor = {},
48+
const int use_model_sensitive_rmsnorm =
49+
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
4750
{
4851
auto rmsnorm2d_fwd_func = [&](auto m) {
4952
const int N = x_m_n.mDesc.get_lengths()[1];
@@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
6871
{
6972
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
7073
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
71-
acc(m, n) = x * divisor * gamma;
74+
if(use_model_sensitive_rmsnorm ==
75+
static_cast<int>(
76+
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
77+
{
78+
acc(m, n) = x * divisor * gamma;
79+
}
80+
else if(use_model_sensitive_rmsnorm ==
81+
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
82+
{
83+
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
84+
{
85+
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
86+
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
87+
type_convert<ComputeDataType>(tmp0) * gamma);
88+
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
89+
acc(m, n) = rmsn_;
90+
}
91+
else
92+
{
93+
const auto tmp = type_convert<XDataType>(x * divisor);
94+
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
95+
acc(m, n) = rmsn_;
96+
}
97+
}
7298
}
7399

74100
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
@@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
84110
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
85111
std::thread::hardware_concurrency());
86112
}
113+
87114
} // namespace ck_tile

include/ck_tile/ops/reduce/block/block_reduce2d.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,13 @@ struct BlockReduce2dTreeCrossWarpSync
400400
block_sync_lds();
401401

402402
// We let each warp holds a duplication to do reduction.
403+
const index_t local_warp_id = warp_id / num_reduce_warps;
404+
const index_t local_smem_os = local_warp_id * num_reduce_warps;
403405
static_for<0, thread_buf_size, 1>{}([&](auto i) {
404406
DataType v = 0;
405407
if(lane_id < num_reduce_warps)
406408
{
407-
v = smem_ptr[lane_id + i * num_warps];
409+
v = smem_ptr[i * num_warps + local_smem_os + lane_id];
408410
}
409411

410412
// cross-lane reduce for replication

0 commit comments

Comments
 (0)