Skip to content

Commit e2cfc9d

Browse files
committed
[df] Prepare for bulk filters
Nodes of the computation graph which may operate a selection on which entries are valid or not are modified to accommodate for multiple entries being evaluated at the same time. For now, the size of the bulk is set to one. Note that this commit does not touch in any way the reading of data.
1 parent cc9b516 commit e2cfc9d

43 files changed

Lines changed: 289 additions & 195 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

roottest/root/dataframe/testIMT.cxx

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,18 @@ void getTracks(unsigned int mu, FourVectors& tracks) {
5555
// This makes the example stand-alone
5656
void FillTree(const char* filename, const char* treeName) {
5757
if (!gSystem->AccessPathName(filename)) return;
58-
TFile f(filename,"RECREATE");
59-
TTree t(treeName,treeName);
58+
auto f = std::make_unique<TFile>(filename, "RECREATE");
59+
auto t = std::make_unique<TTree>(treeName, treeName);
6060
double b1;
6161
int b2;
6262
std::vector<FourVector> tracks;
6363
std::vector<double> dv {-1,2,3,4};
6464
std::list<int> sl {1,2,3,4};
65-
t.Branch("b1", &b1);
66-
t.Branch("b2", &b2);
67-
t.Branch("tracks", &tracks);
68-
t.Branch("dv", &dv);
69-
t.Branch("sl", &sl);
65+
t->Branch("b1", &b1);
66+
t->Branch("b2", &b2);
67+
t->Branch("tracks", &tracks);
68+
t->Branch("dv", &dv);
69+
t->Branch("sl", &sl);
7070

7171
int nevts = 16000;
7272
for(int i = 0; i < nevts; ++i) {
@@ -77,11 +77,9 @@ void FillTree(const char* filename, const char* treeName) {
7777

7878
dv.emplace_back(i);
7979
sl.emplace_back(i);
80-
t.Fill();
80+
t->Fill();
8181
}
82-
t.Write();
83-
f.Close();
84-
return;
82+
f->Write();
8583
}
8684

8785
auto fileName = "testIMT.root";

tree/dataframe/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTDataFrame
8282
ROOT/RDF/RJittedVariation.hxx
8383
ROOT/RDF/RLazyDSImpl.hxx
8484
ROOT/RDF/RLoopManager.hxx
85+
ROOT/RDF/RMaskedEntryRange.hxx
8586
ROOT/RDF/RMergeableValue.hxx
8687
ROOT/RDF/RMetaData.hxx
8788
ROOT/RDF/RNodeBase.hxx

tree/dataframe/inc/ROOT/RCsvDS.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace ROOT::Internal::RDF {
2828
class R__CLING_PTRCHECK(off) RCsvDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
2929
void *fValuePtr;
3030
void *GetImpl(std::size_t) final { return fValuePtr; }
31-
void LoadImpl(Long64_t, bool) final {}
31+
void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) final {}
3232

3333
public:
3434
RCsvDSColumnReader(void *valuePtr) : fValuePtr(valuePtr) {}

tree/dataframe/inc/ROOT/RDF/RAction.hxx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,10 @@ public:
122122
void Run(unsigned int slot, Long64_t entry) final
123123
{
124124
const auto mask = fPrevNode.CheckFilters(slot, entry);
125-
std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry, mask](auto *v) { v->Load(entry, mask); });
125+
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
126126

127-
if (mask)
127+
// Assume 1-size bulk for now
128+
if (mask[0])
128129
CallExec(slot, /*idx=*/0u, ColumnTypes_t{}, TypeInd_t{});
129130
}
130131

tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,28 +195,33 @@ public:
195195
{
196196
if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
197197
// check if entry passes all filters
198-
std::vector<bool> filterPassed(fPrevNodes.size(), false);
198+
std::vector<ROOT::Internal::RDF::RMaskedEntryRange> filterPassed(fPrevNodes.size(), 1ul);
199199
for (unsigned int variation = 0; variation < fPrevNodes.size(); ++variation) {
200200
filterPassed[variation] = fPrevNodes[variation]->CheckFilters(slot, entry);
201201
}
202202

203203
// Currently, every event where any of nominal or variations pass gets written to the output.
204204
// This logic could be extended for different use cases if the need arises.
205-
if (std::any_of(filterPassed.begin(), filterPassed.end(), [](bool val) { return val; })) {
205+
// Assume 1-size bulk for now
206+
if (std::any_of(filterPassed.begin(), filterPassed.end(),
207+
[](const ROOT::Internal::RDF::RMaskedEntryRange &val) { return val[0]; })) {
206208
// TODO: Don't allocate
207209
std::vector<void *> untypedValues;
208210
auto nReaders = fValues[slot].size();
209211
untypedValues.reserve(nReaders);
210-
std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry](auto *v) { v->Load(entry, true); });
212+
std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry](auto *v) {
213+
v->Load(ROOT::Internal::RDF::RMaskedEntryRange{1ul, true, static_cast<std::uint64_t>(entry)});
214+
});
211215
for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++)
212216
untypedValues.push_back(GetValue(slot, readerIdx, /*idx=*/0u));
213217

214218
fHelper.Exec(slot, untypedValues, filterPassed);
215219
}
216220
} else {
217221
const auto mask = fPrevNodes.front()->CheckFilters(slot, entry);
218-
std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry, mask](auto *v) { v->Load(entry, mask); });
219-
if (mask)
222+
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
223+
// Assume 1-size bulk for now
224+
if (mask[0])
220225
CallExec(slot, /*idx=*/0u);
221226
}
222227
}

tree/dataframe/inc/ROOT/RDF/RColumnReaderBase.hxx

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#define ROOT_INTERNAL_RDF_RCOLUMNREADERBASE
1313

1414
#include <Rtypes.h>
15+
#include <ROOT/RDF/RMaskedEntryRange.hxx>
1516

1617
namespace ROOT {
1718
namespace Detail {
@@ -26,22 +27,14 @@ This pure virtual class provides a common base class for the different column re
2627
RDSColumnReader.
2728
**/
2829
class R__CLING_PTRCHECK(off) RColumnReaderBase {
29-
Long64_t fLoadedEntry = -1;
3030

3131
public:
3232
virtual ~RColumnReaderBase() = default;
3333

3434
/// Load the column value for the given entry.
3535
/// \param entry The entry number to load.
3636
/// \param mask The entry mask. Values will be loaded only for entries for which the mask equals true.
37-
void Load(Long64_t entry, bool mask)
38-
{
39-
// For now, as `mask` is just a single boolean, as an optimization we can return early here if `mask == false`.
40-
if (mask) {
41-
fLoadedEntry = entry;
42-
this->LoadImpl(entry, mask);
43-
}
44-
}
37+
void Load(const ROOT::Internal::RDF::RMaskedEntryRange &mask) { LoadImpl(mask); }
4538

4639
/// Return the column value for the given entry.
4740
/// \tparam T The column type
@@ -57,7 +50,7 @@ public:
5750

5851
private:
5952
virtual void *GetImpl(std::size_t idx) = 0;
60-
virtual void LoadImpl(Long64_t /*entry*/, bool /*mask*/) = 0;
53+
virtual void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) = 0;
6154
};
6255

6356
} // namespace RDF

tree/dataframe/inc/ROOT/RDF/RDSColumnReader.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class R__CLING_PTRCHECK(off) RDSColumnReader final : public ROOT::Detail::RDF::R
2424
T **fDSValuePtr = nullptr;
2525

2626
void *GetImpl(std::size_t) final { return *fDSValuePtr; }
27-
void LoadImpl(Long64_t, bool) final {}
27+
void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) final {}
2828

2929
public:
3030
RDSColumnReader(void *DSValuePtr) : fDSValuePtr(static_cast<T **>(DSValuePtr)) {}

tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,20 @@ public:
104104
}
105105

106106
/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
107-
void Update(unsigned int slot, Long64_t entry, bool mask) final
107+
void Update(unsigned int slot, const ROOT::Internal::RDF::RMaskedEntryRange &mask) final
108108
{
109-
if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()]) {
110-
// evaluate this define expression, cache the result
111-
fValues[slot]->Load(entry, mask);
112-
if (mask) {
113-
fLastResults[slot * RDFInternal::CacheLineStep<T>()] = GetValueOrDefault(slot, /*idx=*/0u);
114-
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = entry;
115-
}
109+
if (static_cast<Long64_t>(mask.GetFirstEntry()) ==
110+
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()])
111+
return;
112+
113+
// Assume 1-size bulk for now
114+
fValues[slot]->Load(mask);
115+
const std::size_t bulkSize = 1;
116+
for (std::size_t i = 0; i < bulkSize; ++i) {
117+
if (mask[i])
118+
fLastResults[slot * RDFInternal::CacheLineStep<T>()] = GetValueOrDefault(slot, i);
116119
}
120+
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = mask.GetFirstEntry();
117121
}
118122

119123
void Update(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo & /*id*/) final {}

tree/dataframe/inc/ROOT/RDF/RDefine.hxx

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,31 +83,28 @@ class R__CLING_PTRCHECK(off) RDefine final : public RDefineBase {
8383
}
8484

8585
template <typename... ColTypes, std::size_t... S>
86-
void UpdateHelper(unsigned int slot, std::size_t idx, Long64_t /*entry*/, TypeList<ColTypes...>,
86+
auto UpdateHelper(unsigned int slot, std::size_t idx, Long64_t /*entry*/, TypeList<ColTypes...>,
8787
std::index_sequence<S...>, NoneTag)
8888
{
89-
fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()] =
90-
fExpression(GetValueChecked<ColTypes>(slot, S, idx)...);
89+
return fExpression(GetValueChecked<ColTypes>(slot, S, idx)...);
9190
(void)idx; // avoid unused parameter warning (gcc 12.1)
9291
}
9392

9493
template <typename... ColTypes, std::size_t... S>
95-
void UpdateHelper(unsigned int slot, std::size_t idx, Long64_t /*entry*/, TypeList<ColTypes...>,
94+
auto UpdateHelper(unsigned int slot, std::size_t idx, Long64_t /*entry*/, TypeList<ColTypes...>,
9695
std::index_sequence<S...>, SlotTag)
9796
{
98-
fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()] =
99-
fExpression(slot, GetValueChecked<ColTypes>(slot, S, idx)...);
97+
return fExpression(slot, GetValueChecked<ColTypes>(slot, S, idx)...);
10098
(void)idx; // avoid unused parameter warning (gcc 12.1)
10199
}
102100

103101
template <typename... ColTypes, std::size_t... S>
104-
void UpdateHelper(unsigned int slot, std::size_t idx, Long64_t batchFirstEntry, TypeList<ColTypes...>,
102+
auto UpdateHelper(unsigned int slot, std::size_t idx, Long64_t entryInBatch, TypeList<ColTypes...>,
105103
std::index_sequence<S...>, SlotAndEntryTag)
106104
{
107-
fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()] =
108-
fExpression(slot, batchFirstEntry + idx, GetValueChecked<ColTypes>(slot, S, idx)...);
105+
return fExpression(slot, entryInBatch, GetValueChecked<ColTypes>(slot, S, idx)...);
109106
(void)idx; // avoid unused parameter warning (gcc 12.1)
110-
(void)batchFirstEntry; // avoid unused parameter warning (gcc 12.1)
107+
(void)entryInBatch; // avoid unused parameter warning (gcc 12.1)
111108
}
112109

113110
public:
@@ -138,13 +135,20 @@ public:
138135
}
139136

140137
/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
141-
void Update(unsigned int slot, Long64_t entry, bool mask) final
138+
void Update(unsigned int slot, const ROOT::Internal::RDF::RMaskedEntryRange &mask) final
142139
{
143-
if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()]) {
144-
std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry, mask](auto *v) { v->Load(entry, mask); });
145-
if (mask) {
146-
UpdateHelper(slot, /*idx=*/0u, entry, ColumnTypes_t{}, TypeInd_t{}, ExtraArgsTag{});
147-
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = entry;
140+
if (static_cast<Long64_t>(mask.GetFirstEntry()) ==
141+
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()])
142+
return;
143+
144+
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
145+
// Assume 1-size bulk for now
146+
const std::size_t bulkSize = 1;
147+
auto &result = fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()];
148+
for (std::size_t i = 0; i < bulkSize; ++i) {
149+
if (mask[i]) {
150+
result = UpdateHelper(slot, i, mask.GetFirstEntry() + i, ColumnTypes_t{}, TypeInd_t{}, ExtraArgsTag{});
151+
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = mask.GetFirstEntry();
148152
}
149153
}
150154
}

tree/dataframe/inc/ROOT/RDF/RDefineBase.hxx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "ROOT/RDF/RSampleInfo.hxx"
1717
#include "ROOT/RDF/Utils.hxx"
1818
#include "ROOT/RVec.hxx"
19+
#include <ROOT/RDF/RMaskedEntryRange.hxx>
1920

2021
#include <deque>
2122
#include <map>
@@ -40,7 +41,9 @@ class RDefineBase {
4041
protected:
4142
const std::string fName; ///< The name of the custom column
4243
const std::string fType; ///< The type of the custom column as a text string
43-
std::vector<Long64_t> fLastCheckedEntry;
44+
std::vector<Long64_t> fLastCheckedEntry; /// Starting entry of the last bulk processed, per slot
45+
/// Which entries in the current bulk are valid, per slot
46+
std::vector<ROOT::Internal::RDF::RMaskedEntryRange> fMaskPerSlot;
4447
RDFInternal::RColumnRegister fColRegister;
4548
RLoopManager *fLoopManager; // non-owning pointer to the RLoopManager
4649
const ROOT::RDF::ColumnNames_t fColumnNames;
@@ -63,7 +66,7 @@ public:
6366
std::string GetName() const;
6467
std::string GetTypeName() const;
6568
/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
66-
virtual void Update(unsigned int slot, Long64_t entry, bool mask) = 0;
69+
virtual void Update(unsigned int slot, const ROOT::Internal::RDF::RMaskedEntryRange &mask) = 0;
6770
/// Update function to be called once per sample, used if the derived type is a RDefinePerSample
6871
virtual void Update(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &/*id*/) {}
6972
/// Clean-up operations to be performed at the end of a task.

0 commit comments

Comments
 (0)