Skip to content

Commit 469957e

Browse files
committed
MAINT: Bump mininum dynesty version to 3.0.0
1 parent cde65c5 commit 469957e

10 files changed

Lines changed: 685 additions & 519 deletions

bilby/core/sampler/dynesty.py

Lines changed: 52 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
logger,
1818
safe_file_dump,
1919
)
20+
from . import dynesty_utils
2021
from .base_sampler import NestedSampler, Sampler, _SamplingContainer, signal_wrapper
2122

2223

@@ -196,15 +197,6 @@ def default_kwargs(self):
196197
kwargs["seed"] = None
197198
return kwargs
198199

199-
@property
200-
def new_dynesty_api(self):
201-
try:
202-
import dynesty.internal_samplers # noqa
203-
204-
return True
205-
except ImportError:
206-
return False
207-
208200
def __init__(
209201
self,
210202
likelihood,
@@ -281,61 +273,54 @@ def sampler_init_kwargs(self):
281273
# if we're using a Bilby implemented sampling method we need to register the
282274
# method. If we aren't we need to make sure the default "live" isn't set as
283275
# the bounding method
284-
if self.new_dynesty_api:
285-
internal_kwargs = dict(
286-
ndim=self.ndim,
287-
nonbounded=self.kwargs.get("nonbounded", None),
288-
periodic=self.kwargs.get("periodic", None),
289-
reflective=self.kwargs.get("reflective", None),
290-
maxmcmc=self.maxmcmc,
291-
)
292-
293-
from . import dynesty3_utils as dynesty_utils
276+
internal_kwargs = dict(
277+
ndim=self.ndim,
278+
nonbounded=self.kwargs.get("nonbounded", None),
279+
periodic=self.kwargs.get("periodic", None),
280+
reflective=self.kwargs.get("reflective", None),
281+
maxmcmc=self.maxmcmc,
282+
)
294283

295-
if kwargs["sample"] == "act-walk":
296-
internal_kwargs["nact"] = self.nact
297-
internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk(
298-
**internal_kwargs
299-
)
300-
bound = "none"
301-
logger.info(
302-
f"Using the bilby-implemented ensemble rwalk sampling tracking the "
303-
f"autocorrelation function and thinning by {internal_sampler.thin} with "
304-
f"maximum length {internal_sampler.thin * internal_sampler.maxmcmc}."
305-
)
306-
elif kwargs["sample"] == "acceptance-walk":
307-
internal_kwargs["naccept"] = self.naccept
308-
internal_kwargs["walks"] = self.kwargs["walks"]
309-
internal_sampler = dynesty_utils.EnsembleWalkSampler(**internal_kwargs)
310-
bound = "none"
311-
logger.info(
312-
f"Using the bilby-implemented ensemble rwalk sampling method with an "
313-
f"average of {internal_sampler.naccept} accepted steps up to chain "
314-
f"length {internal_sampler.maxmcmc}."
315-
)
316-
elif kwargs["sample"] == "rwalk":
317-
internal_kwargs["nact"] = self.nact
318-
internal_sampler = dynesty_utils.AcceptanceTrackingRWalk(
319-
**internal_kwargs
320-
)
321-
bound = "none"
322-
logger.info(
323-
f"Using the bilby-implemented ensemble rwalk sampling method with ACT "
324-
f"estimated chain length. An average of {2 * internal_sampler.nact} "
325-
f"steps will be accepted up to chain length {internal_sampler.maxmcmc}."
326-
)
327-
elif kwargs["bound"] == "live":
328-
logger.info(
329-
"Live-point based bound method requested with dynesty sample "
330-
f"'{kwargs['sample']}', overwriting to 'multi'"
331-
)
332-
internal_sampler = kwargs["sample"]
333-
bound = "multi"
334-
else:
335-
internal_sampler = kwargs["sample"]
336-
bound = kwargs["bound"]
337-
kwargs["sample"] = internal_sampler
338-
kwargs["bound"] = bound
284+
if kwargs["sample"] == "act-walk":
285+
internal_kwargs["nact"] = self.nact
286+
internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk(**internal_kwargs)
287+
bound = "none"
288+
logger.info(
289+
f"Using the bilby-implemented ensemble rwalk sampling tracking the "
290+
f"autocorrelation function and thinning by {internal_sampler.thin} with "
291+
f"maximum length {internal_sampler.thin * internal_sampler.maxmcmc}."
292+
)
293+
elif kwargs["sample"] == "acceptance-walk":
294+
internal_kwargs["naccept"] = self.naccept
295+
internal_kwargs["walks"] = self.kwargs["walks"]
296+
internal_sampler = dynesty_utils.EnsembleWalkSampler(**internal_kwargs)
297+
bound = "none"
298+
logger.info(
299+
f"Using the bilby-implemented ensemble rwalk sampling method with an "
300+
f"average of {internal_sampler.naccept} accepted steps up to chain "
301+
f"length {internal_sampler.maxmcmc}."
302+
)
303+
elif kwargs["sample"] == "rwalk":
304+
internal_kwargs["nact"] = self.nact
305+
internal_sampler = dynesty_utils.AcceptanceTrackingRWalk(**internal_kwargs)
306+
bound = "none"
307+
logger.info(
308+
f"Using the bilby-implemented ensemble rwalk sampling method with ACT "
309+
f"estimated chain length. An average of {2 * internal_sampler.nact} "
310+
f"steps will be accepted up to chain length {internal_sampler.maxmcmc}."
311+
)
312+
elif kwargs["bound"] == "live":
313+
logger.info(
314+
"Live-point based bound method requested with dynesty sample "
315+
f"'{kwargs['sample']}', overwriting to 'multi'"
316+
)
317+
internal_sampler = kwargs["sample"]
318+
bound = "multi"
319+
else:
320+
internal_sampler = kwargs["sample"]
321+
bound = kwargs["bound"]
322+
kwargs["sample"] = internal_sampler
323+
kwargs["bound"] = bound
339324
return kwargs
340325

341326
def _translate_kwargs(self, kwargs):
@@ -514,107 +499,12 @@ def sampler_class(self):
514499

515500
return Sampler
516501

517-
def _set_sampling_method(self):
518-
"""
519-
Resolve the sampling method and sampler to use from the provided
520-
:code:`bound` and :code:`sample` arguments.
521-
522-
This requires registering the :code:`bilby` specific methods in the
523-
appropriate locations within :code:`dynesty`.
524-
525-
Additionally, some combinations of bound/sample/proposals are not
526-
compatible and so we either warn the user or raise an error.
527-
"""
528-
if self.new_dynesty_api:
529-
return
530-
531-
import dynesty
532-
533-
_set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept))
534-
535-
sample = self.kwargs["sample"]
536-
bound = self.kwargs["bound"]
537-
538-
if sample not in ["rwalk", "act-walk", "acceptance-walk"] and bound in [
539-
"live",
540-
"live-multi",
541-
]:
542-
logger.info(
543-
"Live-point based bound method requested with dynesty sample "
544-
f"'{sample}', overwriting to 'multi'"
545-
)
546-
self.kwargs["bound"] = "multi"
547-
elif bound == "live":
548-
from .dynesty_utils import LivePointSampler
549-
550-
dynesty.dynamicsampler._SAMPLERS["live"] = LivePointSampler
551-
elif bound == "live-multi":
552-
from .dynesty_utils import MultiEllipsoidLivePointSampler
553-
554-
dynesty.dynamicsampler._SAMPLERS[
555-
"live-multi"
556-
] = MultiEllipsoidLivePointSampler
557-
elif sample == "acceptance-walk":
558-
raise DynestySetupError(
559-
"bound must be set to live or live-multi for sample=acceptance-walk"
560-
)
561-
elif self.proposals is None:
562-
logger.warning(
563-
"No proposals specified using dynesty sampling, defaulting "
564-
"to 'volumetric'."
565-
)
566-
self.proposals = ["volumetric"]
567-
_SamplingContainer.proposals = self.proposals
568-
elif "diff" in self.proposals:
569-
raise DynestySetupError(
570-
"bound must be set to live or live-multi to use differential "
571-
"evolution proposals"
572-
)
573-
574-
if sample == "rwalk":
575-
logger.info(
576-
f"Using the bilby-implemented {sample} sample method with ACT estimated walks. "
577-
f"An average of {2 * self.nact} steps will be accepted up to chain length "
578-
f"{self.maxmcmc}."
579-
)
580-
from .dynesty_utils import AcceptanceTrackingRWalk
581-
582-
if self.kwargs["walks"] > self.maxmcmc:
583-
raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)")
584-
if self.nact < 1:
585-
raise DynestySetupError("Unable to run with nact < 1")
586-
AcceptanceTrackingRWalk.old_act = None
587-
dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk()
588-
elif sample == "acceptance-walk":
589-
logger.info(
590-
f"Using the bilby-implemented {sample} sampling with an average of "
591-
f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}"
592-
)
593-
from .dynesty_utils import FixedRWalk
594-
595-
dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk()
596-
elif sample == "act-walk":
597-
logger.info(
598-
f"Using the bilby-implemented {sample} sampling tracking the "
599-
f"autocorrelation function and thinning by "
600-
f"{self.nact} with maximum length {self.nact * self.maxmcmc}"
601-
)
602-
from .dynesty_utils import ACTTrackingRWalk
603-
604-
ACTTrackingRWalk._cache = list()
605-
dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk()
606-
elif sample == "rwalk_dynesty":
607-
sample = sample.strip("_dynesty")
608-
self.kwargs["sample"] = sample
609-
logger.info(f"Using the dynesty-implemented {sample} sample method")
610-
611502
@signal_wrapper
612503
def run_sampler(self):
613504
import dynesty
614505

615506
logger.info(f"Using dynesty version {dynesty.__version__}")
616507

617-
self._set_sampling_method()
618508
self._setup_pool()
619509

620510
if self.resume:
@@ -666,25 +556,6 @@ def run_sampler(self):
666556

667557
return self.result
668558

669-
def _setup_pool(self):
670-
"""
671-
In addition to the usual steps, we need to set the sampling kwargs on
672-
every process. To make sure we get every process, run the kwarg setting
673-
more times than we have processes.
674-
"""
675-
super(Dynesty, self)._setup_pool()
676-
677-
if self.new_dynesty_api:
678-
return
679-
680-
if self.pool is not None:
681-
args = (
682-
[(self.nact, self.maxmcmc, self.proposals, self.naccept)]
683-
* self.npool
684-
* 10
685-
)
686-
self.pool.map(_set_sampling_kwargs, args)
687-
688559
def _generate_result(self, out):
689560
"""
690561
Extract the information we need from the dynesty output. This includes
@@ -895,10 +766,7 @@ def read_saved_state(self, continuing=False):
895766
mapper = self.pool.map
896767
else:
897768
mapper = map
898-
if self.new_dynesty_api:
899-
self.sampler.mapper = mapper
900-
else:
901-
self.sampler.M = mapper
769+
self.sampler.mapper = mapper
902770
return True
903771
else:
904772
logger.info(f"Resume file {self.resume_file} does not exist.")
@@ -941,10 +809,7 @@ def write_current_state(self):
941809
metadata = dict()
942810
versions = dict(bilby=bilby_version, dynesty=dynesty_version)
943811
self.sampler.pool = None
944-
if self.new_dynesty_api:
945-
self.sampler.mapper = map
946-
else:
947-
self.sampler.M = map
812+
self.sampler.mapper = map
948813
if dill.pickles(self.sampler):
949814
safe_file_dump((self.sampler, versions, metadata), self.resume_file, dill)
950815
logger.info(f"Written checkpoint file {self.resume_file}")
@@ -955,10 +820,7 @@ def write_current_state(self):
955820
)
956821
self.sampler.pool = self.pool
957822
if self.sampler.pool is not None:
958-
if self.new_dynesty_api:
959-
self.sampler.mapper = self.sampler.pool.map
960-
else:
961-
self.sampler.M = self.sampler.pool.map
823+
self.sampler.mapper = self.sampler.pool.map
962824

963825
def dump_samples_to_dat(self):
964826
"""
@@ -1090,7 +952,6 @@ def _run_test(self):
1090952
"""Run the sampler very briefly as a sanity test that it works."""
1091953
import pandas as pd
1092954

1093-
self._set_sampling_method()
1094955
self._setup_pool()
1095956
self.sampler = self.sampler_init(
1096957
loglikelihood=_log_likelihood_wrapper,
@@ -1228,4 +1089,4 @@ def dynesty_stats_plot(sampler):
12281089

12291090

12301091
class DynestySetupError(Exception):
1231-
pass
1092+
pass

0 commit comments

Comments
 (0)