Skip to content

Commit df08087

Browse files
committed
Update JaxProcessesMixin.py
Added comments, and small naming update to the flag to locally store the state to avoid confusion with the global state.
1 parent 5647c68 commit df08087

1 file changed

Lines changed: 48 additions & 2 deletions

File tree

ngclearn/utils/JaxProcessesMixin.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,30 @@
99
from ngcsimlib._src.process.baseProcess import BaseProcess
1010

1111
class JaxCompiledMethod(CompiledMethod):
12+
"""
13+
A wrapper for a compiled method that includes jax's jit wrapped. Used
14+
exclusively by the mixin and shouldn't be used elsewhere.
15+
"""
1216
def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals):
1317
super().__init__(fn, fn_ast, auxiliary_ast, namespace, extra_globals)
1418
self._fn = jax.jit(fn)
1519
self._fn_source = fn
1620

1721
@property
1822
def source_fn(self):
23+
"""
24+
The source method not wrapped in jit
25+
"""
1926
return self._fn_source
2027

2128
@classmethod
2229
def wrap(cls, compiledMethod: CompiledMethod):
30+
"""
31+
Helper method to expand on a base compiled method
32+
Args:
33+
compiledMethod: The method to be expanded upon
34+
Returns: the JaxCompiledMethod based on the input
35+
"""
2336
return cls(compiledMethod._fn,
2437
compiledMethod.ast,
2538
compiledMethod.auxiliary_ast,
@@ -28,35 +41,68 @@ def wrap(cls, compiledMethod: CompiledMethod):
2841

2942

3043
class JaxProcessesMixin:
44+
"""
45+
A mixin for the base Process that adds JAX functionality such as scan and
46+
implicit jit wrapping
47+
"""
3148
def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs):
49+
"""
50+
Look at the BaseProcess class for information about other arguments
51+
Args:
52+
use_jit: a flag for if the process should implicitly jit wrap
53+
"""
3254
super().__init__(name, *args, **kwargs)
3355
self._previous_result = None
3456
self._previous_state = None
3557
self._use_jit = use_jit
3658

3759
@property
3860
def previous_result(self):
61+
"""
62+
Stores and returns the last result of scan (the second returned value)
63+
"""
3964
return self._previous_result
4065

4166
@property
4267
def previous_state(self):
68+
"""
69+
Stores and returns the last returned state of scan (the first returned
70+
value)
71+
"""
4372
return self._previous_state
4473

4574
def clear(self):
75+
"""
76+
Clears out the previous result and state from scan
77+
"""
4678
self._previous_result = None
4779
self._previous_state = None
4880

4981

50-
def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True):
82+
def scan(self: "BaseProcess", inputs, current_state=None, store_state: bool = True, store_results: bool = True):
83+
"""
84+
Runs the process through jax's scan method
85+
Args:
86+
inputs: The inputs for scan (use pack rows to generate), must be a jax array
87+
current_state: Optional, the current state of the model, if none uses current global state
88+
store_state: Optional flag, should the final state be stored in the process
89+
store_results: Optional flag, should the final result be stored in the process
90+
91+
Returns: the final state, the final result
92+
93+
"""
5194
state = current_state or stateManager.state
5295
final_state, result = jax.lax.scan(self.run.compiled, state, inputs)
53-
if save_state:
96+
if store_state:
5497
self._previous_state = final_state
5598
if store_results:
5699
self._previous_result = result
57100
return final_state, result
58101

59102
def compile(self: "baseProcess"):
103+
"""
104+
For use by the compiler
105+
"""
60106
super().compile()
61107
if self._use_jit:
62108
self.run.compiled = JaxCompiledMethod.wrap(self.run.compiled)

0 commit comments

Comments
 (0)