Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tree/ntuple/inc/ROOT/RNTupleJoinTable.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ private:
public:
//////////////////////////////////////////////////////////////////////////
/// \brief Get the entry indexes for this entry mapping.
const std::vector<ROOT::NTupleSize_t> *GetEntryIndexes(std::vector<void *> valuePtrs) const;
const std::vector<ROOT::NTupleSize_t> *GetEntryIndexes(const std::vector<JoinValue_t> &joinValues) const;

//////////////////////////////////////////////////////////////////////////
/// \brief Create a new entry mapping.
Expand Down Expand Up @@ -208,14 +208,14 @@ public:
/////////////////////////////////////////////////////////////////////////////
/// \brief Get an entry index (if it exists) for the given join field value(s), from any partition.
///
/// \param[in] valuePtrs A vector of pointers to the join field values to look up.
/// \param[in] joinValues The join field values to look up.
///
/// \note If one or more corresponding entries exist for the given value(s), the first entry index found in the join
/// table is returned.
///
/// \return An entry number that corresponds to `valuePtrs`. When there are no corresponding entries,
/// \return An entry number that corresponds to `joinValues`. When there are no corresponding entries,
/// `kInvalidNTupleIndex` is returned.
ROOT::NTupleSize_t GetEntryIndex(const std::vector<void *> &valuePtrs) const;
ROOT::NTupleSize_t GetEntryIndex(const std::vector<JoinValue_t> &joinValues) const;
};
} // namespace Internal
} // namespace Experimental
Expand Down
43 changes: 9 additions & 34 deletions tree/ntuple/src/RNTupleJoinTable.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,6 @@

#include <ROOT/RNTupleJoinTable.hxx>

namespace {
ROOT::Experimental::Internal::RNTupleJoinTable::JoinValue_t CastValuePtr(void *valuePtr, std::size_t fieldValueSize)
{
ROOT::Experimental::Internal::RNTupleJoinTable::JoinValue_t value;

switch (fieldValueSize) {
case 1: value = *reinterpret_cast<std::uint8_t *>(valuePtr); break;
case 2: value = *reinterpret_cast<std::uint16_t *>(valuePtr); break;
case 4: value = *reinterpret_cast<std::uint32_t *>(valuePtr); break;
case 8: value = *reinterpret_cast<std::uint64_t *>(valuePtr); break;
default: throw ROOT::RException(R__FAIL("value size not supported"));
}

return value;
}
} // anonymous namespace

ROOT::Experimental::Internal::RNTupleJoinTable::REntryMapping::REntryMapping(
ROOT::Internal::RPageSource &pageSource, const std::vector<std::string> &joinFieldNames,
ROOT::NTupleSize_t entryOffset)
Expand Down Expand Up @@ -61,7 +44,8 @@ ROOT::Experimental::Internal::RNTupleJoinTable::REntryMapping::REntryMapping(
"\" in join table: only integral types are allowed"));
}

auto field = fieldDesc.CreateField(desc.GetRef());
auto field = std::make_unique<ROOT::RField<JoinValue_t>>(fieldDesc.GetFieldName());
field->SetOnDiskId(fieldDesc.GetId());
ROOT::Internal::CallConnectPageSourceOnField(*field, pageSource);

fieldValues.emplace_back(field->CreateValue());
Expand All @@ -78,29 +62,20 @@ ROOT::Experimental::Internal::RNTupleJoinTable::REntryMapping::REntryMapping(
for (auto &fieldValue : fieldValues) {
// TODO(fdegeus): use bulk reading
fieldValue.Read(i);

auto valuePtr = fieldValue.GetPtr<void>();
castJoinValues.push_back(CastValuePtr(valuePtr.get(), fieldValue.GetField().GetValueSize()));
castJoinValues.push_back(fieldValue.GetRef<JoinValue_t>());
}

fMapping[RCombinedJoinFieldValue(castJoinValues)].push_back(i + entryOffset);
}
}

const std::vector<ROOT::NTupleSize_t> *
ROOT ::Experimental::Internal::RNTupleJoinTable::REntryMapping::GetEntryIndexes(std::vector<void *> valuePtrs) const
const std::vector<ROOT::NTupleSize_t> *ROOT ::Experimental::Internal::RNTupleJoinTable::REntryMapping::GetEntryIndexes(
const std::vector<JoinValue_t> &joinValues) const
{
if (valuePtrs.size() != fJoinFieldNames.size())
if (joinValues.size() != fJoinFieldNames.size())
throw RException(R__FAIL("number of value pointers must match number of join fields"));

std::vector<JoinValue_t> castJoinValues;
castJoinValues.reserve(valuePtrs.size());

for (unsigned i = 0; i < valuePtrs.size(); ++i) {
castJoinValues.push_back(CastValuePtr(valuePtrs[i], fJoinFieldValueSizes[i]));
}

if (const auto &entries = fMapping.find(RCombinedJoinFieldValue(castJoinValues)); entries != fMapping.end()) {
if (const auto &entries = fMapping.find(RCombinedJoinFieldValue(joinValues)); entries != fMapping.end()) {
return &entries->second;
}

Expand All @@ -126,11 +101,11 @@ ROOT::Experimental::Internal::RNTupleJoinTable::Add(ROOT::Internal::RPageSource
}

ROOT::NTupleSize_t
ROOT::Experimental::Internal::RNTupleJoinTable::GetEntryIndex(const std::vector<void *> &valuePtrs) const
ROOT::Experimental::Internal::RNTupleJoinTable::GetEntryIndex(const std::vector<JoinValue_t> &joinValues) const
{
for (const auto &partition : fPartitions) {
for (const auto &joinMapping : partition.second) {
auto entriesForMapping = joinMapping->GetEntryIndexes(valuePtrs);
auto entriesForMapping = joinMapping->GetEntryIndexes(joinValues);
if (entriesForMapping) {
return (*entriesForMapping)[0];
}
Expand Down
10 changes: 5 additions & 5 deletions tree/ntuple/src/RNTupleProcessor.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -549,16 +549,16 @@ ROOT::NTupleSize_t ROOT::Experimental::RNTupleJoinProcessor::LoadEntry(ROOT::NTu
}

// Collect the values of the join fields for this entry.
std::vector<void *> valPtrs;
valPtrs.reserve(fJoinFieldIdxs.size());
std::vector<ROOT::Experimental::Internal::RNTupleJoinTable::JoinValue_t> values;
values.reserve(fJoinFieldIdxs.size());
for (const auto &fieldIdx : fJoinFieldIdxs) {
auto ptr = fEntry->GetValue(fieldIdx).GetPtr<void>();
valPtrs.push_back(ptr.get());
auto val = fEntry->GetValue(fieldIdx).GetRef<ROOT::Experimental::Internal::RNTupleJoinTable::JoinValue_t>();
values.push_back(val);
}

// Find the entry index corresponding to the join field values for each auxiliary processor and load the
// corresponding entry.
const auto entryIdx = fJoinTable->GetEntryIndex(valPtrs);
const auto entryIdx = fJoinTable->GetEntryIndex(values);

if (entryIdx == kInvalidNTupleIndex) {
SetAuxiliaryFieldValidity(false);
Expand Down
18 changes: 9 additions & 9 deletions tree/ntuple/test/ntuple_join_table.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ TEST(RNTupleJoinTable, Basic)
std::uint64_t fldValue = 0;

// No entry mappings have been added to the join table yet
EXPECT_EQ(ROOT::kInvalidNTupleIndex, joinTable->GetEntryIndex({&fldValue}));
EXPECT_EQ(ROOT::kInvalidNTupleIndex, joinTable->GetEntryIndex({fldValue}));

// Now add the entry mapping for the page source
joinTable->Add(*pageSource);
Expand All @@ -33,7 +33,7 @@ TEST(RNTupleJoinTable, Basic)
for (unsigned i = 0; i < ntuple->GetNEntries(); ++i) {
fldValue = fld(i);
EXPECT_EQ(fldValue, i * 2);
EXPECT_EQ(joinTable->GetEntryIndex({&fldValue}), i);
EXPECT_EQ(joinTable->GetEntryIndex({fldValue}), i);
}
}

Expand Down Expand Up @@ -127,10 +127,10 @@ TEST(RNTupleJoinTable, SparseSecondary)
auto event = fldEvent(i);

if (i % 2 == 1) {
EXPECT_EQ(joinTable->GetEntryIndex({&event}), ROOT::kInvalidNTupleIndex)
EXPECT_EQ(joinTable->GetEntryIndex({event}), ROOT::kInvalidNTupleIndex)
<< "entry should not be present in the join table";
} else {
auto entryIdx = joinTable->GetEntryIndex({&event});
auto entryIdx = joinTable->GetEntryIndex({event});
EXPECT_EQ(entryIdx, i / 2);
EXPECT_FLOAT_EQ(fldX(entryIdx), static_cast<float>(entryIdx) / 3.14);
}
Expand Down Expand Up @@ -170,25 +170,25 @@ TEST(RNTupleJoinTable, MultipleFields)
for (std::uint64_t i = 0; i < pageSource->GetNEntries(); ++i) {
run = i / 5;
event = i % 5;
auto entryIdx = joinTable->GetEntryIndex({&run, &event});
auto entryIdx = joinTable->GetEntryIndex({static_cast<RNTupleJoinTable::JoinValue_t>(run), event});
EXPECT_EQ(fld(entryIdx), fld(i));
}

run = 1;
event = 2;
auto idx1 = joinTable->GetEntryIndex({&run, &event});
auto idx2 = joinTable->GetEntryIndex({&event, &run});
auto idx1 = joinTable->GetEntryIndex({static_cast<RNTupleJoinTable::JoinValue_t>(run), event});
auto idx2 = joinTable->GetEntryIndex({event, static_cast<RNTupleJoinTable::JoinValue_t>(run)});
EXPECT_NE(idx1, idx2);

try {
joinTable->GetEntryIndex({&run, &event, &event});
joinTable->GetEntryIndex({static_cast<RNTupleJoinTable::JoinValue_t>(run), event, event});
FAIL() << "querying the join table with more values than join field values should not be possible";
} catch (const ROOT::RException &err) {
EXPECT_THAT(err.what(), testing::HasSubstr("number of value pointers must match number of join fields"));
}

try {
joinTable->GetEntryIndex({&run});
joinTable->GetEntryIndex({static_cast<RNTupleJoinTable::JoinValue_t>(run)});
FAIL() << "querying the join table with fewer values than join field values should not be possible";
} catch (const ROOT::RException &err) {
EXPECT_THAT(err.what(), testing::HasSubstr("number of value pointers must match number of join fields"));
Expand Down
Loading