Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <map>
#include <stdexcept>
#include <set>
#include <unordered_map>

namespace RooFit {
namespace JSONIO {
Expand Down Expand Up @@ -253,5 +254,12 @@ class RooJSONFactoryWSTool {
std::unique_ptr<RooFit::JSONIO::Detail::Domains> _domains;
std::vector<RooAbsArg const *> _serversToExport;
std::vector<RooAbsArg const *> _serversToDelete;

// Name-keyed indices over the top-level "functions" and "distributions"
// sequences of the input JSON. Built once at the start of importAllNodes()
// so that requestImpl() lookups become O(1) instead of an O(N) scan over
// every sibling node.
std::unordered_map<std::string, RooFit::Detail::JSONNode const *> _functionsByName;
std::unordered_map<std::string, RooFit::Detail::JSONNode const *> _distributionsByName;
};
#endif
43 changes: 43 additions & 0 deletions roofit/hs3/src/JSONFactories_RooFitCore.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
#include <RooLognormal.h>
#include <RooMultiVarGaussian.h>
#include <RooStats/HistFactory/ParamHistFunc.h>
#include <RooAddition.h>
#include <RooProduct.h>
#include <RooProdPdf.h>
#include <RooPoisson.h>
#include <RooPolynomial.h>
#include <RooPolyVar.h>
Expand Down Expand Up @@ -135,6 +138,43 @@ class RooFormulaArgFactory : public RooFit::JSONIO::Importer {
}
};

// Fast-path importers for RooProduct, RooAddition, and RooProdPdf that
// bypass the generic factory-expression mechanism. The default path
// generates a string expression and passes it to gROOT->ProcessLineFast(),
// which invokes the Cling JIT for every single call. For workspaces with
// thousands of product/sum nodes (a common shape for HistFactory models)
// that JIT cost dominates JSON import time. Constructing the RooFit object
// directly here keeps the work O(N) of cheap C++ calls.
class RooProductFactory : public RooFit::JSONIO::Importer {
public:
bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
{
std::string name(RooJSONFactoryWSTool::name(p));
tool->wsEmplace<RooProduct>(name, tool->requestArgList<RooAbsReal>(p, "factors"));
return true;
}
};

class RooProdPdfFactory : public RooFit::JSONIO::Importer {
public:
bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
{
std::string name(RooJSONFactoryWSTool::name(p));
tool->wsEmplace<RooProdPdf>(name, tool->requestArgList<RooAbsPdf>(p, "factors"));
return true;
}
};

class RooAdditionFactory : public RooFit::JSONIO::Importer {
public:
bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
{
std::string name(RooJSONFactoryWSTool::name(p));
tool->wsEmplace<RooAddition>(name, tool->requestArgList<RooAbsReal>(p, "summands"));
return true;
}
};

class RooAddPdfFactory : public RooFit::JSONIO::Importer {
public:
bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
Expand Down Expand Up @@ -1195,6 +1235,9 @@ DEFINE_EXPORTER_KEY(RooSplineStreamer, "spline");
STATIC_EXECUTE([]() {
using namespace RooFit::JSONIO;

registerImporter<RooProductFactory>("product", false);
registerImporter<RooProdPdfFactory>("product_dist", false);
registerImporter<RooAdditionFactory>("sum", false);
registerImporter<RooAddPdfFactory>("mixture_dist", false);
registerImporter<RooAddModelFactory>("mixture_model", false);
registerImporter<RooBinSamplingPdfFactory>("binsampling_dist", false);
Expand Down
49 changes: 37 additions & 12 deletions roofit/hs3/src/RooJSONFactoryWSTool.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -924,12 +924,11 @@ RooAbsPdf *RooJSONFactoryWSTool::requestImpl<RooAbsPdf>(const std::string &objna
{
if (RooAbsPdf *retval = _workspace.pdf(objname))
return retval;
if (const auto &distributionsNode = _rootnodeInput->find("distributions")) {
if (const auto &child = findNamedChild(*distributionsNode, objname)) {
this->importFunction(*child, true);
if (RooAbsPdf *retval = _workspace.pdf(objname))
return retval;
}
auto it = _distributionsByName.find(objname);
if (it != _distributionsByName.end()) {
this->importFunction(*it->second, true);
if (RooAbsPdf *retval = _workspace.pdf(objname))
return retval;
}
return nullptr;
}
Expand All @@ -945,12 +944,11 @@ RooAbsReal *RooJSONFactoryWSTool::requestImpl<RooAbsReal>(const std::string &obj
return pdf;
if (RooRealVar *var = requestImpl<RooRealVar>(objname))
return var;
if (const auto &functionNode = _rootnodeInput->find("functions")) {
if (const auto &child = findNamedChild(*functionNode, objname)) {
this->importFunction(*child, true);
if (RooAbsReal *retval = _workspace.function(objname))
return retval;
}
auto it = _functionsByName.find(objname);
if (it != _functionsByName.end()) {
this->importFunction(*it->second, true);
if (RooAbsReal *retval = _workspace.function(objname))
return retval;
}
return nullptr;
}
Expand Down Expand Up @@ -2175,6 +2173,31 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)

_attributesNode = findRooFitInternal(*_rootnodeInput, "attributes");

// Build name-keyed indices over the "functions" and "distributions"
// sequences. Without these, every cross-reference resolved during import
// (e.g. dependencies of a PiecewiseInterpolation, or factory-expression
// arguments) triggers a linear scan over all sibling nodes via
// findNamedChild(), which becomes O(N^2) on workspaces with thousands of
// entries. Populating the maps up-front turns each lookup into O(1).
_functionsByName.clear();
_distributionsByName.clear();
if (auto seq = n.find("functions")) {
if (seq->is_seq()) {
_functionsByName.reserve(seq->num_children());
for (const auto &p : seq->children()) {
_functionsByName.emplace(RooJSONFactoryWSTool::name(p), &p);
}
}
}
if (auto seq = n.find("distributions")) {
if (seq->is_seq()) {
_distributionsByName.reserve(seq->num_children());
for (const auto &p : seq->children()) {
_distributionsByName.emplace(RooJSONFactoryWSTool::name(p), &p);
}
}
}

this->importDependants(n);

if (auto paramPointsNode = n.find("parameter_points")) {
Expand Down Expand Up @@ -2239,6 +2262,8 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)

_rootnodeInput = nullptr;
_domains.reset();
_functionsByName.clear();
_distributionsByName.clear();
}

/**
Expand Down
22 changes: 19 additions & 3 deletions roofit/roofitcore/src/RooFactoryWSTool.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ namespace {
return false;
}

pair<list<string>,unsigned int> ctorArgs(const char* classname, std::size_t nPassedArgs) {
pair<list<string>,unsigned int> ctorArgsImpl(const char* classname, std::size_t nPassedArgs) {
// Utility function for RooFactoryWSTool. Return arguments of 'first' non-default, non-copy constructor of any RooAbsArg
// derived class. Only constructors that start with two `const char*` arguments (for name and title) are considered
// The returned object contains
Expand Down Expand Up @@ -269,6 +269,22 @@ namespace {
gInterpreter->ClassInfo_Delete(cls);
return pair<list<string>,unsigned int>(ret,nreq);
}

pair<list<string>,unsigned int> const & ctorArgs(const char* classname, std::size_t nPassedArgs) {
// Cache the result of ctorArgsImpl(). For a given (classname, nPassedArgs)
// the answer is determined by the static class definition and never changes
// at runtime, but ctorArgsImpl() drives the Cling interpreter to enumerate
// every constructor of the class. When the factory is invoked thousands of
// times (e.g. during HS3 JSON import of a large workspace), repeating that
// lookup dominates the import time.
static std::map<pair<string, std::size_t>, pair<list<string>, unsigned int>> cache;
auto key = std::make_pair(string(classname), nPassedArgs);
auto it = cache.find(key);
if (it == cache.end()) {
it = cache.emplace(std::move(key), ctorArgsImpl(classname, nPassedArgs)).first;
}
return it->second;
}
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -321,7 +337,7 @@ RooAbsArg* RooFactoryWSTool::createArg(const char* className, const char* objNam
_args.push_back(tmp.substr(start_tok, end_tok));

// Try Cling interface
pair<list<string>,unsigned int> ca = ctorArgs(className,_args.size()+2) ;
pair<list<string>,unsigned int> const & ca = ctorArgs(className,_args.size()+2) ;
if (ca.first.empty()) {
coutE(ObjectHandling) << "RooFactoryWSTool::createArg() ERROR no suitable constructor found for class " << className << std::endl ;
logError() ;
Expand Down Expand Up @@ -352,7 +368,7 @@ RooAbsArg* RooFactoryWSTool::createArg(const char* className, const char* objNam

try {
Int_t i(0) ;
list<string>::iterator ti = ca.first.begin() ; ++ti ; ++ti ;
list<string>::const_iterator ti = ca.first.begin() ; ++ti ; ++ti ;
for (vector<string>::iterator ai = _args.begin() ; ai != _args.end() ; ++ai,++ti,++i) {
if ((*ti)=="RooAbsReal&" || (*ti)=="const RooAbsReal&" || (*ti)=="RooAbsReal::Ref") {
RooFactoryWSTool::as_FUNC(i) ;
Expand Down
Loading