Skip to content

Commit 1a594a3

Browse files
committed
[RF][HS3] Add name-keyed JSONNode cache for RooAbsaArgs
Looking up by name in the JSON tree is quite slow and happens often during model import from JSON. Caching the pointers to the JSON nodes corresponding to the RooAbsArgs (distributions and functions) speeds up JSON import significantly. (cherry picked from commit 149cca4)
1 parent 52e1a2b commit 1a594a3

2 files changed

Lines changed: 45 additions & 12 deletions

File tree

roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <map>
2424
#include <stdexcept>
2525
#include <set>
26+
#include <unordered_map>
2627

2728
namespace RooFit {
2829
namespace JSONIO {
@@ -253,5 +254,12 @@ class RooJSONFactoryWSTool {
253254
std::unique_ptr<RooFit::JSONIO::Detail::Domains> _domains;
254255
std::vector<RooAbsArg const *> _serversToExport;
255256
std::vector<RooAbsArg const *> _serversToDelete;
257+
258+
// Name-keyed indices over the top-level "functions" and "distributions"
259+
// sequences of the input JSON. Built once at the start of importAllNodes()
260+
// so that requestImpl() lookups become O(1) instead of an O(N) scan over
261+
// every sibling node.
262+
std::unordered_map<std::string, RooFit::Detail::JSONNode const *> _functionsByName;
263+
std::unordered_map<std::string, RooFit::Detail::JSONNode const *> _distributionsByName;
256264
};
257265
#endif

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -924,12 +924,11 @@ RooAbsPdf *RooJSONFactoryWSTool::requestImpl<RooAbsPdf>(const std::string &objna
924924
{
925925
if (RooAbsPdf *retval = _workspace.pdf(objname))
926926
return retval;
927-
if (const auto &distributionsNode = _rootnodeInput->find("distributions")) {
928-
if (const auto &child = findNamedChild(*distributionsNode, objname)) {
929-
this->importFunction(*child, true);
930-
if (RooAbsPdf *retval = _workspace.pdf(objname))
931-
return retval;
932-
}
927+
auto it = _distributionsByName.find(objname);
928+
if (it != _distributionsByName.end()) {
929+
this->importFunction(*it->second, true);
930+
if (RooAbsPdf *retval = _workspace.pdf(objname))
931+
return retval;
933932
}
934933
return nullptr;
935934
}
@@ -945,12 +944,11 @@ RooAbsReal *RooJSONFactoryWSTool::requestImpl<RooAbsReal>(const std::string &obj
945944
return pdf;
946945
if (RooRealVar *var = requestImpl<RooRealVar>(objname))
947946
return var;
948-
if (const auto &functionNode = _rootnodeInput->find("functions")) {
949-
if (const auto &child = findNamedChild(*functionNode, objname)) {
950-
this->importFunction(*child, true);
951-
if (RooAbsReal *retval = _workspace.function(objname))
952-
return retval;
953-
}
947+
auto it = _functionsByName.find(objname);
948+
if (it != _functionsByName.end()) {
949+
this->importFunction(*it->second, true);
950+
if (RooAbsReal *retval = _workspace.function(objname))
951+
return retval;
954952
}
955953
return nullptr;
956954
}
@@ -2175,6 +2173,31 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)
21752173

21762174
_attributesNode = findRooFitInternal(*_rootnodeInput, "attributes");
21772175

2176+
// Build name-keyed indices over the "functions" and "distributions"
2177+
// sequences. Without these, every cross-reference resolved during import
2178+
// (e.g. dependencies of a PiecewiseInterpolation, or factory-expression
2179+
// arguments) triggers a linear scan over all sibling nodes via
2180+
// findNamedChild(), which becomes O(N^2) on workspaces with thousands of
2181+
// entries. Populating the maps up-front turns each lookup into O(1).
2182+
_functionsByName.clear();
2183+
_distributionsByName.clear();
2184+
if (auto seq = n.find("functions")) {
2185+
if (seq->is_seq()) {
2186+
_functionsByName.reserve(seq->num_children());
2187+
for (const auto &p : seq->children()) {
2188+
_functionsByName.emplace(RooJSONFactoryWSTool::name(p), &p);
2189+
}
2190+
}
2191+
}
2192+
if (auto seq = n.find("distributions")) {
2193+
if (seq->is_seq()) {
2194+
_distributionsByName.reserve(seq->num_children());
2195+
for (const auto &p : seq->children()) {
2196+
_distributionsByName.emplace(RooJSONFactoryWSTool::name(p), &p);
2197+
}
2198+
}
2199+
}
2200+
21782201
this->importDependants(n);
21792202

21802203
if (auto paramPointsNode = n.find("parameter_points")) {
@@ -2239,6 +2262,8 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)
22392262

22402263
_rootnodeInput = nullptr;
22412264
_domains.reset();
2265+
_functionsByName.clear();
2266+
_distributionsByName.clear();
22422267
}
22432268

22442269
/**

0 commit comments

Comments
 (0)