Skip to content

Commit a1cf5d6

Browse files
committed
[HS3] Fix HistFactory export and RooConstVar roundtripping
1 parent e217071 commit a1cf5d6

4 files changed

Lines changed: 184 additions & 18 deletions

File tree

roofit/hs3/src/Domains.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ void Domains::ProductDomain::populate(RooWorkspace &ws) const
162162
{
163163
for (auto const &item : _map) {
164164
const auto &name = item.first;
165-
if (!ws.var(name)) {
165+
if (!ws.arg(name)) {
166166
const auto &elem = item.second;
167167
const double vMin = elem.hasMin ? elem.min : -RooNumber::infinity();
168168
const double vMax = elem.hasMax ? elem.max : RooNumber::infinity();

roofit/hs3/src/JSONFactories_HistFactory.cxx

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,36 @@ getOrCreateConstraint(RooJSONFactoryWSTool &tool, const JSONNode &mod, RooRealVa
283283
"'");
284284
}
285285
}
286+
double poissonTau(RooPoisson const &constraint, RooAbsArg const &gamma)
287+
{
288+
auto const *mean = dynamic_cast<RooProduct const *>(&constraint.getMean());
289+
if (!mean) {
290+
RooJSONFactoryWSTool::error("Poisson gamma constraint mean is not a RooProduct: " +
291+
std::string(constraint.GetName()));
292+
}
293+
294+
for (RooAbsArg *arg : mean->servers()) {
295+
if (arg == &gamma) {
296+
continue;
297+
}
298+
299+
if (auto const *tau = dynamic_cast<RooConstVar const *>(arg)) {
300+
return tau->getVal();
301+
}
302+
303+
// Imported workspaces can sometimes represent
304+
// constants as constant RooRealVars.
305+
if (auto const *real = dynamic_cast<RooAbsReal const *>(arg)) {
306+
if (real->isConstant() || endsWith(std::string(real->GetName()), "_tau")) {
307+
return real->getVal();
308+
}
309+
}
310+
}
311+
312+
RooJSONFactoryWSTool::error("Could not find tau component in Poisson gamma constraint mean: " +
313+
std::string(constraint.GetName()));
314+
return std::numeric_limits<double>::quiet_NaN();
315+
}
286316

287317
bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet const &varlist,
288318
RooAbsArg const *mcStatObject, const std::string &fprefix, const JSONNode &p,
@@ -334,6 +364,7 @@ bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet con
334364
// this is dealt with at a different place, ignore it for now
335365
} else if (modtype == "normfactor") {
336366
RooRealVar &constrParam = getOrCreate<RooRealVar>(ws, sysname, 1., -3, 5);
367+
constrParam.setError(0.0);
337368
normElems.add(constrParam);
338369
if (mod.has_child("constraint_name") || mod.has_child("constraint_type")) {
339370
// for norm factors, constraints are optional
@@ -1060,7 +1091,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
10601091
if (constraint) {
10611092
sample.barlowBeestonLightConstraintType = constraint->IsA();
10621093
if (RooPoisson *constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
1063-
double erel = 1. / std::sqrt(constraint_p->getX().getVal());
1094+
double erel = 1. / std::sqrt(poissonTau(*constraint_p, *g));
10641095
channel.rel_errors[idx] = erel;
10651096
} else if (RooGaussian *constraint_g = dynamic_cast<RooGaussian *>(constraint)) {
10661097
double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
@@ -1094,7 +1125,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
10941125
if (!constraint) {
10951126
sys.constraints.push_back(0.0);
10961127
} else if (auto constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
1097-
sys.constraints.push_back(1. / std::sqrt(constraint_p->getX().getVal()));
1128+
sys.constraints.push_back(1. / std::sqrt(poissonTau(*constraint_p, *g)));
10981129
if (!sys.constraint) {
10991130
sys.constraintType = RooPoisson::Class();
11001131
}

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,41 @@ JSONNode const *findRooFitInternal(JSONNode const &node, Keys_t const &...keys)
371371
return node.find("misc", "ROOT_internal", keys...);
372372
}
373373

374+
bool isMarkedConstVar(JSONNode const *attributesNode, std::string const &name)
375+
{
376+
if (!attributesNode) {
377+
return false;
378+
}
379+
if (auto *attrNode = attributesNode->find(name)) {
380+
return attrNode->has_child("is_const_var") && (*attrNode)["is_const_var"].val_int() == 1;
381+
}
382+
return false;
383+
}
384+
385+
void importMarkedConstVars(RooJSONFactoryWSTool &tool, JSONNode const &rootNode, JSONNode const *attributesNode)
386+
{
387+
if (!attributesNode) {
388+
return;
389+
}
390+
391+
JSONNode const *varsNode = getVariablesNode(rootNode);
392+
if (!varsNode) {
393+
return;
394+
}
395+
396+
RooWorkspace *workspace = tool.workspace();
397+
for (const auto &varNode : varsNode->children()) {
398+
std::string name = RooJSONFactoryWSTool::name(varNode);
399+
if (!isMarkedConstVar(attributesNode, name) || workspace->arg(name)) {
400+
continue;
401+
}
402+
if (!varNode.has_child("value")) {
403+
RooJSONFactoryWSTool::error("cannot instantiate RooConstVar '" + name + "' without \"value\"!");
404+
}
405+
tool.wsEmplace<RooConstVar>(name, varNode["value"].val_double());
406+
}
407+
}
408+
374409
/**
375410
* @brief Check if a RooAbsArg is a literal constant variable.
376411
*
@@ -411,10 +446,9 @@ void exportAttributes(const RooAbsArg *arg, JSONNode &rootnode)
411446
node = &RooJSONFactoryWSTool::getRooFitInternal(rootnode, "attributes").set_map()[arg->GetName()].set_map();
412447
};
413448

414-
// RooConstVars are not a thing in HS3, and also for RooFit they are not
415-
// that important: they are just constants. So we don't need to remember
416-
// any information about them.
417449
if (dynamic_cast<RooConstVar const *>(arg)) {
450+
initializeNode();
451+
(*node)["is_const_var"] << 1;
418452
return;
419453
}
420454

@@ -1688,22 +1722,20 @@ void RooJSONFactoryWSTool::importVariable(const JSONNode &p)
16881722
std::string name(RooJSONFactoryWSTool::name(p));
16891723
RooJSONFactoryWSTool::testValidName(name, true);
16901724

1691-
if (_workspace.var(name))
1725+
if (_workspace.var(name) || (isMarkedConstVar(_attributesNode, name) && _workspace.arg(name)))
16921726
return;
16931727
if (!p.is_map()) {
16941728
std::stringstream ss;
16951729
ss << "RooJSONFactoryWSTool() node '" << name << "' is not a map, skipping.";
16961730
oocoutE(nullptr, InputArguments) << ss.str() << std::endl;
16971731
return;
16981732
}
1699-
if (_attributesNode) {
1700-
if (auto *attrNode = _attributesNode->find(name)) {
1701-
// We should not create RooRealVar objects for RooConstVars!
1702-
if (attrNode->has_child("is_const_var") && (*attrNode)["is_const_var"].val_int() == 1) {
1703-
wsEmplace<RooConstVar>(name, p["value"].val_double());
1704-
return;
1705-
}
1733+
if (isMarkedConstVar(_attributesNode, name)) {
1734+
if (!p.has_child("value")) {
1735+
RooJSONFactoryWSTool::error("cannot instantiate RooConstVar '" + name + "' without \"value\"!");
17061736
}
1737+
wsEmplace<RooConstVar>(name, p["value"].val_double());
1738+
return;
17071739
}
17081740
configureVariable(*_domains, p, wsEmplace<RooRealVar>(name, 1.));
17091741
}
@@ -2165,16 +2197,17 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)
21652197
error(ss.str());
21662198
}
21672199

2200+
_rootnodeInput = &n;
2201+
2202+
_attributesNode = findRooFitInternal(*_rootnodeInput, "attributes");
2203+
21682204
_domains = std::make_unique<RooFit::JSONIO::Detail::Domains>();
21692205
if (auto domains = n.find("domains")) {
21702206
_domains->readJSON(*domains);
21712207
}
2208+
importMarkedConstVars(*this, n, _attributesNode);
21722209
_domains->populate(_workspace);
21732210

2174-
_rootnodeInput = &n;
2175-
2176-
_attributesNode = findRooFitInternal(*_rootnodeInput, "attributes");
2177-
21782211
this->importDependants(n);
21792212

21802213
if (auto paramPointsNode = n.find("parameter_points")) {

roofit/hs3/test/testRooFitHS3.cxx

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,108 @@ TEST(RooFitHS3, RooGaussian)
259259
EXPECT_EQ(status, 0);
260260
}
261261

262+
TEST(RooFitHS3, RooGaussianConstVarSigmaExport)
263+
{
264+
RooRealVar x{"x", "x", 0.0, -10.0, 10.0};
265+
RooRealVar mean{"mean", "mean", 0.0};
266+
mean.setConstant(true);
267+
268+
RooConstVar sigmaConst{"sigma_const", "sigma_const", 1.0};
269+
RooGaussian gaussConst{"gauss_const", "gauss_const", x, mean, sigmaConst};
270+
271+
RooGaussian gaussLiteral{"gauss_literal", "gauss_literal", x, mean, RooFit::RooConst(2.0)};
272+
273+
RooRealVar sigmaReal{"sigma_real", "sigma_real", 1.0, 0.1, 10.0};
274+
sigmaReal.setConstant(true);
275+
RooGaussian gaussReal{"gauss_real", "gauss_real", x, mean, sigmaReal};
276+
277+
RooWorkspace ws;
278+
ws.import(gaussConst, RooFit::Silence());
279+
ws.import(gaussLiteral, RooFit::RecycleConflictNodes(), RooFit::Silence());
280+
ws.import(gaussReal, RooFit::RecycleConflictNodes(), RooFit::Silence());
281+
282+
const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString();
283+
284+
EXPECT_NE(json.find("\"sigma\":\"sigma_const\""), std::string::npos);
285+
EXPECT_NE(json.find("\"name\":\"sigma_const\""), std::string::npos);
286+
EXPECT_NE(json.find("\"is_const_var\":1"), std::string::npos);
287+
EXPECT_EQ(json.find("\"sigma\":1.0"), std::string::npos);
288+
EXPECT_NE(json.find("\"sigma\":2.0"), std::string::npos);
289+
290+
EXPECT_NE(json.find("\"sigma\":\"sigma_real\""), std::string::npos);
291+
EXPECT_NE(json.find("\"name\":\"sigma_real\""), std::string::npos);
292+
293+
RooWorkspace imported;
294+
ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json));
295+
EXPECT_NE(dynamic_cast<RooConstVar *>(imported.obj("sigma_const")), nullptr);
296+
EXPECT_EQ(imported.var("sigma_const"), nullptr);
297+
EXPECT_NE(dynamic_cast<RooRealVar *>(imported.obj("sigma_real")), nullptr);
298+
299+
const std::string roundTripJson = RooJSONFactoryWSTool{imported}.exportJSONtoString();
300+
EXPECT_NE(roundTripJson.find("\"sigma\":\"sigma_const\""), std::string::npos);
301+
EXPECT_NE(roundTripJson.find("\"is_const_var\":1"), std::string::npos);
302+
303+
const std::string legacyJson = R"({
304+
"metadata":{"hs3_version":"0.2"},
305+
"parameter_points":[{"name":"default_values","parameters":[
306+
{"name":"x","value":0.0},
307+
{"name":"mean","value":0.0},
308+
{"name":"sigma_const","value":1.0,"const":true}
309+
]}],
310+
"distributions":[{"name":"gauss","type":"gaussian_dist","x":"x","mean":"mean","sigma":"sigma_const"}]
311+
})";
312+
RooWorkspace legacyImport;
313+
ASSERT_TRUE(RooJSONFactoryWSTool{legacyImport}.importJSONfromString(legacyJson));
314+
EXPECT_NE(dynamic_cast<RooRealVar *>(legacyImport.obj("sigma_const")), nullptr);
315+
EXPECT_EQ(dynamic_cast<RooConstVar *>(legacyImport.obj("sigma_const")), nullptr);
316+
317+
const std::string markedConstWithDomainJson = R"({
318+
"metadata":{"hs3_version":"0.2"},
319+
"domains":[{"name":"default_domain","type":"product_domain","axes":[
320+
{"name":"x","min":-10.0,"max":10.0},
321+
{"name":"mean","min":-10.0,"max":10.0},
322+
{"name":"sigma_const","min":0.0,"max":10.0}
323+
]}],
324+
"parameter_points":[{"name":"default_values","parameters":[
325+
{"name":"x","value":0.0},
326+
{"name":"mean","value":0.0},
327+
{"name":"sigma_const","value":1.0,"const":true}
328+
]}],
329+
"distributions":[{"name":"gauss","type":"gaussian_dist","x":"x","mean":"mean","sigma":"sigma_const"}],
330+
"misc":{"ROOT_internal":{"attributes":{"sigma_const":{"is_const_var":1}}}}
331+
})";
332+
RooWorkspace markedImport;
333+
ASSERT_TRUE(RooJSONFactoryWSTool{markedImport}.importJSONfromString(markedConstWithDomainJson));
334+
EXPECT_NE(dynamic_cast<RooConstVar *>(markedImport.obj("sigma_const")), nullptr);
335+
EXPECT_EQ(markedImport.var("sigma_const"), nullptr);
336+
}
337+
338+
TEST(RooFitHS3, RooConstVarCollectionProxyExport)
339+
{
340+
RooRealVar x{"x", "x", 0.0, -10.0, 10.0};
341+
RooRealVar mean1{"mean1", "mean1", -1.0};
342+
RooRealVar mean2{"mean2", "mean2", 1.0};
343+
RooRealVar sigma{"sigma", "sigma", 1.0, 0.1, 10.0};
344+
345+
RooGaussian g1{"g1", "g1", x, mean1, sigma};
346+
RooGaussian g2{"g2", "g2", x, mean2, sigma};
347+
RooConstVar frac{"frac_const", "frac_const", 0.25};
348+
RooAddPdf model{"model", "model", RooArgList{g1, g2}, RooArgList{frac}};
349+
350+
RooWorkspace ws;
351+
ws.import(model, RooFit::Silence());
352+
353+
const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString();
354+
EXPECT_NE(json.find("\"coefficients\":[\"frac_const\"]"), std::string::npos);
355+
EXPECT_NE(json.find("\"name\":\"frac_const\""), std::string::npos);
356+
EXPECT_NE(json.find("\"is_const_var\":1"), std::string::npos);
357+
358+
RooWorkspace imported;
359+
ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json));
360+
EXPECT_NE(dynamic_cast<RooConstVar *>(imported.obj("frac_const")), nullptr);
361+
EXPECT_EQ(imported.var("frac_const"), nullptr);
362+
}
363+
262364
TEST(RooFitHS3, RooBernstein)
263365
{
264366
int status = validate({"RooBernstein::bernstein(x[0, 10], { a[1], 3, b[5, 0, 20] })"});

0 commit comments

Comments
 (0)