Skip to content

Commit 37adbbe

Browse files
committed
[RF] Implement RooMomentMorphFuncND::compileForNormSet()
Transform the RooMomentMorphFuncND instance to an expended computation graph that can be evaluated by the `RooFit::Evaluator` without missing some internal RooAbsArgs that could be cached. (cherry picked from commit d663e30)
1 parent 0fb869f commit 37adbbe

3 files changed

Lines changed: 146 additions & 0 deletions

File tree

roofit/roofit/inc/LinkDef1.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#pragma link C++ class RooMomentMorphFunc + ;
5757
#pragma link C++ class RooMomentMorphFuncND + ;
5858
#pragma link C++ class RooMomentMorphFuncND::Grid2 + ;
59+
#pragma link C++ class RooFit::Detail::RooMomentMorphFraction + ;
5960
#pragma link C++ class RooSpline+ ;
6061
#pragma link C++ class RooStepFunction+ ;
6162
#pragma link C++ class RooMultiBinomial+ ;

roofit/roofit/inc/RooMomentMorphFuncND.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "RooListProxy.h"
2121
#include "RooArgList.h"
2222
#include "RooBinning.h"
23+
#include "RooFit/Detail/NormalizationHelpers.h"
2324

2425
#include "TMatrixD.h"
2526
#include "TMap.h"
@@ -30,6 +31,12 @@
3031
class RooChangeTracker;
3132
class RooRealSumFunc;
3233

34+
namespace RooFit {
35+
namespace Detail {
36+
class RooMomentMorphFraction;
37+
}
38+
} // namespace RooFit
39+
3340
class RooMomentMorphFuncND : public RooAbsReal {
3441

3542
public:
@@ -129,6 +136,9 @@ class RooMomentMorphFuncND : public RooAbsReal {
129136
double evaluate() const override;
130137
double getValV(const RooArgSet *set = nullptr) const override;
131138

139+
std::unique_ptr<RooAbsArg>
140+
compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext &ctx) const override;
141+
132142
protected:
133143
void initialize();
134144

@@ -139,6 +149,7 @@ class RooMomentMorphFuncND : public RooAbsReal {
139149

140150
friend class CacheElem;
141151
friend class Grid2;
152+
friend class RooFit::Detail::RooMomentMorphFraction;
142153

143154
mutable RooObjCacheManager _cacheMgr; ///<! Transient cache manager
144155
mutable RooArgSet *_curNormSet = nullptr; ///<! Transient cache manager
@@ -162,4 +173,32 @@ class RooMomentMorphFuncND : public RooAbsReal {
162173
ClassDefOverride(RooMomentMorphFuncND, 4);
163174
};
164175

176+
namespace RooFit {
177+
namespace Detail {
178+
179+
/// Helper compute-graph node that exposes one of the morph mixing fractions to
180+
/// the RooFit::Evaluator. It re-runs RooMomentMorphFuncND::CacheElem::calculateFractions
181+
/// only when the morph parameters change, then returns the cached fraction value
182+
/// for this index.
183+
class RooMomentMorphFraction : public RooAbsReal {
184+
public:
185+
RooMomentMorphFraction() {}
186+
RooMomentMorphFraction(const char *name, const char *title, RooMomentMorphFuncND const &parent, int index);
187+
RooMomentMorphFraction(RooMomentMorphFraction const &other, const char *name = nullptr);
188+
TObject *clone(const char *newname) const override { return new RooMomentMorphFraction(*this, newname); }
189+
190+
protected:
191+
double evaluate() const override;
192+
193+
private:
194+
RooListProxy _parList;
195+
const RooMomentMorphFuncND *_parent = nullptr; ///<! morph that owns the cache (not owned)
196+
int _index = 0;
197+
198+
ClassDefOverride(RooFit::Detail::RooMomentMorphFraction, 0);
199+
};
200+
201+
} // namespace Detail
202+
} // namespace RooFit
203+
165204
#endif

roofit/roofit/src/RooMomentMorphFuncND.cxx

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,112 @@ void RooMomentMorphFuncND::Grid2::addPdf(const RooMomentMorphFuncND::Base_t &pdf
345345
_nref.push_back(thisBoundaryCoordinates);
346346
}
347347

348+
//_____________________________________________________________________________
349+
std::unique_ptr<RooAbsArg>
350+
RooMomentMorphFuncND::compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext &ctx) const
351+
{
352+
// Build (or fetch) the cache that holds the morph's internal compute graph:
353+
// moment integrals, slope/offset formulas, RooLinearVar transforms, the per-pdf
354+
// RooHistPdf clones, and the final RooAddPdf/RooRealSumFunc sum.
355+
CacheElem *cache = getCache(&normSet);
356+
357+
// Make sure fractions hold sensible initial values (the replacement nodes
358+
// below will keep them in sync going forward).
359+
cache->calculateFractions(*this, false);
360+
361+
// The cache's subtree carries ORIGNAME: attributes left over from the
362+
// RooCustomizer that built the per-pdf transformed RooHistPdf clones (it
363+
// tagged each transVar with ORIGNAME:<obs>). RooCustomizer::build calls
364+
// redirectServers with nameChange=true and will throw if it sees several
365+
// candidates with the same ORIGNAME:* attribute, which is exactly what
366+
// happens here once we ask it to clone the subtree. Strip those stale
367+
// markers before cloning.
368+
{
369+
RooArgSet branches;
370+
cache->_sum->branchNodeServerList(&branches);
371+
branches.add(*cache->_sum);
372+
for (auto *b : branches) {
373+
std::vector<std::string> toRemove;
374+
for (auto const &attr : b->attributes()) {
375+
if (attr.rfind("ORIGNAME:", 0) == 0)
376+
toRemove.push_back(attr);
377+
}
378+
for (auto const &attr : toRemove)
379+
b->setAttribute(attr.c_str(), false);
380+
}
381+
}
382+
383+
// Replace each of the imperatively-updated fraction RooRealVars with a
384+
// RooMomentMorphFraction node. This puts the fraction-recomputation inside
385+
// the Evaluator's compute graph, so the moment integrals (which the
386+
// fractions depend on transitively via the slope/offset formulas) become
387+
// sibling nodes that the Evaluator caches once per minimization step.
388+
RooArgList newFractions;
389+
const int nFrac = cache->_frac.size();
390+
for (int i = 0; i < nFrac; ++i) {
391+
auto frac = static_cast<RooRealVar *>(cache->_frac.at(i));
392+
std::string newName = std::string{frac->GetName()} + "_compiled";
393+
newFractions.addOwned(
394+
std::make_unique<RooFit::Detail::RooMomentMorphFraction>(newName.c_str(), frac->GetTitle(), *this, i));
395+
}
396+
397+
RooArgSet clonedBranches;
398+
RooCustomizer cust(*cache->_sum, "compiled");
399+
cust.setCloneBranchSet(clonedBranches);
400+
for (int i = 0; i < nFrac; ++i) {
401+
cust.replaceArg(*cache->_frac.at(i), newFractions[i]);
402+
}
403+
404+
// RooCustomizer::build() already transfers ownership of the cloned branches
405+
// (everything in `clonedBranches` except the returned top node) to the new
406+
// top node's owned-components list, so we only have to attach the
407+
// newly-created fraction nodes here.
408+
std::unique_ptr<RooAbsReal> newSum{static_cast<RooAbsReal *>(cust.build())};
409+
newSum->addOwnedComponents(std::move(newFractions));
410+
411+
// Mark every node in the freshly-cloned subtree as already compiled, so
412+
// the recursive compileServers call below doesn't try to re-clone any of
413+
// them. The leaves we still want compiled (the morph parameters and the
414+
// observables) are reachable from these nodes' own server lists and will
415+
// be visited.
416+
ctx.markAsCompiled(*newSum);
417+
RooArgSet allBranches;
418+
newSum->branchNodeServerList(&allBranches);
419+
for (auto *b : allBranches) {
420+
ctx.markAsCompiled(*b);
421+
}
422+
ctx.compileServers(*newSum, normSet);
423+
424+
return newSum;
425+
}
426+
427+
namespace RooFit {
428+
namespace Detail {
429+
430+
RooMomentMorphFraction::RooMomentMorphFraction(const char *name, const char *title, RooMomentMorphFuncND const &parent,
431+
int index)
432+
: RooAbsReal(name, title), _parList("parList", "parList", this), _parent(&parent), _index(index)
433+
{
434+
_parList.add(parent._parList);
435+
}
436+
437+
RooMomentMorphFraction::RooMomentMorphFraction(RooMomentMorphFraction const &other, const char *name)
438+
: RooAbsReal(other, name), _parList("parList", this, other._parList), _parent(other._parent), _index(other._index)
439+
{
440+
}
441+
442+
double RooMomentMorphFraction::evaluate() const
443+
{
444+
auto *cache = _parent->getCache(nullptr);
445+
if (cache->_tracker->hasChanged(true)) {
446+
cache->calculateFractions(*_parent, false);
447+
}
448+
return cache->frac(_index)->getVal();
449+
}
450+
451+
} // namespace Detail
452+
} // namespace RooFit
453+
348454
//_____________________________________________________________________________
349455
RooMomentMorphFuncND::CacheElem *RooMomentMorphFuncND::getCache(const RooArgSet * /*nset*/) const
350456
{

0 commit comments

Comments
 (0)