Skip to content

Commit d7d0aaf

Browse files
authored
DEP: remove support for dynesty v2 (#1049)
* MAINT: Bump mininum dynesty version to 3.0.0 * BLD: reinstate build changes * BLD: remove unused files * MAINT: reremove unneeded file * FMT: fix precommit failures
1 parent 2817c2d commit d7d0aaf

6 files changed

Lines changed: 480 additions & 1460 deletions

File tree

bilby/core/sampler/dynesty.py

Lines changed: 51 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
logger,
1717
safe_file_dump,
1818
)
19+
from . import dynesty_utils
1920
from .base_sampler import NestedSampler, Sampler, _SamplingContainer, signal_wrapper
2021

2122

@@ -183,15 +184,6 @@ def default_kwargs(self):
183184
kwargs["seed"] = None
184185
return kwargs
185186

186-
@property
187-
def new_dynesty_api(self):
188-
try:
189-
import dynesty.internal_samplers # noqa
190-
191-
return True
192-
except ImportError:
193-
return False
194-
195187
def __init__(
196188
self,
197189
likelihood,
@@ -268,61 +260,54 @@ def sampler_init_kwargs(self):
268260
# if we're using a Bilby implemented sampling method we need to register the
269261
# method. If we aren't we need to make sure the default "live" isn't set as
270262
# the bounding method
271-
if self.new_dynesty_api:
272-
internal_kwargs = dict(
273-
ndim=self.ndim,
274-
nonbounded=self.kwargs.get("nonbounded", None),
275-
periodic=self.kwargs.get("periodic", None),
276-
reflective=self.kwargs.get("reflective", None),
277-
maxmcmc=self.maxmcmc,
278-
)
279-
280-
from . import dynesty3_utils as dynesty_utils
263+
internal_kwargs = dict(
264+
ndim=self.ndim,
265+
nonbounded=self.kwargs.get("nonbounded", None),
266+
periodic=self.kwargs.get("periodic", None),
267+
reflective=self.kwargs.get("reflective", None),
268+
maxmcmc=self.maxmcmc,
269+
)
281270

282-
if kwargs["sample"] == "act-walk":
283-
internal_kwargs["nact"] = self.nact
284-
internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk(
285-
**internal_kwargs
286-
)
287-
bound = "none"
288-
logger.info(
289-
f"Using the bilby-implemented ensemble rwalk sampling tracking the "
290-
f"autocorrelation function and thinning by {internal_sampler.thin} with "
291-
f"maximum length {internal_sampler.thin * internal_sampler.maxmcmc}."
292-
)
293-
elif kwargs["sample"] == "acceptance-walk":
294-
internal_kwargs["naccept"] = self.naccept
295-
internal_kwargs["walks"] = self.kwargs["walks"]
296-
internal_sampler = dynesty_utils.EnsembleWalkSampler(**internal_kwargs)
297-
bound = "none"
298-
logger.info(
299-
f"Using the bilby-implemented ensemble rwalk sampling method with an "
300-
f"average of {internal_sampler.naccept} accepted steps up to chain "
301-
f"length {internal_sampler.maxmcmc}."
302-
)
303-
elif kwargs["sample"] == "rwalk":
304-
internal_kwargs["nact"] = self.nact
305-
internal_sampler = dynesty_utils.AcceptanceTrackingRWalk(
306-
**internal_kwargs
307-
)
308-
bound = "none"
309-
logger.info(
310-
f"Using the bilby-implemented ensemble rwalk sampling method with ACT "
311-
f"estimated chain length. An average of {2 * internal_sampler.nact} "
312-
f"steps will be accepted up to chain length {internal_sampler.maxmcmc}."
313-
)
314-
elif kwargs["bound"] == "live":
315-
logger.info(
316-
"Live-point based bound method requested with dynesty sample "
317-
f"'{kwargs['sample']}', overwriting to 'multi'"
318-
)
319-
internal_sampler = kwargs["sample"]
320-
bound = "multi"
321-
else:
322-
internal_sampler = kwargs["sample"]
323-
bound = kwargs["bound"]
324-
kwargs["sample"] = internal_sampler
325-
kwargs["bound"] = bound
271+
if kwargs["sample"] == "act-walk":
272+
internal_kwargs["nact"] = self.nact
273+
internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk(**internal_kwargs)
274+
bound = "none"
275+
logger.info(
276+
f"Using the bilby-implemented ensemble rwalk sampling tracking the "
277+
f"autocorrelation function and thinning by {internal_sampler.thin} with "
278+
f"maximum length {internal_sampler.thin * internal_sampler.maxmcmc}."
279+
)
280+
elif kwargs["sample"] == "acceptance-walk":
281+
internal_kwargs["naccept"] = self.naccept
282+
internal_kwargs["walks"] = self.kwargs["walks"]
283+
internal_sampler = dynesty_utils.EnsembleWalkSampler(**internal_kwargs)
284+
bound = "none"
285+
logger.info(
286+
f"Using the bilby-implemented ensemble rwalk sampling method with an "
287+
f"average of {internal_sampler.naccept} accepted steps up to chain "
288+
f"length {internal_sampler.maxmcmc}."
289+
)
290+
elif kwargs["sample"] == "rwalk":
291+
internal_kwargs["nact"] = self.nact
292+
internal_sampler = dynesty_utils.AcceptanceTrackingRWalk(**internal_kwargs)
293+
bound = "none"
294+
logger.info(
295+
f"Using the bilby-implemented ensemble rwalk sampling method with ACT "
296+
f"estimated chain length. An average of {2 * internal_sampler.nact} "
297+
f"steps will be accepted up to chain length {internal_sampler.maxmcmc}."
298+
)
299+
elif kwargs["bound"] == "live":
300+
logger.info(
301+
"Live-point based bound method requested with dynesty sample "
302+
f"'{kwargs['sample']}', overwriting to 'multi'"
303+
)
304+
internal_sampler = kwargs["sample"]
305+
bound = "multi"
306+
else:
307+
internal_sampler = kwargs["sample"]
308+
bound = kwargs["bound"]
309+
kwargs["sample"] = internal_sampler
310+
kwargs["bound"] = bound
326311
return kwargs
327312

328313
def _translate_kwargs(self, kwargs):
@@ -501,107 +486,12 @@ def sampler_class(self):
501486

502487
return Sampler
503488

504-
def _set_sampling_method(self):
505-
"""
506-
Resolve the sampling method and sampler to use from the provided
507-
:code:`bound` and :code:`sample` arguments.
508-
509-
This requires registering the :code:`bilby` specific methods in the
510-
appropriate locations within :code:`dynesty`.
511-
512-
Additionally, some combinations of bound/sample/proposals are not
513-
compatible and so we either warn the user or raise an error.
514-
"""
515-
if self.new_dynesty_api:
516-
return
517-
518-
import dynesty
519-
520-
_set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept))
521-
522-
sample = self.kwargs["sample"]
523-
bound = self.kwargs["bound"]
524-
525-
if sample not in ["rwalk", "act-walk", "acceptance-walk"] and bound in [
526-
"live",
527-
"live-multi",
528-
]:
529-
logger.info(
530-
"Live-point based bound method requested with dynesty sample "
531-
f"'{sample}', overwriting to 'multi'"
532-
)
533-
self.kwargs["bound"] = "multi"
534-
elif bound == "live":
535-
from .dynesty_utils import LivePointSampler
536-
537-
dynesty.dynamicsampler._SAMPLERS["live"] = LivePointSampler
538-
elif bound == "live-multi":
539-
from .dynesty_utils import MultiEllipsoidLivePointSampler
540-
541-
dynesty.dynamicsampler._SAMPLERS[
542-
"live-multi"
543-
] = MultiEllipsoidLivePointSampler
544-
elif sample == "acceptance-walk":
545-
raise DynestySetupError(
546-
"bound must be set to live or live-multi for sample=acceptance-walk"
547-
)
548-
elif self.proposals is None:
549-
logger.warning(
550-
"No proposals specified using dynesty sampling, defaulting "
551-
"to 'volumetric'."
552-
)
553-
self.proposals = ["volumetric"]
554-
_SamplingContainer.proposals = self.proposals
555-
elif "diff" in self.proposals:
556-
raise DynestySetupError(
557-
"bound must be set to live or live-multi to use differential "
558-
"evolution proposals"
559-
)
560-
561-
if sample == "rwalk":
562-
logger.info(
563-
f"Using the bilby-implemented {sample} sample method with ACT estimated walks. "
564-
f"An average of {2 * self.nact} steps will be accepted up to chain length "
565-
f"{self.maxmcmc}."
566-
)
567-
from .dynesty_utils import AcceptanceTrackingRWalk
568-
569-
if self.kwargs["walks"] > self.maxmcmc:
570-
raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)")
571-
if self.nact < 1:
572-
raise DynestySetupError("Unable to run with nact < 1")
573-
AcceptanceTrackingRWalk.old_act = None
574-
dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk()
575-
elif sample == "acceptance-walk":
576-
logger.info(
577-
f"Using the bilby-implemented {sample} sampling with an average of "
578-
f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}"
579-
)
580-
from .dynesty_utils import FixedRWalk
581-
582-
dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk()
583-
elif sample == "act-walk":
584-
logger.info(
585-
f"Using the bilby-implemented {sample} sampling tracking the "
586-
f"autocorrelation function and thinning by "
587-
f"{self.nact} with maximum length {self.nact * self.maxmcmc}"
588-
)
589-
from .dynesty_utils import ACTTrackingRWalk
590-
591-
ACTTrackingRWalk._cache = list()
592-
dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk()
593-
elif sample == "rwalk_dynesty":
594-
sample = sample.strip("_dynesty")
595-
self.kwargs["sample"] = sample
596-
logger.info(f"Using the dynesty-implemented {sample} sample method")
597-
598489
@signal_wrapper
599490
def run_sampler(self):
600491
import dynesty
601492

602493
logger.info(f"Using dynesty version {dynesty.__version__}")
603494

604-
self._set_sampling_method()
605495
self._setup_pool()
606496

607497
if self.resume:
@@ -653,25 +543,6 @@ def run_sampler(self):
653543

654544
return self.result
655545

656-
def _setup_pool(self):
657-
"""
658-
In addition to the usual steps, we need to set the sampling kwargs on
659-
every process. To make sure we get every process, run the kwarg setting
660-
more times than we have processes.
661-
"""
662-
super(Dynesty, self)._setup_pool()
663-
664-
if self.new_dynesty_api:
665-
return
666-
667-
if self.pool is not None:
668-
args = (
669-
[(self.nact, self.maxmcmc, self.proposals, self.naccept)]
670-
* self.npool
671-
* 10
672-
)
673-
self.pool.map(_set_sampling_kwargs, args)
674-
675546
def _generate_result(self, out):
676547
"""
677548
Extract the information we need from the dynesty output. This includes
@@ -882,10 +753,7 @@ def read_saved_state(self, continuing=False):
882753
mapper = self.pool.map
883754
else:
884755
mapper = map
885-
if self.new_dynesty_api:
886-
self.sampler.mapper = mapper
887-
else:
888-
self.sampler.M = mapper
756+
self.sampler.mapper = mapper
889757
return True
890758
else:
891759
logger.info(f"Resume file {self.resume_file} does not exist.")
@@ -928,10 +796,7 @@ def write_current_state(self):
928796
metadata = dict()
929797
versions = dict(bilby=bilby_version, dynesty=dynesty_version)
930798
self.sampler.pool = None
931-
if self.new_dynesty_api:
932-
self.sampler.mapper = map
933-
else:
934-
self.sampler.M = map
799+
self.sampler.mapper = map
935800
if dill.pickles(self.sampler):
936801
safe_file_dump((self.sampler, versions, metadata), self.resume_file, dill)
937802
logger.info(f"Written checkpoint file {self.resume_file}")
@@ -942,10 +807,7 @@ def write_current_state(self):
942807
)
943808
self.sampler.pool = self.pool
944809
if self.sampler.pool is not None:
945-
if self.new_dynesty_api:
946-
self.sampler.mapper = self.sampler.pool.map
947-
else:
948-
self.sampler.M = self.sampler.pool.map
810+
self.sampler.mapper = self.sampler.pool.map
949811

950812
def dump_samples_to_dat(self):
951813
"""
@@ -1077,7 +939,6 @@ def _run_test(self):
1077939
"""Run the sampler very briefly as a sanity test that it works."""
1078940
import pandas as pd
1079941

1080-
self._set_sampling_method()
1081942
self._setup_pool()
1082943
self.sampler = self.sampler_init(
1083944
loglikelihood=_log_likelihood_wrapper,

0 commit comments

Comments
 (0)