Skip to content

Commit 917daa1

Browse files
committed
Fix in diagonal basis + update tests
1 parent ca1f704 commit 917daa1

3 files changed

Lines changed: 84 additions & 34 deletions

File tree

validphys2/src/validphys/pseudodata.py

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,49 @@ class ReplicaGenerationError(Exception):
3131
pass
3232

3333

34+
def fit_diagonal_basis_rotation(fit):
35+
"""Rotation matrix taking pseudodata from the original to the diagonal basis,
36+
or ``None`` if ``fit`` was not run in diagonal basis.
37+
38+
Sources the matrix from the same eigensystem table that
39+
``_inv_covmat_prepared`` loads when running the fit, so the rotation
40+
applied here is bit-identical to the one used at generation time.
41+
"""
42+
runcard = fit.as_input()
43+
if not runcard.get("diagonal_basis", True):
44+
return None
45+
use_thcovmat = runcard.get("theorycovmatconfig", {}).get("use_thcovmat_in_fitting", False)
46+
fname = (
47+
"datacuts_theory_theorycovmatconfig_fitting_covmat_table.csv"
48+
if use_thcovmat
49+
else "datacuts_theory_fitting_covmat_table.csv"
50+
)
51+
eigensystem = pd.read_csv(
52+
fit.path / "tables" / fname, index_col=[0], header=[0], sep="\t|,", engine="python"
53+
)
54+
return eigensystem.iloc[:, 1:].values
55+
56+
57+
def diagonal_indexed_recreate_pseudodata(indexed_make_replica, fit_diagonal_basis_rotation):
58+
"""Recreation-time analogue of
59+
:py:func:`validphys.n3fit_data.diagonal_indexed_make_replica`, but doesn't
60+
need :py:func:`_inv_covmat_prepared` (through :py:func:`fitting_data_dict`)
61+
which requires `output_path` to be available.
62+
63+
Returns the pseudodata in the diagonal basis (eigenmode-indexed) when the
64+
fit was run in diagonal basis, otherwise returns ``indexed_make_replica``
65+
untouched (preserving the original ``(group, dataset, id)`` MultiIndex).
66+
"""
67+
diag_rot = fit_diagonal_basis_rotation
68+
if diag_rot is None:
69+
return indexed_make_replica
70+
values = indexed_make_replica.iloc[:, 0].to_numpy()
71+
rotated = diag_rot @ values
72+
return pd.DataFrame(
73+
rotated, index=pd.Index([f"eigenmode {i}" for i in range(len(rotated))]), columns=["data"]
74+
)
75+
76+
3477
def read_replica_pseudodata(fit, context_index, replica):
3578
"""Function to handle the reading of training and validation splits for a fit that has been
3679
produced with the ``savepseudodata`` flag set to ``True``.
@@ -70,10 +113,8 @@ def read_replica_pseudodata(fit, context_index, replica):
70113
5 3.117819
71114
6 0.771079
72115
"""
73-
# List of length 1 due to the collect
74-
context_index = context_index[0]
75-
# The [0] is because of how pandas handles sorting a MultiIndex
76-
sorted_index = context_index.sortlevel(level=range(1, 3))[0]
116+
# Detect whether fit performed in diagonal basis
117+
diagonal_basis = fit.as_input().get("diagonal_basis", True)
77118

78119
log.debug(f"Reading pseudodata & training/validation splits from {fit.name}.")
79120
replica_path = fit.path / "nnfit" / f"replica_{replica}"
@@ -87,32 +128,37 @@ def read_replica_pseudodata(fit, context_index, replica):
87128
tr_pseudodatafile = "datacuts_theory_fitting_training_pseudodata.csv"
88129
vl_pseudodatafile = "datacuts_theory_fitting_validation_pseudodata.csv"
89130

131+
index_col = [0] if diagonal_basis else [0, 1, 2]
132+
90133
try:
91-
tr = pd.read_csv(replica_path / tr_pseudodatafile, index_col=[0, 1, 2], sep="\t", header=0)
92-
val = pd.read_csv(replica_path / vl_pseudodatafile, index_col=[0, 1, 2], sep="\t", header=0)
134+
tr = pd.read_csv(replica_path / tr_pseudodatafile, index_col=index_col, sep="\t", header=0)
135+
val = pd.read_csv(replica_path / vl_pseudodatafile, index_col=index_col, sep="\t", header=0)
93136
except FileNotFoundError as e:
94137
raise FileNotFoundError(
95138
"Could not find saved training and validation data files. "
96139
f"Please ensure {fit} was generated with the savepseudodata flag set to true"
97140
) from e
98141

99142
tr["type"], val["type"] = "training", "validation"
100-
101143
pseudodata = pd.concat((tr, val))
102144

103-
# In order for this function to work also with old fit, it is necessary to remap the names
104-
# being read (since the names in the context have already been remapped)
105-
# The following checks whether a given name is in both the context and the fit, and if not
106-
# tries to get it from the old_to_new mapping.
107-
mapping = {}
108-
context_datasets = context_index.get_level_values("dataset").unique()
109-
for dsname in pseudodata.index.get_level_values("dataset").unique():
110-
if dsname not in context_datasets:
111-
new_name, _ = legacy_to_new_map(dsname)
112-
mapping[dsname] = new_name
145+
if not diagonal_basis:
146+
# In order for this function to work also with old fit, it is necessary to remap the names
147+
# being read (since the names in the context have already been remapped)
148+
# The following checks whether a given name is in both the context and the fit, and if not
149+
# tries to get it from the old_to_new mapping.
150+
ctx = context_index[0]
151+
sorted_index = ctx.sortlevel(level=range(1, 3))[0]
113152

114-
pseudodata = pseudodata.rename(mapping, level=1).sort_index(level=range(1, 3))
115-
pseudodata.index = sorted_index
153+
mapping = {}
154+
context_datasets = ctx.get_level_values("dataset").unique()
155+
for dsname in pseudodata.index.get_level_values("dataset").unique():
156+
if dsname not in context_datasets:
157+
new_name, _ = legacy_to_new_map(dsname)
158+
mapping[dsname] = new_name
159+
160+
pseudodata = pseudodata.rename(mapping, level=1).sort_index(level=range(1, 3))
161+
pseudodata.index = sorted_index
116162

117163
tr = pseudodata[pseudodata["type"] == "training"]
118164
val = pseudodata[pseudodata["type"] == "validation"]
@@ -462,7 +508,9 @@ def make_level1_data(data, level0_commondata_wc, filterseed, data_index, sep_mul
462508
return level1_commondata_instances_wc
463509

464510

465-
_group_recreate_pseudodata = collect('indexed_make_replica', ('group_dataset_inputs_by_metadata',))
511+
_group_recreate_pseudodata = collect(
512+
'diagonal_indexed_recreate_pseudodata', ('group_dataset_inputs_by_metadata',)
513+
)
466514
_recreate_fit_pseudodata = collect('_group_recreate_pseudodata', ('fitreplicas', 'fitenvironment'))
467515
_recreate_pdf_pseudodata = collect('_group_recreate_pseudodata', ('pdfreplicas', 'fitenvironment'))
468516

validphys2/src/validphys/tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def tmp(tmpdir):
7474
FIT_3REPLICAS_DCUTS = "FIT_3REPLICAS_250616_diffcuts"
7575
FIT = "NNPDF40_nnlo_like_CI_testing_250616"
7676
FIT_ITERATED = "NNPDF40_nnlo_like_CI_testing_250616_iterated"
77-
PSEUDODATA_FIT = "pseudodata_test_fit_n3fit_250616"
78-
PSEUDODATA_FIT_DIAG = "pseudodata_test_fit_n3fit_251104"
77+
PSEUDODATA_FIT = "pseudodata_test_fit_n3fit_260518"
78+
PSEUDODATA_FIT_DIAG = "pseudodata_test_fit_diag_n3fit_260518"
7979
# These fits contain _only_ data
8080
MULTICLOSURE_FITS = ["250618-test-multiclosure-001", "250618-test-multiclosure-002"]
8181

validphys2/src/validphys/tests/test_pseudodata.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,20 @@ def test_no_savepseudodata():
8181
func(fit=FIT)
8282

8383

84-
def test_read_matches_recreate():
85-
86-
for fit in [PSEUDODATA_FIT, PSEUDODATA_FIT_DIAG]:
87-
diagonal_basis = True if fit == PSEUDODATA_FIT_DIAG else False
88-
reads = API.read_fit_pseudodata(fit=fit, diagonal_basis=diagonal_basis)
89-
recreates = API.recreate_fit_pseudodata(fit=fit, diagonal_basis=diagonal_basis)
90-
for read, recreate in zip(reads, recreates):
91-
# We ignore the absolute ordering of the dataframes and just check
92-
# that they contain identical elements.
93-
pd.testing.assert_frame_equal(read.pseudodata, recreate.pseudodata, check_like=True)
94-
pd.testing.assert_index_equal(read.tr_idx, recreate.tr_idx, check_order=False)
95-
pd.testing.assert_index_equal(read.val_idx, recreate.val_idx, check_order=False)
84+
@pytest.mark.parametrize(
85+
"fit, diagonal_basis",
86+
[(PSEUDODATA_FIT, False), (PSEUDODATA_FIT_DIAG, True)],
87+
ids=["standard", "diagonal"],
88+
)
89+
def test_read_matches_recreate(fit, diagonal_basis):
90+
reads = API.read_fit_pseudodata(fit=fit, diagonal_basis=diagonal_basis)
91+
recreates = API.recreate_fit_pseudodata(fit=fit, diagonal_basis=diagonal_basis)
92+
for read, recreate in zip(reads, recreates):
93+
# We ignore the absolute ordering of the dataframes and just check
94+
# that they contain identical elements.
95+
pd.testing.assert_frame_equal(read.pseudodata, recreate.pseudodata, check_like=True)
96+
pd.testing.assert_index_equal(read.tr_idx, recreate.tr_idx, check_order=False)
97+
pd.testing.assert_index_equal(read.val_idx, recreate.val_idx, check_order=False)
9698

9799

98100
def test_level0_commondata_wc():

0 commit comments

Comments
 (0)