Skip to content

Commit 7dff3c6

Browse files
committed
hs3: preserve binning metadata in domains
1 parent 0c635d5 commit 7dff3c6

3 files changed

Lines changed: 162 additions & 9 deletions

File tree

roofit/hs3/src/Domains.cxx

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "Domains.h"
1414

15+
#include <RooAbsBinning.h>
16+
#include <RooBinning.h>
1517
#include <RooFitHS3/RooJSONFactoryWSTool.h>
1618
#include <RooNumber.h>
1719
#include <RooRealVar.h>
@@ -71,12 +73,12 @@ void Domains::readVariable(const char *name, double min, double max, const char
7173

7274
void Domains::readVariable(RooRealVar const &var)
7375
{
74-
readVariable(var.GetName(), var.getMin(), var.getMax(), defaultDomainName);
76+
_map[defaultDomainName].readVariable(var.GetName(), var.getBinning());
7577
for (const auto &bname : var.getBinningNames()) {
7678
if (bname.empty())
7779
continue;
7880
auto &binning = var.getBinning(bname.c_str());
79-
readVariable(var.GetName(), binning.lowBound(), binning.highBound(), bname.c_str());
81+
_map[bname].readVariable(var.GetName(), binning);
8082
}
8183
}
8284

@@ -125,7 +127,34 @@ bool Domains::hasVariable(const char *name) const
125127

126128
void Domains::ProductDomain::readVariable(RooRealVar const &var)
127129
{
128-
readVariable(var.GetName(), var.getMin(), var.getMax());
130+
readVariable(var.GetName(), var.getBinning());
131+
}
132+
133+
void Domains::ProductDomain::readBinning(ProductDomainElement &elem, RooAbsBinning const &binning)
134+
{
135+
if (binning.isUniform()) {
136+
elem.hasNBins = true;
137+
elem.nBins = binning.numBins();
138+
elem.edges.clear();
139+
} else {
140+
elem.hasNBins = false;
141+
elem.edges.clear();
142+
elem.edges.push_back(binning.binLow(0));
143+
for (int i = 0; i < binning.numBins(); ++i) {
144+
elem.edges.push_back(binning.binHigh(i));
145+
}
146+
}
147+
}
148+
149+
void Domains::ProductDomain::readVariable(const char *name, RooAbsBinning const &binning)
150+
{
151+
auto &elem = _map[name];
152+
153+
elem.hasMin = true;
154+
elem.min = binning.lowBound();
155+
elem.hasMax = true;
156+
elem.max = binning.highBound();
157+
readBinning(elem, binning);
129158
}
130159

131160
void Domains::ProductDomain::readVariable(const char *name, double min, double max)
@@ -136,6 +165,34 @@ void Domains::ProductDomain::readVariable(const char *name, double min, double m
136165
elem.min = min;
137166
elem.hasMax = true;
138167
elem.max = max;
168+
elem.hasNBins = false;
169+
elem.nBins = 0;
170+
elem.edges.clear();
171+
}
172+
173+
void Domains::ProductDomain::applyBinning(RooRealVar &var, ProductDomainElement const &elem, const char *name)
174+
{
175+
if (!elem.edges.empty()) {
176+
RooBinning binning(elem.edges.front(), elem.edges.back());
177+
for (double edge : elem.edges) {
178+
binning.addBoundary(edge);
179+
}
180+
var.setBinning(binning, name);
181+
} else if (elem.hasNBins) {
182+
var.setBins(elem.nBins, name);
183+
}
184+
}
185+
186+
void Domains::ProductDomain::writeBinning(RooFit::Detail::JSONNode &node, ProductDomainElement const &elem)
187+
{
188+
if (!elem.edges.empty()) {
189+
auto &edges = node["edges"].set_seq();
190+
for (double edge : elem.edges) {
191+
edges.append_child() << edge;
192+
}
193+
} else if (elem.hasNBins) {
194+
node["nbins"] << elem.nBins;
195+
}
139196
}
140197
void Domains::ProductDomain::writeVariable(RooRealVar &var) const
141198
{
@@ -156,6 +213,7 @@ void Domains::ProductDomain::writeVariable(RooRealVar &var) const
156213
var.setMax(elem.max);
157214
}
158215
}
216+
applyBinning(var, elem);
159217
}
160218
}
161219

@@ -180,6 +238,27 @@ void Domains::ProductDomain::readJSON(RooFit::Detail::JSONNode const &node)
180238
elem.max = readBound(varNode, "max", RooNumber::infinity());
181239
elem.hasMax = true;
182240
}
241+
if (varNode.has_child("edges")) {
242+
elem.hasNBins = false;
243+
elem.edges.clear();
244+
for (auto const &edge : varNode["edges"].children()) {
245+
elem.edges.push_back(edge.val_double());
246+
}
247+
if (!elem.edges.empty()) {
248+
if (!elem.hasMin) {
249+
elem.min = elem.edges.front();
250+
elem.hasMin = true;
251+
}
252+
if (!elem.hasMax) {
253+
elem.max = elem.edges.back();
254+
elem.hasMax = true;
255+
}
256+
}
257+
} else if (varNode.has_child("nbins")) {
258+
elem.hasNBins = true;
259+
elem.nBins = varNode["nbins"].val_int();
260+
elem.edges.clear();
261+
}
183262
}
184263
}
185264
void Domains::ProductDomain::writeJSON(RooFit::Detail::JSONNode &node) const
@@ -195,6 +274,7 @@ void Domains::ProductDomain::writeJSON(RooFit::Detail::JSONNode &node) const
195274
RooFit::Detail::JSONNode &varnode = RooJSONFactoryWSTool::appendNamedChild(variablesNode, item.first);
196275
writeBound(varnode["min"], elem.hasMin ? elem.min : -RooNumber::infinity());
197276
writeBound(varnode["max"], elem.hasMax ? elem.max : RooNumber::infinity());
277+
writeBinning(varnode, elem);
198278
}
199279
}
200280
void Domains::ProductDomain::populate(RooWorkspace &ws) const
@@ -205,7 +285,9 @@ void Domains::ProductDomain::populate(RooWorkspace &ws) const
205285
const auto &elem = item.second;
206286
const double vMin = elem.hasMin ? elem.min : -RooNumber::infinity();
207287
const double vMax = elem.hasMax ? elem.max : RooNumber::infinity();
208-
ws.import(RooRealVar{name.c_str(), name.c_str(), vMin, vMax});
288+
RooRealVar var{name.c_str(), name.c_str(), vMin, vMax};
289+
applyBinning(var, elem);
290+
ws.import(var);
209291
}
210292
}
211293
}
@@ -218,6 +300,7 @@ void Domains::ProductDomain::registerBinnings(const char *name, RooWorkspace &ws
218300
const double vMin = item.second.hasMin ? item.second.min : -RooNumber::infinity();
219301
const double vMax = item.second.hasMax ? item.second.max : RooNumber::infinity();
220302
var->setRange(name, vMin, vMax);
303+
applyBinning(*var, item.second, name);
221304
}
222305
}
223306

roofit/hs3/src/Domains.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <map>
1818
#include <vector>
1919

20+
class RooAbsBinning;
2021
class RooRealVar;
2122
class RooWorkspace;
2223

@@ -49,6 +50,7 @@ class Domains {
4950
class ProductDomain {
5051
public:
5152
void readVariable(const RooRealVar &);
53+
void readVariable(const char *name, RooAbsBinning const &binning);
5254
void readVariable(const char *name, double min, double max);
5355
void writeVariable(RooRealVar &) const;
5456

@@ -66,8 +68,15 @@ class Domains {
6668
bool hasMax = false;
6769
double min = 0.0;
6870
double max = 0.0;
71+
bool hasNBins = false;
72+
int nBins = 0;
73+
std::vector<double> edges;
6974
};
7075

76+
static void applyBinning(RooRealVar &var, ProductDomainElement const &elem, const char *name = nullptr);
77+
static void readBinning(ProductDomainElement &elem, RooAbsBinning const &binning);
78+
static void writeBinning(RooFit::Detail::JSONNode &node, ProductDomainElement const &elem);
79+
7180
std::map<std::string, ProductDomainElement> _map;
7281
};
7382

roofit/hs3/test/testRooFitHS3.cxx

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <RooFitHS3/RooJSONFactoryWSTool.h>
77

88
#include <RooAddPdf.h>
9+
#include <RooAddition.h>
10+
#include <RooBinning.h>
911
#include <RooCategory.h>
1012
#include <RooConstVar.h>
1113
#include <RooDataSet.h>
@@ -268,6 +270,61 @@ TEST(RooFitHS3, ProductDomainEntriesExportExplicitBounds)
268270
EXPECT_GT(importedMean->getMax(), 0.0);
269271
}
270272

273+
TEST(RooFitHS3, ProductDomainEntriesExportBinning)
274+
{
275+
RooRealVar uniform{"uniform", "uniform", 0.0, 1.0};
276+
uniform.setBins(7);
277+
278+
RooRealVar nonuniform{"nonuniform", "nonuniform", 0.0, 3.0};
279+
RooBinning nonuniformBinning{0.0, 3.0};
280+
nonuniformBinning.addBoundary(1.0);
281+
nonuniformBinning.addBoundary(1.5);
282+
nonuniform.setBinning(nonuniformBinning);
283+
284+
RooAddition sum{"sum", "sum", RooArgList{uniform, nonuniform}};
285+
286+
RooWorkspace ws{"workspace"};
287+
ws.import(sum, RooFit::Silence());
288+
289+
const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString();
290+
auto tree = RooFit::Detail::JSONTree::create(json);
291+
auto const *defaultDomain = RooJSONFactoryWSTool::findNamedChild(tree->rootnode()["domains"], "default_domain");
292+
ASSERT_NE(defaultDomain, nullptr);
293+
294+
auto const *uniformAxis = RooJSONFactoryWSTool::findNamedChild((*defaultDomain)["axes"], "uniform");
295+
ASSERT_NE(uniformAxis, nullptr);
296+
ASSERT_TRUE(uniformAxis->has_child("nbins"));
297+
EXPECT_EQ((*uniformAxis)["nbins"].val_int(), 7);
298+
EXPECT_FALSE(uniformAxis->has_child("edges"));
299+
300+
auto const *nonuniformAxis = RooJSONFactoryWSTool::findNamedChild((*defaultDomain)["axes"], "nonuniform");
301+
ASSERT_NE(nonuniformAxis, nullptr);
302+
ASSERT_TRUE(nonuniformAxis->has_child("edges"));
303+
EXPECT_FALSE(nonuniformAxis->has_child("nbins"));
304+
auto const &edges = (*nonuniformAxis)["edges"];
305+
ASSERT_EQ(edges.num_children(), 4u);
306+
EXPECT_DOUBLE_EQ(edges.child(0).val_double(), 0.0);
307+
EXPECT_DOUBLE_EQ(edges.child(1).val_double(), 1.0);
308+
EXPECT_DOUBLE_EQ(edges.child(2).val_double(), 1.5);
309+
EXPECT_DOUBLE_EQ(edges.child(3).val_double(), 3.0);
310+
311+
RooWorkspace imported;
312+
ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json));
313+
auto *importedUniform = imported.var("uniform");
314+
ASSERT_NE(importedUniform, nullptr);
315+
EXPECT_EQ(importedUniform->getBins(), 7);
316+
317+
auto *importedNonuniform = imported.var("nonuniform");
318+
ASSERT_NE(importedNonuniform, nullptr);
319+
auto const &importedBinning = importedNonuniform->getBinning();
320+
EXPECT_FALSE(importedBinning.isUniform());
321+
ASSERT_EQ(importedBinning.numBins(), 3);
322+
EXPECT_DOUBLE_EQ(importedBinning.binLow(0), 0.0);
323+
EXPECT_DOUBLE_EQ(importedBinning.binHigh(0), 1.0);
324+
EXPECT_DOUBLE_EQ(importedBinning.binHigh(1), 1.5);
325+
EXPECT_DOUBLE_EQ(importedBinning.binHigh(2), 3.0);
326+
}
327+
271328
TEST(RooFitHS3, ParameterStepWidthsModelConfigRoundTrip)
272329
{
273330
RooWorkspace ws1{"workspace"};
@@ -522,11 +579,15 @@ TEST(RooFitHS3, RooGaussianConstVarSigmaExport)
522579
EXPECT_NE(json.find("\"name\":\"sigma_real\""), std::string::npos);
523580
EXPECT_NE(domainAxes.find("\"name\":\"sigma_real\""), std::string::npos) << domainAxes;
524581

525-
// The unbounded constant RooRealVar is still a RooRealVar, so it gets an
526-
// empty domain axis that distinguishes it from a RooConstVar.
527-
EXPECT_NE(domainAxes.find("\"name\":\"mean\""), std::string::npos) << domainAxes;
528-
EXPECT_EQ(domainAxes.find("\"name\":\"mean\",\"min\""), std::string::npos) << domainAxes;
529-
EXPECT_EQ(domainAxes.find("\"name\":\"mean\",\"max\""), std::string::npos) << domainAxes;
582+
// The unbounded constant RooRealVar is still a RooRealVar, so it gets a
583+
// domain axis with explicit null bounds that distinguishes it from a RooConstVar.
584+
auto tree = RooFit::Detail::JSONTree::create(json);
585+
auto const *defaultDomain = RooJSONFactoryWSTool::findNamedChild(tree->rootnode()["domains"], "default_domain");
586+
ASSERT_NE(defaultDomain, nullptr);
587+
auto const *meanAxis = RooJSONFactoryWSTool::findNamedChild((*defaultDomain)["axes"], "mean");
588+
ASSERT_NE(meanAxis, nullptr);
589+
EXPECT_TRUE((*meanAxis)["min"].is_null());
590+
EXPECT_TRUE((*meanAxis)["max"].is_null());
530591

531592
RooWorkspace imported;
532593
ASSERT_TRUE(RooJSONFactoryWSTool{imported}.importJSONfromString(json));

0 commit comments

Comments
 (0)