Skip to content

Commit d33254e

Browse files
committed
[df] Prepare for bulk nodes that update values
define and variation nodes update their current available values via an Update method, this is now prepared to work in bulks, currently assuming always size one.
1 parent 986e600 commit d33254e

4 files changed

Lines changed: 67 additions & 38 deletions

File tree

tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ template <typename T>
4545
class R__CLING_PTRCHECK(off) RDefaultValueFor final : public RDefineBase {
4646
using ColumnTypes_t = ROOT::TypeTraits::TypeList<T>;
4747
using TypeInd_t = std::make_index_sequence<ColumnTypes_t::list_size>;
48-
// Avoid instantiating vector<bool> as `operator[]` returns temporaries in that case. Use std::deque instead.
49-
using ValuesPerSlot_t = std::conditional_t<std::is_same<T, bool>::value, std::deque<T>, std::vector<T>>;
48+
49+
using ValuesPerSlot_t = std::vector<ROOT::RVec<T>>;
5050

5151
T fDefaultValue;
52-
ValuesPerSlot_t fLastResults;
52+
// Each slot accesses a cache of values for the current bulk
53+
ValuesPerSlot_t fCachedResultsPerSlot;
5354
// One column reader per slot
5455
std::vector<RColumnReaderBase *> fValues;
5556

@@ -71,12 +72,16 @@ public:
7172
RLoopManager &lm, const std::string &variationName = "nominal")
7273
: RDefineBase(name, type, colRegister, lm, columns, variationName),
7374
fDefaultValue(defaultValue),
74-
fLastResults(lm.GetNSlots() * RDFInternal::CacheLineStep<T>()),
75+
fCachedResultsPerSlot(lm.GetNSlots() * RDFInternal::CacheLineStep<ROOT::RVec<T>>()),
7576
fValues(lm.GetNSlots())
7677
{
7778
fLoopManager->Register(this);
7879
// We suppress errors that TTreeReader prints regarding the missing branch
7980
fLoopManager->InsertSuppressErrorsForMissingBranch(fColumnNames[0]);
81+
// Assume 1-size bulk for now
82+
for (decltype(lm.GetNSlots()) i = 0; i < lm.GetNSlots(); ++i) {
83+
fCachedResultsPerSlot[i * RDFInternal::CacheLineStep<ROOT::RVec<T>>()].resize(1ul);
84+
}
8085
}
8186

8287
RDefaultValueFor(const RDefaultValueFor &) = delete;
@@ -97,10 +102,10 @@ public:
97102
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = -1;
98103
}
99104

100-
/// Return the (type-erased) address of the Define'd value for the given processing slot.
105+
/// Return the beginning of the cached results of the current bulk for the input processing slot
101106
void *GetValuePtr(unsigned int slot) final
102107
{
103-
return static_cast<void *>(&fLastResults[slot * RDFInternal::CacheLineStep<T>()]);
108+
return static_cast<void *>(fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<T>>()].data());
104109
}
105110

106111
/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
@@ -113,9 +118,12 @@ public:
113118
// Assume 1-size bulk for now
114119
fValues[slot]->Load(mask);
115120
const std::size_t bulkSize = fLoopManager->GetCurrentBulkSize();
121+
auto &result = fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<T>>()];
122+
result.clear();
123+
result.resize(bulkSize);
116124
for (std::size_t i = 0; i < bulkSize; ++i) {
117125
if (mask[i])
118-
fLastResults[slot * RDFInternal::CacheLineStep<T>()] = GetValueOrDefault(slot, i);
126+
fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<T>>()][i] = GetValueOrDefault(slot, i);
119127
}
120128
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = mask.GetFirstEntry();
121129
}

tree/dataframe/inc/ROOT/RDF/RDefine.hxx

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,11 @@ class R__CLING_PTRCHECK(off) RDefine final : public RDefineBase {
5656
RDFInternal::RemoveFirstTwoParametersIf_t<std::is_same<ExtraArgsTag, SlotAndEntryTag>::value, ColumnTypesTmp_t>;
5757
using TypeInd_t = std::make_index_sequence<ColumnTypes_t::list_size>;
5858
using ret_type = typename CallableTraits<F>::ret_type;
59-
// Avoid instantiating vector<bool> as `operator[]` returns temporaries in that case. Use std::deque instead.
60-
using ValuesPerSlot_t =
61-
std::conditional_t<std::is_same<ret_type, bool>::value, std::deque<ret_type>, std::vector<ret_type>>;
59+
using ValuesPerSlot_t = std::vector<ROOT::RVec<ret_type>>;
6260

6361
F fExpression;
64-
ValuesPerSlot_t fLastResults;
62+
// Each slot accesses a cache of values for the current bulk
63+
ValuesPerSlot_t fCachedResultsPerSlot;
6564

6665
/// Column readers per slot and per input column
6766
std::vector<std::array<RColumnReaderBase *, ColumnTypes_t::list_size>> fValues;
@@ -114,10 +113,16 @@ public:
114113
RDefine(std::string_view name, std::string_view type, F expression, const ROOT::RDF::ColumnNames_t &columns,
115114
const RDFInternal::RColumnRegister &colRegister, RLoopManager &lm,
116115
const std::string &variationName = "nominal")
117-
: RDefineBase(name, type, colRegister, lm, columns, variationName), fExpression(std::move(expression)),
118-
fLastResults(lm.GetNSlots() * RDFInternal::CacheLineStep<ret_type>()), fValues(lm.GetNSlots())
116+
: RDefineBase(name, type, colRegister, lm, columns, variationName),
117+
fExpression(std::move(expression)),
118+
fCachedResultsPerSlot(lm.GetNSlots() * RDFInternal::CacheLineStep<ROOT::RVec<ret_type>>()),
119+
fValues(lm.GetNSlots())
119120
{
120121
fLoopManager->Register(this);
122+
// Assume 1-size bulk for now
123+
for (decltype(lm.GetNSlots()) i = 0; i < lm.GetNSlots(); ++i) {
124+
fCachedResultsPerSlot[i * RDFInternal::CacheLineStep<ROOT::RVec<ret_type>>()].resize(1ul);
125+
}
121126
}
122127

123128
RDefine(const RDefine &) = delete;
@@ -131,13 +136,14 @@ public:
131136
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = -1;
132137
}
133138

134-
/// Return the (type-erased) address of the Define'd value for the given processing slot.
139+
/// Return the beginning of the cached results of the current bulk for the input processing slot
135140
void *GetValuePtr(unsigned int slot) final
136141
{
137-
return static_cast<void *>(&fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()]);
142+
return static_cast<void *>(
143+
fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<ret_type>>()].data());
138144
}
139145

140-
/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
146+
/// Update the values at the array returned by GetValuePtr with the content corresponding to the given mask
141147
void Update(unsigned int slot, const ROOT::Internal::RDF::RMaskedEntryRange &mask) final
142148
{
143149
if (static_cast<Long64_t>(mask.GetFirstEntry()) ==
@@ -147,10 +153,12 @@ public:
147153
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
148154
// Assume 1-size bulk for now
149155
const std::size_t bulkSize = 1;
150-
auto &result = fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()];
156+
auto &result = fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<ret_type>>()];
157+
result.clear();
158+
result.resize(bulkSize);
151159
for (std::size_t i = 0; i < bulkSize; ++i) {
152160
if (mask[i]) {
153-
result = UpdateHelper(slot, i, mask.GetFirstEntry() + i, ColumnTypes_t{}, TypeInd_t{}, ExtraArgsTag{});
161+
result[i] = UpdateHelper(slot, i, mask.GetFirstEntry() + i, ColumnTypes_t{}, TypeInd_t{}, ExtraArgsTag{});
154162
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = mask.GetFirstEntry();
155163
}
156164
}

tree/dataframe/inc/ROOT/RDF/RDefinePerSample.hxx

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,33 +30,37 @@ template <typename F>
3030
class R__CLING_PTRCHECK(off) RDefinePerSample final : public RDefineBase {
3131
using RetType_t = typename CallableTraits<F>::ret_type;
3232

33-
// Avoid instantiating vector<bool> as `operator[]` returns temporaries in that case. Use std::deque instead.
34-
using ValuesPerSlot_t =
35-
std::conditional_t<std::is_same<RetType_t, bool>::value, std::deque<RetType_t>, std::vector<RetType_t>>;
33+
using ValuesPerSlot_t = std::vector<ROOT::RVec<RetType_t>>;
3634

3735
F fExpression;
38-
ValuesPerSlot_t fLastResults;
36+
// Each slot accesses a cache of values for the current bulk
37+
ValuesPerSlot_t fCachedResultsPerSlot;
3938

4039
public:
4140
RDefinePerSample(std::string_view name, std::string_view type, F expression, RLoopManager &lm)
4241
: RDefineBase(name, type, RDFInternal::RColumnRegister{&lm}, lm, /*columnNames*/ {}),
4342
fExpression(std::move(expression)),
44-
fLastResults(lm.GetNSlots() * RDFInternal::CacheLineStep<RetType_t>())
43+
fCachedResultsPerSlot(lm.GetNSlots() * RDFInternal::CacheLineStep<ROOT::RVec<RetType_t>>())
4544
{
4645
fLoopManager->Register(this);
4746
auto callUpdate = [this](unsigned int slot, const ROOT::RDF::RSampleInfo &id) { this->Update(slot, id); };
4847
fLoopManager->AddSampleCallback(this, std::move(callUpdate));
48+
// Assume 1-size bulk for now
49+
for (decltype(lm.GetNSlots()) i = 0; i < lm.GetNSlots(); ++i) {
50+
fCachedResultsPerSlot[i * RDFInternal::CacheLineStep<ROOT::RVec<RetType_t>>()].resize(1ul);
51+
}
4952
}
5053

5154
RDefinePerSample(const RDefinePerSample &) = delete;
5255
RDefinePerSample &operator=(const RDefinePerSample &) = delete;
5356

5457
~RDefinePerSample() { fLoopManager->Deregister(this); }
5558

56-
/// Return the (type-erased) address of the Define'd value for the given processing slot.
59+
/// Return the beginning of the cached results of the current bulk for the input processing slot
5760
void *GetValuePtr(unsigned int slot) final
5861
{
59-
return static_cast<void *>(&fLastResults[slot * RDFInternal::CacheLineStep<RetType_t>()]);
62+
return static_cast<void *>(
63+
fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<RetType_t>>()].data());
6064
}
6165

6266
void Update(unsigned int, const ROOT::Internal::RDF::RMaskedEntryRange &) final
@@ -67,7 +71,8 @@ public:
6771
/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
6872
void Update(unsigned int slot, const ROOT::RDF::RSampleInfo &id) final
6973
{
70-
fLastResults[slot * RDFInternal::CacheLineStep<RetType_t>()] = fExpression(slot, id);
74+
// Assume 1-size bulk for now
75+
fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<RetType_t>>()][0] = fExpression(slot, id);
7176
}
7277

7378
const std::type_info &GetTypeId() const final { return typeid(RetType_t); }

tree/dataframe/inc/ROOT/RDF/RVariation.hxx

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ void AssignResults(ROOT::RVec<T> &resStorage, ROOT::RVec<T> &&tmpResults)
7777
}
7878

7979
template <typename T>
80-
void *GetValuePtrHelper(ROOT::RVec<T> &v, std::size_t /*colIdx*/, std::size_t varIdx)
80+
void *GetValuePtrHelper(ROOT::RVec<ROOT::RVec<T>> &v, std::size_t /*colIdx*/, std::size_t varIdx)
8181
{
82-
return static_cast<void *>(&v[varIdx]);
82+
return static_cast<void *>(&v[0][varIdx]);
8383
}
8484
///@}
8585

@@ -119,9 +119,9 @@ void AssignResults(std::vector<ROOT::RVec<T>> &resStorage, ROOT::RVec<ROOT::RVec
119119
}
120120

121121
template <typename T>
122-
void *GetValuePtrHelper(std::vector<ROOT::RVec<T>> &v, std::size_t colIdx, std::size_t varIdx)
122+
void *GetValuePtrHelper(ROOT::RVec<std::vector<ROOT::RVec<T>>> &v, std::size_t colIdx, std::size_t varIdx)
123123
{
124-
return static_cast<void *>(&v[colIdx][varIdx]);
124+
return static_cast<void *>(&v[0][colIdx][varIdx]);
125125
}
126126
///@}
127127

@@ -151,10 +151,11 @@ class R__CLING_PTRCHECK(off) RVariation final : public RVariationBase {
151151
using Ret_t = typename CallableTraits<F>::ret_type;
152152
using VariedCol_t = ColumnType_t<IsSingleColumn, Ret_t>;
153153
using Result_t = std::conditional_t<IsSingleColumn, ROOT::RVec<VariedCol_t>, std::vector<ROOT::RVec<VariedCol_t>>>;
154+
using ValuesPerSlot_t = std::vector<ROOT::RVec<Result_t>>;
154155

155156
F fExpression;
156-
/// Per-slot storage for varied column values (for one or multiple columns depending on IsSingleColumn).
157-
std::vector<Result_t> fLastResults;
157+
// Each slot accesses a cache of values for the current bulk
158+
ValuesPerSlot_t fCachedResultsPerSlot;
158159

159160
/// Column readers per slot and per input column
160161
std::vector<std::array<RColumnReaderBase *, ColumnTypes_t::list_size>> fValues;
@@ -186,7 +187,8 @@ class R__CLING_PTRCHECK(off) RVariation final : public RVariationBase {
186187
std::to_string(fVariationNames.size()) + " were expected.");
187188
}
188189

189-
AssignResults(fLastResults[slot * CacheLineStep<Result_t>()], std::move(results));
190+
AssignResults(fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<Result_t>>()][0],
191+
std::move(results));
190192
}
191193

192194
public:
@@ -195,13 +197,17 @@ public:
195197
RLoopManager &lm, const ColumnNames_t &inputColNames)
196198
: RVariationBase(colNames, variationName, variationTags, type, defines, lm, inputColNames),
197199
fExpression(std::move(expression)),
198-
fLastResults(lm.GetNSlots() * CacheLineStep<Result_t>()),
200+
fCachedResultsPerSlot(lm.GetNSlots() * RDFInternal::CacheLineStep<ROOT::RVec<Result_t>>()),
199201
fValues(lm.GetNSlots())
200202
{
201203
fLoopManager->Register(this);
202204

203-
for (auto i = 0u; i < lm.GetNSlots(); ++i)
204-
ResizeResults(fLastResults[i * CacheLineStep<Result_t>()], colNames.size(), variationTags.size());
205+
// Assume 1-size bulk for now
206+
for (decltype(lm.GetNSlots()) i = 0; i < lm.GetNSlots(); ++i) {
207+
auto &cachedResultsForThisSlot = fCachedResultsPerSlot[i * RDFInternal::CacheLineStep<ROOT::RVec<Result_t>>()];
208+
cachedResultsForThisSlot.resize(1ul);
209+
ResizeResults(cachedResultsForThisSlot[0], colNames.size(), variationTags.size());
210+
}
205211
}
206212

207213
RVariation(const RVariation &) = delete;
@@ -215,7 +221,8 @@ public:
215221
fLastCheckedEntry[slot * CacheLineStep<Long64_t>()] = -1;
216222
}
217223

218-
/// Return the (type-erased) address of the value for the given processing slot.
224+
/// Return the beginning of the cached results of the current bulk for the input processing slot, column and
225+
/// variation
219226
void *GetValuePtr(unsigned int slot, const std::string &column, const std::string &variation) final
220227
{
221228
const auto colIt = std::find(fColNames.begin(), fColNames.end(), column);
@@ -226,7 +233,8 @@ public:
226233
assert(varIt != fVariationNames.end());
227234
const auto varIdx = std::distance(fVariationNames.begin(), varIt);
228235

229-
return GetValuePtrHelper(fLastResults[slot * CacheLineStep<Result_t>()], colIdx, varIdx);
236+
return GetValuePtrHelper(fCachedResultsPerSlot[slot * RDFInternal::CacheLineStep<ROOT::RVec<Result_t>>()], colIdx,
237+
varIdx);
230238
}
231239

232240
/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry

0 commit comments

Comments
 (0)