Skip to content

Commit bbc68d9

Browse files
committed
Merge pull request #9963 from velonica0:rvv-elementwise
PiperOrigin-RevId: 907660441
2 parents 56496fd + 0b6f61a commit bbc68d9

32 files changed

Lines changed: 2444 additions & 3 deletions

cmake/gen/rvv_microkernels.cmake

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ SET(PROD_RVV_MICROKERNEL_SRCS
4343
src/f32-spmm/gen/f32-spmm-4vx4-minmax-rvv.c
4444
src/f32-spmm/gen/f32-spmm-8vx1-minmax-rvv.c
4545
src/f32-spmm/gen/f32-spmm-8vx2-minmax-rvv.c
46+
src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u4v.c
4647
src/f32-vbinary/gen/f32-vadd-rvv-u8v.c
4748
src/f32-vbinary/gen/f32-vaddc-rvv-u8v.c
4849
src/f32-vbinary/gen/f32-vdiv-rvv-u8v.c
@@ -69,6 +70,8 @@ SET(PROD_RVV_MICROKERNEL_SRCS
6970
src/f32-vcopysign/gen/f32-vrcopysignc-rvv-u8v.c
7071
src/f32-vcos/gen/f32-vcos-rvv-rational-5-4-div-u8v.c
7172
src/f32-vexp/gen/f32-vexp-rvv-rational-3-2-div-u8v.c
73+
src/f32-velu/gen/f32-velu-rvv-rr2-p6-u4v.c
74+
src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u4v.c
7275
src/f32-vhswish/gen/f32-vhswish-rvv-u4v.c
7376
src/f32-vlog/gen/f32-vlog-rvv-rational-3-3-div-u8v.c
7477
src/f32-vlrelu/gen/f32-vlrelu-rvv-u4v.c
@@ -198,6 +201,13 @@ SET(NON_PROD_RVV_MICROKERNEL_SRCS
198201
src/f32-spmm/gen/f32-spmm-4vx1-minmax-rvv.c
199202
src/f32-spmm/gen/f32-spmm-4vx2-minmax-rvv.c
200203
src/f32-spmm/gen/f32-spmm-8vx4-minmax-rvv.c
204+
src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u1v.c
205+
src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u2v.c
206+
src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u8v.c
207+
src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u1v.c
208+
src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u2v.c
209+
src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u4v.c
210+
src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u8v.c
201211
src/f32-vbinary/gen/f32-vadd-rvv-u4v.c
202212
src/f32-vbinary/gen/f32-vaddc-rvv-u4v.c
203213
src/f32-vbinary/gen/f32-vdiv-rvv-u4v.c
@@ -231,6 +241,16 @@ SET(NON_PROD_RVV_MICROKERNEL_SRCS
231241
src/f32-vexp/gen/f32-vexp-rvv-rational-3-2-div-u1v.c
232242
src/f32-vexp/gen/f32-vexp-rvv-rational-3-2-div-u2v.c
233243
src/f32-vexp/gen/f32-vexp-rvv-rational-3-2-div-u4v.c
244+
src/f32-velu/gen/f32-velu-rvv-rr2-p6-u1v.c
245+
src/f32-velu/gen/f32-velu-rvv-rr2-p6-u2v.c
246+
src/f32-velu/gen/f32-velu-rvv-rr2-p6-u8v.c
247+
src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u1v.c
248+
src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u2v.c
249+
src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u8v.c
250+
src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u1v.c
251+
src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u2v.c
252+
src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u4v.c
253+
src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u8v.c
234254
src/f32-vhswish/gen/f32-vhswish-rvv-u1v.c
235255
src/f32-vhswish/gen/f32-vhswish-rvv-u2v.c
236256
src/f32-vhswish/gen/f32-vhswish-rvv-u8v.c

gen/rvv_microkernels.bzl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ PROD_RVV_MICROKERNEL_SRCS = [
3939
"src/f32-spmm/gen/f32-spmm-4vx4-minmax-rvv.c",
4040
"src/f32-spmm/gen/f32-spmm-8vx1-minmax-rvv.c",
4141
"src/f32-spmm/gen/f32-spmm-8vx2-minmax-rvv.c",
42+
"src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u4v.c",
4243
"src/f32-vbinary/gen/f32-vadd-rvv-u8v.c",
4344
"src/f32-vbinary/gen/f32-vaddc-rvv-u8v.c",
4445
"src/f32-vbinary/gen/f32-vdiv-rvv-u8v.c",
@@ -65,6 +66,8 @@ PROD_RVV_MICROKERNEL_SRCS = [
6566
"src/f32-vcopysign/gen/f32-vrcopysignc-rvv-u8v.c",
6667
"src/f32-vcos/gen/f32-vcos-rvv-rational-5-4-div-u8v.c",
6768
"src/f32-vexp/gen/f32-vexp-rvv-rational-3-2-div-u8v.c",
69+
"src/f32-velu/gen/f32-velu-rvv-rr2-p6-u4v.c",
70+
"src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u4v.c",
6871
"src/f32-vhswish/gen/f32-vhswish-rvv-u4v.c",
6972
"src/f32-vlog/gen/f32-vlog-rvv-rational-3-3-div-u8v.c",
7073
"src/f32-vlrelu/gen/f32-vlrelu-rvv-u4v.c",
@@ -195,6 +198,13 @@ NON_PROD_RVV_MICROKERNEL_SRCS = [
195198
"src/f32-spmm/gen/f32-spmm-4vx1-minmax-rvv.c",
196199
"src/f32-spmm/gen/f32-spmm-4vx2-minmax-rvv.c",
197200
"src/f32-spmm/gen/f32-spmm-8vx4-minmax-rvv.c",
201+
"src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u1v.c",
202+
"src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u2v.c",
203+
"src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u8v.c",
204+
"src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u1v.c",
205+
"src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u2v.c",
206+
"src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u4v.c",
207+
"src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u8v.c",
198208
"src/f32-vbinary/gen/f32-vadd-rvv-u4v.c",
199209
"src/f32-vbinary/gen/f32-vaddc-rvv-u4v.c",
200210
"src/f32-vbinary/gen/f32-vdiv-rvv-u4v.c",
@@ -228,6 +238,16 @@ NON_PROD_RVV_MICROKERNEL_SRCS = [
228238
"src/f32-vexp/gen/f32-vexp-rvv-rational-3-2-div-u1v.c",
229239
"src/f32-vexp/gen/f32-vexp-rvv-rational-3-2-div-u2v.c",
230240
"src/f32-vexp/gen/f32-vexp-rvv-rational-3-2-div-u4v.c",
241+
"src/f32-velu/gen/f32-velu-rvv-rr2-p6-u1v.c",
242+
"src/f32-velu/gen/f32-velu-rvv-rr2-p6-u2v.c",
243+
"src/f32-velu/gen/f32-velu-rvv-rr2-p6-u8v.c",
244+
"src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u1v.c",
245+
"src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u2v.c",
246+
"src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u8v.c",
247+
"src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u1v.c",
248+
"src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u2v.c",
249+
"src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u4v.c",
250+
"src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u8v.c",
231251
"src/f32-vhswish/gen/f32-vhswish-rvv-u1v.c",
232252
"src/f32-vhswish/gen/f32-vhswish-rvv-u2v.c",
233253
"src/f32-vhswish/gen/f32-vhswish-rvv-u8v.c",

scripts/generate-f32-vapproxgelu.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,14 @@ tools/xngen src/f32-vapproxgelu/rational-12-10.c.in -D ARCH=hvx -D BATCH_TILES=3
2020

2121
tools/xngen src/f32-vapproxgelu/rational-12-10.c.in -D ARCH=avx512f -D BATCH_TILES=16,32,48,64 -D DIV=NR -o src/f32-vapproxgelu/gen/f32-vapproxgelu-avx512f-rational-12-10-nr.c &
2222

23+
################################## RISC-V RVV #################################
24+
tools/xngen src/f32-vapproxgelu/rvv-rational-12-10.c.in -D LMUL=1 -D DIV=DIV -o src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u1v.c &
25+
tools/xngen src/f32-vapproxgelu/rvv-rational-12-10.c.in -D LMUL=2 -D DIV=DIV -o src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u2v.c &
26+
tools/xngen src/f32-vapproxgelu/rvv-rational-12-10.c.in -D LMUL=4 -D DIV=DIV -o src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u4v.c &
27+
tools/xngen src/f32-vapproxgelu/rvv-rational-12-10.c.in -D LMUL=8 -D DIV=DIV -o src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-div-u8v.c &
28+
tools/xngen src/f32-vapproxgelu/rvv-rational-12-10.c.in -D LMUL=1 -D DIV=NR -o src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u1v.c &
29+
tools/xngen src/f32-vapproxgelu/rvv-rational-12-10.c.in -D LMUL=2 -D DIV=NR -o src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u2v.c &
30+
tools/xngen src/f32-vapproxgelu/rvv-rational-12-10.c.in -D LMUL=4 -D DIV=NR -o src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u4v.c &
31+
tools/xngen src/f32-vapproxgelu/rvv-rational-12-10.c.in -D LMUL=8 -D DIV=NR -o src/f32-vapproxgelu/gen/f32-vapproxgelu-rvv-rational-12-10-nr-u8v.c &
32+
2333
wait

scripts/generate-f32-velu.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,10 @@ tools/xngen src/f32-velu/avx512f-rr1-p6.c.in -D BATCH_TILE=32 -o src/f32-velu/g
149149
tools/xngen src/f32-velu/avx512f-rr1-p6.c.in -D BATCH_TILE=48 -o src/f32-velu/gen/f32-velu-avx512f-rr1-p6-u48.c &
150150
tools/xngen src/f32-velu/avx512f-rr1-p6.c.in -D BATCH_TILE=64 -o src/f32-velu/gen/f32-velu-avx512f-rr1-p6-u64.c &
151151

152+
################################## RISC-V RVV #################################
153+
tools/xngen src/f32-velu/rvv-rr2-p6.c.in -D LMUL=1 -o src/f32-velu/gen/f32-velu-rvv-rr2-p6-u1v.c &
154+
tools/xngen src/f32-velu/rvv-rr2-p6.c.in -D LMUL=2 -o src/f32-velu/gen/f32-velu-rvv-rr2-p6-u2v.c &
155+
tools/xngen src/f32-velu/rvv-rr2-p6.c.in -D LMUL=4 -o src/f32-velu/gen/f32-velu-rvv-rr2-p6-u4v.c &
156+
tools/xngen src/f32-velu/rvv-rr2-p6.c.in -D LMUL=8 -o src/f32-velu/gen/f32-velu-rvv-rr2-p6-u8v.c &
157+
152158
wait

scripts/generate-f32-vgelu.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,14 @@ tools/xngen src/f32-vgelu/rational-12-10.c.in -D ARCH=hvx -D BATCH_TILES=32
2121
tools/xngen src/f32-vgelu/rational-12-10.c.in -D ARCH=avx512f -D BATCH_TILES=16,32,48,64 -D DIV=NR -o src/f32-vgelu/gen/f32-vgelu-avx512f-rational-12-10-nr.c &
2222
tools/xngen src/f32-vgelu/rational-12-10.c.in -D ARCH=hvx -D BATCH_TILES=32,64,128 -D DIV=NR -o src/f32-vgelu/gen/f32-vgelu-hvx-rational-12-10-nr.c &
2323

24+
################################## RISC-V RVV #################################
25+
tools/xngen src/f32-vgelu/rvv-rational-12-10.c.in -D LMUL=1 -D DIV=DIV -o src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u1v.c &
26+
tools/xngen src/f32-vgelu/rvv-rational-12-10.c.in -D LMUL=2 -D DIV=DIV -o src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u2v.c &
27+
tools/xngen src/f32-vgelu/rvv-rational-12-10.c.in -D LMUL=4 -D DIV=DIV -o src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u4v.c &
28+
tools/xngen src/f32-vgelu/rvv-rational-12-10.c.in -D LMUL=8 -D DIV=DIV -o src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-div-u8v.c &
29+
tools/xngen src/f32-vgelu/rvv-rational-12-10.c.in -D LMUL=1 -D DIV=NR -o src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u1v.c &
30+
tools/xngen src/f32-vgelu/rvv-rational-12-10.c.in -D LMUL=2 -D DIV=NR -o src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u2v.c &
31+
tools/xngen src/f32-vgelu/rvv-rational-12-10.c.in -D LMUL=4 -D DIV=NR -o src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u4v.c &
32+
tools/xngen src/f32-vgelu/rvv-rational-12-10.c.in -D LMUL=8 -D DIV=NR -o src/f32-vgelu/gen/f32-vgelu-rvv-rational-12-10-nr-u8v.c &
33+
2434
wait

src/configs/unary-elementwise-config.c

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,19 @@ static void init_f32_approxgelu_config_impl(struct xnn_unary_elementwise_config*
10901090
config->ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_vapproxgelu_ukernel__hvx_rational_12_10_div_u128);
10911091
config->element_tile = 128;
10921092
}
1093+
#elif XNN_ARCH_RISCV
1094+
#if XNN_ENABLE_RISCV_VECTOR
1095+
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
1096+
assert(hardware_config != NULL);
1097+
if (hardware_config->arch_flags & xnn_arch_riscv_vector) {
1098+
config->ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_vapproxgelu_ukernel__rvv_rational_12_10_div_u4v);
1099+
config->element_tile = 4 * hardware_config->vlenb / sizeof(float);
1100+
} else
1101+
#endif
1102+
{
1103+
config->ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_vapproxgelu_ukernel__scalar_rational_12_10_div_u1);
1104+
config->element_tile = 1;
1105+
}
10931106
#else
10941107
config->ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_vapproxgelu_ukernel__scalar_rational_12_10_div_u1);
10951108
config->element_tile = 1;
@@ -1348,9 +1361,20 @@ static void init_f32_elu_config(void) {
13481361
}
13491362
#endif
13501363
#elif XNN_ARCH_RISCV
1351-
f32_elu_config.ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_velu_ukernel__scalar_rr2_lut16_p3_u4);
1352-
f32_elu_config.element_tile = 4;
1353-
f32_elu_config.init = (xnn_init_unary_uparams_fn) xnn_init_f32_elu_scalar_params;
1364+
#if XNN_ENABLE_RISCV_VECTOR
1365+
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
1366+
assert(hardware_config != NULL);
1367+
if (hardware_config->arch_flags & xnn_arch_riscv_vector) {
1368+
f32_elu_config.ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_velu_ukernel__rvv_rr2_p6_u4v);
1369+
f32_elu_config.element_tile = 4 * hardware_config->vlenb / sizeof(float);
1370+
f32_elu_config.init = (xnn_init_unary_uparams_fn) xnn_init_f32_elu_scalar_params;
1371+
} else
1372+
#endif
1373+
{
1374+
f32_elu_config.ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_velu_ukernel__scalar_rr2_lut16_p3_u4);
1375+
f32_elu_config.element_tile = 4;
1376+
f32_elu_config.init = (xnn_init_unary_uparams_fn) xnn_init_f32_elu_scalar_params;
1377+
}
13541378
#else
13551379
f32_elu_config.ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_velu_ukernel__scalar_rr2_lut16_p3_u4);
13561380
f32_elu_config.element_tile = 4;
@@ -1423,6 +1447,19 @@ static void init_f32_gelu_config_impl(struct xnn_unary_elementwise_config* confi
14231447
config->ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_vgelu_ukernel__hvx_rational_12_10_div_u128);
14241448
config->element_tile = 128;
14251449
}
1450+
#elif XNN_ARCH_RISCV
1451+
#if XNN_ENABLE_RISCV_VECTOR
1452+
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
1453+
assert(hardware_config != NULL);
1454+
if (hardware_config->arch_flags & xnn_arch_riscv_vector) {
1455+
config->ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_vgelu_ukernel__rvv_rational_12_10_div_u4v);
1456+
config->element_tile = 4 * hardware_config->vlenb / sizeof(float);
1457+
} else
1458+
#endif
1459+
{
1460+
config->ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_vgelu_ukernel__scalar_rational_12_10_div_u1);
1461+
config->element_tile = 1;
1462+
}
14261463
#else
14271464
config->ukernel = XNN_INIT_UNARY_UKERNEL(xnn_f32_vgelu_ukernel__scalar_rational_12_10_div_u1);
14281465
config->element_tile = 1;

src/f32-vapproxgelu/f32-vapproxgelu.inc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,10 @@ XNN_UKERNEL(xnn_arch_none, xnn_f32_vapproxgelu_ukernel__wasmsimd_rational_12_10_
6565
XNN_UKERNEL(xnn_arch_none, xnn_f32_vapproxgelu_ukernel__wasmsimd_rational_12_10_div_u12, 12, false, float, struct xnn_f32_default_params, NULL)
6666
XNN_UKERNEL(xnn_arch_none, xnn_f32_vapproxgelu_ukernel__wasmsimd_rational_12_10_div_u16, 16, false, float, struct xnn_f32_default_params, NULL)
6767
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
68+
69+
#if XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
70+
XNN_UKERNEL(xnn_arch_riscv_vector, xnn_f32_vapproxgelu_ukernel__rvv_rational_12_10_div_u1v, 1, true, float, struct xnn_f32_default_params, NULL)
71+
XNN_UKERNEL(xnn_arch_riscv_vector, xnn_f32_vapproxgelu_ukernel__rvv_rational_12_10_div_u2v, 2, true, float, struct xnn_f32_default_params, NULL)
72+
XNN_UKERNEL(xnn_arch_riscv_vector, xnn_f32_vapproxgelu_ukernel__rvv_rational_12_10_div_u4v, 4, true, float, struct xnn_f32_default_params, NULL)
73+
XNN_UKERNEL(xnn_arch_riscv_vector, xnn_f32_vapproxgelu_ukernel__rvv_rational_12_10_div_u8v, 8, true, float, struct xnn_f32_default_params, NULL)
74+
#endif // XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// clang-format off
2+
// Auto-generated file. Do not edit!
3+
// Template: src/f32-vapproxgelu/rvv-rational-12-10.c.in
4+
// Generator: tools/xngen
5+
//
6+
// Copyright 2024 Google LLC
7+
//
8+
// This source code is licensed under the BSD-style license found in the
9+
// LICENSE file in the root directory of this source tree.
10+
11+
#include <assert.h>
12+
13+
#include <riscv_vector.h>
14+
15+
#include "src/xnnpack/common.h"
16+
#include "src/xnnpack/vunary.h"
17+
18+
19+
void xnn_f32_vapproxgelu_ukernel__rvv_rational_12_10_div_u1v(
20+
size_t batch,
21+
const float* input,
22+
float* output,
23+
const struct xnn_f32_default_params* unused_params)
24+
{
25+
assert(batch != 0);
26+
assert(batch % sizeof(float) == 0);
27+
assert(input != NULL);
28+
assert(output != NULL);
29+
30+
// Cap the inputs to this value as `erf(x/sqrt(2))` will always be `+/-1.0f`
31+
// beyond this point. This value is chosen as the first floating point
32+
// number as of which the interpolation returns +/-1.0f.
33+
const float vmax_x = 4.84974098e+00f;
34+
const float vmin_x = -4.84974098e+00f;
35+
36+
// The monomial coefficients of the numerator polynomial (odd).
37+
const float valpha_1 = 7.9788458347e-01f;
38+
const float valpha_3 = 6.0803253204e-02f;
39+
const float valpha_5 = 7.2898347862e-03f;
40+
const float valpha_7 = 2.6887017884e-04f;
41+
const float valpha_9 = 1.4302649106e-05f;
42+
const float valpha_11 = 4.9544411240e-08f;
43+
44+
// The monomial coefficients of the denominator polynomial (even).
45+
const float vbeta_2 = 2.4369759858e-01f;
46+
const float vbeta_4 = 2.4381054565e-02f;
47+
const float vbeta_6 = 1.3060354395e-03f;
48+
const float vbeta_8 = 7.6477612311e-05f;
49+
const float vbeta_10 = 1.3433452750e-06f;
50+
51+
batch >>= XNN_LOG2_SIZEOF_FLOAT;
52+
do {
53+
const size_t n = __riscv_vsetvl_e32m1(batch);
54+
55+
vfloat32m1_t vx_orig = __riscv_vle32_v_f32m1(input, n);
56+
input += n;
57+
58+
// Clamp the inputs to the interpolation range.
59+
vfloat32m1_t vx = __riscv_vfmin_vf_f32m1(vx_orig, vmax_x, n);
60+
vx = __riscv_vfmax_vf_f32m1(vx, vmin_x, n);
61+
62+
// Since the polynomials are odd/even, we need x^2.
63+
vfloat32m1_t vx2 = __riscv_vfmul_vv_f32m1(vx, vx, n);
64+
65+
// Evaluate the numerator polynomial p.
66+
vfloat32m1_t vp = __riscv_vfmv_v_f_f32m1(valpha_9, n);
67+
vp = __riscv_vfmacc_vf_f32m1(vp, valpha_11, vx2, n);
68+
vp = __riscv_vfmadd_vv_f32m1(vp, vx2, __riscv_vfmv_v_f_f32m1(valpha_7, n), n);
69+
vp = __riscv_vfmadd_vv_f32m1(vp, vx2, __riscv_vfmv_v_f_f32m1(valpha_5, n), n);
70+
vp = __riscv_vfmadd_vv_f32m1(vp, vx2, __riscv_vfmv_v_f_f32m1(valpha_3, n), n);
71+
vp = __riscv_vfmadd_vv_f32m1(vp, vx2, __riscv_vfmv_v_f_f32m1(valpha_1, n), n);
72+
vp = __riscv_vfmul_vv_f32m1(vp, vx, n);
73+
74+
// Evaluate the denominator polynomial q.
75+
vfloat32m1_t vq = __riscv_vfmv_v_f_f32m1(vbeta_8, n);
76+
vq = __riscv_vfmacc_vf_f32m1(vq, vbeta_10, vx2, n);
77+
vq = __riscv_vfmadd_vv_f32m1(vq, vx2, __riscv_vfmv_v_f_f32m1(vbeta_6, n), n);
78+
vq = __riscv_vfmadd_vv_f32m1(vq, vx2, __riscv_vfmv_v_f_f32m1(vbeta_4, n), n);
79+
vq = __riscv_vfmadd_vv_f32m1(vq, vx2, __riscv_vfmv_v_f_f32m1(vbeta_2, n), n);
80+
vq = __riscv_vfmadd_vv_f32m1(vq, vx2, __riscv_vfmv_v_f_f32m1(1.0f, n), n);
81+
82+
// Divide the numerator by the denominator.
83+
vfloat32m1_t verf = __riscv_vfdiv_vv_f32m1(vp, vq, n);
84+
85+
// Add one to the rational interpolant, and multiply by 0.5 times the
86+
// original input.
87+
vfloat32m1_t vy = __riscv_vfadd_vf_f32m1(verf, 1.0f, n);
88+
vy = __riscv_vfmul_vf_f32m1(vy, 0.5f, n);
89+
vy = __riscv_vfmul_vv_f32m1(vy, vx_orig, n);
90+
91+
__riscv_vse32_v_f32m1(output, vy, n);
92+
output += n;
93+
94+
batch -= n;
95+
} while (batch != 0);
96+
}

0 commit comments

Comments
 (0)