|
1 | | -#!/bin/sh |
| 1 | +#!/bin/bash |
| 2 | + |
2 | 3 | EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" |
3 | 4 |
|
4 | | -for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\ |
5 | | - "-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do |
6 | | -for pr_i in "fp16" "bf16" ; do |
7 | | -for fadd in "0" "1"; do |
8 | | -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm |
9 | | -for s in "0" "1"; do |
10 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13 |
11 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16 |
12 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100 |
13 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128 |
14 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127 |
15 | | -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256 |
16 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599 |
17 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512 |
18 | | -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000 |
19 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510 |
20 | | -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818 |
21 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636 |
22 | | -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800 |
23 | | -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812 |
24 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024 |
25 | | -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004 |
26 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501 |
27 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826 |
28 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040 |
29 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734 |
30 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182 |
31 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096 |
32 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192 |
33 | | -done |
34 | | -done |
35 | | -done |
36 | | -done |
| 5 | +total=0 |
| 6 | +valid=0 |
37 | 7 |
|
38 | | -# The following cases uses two pass pipeline which doesn't support quant epilogue. |
39 | | -for fquant in "" |
40 | | -for pr_i in "fp16" "bf16" ; do |
41 | | -for fadd in "0" "1"; do |
42 | | -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm |
43 | | -for s in "0" "1"; do |
44 | | -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547 |
45 | | -#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 |
46 | | -done |
47 | | -done |
| 8 | +run_case() { |
| 9 | + cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7" |
| 10 | + echo "[CMD] $cmd" |
| 11 | + output=$($cmd 2>&1) |
| 12 | + echo "$output" |
| 13 | + if echo "$output" | grep -q "valid:y"; then |
| 14 | + valid=$((valid + 1)) |
| 15 | + fi |
| 16 | + total=$((total + 1)) |
| 17 | +} |
| 18 | + |
| 19 | +fquant_list=( |
| 20 | + "" |
| 21 | + "-fquant=1 -prec_o=int8" |
| 22 | + "-fquant=2 -prec_o=int8" |
| 23 | + "-fquant=1 -prec_o=fp8" |
| 24 | + "-fquant=2 -prec_o=fp8" |
| 25 | + "-fquant=1 -prec_o=int8 -save_unquant=1" |
| 26 | + "-fquant=2 -prec_o=int8 -save_unquant=1" |
| 27 | + "-fquant=1 -prec_o=fp8 -save_unquant=1" |
| 28 | + "-fquant=2 -prec_o=fp8 -save_unquant=1" |
| 29 | +) |
| 30 | + |
| 31 | +m_n_list=( |
| 32 | + "99 13" "17 16" "1 100" "4 128" "80 127" |
| 33 | + "7 599" "19 512" "11 510" "91 636" |
| 34 | + "31 1024" "8 1501" "3 1826" "5 2040" |
| 35 | + "7 2734" "1 3182" "9 4096" "3 8192" |
| 36 | +) |
| 37 | + |
| 38 | +### Add special stride test ### |
| 39 | +m_n_stride_list=( |
| 40 | + "22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256" |
| 41 | + "33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000" |
| 42 | + "171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818" |
| 43 | + "12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800" |
| 44 | + "100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812" |
| 45 | + "64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004" |
| 46 | +) |
| 47 | + |
| 48 | +for fquant in "${fquant_list[@]}"; do |
| 49 | + for pr_i in "fp16" "bf16"; do |
| 50 | + for fadd in "0" "1"; do |
| 51 | + for s in "0" "1"; do |
| 52 | + for pair in "${m_n_list[@]}"; do |
| 53 | + m=$(echo $pair | cut -d ' ' -f1) |
| 54 | + n=$(echo $pair | cut -d ' ' -f2) |
| 55 | + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "" |
| 56 | + done |
| 57 | + |
| 58 | + ### Running tests with stride ### |
| 59 | + for triple in "${m_n_stride_list[@]}"; do |
| 60 | + m=$(echo $triple | cut -d ' ' -f1) |
| 61 | + n=$(echo $triple | cut -d ' ' -f2) |
| 62 | + stride_args=$(echo $triple | cut -d ' ' -f3-) |
| 63 | + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args" |
| 64 | + done |
| 65 | + done |
| 66 | + done |
| 67 | + done |
48 | 68 | done |
| 69 | + |
| 70 | +# Special two-pass only |
| 71 | +for pr_i in "fp16" "bf16"; do |
| 72 | + for fadd in "0" "1"; do |
| 73 | + for s in "0" "1"; do |
| 74 | + run_case "$pr_i" "$fadd" "$s" "" "1" "10547" "" |
| 75 | + done |
| 76 | + done |
49 | 77 | done |
| 78 | + |
| 79 | +# Summary |
| 80 | +echo "==============================" |
| 81 | +echo "Total cases: $total" |
| 82 | +echo "Valid cases: $valid" |
| 83 | +accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}") |
| 84 | +echo "Accuracy: $accuracy%" |
| 85 | +echo "==============================" |
0 commit comments