Skip to content

Commit d46a25a

Browse files
committed
[ntuple] Pass join values by reference instead *void
... to avoid potentially reading invalid memory. This requires the join values to be cast to the appropriate type on the call side instead.
1 parent 7282024 commit d46a25a

4 files changed

Lines changed: 27 additions & 52 deletions

File tree

tree/ntuple/inc/ROOT/RNTupleJoinTable.hxx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ private:
149149
public:
150150
//////////////////////////////////////////////////////////////////////////
151151
/// \brief Get the entry indexes for this entry mapping.
152-
const std::vector<ROOT::NTupleSize_t> *GetEntryIndexes(std::vector<void *> valuePtrs) const;
152+
const std::vector<ROOT::NTupleSize_t> *GetEntryIndexes(const std::vector<JoinValue_t> &joinValues) const;
153153

154154
//////////////////////////////////////////////////////////////////////////
155155
/// \brief Create a new entry mapping.
@@ -207,14 +207,14 @@ public:
207207
/////////////////////////////////////////////////////////////////////////////
208208
/// \brief Get an entry index (if it exists) for the given join field value(s), from any partition.
209209
///
210-
/// \param[in] valuePtrs A vector of pointers to the join field values to look up.
210+
/// \param[in] joinValues The join field values to look up.
211211
///
212212
/// \note If one or more corresponding entries exist for the given value(s), the first entry index found in the join
213213
/// table is returned.
214214
///
215-
/// \return An entry number that corresponds to `valuePtrs`. When there are no corresponding entries,
215+
/// \return An entry number that corresponds to `joinValues`. When there are no corresponding entries,
216216
/// `kInvalidNTupleIndex` is returned.
217-
ROOT::NTupleSize_t GetEntryIndex(const std::vector<void *> &valuePtrs) const;
217+
ROOT::NTupleSize_t GetEntryIndex(const std::vector<JoinValue_t> &joinValues) const;
218218
};
219219
} // namespace Internal
220220
} // namespace Experimental

tree/ntuple/src/RNTupleJoinTable.cxx

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,6 @@
1414

1515
#include <ROOT/RNTupleJoinTable.hxx>
1616

17-
namespace {
18-
ROOT::Experimental::Internal::RNTupleJoinTable::JoinValue_t CastValuePtr(void *valuePtr, std::size_t fieldValueSize)
19-
{
20-
ROOT::Experimental::Internal::RNTupleJoinTable::JoinValue_t value;
21-
22-
switch (fieldValueSize) {
23-
case 1: value = *reinterpret_cast<std::uint8_t *>(valuePtr); break;
24-
case 2: value = *reinterpret_cast<std::uint16_t *>(valuePtr); break;
25-
case 4: value = *reinterpret_cast<std::uint32_t *>(valuePtr); break;
26-
case 8: value = *reinterpret_cast<std::uint64_t *>(valuePtr); break;
27-
default: throw ROOT::RException(R__FAIL("value size not supported"));
28-
}
29-
30-
return value;
31-
}
32-
} // anonymous namespace
33-
3417
ROOT::Experimental::Internal::RNTupleJoinTable::REntryMapping::REntryMapping(
3518
ROOT::Internal::RPageSource &pageSource, const std::vector<std::string> &joinFieldNames,
3619
ROOT::NTupleSize_t entryOffset)
@@ -60,7 +43,8 @@ ROOT::Experimental::Internal::RNTupleJoinTable::REntryMapping::REntryMapping(
6043
"\" in join table: only integral types are allowed"));
6144
}
6245

63-
auto field = fieldDesc.CreateField(desc.GetRef());
46+
auto field = std::make_unique<ROOT::RField<JoinValue_t>>(fieldDesc.GetFieldName());
47+
field->SetOnDiskId(fieldDesc.GetId());
6448
ROOT::Internal::CallConnectPageSourceOnField(*field, pageSource);
6549

6650
fieldValues.emplace_back(field->CreateValue());
@@ -77,29 +61,20 @@ ROOT::Experimental::Internal::RNTupleJoinTable::REntryMapping::REntryMapping(
7761
for (auto &fieldValue : fieldValues) {
7862
// TODO(fdegeus): use bulk reading
7963
fieldValue.Read(i);
80-
81-
auto valuePtr = fieldValue.GetPtr<void>();
82-
castJoinValues.push_back(CastValuePtr(valuePtr.get(), fieldValue.GetField().GetValueSize()));
64+
castJoinValues.push_back(fieldValue.GetRef<JoinValue_t>());
8365
}
8466

8567
fMapping[RCombinedJoinFieldValue(castJoinValues)].push_back(i + entryOffset);
8668
}
8769
}
8870

89-
const std::vector<ROOT::NTupleSize_t> *
90-
ROOT ::Experimental::Internal::RNTupleJoinTable::REntryMapping::GetEntryIndexes(std::vector<void *> valuePtrs) const
71+
const std::vector<ROOT::NTupleSize_t> *ROOT ::Experimental::Internal::RNTupleJoinTable::REntryMapping::GetEntryIndexes(
72+
const std::vector<JoinValue_t> &joinValues) const
9173
{
92-
if (valuePtrs.size() != fJoinFieldNames.size())
74+
if (joinValues.size() != fJoinFieldNames.size())
9375
throw RException(R__FAIL("number of value pointers must match number of join fields"));
9476

95-
std::vector<JoinValue_t> castJoinValues;
96-
castJoinValues.reserve(valuePtrs.size());
97-
98-
for (unsigned i = 0; i < valuePtrs.size(); ++i) {
99-
castJoinValues.push_back(CastValuePtr(valuePtrs[i], fJoinFieldValueSizes[i]));
100-
}
101-
102-
if (const auto &entries = fMapping.find(RCombinedJoinFieldValue(castJoinValues)); entries != fMapping.end()) {
77+
if (const auto &entries = fMapping.find(RCombinedJoinFieldValue(joinValues)); entries != fMapping.end()) {
10378
return &entries->second;
10479
}
10580

@@ -125,11 +100,11 @@ ROOT::Experimental::Internal::RNTupleJoinTable::Add(ROOT::Internal::RPageSource
125100
}
126101

127102
ROOT::NTupleSize_t
128-
ROOT::Experimental::Internal::RNTupleJoinTable::GetEntryIndex(const std::vector<void *> &valuePtrs) const
103+
ROOT::Experimental::Internal::RNTupleJoinTable::GetEntryIndex(const std::vector<JoinValue_t> &joinValues) const
129104
{
130105
for (const auto &partition : fPartitions) {
131106
for (const auto &joinMapping : partition.second) {
132-
auto entriesForMapping = joinMapping->GetEntryIndexes(valuePtrs);
107+
auto entriesForMapping = joinMapping->GetEntryIndexes(joinValues);
133108
if (entriesForMapping) {
134109
return (*entriesForMapping)[0];
135110
}

tree/ntuple/src/RNTupleProcessor.cxx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,16 +548,16 @@ ROOT::NTupleSize_t ROOT::Experimental::RNTupleJoinProcessor::LoadEntry(ROOT::NTu
548548
}
549549

550550
// Collect the values of the join fields for this entry.
551-
std::vector<void *> valPtrs;
552-
valPtrs.reserve(fJoinFieldIdxs.size());
551+
std::vector<ROOT::Experimental::Internal::RNTupleJoinTable::JoinValue_t> values;
552+
values.reserve(fJoinFieldIdxs.size());
553553
for (const auto &fieldIdx : fJoinFieldIdxs) {
554-
auto ptr = fEntry->GetValue(fieldIdx).GetPtr<void>();
555-
valPtrs.push_back(ptr.get());
554+
auto val = fEntry->GetValue(fieldIdx).GetRef<ROOT::Experimental::Internal::RNTupleJoinTable::JoinValue_t>();
555+
values.push_back(val);
556556
}
557557

558558
// Find the entry index corresponding to the join field values for each auxiliary processor and load the
559559
// corresponding entry.
560-
const auto entryIdx = fJoinTable->GetEntryIndex(valPtrs);
560+
const auto entryIdx = fJoinTable->GetEntryIndex(values);
561561

562562
if (entryIdx == kInvalidNTupleIndex) {
563563
SetAuxiliaryFieldValidity(false);

tree/ntuple/test/ntuple_join_table.cxx

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ TEST(RNTupleJoinTable, Basic)
2222
std::uint64_t fldValue = 0;
2323

2424
// No entry mappings have been added to the join table yet
25-
EXPECT_EQ(ROOT::kInvalidNTupleIndex, joinTable->GetEntryIndex({&fldValue}));
25+
EXPECT_EQ(ROOT::kInvalidNTupleIndex, joinTable->GetEntryIndex({fldValue}));
2626

2727
// Now add the entry mapping for the page source
2828
joinTable->Add(*pageSource);
@@ -33,7 +33,7 @@ TEST(RNTupleJoinTable, Basic)
3333
for (unsigned i = 0; i < ntuple->GetNEntries(); ++i) {
3434
fldValue = fld(i);
3535
EXPECT_EQ(fldValue, i * 2);
36-
EXPECT_EQ(joinTable->GetEntryIndex({&fldValue}), i);
36+
EXPECT_EQ(joinTable->GetEntryIndex({fldValue}), i);
3737
}
3838
}
3939

@@ -127,10 +127,10 @@ TEST(RNTupleJoinTable, SparseSecondary)
127127
auto event = fldEvent(i);
128128

129129
if (i % 2 == 1) {
130-
EXPECT_EQ(joinTable->GetEntryIndex({&event}), ROOT::kInvalidNTupleIndex)
130+
EXPECT_EQ(joinTable->GetEntryIndex({event}), ROOT::kInvalidNTupleIndex)
131131
<< "entry should not be present in the join table";
132132
} else {
133-
auto entryIdx = joinTable->GetEntryIndex({&event});
133+
auto entryIdx = joinTable->GetEntryIndex({event});
134134
EXPECT_EQ(entryIdx, i / 2);
135135
EXPECT_FLOAT_EQ(fldX(entryIdx), static_cast<float>(entryIdx) / 3.14);
136136
}
@@ -170,25 +170,25 @@ TEST(RNTupleJoinTable, MultipleFields)
170170
for (std::uint64_t i = 0; i < pageSource->GetNEntries(); ++i) {
171171
run = i / 5;
172172
event = i % 5;
173-
auto entryIdx = joinTable->GetEntryIndex({&run, &event});
173+
auto entryIdx = joinTable->GetEntryIndex({static_cast<RNTupleJoinTable::JoinValue_t>(run), event});
174174
EXPECT_EQ(fld(entryIdx), fld(i));
175175
}
176176

177177
run = 1;
178178
event = 2;
179-
auto idx1 = joinTable->GetEntryIndex({&run, &event});
180-
auto idx2 = joinTable->GetEntryIndex({&event, &run});
179+
auto idx1 = joinTable->GetEntryIndex({static_cast<RNTupleJoinTable::JoinValue_t>(run), event});
180+
auto idx2 = joinTable->GetEntryIndex({event, static_cast<RNTupleJoinTable::JoinValue_t>(run)});
181181
EXPECT_NE(idx1, idx2);
182182

183183
try {
184-
joinTable->GetEntryIndex({&run, &event, &event});
184+
joinTable->GetEntryIndex({static_cast<RNTupleJoinTable::JoinValue_t>(run), event, event});
185185
FAIL() << "querying the join table with more values than join field values should not be possible";
186186
} catch (const ROOT::RException &err) {
187187
EXPECT_THAT(err.what(), testing::HasSubstr("number of value pointers must match number of join fields"));
188188
}
189189

190190
try {
191-
joinTable->GetEntryIndex({&run});
191+
joinTable->GetEntryIndex({static_cast<RNTupleJoinTable::JoinValue_t>(run)});
192192
FAIL() << "querying the join table with fewer values than join field values should not be possible";
193193
} catch (const ROOT::RException &err) {
194194
EXPECT_THAT(err.what(), testing::HasSubstr("number of value pointers must match number of join fields"));

0 commit comments

Comments
 (0)