Skip to content

Commit 96cf21a

Browse files
committed
Update: use Charbonnier transform
1 parent 2ae8d8a commit 96cf21a

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

uncertaintyx/b/jax.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,22 +198,27 @@ def nnls(c: Array, rhs: Array):
198198
Non-negative least-squares solver.
199199
200200
Uses QR factorization to compute a stable unconstrained
201-
solution. Applies a softplus transformation and an L-BFGS
201+
solution. Applies a positive transformation and an L-BFGS
202202
optimizer to ensure non-negativity.
203203
"""
204204

205205
def forward(u: Array) -> Array:
206-
"""The forward softplus transformation."""
207-
return jnp.log(1.0 + jnp.exp(u))
206+
r"""
207+
The forward (Charbonnier) transformation.
208+
209+
Asymptotic limits are :math:`2u` for :math:`u \to \infty` and
210+
zero for :math:`u \to -\infty`.
211+
"""
212+
return 0.25 * jnp.square(u + jnp.sqrt(jnp.square(u) + 4.0))
208213

209-
def inverse(c_: Array) -> Array:
210-
"""The inverse softplus transformation."""
211-
return jnp.log(jnp.expm1(c_))
214+
def inverse(c: Array) -> Array:
215+
"""The inverse (Charbonnier) transformation."""
216+
return (c - 1.0) / jnp.sqrt(c)
212217

213218
def misfit(u: Array, _: None = None) -> Array:
214219
"""The misfit function with quadratic transformation."""
215220
c_ = forward(u)
216-
return 0.5 * jnp.sum(c_ * hvp(c_)) - jnp.sum(c_ * rhs)
221+
return 0.5 * jnp.vdot(c_, hvp(c_)) - jnp.vdot(c_, rhs)
217222

218223
def make_minimizer():
219224
"""Returns the minimizer."""

0 commit comments

Comments
 (0)