Skip to content

Commit cc68da8

Browse files
dsharletgxnnpack-bot
authored andcommitted
Add sigmoid_fp64 kernels
And rewrite the sigmoid_fp32 kernel using the same technique. It turns out that this kernel is faster, which is a little surprising. It does have less "overhead" (special cases, piecewise branches, etc.) in exchange for more polynomial arithmetic. Change in performance for `sigmoid_fp32`: ``` name time/op time/op vs base bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time 1.759µ ± 3% 1.771µ ± 21% ~ (p=0.818 n=6) bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:4/n:1024/real_time 1.740µ ± 3% 1.774µ ± 3% ~ (p=0.065 n=6) bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:16/n:256/real_time 1.773µ ± 2% 1.783µ ± 1% ~ (p=0.699 n=6) bench/sigmoid_fp32_1x16_x86_avx2/m:1/n:4096/real_time 4.909µ ± 1% 3.670µ ± 2% -25.24% (p=0.002 n=6) bench/sigmoid_fp32_1x16_x86_avx2/m:4/n:1024/real_time 4.830µ ± 2% 3.706µ ± 4% -23.28% (p=0.002 n=6) bench/sigmoid_fp32_1x16_x86_avx2/m:16/n:256/real_time 4.912µ ± 1% 3.740µ ± 2% -23.87% (p=0.002 n=6) bench/sigmoid_fp32_1x32_x86_sse2/m:1/n:4096/real_time 6.632µ ± 3% 5.437µ ± 2% -18.02% (p=0.002 n=6) bench/sigmoid_fp32_1x32_x86_sse2/m:4/n:1024/real_time 6.637µ ± 3% 5.524µ ± 3% -16.77% (p=0.002 n=6) bench/sigmoid_fp32_1x32_x86_sse2/m:16/n:256/real_time 6.692µ ± 4% 5.493µ ± 1% -17.92% (p=0.002 n=6) geomean 3.851µ 3.305µ -14.19% ``` `sigmoid_fp64` compared to other kernels: ``` ---------------------------------------------------------------------------------------------------------------------------- Benchmark Time CPU Iterations UserCounters... ---------------------------------------------------------------------------------------------------------------------------- bench_reference/sigmoid_float/m:1/n:4096/real_time 38579 ns 38571 ns 7348 Bytes=849.382M/s Op=106.173M/s bench_reference/sigmoid_float/m:4/n:1024/real_time 38345 ns 38338 ns 7440 Bytes=854.556M/s Op=106.819M/s bench_reference/sigmoid_float/m:16/n:256/real_time 39192 ns 39187 ns 7190 Bytes=836.095M/s Op=104.512M/s bench_reference/sigmoid_double/m:1/n:4096/real_time 91326 ns 91313 ns 3045 Bytes=717.606M/s Op=44.8504M/s bench_reference/sigmoid_double/m:4/n:1024/real_time 91307 ns 91290 ns 3043 Bytes=717.757M/s Op=44.8598M/s bench_reference/sigmoid_double/m:16/n:256/real_time 93505 ns 93486 ns 3018 Bytes=700.885M/s Op=43.8053M/s bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time 1786 ns 1786 ns 155658 Bytes=18.3422G/s Op=2.29277G/s bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:4/n:1024/real_time 1802 ns 1802 ns 157599 Bytes=18.1847G/s Op=2.27309G/s bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:16/n:256/real_time 1791 ns 1791 ns 156134 Bytes=18.2963G/s Op=2.28704G/s bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:1/n:4096/real_time 4475 ns 4475 ns 60425 Bytes=14.6433G/s Op=915.207M/s bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:4/n:1024/real_time 4822 ns 4821 ns 59593 Bytes=13.5913G/s Op=849.459M/s bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:16/n:256/real_time 4842 ns 4840 ns 56596 Bytes=13.5363G/s Op=846.016M/s bench/sigmoid_fp32_1x16_x86_avx2/m:1/n:4096/real_time 3789 ns 3788 ns 69486 Bytes=8.64752G/s Op=1.08094G/s bench/sigmoid_fp32_1x16_x86_avx2/m:4/n:1024/real_time 3892 ns 3892 ns 74142 Bytes=8.41825G/s Op=1.05228G/s bench/sigmoid_fp32_1x16_x86_avx2/m:16/n:256/real_time 3757 ns 3756 ns 72827 Bytes=8.72073G/s Op=1.09009G/s bench/sigmoid_fp64_1x8_x86_avx2/m:1/n:4096/real_time 10451 ns 10450 ns 26516 Bytes=6.27103G/s Op=391.939M/s bench/sigmoid_fp64_1x8_x86_avx2/m:4/n:1024/real_time 11010 ns 11007 ns 24451 Bytes=5.95261G/s Op=372.038M/s bench/sigmoid_fp64_1x8_x86_avx2/m:16/n:256/real_time 10475 ns 10472 ns 26374 Bytes=6.2567G/s Op=391.044M/s bench/sigmoid_fp32_1x32_x86_sse2/m:1/n:4096/real_time 5649 ns 5648 ns 49675 Bytes=5.80048G/s Op=725.06M/s bench/sigmoid_fp32_1x32_x86_sse2/m:4/n:1024/real_time 5646 ns 5645 ns 50916 Bytes=5.80353G/s Op=725.441M/s bench/sigmoid_fp32_1x32_x86_sse2/m:16/n:256/real_time 5571 ns 5571 ns 48792 Bytes=5.88151G/s Op=735.188M/s bench/sigmoid_fp64_1x8_x86_sse2/m:1/n:4096/real_time 15957 ns 15952 ns 17116 Bytes=4.10712G/s Op=256.695M/s bench/sigmoid_fp64_1x8_x86_sse2/m:4/n:1024/real_time 15657 ns 15654 ns 17451 Bytes=4.18581G/s Op=261.613M/s bench/sigmoid_fp64_1x8_x86_sse2/m:16/n:256/real_time 15748 ns 15744 ns 17804 Bytes=4.16163G/s Op=260.102M/s ``` PiperOrigin-RevId: 918081537
1 parent 4b53a43 commit cc68da8

3 files changed

Lines changed: 88 additions & 49 deletions

File tree

ynnpack/kernels/unary/generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def main(argv: Sequence[str]) -> None:
4848
(square_root_fp32, (8, 1)),
4949
(square_root_fp64, (4, 1)),
5050
(sigmoid_fp32, (32, 1)),
51+
(sigmoid_fp64, (8, 1)),
5152
(tanh_fp32, (16, 1)),
5253
(tanh_fp64, (8, 1)),
5354
],
@@ -99,6 +100,7 @@ def main(argv: Sequence[str]) -> None:
99100
(log_fp64, (16, 1)),
100101
(round_to_bf16_fp32, (16, 1)),
101102
(sigmoid_fp32, (16, 1)),
103+
(sigmoid_fp64, (8, 1)),
102104
],
103105
"x86_fma3": [
104106
(cosine_fp32, (16, 1)),
@@ -146,6 +148,7 @@ def main(argv: Sequence[str]) -> None:
146148
(square_root_fp32, (32, 1)),
147149
(square_root_fp64, (16, 1)),
148150
(sigmoid_fp32, (32, 1)),
151+
(sigmoid_fp64, (16, 1)),
149152
(tanh_fp32, (32, 1)),
150153
(tanh_fp64, (16, 1)),
151154
(convert_int2_to_int8, (64, 1)),
@@ -188,6 +191,7 @@ def main(argv: Sequence[str]) -> None:
188191
(poly3_fp64, (16, 1)),
189192
(reciprocal_square_root_fp64, (4, 1)),
190193
(round_fp64, (4, 1)),
194+
(sigmoid_fp64, (4, 1)),
191195
(square_fp64, (4, 1)),
192196
(square_root_fp64, (4, 1)),
193197
(tanh_fp64, (8, 1)),

ynnpack/kernels/unary/reference.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,7 @@ struct sigmoid : public unary_op_info {
184184
}
185185

186186
tolerance_spec tolerance(ynn_type /*type*/) const override {
187-
return tolerance_spec{/*relative=*/1.0f, /*absolute=*/1.0f};
188-
}
189-
190-
interval domain(ynn_type type) const override {
191-
switch (type) {
192-
case ynn_type_fp16:
193-
return {-25.0f, 25.0f};
194-
default:
195-
return {-125.0f, 125.0f};
196-
}
187+
return tolerance_spec{/*relative=*/2.0f, /*absolute=*/1e-2f};
197188
}
198189
};
199190

ynnpack/kernels/unary/sigmoid.py

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,93 @@
11
"""Definition of sigmoid kernel."""
22

3+
import math
4+
35
# pylint: disable=undefined-variable
6+
# pylint: disable=missing-function-docstring
47
from ynnpack.kernels.elementwise.compiler import * # pylint: disable=wildcard-import
8+
from ynnpack.kernels.unary.util import * # pylint: disable=wildcard-import
59

610

711
@const_buffer('a', Float(32))
812
@buffer('x', Float(32))
913
@operator_name('sigmoid')
1014
def sigmoid_fp32(a, x):
11-
"""Polynomial approximation of sigmoid."""
12-
vmagic_bias = float.fromhex('0x1.8000FEp23')
13-
vminus_log2e = float.fromhex('-0x1.715476p0')
14-
vln2_hi = float.fromhex('0x1.62E400p-1')
15-
vln2_lo = float.fromhex('0x1.7F7D1Cp-20')
16-
vc5 = float.fromhex('-0x1.0F9F9Cp-7')
17-
vc4 = float.fromhex('0x1.573A1Ap-5')
18-
vc3 = float.fromhex('-0x1.555A80p-3')
19-
vc2 = float.fromhex('0x1.FFFDC6p-2')
20-
vc1 = float.fromhex('-0x1.FFFFF6p-1')
21-
vdenorm_cutoff = float.fromhex('0x1.5D589Ep+6')
22-
23-
vx = load(a)
24-
vz = abs(vx)
25-
26-
vn = multiply_add(vz, vminus_log2e, vmagic_bias)
27-
28-
vs = reinterpret_cast(
29-
Float(32), logical_shift_left(reinterpret_cast(Int(32), vn), i32(23))
30-
)
31-
vn = vn - vmagic_bias
32-
33-
vt = multiply_add(vn, vln2_hi, vz)
34-
vt = multiply_add(vn, vln2_lo, vt)
35-
36-
vp = multiply_add(vt, vc5, vc4)
37-
vp = multiply_add(vt, vp, vc3)
38-
vp = multiply_add(vt, vp, vc2)
39-
vp = multiply_add(vt, vp, vc1)
40-
41-
vt = vt * vs
42-
ve = multiply_add(vt, vp, vs)
43-
vd = ve + 1.0
44-
45-
vf = ve / vd
46-
vf = select(vz > vdenorm_cutoff, f32(0.0), vf)
47-
vf = select(vx > f32(0.0), 1.0 - vf, vf)
48-
49-
return store(vf, x)
15+
# We want to compute 1/(1 + e^-x)
16+
# Polynomial coefficients for 2^r. These coefficients have been optimized for
17+
# the purposes of computing sigmoid.
18+
p = [
19+
1.3182145543e-02,
20+
1.3825711608e-01,
21+
6.0665678978e-01,
22+
1.0000000000e+00
23+
]
24+
q = [
25+
7.5802938081e-03,
26+
-4.2016692460e-02,
27+
-8.6490385234e-02,
28+
1.0000000000e+00
29+
]
30+
log2_e = math.log2(math.e)
31+
32+
va = load(a) * -log2_e
33+
vz_prime = min(max(va, -127.0), 128.0)
34+
35+
# Decompose x * log2e into `z` (integer part) and `r` (remainder).
36+
vz = round_small_fp32(vz_prime)
37+
vr = vz_prime - vz
38+
39+
# Compute 2^z.
40+
v2z = exp2_round(vz)
41+
v2z = copynan(v2z, va)
42+
43+
vp = eval_polynomial(vr, p)
44+
vq = eval_polynomial(vr, q)
45+
46+
# This is 1 / (1 + v2z * vp / vq), rearranged to avoid the extra division.
47+
vx = vq / multiply_add(v2z, vp, vq)
48+
49+
return store(vx, x)
50+
51+
52+
@const_buffer('a', Float(64))
53+
@buffer('x', Float(64))
54+
@operator_name('sigmoid')
55+
def sigmoid_fp64(a, x):
56+
# We want to compute 1/(1 + e^-x)
57+
# Polynomial coefficients for 2^r
58+
p = [
59+
f64(3.430671987749682348e-06),
60+
f64(1.754214714900316551e-04),
61+
f64(3.930681642138933278e-03),
62+
f64(4.871246780757146344e-02),
63+
f64(3.331084219217221309e-01),
64+
f64(9.999999999999998890e-01),
65+
]
66+
q = [
67+
f64(-7.126417699553038831e-06),
68+
f64(2.821496207245147419e-04),
69+
f64(-5.316864104706260294e-03),
70+
f64(5.804581129084830649e-02),
71+
f64(-3.600387586382234328e-01),
72+
f64(1.000000000000000000e00),
73+
]
74+
log2_e = f64(math.log2(math.e))
75+
76+
va = load(a) * -log2_e
77+
vz_prime = min(max(va, f64(-1023.0)), f64(1024.0))
78+
79+
# Decompose x * log2e into `z` (integer part) and `r` (remainder).
80+
vz = round_small_fp64(vz_prime)
81+
vr = vz_prime - vz
82+
83+
# Compute 2^z.
84+
v2z = exp2_round(vz)
85+
v2z = copynan(v2z, va)
86+
87+
vp = eval_polynomial(vr, p)
88+
vq = eval_polynomial(vr, q)
89+
90+
# This is 1 / (1 + v2z * vp / vq), rearranged to avoid the extra division.
91+
vx = vq / multiply_add(v2z, vp, vq)
92+
93+
return store(vx, x)

0 commit comments

Comments
 (0)