Skip to content

Commit 18fbd3f

Browse files
committed
[RF][HF] Support setting ShapeFactor value and range
Adds a Sample::AddShapeFactor overload that takes the initial value and the range of the ShapeFactor gammas, analogous to AddNormFactor. This is important e.g. for ABCD estimates, where the hard-coded default range can cause convergence problems. The value and range are persisted to and read back from both XML and ROOT files, and a dedicated test covers the full round trips as well as the resulting workspace parameters. The ShapeFactor element in the XML schema (HistFactorySchema.dtd) is updated accordingly. The value and range use the same 'Val' / 'Low' / 'High' attribute and accessor names as NormFactor for consistency. Closes #20697. 🤖 Done with the help of AI for writing the tests. (cherry picked from commit ecb0adc)
1 parent 8b2ca6f commit 18fbd3f

7 files changed

Lines changed: 220 additions & 47 deletions

File tree

roofit/etc/HistFactorySchema.dtd

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,5 +156,12 @@ For this element there is no sublemenents so the setting will only have local ef
156156
<!ELEMENT ShapeFactor EMPTY>
157157
<!ATTLIST ShapeFactor
158158
Name CDATA #REQUIRED
159+
Val CDATA #IMPLIED
160+
Low CDATA #IMPLIED
161+
High CDATA #IMPLIED
162+
Const CDATA #IMPLIED
163+
InputFile CDATA #IMPLIED
164+
HistoName CDATA #IMPLIED
165+
HistoPath CDATA #IMPLIED
159166
>
160167

roofit/histfactory/inc/RooStats/HistFactory/Detail/HistFactoryImpl.h

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,24 @@
1313
#ifndef HistFactoryImplHelpers_h
1414
#define HistFactoryImplHelpers_h
1515

16-
#include <RooStats/HistFactory/Measurement.h>
17-
1816
#include <RooGlobalFunc.h>
1917
#include <RooWorkspace.h>
2018

2119
#include <ROOT/RSpan.hxx>
2220

23-
namespace RooStats {
24-
namespace HistFactory {
21+
namespace RooStats::HistFactory {
22+
23+
namespace Constraint {
24+
25+
enum Type {
26+
Gaussian,
27+
Poisson
28+
};
29+
std::string Name(Type type);
30+
Type GetType(const std::string &Name);
31+
32+
} // namespace Constraint
33+
2534
namespace Detail {
2635

2736
namespace MagicConstants {
@@ -49,15 +58,13 @@ void configureConstrainedGammas(RooArgList const &gammas, std::span<const double
4958

5059
struct CreateGammaConstraintsOutput {
5160
std::vector<std::unique_ptr<RooAbsPdf>> constraints;
52-
std::vector<RooRealVar*> globalObservables;
61+
std::vector<RooRealVar *> globalObservables;
5362
};
5463

55-
CreateGammaConstraintsOutput createGammaConstraints(RooArgList const &paramList,
56-
std::span<const double> relSigmas, double minSigma,
57-
Constraint::Type type);
64+
CreateGammaConstraintsOutput createGammaConstraints(RooArgList const &paramList, std::span<const double> relSigmas,
65+
double minSigma, Constraint::Type type);
5866

5967
} // namespace Detail
60-
} // namespace HistFactory
61-
} // namespace RooStats
68+
} // namespace RooStats::HistFactory
6269

6370
#endif

roofit/histfactory/inc/RooStats/HistFactory/Measurement.h

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include <TNamed.h>
1515

16+
#include <RooStats/HistFactory/Detail/HistFactoryImpl.h>
17+
1618
#include <fstream>
1719
#include <iostream>
1820
#include <map>
@@ -27,17 +29,6 @@ class RooWorkspace;
2729

2830
namespace RooStats::HistFactory {
2931

30-
namespace Constraint {
31-
32-
enum Type {
33-
Gaussian,
34-
Poisson
35-
};
36-
std::string Name(Type type);
37-
Type GetType(const std::string &Name);
38-
39-
} // namespace Constraint
40-
4132
/** \class OverallSys
4233
* \ingroup HistFactory
4334
* Configuration for a constrained overall systematic to scale sample normalisations.
@@ -257,12 +248,24 @@ class ShapeFactor : public HistogramUncertaintyBase {
257248
}
258249
const std::string &GetHistoPath() const { return fHistoPathHigh; }
259250

251+
double GetVal() const { return fVal; }
252+
253+
double GetLow() const { return fLow; }
254+
double GetHigh() const { return fHigh; }
255+
256+
void SetVal(double Val) { fVal = Val; }
257+
258+
void SetLow(double Low) { fLow = Low; }
259+
void SetHigh(double High) { fHigh = High; }
260+
260261
protected:
261262
bool fConstant = false;
262-
263-
// A histogram representing
264-
// the initial shape
265263
bool fHasInitialShape = false;
264+
double fVal = 1.0;
265+
// GHL: Again, we are putting hard ranges on the gammas by default.
266+
// We should change this to range from 0 to /inf.
267+
double fLow = Detail::MagicConstants::defaultGammaMin;
268+
double fHigh = Detail::MagicConstants::defaultShapeFactorGammaMax;
266269
};
267270

268271
/** \class StatError
@@ -484,6 +487,7 @@ class Sample {
484487
void AddHistoFactor(const HistoFactor &Factor);
485488

486489
void AddShapeFactor(std::string Name);
490+
void AddShapeFactor(std::string Name, double Val, double Low, double High);
487491
void AddShapeFactor(const ShapeFactor &Factor);
488492

489493
void AddShapeSys(std::string Name, Constraint::Type ConstraintType, std::string HistoName, std::string HistoFile,

roofit/histfactory/src/ConfigParser.cxx

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,29 +1252,32 @@ HistFactory::ShapeFactor ConfigParser::MakeShapeFactor( TXMLNode* node ) {
12521252

12531253
else if( attrName == TString( "Name" ) ) {
12541254
shapeFactor.SetName( attrVal );
1255-
}
1256-
else if( attrName == TString( "Const" ) ) {
1257-
shapeFactor.SetConstant( CheckTrueFalse(attrVal, "ShapeFactor" ) );
1255+
} else if (attrName == TString("Val")) {
1256+
shapeFactor.SetVal(toDouble(attrVal));
1257+
} else if (attrName == TString("Low")) {
1258+
shapeFactor.SetLow(toDouble(attrVal));
1259+
} else if (attrName == TString("High")) {
1260+
shapeFactor.SetHigh(toDouble(attrVal));
1261+
} else if (attrName == TString("Const")) {
1262+
shapeFactor.SetConstant(CheckTrueFalse(attrVal, "ShapeFactor"));
12581263
}
12591264

1260-
else if( attrName == TString( "HistoName" ) ) {
1261-
shapeFactor.SetHistoName( attrVal );
1265+
else if (attrName == TString("HistoName")) {
1266+
shapeFactor.SetHistoName(attrVal);
12621267
}
12631268

1264-
else if( attrName == TString( "InputFile" ) ) {
1265-
ShapeInputFile = attrVal;
1269+
else if (attrName == TString("InputFile")) {
1270+
ShapeInputFile = attrVal;
12661271
}
12671272

1268-
else if( attrName == TString( "HistoPath" ) ) {
1269-
ShapeInputPath = attrVal;
1273+
else if (attrName == TString("HistoPath")) {
1274+
ShapeInputPath = attrVal;
12701275
}
12711276

12721277
else {
1273-
cxcoutEHF << "Error: Encountered Element in ShapeFactor with unknown name: "
1274-
<< attrName << std::endl;
1275-
throw hf_exc();
1278+
cxcoutEHF << "Error: Encountered Element in ShapeFactor with unknown name: " << attrName << std::endl;
1279+
throw hf_exc();
12761280
}
1277-
12781281
}
12791282

12801283
if( shapeFactor.GetName().empty() ) {

roofit/histfactory/src/HistoToWorkspaceFactoryFast.cxx

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,13 +1101,17 @@ RooArgList HistoToWorkspaceFactoryFast::createObservables(const TH1 *hist, RooWo
11011101
}
11021102

11031103
// Create the Parameters
1104-
std::string funcParams = "gamma_" + shapeFactor.GetName();
1105-
1106-
// GHL: Again, we are putting hard ranges on the gamma's
1107-
// We should change this to range from 0 to /inf
1108-
RooArgList shapeFactorParams = ParamHistFunc::createParamSet(proto,
1109-
funcParams,
1110-
theObservables, defaultGammaMin, defaultShapeFactorGammaMax);
1104+
RooArgList shapeFactorParams =
1105+
ParamHistFunc::createParamSet(proto, "gamma_" + shapeFactor.GetName(), theObservables);
1106+
for (auto *comp : shapeFactorParams) {
1107+
// If the gamma is subject to a preprocess function, it is a RooAbsReal and
1108+
// we don't need to set the initial value.
1109+
if (auto var = dynamic_cast<RooRealVar *>(comp)) {
1110+
var->setVal(shapeFactor.GetVal());
1111+
var->setMin(shapeFactor.GetLow());
1112+
var->setMax(shapeFactor.GetHigh());
1113+
}
1114+
}
11111115

11121116
// Create the Function
11131117
ParamHistFunc shapeFactorFunc( funcName.c_str(), funcName.c_str(),

roofit/histfactory/src/Measurement.cxx

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -860,9 +860,18 @@ void Sample::AddHistoFactor(const HistoFactor &Factor)
860860
void Sample::AddShapeFactor(std::string SysName)
861861
{
862862

863-
ShapeFactor factor;
864-
factor.SetName(SysName);
865-
fShapeFactorList.push_back(factor);
863+
fShapeFactorList.emplace_back();
864+
fShapeFactorList.back().SetName(SysName);
865+
}
866+
867+
void Sample::AddShapeFactor(std::string SysName, double Val, double Low, double High)
868+
{
869+
870+
fShapeFactorList.emplace_back();
871+
fShapeFactorList.back().SetName(SysName);
872+
fShapeFactorList.back().SetVal(Val);
873+
fShapeFactorList.back().SetLow(Low);
874+
fShapeFactorList.back().SetHigh(High);
866875
}
867876

868877
void Sample::AddShapeFactor(const ShapeFactor &Factor)
@@ -1725,6 +1734,8 @@ void ShapeFactor::Print(std::ostream &stream) const
17251734
<< " Shape Hist Name: " << fHistoNameHigh << " Shape Hist Path Name: " << fHistoPathHigh
17261735
<< " Shape Hist FileName: " << fInputFileHigh << std::endl;
17271736
}
1737+
// Print value and range in RooRealVar style
1738+
stream << "\t \t Value: " << GetVal() << " L(" << GetLow() << " - " << GetHigh() << ")\n";
17281739

17291740
if (fConstant) {
17301741
stream << "\t \t ( Constant ): " << std::endl;
@@ -1756,6 +1767,9 @@ void ShapeFactor::PrintXML(std::ostream &xml) const
17561767
<< " HistoName=\"" << GetHistoName() << "\" "
17571768
<< " HistoPath=\"" << GetHistoPath() << "\" ";
17581769
}
1770+
xml << " Val=\"" << GetVal() << "\" "
1771+
<< " High=\"" << GetHigh() << "\" "
1772+
<< " Low=\"" << GetLow() << "\" ";
17591773
xml << " /> " << std::endl;
17601774
}
17611775

roofit/histfactory/test/testHistFactory.cxx

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <RooStats/HistFactory/Measurement.h>
66
#include <RooStats/HistFactory/MakeModelAndMeasurementsFast.h>
7+
#include <RooStats/HistFactory/ConfigParser.h>
78
#include <RooFit/ModelConfig.h>
89

910
#include <RooFitHS3/JSONIO.h>
@@ -24,6 +25,7 @@
2425
#include <TROOT.h>
2526
#include <TFile.h>
2627
#include <TCanvas.h>
28+
#include <TSystem.h>
2729
#include <gtest/gtest.h>
2830

2931
#include "../../roofitcore/test/gtest_wrapper.h"
@@ -794,3 +796,135 @@ TEST(HistFactory, HS3ImportShapeFactorModifier)
794796
const std::string js2 = RooJSONFactoryWSTool{wsFromJson}.exportJSONtoString();
795797
EXPECT_EQ(js, js2) << "JSON -> WS -> JSON roundtrip changed the JSON";
796798
}
799+
800+
// Issue #20697: Sample::AddShapeFactor() now allows to set the initial value
801+
// and the range of the ShapeFactor gammas (just like AddNormFactor() does for
802+
// the NormFactors). This is important e.g. for ABCD estimates, where the
803+
// hard-coded default range can cause convergence problems.
804+
//
805+
// These settings need to survive being persisted, so this test checks that the
806+
// value and range make it through:
807+
// 1. a ROOT file round trip (Measurement::writeToFile),
808+
// 2. an XML file round trip (Measurement::PrintXML / ConfigParser), and
809+
// 3. into the actual gamma parameters of the generated workspace.
810+
TEST(HistFactory, ShapeFactorValueAndRange)
811+
{
812+
using namespace RooStats::HistFactory;
813+
RooHelpers::LocalChangeMsgLevel changeMsgLvl(RooFit::WARNING);
814+
815+
// Deliberately use non-default values and a range that differs from the
816+
// hard-coded default of [0, 1000].
817+
const double sfVal = 2.0;
818+
const double sfLow = 0.1;
819+
const double sfHigh = 12.0;
820+
821+
const std::string inputFileName = "TestShapeFactorRange_input.root";
822+
{
823+
TFile f(inputFileName.c_str(), "RECREATE");
824+
auto *data = new TH1D("data", "data", 2, 1, 2);
825+
auto *signal = new TH1D("signal", "signal", 2, 1, 2);
826+
auto *bkg = new TH1D("background", "background", 2, 1, 2);
827+
data->SetBinContent(1, 220);
828+
data->SetBinContent(2, 230);
829+
signal->SetBinContent(1, 10);
830+
signal->SetBinContent(2, 20);
831+
bkg->SetBinContent(1, 200);
832+
bkg->SetBinContent(2, 200);
833+
for (auto *h : {data, signal, bkg})
834+
f.WriteTObject(h);
835+
}
836+
837+
auto makeMeasurement = [&]() {
838+
Measurement meas("meas", "meas");
839+
meas.SetOutputFilePrefix("TestShapeFactorRange");
840+
meas.SetPOI("SigXsecOverSM");
841+
meas.AddConstantParam("Lumi");
842+
meas.SetLumi(1.0);
843+
meas.SetLumiRelErr(0.10);
844+
845+
Channel chan("channel1");
846+
chan.SetData("data", inputFileName);
847+
848+
Sample sig("signal", "signal", inputFileName);
849+
sig.AddNormFactor("SigXsecOverSM", 1, 0, 3);
850+
chan.AddSample(sig);
851+
852+
// The new overload under test: ShapeFactor with custom value and range.
853+
Sample bkg("background", "background", inputFileName);
854+
bkg.AddShapeFactor("bkgShape", sfVal, sfLow, sfHigh);
855+
chan.AddSample(bkg);
856+
857+
meas.AddChannel(chan);
858+
meas.CollectHistograms();
859+
return meas;
860+
};
861+
862+
// Fetch the (single) ShapeFactor stored in a measurement.
863+
auto getShapeFactor = [](Measurement &meas) -> ShapeFactor & {
864+
Channel &chan = meas.GetChannel("channel1");
865+
for (Sample &sample : chan.GetSamples()) {
866+
if (!sample.GetShapeFactorList().empty())
867+
return sample.GetShapeFactorList().front();
868+
}
869+
throw std::runtime_error("ShapeFactor not found in measurement");
870+
};
871+
872+
auto checkShapeFactor = [&](Measurement &meas, const char *context) {
873+
ShapeFactor &sf = getShapeFactor(meas);
874+
EXPECT_DOUBLE_EQ(sf.GetVal(), sfVal) << context;
875+
EXPECT_DOUBLE_EQ(sf.GetLow(), sfLow) << context;
876+
EXPECT_DOUBLE_EQ(sf.GetHigh(), sfHigh) << context;
877+
};
878+
879+
// 0. Sanity check on the in-memory measurement.
880+
{
881+
Measurement meas = makeMeasurement();
882+
checkShapeFactor(meas, "in-memory measurement");
883+
}
884+
885+
// 1. ROOT file round trip.
886+
{
887+
Measurement meas = makeMeasurement();
888+
const std::string rootFileName = "TestShapeFactorRange_meas.root";
889+
{
890+
TFile outFile(rootFileName.c_str(), "RECREATE");
891+
meas.writeToFile(&outFile);
892+
}
893+
TFile inFile(rootFileName.c_str(), "READ");
894+
std::unique_ptr<Measurement> measFromFile{inFile.Get<Measurement>("meas")};
895+
ASSERT_NE(measFromFile, nullptr);
896+
checkShapeFactor(*measFromFile, "ROOT file round trip");
897+
}
898+
899+
// 2. XML file round trip.
900+
{
901+
Measurement meas = makeMeasurement();
902+
const std::string xmlDir = "TestShapeFactorRangeXML";
903+
meas.PrintXML(xmlDir);
904+
905+
// The generated XML files refer to the DTD by relative path, so it has to
906+
// be available next to them for the validating parser to find it.
907+
gSystem->CopyFile(TString::Format("%s/HistFactorySchema.dtd", TROOT::GetEtcDir().Data()),
908+
TString::Format("%s/HistFactorySchema.dtd", xmlDir.c_str()), true);
909+
910+
ConfigParser parser;
911+
std::vector<Measurement> measFromXML = parser.GetMeasurementsFromXML(xmlDir + "/meas.xml");
912+
ASSERT_EQ(measFromXML.size(), 1u);
913+
checkShapeFactor(measFromXML.front(), "XML file round trip");
914+
}
915+
916+
// 3. End to end: the gamma parameters of the workspace pick up the requested
917+
// value and range.
918+
{
919+
Measurement meas = makeMeasurement();
920+
std::unique_ptr<RooWorkspace> ws{MakeModelAndMeasurementFast(meas)};
921+
ASSERT_NE(ws, nullptr);
922+
for (const char *name : {"gamma_bkgShape_bin_0", "gamma_bkgShape_bin_1"}) {
923+
auto *gamma = ws->var(name);
924+
ASSERT_NE(gamma, nullptr) << name;
925+
EXPECT_DOUBLE_EQ(gamma->getVal(), sfVal) << name;
926+
EXPECT_DOUBLE_EQ(gamma->getMin(), sfLow) << name;
927+
EXPECT_DOUBLE_EQ(gamma->getMax(), sfHigh) << name;
928+
}
929+
}
930+
}

0 commit comments

Comments
 (0)