diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 8299d6816..52184b3b9 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -69,6 +69,8 @@ def default(self, obj): return {"__series__": True, "content": obj.to_dict()} if isinstance(obj, np.random.Generator): return encode_numpy_random_generator(obj) + if isinstance(obj, np.random.SeedSequence): + return encode_numpy_seed_sequence(obj) if inspect.isfunction(obj): return { "__function__": True, @@ -162,11 +164,29 @@ def encode_numpy_random_generator(generator): } +def encode_numpy_seed_sequence(seed_sequence): + """Encode a numpy SeedSequence to a dictionary. + + Adds the key :code:`__numpy_seed_sequence__` to the dictionary to indicate + that the object is a numpy SeedSequence. + + The :code:`state` key contains the state of the seed sequence. + + .. versionadded:: 3.0.0 + """ + state = dict(seed_sequence.state) + state["spawn_key"] = list(state["spawn_key"]) + return { + "__numpy_seed_sequence__": True, + "state": state, + } + + def decode_astropy_cosmology(dct): """Decode an astropy cosmology from a dictionary. The dictionary should have been encoded using - :py:func:`~bibly.core.utils.io.encode_astropy_cosmology` and should have the + :py:func:`~bilby.core.utils.io.encode_astropy_cosmology` and should have the key :code:`__cosmology__`. .. versionchange:: 2.5.0 @@ -270,6 +290,16 @@ def decode_numpy_random_generator(dct): return generator +def decode_numpy_seed_sequence(dct): + """Decode a numpy SeedSequence from a dictionary. + + .. versionadded:: 3.0.0 + """ + state = dict(dct["state"]) + state["spawn_key"] = tuple(int(idx) for idx in state["spawn_key"]) + return np.random.SeedSequence(**state) + + def load_json(filename, gzip): if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz": import gzip @@ -313,6 +343,8 @@ def decode_bilby_json(dct): return obj if dct.get("__numpy_random_generator__", False): return decode_numpy_random_generator(dct) + if dct.get("__numpy_seed_sequence__", False): + return decode_numpy_seed_sequence(dct) if dct.get("__cosmology__", False): return decode_astropy_cosmology(dct) if dct.get("__astropy_quantity__", False): @@ -442,6 +474,8 @@ def encode_for_hdf5(key, item): output = item elif isinstance(item, np.random.Generator): output = encode_numpy_random_generator(item) + elif isinstance(item, np.random.SeedSequence): + output = encode_numpy_seed_sequence(item) elif item is None: output = "__none__" elif isinstance(item, list): @@ -508,6 +542,8 @@ def decode_hdf5_dict(output): output = decode_astropy_unit(output) elif "__numpy_random_generator__" in output: output = decode_numpy_random_generator(output) + elif "__numpy_seed_sequence__" in output: + output = decode_numpy_seed_sequence(output) return output diff --git a/test/core/utils_test.py b/test/core/utils_test.py index df46d6bb3..b1dff0313 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -462,6 +462,38 @@ def test_json(self): b = data["rng"].random() self.assertEqual(a, b) + def test_hdf5_seed_sequence(self): + seed_sequence = np.random.SeedSequence(1234).spawn(1)[0] + data = {"seed_sequence": seed_sequence} + + with h5py.File(self.outdir / "test_seed_sequence.h5", "w") as f: + bilby.core.utils.recursively_save_dict_contents_to_group( + f, "/", data + ) + + with h5py.File(self.outdir / "test_seed_sequence.h5", "r") as f: + loaded = bilby.core.utils.recursively_load_dict_contents_from_group( + f, "/" + ) + + self.assertIsInstance(loaded["seed_sequence"], np.random.SeedSequence) + self.assertEqual(loaded["seed_sequence"].state, seed_sequence.state) + + def test_json_seed_sequence(self): + seed_sequence = np.random.SeedSequence(1234).spawn(1)[0] + data = {"seed_sequence": seed_sequence} + + with open(self.outdir / "test_seed_sequence.json", "w") as file: + json.dump(data, file, indent=2, cls=bilby.core.utils.BilbyJsonEncoder) + + with open(self.outdir / "test_seed_sequence.json", "r") as file: + loaded = json.load( + file, object_hook=bilby.core.utils.decode_bilby_json + ) + + self.assertIsInstance(loaded["seed_sequence"], np.random.SeedSequence) + self.assertEqual(loaded["seed_sequence"].state, seed_sequence.state) + def test_pickle(self): with open(self.outdir / "test.pkl", 'wb') as file: dill.dump(self.data, file)