The TensorFlow probability implementation of softplus leaks memory, and appears to no longer be needed. That is, I think the standard tf.nn.softplus implementation can be used now, as numerical stability issues appear to have been solved.
Currently the implementation of softplus is as follows (from here):
# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
_stable_grad_softplus = tf.nn.softplus
else:
@tf.custom_gradient
def _stable_grad_softplus(x):
"""A (more) numerically stable softplus than `tf.nn.softplus`."""
x = tf.convert_to_tensor(x)
if x.dtype == tf.float64:
cutoff = -20
else:
cutoff = -9
y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))
def grad_fn(dy):
return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))
return y, grad_fn
This leaks memory (in non-JAX mode) due to a couple of issues:
- The
grad_fn closure captures the tensor represented by x. This closure then ends up in the gradient registry, which is never cleared. So the tensor represented by x hangs around forever.
- For a similar reason TensorFlow's
custom_gradient implementation also leaks memory. See 97697 for more details.
Here is a Colab notebook to demonstrate the memory leak.
However, I believe that the numerical stability issues with tf.nn.softplus have been solved. Specifically:
- The
tf.nn.softplus implementation now uses log1p as of this commit on May 1 2020.
- The gradient computation for
tf.nn.softplus now uses math_ops.sigmoid as of this commit on April 4 2019.
- The Eigen implementation of sigmoid (which I think is here) computes this as
e^x / 1.0 + e^x, so using the approximation of e^x in _stable_grad_softplus seems unnecessary to me. If e^x is very small then 1.0 + e^x will be exactly 1.0, so this is equivalent to e^x. If e^x > 1.0 then the result of e^x / 1.0 + e^x will be (I think) more accurate than just approximating the gradient to e^x. But I am not a numerical stability expert, so I may be wrong.
The TensorFlow probability implementation of softplus leaks memory, and appears to no longer be needed. That is, I think the standard
tf.nn.softplusimplementation can be used now, as numerical stability issues appear to have been solved.Currently the implementation of softplus is as follows (from here):
This leaks memory (in non-JAX mode) due to a couple of issues:
grad_fnclosure captures the tensor represented byx. This closure then ends up in the gradient registry, which is never cleared. So the tensor represented byxhangs around forever.custom_gradientimplementation also leaks memory. See 97697 for more details.Here is a Colab notebook to demonstrate the memory leak.
However, I believe that the numerical stability issues with
tf.nn.softplushave been solved. Specifically:tf.nn.softplusimplementation now useslog1pas of this commit on May 1 2020.tf.nn.softplusnow usesmath_ops.sigmoidas of this commit on April 4 2019.e^x / 1.0 + e^x, so using the approximation ofe^xin_stable_grad_softplusseems unnecessary to me. Ife^xis very small then1.0 + e^xwill be exactly 1.0, so this is equivalent toe^x. Ife^x> 1.0 then the result ofe^x / 1.0 + e^xwill be (I think) more accurate than just approximating the gradient toe^x. But I am not a numerical stability expert, so I may be wrong.