Skip to content

Commit 986e600

Browse files
committed
[df] Centrally propagate bulk size through RLoopManager
The RLoopManager decides the current bulk size that all nodes of the computation graph must adhere to. Currently this is set to one, in the future it may vary.
1 parent 7b7700a commit 986e600

15 files changed

Lines changed: 90 additions & 66 deletions

tree/dataframe/inc/ROOT/RDF/RAction.hxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ public:
119119
(void)idx; // avoid unused parameter warning (gcc 12.1)
120120
}
121121

122-
void Run(unsigned int slot, Long64_t entry) final
122+
void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) final
123123
{
124-
const auto mask = fPrevNode.CheckFilters(slot, entry);
124+
const auto mask = fPrevNode.CheckFilters(slot, bulkBeginEntry, bulkSize);
125125
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
126126

127127
// Assume 1-size bulk for now

tree/dataframe/inc/ROOT/RDF/RActionBase.hxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ public:
6363
RColumnRegister &GetColRegister() { return fColRegister; }
6464
RLoopManager *GetLoopManager() { return fLoopManager; }
6565
unsigned int GetNSlots() const { return fNSlots; }
66-
virtual void Run(unsigned int slot, Long64_t entry) = 0;
6766
virtual void Initialize() = 0;
6867
virtual void InitSlot(TTreeReader *r, unsigned int slot) = 0;
6968
virtual void TriggerChildrenCount() = 0;
@@ -92,6 +91,8 @@ public:
9291

9392
virtual std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&results) = 0;
9493
virtual std::unique_ptr<RActionBase> CloneAction(void *newResult) = 0;
94+
95+
virtual void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) = 0;
9596
};
9697
} // namespace RDF
9798
} // namespace Internal

tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,36 +191,35 @@ public:
191191
fHelper.Exec(slot, untypedValues);
192192
}
193193

194-
void Run(unsigned int slot, Long64_t entry) final
194+
void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) final
195195
{
196196
if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
197197
// check if entry passes all filters
198198
std::vector<ROOT::Internal::RDF::RMaskedEntryRange> filterPassed(fPrevNodes.size(), 1ul);
199199
for (unsigned int variation = 0; variation < fPrevNodes.size(); ++variation) {
200-
filterPassed[variation] = fPrevNodes[variation]->CheckFilters(slot, entry);
200+
filterPassed[variation] = fPrevNodes[variation]->CheckFilters(slot, bulkBeginEntry, bulkSize);
201201
}
202202

203203
// Currently, every event where any of nominal or variations pass gets written to the output.
204204
// This logic could be extended for different use cases if the need arises.
205-
// Assume 1-size bulk for now
206205
if (std::any_of(filterPassed.begin(), filterPassed.end(),
207206
[](const ROOT::Internal::RDF::RMaskedEntryRange &val) { return val[0]; })) {
208207
// TODO: Don't allocate
209208
std::vector<void *> untypedValues;
210209
auto nReaders = fValues[slot].size();
211210
untypedValues.reserve(nReaders);
212-
std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry](auto *v) {
213-
v->Load(ROOT::Internal::RDF::RMaskedEntryRange{1ul, true, static_cast<std::uint64_t>(entry)});
211+
std::for_each(fValues[slot].begin(), fValues[slot].end(), [bulkBeginEntry, bulkSize](auto *v) {
212+
v->Load(
213+
ROOT::Internal::RDF::RMaskedEntryRange{bulkSize, true, static_cast<std::uint64_t>(bulkBeginEntry)});
214214
});
215215
for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++)
216216
untypedValues.push_back(GetValue(slot, readerIdx, /*idx=*/0u));
217217

218218
fHelper.Exec(slot, untypedValues, filterPassed);
219219
}
220220
} else {
221-
const auto mask = fPrevNodes.front()->CheckFilters(slot, entry);
221+
const auto mask = fPrevNodes.front()->CheckFilters(slot, bulkBeginEntry, bulkSize);
222222
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
223-
// Assume 1-size bulk for now
224223
if (mask[0])
225224
CallExec(slot, /*idx=*/0u);
226225
}

tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public:
112112

113113
// Assume 1-size bulk for now
114114
fValues[slot]->Load(mask);
115-
const std::size_t bulkSize = 1;
115+
const std::size_t bulkSize = fLoopManager->GetCurrentBulkSize();
116116
for (std::size_t i = 0; i < bulkSize; ++i) {
117117
if (mask[i])
118118
fLastResults[slot * RDFInternal::CacheLineStep<T>()] = GetValueOrDefault(slot, i);

tree/dataframe/inc/ROOT/RDF/RFilter.hxx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,16 @@ public:
9393
fLoopManager->Deregister(this);
9494
}
9595

96-
ROOT::Internal::RDF::RMaskedEntryRange CheckFilters(unsigned int slot, Long64_t entry) final
96+
ROOT::Internal::RDF::RMaskedEntryRange
97+
CheckFilters(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) final
9798
{
9899
auto &cachedResults = fCachedResults[slot * RDFInternal::CacheLineStep<ROOT::RVec<bool>>()];
99-
if (entry == fLastCheckedEntry[slot * ROOT::Internal::RDF::CacheLineStep<Long64_t>()])
100-
return {cachedResults, static_cast<std::uint64_t>(entry)};
100+
if (bulkBeginEntry == fLastCheckedEntry[slot * ROOT::Internal::RDF::CacheLineStep<Long64_t>()])
101+
return {cachedResults, static_cast<std::uint64_t>(bulkBeginEntry)};
101102

102-
auto mask = fPrevNode.CheckFilters(slot, entry);
103+
auto mask = fPrevNode.CheckFilters(slot, bulkBeginEntry, bulkSize);
103104
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
104-
// Assume 1-size bulk for now
105-
const std::size_t bulkSize{1};
105+
106106
std::size_t accepted{0};
107107
std::size_t rejected{0};
108108
cachedResults.clear();
@@ -116,11 +116,11 @@ public:
116116
++rejected;
117117
}
118118
}
119-
fLastCheckedEntry[slot * ROOT::Internal::RDF::CacheLineStep<Long64_t>()] = entry;
119+
fLastCheckedEntry[slot * ROOT::Internal::RDF::CacheLineStep<Long64_t>()] = bulkBeginEntry;
120120
fAccepted[slot * RDFInternal::CacheLineStep<ULong64_t>()] += accepted;
121121
fRejected[slot * RDFInternal::CacheLineStep<ULong64_t>()] += rejected;
122122

123-
return {cachedResults, static_cast<std::uint64_t>(entry)};
123+
return {cachedResults, static_cast<std::uint64_t>(bulkBeginEntry)};
124124
}
125125

126126
template <typename... ColTypes, std::size_t... S>

tree/dataframe/inc/ROOT/RDF/RFilterWithMissingValues.hxx

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,16 @@ public:
9696
fLoopManager->EraseSuppressErrorsForMissingBranch(fColumnNames[0]);
9797
}
9898

99-
ROOT::Internal::RDF::RMaskedEntryRange CheckFilters(unsigned int slot, Long64_t entry) final
99+
ROOT::Internal::RDF::RMaskedEntryRange
100+
CheckFilters(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) final
100101
{
101102
constexpr static auto cacheLineStepLong64_t = RDFInternal::CacheLineStep<Long64_t>();
102103
constexpr static auto cacheLineStepULong64_t = RDFInternal::CacheLineStep<ULong64_t>();
103104

104-
if (entry == fLastCheckedEntry[slot * cacheLineStepLong64_t])
105-
return {fCachedResults[slot], static_cast<std::uint64_t>(entry)};
105+
if (bulkBeginEntry == fLastCheckedEntry[slot * cacheLineStepLong64_t])
106+
return {fCachedResults[slot], static_cast<std::uint64_t>(bulkBeginEntry)};
106107

107-
// Assume 1-size bulk for now
108-
const std::size_t bulkSize{1};
109-
auto mask = fPrevNodePtr->CheckFilters(slot, entry);
108+
auto mask = fPrevNodePtr->CheckFilters(slot, bulkBeginEntry, bulkSize);
110109

111110
fValues[slot]->Load(mask);
112111

@@ -127,9 +126,9 @@ public:
127126
}
128127
fAccepted[slot * cacheLineStepULong64_t] += accepted;
129128
fRejected[slot * cacheLineStepULong64_t] += rejected;
130-
fLastCheckedEntry[slot * cacheLineStepLong64_t] = entry;
129+
fLastCheckedEntry[slot * cacheLineStepLong64_t] = bulkBeginEntry;
131130

132-
return {fCachedResults[slot], static_cast<std::uint64_t>(entry)};
131+
return {fCachedResults[slot], static_cast<std::uint64_t>(bulkBeginEntry)};
133132
}
134133

135134
void InitSlot(TTreeReader *r, unsigned int slot) final

tree/dataframe/inc/ROOT/RDF/RJittedAction.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public:
5050

5151
void SetAction(std::unique_ptr<RActionBase> a) { fConcreteAction = std::move(a); }
5252

53-
void Run(unsigned int slot, Long64_t entry) final;
53+
void Run(unsigned int, Long64_t, std::size_t) final;
5454
void Initialize() final;
5555
void InitSlot(TTreeReader *r, unsigned int slot) final;
5656
void TriggerChildrenCount() final;

tree/dataframe/inc/ROOT/RDF/RJittedFilter.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public:
5555
void SetFilter(std::unique_ptr<RFilterBase> f);
5656

5757
void InitSlot(TTreeReader *r, unsigned int slot) final;
58-
ROOT::Internal::RDF::RMaskedEntryRange CheckFilters(unsigned int, Long64_t) final;
58+
ROOT::Internal::RDF::RMaskedEntryRange CheckFilters(unsigned int, Long64_t, std::size_t) final;
5959
void Report(ROOT::RDF::RCutFlowReport &) const final;
6060
void PartialReport(ROOT::RDF::RCutFlowReport &) const final;
6161
void FillReport(ROOT::RDF::RCutFlowReport &) const final;

tree/dataframe/inc/ROOT/RDF/RLoopManager.hxx

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class RLoopManager : public RNodeBase {
181181
void RunEmptySource();
182182
void RunDataSourceMT();
183183
void RunDataSource();
184-
void RunAndCheckFilters(unsigned int slot, Long64_t entry);
184+
void RunAndCheckFilters(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize);
185185
void InitNodeSlots(TTreeReader *r, unsigned int slot);
186186
void InitNodes();
187187
void CleanUpNodes();
@@ -224,6 +224,9 @@ class RLoopManager : public RNodeBase {
224224
std::vector<DeferredJitCall> fJitHelperCalls{};
225225
std::hash<std::string> fStringHasher{};
226226

227+
// Assume 1-size bulk for now
228+
std::size_t fCurrentBulkSize{1};
229+
227230
public:
228231
RLoopManager(const ColumnNames_t &defaultColumns = {});
229232
RLoopManager(TTree *tree, const ColumnNames_t &defaultBranches);
@@ -256,7 +259,7 @@ public:
256259
void Deregister(RDefineBase *definePtr);
257260
void Register(RDFInternal::RVariationBase *varPtr);
258261
void Deregister(RDFInternal::RVariationBase *varPtr);
259-
ROOT::Internal::RDF::RMaskedEntryRange CheckFilters(unsigned int, Long64_t) final;
262+
ROOT::Internal::RDF::RMaskedEntryRange CheckFilters(unsigned int, Long64_t, std::size_t) final;
260263
unsigned int GetNSlots() const { return fNSlots; }
261264
void Report(ROOT::RDF::RCutFlowReport &rep) const final;
262265
/// End of recursive chain of calls, does nothing
@@ -336,6 +339,8 @@ public:
336339
/// The task run by every thread on an entry range (known by the input TTreeReader), for the TTree data source.
337340
void
338341
TTreeThreadTask(TTreeReader &treeReader, ROOT::Internal::RSlotStack &slotStack, std::atomic<ULong64_t> &entryCount);
342+
343+
std::size_t GetCurrentBulkSize() { return fCurrentBulkSize; }
339344
};
340345

341346
/// \brief Create an RLoopManager that reads a TChain.

tree/dataframe/inc/ROOT/RDF/RNodeBase.hxx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ public:
6565
RNodeBase &operator=(RNodeBase &&) = delete;
6666
virtual ~RNodeBase() = default;
6767

68-
virtual ROOT::Internal::RDF::RMaskedEntryRange CheckFilters(unsigned int, Long64_t) = 0;
6968
virtual void Report(ROOT::RDF::RCutFlowReport &) const = 0;
7069
virtual void PartialReport(ROOT::RDF::RCutFlowReport &) const = 0;
7170
virtual void IncrChildrenCount() = 0;
@@ -92,6 +91,9 @@ public:
9291
"GetVariedFilter was called on a node type that does not implement it. This should never happen.");
9392
return nullptr;
9493
}
94+
95+
virtual ROOT::Internal::RDF::RMaskedEntryRange
96+
CheckFilters(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) = 0;
9597
};
9698
} // ns RDF
9799
} // ns Detail

0 commit comments

Comments
 (0)