Skip to content

Commit b6da3e0

Browse files
committed
Add test for passing multiple percentiles to calculate_percentile_cut
1 parent fa65d16 commit b6da3e0

2 files changed

Lines changed: 37 additions & 3 deletions

File tree

pyirf/cut_optimization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def optimize_cuts(
7575
fill_value = signal['gh_score'].max()
7676

7777
sensitivities = []
78-
cut_indicies = []
78+
cut_indices = []
7979
n_theta_cuts = len(theta_cut_efficiencies)
8080
n_gh_cuts = len(gh_cut_efficiencies)
8181
n_cuts = len(multiplicity_cuts) * n_theta_cuts * n_gh_cuts
@@ -156,7 +156,7 @@ def optimize_cuts(
156156
signal_hist, background_hist, alpha=alpha,
157157
**kwargs,
158158
)
159-
cut_indicies.append((multiplicity_index, theta_index, gh_index))
159+
cut_indices.append((multiplicity_index, theta_index, gh_index))
160160
sensitivities.append(sensitivity)
161161
bar.update(1)
162162

@@ -192,7 +192,7 @@ def optimize_cuts(
192192
# if all are invalid, just use the first one
193193
best = 0
194194

195-
multiplicity_index, theta_index, gh_index = cut_indicies[best]
195+
multiplicity_index, theta_index, gh_index = cut_indices[best]
196196

197197
best_sensitivity[bin_id] = sensitivities[best][bin_id]
198198

pyirf/tests/test_cuts.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,37 @@ def test_calculate_percentile_cuts_table():
178178
[dist1.ppf(0.68), dist2.ppf(0.68)],
179179
rtol=0.1,
180180
)
181+
182+
183+
184+
185+
def test_calculate_percentile_cuts_multiple():
186+
from pyirf.cuts import calculate_percentile_cut
187+
188+
np.random.seed(0)
189+
190+
dist1 = norm(0, 1)
191+
dist2 = norm(10, 1)
192+
N = int(1e4)
193+
194+
values = np.append(dist1.rvs(size=N), dist2.rvs(size=N)) * u.deg
195+
bin_values = np.append(np.zeros(N), np.ones(N)) * u.m
196+
# add some values outside of binning to test that under/overflow are ignored
197+
bin_values[10] = 5 * u.m
198+
bin_values[30] = -1 * u.m
199+
200+
bins = [-0.5, 0.5, 1.5] * u.m
201+
202+
cuts = calculate_percentile_cut(values, bin_values, bins, fill_value=np.nan * u.deg, percentile=[50, 68, 95])
203+
assert np.all(cuts["low"] == bins[:-1])
204+
assert np.all(cuts["high"] == bins[1:])
205+
206+
np.testing.assert_allclose(
207+
cuts["cut"].to_value(u.deg),
208+
[
209+
[dist1.ppf(0.5), dist1.ppf(0.68), dist1.ppf(0.95)],
210+
[dist2.ppf(0.5), dist2.ppf(0.68), dist2.ppf(0.95)],
211+
],
212+
rtol=0.1,
213+
atol=0.1, # dist1.ppf(0.5) == 0, so we also need a non-zero atol
214+
)

0 commit comments

Comments
 (0)