Skip to content

Commit 8ed2edf

Browse files
committed
[df] allow for alternative int types for cardinality columns
1 parent 13fae1a commit 8ed2edf

2 files changed

Lines changed: 82 additions & 29 deletions

File tree

tree/dataframe/src/RNTupleDS.cxx

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,49 @@
5050
// clang-format on
5151

5252
namespace ROOT::Internal::RDF {
53+
class RRDFCardinalityFieldBase : public ROOT::RFieldBase {
54+
protected:
55+
// We construct these fields and know that they match the page source
56+
void ReconcileOnDiskField(const RNTupleDescriptor &) final {}
57+
58+
RRDFCardinalityFieldBase(std::string_view name, std::string_view type)
59+
: ROOT::RFieldBase(name, type, ROOT::ENTupleStructure::kPlain, false /* isSimple */)
60+
{
61+
}
62+
63+
// Field is only used for reading
64+
void GenerateColumns() final { throw RException(R__FAIL("Cardinality fields must only be used for reading")); }
65+
void GenerateColumns(const ROOT::RNTupleDescriptor &desc) final
66+
{
67+
GenerateColumnsImpl<ROOT::Internal::RColumnIndex>(desc);
68+
}
69+
70+
public:
71+
RRDFCardinalityFieldBase(const RRDFCardinalityFieldBase &other) = delete;
72+
RRDFCardinalityFieldBase &operator=(const RRDFCardinalityFieldBase &other) = delete;
73+
RRDFCardinalityFieldBase(RRDFCardinalityFieldBase &&other) = default;
74+
RRDFCardinalityFieldBase &operator=(RRDFCardinalityFieldBase &&other) = default;
75+
~RRDFCardinalityFieldBase() override = default;
76+
77+
const RColumnRepresentations &GetColumnRepresentations() const final
78+
{
79+
static RColumnRepresentations representations({{ENTupleColumnType::kSplitIndex64},
80+
{ENTupleColumnType::kIndex64},
81+
{ENTupleColumnType::kSplitIndex32},
82+
{ENTupleColumnType::kIndex32}},
83+
{});
84+
return representations;
85+
}
86+
};
87+
5388
/// An artificial field that transforms an RNTuple column that contains the offset of collections into
5489
/// collection sizes. It is used to provide the "number of" RDF columns for collections, e.g.
5590
/// `R_rdf_sizeof_jets` for a collection named `jets`.
5691
///
5792
/// This is similar to the RCardinalityField but it presents itself as an integer type.
5893
/// The template argument T must be an integral type.
5994
template <typename T>
60-
class RRDFCardinalityField final : public ROOT::RFieldBase {
95+
class RRDFCardinalityField final : public RRDFCardinalityFieldBase {
6196
static_assert(std::is_integral_v<T>, "T must be an integral type");
6297

6398
protected:
@@ -67,13 +102,9 @@ class RRDFCardinalityField final : public ROOT::RFieldBase {
67102
}
68103
void ConstructValue(void *where) const final { *static_cast<T *>(where) = 0; }
69104

70-
// We construct these fields and know that they match the page source
71-
void ReconcileOnDiskField(const RNTupleDescriptor &) final {}
72-
73105
public:
74106
RRDFCardinalityField(std::string_view name)
75-
: ROOT::RFieldBase(name, ROOT::Internal::GetRenormalizedTypeName(typeid(T)), ROOT::ENTupleStructure::kPlain,
76-
false /* isSimple */)
107+
: RRDFCardinalityFieldBase(name, ROOT::Internal::GetRenormalizedTypeName(typeid(T)))
77108
{
78109
}
79110
RRDFCardinalityField(const RRDFCardinalityField &other) = delete;
@@ -82,22 +113,6 @@ class RRDFCardinalityField final : public ROOT::RFieldBase {
82113
RRDFCardinalityField &operator=(RRDFCardinalityField &&other) = default;
83114
~RRDFCardinalityField() override = default;
84115

85-
const RColumnRepresentations &GetColumnRepresentations() const final
86-
{
87-
static RColumnRepresentations representations({{ENTupleColumnType::kSplitIndex64},
88-
{ENTupleColumnType::kIndex64},
89-
{ENTupleColumnType::kSplitIndex32},
90-
{ENTupleColumnType::kIndex32}},
91-
{});
92-
return representations;
93-
}
94-
// Field is only used for reading
95-
void GenerateColumns() final { throw RException(R__FAIL("Cardinality fields must only be used for reading")); }
96-
void GenerateColumns(const ROOT::RNTupleDescriptor &desc) final
97-
{
98-
GenerateColumnsImpl<ROOT::Internal::RColumnIndex>(desc);
99-
}
100-
101116
std::size_t GetValueSize() const final { return sizeof(T); }
102117
std::size_t GetAlignment() const final { return alignof(T); }
103118

@@ -150,7 +165,8 @@ class RArraySizeField final : public ROOT::RFieldBase {
150165

151166
public:
152167
RArraySizeField(std::string_view name, std::size_t arrayLength)
153-
: ROOT::RFieldBase(name, "std::size_t", ROOT::ENTupleStructure::kPlain, false /* isSimple */),
168+
: ROOT::RFieldBase(name, ROOT::Internal::GetRenormalizedTypeName(typeid(std::size_t)),
169+
ROOT::ENTupleStructure::kPlain, false /* isSimple */),
154170
fArrayLength(arrayLength)
155171
{
156172
}
@@ -493,7 +509,7 @@ ROOT::RFieldBase *ROOT::RDF::RNTupleDS::GetFieldWithTypeChecks(std::string_view
493509
// If the field corresponding to the provided name is not a cardinality column and the requested type is different
494510
// from the proto field that was created when the data source was constructed, we first have to create an
495511
// alternative proto field for the column reader. Otherwise, we can directly use the existing proto field.
496-
if (fieldName.substr(0, 13) != "R_rdf_sizeof_" && requestedType != fColumnTypes[index]) {
512+
if (requestedType != fColumnTypes[index]) {
497513
auto &altProtoFields = fAlternativeProtoFields[index];
498514

499515
// If we can find the requested type in the registered alternative protofields, return the corresponding field
@@ -506,12 +522,41 @@ ROOT::RFieldBase *ROOT::RDF::RNTupleDS::GetFieldWithTypeChecks(std::string_view
506522
}
507523

508524
// Otherwise, create a new protofield and register it in the alternatives before returning
509-
auto newAltProtoFieldOrException = ROOT::RFieldBase::Create(std::string(fieldName), requestedType);
510-
if (!newAltProtoFieldOrException) {
511-
throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType +
512-
"\" for column \"" + std::string(fieldName) + "\"");
525+
std::unique_ptr<RFieldBase> newAltProtoField;
526+
const std::string strName = std::string(fieldName);
527+
if (dynamic_cast<ROOT::Internal::RDF::RRDFCardinalityFieldBase *>(fProtoFields[index].get())) {
528+
if (requestedType == "bool") {
529+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<bool>>(strName);
530+
} else if (requestedType == "char") {
531+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<char>>(strName);
532+
} else if (requestedType == "std::int8_t") {
533+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::int8_t>>(strName);
534+
} else if (requestedType == "std::uint8_t") {
535+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint8_t>>(strName);
536+
} else if (requestedType == "std::int16_t") {
537+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::int16_t>>(strName);
538+
} else if (requestedType == "std::uint16_t") {
539+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint16_t>>(strName);
540+
} else if (requestedType == "std::int32_t") {
541+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::int32_t>>(strName);
542+
} else if (requestedType == "std::uint32_t") {
543+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint32_t>>(strName);
544+
} else if (requestedType == "std::int64_t") {
545+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::int64_t>>(strName);
546+
} else if (requestedType == "std::uint64_t") {
547+
newAltProtoField = std::make_unique<ROOT::Internal::RDF::RRDFCardinalityField<std::uint64_t>>(strName);
548+
} else {
549+
throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType +
550+
"\" for column \"" + std::string(fieldName) + "\"");
551+
}
552+
} else {
553+
auto newAltProtoFieldOrException = ROOT::RFieldBase::Create(strName, requestedType);
554+
if (!newAltProtoFieldOrException) {
555+
throw std::runtime_error("RNTupleDS: Could not create field with type \"" + requestedType +
556+
"\" for column \"" + std::string(fieldName) + "\"");
557+
}
558+
newAltProtoField = newAltProtoFieldOrException.Unwrap();
513559
}
514-
auto newAltProtoField = newAltProtoFieldOrException.Unwrap();
515560
newAltProtoField->SetOnDiskId(fProtoFields[index]->GetOnDiskId());
516561
auto *newField = newAltProtoField.get();
517562
altProtoFields.emplace_back(std::move(newAltProtoField));

tree/dataframe/test/datasource_ntuple.cxx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ TEST_F(RNTupleDSTest, ProjectedCardinalityColumn)
155155
EXPECT_EQ(2u, *df.Filter("nElectron == 2").Max("nElectron"));
156156

157157
EXPECT_EQ(2u, *df.Filter([](std::uint64_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
158+
EXPECT_EQ(2u, *df.Filter([](std::int32_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
159+
EXPECT_EQ(2u, *df.Filter([](std::uint32_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
160+
EXPECT_EQ(2u, *df.Filter([](std::int16_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
161+
EXPECT_EQ(2u, *df.Filter([](std::uint16_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
162+
EXPECT_EQ(2u, *df.Filter([](std::int8_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
163+
EXPECT_EQ(2u, *df.Filter([](std::uint8_t x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
164+
EXPECT_EQ(2u, *df.Filter([](char x) { return x == 2; }, {"nElectron"}).Max("nElectron"));
165+
EXPECT_EQ(2u, *df.Filter([](bool x) { return x; }, {"nElectron"}).Max("nElectron"));
158166
}
159167

160168
static void ReadTest(const std::string &name, const std::string &fname)

0 commit comments

Comments
 (0)