Skip to content

Commit 197e6e0

Browse files
authored
Merge pull request #47 from CovertLab/serialize_rng_state
Serialize PRNG state
2 parents c91d6fe + 8ab110d commit 197e6e0

9 files changed

Lines changed: 85 additions & 9 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ given duration. `evolve` returns a dictionary with five keys:
7070
* outcome - the final state of the system
7171

7272
```python
73-
result = system.evolve(state, duration, rates)
73+
result = system.evolve(duration, state, rates)
7474
```
7575

7676
If you are interested in the history of states for plotting or otherwise, these can be

arrow/arrow.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def flat_indexes(assorted_lists):
2323
conjunction with the flat array to recover the original list of lists.
2424
2525
Args:
26-
assorted_lists (List[List]): A list of variable length lists.
26+
assorted_lists (List[List]): A list of variable length lists.
2727
2828
Returns numpy arrays:
2929
flat: The flattened data.
@@ -78,7 +78,7 @@ class StochasticSystem(object):
7878
7979
The stoichiometric matrix has a reaction for each row, with the values in that row
8080
encoding how many of each substrate are either consumed or produced by the reaction
81-
(and zero everywhere else).
81+
(and zero everywhere else).
8282
'''
8383

8484
def __init__(self, stoichiometry, random_seed=0):
@@ -134,6 +134,7 @@ def __getstate__(self):
134134

135135
return (
136136
self.random_seed,
137+
self.obsidian.get_random_state(),
137138
self.stoichiometry,
138139
self.reactants_lengths,
139140
self.reactants_indexes,
@@ -151,7 +152,7 @@ def __setstate__(self, state):
151152
Import from pickled state.
152153
'''
153154

154-
self.random_seed, self.stoichiometry, self.reactants_lengths, self.reactants_indexes, self.reactants_flat, self.reactions_flat, self.dependencies_lengths, self.dependencies_indexes, self.dependencies_flat, self.substrates_lengths, self.substrates_indexes, self.substrates_flat = state
155+
self.random_seed, random_state, self.stoichiometry, self.reactants_lengths, self.reactants_indexes, self.reactants_flat, self.reactions_flat, self.dependencies_lengths, self.dependencies_indexes, self.dependencies_flat, self.substrates_lengths, self.substrates_indexes, self.substrates_flat = state
155156

156157
self.obsidian = Arrowhead(
157158
self.random_seed,
@@ -166,6 +167,7 @@ def __setstate__(self, state):
166167
self.substrates_lengths,
167168
self.substrates_indexes,
168169
self.substrates_flat)
170+
self.obsidian.set_random_state(*random_state)
169171

170172
def evolve(self, duration, state, rates):
171173
status, steps, time, events, outcome = self.obsidian.evolve(duration, state, rates)

arrow/arrowhead.pyx

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from __future__ import absolute_import, division, print_function
44

55
cimport cython
66
from cpython.mem cimport PyMem_Malloc, PyMem_Free
7-
from libc.stdint cimport int64_t
7+
from libc.stdint cimport int64_t, uint32_t
88
from libc.string cimport memset, memcpy
99
from libc.stdlib cimport free
1010

@@ -132,6 +132,32 @@ cdef class Arrowhead:
132132
"""Returns the number of substrates this system operates on."""
133133
return self.info.substrates_count
134134

135+
def get_random_state(self):
136+
"""Returns the state of the pseudorandom number generator."""
137+
cdef mersenne.MTState state
138+
obsidian.get_random_state(&self.info, &state)
139+
140+
mt = copy_c_array(
141+
&state.MT[0], mersenne.TWISTER_SIZE, sizeof(uint32_t),
142+
np.NPY_UINT32)
143+
mt_tempered = copy_c_array(
144+
&state.MT_TEMPERED[0], mersenne.TWISTER_SIZE, sizeof(uint32_t),
145+
np.NPY_UINT32)
146+
index = state.index
147+
148+
return mt, mt_tempered, index
149+
150+
def set_random_state(
151+
self, uint32_t[::1] mt, uint32_t[::1] mt_tempered,
152+
size_t index):
153+
cdef mersenne.MTState state
154+
memcpy(&state.MT[0], &mt[0], sizeof(state.MT))
155+
memcpy(
156+
&state.MT_TEMPERED[0], &mt_tempered[0],
157+
sizeof(state.MT_TEMPERED))
158+
state.index = index
159+
obsidian.set_random_state(&self.info, &state)
160+
135161

136162
cdef np.ndarray copy_c_array(
137163
void *source, np.npy_intp element_count, size_t element_size, int np_typenum):

arrow/mersenne.pxd

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@ from libc.stdint cimport uint32_t
55

66
cdef extern from "mersenne.h":
77

8+
size_t TWISTER_SIZE
9+
810
ctypedef struct MTState:
9-
pass
11+
# mersenne.h defines TWISTER_SIZE to be 624.
12+
uint32_t MT[624]
13+
uint32_t MT_TEMPERED[624]
14+
size_t index
15+
1016

1117
void seed(MTState *state, uint32_t seed_value)
1218

arrow/obsidian.c

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ evolve_result evolve(Info *info, double duration, int64_t *state, double *rates)
200200
choice = -1;
201201
break;
202202

203-
// Otherwise we need to find the next reaction to perform.
203+
// Otherwise we need to find the next reaction to perform.
204204
} else {
205205

206206
// First, sample two random values, `point` from a linear distribution and
@@ -310,6 +310,14 @@ evolve_result evolve(Info *info, double duration, int64_t *state, double *rates)
310310
return result;
311311
}
312312

313+
void get_random_state(const Info *info, MTState *exported_random_state) {
314+
memcpy(exported_random_state, info->random_state, sizeof(MTState));
315+
}
316+
317+
void set_random_state(Info *info, const MTState *state) {
318+
memcpy(info->random_state, state, sizeof(MTState));
319+
}
320+
313321
// Print an array of doubles
314322
int
315323
print_array(double *array, int length) {

arrow/obsidian.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ typedef struct Info {
3838
// arrays that the caller must free().
3939
evolve_result evolve(Info *info, double duration, int64_t *state, double *rates);
4040

41+
void get_random_state(const Info *info, MTState *exported_random_state);
42+
43+
void set_random_state(Info *info, const MTState *state);
44+
4145
// Supporting print utilities
4246
int print_array(double *array, int length);
4347

arrow/obsidian.pxd

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# cython: language_level=3str
22

3-
from libc.stdint cimport int64_t
3+
from libc.stdint cimport int64_t, uint32_t
44

55
from mersenne cimport MTState
66

@@ -37,4 +37,8 @@ cdef extern from "obsidian.h":
3737

3838
evolve_result evolve(Info *info, double duration, int64_t *state, double *rates)
3939

40+
void get_random_state(Info *info, MTState *exported_random_state)
41+
42+
void set_random_state(Info *info, MTState *state)
43+
4044
int print_array(double *array, int length)

arrow/test/test_arrow.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,32 @@ def test_flagella():
273273

274274
print('flagella result: {}'.format(result))
275275

276+
def test_get_set_random_state():
277+
stoich = np.array([[1, 1, -1, 0], [-2, 0, 0, 1], [-1, -1, 1, 0]])
278+
system = StochasticSystem(stoich)
279+
280+
state = np.array([1000, 1000, 0, 0])
281+
rates = np.array([3.0, 1.0, 1.0])
282+
283+
system.evolve(1, state, rates)
284+
285+
rand_state = system.obsidian.get_random_state()
286+
287+
result_1 = system.evolve(1, state, rates)
288+
result_2 = system.evolve(1, state, rates)
289+
290+
with np.testing.assert_raises(AssertionError):
291+
for key in ('time', 'events', 'occurrences', 'outcome'):
292+
np.testing.assert_array_equal(
293+
result_1[key], result_2[key])
294+
295+
system.obsidian.set_random_state(*rand_state)
296+
result_1_again = system.evolve(1, state, rates)
297+
298+
for key in ('time', 'events', 'occurrences', 'outcome'):
299+
np.testing.assert_array_equal(
300+
result_1[key], result_1_again[key])
301+
276302

277303
def main(args):
278304
systems = (

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
setup(
4141
name='stochastic-arrow',
42-
version='0.4.3',
42+
version='0.4.4',
4343
packages=['arrow'],
4444
author='Ryan Spangler, John Mason, Jerry Morrison',
4545
author_email='spanglry@stanford.edu',

0 commit comments

Comments
 (0)