Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion bilby/core/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def default(self, obj):
return {"__series__": True, "content": obj.to_dict()}
if isinstance(obj, np.random.Generator):
return encode_numpy_random_generator(obj)
if isinstance(obj, np.random.SeedSequence):
return encode_numpy_seed_sequence(obj)
if inspect.isfunction(obj):
return {
"__function__": True,
Expand Down Expand Up @@ -162,11 +164,29 @@ def encode_numpy_random_generator(generator):
}


def encode_numpy_seed_sequence(seed_sequence):
"""Encode a numpy SeedSequence to a dictionary.

Adds the key :code:`__numpy_seed_sequence__` to the dictionary to indicate
that the object is a numpy SeedSequence.

The :code:`state` key contains the state of the seed sequence.

.. versionadded:: 3.0.0
"""
state = dict(seed_sequence.state)
state["spawn_key"] = list(state["spawn_key"])
return {
"__numpy_seed_sequence__": True,
"state": state,
}


def decode_astropy_cosmology(dct):
"""Decode an astropy cosmology from a dictionary.

The dictionary should have been encoded using
:py:func:`~bibly.core.utils.io.encode_astropy_cosmology` and should have the
:py:func:`~bilby.core.utils.io.encode_astropy_cosmology` and should have the
key :code:`__cosmology__`.

.. versionchange:: 2.5.0
Expand Down Expand Up @@ -270,6 +290,16 @@ def decode_numpy_random_generator(dct):
return generator


def decode_numpy_seed_sequence(dct):
"""Decode a numpy SeedSequence from a dictionary.

.. versionadded:: 3.0.0
"""
state = dict(dct["state"])
state["spawn_key"] = tuple(int(idx) for idx in state["spawn_key"])
return np.random.SeedSequence(**state)


def load_json(filename, gzip):
if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz":
import gzip
Expand Down Expand Up @@ -313,6 +343,8 @@ def decode_bilby_json(dct):
return obj
if dct.get("__numpy_random_generator__", False):
return decode_numpy_random_generator(dct)
if dct.get("__numpy_seed_sequence__", False):
return decode_numpy_seed_sequence(dct)
if dct.get("__cosmology__", False):
return decode_astropy_cosmology(dct)
if dct.get("__astropy_quantity__", False):
Expand Down Expand Up @@ -442,6 +474,8 @@ def encode_for_hdf5(key, item):
output = item
elif isinstance(item, np.random.Generator):
output = encode_numpy_random_generator(item)
elif isinstance(item, np.random.SeedSequence):
output = encode_numpy_seed_sequence(item)
elif item is None:
output = "__none__"
elif isinstance(item, list):
Expand Down Expand Up @@ -508,6 +542,8 @@ def decode_hdf5_dict(output):
output = decode_astropy_unit(output)
elif "__numpy_random_generator__" in output:
output = decode_numpy_random_generator(output)
elif "__numpy_seed_sequence__" in output:
output = decode_numpy_seed_sequence(output)
return output


Expand Down
32 changes: 32 additions & 0 deletions test/core/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,38 @@ def test_json(self):
b = data["rng"].random()
self.assertEqual(a, b)

def test_hdf5_seed_sequence(self):
seed_sequence = np.random.SeedSequence(1234).spawn(1)[0]
data = {"seed_sequence": seed_sequence}

with h5py.File(self.outdir / "test_seed_sequence.h5", "w") as f:
bilby.core.utils.recursively_save_dict_contents_to_group(
f, "/", data
)

with h5py.File(self.outdir / "test_seed_sequence.h5", "r") as f:
loaded = bilby.core.utils.recursively_load_dict_contents_from_group(
f, "/"
)

self.assertIsInstance(loaded["seed_sequence"], np.random.SeedSequence)
self.assertEqual(loaded["seed_sequence"].state, seed_sequence.state)

def test_json_seed_sequence(self):
seed_sequence = np.random.SeedSequence(1234).spawn(1)[0]
data = {"seed_sequence": seed_sequence}

with open(self.outdir / "test_seed_sequence.json", "w") as file:
json.dump(data, file, indent=2, cls=bilby.core.utils.BilbyJsonEncoder)

with open(self.outdir / "test_seed_sequence.json", "r") as file:
loaded = json.load(
file, object_hook=bilby.core.utils.decode_bilby_json
)

self.assertIsInstance(loaded["seed_sequence"], np.random.SeedSequence)
self.assertEqual(loaded["seed_sequence"].state, seed_sequence.state)

def test_pickle(self):
with open(self.outdir / "test.pkl", 'wb') as file:
dill.dump(self.data, file)
Expand Down
Loading