99 from ngcsimlib ._src .process .baseProcess import BaseProcess
1010
1111class JaxCompiledMethod (CompiledMethod ):
12+ """
13+ A wrapper for a compiled method that includes jax's jit wrapped. Used
14+ exclusively by the mixin and shouldn't be used elsewhere.
15+ """
1216 def __init__ (self , fn , fn_ast , auxiliary_ast , namespace , extra_globals ):
1317 super ().__init__ (fn , fn_ast , auxiliary_ast , namespace , extra_globals )
1418 self ._fn = jax .jit (fn )
1519 self ._fn_source = fn
1620
1721 @property
1822 def source_fn (self ):
23+ """
24+ The source method not wrapped in jit
25+ """
1926 return self ._fn_source
2027
2128 @classmethod
2229 def wrap (cls , compiledMethod : CompiledMethod ):
30+ """
31+ Helper method to expand on a base compiled method
32+ Args:
33+ compiledMethod: The method to be expanded upon
34+ Returns: the JaxCompiledMethod based on the input
35+ """
2336 return cls (compiledMethod ._fn ,
2437 compiledMethod .ast ,
2538 compiledMethod .auxiliary_ast ,
@@ -28,35 +41,68 @@ def wrap(cls, compiledMethod: CompiledMethod):
2841
2942
3043class JaxProcessesMixin :
44+ """
45+ A mixin for the base Process that adds JAX functionality such as scan and
46+ implicit jit wrapping
47+ """
3148 def __init__ (self : "BaseProcess" , name , * args , use_jit = True , ** kwargs ):
49+ """
50+ Look at the BaseProcess class for information about other arguments
51+ Args:
52+ use_jit: a flag for if the process should implicitly jit wrap
53+ """
3254 super ().__init__ (name , * args , ** kwargs )
3355 self ._previous_result = None
3456 self ._previous_state = None
3557 self ._use_jit = use_jit
3658
3759 @property
3860 def previous_result (self ):
61+ """
62+ Stores and returns the last result of scan (the second returned value)
63+ """
3964 return self ._previous_result
4065
4166 @property
4267 def previous_state (self ):
68+ """
69+ Stores and returns the last returned state of scan (the first returned
70+ value)
71+ """
4372 return self ._previous_state
4473
4574 def clear (self ):
75+ """
76+ Clears out the previous result and state from scan
77+ """
4678 self ._previous_result = None
4779 self ._previous_state = None
4880
4981
50- def scan (self : "BaseProcess" , inputs , current_state = None , save_state : bool = True , store_results : bool = True ):
82+ def scan (self : "BaseProcess" , inputs , current_state = None , store_state : bool = True , store_results : bool = True ):
83+ """
84+ Runs the process through jax's scan method
85+ Args:
86+ inputs: The inputs for scan (use pack rows to generate), must be a jax array
87+ current_state: Optional, the current state of the model, if none uses current global state
88+ store_state: Optional flag, should the final state be stored in the process
89+ store_results: Optional flag, should the final result be stored in the process
90+
91+ Returns: the final state, the final result
92+
93+ """
5194 state = current_state or stateManager .state
5295 final_state , result = jax .lax .scan (self .run .compiled , state , inputs )
53- if save_state :
96+ if store_state :
5497 self ._previous_state = final_state
5598 if store_results :
5699 self ._previous_result = result
57100 return final_state , result
58101
59102 def compile (self : "baseProcess" ):
103+ """
104+ For use by the compiler
105+ """
60106 super ().compile ()
61107 if self ._use_jit :
62108 self .run .compiled = JaxCompiledMethod .wrap (self .run .compiled )
0 commit comments