diff --git a/ynnpack/kernels/unary/generator.py b/ynnpack/kernels/unary/generator.py index e7d2fe665ff..158ab25414d 100644 --- a/ynnpack/kernels/unary/generator.py +++ b/ynnpack/kernels/unary/generator.py @@ -48,6 +48,7 @@ def main(argv: Sequence[str]) -> None: (square_root_fp32, (8, 1)), (square_root_fp64, (4, 1)), (sigmoid_fp32, (32, 1)), + (sigmoid_fp64, (8, 1)), (tanh_fp32, (16, 1)), (tanh_fp64, (8, 1)), ], @@ -99,6 +100,7 @@ def main(argv: Sequence[str]) -> None: (log_fp64, (16, 1)), (round_to_bf16_fp32, (16, 1)), (sigmoid_fp32, (16, 1)), + (sigmoid_fp64, (8, 1)), ], "x86_fma3": [ (cosine_fp32, (16, 1)), @@ -146,6 +148,7 @@ def main(argv: Sequence[str]) -> None: (square_root_fp32, (32, 1)), (square_root_fp64, (16, 1)), (sigmoid_fp32, (32, 1)), + (sigmoid_fp64, (16, 1)), (tanh_fp32, (32, 1)), (tanh_fp64, (16, 1)), (convert_int2_to_int8, (64, 1)), @@ -188,6 +191,7 @@ def main(argv: Sequence[str]) -> None: (poly3_fp64, (16, 1)), (reciprocal_square_root_fp64, (4, 1)), (round_fp64, (4, 1)), + (sigmoid_fp64, (4, 1)), (square_fp64, (4, 1)), (square_root_fp64, (4, 1)), (tanh_fp64, (8, 1)), diff --git a/ynnpack/kernels/unary/reference.h b/ynnpack/kernels/unary/reference.h index 78e7bfd9b7c..42fb55f85f1 100644 --- a/ynnpack/kernels/unary/reference.h +++ b/ynnpack/kernels/unary/reference.h @@ -184,16 +184,7 @@ struct sigmoid : public unary_op_info { } tolerance_spec tolerance(ynn_type /*type*/) const override { - return tolerance_spec{/*relative=*/1.0f, /*absolute=*/1.0f}; - } - - interval domain(ynn_type type) const override { - switch (type) { - case ynn_type_fp16: - return {-25.0f, 25.0f}; - default: - return {-125.0f, 125.0f}; - } + return tolerance_spec{/*relative=*/2.0f, /*absolute=*/1e-2f}; } }; diff --git a/ynnpack/kernels/unary/sigmoid.py b/ynnpack/kernels/unary/sigmoid.py index ff600eb32a7..cf38418d874 100644 --- a/ynnpack/kernels/unary/sigmoid.py +++ b/ynnpack/kernels/unary/sigmoid.py @@ -1,49 +1,93 @@ """Definition of sigmoid kernel.""" +import math + # pylint: disable=undefined-variable +# pylint: disable=missing-function-docstring from ynnpack.kernels.elementwise.compiler import * # pylint: disable=wildcard-import +from ynnpack.kernels.unary.util import * # pylint: disable=wildcard-import @const_buffer('a', Float(32)) @buffer('x', Float(32)) @operator_name('sigmoid') def sigmoid_fp32(a, x): - """Polynomial approximation of sigmoid.""" - vmagic_bias = float.fromhex('0x1.8000FEp23') - vminus_log2e = float.fromhex('-0x1.715476p0') - vln2_hi = float.fromhex('0x1.62E400p-1') - vln2_lo = float.fromhex('0x1.7F7D1Cp-20') - vc5 = float.fromhex('-0x1.0F9F9Cp-7') - vc4 = float.fromhex('0x1.573A1Ap-5') - vc3 = float.fromhex('-0x1.555A80p-3') - vc2 = float.fromhex('0x1.FFFDC6p-2') - vc1 = float.fromhex('-0x1.FFFFF6p-1') - vdenorm_cutoff = float.fromhex('0x1.5D589Ep+6') - - vx = load(a) - vz = abs(vx) - - vn = multiply_add(vz, vminus_log2e, vmagic_bias) - - vs = reinterpret_cast( - Float(32), logical_shift_left(reinterpret_cast(Int(32), vn), i32(23)) - ) - vn = vn - vmagic_bias - - vt = multiply_add(vn, vln2_hi, vz) - vt = multiply_add(vn, vln2_lo, vt) - - vp = multiply_add(vt, vc5, vc4) - vp = multiply_add(vt, vp, vc3) - vp = multiply_add(vt, vp, vc2) - vp = multiply_add(vt, vp, vc1) - - vt = vt * vs - ve = multiply_add(vt, vp, vs) - vd = ve + 1.0 - - vf = ve / vd - vf = select(vz > vdenorm_cutoff, f32(0.0), vf) - vf = select(vx > f32(0.0), 1.0 - vf, vf) - - return store(vf, x) + # We want to compute 1/(1 + e^-x) + # Polynomial coefficients for 2^r. These coefficients have been optimized for + # the purposes of computing sigmoid. + p = [ + 1.3182145543e-02, + 1.3825711608e-01, + 6.0665678978e-01, + 1.0000000000e+00 + ] + q = [ + 7.5802938081e-03, + -4.2016692460e-02, + -8.6490385234e-02, + 1.0000000000e+00 + ] + log2_e = math.log2(math.e) + + va = load(a) * -log2_e + vz_prime = min(max(va, -127.0), 128.0) + + # Decompose x * log2e into `z` (integer part) and `r` (remainder). + vz = round_small_fp32(vz_prime) + vr = vz_prime - vz + + # Compute 2^z. + v2z = exp2_round(vz) + v2z = copynan(v2z, va) + + vp = eval_polynomial(vr, p) + vq = eval_polynomial(vr, q) + + # This is 1 / (1 + v2z * vp / vq), rearranged to avoid the extra division. + vx = vq / multiply_add(v2z, vp, vq) + + return store(vx, x) + + +@const_buffer('a', Float(64)) +@buffer('x', Float(64)) +@operator_name('sigmoid') +def sigmoid_fp64(a, x): + # We want to compute 1/(1 + e^-x) + # Polynomial coefficients for 2^r + p = [ + f64(3.430671987749682348e-06), + f64(1.754214714900316551e-04), + f64(3.930681642138933278e-03), + f64(4.871246780757146344e-02), + f64(3.331084219217221309e-01), + f64(9.999999999999998890e-01), + ] + q = [ + f64(-7.126417699553038831e-06), + f64(2.821496207245147419e-04), + f64(-5.316864104706260294e-03), + f64(5.804581129084830649e-02), + f64(-3.600387586382234328e-01), + f64(1.000000000000000000e00), + ] + log2_e = f64(math.log2(math.e)) + + va = load(a) * -log2_e + vz_prime = min(max(va, f64(-1023.0)), f64(1024.0)) + + # Decompose x * log2e into `z` (integer part) and `r` (remainder). + vz = round_small_fp64(vz_prime) + vr = vz_prime - vz + + # Compute 2^z. + v2z = exp2_round(vz) + v2z = copynan(v2z, va) + + vp = eval_polynomial(vr, p) + vq = eval_polynomial(vr, q) + + # This is 1 / (1 + v2z * vp / vq), rearranged to avoid the extra division. + vx = vq / multiply_add(v2z, vp, vq) + + return store(vx, x)