Skip to content

Commit ba1ca31

Browse files
committed
[HS3] Patch RooGaussian handling and histfactory_bugfix
1 parent e217071 commit ba1ca31

3 files changed

Lines changed: 95 additions & 2 deletions

File tree

roofit/hs3/src/JSONFactories_HistFactory.cxx

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,36 @@ getOrCreateConstraint(RooJSONFactoryWSTool &tool, const JSONNode &mod, RooRealVa
283283
"'");
284284
}
285285
}
286+
double poissonTau(RooPoisson const &constraint, RooAbsArg const &gamma)
287+
{
288+
auto const *mean = dynamic_cast<RooProduct const *>(&constraint.getMean());
289+
if (!mean) {
290+
RooJSONFactoryWSTool::error("Poisson gamma constraint mean is not a RooProduct: " +
291+
std::string(constraint.GetName()));
292+
}
293+
294+
for (RooAbsArg *arg : mean->servers()) {
295+
if (arg == &gamma) {
296+
continue;
297+
}
298+
299+
if (auto const *tau = dynamic_cast<RooConstVar const *>(arg)) {
300+
return tau->getVal();
301+
}
302+
303+
// Imported workspaces can sometimes represent
304+
// constants as constant RooRealVars.
305+
if (auto const *real = dynamic_cast<RooAbsReal const *>(arg)) {
306+
if (real->isConstant() || endsWith(std::string(real->GetName()), "_tau")) {
307+
return real->getVal();
308+
}
309+
}
310+
}
311+
312+
RooJSONFactoryWSTool::error("Could not find tau component in Poisson gamma constraint mean: " +
313+
std::string(constraint.GetName()));
314+
return std::numeric_limits<double>::quiet_NaN();
315+
}
286316

287317
bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet const &varlist,
288318
RooAbsArg const *mcStatObject, const std::string &fprefix, const JSONNode &p,
@@ -334,6 +364,7 @@ bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet con
334364
// this is dealt with at a different place, ignore it for now
335365
} else if (modtype == "normfactor") {
336366
RooRealVar &constrParam = getOrCreate<RooRealVar>(ws, sysname, 1., -3, 5);
367+
constrParam.setError(0.0);
337368
normElems.add(constrParam);
338369
if (mod.has_child("constraint_name") || mod.has_child("constraint_type")) {
339370
// for norm factors, constraints are optional
@@ -1060,7 +1091,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
10601091
if (constraint) {
10611092
sample.barlowBeestonLightConstraintType = constraint->IsA();
10621093
if (RooPoisson *constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
1063-
double erel = 1. / std::sqrt(constraint_p->getX().getVal());
1094+
double erel = 1. / std::sqrt(poissonTau(*constraint_p, *g));
10641095
channel.rel_errors[idx] = erel;
10651096
} else if (RooGaussian *constraint_g = dynamic_cast<RooGaussian *>(constraint)) {
10661097
double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
@@ -1094,7 +1125,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
10941125
if (!constraint) {
10951126
sys.constraints.push_back(0.0);
10961127
} else if (auto constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
1097-
sys.constraints.push_back(1. / std::sqrt(constraint_p->getX().getVal()));
1128+
sys.constraints.push_back(1. / std::sqrt(poissonTau(*constraint_p, *g)));
10981129
if (!sys.constraint) {
10991130
sys.constraintType = RooPoisson::Class();
11001131
}

roofit/hs3/src/JSONFactories_RooFitCore.cxx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <RooBinSamplingPdf.h>
2020
#include <RooBinWidthFunction.h>
2121
#include <RooCategory.h>
22+
#include <RooConstVar.h>
2223
#include <RooDataHist.h>
2324
#include <RooDecay.h>
2425
#include <RooDerivative.h>
@@ -29,6 +30,7 @@
2930
#include <RooFitHS3/JSONIO.h>
3031
#include <RooFormulaVar.h>
3132
#include <RooGenericPdf.h>
33+
#include <RooGaussian.h>
3234
#include <RooHistFunc.h>
3335
#include <RooHistPdf.h>
3436
#include <RooLegacyExpPoly.h>
@@ -878,6 +880,32 @@ class RooPoissonStreamer : public RooFit::JSONIO::Exporter {
878880
}
879881
};
880882

883+
class RooGaussianStreamer : public RooFit::JSONIO::Exporter {
884+
public:
885+
std::string const &key() const override;
886+
bool autoExportDependants() const override { return false; }
887+
bool exportObject(RooJSONFactoryWSTool *tool, const RooAbsArg *func, JSONNode &elem) const override
888+
{
889+
auto *pdf = static_cast<const RooGaussian *>(func);
890+
elem["type"] << key();
891+
writeArg(tool, elem["x"], pdf->getX());
892+
writeArg(tool, elem["mean"], pdf->getMean());
893+
writeArg(tool, elem["sigma"], pdf->getSigma());
894+
return true;
895+
}
896+
897+
private:
898+
static void writeArg(RooJSONFactoryWSTool *tool, JSONNode &node, RooAbsReal const &arg)
899+
{
900+
if (auto const *constant = dynamic_cast<RooConstVar const *>(&arg)) {
901+
node << constant->getVal();
902+
} else {
903+
node << arg.GetName();
904+
tool->queueExport(arg);
905+
}
906+
}
907+
};
908+
881909
class RooDecayStreamer : public RooFit::JSONIO::Exporter {
882910
public:
883911
std::string const &key() const override;
@@ -1171,6 +1199,7 @@ DEFINE_EXPORTER_KEY(RooHistPdfStreamer, "histogram_dist");
11711199
DEFINE_EXPORTER_KEY(RooLogNormalStreamer, "lognormal_dist");
11721200
DEFINE_EXPORTER_KEY(RooMultiVarGaussianStreamer, "multivariate_normal_dist");
11731201
DEFINE_EXPORTER_KEY(RooPoissonStreamer, "poisson_dist");
1202+
DEFINE_EXPORTER_KEY(RooGaussianStreamer, "gaussian_dist");
11741203
DEFINE_EXPORTER_KEY(RooDecayStreamer, "decay_dist");
11751204
DEFINE_EXPORTER_KEY(RooTruthModelStreamer, "truth_model_function");
11761205
DEFINE_EXPORTER_KEY(RooGaussModelStreamer, "gauss_model_function");
@@ -1235,6 +1264,7 @@ STATIC_EXECUTE([]() {
12351264
registerExporter<RooLogNormalStreamer>(RooLognormal::Class(), false);
12361265
registerExporter<RooMultiVarGaussianStreamer>(RooMultiVarGaussian::Class(), false);
12371266
registerExporter<RooPoissonStreamer>(RooPoisson::Class(), false);
1267+
registerExporter<RooGaussianStreamer>(RooGaussian::Class(), false);
12381268
registerExporter<RooDecayStreamer>(RooDecay::Class(), false);
12391269
registerExporter<RooTruthModelStreamer>(RooTruthModel::Class(), false);
12401270
registerExporter<RooGaussModelStreamer>(RooGaussModel::Class(), false);

roofit/hs3/test/testRooFitHS3.cxx

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,38 @@ TEST(RooFitHS3, RooGaussian)
259259
EXPECT_EQ(status, 0);
260260
}
261261

262+
TEST(RooFitHS3, RooGaussianConstVarSigmaExport)
263+
{
264+
RooRealVar x{"x", "x", 0.0, -10.0, 10.0};
265+
RooRealVar mean{"mean", "mean", 0.0};
266+
mean.setConstant(true);
267+
268+
RooConstVar sigmaConst{"sigma_const", "sigma_const", 1.0};
269+
RooGaussian gaussConst{"gauss_const", "gauss_const", x, mean, sigmaConst};
270+
271+
RooRealVar sigmaReal{"sigma_real", "sigma_real", 1.0, 0.1, 10.0};
272+
sigmaReal.setConstant(true);
273+
RooGaussian gaussReal{"gauss_real", "gauss_real", x, mean, sigmaReal};
274+
275+
RooWorkspace ws;
276+
ws.import(gaussConst, RooFit::Silence());
277+
ws.import(gaussReal, RooFit::RecycleConflictNodes(), RooFit::Silence());
278+
279+
const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString();
280+
281+
EXPECT_EQ(json.find("\"sigma\":\"sigma_const\""), std::string::npos);
282+
EXPECT_EQ(json.find("\"name\":\"sigma_const\""), std::string::npos);
283+
EXPECT_NE(json.find("\"sigma\":1.0"), std::string::npos);
284+
285+
EXPECT_NE(json.find("\"sigma\":\"sigma_real\""), std::string::npos);
286+
EXPECT_NE(json.find("\"name\":\"sigma_real\""), std::string::npos);
287+
288+
RooWorkspace imported;
289+
RooJSONFactoryWSTool{imported}.importJSONfromString(json);
290+
EXPECT_EQ(imported.obj("sigma_const"), nullptr);
291+
EXPECT_NE(dynamic_cast<RooRealVar *>(imported.obj("sigma_real")), nullptr);
292+
}
293+
262294
TEST(RooFitHS3, RooBernstein)
263295
{
264296
int status = validate({"RooBernstein::bernstein(x[0, 10], { a[1], 3, b[5, 0, 20] })"});

0 commit comments

Comments
 (0)