Skip to content

Commit 9c32aa3

Browse files
committed
Update unit tests
1 parent fc940b5 commit 9c32aa3

3 files changed

Lines changed: 272 additions & 0 deletions

File tree

tests/test_blend.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from esigmapy.blend import (
55
align_in_phase,
6+
blend_modes,
67
blend_series,
78
compute_amplitude,
89
compute_frequency,
@@ -363,3 +364,55 @@ def test_align_inspiral_to_merger_order(self):
363364
# Merger-ringdown is unchanged; inspiral is shifted
364365
assert np.allclose(mr_out, mr)
365366
assert not np.allclose(insp_out, wave)
367+
368+
369+
# ---------------------------------------------------------------------------
370+
# blend_modes — input validation guards (pure Python, no LAL required)
371+
# ---------------------------------------------------------------------------
372+
373+
class TestBlendModesValidation:
374+
@staticmethod
375+
def _modes(n=500, freq_hz=50.0, dt=1.0 / 4096):
376+
t = np.arange(n) * dt
377+
mode = np.exp(-2j * np.pi * freq_hz * t)
378+
return {(2, 2): mode, (2, -2): np.conj(mode)}
379+
380+
def test_negative_frq_width_raises(self):
381+
modes = self._modes()
382+
with pytest.raises(IOError, match="negative"):
383+
blend_modes(modes, modes, np.ones(500), 50.0, frq_width=-5.0)
384+
385+
def test_zero_frq_width_raises(self):
386+
modes = self._modes()
387+
with pytest.raises(IOError):
388+
blend_modes(modes, modes, np.ones(500), 50.0, frq_width=0.0)
389+
390+
def test_mismatched_orbital_freq_length_raises(self):
391+
modes = self._modes(n=500)
392+
with pytest.raises(IOError):
393+
blend_modes(
394+
modes, modes,
395+
inspiral_orbital_frequency=np.ones(100), # wrong length
396+
frq_attach=50.0, frq_width=5.0,
397+
blend_using_avg_orbital_frequency=True,
398+
)
399+
400+
def test_mode_missing_from_inspiral_raises(self):
401+
insp = {(2, 2): np.ones(500, dtype=complex)}
402+
mr = {(2, 2): np.ones(500, dtype=complex), (3, 3): np.ones(500, dtype=complex)}
403+
with pytest.raises(IOError):
404+
blend_modes(
405+
insp, mr, np.ones(500), 50.0, frq_width=5.0,
406+
modes_to_blend=[(2, 2), (3, 3)],
407+
include_conjugate_modes=False,
408+
)
409+
410+
def test_mode_missing_from_mr_raises(self):
411+
insp = {(2, 2): np.ones(500, dtype=complex), (3, 3): np.ones(500, dtype=complex)}
412+
mr = {(2, 2): np.ones(500, dtype=complex)}
413+
with pytest.raises(IOError):
414+
blend_modes(
415+
insp, mr, np.ones(500), 50.0, frq_width=5.0,
416+
modes_to_blend=[(2, 2), (3, 3)],
417+
include_conjugate_modes=False,
418+
)

tests/test_generator.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import numpy as np
2+
import pytest
3+
from scipy import integrate
4+
5+
from esigmapy.generator import (
6+
_get_window_start,
7+
_get_transition_frequency_window,
8+
get_imr_esigma_modes,
9+
)
10+
11+
12+
# ---------------------------------------------------------------------------
13+
# _get_window_start
14+
# ---------------------------------------------------------------------------
15+
16+
class TestGetWindowStart:
17+
"""Tests for the private helper that integrates a frequency series and
18+
returns the first index where the cumulative phase exceeds a threshold."""
19+
20+
_DT = 1.0 / 4096
21+
_F0 = 100.0 # Hz — constant frequency used in several tests
22+
23+
def test_forward_meets_threshold(self):
24+
freq = np.ones(500) * self._F0
25+
idx = _get_window_start(freq, self._DT, 1.0, direction="forward")
26+
assert idx is not None
27+
assert abs(integrate.trapezoid(freq[: idx + 1], dx=self._DT)) >= 1.0
28+
29+
def test_forward_is_first_crossing(self):
30+
# The index immediately before should NOT yet meet the threshold
31+
freq = np.ones(500) * self._F0
32+
idx = _get_window_start(freq, self._DT, 1.0, direction="forward")
33+
assert abs(integrate.trapezoid(freq[:idx], dx=self._DT)) < 1.0
34+
35+
def test_forward_returns_none_when_unreachable(self):
36+
# Total integral of [1 Hz × 5 samples × dt] is far below 1000 rad
37+
freq = np.ones(5) * 1.0
38+
assert _get_window_start(freq, 0.001, 1000.0, direction="forward") is None
39+
40+
def test_backward_meets_threshold(self):
41+
freq = np.ones(500) * self._F0
42+
idx = _get_window_start(freq, self._DT, 1.0, direction="backward")
43+
assert idx is not None
44+
assert abs(integrate.trapezoid(freq[idx:], dx=self._DT)) >= 1.0
45+
46+
def test_backward_is_last_crossing(self):
47+
# The index one step to the right should NOT meet the threshold alone
48+
freq = np.ones(500) * self._F0
49+
idx = _get_window_start(freq, self._DT, 1.0, direction="backward")
50+
assert abs(integrate.trapezoid(freq[idx + 1 :], dx=self._DT)) < 1.0
51+
52+
def test_backward_returns_none_when_unreachable(self):
53+
freq = np.ones(5) * 1.0
54+
assert _get_window_start(freq, 0.001, 1000.0, direction="backward") is None
55+
56+
def test_forward_index_in_valid_range(self):
57+
freq = np.linspace(10.0, 200.0, 500)
58+
idx = _get_window_start(freq, self._DT, 1.0, direction="forward")
59+
assert idx is not None
60+
assert 0 < idx < len(freq)
61+
62+
63+
# ---------------------------------------------------------------------------
64+
# _get_transition_frequency_window
65+
# ---------------------------------------------------------------------------
66+
67+
class TestGetTransitionFrequencyWindow:
68+
"""Tests for the private helper that converts num_hyb_orbits into a
69+
hybridization frequency window width."""
70+
71+
@staticmethod
72+
def _setup(n=2000, dt=1.0 / 4096, f_start=10.0, f_end=200.0):
73+
freq = np.linspace(f_start, f_end, n)
74+
phase = np.cumsum(freq) * dt # monotonically increasing
75+
return freq, phase, dt
76+
77+
def test_end_mode_returns_positive_width(self):
78+
freq, phase, dt = self._setup()
79+
result = _get_transition_frequency_window(
80+
phase, freq, dt,
81+
f_mr_transition=freq[1000],
82+
num_hyb_orbits=0.1,
83+
keep_f_mr_transition_at_center=False,
84+
blend_using_avg_orbital_frequency=False,
85+
failsafe=True,
86+
)
87+
assert result > 0
88+
89+
def test_center_mode_returns_positive_width(self):
90+
freq, phase, dt = self._setup()
91+
result = _get_transition_frequency_window(
92+
phase, freq, dt,
93+
f_mr_transition=freq[1000],
94+
num_hyb_orbits=0.1,
95+
keep_f_mr_transition_at_center=True,
96+
blend_using_avg_orbital_frequency=False,
97+
failsafe=True,
98+
)
99+
assert result > 0
100+
101+
def test_avg_orbital_frequency_mode_returns_positive_width(self):
102+
freq, phase, dt = self._setup()
103+
result = _get_transition_frequency_window(
104+
phase, freq, dt,
105+
f_mr_transition=freq[1000],
106+
num_hyb_orbits=0.1,
107+
keep_f_mr_transition_at_center=False,
108+
blend_using_avg_orbital_frequency=True,
109+
failsafe=True,
110+
)
111+
assert result > 0
112+
113+
def test_more_orbits_wider_or_equal_window(self):
114+
freq, phase, dt = self._setup()
115+
f_tr = freq[1000]
116+
w_narrow = _get_transition_frequency_window(
117+
phase, freq, dt, f_tr,
118+
num_hyb_orbits=0.1,
119+
keep_f_mr_transition_at_center=False,
120+
blend_using_avg_orbital_frequency=False,
121+
failsafe=True,
122+
)
123+
w_wide = _get_transition_frequency_window(
124+
phase, freq, dt, f_tr,
125+
num_hyb_orbits=0.5,
126+
keep_f_mr_transition_at_center=False,
127+
blend_using_avg_orbital_frequency=False,
128+
failsafe=True,
129+
)
130+
assert w_wide >= w_narrow
131+
132+
133+
# ---------------------------------------------------------------------------
134+
# get_imr_esigma_modes — input validation guards (fire before any LAL call)
135+
# ---------------------------------------------------------------------------
136+
137+
class TestGetImrEsigmaModesValidation:
138+
_BASE = dict(
139+
mass1=20.0,
140+
mass2=20.0,
141+
f_lower=20.0,
142+
delta_t=1.0 / 2048,
143+
merger_ringdown_approximant="NRSur7dq4",
144+
)
145+
146+
def test_invalid_approximant_raises_before_lal(self):
147+
kwargs = {**self._BASE, "merger_ringdown_approximant": "IMRPhenomD"}
148+
with pytest.raises(IOError):
149+
get_imr_esigma_modes(**kwargs, mean_anomaly=0.0)
150+
151+
def test_both_phase_angles_none_raises(self):
152+
with pytest.raises(IOError):
153+
get_imr_esigma_modes(**self._BASE, mean_anomaly=None, coa_phase=None)
154+
155+
def test_align_merger_without_mean_anomaly_raises(self):
156+
# blend_aligning_merger_to_inspiral=True (default) requires mean_anomaly
157+
with pytest.raises(IOError):
158+
get_imr_esigma_modes(
159+
**self._BASE,
160+
blend_aligning_merger_to_inspiral=True,
161+
mean_anomaly=None,
162+
coa_phase=0.0,
163+
)
164+
165+
def test_align_inspiral_without_coa_phase_raises(self):
166+
# blend_aligning_merger_to_inspiral=False requires coa_phase
167+
with pytest.raises(IOError):
168+
get_imr_esigma_modes(
169+
**self._BASE,
170+
blend_aligning_merger_to_inspiral=False,
171+
mean_anomaly=0.0,
172+
coa_phase=None,
173+
)

tests/test_mr_generator.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
3+
from esigmapy.mr_generator import (
4+
check_available_mr_approximants,
5+
LALSIM_APPROXIMANTS,
6+
PYSEOBNR_APPROXIMANTS,
7+
SUPPORTED_MR_APPROXIMANTS,
8+
)
9+
10+
11+
class TestCheckAvailableMrApproximants:
12+
def test_lalsim_approximants_valid(self):
13+
for approx in LALSIM_APPROXIMANTS:
14+
check_available_mr_approximants(approx) # must not raise
15+
16+
def test_pyseobnr_approximants_valid(self):
17+
for approx in PYSEOBNR_APPROXIMANTS:
18+
check_available_mr_approximants(approx) # must not raise
19+
20+
def test_unsupported_approximant_raises(self):
21+
with pytest.raises(IOError, match="cannot generate"):
22+
check_available_mr_approximants("IMRPhenomD")
23+
24+
def test_empty_string_raises(self):
25+
with pytest.raises(IOError):
26+
check_available_mr_approximants("")
27+
28+
def test_case_sensitive(self):
29+
with pytest.raises(IOError):
30+
check_available_mr_approximants("nrsur7dq4")
31+
32+
33+
class TestApproximantConstants:
34+
def test_supported_is_union_of_lalsim_and_pyseobnr(self):
35+
assert set(SUPPORTED_MR_APPROXIMANTS) == set(LALSIM_APPROXIMANTS) | set(
36+
PYSEOBNR_APPROXIMANTS
37+
)
38+
39+
def test_lalsim_and_pyseobnr_disjoint(self):
40+
assert not set(LALSIM_APPROXIMANTS) & set(PYSEOBNR_APPROXIMANTS)
41+
42+
def test_lalsim_contains_nrsur7dq4(self):
43+
assert "NRSur7dq4" in LALSIM_APPROXIMANTS
44+
45+
def test_lalsim_contains_seobnrv4phm(self):
46+
assert "SEOBNRv4PHM" in LALSIM_APPROXIMANTS

0 commit comments

Comments
 (0)