Skip to content

Commit 4c38490

Browse files
authored
Update to and adjustment to (with comments added) JaxProcessesMixin.py (#154)
Merging in JaxProcessMixin adjustment where comments were added as well as small naming update to the flag (which locally stores state to avoid confusion with the global state). This is a clean-up commit with minor adjustment.
2 parents 5647c68 + df08087 commit 4c38490

1 file changed

Lines changed: 48 additions & 2 deletions

File tree

ngclearn/utils/JaxProcessesMixin.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,30 @@
99
from ngcsimlib._src.process.baseProcess import BaseProcess
1010

1111
class JaxCompiledMethod(CompiledMethod):
12+
"""
13+
A wrapper for a compiled method that includes jax's jit wrapped. Used
14+
exclusively by the mixin and shouldn't be used elsewhere.
15+
"""
1216
def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals):
1317
super().__init__(fn, fn_ast, auxiliary_ast, namespace, extra_globals)
1418
self._fn = jax.jit(fn)
1519
self._fn_source = fn
1620

1721
@property
1822
def source_fn(self):
23+
"""
24+
The source method not wrapped in jit
25+
"""
1926
return self._fn_source
2027

2128
@classmethod
2229
def wrap(cls, compiledMethod: CompiledMethod):
30+
"""
31+
Helper method to expand on a base compiled method
32+
Args:
33+
compiledMethod: The method to be expanded upon
34+
Returns: the JaxCompiledMethod based on the input
35+
"""
2336
return cls(compiledMethod._fn,
2437
compiledMethod.ast,
2538
compiledMethod.auxiliary_ast,
@@ -28,35 +41,68 @@ def wrap(cls, compiledMethod: CompiledMethod):
2841

2942

3043
class JaxProcessesMixin:
44+
"""
45+
A mixin for the base Process that adds JAX functionality such as scan and
46+
implicit jit wrapping
47+
"""
3148
def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs):
49+
"""
50+
Look at the BaseProcess class for information about other arguments
51+
Args:
52+
use_jit: a flag for if the process should implicitly jit wrap
53+
"""
3254
super().__init__(name, *args, **kwargs)
3355
self._previous_result = None
3456
self._previous_state = None
3557
self._use_jit = use_jit
3658

3759
@property
3860
def previous_result(self):
61+
"""
62+
Stores and returns the last result of scan (the second returned value)
63+
"""
3964
return self._previous_result
4065

4166
@property
4267
def previous_state(self):
68+
"""
69+
Stores and returns the last returned state of scan (the first returned
70+
value)
71+
"""
4372
return self._previous_state
4473

4574
def clear(self):
75+
"""
76+
Clears out the previous result and state from scan
77+
"""
4678
self._previous_result = None
4779
self._previous_state = None
4880

4981

50-
def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True):
82+
def scan(self: "BaseProcess", inputs, current_state=None, store_state: bool = True, store_results: bool = True):
83+
"""
84+
Runs the process through jax's scan method
85+
Args:
86+
inputs: The inputs for scan (use pack rows to generate), must be a jax array
87+
current_state: Optional, the current state of the model, if none uses current global state
88+
store_state: Optional flag, should the final state be stored in the process
89+
store_results: Optional flag, should the final result be stored in the process
90+
91+
Returns: the final state, the final result
92+
93+
"""
5194
state = current_state or stateManager.state
5295
final_state, result = jax.lax.scan(self.run.compiled, state, inputs)
53-
if save_state:
96+
if store_state:
5497
self._previous_state = final_state
5598
if store_results:
5699
self._previous_result = result
57100
return final_state, result
58101

59102
def compile(self: "baseProcess"):
103+
"""
104+
For use by the compiler
105+
"""
60106
super().compile()
61107
if self._use_jit:
62108
self.run.compiled = JaxCompiledMethod.wrap(self.run.compiled)

0 commit comments

Comments
 (0)