Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ where
});

c.bench_functions(&format!("scalar/{}", name),
vec![scalar_baseline, scalar_full, scalar_raw, scalar_std],
vec![scalar_baseline, scalar_raw, scalar_full, scalar_std],
values);

let vector_baseline = Fun::new(
Expand Down Expand Up @@ -84,7 +84,7 @@ where
});

c.bench_functions(&format!("vector/{}", name),
vec![vector_baseline, vector_full, vector_raw, vector_std],
vec![vector_baseline, vector_raw, vector_full, vector_std],
values);
}

Expand Down Expand Up @@ -158,5 +158,15 @@ fn bench_atan2(c: &mut Criterion) {
c.bench_functions("scalar/atan2", vec![baseline, full, std], values);
}

criterion_group!(benches, bench_log2, bench_exp, bench_exp2, bench_atan, bench_atan2);
fn bench_tanh(c: &mut Criterion) {
let values = &[
0.85708036, -2.43390621, 2.80163358, -2.55126348, 3.18046186,
-2.88689427, 0.32215155, -0.07701401, 1.22922506, -0.4580259,
0.01257442, -4.23107197, 0.89538113, -1.65219582, 0.14632742,
-1.68663984, 1.88125115, -2.16773942, 1.27461936, -1.03091265
];
bench(c, "tanh", values, &fast_math::tanh, &fast_math::tanh_raw, &f32::tanh)
}

criterion_group!(benches, bench_log2, bench_exp, bench_exp2, bench_atan, bench_atan2, bench_tanh);
criterion_main!(benches);
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ extern crate ieee754;
pub use log::{log2, log2_raw};
pub use atan::{atan_raw, atan, atan2};
pub use exp::{exp_raw, exp2_raw, exp, exp2};
pub use tanh::{tanh, tanh_raw};

mod log;
mod atan;
mod exp;
mod tanh;

#[doc(hidden)]
pub mod float;
166 changes: 166 additions & 0 deletions src/tanh.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
use ieee754::Ieee754;

/// Calculate the numerator of the `tanh` approximation.
fn a(x: f32) -> f32 {
let x2 = x * x;
(((x2 + 378.) * x2 + 17325.) * x2 + 135135.) * x
}

/// Calculate the denominator of the `tanh` approximation.
fn b(x: f32) -> f32 {
let x2 = x * x;
((28. * x2 + 3150.) * x2 + 62370.) * x2 + 135135.
}

/// Compute a fast approximation of the hyperbolic tangent of `x`.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Compute a fast approximation of the hyperbolic tangent of `x`.
/// Compute a fast approximation of the hyperbolic tangent of `x` for -4 < `x` < 4.

///
/// For large |x|, the output may be outside of [-1, 1].
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer this to just make no guarantees about the behaviour at all, e.g.

/// This will return unspecified nonsense if `x` is doesn't 
/// satisfy those constraints. Use `tanh` if correct handling is
/// required (at the expense of some speed).

#[inline]
pub fn tanh_raw(x: f32) -> f32 {
// Implementation based on
// https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction
a(x) / b(x)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, did you consider other approaches? E.g.

  1. using a lower-degree continued fraction approximation instead of (7, 6), such as cutting it off at the level with the 5 (this seems to have maximum relative and absolute errors of about 0.02 if it's used on [-2.3, 2.3] and clipped to +/-1 outside that):

    (x2 + 15.) * x / (6. * x2 + 15.)
  2. optimize the parameters of the approximation (just truncating series like the Taylor series of continued fractions won't be the most accurate approximation); for the form (x2 + a) * x / (b * x2 + a) on some interval [-limit, limit] (like the above), I get a = 21.350693, b = 7.8355837, limit = 2.933833 as the best, with relative and absolute errors of approximately 0.0057, which is about as accurate as other functions in fast-math. (I used the approx.py script referenced at the end of this comment.) (I suspect this expensive form could benefit from optimizing its coefficients too.)

  3. Use the tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) form, with an approximate exp such as that described in https://stackoverflow.com/a/50379934/1256624, which could look something like the following in Rust (haven't tested):

    /// Computes an approximation to (exp(x), exp(-x))
    #[inline]
    fn pm_exp(x: f32) -> (f32, f32) {
        const A: f32 = (1 << 23) as f32 / LN_2;
        const B: u32 = 127 << 23;
        let r = (A * x) as i32 as u32;
        (f32::from_bits(B.wrapping_add(r)),
         f32::from_bits(B.wrapping_sub(r)))
    }
    pub fn tanh_raw(x: f32) -> f32 {
        let (plus, minus) = pm_exp(x);
        (plus - minus) / (plus + minus)
    }

    It could also use the exp now in the library, but it's more expensive (uses a quadratic approximation that requires pulling more info out of the floats), and I'm lead to believe that the above typically benefits from some cancellation of errors that doesn't occur for an isolated exp.

My suspicion is that 2. will be a good balance of speed and accuracy, but 3. could surprise me. Do you know any details about the above?

approx.py

Run like python approx.py, may require Python 3.

import numpy as np
from scipy import optimize

def rel_error(approx, true):
    # if the true value is 0, the approximate one must be too
    return np.where(true != 0,
                    np.abs((approx - true) / true),
                    np.where(approx == 0, 0, np.inf))
def abs_error(approx, true):
    return np.abs(approx - true)

f32 = np.float32
def approx(coeffs, points):
    a, b, limit = coeffs
    # approximate with (x^3 + a x) / (b x ^ 2 + a) on the interval
    # [-limit, limit].
    #
    # Why this form?  tanh is odd, so we should have odd / even (and
    # so the unlisted coeffs must be zero), tanh(x) ~= x for small x
    # (so we can share a in the top and bottom so it approximates a x
    # / a = x when x is small).
    points2 = points * points
    poly = (points2 + a) * points / (b * points2 + a)
    return np.where(np.abs(points) <= limit, poly, np.sign(points))

def evaluation(coeffs, points):
    a = approx(coeffs, points)
    t = np.tanh(points)
    rel = rel_error(a, t).max()
    abs = abs_error(a, t).max()
    return (rel, abs)

start = np.array([15, 6, 2.3])
opt_points = np.linspace(-5, 5, 100001)

# optimize on the relative error
result = optimize.fmin(lambda c: evaluation(f32(c), f32(opt_points))[0], start, maxiter=10000)

final_values = approx(result, opt_points)
assert np.all((final_values >= -1) & (final_values <= 1)), "not allowed to overshoot"

rel, abs = evaluation(result, opt_points)
print("a = %s, b = %s, limit = %s" % tuple(f32(result)))
print("on [-5, 5]: rel error = %.6f, abs error = %.6f" % (rel, abs))
for x in np.arange(-5, 5.01, 0.5):
    print("%7.4f: %f (%f)" % (x, approx(result, x), np.tanh(x)))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

/// Compute a fast approximation of the hyperbolic tangent of `x`.
///
/// See `tanh_raw` for a faster version that may return incorrect results for
/// large `|x|` and `nan`.
#[inline]
pub fn tanh(x: f32) -> f32 {
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this could be something like:

pub fn tanh(x: f32) -> f32 {
    if x < -4.97 {
        -1.
    } else if x > 4.97 {
        1.
    } else {
        // if x is NaN, it will propagate through the arithmetic
        tanh_raw(x)
    }
}

This is likely to be easier to vectorize, and does fewer operations. If you rebase/merge this PR onto the latest master (and so can use Ieee754::copy_sign), this could even be:

pub fn tanh(x: f32) -> f32 {
    if x.abs() > 4.97 {
         // the true value |tanh(x)| > 0.9999 when |x| > 4.97, so 
        // rounding to ±1 is close enough
        1_f32.copy_sign(x)
    } else {
        // |tanh_raw(x)| < 1 when |x| <= 4.97, so no post-processing is needed,
        // and x being NaN is handled by propagating through the arithmetic
        tanh_raw(x)
    }
}

With this adjustment, a and b no longer need to be separate functions and can be inlined straight into tanh_raw.

Copy link
Copy Markdown
Contributor Author

@vks vks Mar 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you think clipping like this might be problematic, because it results in discontinuities?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a potential problem. An alternative would be to find when the approximation is exactly +/-1, and clip there instead of 4.97 (I think it should be symmetric?), so that the tanh approximation is continuous (although it's derivative won't be).

if x.is_nan() {
return x;
}

let a = a(x);
if !a.is_finite() {
return 1_f32.copy_sign(a);
}

let result = a / b(x);
if result.abs() > 1. {
return 1_f32.copy_sign(result);
}
result
}

#[cfg(test)]
mod tests {
use super::*;
use quickcheck as qc;
use std::f32 as f;
use ieee754::Ieee754;

/// Maximal absolute error.
Comment thread
vks marked this conversation as resolved.
const TOL_ABS: f32 = 0.0001;

/// Maximal relative error.
const TOL_REL: f32 = 0.0001;

#[test]
fn tanh_err_qc() {
fn prop(x: f32) -> qc::TestResult {
let e = tanh(x);
let t = x.tanh();
let abs = (e - t).abs();
let rel = e.rel_error(t).abs();

qc::TestResult::from_bool(abs < TOL_ABS && rel < TOL_REL)
}
qc::quickcheck(prop as fn(f32) -> qc::TestResult)
}

const PREC: u32 = 1 << 20;
#[test]
fn tanh_err_exhaustive() {
for i in 0..PREC + 1 {
for j in -5..6 {
let x = (1.0 + i as f32 / PREC as f32) * 2f32.powi(j * 20);
{
let e = tanh(x);
let t = x.tanh();
let abs = (e - t).abs();
let rel = e.rel_error(t).abs();

assert!(abs < TOL_ABS,
"{:.8}: {:.8}, {:.8}. {:.4}", x, e, t, abs);
assert!(rel < TOL_REL,
"{:.8}: {:.8}, {:.8}. {:.4}", x, e, t, rel);
}
{
let e = tanh(-x);
let t = (-x).tanh();
let abs = (e - t).abs();
let rel = e.rel_error(t).abs();

assert!(abs < TOL_ABS,
"{:.8}: {:.8}, {:.8}. {:.4}", -x, e, t, abs);
assert!(rel < TOL_REL,
"{:.8}: {:.8}, {:.8}. {:.4}", x, e, t, rel);
}
}
}
}

#[test]
fn tanh_edge_cases() {
assert!(tanh(f::NAN).is_nan());
assert_eq!(tanh(f::NEG_INFINITY), -1.);
assert_eq!(tanh(f::INFINITY), 1.);
}

#[test]
fn tanh_denormals() {
fn prop(x: u8, y: u16) -> bool {
let signif = ((x as u32) << 16) | (y as u32);
let mut x = f32::recompose_raw(false, 1, signif);

for _ in 0..23 {
{
let e = tanh(x);
let t = x.tanh();
let abs = (e - t).abs();
let rel = e.rel_error(t).abs();
if abs >= TOL_ABS && rel >= TOL_REL {
return false
}
}
{
let e = tanh(-x);
let t = (-x).tanh();
let abs = (e - t).abs();
let rel = e.rel_error(t).abs();
if abs >= TOL_ABS && rel >= TOL_REL {
return false
}
}

x /= 2.0;
}
true
}
qc::quickcheck(prop as fn(u8, u16) -> bool)
}

#[test]
fn tanh_raw_denormals() {
fn prop(x: u8, y: u16) -> bool {
let signif = ((x as u32) << 16) | (y as u32);
let mut x = f32::recompose_raw(false, 1, signif);

for _ in 0..23 {
let e = tanh_raw(x);
let t = x.tanh();
let abs = (e - t).abs();
let rel = e.rel_error(t).abs();
if abs >= TOL_ABS && rel >= TOL_REL {
return false
}

x /= 2.0;
}
true
}
qc::quickcheck(prop as fn(u8, u16) -> bool)
}
}