Skip to content

Commit 20baff2

Browse files
committed
[df] integer overflow check for RNTuple cardinality columns
1 parent ef67e11 commit 20baff2

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

tree/dataframe/src/RNTupleDS.cxx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <TSystem.h>
2828

2929
#include <cassert>
30+
#include <limits>
3031
#include <memory>
3132
#include <mutex>
3233
#include <string>
@@ -95,6 +96,16 @@ template <typename T>
9596
class RRDFCardinalityField final : public RRDFCardinalityFieldBase {
9697
static_assert(std::is_integral_v<T>, "T must be an integral type");
9798

99+
inline void CheckSize(ROOT::NTupleSize_t size) const
100+
{
101+
if constexpr (std::is_same_v<T, bool> || std::is_same_v<T, std::uint64_t>)
102+
return;
103+
if (size > std::numeric_limits<T>::max()) {
104+
throw RException(R__FAIL(std::string("integer overflow in field ") + GetFieldName() +
105+
". Please read the column with a larger-sized integral type."));
106+
}
107+
}
108+
98109
protected:
99110
std::unique_ptr<ROOT::RFieldBase> CloneImpl(std::string_view newName) const final
100111
{
@@ -120,6 +131,7 @@ class RRDFCardinalityField final : public RRDFCardinalityFieldBase {
120131
RNTupleLocalIndex collectionStart;
121132
ROOT::NTupleSize_t size;
122133
fPrincipalColumn->GetCollectionInfo(globalIndex, &collectionStart, &size);
134+
CheckSize(size);
123135
*static_cast<T *>(to) = size;
124136
}
125137

@@ -129,6 +141,7 @@ class RRDFCardinalityField final : public RRDFCardinalityFieldBase {
129141
RNTupleLocalIndex collectionStart;
130142
ROOT::NTupleSize_t size;
131143
fPrincipalColumn->GetCollectionInfo(localIndex, &collectionStart, &size);
144+
CheckSize(size);
132145
*static_cast<T *>(to) = size;
133146
}
134147
};

tree/dataframe/test/datasource_ntuple.cxx

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class RNTupleDSTest : public ::testing::Test {
6868
auto fldElectron = model->MakeField<Electron>("electron");
6969
fldElectron->pt = 137.0;
7070
auto fldVecElectron = model->MakeField<std::vector<Electron>>("VecElectron");
71-
fldVecElectron->push_back(*fldElectron);
72-
fldVecElectron->push_back(*fldElectron);
71+
for (int i = 0; i < 128; ++i)
72+
fldVecElectron->push_back(*fldElectron);
7373
auto fldNElectron = std::make_unique<ROOT::RField<ROOT::RNTupleCardinality<std::uint64_t>>>("nElectron");
7474
model->AddProjectedField(std::move(fldNElectron), [](const std::string &) { return "VecElectron"; });
7575
{
@@ -152,17 +152,21 @@ TEST_F(RNTupleDSTest, ProjectedCardinalityColumn)
152152
{
153153
auto df = ROOT::RDF::FromRNTuple(fNtplName, fFileName);
154154

155-
EXPECT_EQ(2u, *df.Filter("nElectron == 2").Max("nElectron"));
156-
157-
EXPECT_EQ(2u, *df.Filter([](std::uint64_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
158-
EXPECT_EQ(2u, *df.Filter([](std::int32_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
159-
EXPECT_EQ(2u, *df.Filter([](std::uint32_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
160-
EXPECT_EQ(2u, *df.Filter([](std::int16_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
161-
EXPECT_EQ(2u, *df.Filter([](std::uint16_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
162-
EXPECT_EQ(2u, *df.Filter([](std::int8_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
163-
EXPECT_EQ(2u, *df.Filter([](std::uint8_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
164-
EXPECT_EQ(2u, *df.Filter([](char x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
165-
EXPECT_EQ(2u, *df.Filter([](bool x) { return x; }, {"nElectron"}).Max("nElectron"));
155+
EXPECT_EQ(128u, *df.Filter("nElectron == 128").Max("nElectron"));
156+
157+
EXPECT_EQ(128u, *df.Filter([](std::uint64_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
158+
EXPECT_EQ(128u, *df.Filter([](std::int32_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
159+
EXPECT_EQ(128u, *df.Filter([](std::uint32_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
160+
EXPECT_EQ(128u, *df.Filter([](std::int16_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
161+
EXPECT_EQ(128u, *df.Filter([](std::uint16_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
162+
EXPECT_EQ(128u, *df.Filter([](std::uint8_t x) { return x == 128; }, {"nElectron"}).Max("nElectron"));
163+
EXPECT_EQ(128u, *df.Filter([](bool x) { return x; }, {"nElectron"}).Max("nElectron"));
164+
try {
165+
*df.Filter([](std::int8_t x) { return x == 0; }, {"nElectron"}).Count();
166+
FAIL() << "integer overflow should fail";
167+
} catch (const ROOT::RException &e) {
168+
EXPECT_THAT(e.what(), ::testing::HasSubstr("integer overflow"));
169+
}
166170
}
167171

168172
static void ReadTest(const std::string &name, const std::string &fname)
@@ -206,7 +210,7 @@ static void ReadTest(const std::string &name, const std::string &fname)
206210
EXPECT_TRUE(All(rvec->at(0) == ROOT::RVecI{1, 2, 3}));
207211
EXPECT_TRUE(All(vectorasrvec->at(0) == ROOT::RVecF{1.f, 2.f}));
208212
EXPECT_FLOAT_EQ(137.0, sumElectronPt.GetValue());
209-
EXPECT_FLOAT_EQ(2. * 137.0, sumVecElectronPt.GetValue());
213+
EXPECT_FLOAT_EQ(128. * 137.0, sumVecElectronPt.GetValue());
210214
}
211215

212216
static void ChainTest(const std::string &name, const std::string &fname)

0 commit comments

Comments
 (0)