Skip to content

Commit ec487a3

Browse files
author
Rafał Hibner
committed
Support any scalar type in PivotLonger features
1 parent 607be64 commit ec487a3

3 files changed

Lines changed: 61 additions & 19 deletions

File tree

cpp/src/arrow/acero/options.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,15 +780,15 @@ class ARROW_ACERO_EXPORT TableSinkNodeOptions : public ExecNodeOptions {
780780

781781
/// \brief a row template that describes one row that will be generated for each input row
782782
struct ARROW_ACERO_EXPORT PivotLongerRowTemplate {
783-
PivotLongerRowTemplate(std::vector<std::string> feature_values,
783+
PivotLongerRowTemplate(std::vector<std::shared_ptr<Scalar>> feature_values,
784784
std::vector<std::optional<FieldRef>> measurement_values)
785785
: feature_values(std::move(feature_values)),
786786
measurement_values(std::move(measurement_values)) {}
787787
/// A (typically unique) set of feature values for the template, usually derived from a
788788
/// column name
789789
///
790790
/// These will be used to populate the feature columns
791-
std::vector<std::string> feature_values;
791+
std::vector<std::shared_ptr<Scalar>> feature_values;
792792
/// The fields containing the measurements to use for this row
793793
///
794794
/// These will be used to populate the measurement columns. If nullopt then nulls

cpp/src/arrow/acero/pivot_longer_node.cc

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ namespace {
4242

4343
// A row template that's been bound to a schema
4444
struct BoundRowTemplate {
45-
std::vector<std::string> feature_values;
45+
std::vector<std::shared_ptr<Scalar>> feature_values;
4646
std::vector<std::optional<FieldPath>> measurement_paths;
4747

4848
static Result<BoundRowTemplate> Make(const PivotLongerRowTemplate& unbound,
@@ -65,7 +65,7 @@ struct BoundRowTemplate {
6565
}
6666

6767
private:
68-
BoundRowTemplate(std::vector<std::string> feature_values,
68+
BoundRowTemplate(std::vector<std::shared_ptr<Scalar>> feature_values,
6969
std::vector<std::optional<FieldPath>> measurement_paths)
7070
: feature_values(std::move(feature_values)),
7171
measurement_paths(std::move(measurement_paths)) {}
@@ -89,6 +89,8 @@ class PivotLongerNode : public ExecNode, public TracedNode {
8989
"have names");
9090
}
9191

92+
std::vector<std::shared_ptr<DataType>> feature_types(
93+
options.feature_field_names.size());
9294
for (const auto& row_template : options.row_templates) {
9395
if (row_template.feature_values.size() != options.feature_field_names.size()) {
9496
return Status::Invalid("There were names given for ",
@@ -103,11 +105,28 @@ class PivotLongerNode : public ExecNode, public TracedNode {
103105
" measurement columns but one of the row templates only had ",
104106
row_template.measurement_values.size(), " field references");
105107
}
108+
109+
for (std::size_t i = 0; i < row_template.feature_values.size(); i++) {
110+
if (feature_types[i]) {
111+
if (!feature_types[i]->Equals(row_template.feature_values[i]->type)) {
112+
return Status::Invalid(
113+
"Mixed feature types at column ", options.feature_field_names[i],
114+
". Some row templates had the type ", feature_types[i]->ToString(),
115+
" but later row templates had the type ",
116+
row_template.feature_values[i]->type->ToString(),
117+
". All row templates must have same type for each feature "
118+
"column.");
119+
}
120+
} else {
121+
feature_types[i] = row_template.feature_values[i]->type;
122+
}
123+
}
106124
}
107125

108126
std::vector<std::shared_ptr<Field>> fields(input_schema->fields());
109-
for (const auto& name : options.feature_field_names) {
110-
fields.push_back(field(name, utf8()));
127+
for (std::size_t i = 0; i < options.feature_field_names.size(); i++) {
128+
fields.push_back(
129+
field(options.feature_field_names[i], std::move(feature_types[i])));
111130
}
112131
std::vector<std::shared_ptr<DataType>> measurement_types(
113132
options.measurement_field_names.size());

cpp/src/arrow/acero/pivot_longer_node_test.cc

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,15 @@ TEST(PivotLongerNode, Basic) {
4343
->Table(kRowsPerBatch, kNumBatches);
4444

4545
PivotLongerNodeOptions options;
46-
options.feature_field_names = {"feature1", "feature2"};
46+
options.feature_field_names = {"feature1", "feature2", "feature3"};
4747
options.measurement_field_names = {"meas1", "meas2"};
48-
options.row_templates = {{{"a", "x"}, {{1}, {3}}}, {{"b", "y"}, {{2}, std::nullopt}}};
48+
options.row_templates = {
49+
{{std::make_shared<StringScalar>("a"), std::make_shared<StringScalar>("x"),
50+
std::make_shared<UInt32Scalar>(12)},
51+
{{1}, {3}}},
52+
{{std::make_shared<StringScalar>("b"), std::make_shared<StringScalar>("y"),
53+
std::make_shared<UInt32Scalar>(13)},
54+
{{2}, std::nullopt}}};
4955

5056
Declaration plan = Declaration::Sequence({
5157
{"table_source", TableSourceNodeOptions(std::move(input))},
@@ -62,6 +68,7 @@ TEST(PivotLongerNode, Basic) {
6268
field("f3", uint32()),
6369
field("feature1", utf8()),
6470
field("feature2", utf8()),
71+
field("feature3", uint32()),
6572
field("meas1", uint32()),
6673
field("meas2", uint32()),
6774
});
@@ -98,19 +105,27 @@ TEST(PivotLongerNode, Error) {
98105
"There were names given for 1 feature columns but one of the row templates "
99106
"only had 0 feature values");
100107

101-
options.row_templates = {{{"x"}, {}}};
108+
options.row_templates = {{{std::make_shared<StringScalar>("x")}, {}}};
102109
CheckError(
103110
options,
104111
"There were names given for 1 measurement columns but one of the row templates "
105112
"only had 0 field references");
106113

107-
options.row_templates = {{{"x"}, {{0}}}, {{"y"}, {{1}}}};
114+
options.row_templates = {{{std::make_shared<StringScalar>("x")}, {{0}}},
115+
{{std::make_shared<StringScalar>("y")}, {{1}}}};
108116
CheckError(
109117
options,
110118
"Some row templates had the type uint32 but later row templates had the type bool");
111119

112-
options.row_templates = {{{"x"}, {std::nullopt}}, {{"y"}, {std::nullopt}}};
120+
options.row_templates = {{{std::make_shared<StringScalar>("x")}, {std::nullopt}},
121+
{{std::make_shared<StringScalar>("y")}, {std::nullopt}}};
113122
CheckError(options, "All row templates had nullopt");
123+
124+
options.row_templates = {{{std::make_shared<StringScalar>("x")}, {{0}}},
125+
{{std::make_shared<UInt32Scalar>(1)}, {{0}}}};
126+
CheckError(options,
127+
"Some row templates had the type string but later row templates had the "
128+
"type uint32");
114129
}
115130

116131
// The following examples are smaller versions of examples taken from
@@ -130,11 +145,11 @@ TEST(PivotLongerNode, ExamplesFromTidyr1) {
130145
PivotLongerNodeOptions options;
131146
options.feature_field_names = {"income"};
132147
options.measurement_field_names = {"count"};
133-
options.row_templates = {{{"<$10k"}, {{1}}},
134-
{{"$10k-20k"}, {{2}}},
135-
{{"$20k-30k"}, {{3}}},
136-
{{"$30k-40k"}, {{4}}},
137-
{{"$40k-50k"}, {{5}}}};
148+
options.row_templates = {{{std::make_shared<StringScalar>("<$10k")}, {{1}}},
149+
{{std::make_shared<StringScalar>("$10k-20k")}, {{2}}},
150+
{{std::make_shared<StringScalar>("$20k-30k")}, {{3}}},
151+
{{std::make_shared<StringScalar>("$30k-40k")}, {{4}}},
152+
{{std::make_shared<StringScalar>("$40k-50k")}, {{5}}}};
138153

139154
Declaration plan = Declaration::Sequence(
140155
{{"table_source", TableSourceNodeOptions(std::move(input))},
@@ -183,7 +198,8 @@ TEST(PivotLongerNode, ExamplesFromTidyr2) {
183198
PivotLongerNodeOptions options;
184199
options.feature_field_names = {"week"};
185200
options.measurement_field_names = {"rank"};
186-
options.row_templates = {{{"1"}, {{2}}}, {{"2"}, {{3}}}};
201+
options.row_templates = {{{std::make_shared<StringScalar>("1")}, {{2}}},
202+
{{std::make_shared<StringScalar>("2")}, {{3}}}};
187203

188204
Declaration plan = Declaration::Sequence(
189205
{{"table_source", TableSourceNodeOptions(std::move(input))},
@@ -222,7 +238,13 @@ TEST(PivotLongerNode, ExamplesFromTidyr3) {
222238
PivotLongerNodeOptions options;
223239
options.feature_field_names = {"diagnosis", "gender", "age"};
224240
options.measurement_field_names = {"count"};
225-
options.row_templates = {{{"sp", "m", "014"}, {{1}}}, {{"sp", "m", "1524"}, {{2}}}};
241+
options.row_templates = {
242+
{{std::make_shared<StringScalar>("sp"), std::make_shared<StringScalar>("m"),
243+
std::make_shared<StringScalar>("014")},
244+
{{1}}},
245+
{{std::make_shared<StringScalar>("sp"), std::make_shared<StringScalar>("m"),
246+
std::make_shared<StringScalar>("1524")},
247+
{{2}}}};
226248

227249
Declaration plan = Declaration::Sequence(
228250
{{"table_source", TableSourceNodeOptions(std::move(input))},
@@ -261,7 +283,8 @@ TEST(PivotLongerNode, ExamplesFromTidyr4) {
261283
PivotLongerNodeOptions options;
262284
options.feature_field_names = {"set"};
263285
options.measurement_field_names = {"x", "y"};
264-
options.row_templates = {{{"1"}, {{0}, {2}}}, {{"2"}, {{1}, {3}}}};
286+
options.row_templates = {{{std::make_shared<StringScalar>("1")}, {{0}, {2}}},
287+
{{std::make_shared<StringScalar>("2")}, {{1}, {3}}}};
265288

266289
Declaration plan = Declaration::Sequence(
267290
{{"table_source", TableSourceNodeOptions(std::move(input))},

0 commit comments

Comments
 (0)