Skip to content

Commit 94bf371

Browse files
committed
Fix pickling support
__reduce__ methods were not added in subclasses, only in superclass (_MKLRandomState), so the proper constructors would never be called, thus pickling would produce different objects
1 parent 0b19fe4 commit 94bf371

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

mkl_random/interfaces/_numpy_random.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, seed=None):
6060

6161
def seed(self, seed=None):
6262
"""
63-
seed(seed=Nonee)
63+
seed(seed=None)
6464
6565
Seed the generator.
6666
@@ -104,6 +104,11 @@ def set_state(self, state):
104104
"""
105105
return super().set_state(state=state)
106106

107+
# pickling support
108+
def __reduce__(self):
109+
global __NPRandomState_ctor
110+
return (__NPRandomState_ctor, (), self.get_state())
111+
107112
def random_sample(self, size=None):
108113
"""
109114
random_sample(size=None)
@@ -223,7 +228,7 @@ def beta(self, a, b, size=None):
223228
For full documentation refer to `numpy.random.beta`.
224229
225230
"""
226-
return super().beta(a=a, b=b, size=size,)
231+
return super().beta(a=a, b=b, size=size)
227232

228233
def exponential(self, scale=1.0, size=None):
229234
"""
@@ -591,6 +596,19 @@ def permutation(self, x):
591596
return super().permutation(x=x)
592597

593598

599+
def __NPRandomState_ctor():
600+
"""
601+
Return a RandomState instance.
602+
This function exists solely to assist (un)pickling.
603+
Note that the state of the RandomState returned here is irrelevant, as this function's
604+
entire purpose is to return a newly allocated RandomState whose state pickle can set.
605+
Consequently the RandomState returned by this function is a freshly allocated copy
606+
with a seed=0.
607+
See https://github.com/numpy/numpy/issues/4763 for a detailed discussion
608+
"""
609+
return RandomState(seed=0)
610+
611+
594612
# instantiate a default RandomState object to be used by module-level functions
595613
_rand = RandomState()
596614
# define module-level functions using methods of a default RandomState object

mkl_random/mklrand.pyx

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,10 +1261,6 @@ cdef class _MKLRandomState:
12611261
def __setstate__(self, state):
12621262
self.set_state(state)
12631263

1264-
def __reduce__(self):
1265-
global __RandomState_ctor
1266-
return (__RandomState_ctor, (), self.get_state())
1267-
12681264
# Basic distributions:
12691265
def random_sample(self, size=None):
12701266
"""
@@ -5605,6 +5601,11 @@ cdef class MKLRandomState(_MKLRandomState):
56055601
56065602
"""
56075603

5604+
# pickling support
5605+
def __reduce__(self):
5606+
global __MKLRandomState_ctor
5607+
return (__MKLRandomState_ctor, (), self.get_state())
5608+
56085609
def leapfrog(self, int k, int nstreams):
56095610
"""
56105611
leapfrog(k, nstreams)
@@ -5966,6 +5967,11 @@ class RandomState(MKLRandomState):
59665967
)
59675968
super().__init__(seed=seed, brng=brng)
59685969

5970+
# pickling support
5971+
def __reduce__(self):
5972+
global __RandomState_ctor
5973+
return (__RandomState_ctor, (), self.get_state())
5974+
59695975

59705976
def __MKLRandomState_ctor():
59715977
"""

0 commit comments

Comments
 (0)