Skip to content

Commit 771a8a9

Browse files
committed
TST: refactor marginalization tests to be less restrictive
1 parent 667e549 commit 771a8a9

1 file changed

Lines changed: 34 additions & 25 deletions

File tree

test/gw/likelihood/marginalization_test.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
import unittest
55
from copy import deepcopy
6+
from functools import cached_property
67
from itertools import product
78
from parameterized import parameterized
89

@@ -230,54 +231,63 @@ def setUp(self):
230231
maximum=self.parameters["geocent_time"] + 0.1
231232
)
232233

233-
trial_roq_paths = [
234-
"/roq_basis",
235-
os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
236-
"/home/cbc/ROQ_data/IMRPhenomPv2/4s",
237-
]
238-
roq_dir = None
239-
for path in trial_roq_paths:
240-
if os.path.isdir(path):
241-
roq_dir = path
242-
break
243-
if roq_dir is None:
244-
raise Exception("Unable to load ROQ basis: cannot proceed with tests")
245-
246-
self.roq_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
234+
self.relbin_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
247235
duration=self.duration,
248236
sampling_frequency=self.sampling_frequency,
249-
frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
237+
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning,
250238
start_time=1126259640,
251239
waveform_arguments=dict(
252240
reference_frequency=20.0,
241+
minimum_frequency=20.0,
253242
waveform_approximant="IMRPhenomPv2",
254-
frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"),
255-
frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"),
256243
)
257244
)
258-
self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy"
259-
self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy"
260245

261-
self.relbin_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
246+
self.multiband_waveform_generator = bilby.gw.WaveformGenerator(
262247
duration=self.duration,
263248
sampling_frequency=self.sampling_frequency,
264-
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning,
249+
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
265250
start_time=1126259640,
266251
waveform_arguments=dict(
267252
reference_frequency=20.0,
268-
minimum_frequency=20.0,
269253
waveform_approximant="IMRPhenomPv2",
270254
)
271255
)
272256

273-
self.multiband_waveform_generator = bilby.gw.WaveformGenerator(
257+
@property
258+
def roq_dir(self):
259+
trial_roq_paths = [
260+
"/roq_basis",
261+
os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
262+
"/home/cbc/ROQ_data/IMRPhenomPv2/4s",
263+
]
264+
if "BILBY_TESTING_ROQ_DIR" in os.environ:
265+
trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"])
266+
for path in trial_roq_paths:
267+
if os.path.isdir(path):
268+
return path
269+
raise Exception("Unable to load ROQ basis: cannot proceed with tests")
270+
271+
@property
272+
def roq_linear_matrix_file(self):
273+
return f"{self.roq_dir}/B_linear.npy"
274+
275+
@property
276+
def roq_quadratic_matrix_file(self):
277+
return f"{self.roq_dir}/B_quadratic.npy"
278+
279+
@cached_property
280+
def roq_waveform_generator(self):
281+
return bilby.gw.waveform_generator.WaveformGenerator(
274282
duration=self.duration,
275283
sampling_frequency=self.sampling_frequency,
276-
frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
284+
frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
277285
start_time=1126259640,
278286
waveform_arguments=dict(
279287
reference_frequency=20.0,
280288
waveform_approximant="IMRPhenomPv2",
289+
frequency_nodes_linear=np.load(f"{self.roq_dir}/fnodes_linear.npy"),
290+
frequency_nodes_quadratic=np.load(f"{self.roq_dir}/fnodes_quadratic.npy"),
281291
)
282292
)
283293

@@ -287,7 +297,6 @@ def tearDown(self):
287297
del self.parameters
288298
del self.interferometers
289299
del self.waveform_generator
290-
del self.roq_waveform_generator
291300
del self.priors
292301

293302
@classmethod

0 commit comments

Comments
 (0)