Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions roottest/root/dataframe/testIMT.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ void getTracks(unsigned int mu, FourVectors& tracks) {
// This makes the example stand-alone
void FillTree(const char* filename, const char* treeName) {
if (!gSystem->AccessPathName(filename)) return;
TFile f(filename,"RECREATE");
TTree t(treeName,treeName);
auto f = std::make_unique<TFile>(filename, "RECREATE");
auto t = std::make_unique<TTree>(treeName, treeName);
double b1;
int b2;
std::vector<FourVector> tracks;
std::vector<double> dv {-1,2,3,4};
std::list<int> sl {1,2,3,4};
t.Branch("b1", &b1);
t.Branch("b2", &b2);
t.Branch("tracks", &tracks);
t.Branch("dv", &dv);
t.Branch("sl", &sl);
t->Branch("b1", &b1);
t->Branch("b2", &b2);
t->Branch("tracks", &tracks);
t->Branch("dv", &dv);
t->Branch("sl", &sl);

int nevts = 16000;
for(int i = 0; i < nevts; ++i) {
Expand All @@ -77,11 +77,9 @@ void FillTree(const char* filename, const char* treeName) {

dv.emplace_back(i);
sl.emplace_back(i);
t.Fill();
t->Fill();
}
t.Write();
f.Close();
return;
f->Write();
}

auto fileName = "testIMT.root";
Expand Down
1 change: 1 addition & 0 deletions tree/dataframe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTDataFrame
ROOT/RDF/RJittedVariation.hxx
ROOT/RDF/RLazyDSImpl.hxx
ROOT/RDF/RLoopManager.hxx
ROOT/RDF/RMaskedEntryRange.hxx
ROOT/RDF/RMergeableValue.hxx
ROOT/RDF/RMetaData.hxx
ROOT/RDF/RNodeBase.hxx
Expand Down
3 changes: 2 additions & 1 deletion tree/dataframe/inc/ROOT/RCsvDS.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
namespace ROOT::Internal::RDF {
class R__CLING_PTRCHECK(off) RCsvDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
void *fValuePtr;
void *GetImpl(Long64_t) final { return fValuePtr; }
void *GetImpl(std::size_t) final { return fValuePtr; }
void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) final {}

public:
RCsvDSColumnReader(void *valuePtr) : fValuePtr(valuePtr) {}
Expand Down
23 changes: 13 additions & 10 deletions tree/dataframe/inc/ROOT/RDF/RAction.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -99,31 +99,34 @@ public:
}

template <typename ColType>
auto GetValueChecked(unsigned int slot, std::size_t readerIdx, Long64_t entry) -> ColType &
auto GetValueChecked(unsigned int slot, std::size_t readerIdx, std::size_t idx) -> ColType &
{
if (auto *val = fValues[slot][readerIdx]->template TryGet<ColType>(entry))
if (auto *val = fValues[slot][readerIdx]->template TryGet<ColType>(idx))
return *val;

throw std::out_of_range{"RDataFrame: Action (" + fHelper.GetActionName() +
") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " +
std::to_string(entry) +
std::to_string(idx) +
". You can use the DefaultValueFor operation to provide a default value, or "
"FilterAvailable/FilterMissing to discard/keep entries with missing values instead."};
}

template <typename... ColTypes, std::size_t... S>
void CallExec(unsigned int slot, Long64_t entry, TypeList<ColTypes...>, std::index_sequence<S...>)
void CallExec(unsigned int slot, std::size_t idx, TypeList<ColTypes...>, std::index_sequence<S...>)
{
ROOT::Internal::RDF::CallGuaranteedOrder{[&](auto &&...args) { return fHelper.Exec(slot, args...); },
GetValueChecked<ColTypes>(slot, S, entry)...};
(void)entry; // avoid unused parameter warning (gcc 12.1)
GetValueChecked<ColTypes>(slot, S, idx)...};
(void)idx; // avoid unused parameter warning (gcc 12.1)
}

void Run(unsigned int slot, Long64_t entry) final
void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) final
{
// check if entry passes all filters
if (fPrevNode.CheckFilters(slot, entry))
CallExec(slot, entry, ColumnTypes_t{}, TypeInd_t{});
const auto mask = fPrevNode.CheckFilters(slot, bulkBeginEntry, bulkSize);
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });

// Assume 1-size bulk for now
if (mask[0])
CallExec(slot, /*idx=*/0u, ColumnTypes_t{}, TypeInd_t{});
}

void TriggerChildrenCount() final { fPrevNode.IncrChildrenCount(); }
Expand Down
3 changes: 2 additions & 1 deletion tree/dataframe/inc/ROOT/RDF/RActionBase.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public:
RColumnRegister &GetColRegister() { return fColRegister; }
RLoopManager *GetLoopManager() { return fLoopManager; }
unsigned int GetNSlots() const { return fNSlots; }
virtual void Run(unsigned int slot, Long64_t entry) = 0;
virtual void Initialize() = 0;
virtual void InitSlot(TTreeReader *r, unsigned int slot) = 0;
virtual void TriggerChildrenCount() = 0;
Expand Down Expand Up @@ -92,6 +91,8 @@ public:

virtual std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&results) = 0;
virtual std::unique_ptr<RActionBase> CloneAction(void *newResult) = 0;

virtual void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) = 0;
};
} // namespace RDF
} // namespace Internal
Expand Down
31 changes: 19 additions & 12 deletions tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -166,55 +166,62 @@ public:
fHelper.InitTask(r, slot);
}

void *GetValue(unsigned int slot, std::size_t readerIdx, Long64_t entry)
void *GetValue(unsigned int slot, std::size_t readerIdx, std::size_t idx)
{
assert(slot < fValues.size());
assert(readerIdx < fValues[slot].size());
if (auto *val = fValues[slot][readerIdx]->template TryGet<void>(entry))
if (auto *val = fValues[slot][readerIdx]->template TryGet<void>(idx))
return val;

throw std::out_of_range{"RDataFrame: Action (" + fHelper.GetActionName() +
") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " +
std::to_string(entry) +
std::to_string(idx) +
". You can use the DefaultValueFor operation to provide a default value, or "
"FilterAvailable/FilterMissing to discard/keep entries with missing values instead."};
}

void CallExec(unsigned int slot, Long64_t entry)
void CallExec(unsigned int slot, std::size_t idx)
{
std::vector<void *> untypedValues;
auto nReaders = fValues[slot].size();
untypedValues.reserve(nReaders);
for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++)
untypedValues.push_back(GetValue(slot, readerIdx, entry));
untypedValues.push_back(GetValue(slot, readerIdx, idx));

fHelper.Exec(slot, untypedValues);
}

void Run(unsigned int slot, Long64_t entry) final
void Run(unsigned int slot, Long64_t bulkBeginEntry, std::size_t bulkSize) final
{
if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
// check if entry passes all filters
std::vector<bool> filterPassed(fPrevNodes.size(), false);
std::vector<ROOT::Internal::RDF::RMaskedEntryRange> filterPassed(fPrevNodes.size(), 1ul);
for (unsigned int variation = 0; variation < fPrevNodes.size(); ++variation) {
filterPassed[variation] = fPrevNodes[variation]->CheckFilters(slot, entry);
filterPassed[variation] = fPrevNodes[variation]->CheckFilters(slot, bulkBeginEntry, bulkSize);
}

// Currently, every event where any of nominal or variations pass gets written to the output.
// This logic could be extended for different use cases if the need arises.
if (std::any_of(filterPassed.begin(), filterPassed.end(), [](bool val) { return val; })) {
if (std::any_of(filterPassed.begin(), filterPassed.end(),
[](const ROOT::Internal::RDF::RMaskedEntryRange &val) { return val[0]; })) {
// TODO: Don't allocate
std::vector<void *> untypedValues;
auto nReaders = fValues[slot].size();
untypedValues.reserve(nReaders);
std::for_each(fValues[slot].begin(), fValues[slot].end(), [bulkBeginEntry, bulkSize](auto *v) {
v->Load(
ROOT::Internal::RDF::RMaskedEntryRange{bulkSize, true, static_cast<std::uint64_t>(bulkBeginEntry)});
});
for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++)
untypedValues.push_back(GetValue(slot, readerIdx, entry));
untypedValues.push_back(GetValue(slot, readerIdx, /*idx=*/0u));

fHelper.Exec(slot, untypedValues, filterPassed);
}
} else {
if (fPrevNodes.front()->CheckFilters(slot, entry))
CallExec(slot, entry);
const auto mask = fPrevNodes.front()->CheckFilters(slot, bulkBeginEntry, bulkSize);
std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
if (mask[0])
CallExec(slot, /*idx=*/0u);
}
}

Expand Down
14 changes: 11 additions & 3 deletions tree/dataframe/inc/ROOT/RDF/RColumnReaderBase.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define ROOT_INTERNAL_RDF_RCOLUMNREADERBASE

#include <Rtypes.h>
#include <ROOT/RDF/RMaskedEntryRange.hxx>

namespace ROOT {
namespace Detail {
Expand All @@ -26,23 +27,30 @@ This pure virtual class provides a common base class for the different column re
RDSColumnReader.
**/
class R__CLING_PTRCHECK(off) RColumnReaderBase {

public:
virtual ~RColumnReaderBase() = default;

/// Load the column value for the given entry.
/// \param entry The entry number to load.
/// \param mask The entry mask. Values will be loaded only for entries for which the mask equals true.
void Load(const ROOT::Internal::RDF::RMaskedEntryRange &mask) { LoadImpl(mask); }

/// Return the column value for the given entry.
/// \tparam T The column type
/// \param entry The entry number
///
/// The caller is responsible for checking that the returned value actually
/// exists.
template <typename T>
T *TryGet(Long64_t entry)
T *TryGet(std::size_t idx)
{
return static_cast<T *>(GetImpl(entry));
return static_cast<T *>(GetImpl(idx));
}

private:
virtual void *GetImpl(Long64_t entry) = 0;
virtual void *GetImpl(std::size_t idx) = 0;
virtual void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) = 0;
};

} // namespace RDF
Expand Down
3 changes: 2 additions & 1 deletion tree/dataframe/inc/ROOT/RDF/RDSColumnReader.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ template <typename T>
class R__CLING_PTRCHECK(off) RDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
T **fDSValuePtr = nullptr;

void *GetImpl(Long64_t) final { return *fDSValuePtr; }
void *GetImpl(std::size_t) final { return *fDSValuePtr; }
void LoadImpl(const ROOT::Internal::RDF::RMaskedEntryRange &) final {}

public:
RDSColumnReader(void *DSValuePtr) : fDSValuePtr(static_cast<T **>(DSValuePtr)) {}
Expand Down
21 changes: 14 additions & 7 deletions tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class R__CLING_PTRCHECK(off) RDefaultValueFor final : public RDefineBase {
/// The map key is the full variation name, e.g. "pt:up".
std::unordered_map<std::string, std::unique_ptr<RDefineBase>> fVariedDefines;

T &GetValueOrDefault(unsigned int slot, Long64_t entry)
T &GetValueOrDefault(unsigned int slot, std::size_t idx)
{
if (auto *value = fValues[slot]->template TryGet<T>(entry))
if (auto *value = fValues[slot]->template TryGet<T>(idx))
return *value;
else
return fDefaultValue;
Expand Down Expand Up @@ -104,13 +104,20 @@ public:
}

/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
void Update(unsigned int slot, Long64_t entry) final
void Update(unsigned int slot, const ROOT::Internal::RDF::RMaskedEntryRange &mask) final
{
if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()]) {
// evaluate this define expression, cache the result
fLastResults[slot * RDFInternal::CacheLineStep<T>()] = GetValueOrDefault(slot, entry);
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = entry;
if (static_cast<Long64_t>(mask.GetFirstEntry()) ==
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()])
return;

// Assume 1-size bulk for now
fValues[slot]->Load(mask);
const std::size_t bulkSize = fLoopManager->GetCurrentBulkSize();
for (std::size_t i = 0; i < bulkSize; ++i) {
if (mask[i])
fLastResults[slot * RDFInternal::CacheLineStep<T>()] = GetValueOrDefault(slot, i);
}
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = mask.GetFirstEntry();
}

void Update(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo & /*id*/) final {}
Expand Down
53 changes: 33 additions & 20 deletions tree/dataframe/inc/ROOT/RDF/RDefine.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -71,39 +71,43 @@ class R__CLING_PTRCHECK(off) RDefine final : public RDefineBase {
std::unordered_map<std::string, std::unique_ptr<RDefineBase>> fVariedDefines;

template <typename ColType>
auto GetValueChecked(unsigned int slot, std::size_t readerIdx, Long64_t entry) -> ColType &
auto GetValueChecked(unsigned int slot, std::size_t readerIdx, std::size_t idx) -> ColType &
{
if (auto *val = fValues[slot][readerIdx]->template TryGet<ColType>(entry))
if (auto *val = fValues[slot][readerIdx]->template TryGet<ColType>(idx))
return *val;

throw std::out_of_range{"RDataFrame: Define could not retrieve value for column '" + fColumnNames[readerIdx] +
"' for entry " + std::to_string(entry) +
"' for entry " + std::to_string(idx) +
". You can use the DefaultValueFor operation to provide a default value, or "
"FilterAvailable/FilterMissing to discard/keep entries with missing values instead."};
}

template <typename... ColTypes, std::size_t... S>
void UpdateHelper(unsigned int slot, Long64_t entry, TypeList<ColTypes...>, std::index_sequence<S...>, NoneTag)
auto UpdateHelper(unsigned int slot, std::size_t idx, Long64_t /*entry*/, TypeList<ColTypes...>,
std::index_sequence<S...>, NoneTag)
{
fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()] =
fExpression(GetValueChecked<ColTypes>(slot, S, entry)...);
(void)entry; // avoid unused parameter warning (gcc 12.1)
return fExpression(GetValueChecked<ColTypes>(slot, S, idx)...);
(void)slot; // avoid unused parameter warning
(void)idx; // avoid unused parameter warning
}

template <typename... ColTypes, std::size_t... S>
void UpdateHelper(unsigned int slot, Long64_t entry, TypeList<ColTypes...>, std::index_sequence<S...>, SlotTag)
auto UpdateHelper(unsigned int slot, std::size_t idx, Long64_t /*entry*/, TypeList<ColTypes...>,
std::index_sequence<S...>, SlotTag)
{
fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()] =
fExpression(slot, GetValueChecked<ColTypes>(slot, S, entry)...);
(void)entry; // avoid unused parameter warning (gcc 12.1)
return fExpression(slot, GetValueChecked<ColTypes>(slot, S, idx)...);
(void)slot; // avoid unused parameter warning
(void)idx; // avoid unused parameter warning
}

template <typename... ColTypes, std::size_t... S>
void
UpdateHelper(unsigned int slot, Long64_t entry, TypeList<ColTypes...>, std::index_sequence<S...>, SlotAndEntryTag)
auto UpdateHelper(unsigned int slot, std::size_t idx, Long64_t entryInBatch, TypeList<ColTypes...>,
std::index_sequence<S...>, SlotAndEntryTag)
{
fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()] =
fExpression(slot, entry, GetValueChecked<ColTypes>(slot, S, entry)...);
return fExpression(slot, entryInBatch, GetValueChecked<ColTypes>(slot, S, idx)...);
(void)slot; // avoid unused parameter warning
(void)idx; // avoid unused parameter warning
(void)entryInBatch; // avoid unused parameter warning
}

public:
Expand Down Expand Up @@ -134,12 +138,21 @@ public:
}

/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
void Update(unsigned int slot, Long64_t entry) final
void Update(unsigned int slot, const ROOT::Internal::RDF::RMaskedEntryRange &mask) final
{
if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()]) {
// evaluate this define expression, cache the result
UpdateHelper(slot, entry, ColumnTypes_t{}, TypeInd_t{}, ExtraArgsTag{});
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = entry;
if (static_cast<Long64_t>(mask.GetFirstEntry()) ==
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()])
return;

std::for_each(fValues[slot].begin(), fValues[slot].end(), [&mask](auto *v) { v->Load(mask); });
// Assume 1-size bulk for now
const std::size_t bulkSize = 1;
auto &result = fLastResults[slot * RDFInternal::CacheLineStep<ret_type>()];
for (std::size_t i = 0; i < bulkSize; ++i) {
if (mask[i]) {
result = UpdateHelper(slot, i, mask.GetFirstEntry() + i, ColumnTypes_t{}, TypeInd_t{}, ExtraArgsTag{});
fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = mask.GetFirstEntry();
}
}
}

Expand Down
7 changes: 5 additions & 2 deletions tree/dataframe/inc/ROOT/RDF/RDefineBase.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "ROOT/RDF/RSampleInfo.hxx"
#include "ROOT/RDF/Utils.hxx"
#include "ROOT/RVec.hxx"
#include <ROOT/RDF/RMaskedEntryRange.hxx>

#include <deque>
#include <map>
Expand All @@ -40,7 +41,9 @@ class RDefineBase {
protected:
const std::string fName; ///< The name of the custom column
const std::string fType; ///< The type of the custom column as a text string
std::vector<Long64_t> fLastCheckedEntry;
std::vector<Long64_t> fLastCheckedEntry; /// Starting entry of the last bulk processed, per slot
/// Which entries in the current bulk are valid, per slot
std::vector<ROOT::Internal::RDF::RMaskedEntryRange> fMaskPerSlot;
RDFInternal::RColumnRegister fColRegister;
RLoopManager *fLoopManager; // non-owning pointer to the RLoopManager
const ROOT::RDF::ColumnNames_t fColumnNames;
Expand All @@ -63,7 +66,7 @@ public:
std::string GetName() const;
std::string GetTypeName() const;
/// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry
virtual void Update(unsigned int slot, Long64_t entry) = 0;
virtual void Update(unsigned int slot, const ROOT::Internal::RDF::RMaskedEntryRange &mask) = 0;
/// Update function to be called once per sample, used if the derived type is a RDefinePerSample
virtual void Update(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &/*id*/) {}
/// Clean-up operations to be performed at the end of a task.
Expand Down
Loading
Loading