1010import lineax .internal as lxi
1111from jaxtyping import Array , PRNGKeyArray , PyTree
1212
13- from .._custom_types import levy_tree_transpose , LevyArea , LevyVal , RealScalarLike
13+ from .._custom_types import (
14+ AbstractBrownianIncrement ,
15+ BrownianIncrement ,
16+ levy_tree_transpose ,
17+ RealScalarLike ,
18+ SpaceTimeLevyArea ,
19+ )
1420from .._misc import (
1521 force_bitcast_convert_type ,
1622 is_tuple_of_ints ,
@@ -42,25 +48,23 @@ class UnsafeBrownianPath(AbstractBrownianPath):
4248 """
4349
4450 shape : PyTree [jax .ShapeDtypeStruct ] = eqx .field (static = True )
45- levy_area : LevyArea = eqx .field (static = True )
51+ levy_area : type [Union [BrownianIncrement , SpaceTimeLevyArea ]] = eqx .field (
52+ static = True
53+ )
4654 key : PRNGKeyArray
4755
4856 def __init__ (
4957 self ,
5058 shape : Union [tuple [int , ...], PyTree [jax .ShapeDtypeStruct ]],
5159 key : PRNGKeyArray ,
52- levy_area : LevyArea = "" ,
60+ levy_area : type [ Union [ BrownianIncrement , SpaceTimeLevyArea ]] ,
5361 ):
5462 self .shape = (
5563 jax .ShapeDtypeStruct (shape , lxi .default_floating_dtype ())
5664 if is_tuple_of_ints (shape )
5765 else shape
5866 )
5967 self .key = key
60- if levy_area not in ["" , "space-time" ]:
61- raise ValueError (
62- f"levy_area must be one of '', 'space-time', but got { levy_area } ."
63- )
6468 self .levy_area = levy_area
6569
6670 if any (
@@ -84,7 +88,7 @@ def evaluate(
8488 t1 : Optional [RealScalarLike ] = None ,
8589 left : bool = True ,
8690 use_levy : bool = False ,
87- ) -> Union [PyTree [Array ], LevyVal ]:
91+ ) -> Union [PyTree [Array ], AbstractBrownianIncrement ]:
8892 del left
8993 if t1 is None :
9094 dtype = jnp .result_type (t0 )
@@ -111,8 +115,8 @@ def evaluate(
111115 self .shape ,
112116 )
113117 if use_levy :
114- out = levy_tree_transpose (self .shape , self . levy_area , out )
115- assert isinstance (out , LevyVal )
118+ out = levy_tree_transpose (self .shape , out )
119+ assert isinstance (out , ( BrownianIncrement , SpaceTimeLevyArea ) )
116120 return out
117121
118122 @staticmethod
@@ -121,25 +125,26 @@ def _evaluate_leaf(
121125 t1 : RealScalarLike ,
122126 key ,
123127 shape : jax .ShapeDtypeStruct ,
124- levy_area : str ,
128+ levy_area : type [ Union [ BrownianIncrement , SpaceTimeLevyArea ]] ,
125129 use_levy : bool ,
126130 ):
127131 w_std = jnp .sqrt (t1 - t0 ).astype (shape .dtype )
132+ w = jr .normal (key , shape .shape , shape .dtype ) * w_std
133+ dt = jnp .asarray (t1 - t0 , dtype = shape .dtype )
128134
129- if levy_area == "space-time" :
135+ if levy_area is SpaceTimeLevyArea :
130136 key , key_hh = jr .split (key , 2 )
131137 hh_std = w_std / math .sqrt (12 )
132138 hh = jr .normal (key_hh , shape .shape , shape .dtype ) * hh_std
133- elif levy_area == "" :
134- hh = None
139+ levy_val = SpaceTimeLevyArea (dt = dt , W = w , H = hh )
140+ elif levy_area is BrownianIncrement :
141+ levy_val = BrownianIncrement (dt = dt , W = w )
135142 else :
136143 assert False
137- w = jr .normal (key , shape .shape , shape .dtype ) * w_std
138144
139145 if use_levy :
140- return LevyVal (dt = t1 - t0 , W = w , H = hh , bar_H = None , K = None , bar_K = None )
141- else :
142- return w
146+ return levy_val
147+ return w
143148
144149
145150UnsafeBrownianPath .__init__ .__doc__ = """
0 commit comments