|
| 1 | +""" |
| 2 | +v0.5.0 introduced a new implementation for `diffrax.VirtualBrownianTree` that is |
| 3 | +additionally capable of computing Levy area. |
| 4 | +
|
| 5 | +Here we check the speed of the new implementation against the old implementation, to be |
| 6 | +sure that it is still fast. |
| 7 | +""" |
| 8 | + |
| 9 | +import timeit |
| 10 | +from typing import cast, Optional, Union |
| 11 | +from typing_extensions import TypeAlias |
| 12 | + |
| 13 | +import equinox as eqx |
| 14 | +import equinox.internal as eqxi |
| 15 | +import jax |
| 16 | +import jax.lax as lax |
| 17 | +import jax.numpy as jnp |
| 18 | +import jax.random as jr |
| 19 | +import jax.tree_util as jtu |
| 20 | +import lineax.internal as lxi |
| 21 | +import numpy as np |
| 22 | +from diffrax import AbstractBrownianPath, VirtualBrownianTree |
| 23 | +from jaxtyping import Array, Float, PRNGKeyArray, PyTree, Real |
| 24 | + |
| 25 | + |
| 26 | +RealScalarLike: TypeAlias = Real[Union[int, float, Array, np.ndarray], ""] |
| 27 | + |
| 28 | + |
| 29 | +class _State(eqx.Module): |
| 30 | + s: RealScalarLike |
| 31 | + t: RealScalarLike |
| 32 | + u: RealScalarLike |
| 33 | + w_s: Float[Array, " *shape"] |
| 34 | + w_t: Float[Array, " *shape"] |
| 35 | + w_u: Float[Array, " *shape"] |
| 36 | + key: PRNGKeyArray |
| 37 | + |
| 38 | + |
| 39 | +class OldVBT(AbstractBrownianPath): |
| 40 | + t0: RealScalarLike |
| 41 | + t1: RealScalarLike |
| 42 | + tol: RealScalarLike |
| 43 | + shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) |
| 44 | + key: PRNGKeyArray |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + t0: RealScalarLike, |
| 49 | + t1: RealScalarLike, |
| 50 | + tol: RealScalarLike, |
| 51 | + shape: tuple[int, ...], |
| 52 | + key: PRNGKeyArray, |
| 53 | + levy_area: str, |
| 54 | + ): |
| 55 | + assert levy_area == "" |
| 56 | + self.t0 = t0 |
| 57 | + self.t1 = t1 |
| 58 | + self.tol = tol |
| 59 | + self.shape = jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) |
| 60 | + self.key = key |
| 61 | + |
| 62 | + @property |
| 63 | + def levy_area(self): |
| 64 | + assert False |
| 65 | + |
| 66 | + @eqx.filter_jit |
| 67 | + def evaluate( |
| 68 | + self, |
| 69 | + t0: RealScalarLike, |
| 70 | + t1: Optional[RealScalarLike] = None, |
| 71 | + left: bool = True, |
| 72 | + use_levy: bool = False, |
| 73 | + ) -> PyTree[Array]: |
| 74 | + del left, use_levy |
| 75 | + t0 = eqxi.nondifferentiable(t0, name="t0") |
| 76 | + if t1 is None: |
| 77 | + return self._evaluate(t0) |
| 78 | + else: |
| 79 | + t1 = cast(RealScalarLike, eqxi.nondifferentiable(t1, name="t1")) |
| 80 | + return jtu.tree_map( |
| 81 | + lambda x, y: x - y, |
| 82 | + self._evaluate(t1), |
| 83 | + self._evaluate(t0), |
| 84 | + ) |
| 85 | + |
| 86 | + def _evaluate(self, τ: RealScalarLike) -> PyTree[Array]: |
| 87 | + map_func = lambda key, struct: self._evaluate_leaf(key, τ, struct) |
| 88 | + return jtu.tree_map(map_func, self.key, self.shape) |
| 89 | + |
| 90 | + def _brownian_bridge(self, s, t, u, w_s, w_u, key, shape, dtype): |
| 91 | + mean = w_s + (w_u - w_s) * ((t - s) / (u - s)) |
| 92 | + var = (u - t) * (t - s) / (u - s) |
| 93 | + std = jnp.sqrt(var) |
| 94 | + return mean + std * jr.normal(key, shape, dtype) |
| 95 | + |
| 96 | + def _evaluate_leaf( |
| 97 | + self, |
| 98 | + key, |
| 99 | + τ: RealScalarLike, |
| 100 | + struct: jax.ShapeDtypeStruct, |
| 101 | + ) -> Array: |
| 102 | + shape, dtype = struct.shape, struct.dtype |
| 103 | + |
| 104 | + cond = self.t0 < self.t1 |
| 105 | + t0 = jnp.where(cond, self.t0, self.t1).astype(dtype) |
| 106 | + t1 = jnp.where(cond, self.t1, self.t0).astype(dtype) |
| 107 | + |
| 108 | + t0 = eqxi.error_if( |
| 109 | + t0, |
| 110 | + τ < t0, |
| 111 | + "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].", |
| 112 | + ) |
| 113 | + t1 = eqxi.error_if( |
| 114 | + t1, |
| 115 | + τ > t1, |
| 116 | + "Cannot evaluate VirtualBrownianTree outside of its range [t0, t1].", |
| 117 | + ) |
| 118 | + τ = jnp.clip(τ, t0, t1).astype(dtype) |
| 119 | + |
| 120 | + key, init_key = jr.split(key, 2) |
| 121 | + thalf = t0 + 0.5 * (t1 - t0) |
| 122 | + w_t1 = jr.normal(init_key, shape, dtype) * jnp.sqrt(t1 - t0) |
| 123 | + w_thalf = self._brownian_bridge(t0, thalf, t1, 0, w_t1, key, shape, dtype) |
| 124 | + init_state = _State( |
| 125 | + s=t0, |
| 126 | + t=thalf, |
| 127 | + u=t1, |
| 128 | + w_s=jnp.zeros_like(w_t1), |
| 129 | + w_t=w_thalf, |
| 130 | + w_u=w_t1, |
| 131 | + key=key, |
| 132 | + ) |
| 133 | + |
| 134 | + def _cond_fun(_state): |
| 135 | + return (_state.u - _state.s) > self.tol |
| 136 | + |
| 137 | + def _body_fun(_state): |
| 138 | + _key1, _key2 = jr.split(_state.key, 2) |
| 139 | + _cond = τ > _state.t |
| 140 | + _s = jnp.where(_cond, _state.t, _state.s) |
| 141 | + _u = jnp.where(_cond, _state.u, _state.t) |
| 142 | + _w_s = jnp.where(_cond, _state.w_t, _state.w_s) |
| 143 | + _w_u = jnp.where(_cond, _state.w_u, _state.w_t) |
| 144 | + _key = jnp.where(_cond, _key1, _key2) |
| 145 | + _t = _s + 0.5 * (_u - _s) |
| 146 | + _w_t = self._brownian_bridge(_s, _t, _u, _w_s, _w_u, _key, shape, dtype) |
| 147 | + return _State(s=_s, t=_t, u=_u, w_s=_w_s, w_t=_w_t, w_u=_w_u, key=_key) |
| 148 | + |
| 149 | + final_state = lax.while_loop(_cond_fun, _body_fun, init_state) |
| 150 | + |
| 151 | + s = final_state.s |
| 152 | + u = final_state.u |
| 153 | + w_s = final_state.w_s |
| 154 | + w_t = final_state.w_t |
| 155 | + w_u = final_state.w_u |
| 156 | + rescaled_τ = (τ - s) / (u - s) |
| 157 | + |
| 158 | + A = jnp.array([[2, -4, 2], [-3, 4, -1], [1, 0, 0]]) |
| 159 | + coeffs = jnp.tensordot(A, jnp.stack([w_s, w_t, w_u]), axes=1) |
| 160 | + return jnp.polyval(coeffs, rescaled_τ) |
| 161 | + |
| 162 | + |
| 163 | +key = jr.PRNGKey(0) |
| 164 | +t0, t1 = 0.3, 20.3 |
| 165 | + |
| 166 | + |
| 167 | +def time_tree(tree_cls, num_ts, tol, levy_area): |
| 168 | + tree = tree_cls(t0=t0, t1=t1, tol=tol, shape=(10,), key=key, levy_area=levy_area) |
| 169 | + |
| 170 | + if num_ts == 1: |
| 171 | + ts = 11.2 |
| 172 | + |
| 173 | + @jax.jit |
| 174 | + @eqx.debug.assert_max_traces(max_traces=1) |
| 175 | + def run(_ts): |
| 176 | + return tree.evaluate(_ts, use_levy=True) |
| 177 | + else: |
| 178 | + ts = jnp.linspace(t0, t1, num_ts) |
| 179 | + |
| 180 | + @jax.jit |
| 181 | + @eqx.debug.assert_max_traces(max_traces=1) |
| 182 | + def run(_ts): |
| 183 | + return jax.vmap(lambda _t: tree.evaluate(_t, use_levy=True))(_ts) |
| 184 | + |
| 185 | + return min( |
| 186 | + timeit.repeat(lambda: jax.block_until_ready(run(ts)), number=1, repeat=100) |
| 187 | + ) |
| 188 | + |
| 189 | + |
| 190 | +for levy_area in ("", "space-time"): |
| 191 | + print(f"- {levy_area=}") |
| 192 | + for tol in (2**-3, 2**-12): |
| 193 | + print(f"-- {tol=}") |
| 194 | + for num_ts in (1, 100): |
| 195 | + print(f"--- {num_ts=}") |
| 196 | + if levy_area == "": |
| 197 | + print(f"Old: {time_tree(OldVBT, num_ts, tol, levy_area):.5f}") |
| 198 | + print(f"new: {time_tree(VirtualBrownianTree, num_ts, tol, levy_area):.5f}") |
| 199 | + print("") |
0 commit comments