-
Notifications
You must be signed in to change notification settings - Fork 6
Implement fast tanh #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
10d5e85
d7347f5
62aad3d
bd477b6
aa77262
fe5211a
f0f0549
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| use ieee754::Ieee754; | ||
|
|
||
| /// Calculate the numerator of the `tanh` approximation. | ||
| fn a(x: f32) -> f32 { | ||
| let x2 = x * x; | ||
| (((x2 + 378.) * x2 + 17325.) * x2 + 135135.) * x | ||
| } | ||
|
|
||
| /// Calculate the denominator of the `tanh` approximation. | ||
| fn b(x: f32) -> f32 { | ||
| let x2 = x * x; | ||
| ((28. * x2 + 3150.) * x2 + 62370.) * x2 + 135135. | ||
| } | ||
|
|
||
| /// Compute a fast approximation of the hyperbolic tangent of `x`. | ||
| /// | ||
| /// For large |x|, the output may be outside of [-1, 1]. | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer this to just make no guarantees about the behaviour at all, e.g. /// This will return unspecified nonsense if `x` is doesn't
/// satisfy those constraints. Use `tanh` if correct handling is
/// required (at the expense of some speed).
|
||
| #[inline] | ||
| pub fn tanh_raw(x: f32) -> f32 { | ||
| // Implementation based on | ||
| // https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction | ||
| a(x) / b(x) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of interest, did you consider other approaches? E.g.
My suspicion is that 2. will be a good balance of speed and accuracy, but 3. could surprise me. Do you know any details about the above?
|
||
| } | ||
|
|
||
| /// Compute a fast approximation of the hyperbolic tangent of `x`. | ||
| /// | ||
| /// See `tanh_raw` for a faster version that may return incorrect results for | ||
| /// large `|x|` and `nan`. | ||
| #[inline] | ||
| pub fn tanh(x: f32) -> f32 { | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this could be something like: pub fn tanh(x: f32) -> f32 {
if x < -4.97 {
-1.
} else if x > 4.97 {
1.
} else {
// if x is NaN, it will propagate through the arithmetic
tanh_raw(x)
}
}This is likely to be easier to vectorize, and does fewer operations. If you rebase/merge this PR onto the latest master (and so can use pub fn tanh(x: f32) -> f32 {
if x.abs() > 4.97 {
// the true value |tanh(x)| > 0.9999 when |x| > 4.97, so
// rounding to ±1 is close enough
1_f32.copy_sign(x)
} else {
// |tanh_raw(x)| < 1 when |x| <= 4.97, so no post-processing is needed,
// and x being NaN is handled by propagating through the arithmetic
tanh_raw(x)
}
}With this adjustment,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't you think clipping like this might be problematic, because it results in discontinuities?
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a potential problem. An alternative would be to find when the approximation is exactly +/-1, and clip there instead of 4.97 (I think it should be symmetric?), so that the tanh approximation is continuous (although it's derivative won't be). |
||
| if x.is_nan() { | ||
| return x; | ||
| } | ||
|
|
||
| let a = a(x); | ||
| if !a.is_finite() { | ||
| return 1_f32.copy_sign(a); | ||
| } | ||
|
|
||
| let result = a / b(x); | ||
| if result.abs() > 1. { | ||
| return 1_f32.copy_sign(result); | ||
| } | ||
| result | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use quickcheck as qc; | ||
| use std::f32 as f; | ||
| use ieee754::Ieee754; | ||
|
|
||
| /// Maximal absolute error. | ||
|
vks marked this conversation as resolved.
|
||
| const TOL_ABS: f32 = 0.0001; | ||
|
|
||
| /// Maximal relative error. | ||
| const TOL_REL: f32 = 0.0001; | ||
|
|
||
| #[test] | ||
| fn tanh_err_qc() { | ||
| fn prop(x: f32) -> qc::TestResult { | ||
| let e = tanh(x); | ||
| let t = x.tanh(); | ||
| let abs = (e - t).abs(); | ||
| let rel = e.rel_error(t).abs(); | ||
|
|
||
| qc::TestResult::from_bool(abs < TOL_ABS && rel < TOL_REL) | ||
| } | ||
| qc::quickcheck(prop as fn(f32) -> qc::TestResult) | ||
| } | ||
|
|
||
| const PREC: u32 = 1 << 20; | ||
| #[test] | ||
| fn tanh_err_exhaustive() { | ||
| for i in 0..PREC + 1 { | ||
| for j in -5..6 { | ||
| let x = (1.0 + i as f32 / PREC as f32) * 2f32.powi(j * 20); | ||
| { | ||
| let e = tanh(x); | ||
| let t = x.tanh(); | ||
| let abs = (e - t).abs(); | ||
| let rel = e.rel_error(t).abs(); | ||
|
|
||
| assert!(abs < TOL_ABS, | ||
| "{:.8}: {:.8}, {:.8}. {:.4}", x, e, t, abs); | ||
| assert!(rel < TOL_REL, | ||
| "{:.8}: {:.8}, {:.8}. {:.4}", x, e, t, rel); | ||
| } | ||
| { | ||
| let e = tanh(-x); | ||
| let t = (-x).tanh(); | ||
| let abs = (e - t).abs(); | ||
| let rel = e.rel_error(t).abs(); | ||
|
|
||
| assert!(abs < TOL_ABS, | ||
| "{:.8}: {:.8}, {:.8}. {:.4}", -x, e, t, abs); | ||
| assert!(rel < TOL_REL, | ||
| "{:.8}: {:.8}, {:.8}. {:.4}", x, e, t, rel); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn tanh_edge_cases() { | ||
| assert!(tanh(f::NAN).is_nan()); | ||
| assert_eq!(tanh(f::NEG_INFINITY), -1.); | ||
| assert_eq!(tanh(f::INFINITY), 1.); | ||
| } | ||
|
|
||
| #[test] | ||
| fn tanh_denormals() { | ||
| fn prop(x: u8, y: u16) -> bool { | ||
| let signif = ((x as u32) << 16) | (y as u32); | ||
| let mut x = f32::recompose_raw(false, 1, signif); | ||
|
|
||
| for _ in 0..23 { | ||
| { | ||
| let e = tanh(x); | ||
| let t = x.tanh(); | ||
| let abs = (e - t).abs(); | ||
| let rel = e.rel_error(t).abs(); | ||
| if abs >= TOL_ABS && rel >= TOL_REL { | ||
| return false | ||
| } | ||
| } | ||
| { | ||
| let e = tanh(-x); | ||
| let t = (-x).tanh(); | ||
| let abs = (e - t).abs(); | ||
| let rel = e.rel_error(t).abs(); | ||
| if abs >= TOL_ABS && rel >= TOL_REL { | ||
| return false | ||
| } | ||
| } | ||
|
|
||
| x /= 2.0; | ||
| } | ||
| true | ||
| } | ||
| qc::quickcheck(prop as fn(u8, u16) -> bool) | ||
| } | ||
|
|
||
| #[test] | ||
| fn tanh_raw_denormals() { | ||
| fn prop(x: u8, y: u16) -> bool { | ||
| let signif = ((x as u32) << 16) | (y as u32); | ||
| let mut x = f32::recompose_raw(false, 1, signif); | ||
|
|
||
| for _ in 0..23 { | ||
| let e = tanh_raw(x); | ||
| let t = x.tanh(); | ||
| let abs = (e - t).abs(); | ||
| let rel = e.rel_error(t).abs(); | ||
| if abs >= TOL_ABS && rel >= TOL_REL { | ||
| return false | ||
| } | ||
|
|
||
| x /= 2.0; | ||
| } | ||
| true | ||
| } | ||
| qc::quickcheck(prop as fn(u8, u16) -> bool) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.