From cc68da84bd4e062a79600b82109ea28bd1553e49 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Tue, 19 May 2026 16:14:22 -0700 Subject: [PATCH] Add sigmoid_fp64 kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit And rewrite the sigmoid_fp32 kernel using the same technique. It turns out that this kernel is faster, which is a little surprising. It does have less "overhead" (special cases, piecewise branches, etc.) in exchange for more polynomial arithmetic. Change in performance for `sigmoid_fp32`: ``` name time/op time/op vs base bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time 1.759µ ± 3% 1.771µ ± 21% ~ (p=0.818 n=6) bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:4/n:1024/real_time 1.740µ ± 3% 1.774µ ± 3% ~ (p=0.065 n=6) bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:16/n:256/real_time 1.773µ ± 2% 1.783µ ± 1% ~ (p=0.699 n=6) bench/sigmoid_fp32_1x16_x86_avx2/m:1/n:4096/real_time 4.909µ ± 1% 3.670µ ± 2% -25.24% (p=0.002 n=6) bench/sigmoid_fp32_1x16_x86_avx2/m:4/n:1024/real_time 4.830µ ± 2% 3.706µ ± 4% -23.28% (p=0.002 n=6) bench/sigmoid_fp32_1x16_x86_avx2/m:16/n:256/real_time 4.912µ ± 1% 3.740µ ± 2% -23.87% (p=0.002 n=6) bench/sigmoid_fp32_1x32_x86_sse2/m:1/n:4096/real_time 6.632µ ± 3% 5.437µ ± 2% -18.02% (p=0.002 n=6) bench/sigmoid_fp32_1x32_x86_sse2/m:4/n:1024/real_time 6.637µ ± 3% 5.524µ ± 3% -16.77% (p=0.002 n=6) bench/sigmoid_fp32_1x32_x86_sse2/m:16/n:256/real_time 6.692µ ± 4% 5.493µ ± 1% -17.92% (p=0.002 n=6) geomean 3.851µ 3.305µ -14.19% ``` `sigmoid_fp64` compared to other kernels: ``` ---------------------------------------------------------------------------------------------------------------------------- Benchmark Time CPU Iterations UserCounters... ---------------------------------------------------------------------------------------------------------------------------- bench_reference/sigmoid_float/m:1/n:4096/real_time 38579 ns 38571 ns 7348 Bytes=849.382M/s Op=106.173M/s bench_reference/sigmoid_float/m:4/n:1024/real_time 38345 ns 38338 ns 7440 Bytes=854.556M/s Op=106.819M/s bench_reference/sigmoid_float/m:16/n:256/real_time 39192 ns 39187 ns 7190 Bytes=836.095M/s Op=104.512M/s bench_reference/sigmoid_double/m:1/n:4096/real_time 91326 ns 91313 ns 3045 Bytes=717.606M/s Op=44.8504M/s bench_reference/sigmoid_double/m:4/n:1024/real_time 91307 ns 91290 ns 3043 Bytes=717.757M/s Op=44.8598M/s bench_reference/sigmoid_double/m:16/n:256/real_time 93505 ns 93486 ns 3018 Bytes=700.885M/s Op=43.8053M/s bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:1/n:4096/real_time 1786 ns 1786 ns 155658 Bytes=18.3422G/s Op=2.29277G/s bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:4/n:1024/real_time 1802 ns 1802 ns 157599 Bytes=18.1847G/s Op=2.27309G/s bench/sigmoid_fp32_1x32_x86_avx512f_avx512bw/m:16/n:256/real_time 1791 ns 1791 ns 156134 Bytes=18.2963G/s Op=2.28704G/s bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:1/n:4096/real_time 4475 ns 4475 ns 60425 Bytes=14.6433G/s Op=915.207M/s bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:4/n:1024/real_time 4822 ns 4821 ns 59593 Bytes=13.5913G/s Op=849.459M/s bench/sigmoid_fp64_1x16_x86_avx512f_avx512bw/m:16/n:256/real_time 4842 ns 4840 ns 56596 Bytes=13.5363G/s Op=846.016M/s bench/sigmoid_fp32_1x16_x86_avx2/m:1/n:4096/real_time 3789 ns 3788 ns 69486 Bytes=8.64752G/s Op=1.08094G/s bench/sigmoid_fp32_1x16_x86_avx2/m:4/n:1024/real_time 3892 ns 3892 ns 74142 Bytes=8.41825G/s Op=1.05228G/s bench/sigmoid_fp32_1x16_x86_avx2/m:16/n:256/real_time 3757 ns 3756 ns 72827 Bytes=8.72073G/s Op=1.09009G/s bench/sigmoid_fp64_1x8_x86_avx2/m:1/n:4096/real_time 10451 ns 10450 ns 26516 Bytes=6.27103G/s Op=391.939M/s bench/sigmoid_fp64_1x8_x86_avx2/m:4/n:1024/real_time 11010 ns 11007 ns 24451 Bytes=5.95261G/s Op=372.038M/s bench/sigmoid_fp64_1x8_x86_avx2/m:16/n:256/real_time 10475 ns 10472 ns 26374 Bytes=6.2567G/s Op=391.044M/s bench/sigmoid_fp32_1x32_x86_sse2/m:1/n:4096/real_time 5649 ns 5648 ns 49675 Bytes=5.80048G/s Op=725.06M/s bench/sigmoid_fp32_1x32_x86_sse2/m:4/n:1024/real_time 5646 ns 5645 ns 50916 Bytes=5.80353G/s Op=725.441M/s bench/sigmoid_fp32_1x32_x86_sse2/m:16/n:256/real_time 5571 ns 5571 ns 48792 Bytes=5.88151G/s Op=735.188M/s bench/sigmoid_fp64_1x8_x86_sse2/m:1/n:4096/real_time 15957 ns 15952 ns 17116 Bytes=4.10712G/s Op=256.695M/s bench/sigmoid_fp64_1x8_x86_sse2/m:4/n:1024/real_time 15657 ns 15654 ns 17451 Bytes=4.18581G/s Op=261.613M/s bench/sigmoid_fp64_1x8_x86_sse2/m:16/n:256/real_time 15748 ns 15744 ns 17804 Bytes=4.16163G/s Op=260.102M/s ``` PiperOrigin-RevId: 918081537 --- ynnpack/kernels/unary/generator.py | 4 + ynnpack/kernels/unary/reference.h | 11 +-- ynnpack/kernels/unary/sigmoid.py | 122 ++++++++++++++++++++--------- 3 files changed, 88 insertions(+), 49 deletions(-) diff --git a/ynnpack/kernels/unary/generator.py b/ynnpack/kernels/unary/generator.py index e7d2fe665ff..158ab25414d 100644 --- a/ynnpack/kernels/unary/generator.py +++ b/ynnpack/kernels/unary/generator.py @@ -48,6 +48,7 @@ def main(argv: Sequence[str]) -> None: (square_root_fp32, (8, 1)), (square_root_fp64, (4, 1)), (sigmoid_fp32, (32, 1)), + (sigmoid_fp64, (8, 1)), (tanh_fp32, (16, 1)), (tanh_fp64, (8, 1)), ], @@ -99,6 +100,7 @@ def main(argv: Sequence[str]) -> None: (log_fp64, (16, 1)), (round_to_bf16_fp32, (16, 1)), (sigmoid_fp32, (16, 1)), + (sigmoid_fp64, (8, 1)), ], "x86_fma3": [ (cosine_fp32, (16, 1)), @@ -146,6 +148,7 @@ def main(argv: Sequence[str]) -> None: (square_root_fp32, (32, 1)), (square_root_fp64, (16, 1)), (sigmoid_fp32, (32, 1)), + (sigmoid_fp64, (16, 1)), (tanh_fp32, (32, 1)), (tanh_fp64, (16, 1)), (convert_int2_to_int8, (64, 1)), @@ -188,6 +191,7 @@ def main(argv: Sequence[str]) -> None: (poly3_fp64, (16, 1)), (reciprocal_square_root_fp64, (4, 1)), (round_fp64, (4, 1)), + (sigmoid_fp64, (4, 1)), (square_fp64, (4, 1)), (square_root_fp64, (4, 1)), (tanh_fp64, (8, 1)), diff --git a/ynnpack/kernels/unary/reference.h b/ynnpack/kernels/unary/reference.h index 78e7bfd9b7c..42fb55f85f1 100644 --- a/ynnpack/kernels/unary/reference.h +++ b/ynnpack/kernels/unary/reference.h @@ -184,16 +184,7 @@ struct sigmoid : public unary_op_info { } tolerance_spec tolerance(ynn_type /*type*/) const override { - return tolerance_spec{/*relative=*/1.0f, /*absolute=*/1.0f}; - } - - interval domain(ynn_type type) const override { - switch (type) { - case ynn_type_fp16: - return {-25.0f, 25.0f}; - default: - return {-125.0f, 125.0f}; - } + return tolerance_spec{/*relative=*/2.0f, /*absolute=*/1e-2f}; } }; diff --git a/ynnpack/kernels/unary/sigmoid.py b/ynnpack/kernels/unary/sigmoid.py index ff600eb32a7..cf38418d874 100644 --- a/ynnpack/kernels/unary/sigmoid.py +++ b/ynnpack/kernels/unary/sigmoid.py @@ -1,49 +1,93 @@ """Definition of sigmoid kernel.""" +import math + # pylint: disable=undefined-variable +# pylint: disable=missing-function-docstring from ynnpack.kernels.elementwise.compiler import * # pylint: disable=wildcard-import +from ynnpack.kernels.unary.util import * # pylint: disable=wildcard-import @const_buffer('a', Float(32)) @buffer('x', Float(32)) @operator_name('sigmoid') def sigmoid_fp32(a, x): - """Polynomial approximation of sigmoid.""" - vmagic_bias = float.fromhex('0x1.8000FEp23') - vminus_log2e = float.fromhex('-0x1.715476p0') - vln2_hi = float.fromhex('0x1.62E400p-1') - vln2_lo = float.fromhex('0x1.7F7D1Cp-20') - vc5 = float.fromhex('-0x1.0F9F9Cp-7') - vc4 = float.fromhex('0x1.573A1Ap-5') - vc3 = float.fromhex('-0x1.555A80p-3') - vc2 = float.fromhex('0x1.FFFDC6p-2') - vc1 = float.fromhex('-0x1.FFFFF6p-1') - vdenorm_cutoff = float.fromhex('0x1.5D589Ep+6') - - vx = load(a) - vz = abs(vx) - - vn = multiply_add(vz, vminus_log2e, vmagic_bias) - - vs = reinterpret_cast( - Float(32), logical_shift_left(reinterpret_cast(Int(32), vn), i32(23)) - ) - vn = vn - vmagic_bias - - vt = multiply_add(vn, vln2_hi, vz) - vt = multiply_add(vn, vln2_lo, vt) - - vp = multiply_add(vt, vc5, vc4) - vp = multiply_add(vt, vp, vc3) - vp = multiply_add(vt, vp, vc2) - vp = multiply_add(vt, vp, vc1) - - vt = vt * vs - ve = multiply_add(vt, vp, vs) - vd = ve + 1.0 - - vf = ve / vd - vf = select(vz > vdenorm_cutoff, f32(0.0), vf) - vf = select(vx > f32(0.0), 1.0 - vf, vf) - - return store(vf, x) + # We want to compute 1/(1 + e^-x) + # Polynomial coefficients for 2^r. These coefficients have been optimized for + # the purposes of computing sigmoid. + p = [ + 1.3182145543e-02, + 1.3825711608e-01, + 6.0665678978e-01, + 1.0000000000e+00 + ] + q = [ + 7.5802938081e-03, + -4.2016692460e-02, + -8.6490385234e-02, + 1.0000000000e+00 + ] + log2_e = math.log2(math.e) + + va = load(a) * -log2_e + vz_prime = min(max(va, -127.0), 128.0) + + # Decompose x * log2e into `z` (integer part) and `r` (remainder). + vz = round_small_fp32(vz_prime) + vr = vz_prime - vz + + # Compute 2^z. + v2z = exp2_round(vz) + v2z = copynan(v2z, va) + + vp = eval_polynomial(vr, p) + vq = eval_polynomial(vr, q) + + # This is 1 / (1 + v2z * vp / vq), rearranged to avoid the extra division. + vx = vq / multiply_add(v2z, vp, vq) + + return store(vx, x) + + +@const_buffer('a', Float(64)) +@buffer('x', Float(64)) +@operator_name('sigmoid') +def sigmoid_fp64(a, x): + # We want to compute 1/(1 + e^-x) + # Polynomial coefficients for 2^r + p = [ + f64(3.430671987749682348e-06), + f64(1.754214714900316551e-04), + f64(3.930681642138933278e-03), + f64(4.871246780757146344e-02), + f64(3.331084219217221309e-01), + f64(9.999999999999998890e-01), + ] + q = [ + f64(-7.126417699553038831e-06), + f64(2.821496207245147419e-04), + f64(-5.316864104706260294e-03), + f64(5.804581129084830649e-02), + f64(-3.600387586382234328e-01), + f64(1.000000000000000000e00), + ] + log2_e = f64(math.log2(math.e)) + + va = load(a) * -log2_e + vz_prime = min(max(va, f64(-1023.0)), f64(1024.0)) + + # Decompose x * log2e into `z` (integer part) and `r` (remainder). + vz = round_small_fp64(vz_prime) + vr = vz_prime - vz + + # Compute 2^z. + v2z = exp2_round(vz) + v2z = copynan(v2z, va) + + vp = eval_polynomial(vr, p) + vq = eval_polynomial(vr, q) + + # This is 1 / (1 + v2z * vp / vq), rearranged to avoid the extra division. + vx = vq / multiply_add(v2z, vp, vq) + + return store(vx, x)