Skip to content

Commit ae92c2e

Browse files
committed
[HS3] Fix HistFactory export and RooConstVar roundtripping
1 parent 9916067 commit ae92c2e

4 files changed

Lines changed: 184 additions & 18 deletions

File tree

roofit/hs3/src/Domains.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ void Domains::ProductDomain::populate(RooWorkspace &ws) const
162162
{
163163
for (auto const &item : _map) {
164164
const auto &name = item.first;
165-
if (!ws.var(name)) {
165+
if (!ws.arg(name)) {
166166
const auto &elem = item.second;
167167
const double vMin = elem.hasMin ? elem.min : -RooNumber::infinity();
168168
const double vMax = elem.hasMax ? elem.max : RooNumber::infinity();

roofit/hs3/src/JSONFactories_HistFactory.cxx

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,36 @@ getOrCreateConstraint(RooJSONFactoryWSTool &tool, const JSONNode &mod, RooRealVa
283283
"'");
284284
}
285285
}
286+
double poissonTau(RooPoisson const &constraint, RooAbsArg const &gamma)
287+
{
288+
auto const *mean = dynamic_cast<RooProduct const *>(&constraint.getMean());
289+
if (!mean) {
290+
RooJSONFactoryWSTool::error("Poisson gamma constraint mean is not a RooProduct: " +
291+
std::string(constraint.GetName()));
292+
}
293+
294+
for (RooAbsArg *arg : mean->servers()) {
295+
if (arg == &gamma) {
296+
continue;
297+
}
298+
299+
if (auto const *tau = dynamic_cast<RooConstVar const *>(arg)) {
300+
return tau->getVal();
301+
}
302+
303+
// Imported workspaces can sometimes represent
304+
// constants as constant RooRealVars.
305+
if (auto const *real = dynamic_cast<RooAbsReal const *>(arg)) {
306+
if (real->isConstant() || endsWith(std::string(real->GetName()), "_tau")) {
307+
return real->getVal();
308+
}
309+
}
310+
}
311+
312+
RooJSONFactoryWSTool::error("Could not find tau component in Poisson gamma constraint mean: " +
313+
std::string(constraint.GetName()));
314+
return std::numeric_limits<double>::quiet_NaN();
315+
}
286316

287317
bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet const &varlist,
288318
RooAbsArg const *mcStatObject, const std::string &fprefix, const JSONNode &p,
@@ -334,6 +364,7 @@ bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet con
334364
// this is dealt with at a different place, ignore it for now
335365
} else if (modtype == "normfactor") {
336366
RooRealVar &constrParam = getOrCreate<RooRealVar>(ws, sysname, 1., -3, 5);
367+
constrParam.setError(0.0);
337368
normElems.add(constrParam);
338369
if (mod.has_child("constraint_name") || mod.has_child("constraint_type")) {
339370
// for norm factors, constraints are optional
@@ -1060,7 +1091,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
10601091
if (constraint) {
10611092
sample.barlowBeestonLightConstraintType = constraint->IsA();
10621093
if (RooPoisson *constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
1063-
double erel = 1. / std::sqrt(constraint_p->getX().getVal());
1094+
double erel = 1. / std::sqrt(poissonTau(*constraint_p, *g));
10641095
channel.rel_errors[idx] = erel;
10651096
} else if (RooGaussian *constraint_g = dynamic_cast<RooGaussian *>(constraint)) {
10661097
double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
@@ -1094,7 +1125,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
10941125
if (!constraint) {
10951126
sys.constraints.push_back(0.0);
10961127
} else if (auto constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
1097-
sys.constraints.push_back(1. / std::sqrt(constraint_p->getX().getVal()));
1128+
sys.constraints.push_back(1. / std::sqrt(poissonTau(*constraint_p, *g)));
10981129
if (!sys.constraint) {
10991130
sys.constraintType = RooPoisson::Class();
11001131
}

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,41 @@ JSONNode const *findRooFitInternal(JSONNode const &node, Keys_t const &...keys)
479479
return node.find("misc", "ROOT_internal", keys...);
480480
}
481481

482+
bool isMarkedConstVar(JSONNode const *attributesNode, std::string const &name)
483+
{
484+
if (!attributesNode) {
485+
return false;
486+
}
487+
if (auto *attrNode = attributesNode->find(name)) {
488+
return attrNode->has_child("is_const_var") && (*attrNode)["is_const_var"].val_int() == 1;
489+
}
490+
return false;
491+
}
492+
493+
void importMarkedConstVars(RooJSONFactoryWSTool &tool, JSONNode const &rootNode, JSONNode const *attributesNode)
494+
{
495+
if (!attributesNode) {
496+
return;
497+
}
498+
499+
JSONNode const *varsNode = getVariablesNode(rootNode);
500+
if (!varsNode) {
501+
return;
502+
}
503+
504+
RooWorkspace *workspace = tool.workspace();
505+
for (const auto &varNode : varsNode->children()) {
506+
std::string name = RooJSONFactoryWSTool::name(varNode);
507+
if (!isMarkedConstVar(attributesNode, name) || workspace->arg(name)) {
508+
continue;
509+
}
510+
if (!varNode.has_child("value")) {
511+
RooJSONFactoryWSTool::error("cannot instantiate RooConstVar '" + name + "' without \"value\"!");
512+
}
513+
tool.wsEmplace<RooConstVar>(name, varNode["value"].val_double());
514+
}
515+
}
516+
482517
/**
483518
* @brief Check if a RooAbsArg is a literal constant variable.
484519
*
@@ -519,10 +554,9 @@ void exportAttributes(const RooAbsArg *arg, JSONNode &rootnode)
519554
node = &RooJSONFactoryWSTool::getRooFitInternal(rootnode, "attributes").set_map()[arg->GetName()].set_map();
520555
};
521556

522-
// RooConstVars are not a thing in HS3, and also for RooFit they are not
523-
// that important: they are just constants. So we don't need to remember
524-
// any information about them.
525557
if (dynamic_cast<RooConstVar const *>(arg)) {
558+
initializeNode();
559+
(*node)["is_const_var"] << 1;
526560
return;
527561
}
528562

@@ -1794,22 +1828,20 @@ void RooJSONFactoryWSTool::importVariable(const JSONNode &p)
17941828
std::string name(RooJSONFactoryWSTool::name(p));
17951829
RooJSONFactoryWSTool::testValidName(name, true);
17961830

1797-
if (_workspace.var(name))
1831+
if (_workspace.var(name) || (isMarkedConstVar(_attributesNode, name) && _workspace.arg(name)))
17981832
return;
17991833
if (!p.is_map()) {
18001834
std::stringstream ss;
18011835
ss << "RooJSONFactoryWSTool() node '" << name << "' is not a map, skipping.";
18021836
oocoutE(nullptr, InputArguments) << ss.str() << std::endl;
18031837
return;
18041838
}
1805-
if (_attributesNode) {
1806-
if (auto *attrNode = _attributesNode->find(name)) {
1807-
// We should not create RooRealVar objects for RooConstVars!
1808-
if (attrNode->has_child("is_const_var") && (*attrNode)["is_const_var"].val_int() == 1) {
1809-
wsEmplace<RooConstVar>(name, p["value"].val_double());
1810-
return;
1811-
}
1839+
if (isMarkedConstVar(_attributesNode, name)) {
1840+
if (!p.has_child("value")) {
1841+
RooJSONFactoryWSTool::error("cannot instantiate RooConstVar '" + name + "' without \"value\"!");
18121842
}
1843+
wsEmplace<RooConstVar>(name, p["value"].val_double());
1844+
return;
18131845
}
18141846
configureVariable(*_domains, p, wsEmplace<RooRealVar>(name, 1.));
18151847
}
@@ -2273,16 +2305,17 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)
22732305
error(ss.str());
22742306
}
22752307

2308+
_rootnodeInput = &n;
2309+
2310+
_attributesNode = findRooFitInternal(*_rootnodeInput, "attributes");
2311+
22762312
_domains = std::make_unique<RooFit::JSONIO::Detail::Domains>();
22772313
if (auto domains = n.find("domains")) {
22782314
_domains->readJSON(*domains);
22792315
}
2316+
importMarkedConstVars(*this, n, _attributesNode);
22802317
_domains->populate(_workspace);
22812318

2282-
_rootnodeInput = &n;
2283-
2284-
_attributesNode = findRooFitInternal(*_rootnodeInput, "attributes");
2285-
22862319
// Build name-keyed indices over the "functions" and "distributions"
22872320
// sequences. Without these, every cross-reference resolved during import
22882321
// (e.g. dependencies of a PiecewiseInterpolation, or factory-expression

roofit/hs3/test/testRooFitHS3.cxx

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,108 @@ TEST(RooFitHS3, RooGaussian)
395395
EXPECT_EQ(status, 0);
396396
}
397397

398+
TEST(RooFitHS3, RooGaussianConstVarSigmaExport)
399+
{
400+
RooRealVar x{"x", "x", 0.0, -10.0, 10.0};
401+
RooRealVar mean{"mean", "mean", 0.0};
402+
mean.setConstant(true);
403+
404+
RooConstVar sigmaConst{"sigma_const", "sigma_const", 1.0};
405+
RooGaussian gaussConst{"gauss_const", "gauss_const", x, mean, sigmaConst};
406+
407+
RooGaussian gaussLiteral{"gauss_literal", "gauss_literal", x, mean, RooFit::RooConst(2.0)};
408+
409+
RooRealVar sigmaReal{"sigma_real", "sigma_real", 1.0, 0.1, 10.0};
410+
sigmaReal.setConstant(true);
411+
RooGaussian gaussReal{"gauss_real", "gauss_real", x, mean, sigmaReal};
412+
413+
RooWorkspace ws;
414+
ws.import(gaussConst, RooFit::Silence());
415+
ws.import(gaussLiteral, RooFit::RecycleConflictNodes(), RooFit::Silence());
416+
ws.import(gaussReal, RooFit::RecycleConflictNodes(), RooFit::Silence());
417+
418+
const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString();
419+
420+
EXPECT_NE(json.find("\"sigma\":\"sigma_const\""), std::string::npos);
421+
EXPECT_NE(json.find("\"name\":\"sigma_const\""), std::string::npos);
422+
EXPECT_NE(json.find("\"is_const_var\":1"), std::string::npos);
423+
EXPECT_EQ(json.find("\"sigma\":1.0"), std::string::npos);
424+
EXPECT_NE(json.find("\"sigma\":2.0"), std::string::npos);
425+
426+
EXPECT_NE(json.find("\"sigma\":\"sigma_real\""), std::string::npos);
427+
EXPECT_NE(json.find("\"name\":\"sigma_real\""), std::string::npos);
428+
429+
RooWorkspace imported;
430+
ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json));
431+
EXPECT_NE(dynamic_cast<RooConstVar *>(imported.obj("sigma_const")), nullptr);
432+
EXPECT_EQ(imported.var("sigma_const"), nullptr);
433+
EXPECT_NE(dynamic_cast<RooRealVar *>(imported.obj("sigma_real")), nullptr);
434+
435+
const std::string roundTripJson = RooJSONFactoryWSTool{imported}.exportJSONtoString();
436+
EXPECT_NE(roundTripJson.find("\"sigma\":\"sigma_const\""), std::string::npos);
437+
EXPECT_NE(roundTripJson.find("\"is_const_var\":1"), std::string::npos);
438+
439+
const std::string legacyJson = R"({
440+
"metadata":{"hs3_version":"0.2"},
441+
"parameter_points":[{"name":"default_values","parameters":[
442+
{"name":"x","value":0.0},
443+
{"name":"mean","value":0.0},
444+
{"name":"sigma_const","value":1.0,"const":true}
445+
]}],
446+
"distributions":[{"name":"gauss","type":"gaussian_dist","x":"x","mean":"mean","sigma":"sigma_const"}]
447+
})";
448+
RooWorkspace legacyImport;
449+
ASSERT_TRUE(RooJSONFactoryWSTool{legacyImport}.importJSONfromString(legacyJson));
450+
EXPECT_NE(dynamic_cast<RooRealVar *>(legacyImport.obj("sigma_const")), nullptr);
451+
EXPECT_EQ(dynamic_cast<RooConstVar *>(legacyImport.obj("sigma_const")), nullptr);
452+
453+
const std::string markedConstWithDomainJson = R"({
454+
"metadata":{"hs3_version":"0.2"},
455+
"domains":[{"name":"default_domain","type":"product_domain","axes":[
456+
{"name":"x","min":-10.0,"max":10.0},
457+
{"name":"mean","min":-10.0,"max":10.0},
458+
{"name":"sigma_const","min":0.0,"max":10.0}
459+
]}],
460+
"parameter_points":[{"name":"default_values","parameters":[
461+
{"name":"x","value":0.0},
462+
{"name":"mean","value":0.0},
463+
{"name":"sigma_const","value":1.0,"const":true}
464+
]}],
465+
"distributions":[{"name":"gauss","type":"gaussian_dist","x":"x","mean":"mean","sigma":"sigma_const"}],
466+
"misc":{"ROOT_internal":{"attributes":{"sigma_const":{"is_const_var":1}}}}
467+
})";
468+
RooWorkspace markedImport;
469+
ASSERT_TRUE(RooJSONFactoryWSTool{markedImport}.importJSONfromString(markedConstWithDomainJson));
470+
EXPECT_NE(dynamic_cast<RooConstVar *>(markedImport.obj("sigma_const")), nullptr);
471+
EXPECT_EQ(markedImport.var("sigma_const"), nullptr);
472+
}
473+
474+
TEST(RooFitHS3, RooConstVarCollectionProxyExport)
475+
{
476+
RooRealVar x{"x", "x", 0.0, -10.0, 10.0};
477+
RooRealVar mean1{"mean1", "mean1", -1.0};
478+
RooRealVar mean2{"mean2", "mean2", 1.0};
479+
RooRealVar sigma{"sigma", "sigma", 1.0, 0.1, 10.0};
480+
481+
RooGaussian g1{"g1", "g1", x, mean1, sigma};
482+
RooGaussian g2{"g2", "g2", x, mean2, sigma};
483+
RooConstVar frac{"frac_const", "frac_const", 0.25};
484+
RooAddPdf model{"model", "model", RooArgList{g1, g2}, RooArgList{frac}};
485+
486+
RooWorkspace ws;
487+
ws.import(model, RooFit::Silence());
488+
489+
const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString();
490+
EXPECT_NE(json.find("\"coefficients\":[\"frac_const\"]"), std::string::npos);
491+
EXPECT_NE(json.find("\"name\":\"frac_const\""), std::string::npos);
492+
EXPECT_NE(json.find("\"is_const_var\":1"), std::string::npos);
493+
494+
RooWorkspace imported;
495+
ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json));
496+
EXPECT_NE(dynamic_cast<RooConstVar *>(imported.obj("frac_const")), nullptr);
497+
EXPECT_EQ(imported.var("frac_const"), nullptr);
498+
}
499+
398500
TEST(RooFitHS3, RooBernstein)
399501
{
400502
int status = validate({"RooBernstein::bernstein(x[0, 10], { a[1], 3, b[5, 0, 20] })"});

0 commit comments

Comments
 (0)