Skip to content

Commit 2c538af

Browse files
Merge pull request #149 from jacobusmmsmit/fix-jtu-warning
jax -> jtu for relevant functions
2 parents 8f18b55 + fc785a5 commit 2c538af

6 files changed

Lines changed: 19 additions & 21 deletions

File tree

diffrax/adjoint.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any, Dict
33

44
import equinox as eqx
5-
import jax
65
import jax.lax as lax
76
import jax.numpy as jnp
87
import jax.tree_util as jtu
@@ -171,7 +170,7 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs):
171170
del throw
172171
y, args, terms = y__args__terms
173172
init_state = eqx.tree_at(
174-
lambda s: jax.tree_leaves(s.y), init_state, jax.tree_leaves(y)
173+
lambda s: jtu.tree_leaves(s.y), init_state, jtu.tree_leaves(y)
175174
)
176175
del y
177176
return self._loop_fn(
@@ -409,7 +408,7 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs):
409408
y = init_state.y
410409
sentinel = object()
411410
init_state = eqx.tree_at(
412-
lambda s: jax.tree_leaves(s.y), init_state, replace_fn=lambda _: sentinel
411+
lambda s: jtu.tree_leaves(s.y), init_state, replace_fn=lambda _: sentinel
413412
)
414413

415414
final_state, aux_stats = _loop_backsolve(
@@ -421,7 +420,7 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs):
421420
ys = final_state.ys
422421
final_state = jtu.tree_map(nondifferentiable_output, final_state)
423422
final_state = eqx.tree_at(
424-
lambda s: jax.tree_leaves(s.ys), final_state, jax.tree_leaves(ys)
423+
lambda s: jtu.tree_leaves(s.ys), final_state, jtu.tree_leaves(ys)
425424
)
426425

427426
return final_state, aux_stats

diffrax/global_interpolation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ def rectilinear_interpolation(
547547
out = jtu.tree_map(fn, replace_nans_at_start, ys)
548548
ys_treedef = jtu.tree_structure(ys)
549549
interp_treedef = jtu.tree_structure((0, 0))
550-
return jax.tree_transpose(ys_treedef, interp_treedef, out)
550+
return jtu.tree_transpose(ys_treedef, interp_treedef, out)
551551

552552

553553
def _hermite_forward(
@@ -725,4 +725,4 @@ def backward_hermite_coefficients(
725725
coeffs = jtu.tree_map(fn, ys, deriv0, replace_nans_at_start)
726726
ys_treedef = jtu.tree_structure(ys)
727727
coeffs_treedef = jtu.tree_structure((0, 0, 0, 0))
728-
return jax.tree_transpose(ys_treedef, coeffs_treedef, coeffs)
728+
return jtu.tree_transpose(ys_treedef, coeffs_treedef, coeffs)

diffrax/integrate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ def _save(state: _State, t: Scalar) -> _State:
7878
save_index = save_index + 1
7979

8080
return eqx.tree_at(
81-
lambda s: [s.ts, s.save_index] + jax.tree_leaves(s.ys),
81+
lambda s: [s.ts, s.save_index] + jtu.tree_leaves(s.ys),
8282
state,
83-
[ts, save_index] + jax.tree_leaves(ys),
83+
[ts, save_index] + jtu.tree_leaves(ys),
8484
)
8585

8686

diffrax/solver/milstein.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def step(
183183
leaf = leaf - Δt * eye
184184
leaves_ΔwΔw.append(leaf)
185185
tree_ΔwΔw = tree_Δw.compose(tree_Δw)
186-
ΔwΔw = jax.tree_unflatten(tree_ΔwΔw, leaves_ΔwΔw)
186+
ΔwΔw = jtu.tree_unflatten(tree_ΔwΔw, leaves_ΔwΔw)
187187
# ΔwΔw has structure (tree(Δw), tree(Δw), leaf(Δw), leaf(Δw))
188188

189189
#
@@ -250,11 +250,11 @@ def _to_treemap(_Δw, _g0):
250250
# Which we now transform into its isomorphic matrix form, as above.
251251
g0_matrix = jax.jacfwd(lambda _Δw: diffusion.prod(g0, _Δw))(Δw)
252252
# g0_matrix has structure (tree(y0), tree(Δw), leaf(y0), leaf(Δw))
253-
g0_matrix = jax.tree_transpose(y_treedef, Δw_treedef, g0_matrix)
253+
g0_matrix = jtu.tree_transpose(y_treedef, Δw_treedef, g0_matrix)
254254
# g0_matrix has structure (tree(Δw), tree(y0), leaf(y0), leaf(Δw))
255255
v0_matrix = jtu.tree_map(_to_treemap, Δw, g0_matrix)
256256
# v0_matrix has structure (tree(Δw), tree(y0), tree(Δw), leaf(y0), leaf(Δw), leaf(Δw)) # noqa: E501
257-
v0_matrix = jax.tree_transpose(
257+
v0_matrix = jtu.tree_transpose(
258258
Δw_treedef, y_treedef.compose(Δw_treedef), v0_matrix
259259
)
260260
# v0_matrix has structure (tree(y0), tree(Δw), tree(Δw), leaf(y0), leaf(Δw), leaf(Δw)) # noqa: E501
@@ -275,7 +275,7 @@ def _dot(_, _v0):
275275
# ΔwΔw has structure (tree(Δw), tree(Δw), leaf(Δw), leaf(Δw))
276276
_dotted = jtu.tree_map(__dot, _v0, ΔwΔw)
277277
# _dotted has structure (tree(Δw), tree(Δw), leaf(y0))
278-
_out = sum(jax.tree_leaves(_dotted))
278+
_out = sum(jtu.tree_leaves(_dotted))
279279
# _out has structure (leaf(y0),)
280280
return _out
281281

diffrax/term.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def is_vf_expensive(
383383
args: PyTree,
384384
) -> bool:
385385
control = self.contr(t0, t1)
386-
if sum(c.size for c in jax.tree_leaves(control)) in (0, 1):
386+
if sum(c.size for c in jtu.tree_leaves(control)) in (0, 1):
387387
return False
388388
else:
389389
return True
@@ -412,8 +412,8 @@ def vf(
412412
# PyTree structure. (This is because `self.vf_prod` is linear in `control`.)
413413
control = self.contr(t, t)
414414

415-
y_size = sum(np.size(yi) for yi in jax.tree_leaves(y))
416-
control_size = sum(np.size(ci) for ci in jax.tree_leaves(control))
415+
y_size = sum(np.size(yi) for yi in jtu.tree_leaves(y))
416+
control_size = sum(np.size(ci) for ci in jtu.tree_leaves(control))
417417
if y_size > control_size:
418418
make_jac = jax.jacfwd
419419
else:
@@ -441,7 +441,7 @@ def _fn(_control):
441441
raise NotImplementedError(
442442
"`AdjointTerm.vf` not implemented for `None` controls or states."
443443
)
444-
return jax.tree_transpose(vf_prod_tree, control_tree, jac)
444+
return jtu.tree_transpose(vf_prod_tree, control_tree, jac)
445445

446446
def contr(self, t0: Scalar, t1: Scalar) -> PyTree:
447447
return self.term.contr(t0, t1)
@@ -467,16 +467,16 @@ def _get_vf_tree(_, tree):
467467
jtu.tree_map(_get_vf_tree, control, vf)
468468
assert vf_prod_tree is not sentinel
469469

470-
vf = jax.tree_transpose(control_tree, vf_prod_tree, vf)
470+
vf = jtu.tree_transpose(control_tree, vf_prod_tree, vf)
471471

472-
example_vf_prod = jax.tree_unflatten(
472+
example_vf_prod = jtu.tree_unflatten(
473473
vf_prod_tree, [0 for _ in range(vf_prod_tree.num_leaves)]
474474
)
475475

476476
def _contract(_, vf_piece):
477477
assert jtu.tree_structure(vf_piece) == control_tree
478478
_contracted = jtu.tree_map(_prod, vf_piece, control)
479-
return sum(jax.tree_leaves(_contracted), 0)
479+
return sum(jtu.tree_leaves(_contracted), 0)
480480

481481
return jtu.tree_map(_contract, example_vf_prod, vf)
482482

test/helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import diffrax
77
import equinox as eqx
8-
import jax
98
import jax.numpy as jnp
109
import jax.random as jrandom
1110
import jax.tree_util as jtu
@@ -59,7 +58,7 @@ def random_pytree(key, treedef):
5958
dim_sizes = jrandom.randint(sizekey, (num_dims,), 0, 5)
6059
value = jrandom.normal(valuekey, dim_sizes)
6160
leaves.append(value)
62-
return jax.tree_unflatten(treedef, leaves)
61+
return jtu.tree_unflatten(treedef, leaves)
6362

6463

6564
treedefs = [

0 commit comments

Comments
 (0)