@@ -383,7 +383,7 @@ def is_vf_expensive(
383383 args : PyTree ,
384384 ) -> bool :
385385 control = self .contr (t0 , t1 )
386- if sum (c .size for c in jax .tree_leaves (control )) in (0 , 1 ):
386+ if sum (c .size for c in jtu .tree_leaves (control )) in (0 , 1 ):
387387 return False
388388 else :
389389 return True
@@ -412,8 +412,8 @@ def vf(
412412 # PyTree structure. (This is because `self.vf_prod` is linear in `control`.)
413413 control = self .contr (t , t )
414414
415- y_size = sum (np .size (yi ) for yi in jax .tree_leaves (y ))
416- control_size = sum (np .size (ci ) for ci in jax .tree_leaves (control ))
415+ y_size = sum (np .size (yi ) for yi in jtu .tree_leaves (y ))
416+ control_size = sum (np .size (ci ) for ci in jtu .tree_leaves (control ))
417417 if y_size > control_size :
418418 make_jac = jax .jacfwd
419419 else :
@@ -441,7 +441,7 @@ def _fn(_control):
441441 raise NotImplementedError (
442442 "`AdjointTerm.vf` not implemented for `None` controls or states."
443443 )
444- return jax .tree_transpose (vf_prod_tree , control_tree , jac )
444+ return jtu .tree_transpose (vf_prod_tree , control_tree , jac )
445445
446446 def contr (self , t0 : Scalar , t1 : Scalar ) -> PyTree :
447447 return self .term .contr (t0 , t1 )
@@ -467,16 +467,16 @@ def _get_vf_tree(_, tree):
467467 jtu .tree_map (_get_vf_tree , control , vf )
468468 assert vf_prod_tree is not sentinel
469469
470- vf = jax .tree_transpose (control_tree , vf_prod_tree , vf )
470+ vf = jtu .tree_transpose (control_tree , vf_prod_tree , vf )
471471
472- example_vf_prod = jax .tree_unflatten (
472+ example_vf_prod = jtu .tree_unflatten (
473473 vf_prod_tree , [0 for _ in range (vf_prod_tree .num_leaves )]
474474 )
475475
476476 def _contract (_ , vf_piece ):
477477 assert jtu .tree_structure (vf_piece ) == control_tree
478478 _contracted = jtu .tree_map (_prod , vf_piece , control )
479- return sum (jax .tree_leaves (_contracted ), 0 )
479+ return sum (jtu .tree_leaves (_contracted ), 0 )
480480
481481 return jtu .tree_map (_contract , example_vf_prod , vf )
482482
0 commit comments