@@ -31,6 +31,49 @@ class ReplicaGenerationError(Exception):
3131 pass
3232
3333
34+ def fit_diagonal_basis_rotation (fit ):
35+ """Rotation matrix taking pseudodata from the original to the diagonal basis,
36+ or ``None`` if ``fit`` was not run in diagonal basis.
37+
38+ Sources the matrix from the same eigensystem table that
39+ ``_inv_covmat_prepared`` loads when running the fit, so the rotation
40+ applied here is bit-identical to the one used at generation time.
41+ """
42+ runcard = fit .as_input ()
43+ if not runcard .get ("diagonal_basis" , True ):
44+ return None
45+ use_thcovmat = runcard .get ("theorycovmatconfig" , {}).get ("use_thcovmat_in_fitting" , False )
46+ fname = (
47+ "datacuts_theory_theorycovmatconfig_fitting_covmat_table.csv"
48+ if use_thcovmat
49+ else "datacuts_theory_fitting_covmat_table.csv"
50+ )
51+ eigensystem = pd .read_csv (
52+ fit .path / "tables" / fname , index_col = [0 ], header = [0 ], sep = "\t |," , engine = "python"
53+ )
54+ return eigensystem .iloc [:, 1 :].values
55+
56+
57+ def diagonal_indexed_recreate_pseudodata (indexed_make_replica , fit_diagonal_basis_rotation ):
58+ """Recreation-time analogue of
59+ :py:func:`validphys.n3fit_data.diagonal_indexed_make_replica`, but doesn't
60+ need :py:func:`_inv_covmat_prepared` (through :py:func:`fitting_data_dict`)
61+ which requires `output_path` to be available.
62+
63+ Returns the pseudodata in the diagonal basis (eigenmode-indexed) when the
64+ fit was run in diagonal basis, otherwise returns ``indexed_make_replica``
65+ untouched (preserving the original ``(group, dataset, id)`` MultiIndex).
66+ """
67+ diag_rot = fit_diagonal_basis_rotation
68+ if diag_rot is None :
69+ return indexed_make_replica
70+ values = indexed_make_replica .iloc [:, 0 ].to_numpy ()
71+ rotated = diag_rot @ values
72+ return pd .DataFrame (
73+ rotated , index = pd .Index ([f"eigenmode { i } " for i in range (len (rotated ))]), columns = ["data" ]
74+ )
75+
76+
3477def read_replica_pseudodata (fit , context_index , replica ):
3578 """Function to handle the reading of training and validation splits for a fit that has been
3679 produced with the ``savepseudodata`` flag set to ``True``.
@@ -70,10 +113,8 @@ def read_replica_pseudodata(fit, context_index, replica):
70113 5 3.117819
71114 6 0.771079
72115 """
73- # List of length 1 due to the collect
74- context_index = context_index [0 ]
75- # The [0] is because of how pandas handles sorting a MultiIndex
76- sorted_index = context_index .sortlevel (level = range (1 , 3 ))[0 ]
116+ # Detect whether fit performed in diagonal basis
117+ diagonal_basis = fit .as_input ().get ("diagonal_basis" , True )
77118
78119 log .debug (f"Reading pseudodata & training/validation splits from { fit .name } ." )
79120 replica_path = fit .path / "nnfit" / f"replica_{ replica } "
@@ -87,32 +128,37 @@ def read_replica_pseudodata(fit, context_index, replica):
87128 tr_pseudodatafile = "datacuts_theory_fitting_training_pseudodata.csv"
88129 vl_pseudodatafile = "datacuts_theory_fitting_validation_pseudodata.csv"
89130
131+ index_col = [0 ] if diagonal_basis else [0 , 1 , 2 ]
132+
90133 try :
91- tr = pd .read_csv (replica_path / tr_pseudodatafile , index_col = [ 0 , 1 , 2 ] , sep = "\t " , header = 0 )
92- val = pd .read_csv (replica_path / vl_pseudodatafile , index_col = [ 0 , 1 , 2 ] , sep = "\t " , header = 0 )
134+ tr = pd .read_csv (replica_path / tr_pseudodatafile , index_col = index_col , sep = "\t " , header = 0 )
135+ val = pd .read_csv (replica_path / vl_pseudodatafile , index_col = index_col , sep = "\t " , header = 0 )
93136 except FileNotFoundError as e :
94137 raise FileNotFoundError (
95138 "Could not find saved training and validation data files. "
96139 f"Please ensure { fit } was generated with the savepseudodata flag set to true"
97140 ) from e
98141
99142 tr ["type" ], val ["type" ] = "training" , "validation"
100-
101143 pseudodata = pd .concat ((tr , val ))
102144
103- # In order for this function to work also with old fit, it is necessary to remap the names
104- # being read (since the names in the context have already been remapped)
105- # The following checks whether a given name is in both the context and the fit, and if not
106- # tries to get it from the old_to_new mapping.
107- mapping = {}
108- context_datasets = context_index .get_level_values ("dataset" ).unique ()
109- for dsname in pseudodata .index .get_level_values ("dataset" ).unique ():
110- if dsname not in context_datasets :
111- new_name , _ = legacy_to_new_map (dsname )
112- mapping [dsname ] = new_name
145+ if not diagonal_basis :
146+ # In order for this function to work also with old fit, it is necessary to remap the names
147+ # being read (since the names in the context have already been remapped)
148+ # The following checks whether a given name is in both the context and the fit, and if not
149+ # tries to get it from the old_to_new mapping.
150+ ctx = context_index [0 ]
151+ sorted_index = ctx .sortlevel (level = range (1 , 3 ))[0 ]
113152
114- pseudodata = pseudodata .rename (mapping , level = 1 ).sort_index (level = range (1 , 3 ))
115- pseudodata .index = sorted_index
153+ mapping = {}
154+ context_datasets = ctx .get_level_values ("dataset" ).unique ()
155+ for dsname in pseudodata .index .get_level_values ("dataset" ).unique ():
156+ if dsname not in context_datasets :
157+ new_name , _ = legacy_to_new_map (dsname )
158+ mapping [dsname ] = new_name
159+
160+ pseudodata = pseudodata .rename (mapping , level = 1 ).sort_index (level = range (1 , 3 ))
161+ pseudodata .index = sorted_index
116162
117163 tr = pseudodata [pseudodata ["type" ] == "training" ]
118164 val = pseudodata [pseudodata ["type" ] == "validation" ]
@@ -462,7 +508,9 @@ def make_level1_data(data, level0_commondata_wc, filterseed, data_index, sep_mul
462508 return level1_commondata_instances_wc
463509
464510
465- _group_recreate_pseudodata = collect ('indexed_make_replica' , ('group_dataset_inputs_by_metadata' ,))
511+ _group_recreate_pseudodata = collect (
512+ 'diagonal_indexed_recreate_pseudodata' , ('group_dataset_inputs_by_metadata' ,)
513+ )
466514_recreate_fit_pseudodata = collect ('_group_recreate_pseudodata' , ('fitreplicas' , 'fitenvironment' ))
467515_recreate_pdf_pseudodata = collect ('_group_recreate_pseudodata' , ('pdfreplicas' , 'fitenvironment' ))
468516
0 commit comments