11import math
2- from typing import Literal , Optional , TypeVar , Union
3- from typing_extensions import TypeAlias
2+ from typing import Literal , TypeAlias , TypeVar
43
54import equinox as eqx
65import equinox .internal as eqxi
7069class _LevyVal (eqx .Module ):
7170 dt : RealScalarLike
7271 W : Array
73- H : Optional [ Array ]
74- bar_H : Optional [ Array ]
75- K : Optional [ Array ]
76- bar_K : Optional [ Array ]
72+ H : Array | None
73+ bar_H : Array | None
74+ K : Array | None
75+ bar_K : Array | None
7776
7877 def __check_init__ (self ):
7978 if self .H is None :
@@ -88,8 +87,8 @@ class _State(eqx.Module):
8887 s : RealScalarLike # starting time of the interval
8988 w_s_u_su : FloatTriple # W_s, W_u, W_{s,u}
9089 key : PRNGKeyArray
91- bhh_s_u_su : Optional [ FloatTriple ] # \bar{H}_s, _u, _{s,u}
92- bkk_s_u_su : Optional [ FloatTriple ] # \bar{K}_s, _u, _{s,u}
90+ bhh_s_u_su : FloatTriple | None # \bar{H}_s, _u, _{s,u}
91+ bkk_s_u_su : FloatTriple | None # \bar{K}_s, _u, _{s,u}
9392
9493
9594def _levy_diff (_ , x0 : _LevyVal , x1 : _LevyVal ) -> AbstractBrownianIncrement :
@@ -161,8 +160,8 @@ def _make_levy_val(_, x: _LevyVal) -> AbstractBrownianIncrement:
161160
162161
163162def _split_interval (
164- pred : BoolScalarLike , x_stu : Optional [ FloatTriple ] , x_st_tu : Optional [ FloatDouble ]
165- ) -> Optional [ FloatTriple ] :
163+ pred : BoolScalarLike , x_stu : FloatTriple | None , x_st_tu : FloatDouble | None
164+ ) -> FloatTriple | None :
166165 if x_stu is None :
167166 assert x_st_tu is None
168167 return None
@@ -237,7 +236,7 @@ class VirtualBrownianTree(AbstractBrownianPath):
237236 tol : RealScalarLike
238237 shape : PyTree [jax .ShapeDtypeStruct ] = eqx .field (static = True )
239238 levy_area : type [
240- Union [ BrownianIncrement , SpaceTimeLevyArea , SpaceTimeTimeLevyArea ]
239+ BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea
241240 ] = eqx .field (static = True )
242241 key : PyTree [PRNGKeyArray ]
243242 _spline : _Spline = eqx .field (static = True )
@@ -248,10 +247,10 @@ def __init__(
248247 t0 : RealScalarLike ,
249248 t1 : RealScalarLike ,
250249 tol : RealScalarLike ,
251- shape : Union [ tuple [int , ...], PyTree [jax .ShapeDtypeStruct ] ],
250+ shape : tuple [int , ...] | PyTree [jax .ShapeDtypeStruct ],
252251 key : PRNGKeyArray ,
253252 levy_area : type [
254- Union [ BrownianIncrement , SpaceTimeLevyArea , SpaceTimeTimeLevyArea ]
253+ BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea
255254 ] = BrownianIncrement ,
256255 _spline : _Spline = "sqrt" ,
257256 ):
@@ -327,10 +326,10 @@ def is_dt(z):
327326 def evaluate (
328327 self ,
329328 t0 : RealScalarLike ,
330- t1 : Optional [ RealScalarLike ] = None ,
329+ t1 : RealScalarLike | None = None ,
331330 left : bool = True ,
332331 use_levy : bool = False ,
333- ) -> Union [ PyTree [Array ], AbstractBrownianIncrement ] :
332+ ) -> PyTree [Array ] | AbstractBrownianIncrement :
334333 """Implements [`diffrax.AbstractBrownianPath.evaluate`][]."""
335334 del left
336335 t0 = eqxi .nondifferentiable (t0 , name = "t0" )
@@ -532,9 +531,9 @@ def _body_fun(_state: _State):
532531 f" 'zero' splines are permitted, got { self ._spline } ."
533532 )
534533
535- hat_w_sr , hat_hh_sr , hat_kk_sr = [
534+ hat_w_sr , hat_hh_sr , hat_kk_sr = (
536535 x .squeeze (axis = - 1 ) for x in jnp .split (hat_y , 3 , axis = - 1 )
537- ]
536+ )
538537 assert hat_w_sr .shape == hat_hh_sr .shape == hat_kk_sr .shape == shape
539538
540539 w_sr = w_mean + hat_w_sr
@@ -631,10 +630,10 @@ def _brownian_arch(
631630 FloatTriple ,
632631 FloatDouble ,
633632 tuple [PRNGKeyArray , PRNGKeyArray ],
634- Optional [ FloatTriple ] ,
635- Optional [ FloatDouble ] ,
636- Optional [ FloatTriple ] ,
637- Optional [ FloatDouble ] ,
633+ FloatTriple | None ,
634+ FloatDouble | None ,
635+ FloatTriple | None ,
636+ FloatDouble | None ,
638637 ]:
639638 r"""For `t = (s+u)/2` evaluates `w_t` and (optionally) `bhh_t`
640639 conditioned on `w_s`, `w_u`, `bhh_s`, `bhh_u`, where
0 commit comments