Skip to content

Commit 30b78d4

Browse files
Phmonskiguitargeek
authored andcommitted
[HS3] Add parameter errors in HS3 JSON import/export
This PR adds support for exporting and importing RooRealVar parameter errors through HS3 JSON. Parameter uncertainties are written to misc.minimization.parameter_stepwidths and restored during import after the default parameter snapshot is loaded. The export logic collects relevant parameters from ModelConfig POI/nuisance sets and from PDF/data fallback discovery, while excluding observables and data axes. The PR also adds regression coverage for parameter-error round-tripping, data-axis exclusion, and preserving imported errors together with default snapshot values. (cherry picked from commit 283b12f)
1 parent 60e002e commit 30b78d4

2 files changed

Lines changed: 255 additions & 5 deletions

File tree

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 119 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,114 @@ void importAttributes(RooAbsArg *arg, JSONNode const &node)
307307
}
308308
}
309309

310+
void addIfPresent(RooArgSet &out, RooArgSet const *args)
311+
{
312+
if (args) {
313+
out.add(*args, true);
314+
}
315+
}
316+
317+
void collectParameterStepWidthCandidatesFromModelConfigs(RooWorkspace const &workspace, RooArgSet &candidates,
318+
RooArgSet &excluded)
319+
{
320+
for (TObject *obj : workspace.allGenericObjects()) {
321+
auto const *mc = dynamic_cast<RooFit::ModelConfig const *>(obj);
322+
if (!mc) {
323+
continue;
324+
}
325+
326+
addIfPresent(candidates, mc->GetParametersOfInterest());
327+
addIfPresent(candidates, mc->GetNuisanceParameters());
328+
329+
addIfPresent(excluded, mc->GetObservables());
330+
addIfPresent(excluded, mc->GetGlobalObservables());
331+
addIfPresent(excluded, mc->GetConditionalObservables());
332+
}
333+
}
334+
335+
void collectParameterStepWidthCandidatesFromPdfs(std::vector<RooAbsPdf *> const &pdfs,
336+
std::vector<RooAbsData *> const &data, RooArgSet &candidates,
337+
RooArgSet &excluded)
338+
{
339+
for (RooAbsPdf const *pdf : pdfs) {
340+
RooArgSet observables;
341+
for (RooAbsData const *dataset : data) {
342+
std::unique_ptr<RooArgSet> pdfObs{pdf->getObservables(*dataset->get())};
343+
observables.add(*pdfObs, true);
344+
}
345+
346+
if (observables.empty()) {
347+
continue;
348+
}
349+
350+
RooArgSet params;
351+
pdf->getParameters(&observables, params);
352+
candidates.add(params, true);
353+
excluded.add(observables, true);
354+
}
355+
}
356+
357+
void exportParameterStepWidths(RooWorkspace const &workspace, std::vector<RooAbsPdf *> const &pdfs,
358+
std::vector<RooAbsData *> const &data, JSONNode &rootnode)
359+
{
360+
RooArgSet candidates;
361+
RooArgSet excluded;
362+
363+
collectParameterStepWidthCandidatesFromModelConfigs(workspace, candidates, excluded);
364+
collectParameterStepWidthCandidatesFromPdfs(pdfs, data, candidates, excluded);
365+
366+
candidates.sort();
367+
368+
JSONNode *parameterStepWidthsNode = nullptr;
369+
for (RooAbsArg *arg : candidates) {
370+
if (excluded.find(*arg)) {
371+
continue;
372+
}
373+
374+
auto *var = dynamic_cast<RooRealVar *>(arg);
375+
if (!var || !var->hasError()) {
376+
continue;
377+
}
378+
379+
if (!parameterStepWidthsNode) {
380+
parameterStepWidthsNode = &rootnode["misc"]["minimization"]["parameter_stepwidths"].set_seq();
381+
}
382+
383+
JSONNode &stepWidthNode = RooJSONFactoryWSTool::appendNamedChild(*parameterStepWidthsNode, var->GetName());
384+
stepWidthNode["step_width"] << var->getError();
385+
}
386+
}
387+
388+
void importParameterStepWidths(RooWorkspace &workspace, JSONNode const &rootnode)
389+
{
390+
auto const *parameterStepWidthsNode = rootnode.find("misc", "minimization", "parameter_stepwidths");
391+
if (!parameterStepWidthsNode) {
392+
return;
393+
}
394+
if (!parameterStepWidthsNode->is_seq()) {
395+
RooJSONFactoryWSTool::warning("RooFitHS3: misc.minimization.parameter_stepwidths is not a sequence, skipping.");
396+
return;
397+
}
398+
399+
for (JSONNode const &stepWidthNode : parameterStepWidthsNode->children()) {
400+
if (!stepWidthNode.is_map() || !stepWidthNode.has_child("name") || !stepWidthNode.has_child("step_width")) {
401+
RooJSONFactoryWSTool::warning("RooFitHS3: skipping malformed parameter_stepwidths entry.");
402+
continue;
403+
}
404+
405+
const std::string name = RooJSONFactoryWSTool::name(stepWidthNode);
406+
RooAbsArg *arg = workspace.arg(name);
407+
auto *var = dynamic_cast<RooRealVar *>(arg);
408+
if (!var) {
409+
RooJSONFactoryWSTool::warning(
410+
"RooFitHS3: skipping parameter_stepwidths entry for unknown or non-real variable '" + name + "'.");
411+
continue;
412+
}
413+
414+
var->setError(stepWidthNode.find("step_width")->val_double());
415+
}
416+
}
417+
310418
// RooWSFactoryTool expression handling
311419
std::string generate(const RooFit::JSONIO::ImportExpression &ex, const JSONNode &p, RooJSONFactoryWSTool *tool)
312420
{
@@ -594,7 +702,7 @@ void importAnalysis(const JSONNode &rootnode, const JSONNode &analysisNode, cons
594702
for (const auto &d : datasets) {
595703
if (d->GetName() == nameNode.val()) {
596704
found = true;
597-
observables.add(*d->get());
705+
observables.add(*d->get(), true);
598706
}
599707
}
600708
if (nameNode.val() != "0" && !found)
@@ -758,7 +866,7 @@ void combineDatasets(const JSONNode &rootnode, std::vector<std::unique_ptr<RooAb
758866
datasets.begin(), datasets.end(), [&](auto &d) { return d && d->GetName() == componentName; });
759867
if (!component)
760868
RooJSONFactoryWSTool::error("unable to obtain component matching component name '" + componentName + "'");
761-
allVars.add(*component->get());
869+
allVars.add(*component->get(), true);
762870
dsMap.insert({labels[iChannel], std::move(component)});
763871
indexCat.defineType(labels[iChannel], indices[iChannel]);
764872
}
@@ -1787,7 +1895,7 @@ void RooJSONFactoryWSTool::exportSingleModelConfig(JSONNode &rootnode, RooFit::M
17871895
nllNode["data"].set_seq();
17881896

17891897
if (dataComponents) {
1790-
auto simPdf = static_cast<RooSimultaneous const *>(pdf);
1898+
auto simPdf = dynamic_cast<RooSimultaneous const *>(pdf);
17911899
if (simPdf) {
17921900
for (auto const &item : simPdf->indexCat()) {
17931901
const auto &dataComp = dataComponents->find(item.first);
@@ -1924,6 +2032,8 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n)
19242032
}
19252033
}
19262034

2035+
exportParameterStepWidths(_workspace, allpdfs, alldata, n);
2036+
19272037
for (auto *snsh : static_range_cast<RooArgSet const *>(_workspace.getSnapshots())) {
19282038
RooArgSet snapshotSorted;
19292039
// We only want to add the variables that actually got exported and skip
@@ -2276,10 +2386,12 @@ bool RooJSONFactoryWSTool::importJSON(std::istream &is)
22762386
{
22772387
// import a JSON file to the workspace
22782388
std::unique_ptr<JSONTree> tree = JSONTree::create(is);
2279-
this->importAllNodes(tree->rootnode());
2389+
JSONNode const &rootnode = tree->rootnode();
2390+
this->importAllNodes(rootnode);
22802391
if (this->workspace()->getSnapshot("default_values")) {
22812392
this->workspace()->loadSnapshot("default_values");
22822393
}
2394+
importParameterStepWidths(*this->workspace(), rootnode);
22832395
return true;
22842396
}
22852397

@@ -2312,7 +2424,9 @@ bool RooJSONFactoryWSTool::importYML(std::istream &is)
23122424
{
23132425
// import a YML file to the workspace
23142426
std::unique_ptr<JSONTree> tree = JSONTree::create(is);
2315-
this->importAllNodes(tree->rootnode());
2427+
JSONNode const &rootnode = tree->rootnode();
2428+
this->importAllNodes(rootnode);
2429+
importParameterStepWidths(*this->workspace(), rootnode);
23162430
return true;
23172431
}
23182432

roofit/hs3/test/testRooFitHS3.cxx

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,20 @@ int validate(RooAbsArg const &arg, bool exact = true)
132132
return validate(ws, arg.GetName(), exact);
133133
}
134134

135+
std::string parameterStepWidthsNode(std::string const &json)
136+
{
137+
const std::string key = "\"parameter_stepwidths\":[";
138+
const auto begin = json.find(key);
139+
if (begin == std::string::npos) {
140+
return "";
141+
}
142+
const auto end = json.find("]", begin);
143+
if (end == std::string::npos) {
144+
return "";
145+
}
146+
return json.substr(begin, end - begin + 1);
147+
}
148+
135149
} // namespace
136150

137151
// Test that the IO of attributes and string attributes works.
@@ -165,6 +179,128 @@ TEST(RooFitHS3, AttributesIO)
165179
EXPECT_STREQ(pdf.getStringAttribute("key1"), nullptr) << "unexpected string attribute found!";
166180
}
167181

182+
TEST(RooFitHS3, ParameterStepWidthsModelConfigRoundTrip)
183+
{
184+
RooWorkspace ws1{"workspace"};
185+
ws1.factory("Gaussian::sig(x[-5, 5], mu[0, -10, 10], sigma[1, 0.1, 10])");
186+
ws1.factory("Polynomial::bkg(x, {theta[0, -1, 1]})");
187+
ws1.factory("SUM::model(fsig[0.5, 0, 1] * sig, bkg)");
188+
189+
RooRealVar &x = *ws1.var("x");
190+
RooDataSet data{"data", "data", RooArgSet{x}};
191+
for (double val : {-1.0, 0.5, 1.5}) {
192+
x.setVal(val);
193+
data.add(RooArgSet{x});
194+
}
195+
ws1.import(data);
196+
197+
ws1.var("x")->setError(9.0);
198+
ws1.var("mu")->setError(0.12);
199+
ws1.var("theta")->setError(0.33);
200+
ws1.var("sigma")->setError(0.20);
201+
ws1.var("sigma")->setAsymError(-0.18, 0.25);
202+
203+
RooFit::ModelConfig mc{"mc", &ws1};
204+
mc.SetPdf(*ws1.pdf("model"));
205+
mc.SetObservables("x");
206+
mc.SetParametersOfInterest("mu");
207+
mc.SetNuisanceParameters("sigma");
208+
ws1.import(mc);
209+
210+
const std::string json = RooJSONFactoryWSTool{ws1}.exportJSONtoString();
211+
const std::string parameterStepWidths = parameterStepWidthsNode(json);
212+
ASSERT_FALSE(parameterStepWidths.empty()) << json;
213+
EXPECT_NE(parameterStepWidths.find("\"name\":\"mu\""), std::string::npos) << parameterStepWidths;
214+
EXPECT_NE(parameterStepWidths.find("\"name\":\"sigma\""), std::string::npos) << parameterStepWidths;
215+
EXPECT_NE(parameterStepWidths.find("\"name\":\"theta\""), std::string::npos) << parameterStepWidths;
216+
EXPECT_NE(parameterStepWidths.find("\"step_width\":0.12"), std::string::npos) << parameterStepWidths;
217+
EXPECT_NE(parameterStepWidths.find("\"step_width\":0.2"), std::string::npos) << parameterStepWidths;
218+
EXPECT_EQ(parameterStepWidths.find("\"error_lo\""), std::string::npos) << parameterStepWidths;
219+
EXPECT_EQ(parameterStepWidths.find("\"error_hi\""), std::string::npos) << parameterStepWidths;
220+
EXPECT_EQ(parameterStepWidths.find("\"name\":\"x\""), std::string::npos) << parameterStepWidths;
221+
222+
RooWorkspace ws2{"workspace2"};
223+
ASSERT_TRUE(RooJSONFactoryWSTool{ws2}.importJSONfromString(json));
224+
225+
ASSERT_NE(ws2.var("mu"), nullptr);
226+
ASSERT_NE(ws2.var("theta"), nullptr);
227+
ASSERT_NE(ws2.var("sigma"), nullptr);
228+
ASSERT_NE(ws2.var("x"), nullptr);
229+
EXPECT_TRUE(ws2.var("mu")->hasError());
230+
EXPECT_DOUBLE_EQ(ws2.var("mu")->getError(), 0.12);
231+
EXPECT_TRUE(ws2.var("theta")->hasError());
232+
EXPECT_DOUBLE_EQ(ws2.var("theta")->getError(), 0.33);
233+
EXPECT_TRUE(ws2.var("sigma")->hasError());
234+
EXPECT_DOUBLE_EQ(ws2.var("sigma")->getError(), 0.20);
235+
EXPECT_FALSE(ws2.var("sigma")->hasAsymError());
236+
EXPECT_FALSE(ws2.var("x")->hasError());
237+
}
238+
239+
TEST(RooFitHS3, ParameterStepWidthsFallbackExcludesDataAxes)
240+
{
241+
RooWorkspace ws1{"workspace"};
242+
ws1.factory("Gaussian::model(x[-5, 5], mu[0, -10, 10], sigma[1, 0.1, 10])");
243+
244+
RooRealVar &x = *ws1.var("x");
245+
RooDataSet data{"data", "data", RooArgSet{x}};
246+
for (double val : {-1.0, 0.5, 1.5}) {
247+
x.setVal(val);
248+
data.add(RooArgSet{x});
249+
}
250+
ws1.import(data);
251+
252+
ws1.var("x")->setError(9.0);
253+
ws1.var("mu")->setError(0.12);
254+
ws1.var("sigma")->setError(0.20);
255+
256+
const std::string json = RooJSONFactoryWSTool{ws1}.exportJSONtoString();
257+
const std::string parameterStepWidths = parameterStepWidthsNode(json);
258+
ASSERT_FALSE(parameterStepWidths.empty()) << json;
259+
EXPECT_NE(parameterStepWidths.find("\"name\":\"mu\""), std::string::npos) << parameterStepWidths;
260+
EXPECT_NE(parameterStepWidths.find("\"name\":\"sigma\""), std::string::npos) << parameterStepWidths;
261+
EXPECT_EQ(parameterStepWidths.find("\"name\":\"x\""), std::string::npos) << parameterStepWidths;
262+
263+
RooWorkspace ws2{"workspace2"};
264+
ASSERT_TRUE(RooJSONFactoryWSTool{ws2}.importJSONfromString(json));
265+
266+
ASSERT_NE(ws2.var("mu"), nullptr);
267+
ASSERT_NE(ws2.var("sigma"), nullptr);
268+
ASSERT_NE(ws2.var("x"), nullptr);
269+
EXPECT_DOUBLE_EQ(ws2.var("mu")->getError(), 0.12);
270+
EXPECT_DOUBLE_EQ(ws2.var("sigma")->getError(), 0.20);
271+
EXPECT_FALSE(ws2.var("x")->hasError());
272+
}
273+
274+
TEST(RooFitHS3, ParameterStepWidthsImportAfterDefaultSnapshot)
275+
{
276+
const std::string json = R"({
277+
"metadata": {"hs3_version": "0.1.90"},
278+
"parameter_points": [
279+
{
280+
"name": "default_values",
281+
"parameters": [
282+
{"name": "mu", "value": 0.0, "err": 0.01}
283+
]
284+
}
285+
],
286+
"misc": {
287+
"minimization": {
288+
"parameter_stepwidths": [
289+
{"name": "mu", "step_width": 0.42},
290+
{"name": "missing", "step_width": 1.0}
291+
]
292+
}
293+
}
294+
})";
295+
296+
RooWorkspace ws{"workspace"};
297+
ASSERT_TRUE(RooJSONFactoryWSTool{ws}.importJSONfromString(json));
298+
299+
ASSERT_NE(ws.var("mu"), nullptr);
300+
EXPECT_TRUE(ws.var("mu")->hasError());
301+
EXPECT_DOUBLE_EQ(ws.var("mu")->getError(), 0.42);
302+
}
303+
168304
TEST(RooFitHS3, RooAddPdf)
169305
{
170306
int status = validate({"Gaussian::sig(x[5.20, 5.30], sigmean[5.28, 5.20, 5.30], sigwidth[0.0027, 0.001, 1.])",

0 commit comments

Comments
 (0)