Skip to content

Commit d9df110

Browse files
committed
added all mutliclosure nsigma modules
1 parent 164577c commit d9df110

3 files changed

Lines changed: 551 additions & 0 deletions

File tree

Lines changed: 396 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,396 @@
1+
"""
2+
This module contains the functions to compute the consistency / inconsistency sets.
3+
4+
Assuming that we have two datasets A and B, and that we are investigating whether A is consistent or not
5+
we can define the following sets:
6+
7+
8+
1⍺ = {i | nσi > Z⍺} ...
9+
10+
"""
11+
12+
import dataclasses
13+
import pandas as pd
14+
import numpy as np
15+
import logging
16+
from matplotlib.figure import Figure
17+
from scipy.stats import norm
18+
from typing import Any, Union, Generator
19+
20+
import reportengine
21+
from reportengine.figure import figuregen
22+
from reportengine import collect
23+
24+
from validphys.core import DataSetSpec, PDF
25+
from validphys.coredata import CommonData
26+
from validphys.calcutils import calc_chi2
27+
from validphys import convolution
28+
from validphys.closuretest.closure_checks import (
29+
check_fits_areclosures,
30+
check_fits_underlying_law_match,
31+
)
32+
from validphys import plotutils
33+
from validphys.api import API
34+
35+
import sys
36+
37+
sys.path.insert(0, "./")
38+
from multiclosure_nsigma_helpers import CentralChi2Data
39+
40+
41+
log = logging.getLogger(__name__)
42+
43+
"""
44+
Quantile range for computing the true positive rate and true negative rate.
45+
"""
46+
ALPHA_RANGE = np.linspace(0, 1.0, 100)
47+
48+
49+
def chi2_nsigma_deviation(central_member_chi2: CentralChi2Data) -> float:
50+
"""
51+
Computes n_sigma as: (chi2 - ndata) / sqrt(2 * ndata)
52+
53+
Parameters
54+
----------
55+
central_member_chi2: CentralChi2Data
56+
57+
Returns
58+
-------
59+
float
60+
The deviation in units of sigma.
61+
"""
62+
diff = central_member_chi2.value - central_member_chi2.ndata
63+
return diff / np.sqrt(2 * central_member_chi2.ndata)
64+
65+
66+
"""
67+
Collect the n_sigma values over list of ``dataset_input``.
68+
"""
69+
datasets_chi2_nsigma_deviation = collect("chi2_nsigma_deviation", ("data_input",))
70+
71+
72+
"""
73+
Collects over fits and for all datasets the n_sigma values.
74+
"""
75+
fits_datasets_chi2_nsigma_deviation = collect(
76+
"datasets_chi2_nsigma_deviation", ("fits", "fitcontext")
77+
)
78+
79+
80+
"""
81+
Collects the data for each fit.
82+
"""
83+
fits_data = collect("data", ("fits", "fitinputcontext"))
84+
85+
86+
def is_weighted(fits_data: list) -> bool:
87+
"""
88+
Returns whether the considered multiclosure tests has been weighted or not.
89+
If the weighted datasets are not the same for all fits,
90+
or there is more than one weighted dataset, an error is raised.
91+
92+
Parameters
93+
----------
94+
fits_data: list
95+
List of data for each fit.
96+
97+
Returns
98+
-------
99+
str or None
100+
Name of the weighted dataset.
101+
"""
102+
# Extract the set of unique weighted dataset names from all fits
103+
weighted_ds_sets = [{ds.name for ds in data.datasets if ds.weight != 1} for data in fits_data]
104+
105+
# Ensure all fits have the same set of weighted datasets
106+
if len(set(frozenset(ds_set) for ds_set in weighted_ds_sets)) > 1:
107+
error_msg = "Weighted datasets are not the same for all fits in the same multiclosure test (dataspec)."
108+
log.error(error_msg)
109+
raise ValueError(error_msg)
110+
111+
# Extract the single weighted dataset set (all should be identical)
112+
weighted_ds = next(iter(weighted_ds_sets))
113+
114+
# Ensure there is exactly one weighted dataset
115+
if len(weighted_ds) > 1:
116+
error_msg = "Only one dataset can be weighted in a multiclosure test."
117+
log.error(error_msg)
118+
raise ValueError(error_msg)
119+
120+
return bool(weighted_ds)
121+
122+
123+
@dataclasses.dataclass
124+
class MulticlosureNsigma:
125+
"""
126+
Dataclass containing nsigma values for all datasets and fits,
127+
also used to keep track on whether the multiclosure fit is weighted or not.
128+
129+
Attributes
130+
----------
131+
nsigma_table: pd.DataFrame
132+
A table containing n_sigma values.
133+
is_weighted: bool
134+
Whether the fit was weighted.
135+
"""
136+
137+
nsigma_table: pd.DataFrame
138+
is_weighted: bool
139+
140+
141+
@check_fits_areclosures
142+
@check_fits_underlying_law_match
143+
def multiclosurefits_nsigma(
144+
fits: reportengine.namespaces.NSList,
145+
fits_data: list,
146+
fits_datasets_chi2_nsigma_deviation: list,
147+
is_weighted: bool,
148+
) -> MulticlosureNsigma:
149+
"""
150+
Returns a table (dataframe) containing n_sigma values.
151+
Index: dataset names, Columns: Level 1 seeds (filterseed).
152+
153+
Parameters
154+
----------
155+
fits: NSList
156+
List of fits.
157+
fits_data: list
158+
List of data for each fit.
159+
fits_datasets_chi2_nsigma_deviation: list
160+
List of n_sigma values for each dataset for each fit.
161+
is_weighted: bool
162+
Used to keep track of whether the fit was weighted.
163+
164+
Returns
165+
-------
166+
MulticlosureNsigma
167+
"""
168+
res_dict = {}
169+
for fit, input_data, nsigma_data in zip(fits, fits_data, fits_datasets_chi2_nsigma_deviation):
170+
res_dict[fit.as_input()["closuretest"]["filterseed"]] = d = {}
171+
for ds, nsigma in zip(input_data.datasets, nsigma_data):
172+
d[ds.name] = nsigma
173+
return MulticlosureNsigma(is_weighted=is_weighted, nsigma_table=pd.DataFrame(res_dict))
174+
175+
176+
"""
177+
Collect the multiclosurefits_nsigma over dataspecs.
178+
"""
179+
dataspecs_multiclosurefits_nsigma = collect("multiclosurefits_nsigma", ("dataspecs",))
180+
181+
182+
def n_fits(dataspecs):
183+
"""
184+
Computes the total number of fits in the multiclosure test.
185+
If the number of fits is not the same across dataspecs it raises an error.
186+
"""
187+
n_fits = set()
188+
for dataspec in dataspecs:
189+
n_fits.add(len(dataspec['fits']))
190+
191+
if len(n_fits) > 1:
192+
error_msg = "The number of fits is not the same across dataspecs."
193+
log.error(error_msg)
194+
raise ValueError(error_msg)
195+
196+
return next(iter(n_fits))
197+
198+
199+
@dataclasses.dataclass
200+
class NsigmaAlpha:
201+
"""
202+
Dataclass storing the set 1 alpha values (can be used both for the set 1 and its complement).
203+
204+
Attributes
205+
----------
206+
alpha_dict: dict
207+
A dictionary containing the set 1 alpha values.
208+
is_weighted: bool
209+
Whether the fit was weighted.
210+
"""
211+
212+
alpha_dict: dict
213+
is_weighted: bool
214+
215+
216+
def def_of_nsigma_alpha(
217+
multiclosurefits_nsigma: pd.DataFrame, weighted_dataset: str, complement: bool = False
218+
) -> NsigmaAlpha:
219+
"""
220+
Defines how the set 1 alpha values are computed.
221+
It allows to compute both the set 1 and its complement.
222+
223+
Parameters
224+
----------
225+
multiclosurefits_nsigma: pd.DataFrame
226+
The nsigma table.
227+
weighted_dataset: str
228+
The name of the weighted dataset.
229+
complement: bool, default=False
230+
Whether to compute the complement set 1 alpha values.
231+
232+
Returns
233+
-------
234+
NsigmaAlpha
235+
"""
236+
df = multiclosurefits_nsigma.nsigma_table
237+
nsigma_values = df[df.index == weighted_dataset].values.flatten()
238+
set1_alpha = {}
239+
for alpha in ALPHA_RANGE:
240+
z_alpha = norm.ppf(1 - alpha)
241+
if complement:
242+
fit_idxs = np.where(nsigma_values < z_alpha)[0]
243+
else:
244+
fit_idxs = np.where(nsigma_values > z_alpha)[0]
245+
set1_alpha[alpha] = df.columns[fit_idxs].tolist()
246+
247+
return NsigmaAlpha(alpha_dict=set1_alpha, is_weighted=multiclosurefits_nsigma.is_weighted)
248+
249+
250+
def nsigma_alpha(multiclosurefits_nsigma: pd.DataFrame, weighted_dataset: str) -> NsigmaAlpha:
251+
"""
252+
Computes the set 1 alpha values.
253+
"""
254+
return def_of_nsigma_alpha(multiclosurefits_nsigma, weighted_dataset, complement=False)
255+
256+
257+
"""
258+
Collect set 1 alpha over dataspecs.
259+
"""
260+
dataspecs_nsigma_alpha = collect("nsigma_alpha", ("dataspecs",))
261+
262+
263+
def comp_nsigma_alpha(multiclosurefits_nsigma: pd.DataFrame, weighted_dataset: str) -> NsigmaAlpha:
264+
"""
265+
Computes the complement set 1 alpha values.
266+
"""
267+
return def_of_nsigma_alpha(multiclosurefits_nsigma, weighted_dataset, complement=True)
268+
269+
270+
"""
271+
Collect complement set 1 alpha over dataspecs.
272+
"""
273+
dataspecs_comp_nsigma_alpha = collect("comp_nsigma_alpha", ("dataspecs",))
274+
275+
276+
def set_1_alpha(dataspecs_nsigma_alpha: list) -> dict:
277+
"""
278+
Returns the set 1 alpha values.
279+
280+
Parameters
281+
----------
282+
dataspecs_nsigma_alpha: list
283+
List of NsigmaAlpha dataclasses.
284+
285+
Returns
286+
-------
287+
dict
288+
"""
289+
for dataspec_nsigma in dataspecs_nsigma_alpha:
290+
if not dataspec_nsigma.is_weighted:
291+
return dataspec_nsigma.alpha_dict
292+
293+
294+
def set_3_alpha(dataspecs_nsigma_alpha: list) -> dict:
295+
"""
296+
Same as the set 1 alpha values, but for the weighted datasets.
297+
298+
Parameters
299+
----------
300+
dataspecs_nsigma_alpha: list
301+
List of NsigmaAlpha dataclasses.
302+
303+
Returns
304+
-------
305+
dict
306+
"""
307+
for dataspec_nsigma in dataspecs_nsigma_alpha:
308+
if dataspec_nsigma.is_weighted:
309+
return dataspec_nsigma.alpha_dict
310+
311+
312+
def comp_set_1_alpha(dataspecs_comp_nsigma_alpha: list) -> dict:
313+
"""
314+
Returns the complement set 1 alpha values.
315+
"""
316+
for dataspec_nsigma in dataspecs_comp_nsigma_alpha:
317+
if not dataspec_nsigma.is_weighted:
318+
return dataspec_nsigma.alpha_dict
319+
320+
321+
def comp_set_3_alpha(dataspecs_comp_nsigma_alpha: list) -> dict:
322+
"""
323+
Same as the complement set 1 alpha values, but for the weighted datasets.
324+
"""
325+
for dataspec_nsigma in dataspecs_comp_nsigma_alpha:
326+
if dataspec_nsigma.is_weighted:
327+
return dataspec_nsigma.alpha_dict
328+
329+
330+
def def_set_2(
331+
dataspecs_multiclosurefits_nsigma: list, weighted_dataset: str, complement: bool = False
332+
) -> dict:
333+
"""
334+
Defines how the set 2 alpha values are computed.
335+
It allows to compute both the set 2 and its complement.
336+
337+
Parameters
338+
----------
339+
dataspecs_multiclosurefits_nsigma: list
340+
List of MulticlosureNsigma dataclasses.
341+
weighted_dataset: str
342+
The name of the weighted dataset.
343+
complement: bool, default=False
344+
Whether to compute the complement set 2 alpha values.
345+
346+
Returns
347+
-------
348+
dict
349+
"""
350+
# Order the dataspecs so that the weighted dataset is the first one
351+
dataspecs_mct = []
352+
for mct_nsigma in dataspecs_multiclosurefits_nsigma:
353+
if mct_nsigma.is_weighted:
354+
dataspecs_mct.insert(0, mct_nsigma)
355+
else:
356+
dataspecs_mct.append(mct_nsigma)
357+
358+
df_weight = dataspecs_mct[0].nsigma_table
359+
df_weight = df_weight[df_weight.index != weighted_dataset]
360+
361+
df_ref = dataspecs_mct[1].nsigma_table
362+
df_ref = df_ref[df_ref.index != weighted_dataset]
363+
364+
# ensure that weighted and reference dfs have the columns in the same order
365+
# (needed to properly compare fits)
366+
df_ref = df_ref[df_weight.columns]
367+
368+
set2_alpha = {}
369+
370+
for alpha in ALPHA_RANGE:
371+
z_alpha = norm.ppf(1 - alpha)
372+
373+
if complement:
374+
columns_bools = np.any((df_weight - df_ref).values < z_alpha, axis=0)
375+
else:
376+
columns_bools = np.any((df_weight - df_ref).values > z_alpha, axis=0)
377+
378+
columns = df_weight.columns[columns_bools].to_list()
379+
380+
set2_alpha[alpha] = columns
381+
382+
return set2_alpha
383+
384+
385+
def set_2_alpha(dataspecs_multiclosurefits_nsigma: list, weighted_dataset: str) -> dict:
386+
"""
387+
Computes the set 2 alpha values.
388+
"""
389+
return def_set_2(dataspecs_multiclosurefits_nsigma, weighted_dataset, complement=False)
390+
391+
392+
def comp_set_2_alpha(dataspecs_multiclosurefits_nsigma: list, weighted_dataset: str) -> dict:
393+
"""
394+
Computes the complement set 2 alpha values.
395+
"""
396+
return def_set_2(dataspecs_multiclosurefits_nsigma, weighted_dataset, complement=True)

0 commit comments

Comments
 (0)