Skip to content

Commit 4791449

Browse files
authored
Parquet readers support case-insensitive column names (rapidsai#21700)
Closes rapidsai#21631 This PR enables the Parquet and Hybrid scan readers to support case-insensitive column names (or paths if structs) for column selection and row filteration. Authors: - Muhammad Haseeb (https://github.com/mhaseeb123) - Vukasin Milovanovic (https://github.com/vuule) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Nghia Truong (https://github.com/ttnghia) - Matthew Murray (https://github.com/Matt711) - Lawrence Mitchell (https://github.com/wence-) - Vyas Ramasubramani (https://github.com/vyasr) URL: rapidsai#21700
1 parent a5f171b commit 4791449

23 files changed

Lines changed: 779 additions & 163 deletions

cpp/include/cudf/io/parquet.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class parquet_reader_options {
103103
type_id _decimal_width{type_id::EMPTY};
104104
// Whether to use JIT compilation for filtering
105105
bool _use_jit_filter = false;
106+
// Whether column name matching is case sensitive. In case of multiple
107+
// case-insensitive matches, the first matched column is selected
108+
bool _case_sensitive_names = true;
106109

107110
std::optional<std::vector<reader_column_schema>> _reader_column_schema;
108111

@@ -285,6 +288,16 @@ class parquet_reader_options {
285288
*/
286289
[[nodiscard]] bool is_enabled_use_jit_filter() const { return _use_jit_filter; }
287290

291+
/**
292+
* @brief Returns whether column name matching is case sensitive.
293+
*
294+
* @note When disabled, if there are multiple case-insensitive matches, the first
295+
* matched column is selected from the Parquet schema.
296+
*
297+
* @return `true` if column name matching is case sensitive (default)
298+
*/
299+
[[nodiscard]] bool is_enabled_case_sensitive_names() const { return _case_sensitive_names; }
300+
288301
/**
289302
* @brief Set a new source location
290303
*
@@ -510,6 +523,16 @@ class parquet_reader_options {
510523
* columns need to be cast. The scale of each column is preserved from the file.
511524
*/
512525
void set_decimal_width(type_id width) { _decimal_width = width; }
526+
527+
/**
528+
* @brief Sets whether column name matching is case sensitive.
529+
*
530+
* @note When disabled, if there are multiple case-insensitive matches, the first
531+
* matched column is selected from the Parquet schema.
532+
*
533+
* @param val Boolean indicating whether to enable case-sensitive matching.
534+
*/
535+
void enable_case_sensitive_names(bool val) { _case_sensitive_names = val; }
513536
};
514537

515538
/**
@@ -758,6 +781,21 @@ class parquet_reader_options_builder {
758781
return *this;
759782
}
760783

784+
/**
785+
* @brief Sets whether column name matching is case sensitive.
786+
*
787+
* @note When disabled, if there are multiple case-insensitive matches, the first
788+
* matched column is selected from the Parquet schema.
789+
*
790+
* @param val Boolean indicating whether to enable case-sensitive matching
791+
* @return this for chaining
792+
*/
793+
parquet_reader_options_builder& case_sensitive_names(bool val)
794+
{
795+
options._case_sensitive_names = val;
796+
return *this;
797+
}
798+
761799
/**
762800
* @brief move parquet_reader_options member once it's built.
763801
*/

cpp/src/io/parquet/experimental/hybrid_scan_helpers.cpp

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include "hybrid_scan_helpers.hpp"
77

88
#include "io/parquet/compact_protocol_reader.hpp"
9+
#include "io/parquet/expression_transform_helpers.hpp"
910
#include "io/parquet/reader_impl_helpers.hpp"
10-
#include "io/utilities/row_selection.hpp"
1111

1212
#include <cudf/detail/nvtx/ranges.hpp>
1313
#include <cudf/logger.hpp>
@@ -201,7 +201,8 @@ aggregate_reader_metadata::select_payload_columns(
201201
bool strings_to_categorical,
202202
bool ignore_missing_columns,
203203
type_id timestamp_type_id,
204-
type_id decimal_type_id)
204+
type_id decimal_type_id,
205+
bool case_sensitive_names)
205206
{
206207
// If neither payload nor filter columns are specified, select all columns
207208
if (not payload_column_names.has_value() and not filter_column_names.has_value()) {
@@ -212,27 +213,40 @@ aggregate_reader_metadata::select_payload_columns(
212213
strings_to_categorical,
213214
ignore_missing_columns,
214215
timestamp_type_id,
215-
decimal_type_id);
216+
decimal_type_id,
217+
case_sensitive_names);
216218
}
217219

218220
std::vector<std::string> valid_payload_columns;
219221

222+
using cudf::io::parquet::detail::normalize_column_path;
223+
224+
// Helper lambda to construct a set of normalized column names for O(1) lookup
225+
auto construct_filter_columns_set = [](auto const& names, bool case_sensitive_names) {
226+
std::unordered_set<std::string> filter_columns_set;
227+
for (auto const& name : names) {
228+
filter_columns_set.insert(normalize_column_path(name, case_sensitive_names));
229+
}
230+
return filter_columns_set;
231+
};
232+
220233
// If payload columns are specified, only select payload columns that do not appear in the filter
221234
// expression
222235
if (payload_column_names.has_value()) {
223236
valid_payload_columns = *payload_column_names;
224237
// Remove filter columns from the provided payload column names
225238
if (filter_column_names.has_value() and not filter_column_names->empty()) {
226-
// Add filter column names to a hash set for faster lookup
227-
std::unordered_set<std::string> filter_columns_set(filter_column_names->begin(),
228-
filter_column_names->end());
239+
auto const filter_columns_set =
240+
construct_filter_columns_set(*filter_column_names, case_sensitive_names);
229241
// Remove a payload column name if it is also present in the hash set
230-
valid_payload_columns.erase(std::remove_if(valid_payload_columns.begin(),
231-
valid_payload_columns.end(),
232-
[&filter_columns_set](auto const& col) {
233-
return filter_columns_set.count(col) > 0;
234-
}),
235-
valid_payload_columns.end());
242+
valid_payload_columns.erase(
243+
std::remove_if(valid_payload_columns.begin(),
244+
valid_payload_columns.end(),
245+
[&](auto const& col) {
246+
return filter_columns_set.count(
247+
normalize_column_path(col, case_sensitive_names)) > 0;
248+
}),
249+
valid_payload_columns.end());
236250
}
237251
// Call the base `select_columns()` method with valid payload columns
238252
return select_columns(valid_payload_columns,
@@ -241,40 +255,40 @@ aggregate_reader_metadata::select_payload_columns(
241255
strings_to_categorical,
242256
ignore_missing_columns,
243257
timestamp_type_id,
244-
decimal_type_id);
258+
decimal_type_id,
259+
case_sensitive_names);
245260
}
246261

247262
// Else if only filter columns are specified, select all columns that do not appear in the
248263
// filter expression
249-
250-
// Add filter column names to a hash set for faster lookup
251-
std::unordered_set<std::string> filter_columns_set(filter_column_names->begin(),
252-
filter_column_names->end());
264+
auto const filter_columns_set =
265+
construct_filter_columns_set(*filter_column_names, case_sensitive_names);
253266

254267
std::function<void(std::string, int)> add_column_path = [&](std::string path_till_now,
255268
int schema_idx) {
256269
auto const& schema_elem = get_schema(schema_idx);
257270
std::string const curr_path = path_till_now + schema_elem.name;
258-
// Add the current path to the list of valid payload columns if it is not a filter column
259271
// TODO: Add children when AST filter expressions start supporting nested struct columns
260-
if (filter_columns_set.count(curr_path) == 0) { valid_payload_columns.push_back(curr_path); }
272+
if (filter_columns_set.count(normalize_column_path(curr_path, case_sensitive_names)) == 0) {
273+
valid_payload_columns.push_back(curr_path);
274+
}
261275
};
262276

263-
// Add all but filter columns to valid payload columns
264277
if (not filter_column_names->empty()) {
265-
for (auto const& child_idx : get_schema(0).children_idx) {
278+
auto const& root = get_schema(0);
279+
for (auto const& child_idx : root.children_idx) {
266280
add_column_path("", child_idx);
267281
}
268282
}
269283

270-
// Call the base `select_columns()` method with all but filter columns
271284
return select_columns(valid_payload_columns,
272285
{},
273286
include_index,
274287
strings_to_categorical,
275288
ignore_missing_columns,
276289
timestamp_type_id,
277-
decimal_type_id);
290+
decimal_type_id,
291+
case_sensitive_names);
278292
}
279293

280294
std::vector<std::vector<cudf::size_type>>
@@ -589,19 +603,26 @@ named_to_reference_converter::named_to_reference_converter(
589603
std::optional<std::reference_wrapper<ast::expression const>> expr,
590604
table_metadata const& metadata,
591605
std::vector<SchemaElement> const& schema_tree,
592-
cudf::io::parquet_reader_options const& options)
606+
cudf::io::parquet_reader_options const& options,
607+
bool case_sensitive_names)
593608
{
594609
if (!expr.has_value()) { return; }
595610

596-
_column_indices_to_names =
597-
cudf::io::parquet::detail::map_column_indices_to_names(options, schema_tree);
611+
_case_sensitive_names = case_sensitive_names;
612+
613+
_column_indices_to_names = cudf::io::parquet::detail::map_column_indices_to_names(
614+
options, schema_tree, case_sensitive_names);
598615

599616
// Map column names to their indices
600-
std::transform(metadata.schema_info.cbegin(),
601-
metadata.schema_info.cend(),
602-
thrust::counting_iterator<size_t>(0),
603-
std::inserter(_column_name_to_index, _column_name_to_index.end()),
604-
[](auto const& sch, auto index) { return std::make_pair(sch.name, index); });
617+
std::transform(
618+
metadata.schema_info.cbegin(),
619+
metadata.schema_info.cend(),
620+
thrust::counting_iterator<size_t>(0),
621+
std::inserter(_column_name_to_index, _column_name_to_index.end()),
622+
[&](auto const& sch, auto index) {
623+
return std::make_pair(
624+
cudf::io::parquet::detail::normalize_column_path(sch.name, case_sensitive_names), index);
625+
});
605626

606627
expr.value().get().accept(*this);
607628
}
@@ -611,12 +632,17 @@ std::reference_wrapper<ast::expression const> named_to_reference_converter::visi
611632
{
612633
// Map the column index to its name
613634
auto const col_name_iter = _column_indices_to_names.find(expr.get_column_index());
614-
CUDF_EXPECTS(col_name_iter != _column_indices_to_names.end(),
615-
"Column index not found in column indices to names map");
635+
CUDF_EXPECTS(
636+
col_name_iter != _column_indices_to_names.end(),
637+
"Column index in the filter expression not found in the column indices to names map. Note that "
638+
"only top-level columns except structs and lists are supported in "
639+
"Parquet filter expression",
640+
std::invalid_argument);
616641
auto const col_name = col_name_iter->second;
617-
// Check if the column name exists in the metadata and map it to its new column index
618-
auto col_index_it = _column_name_to_index.find(col_name);
619-
CUDF_EXPECTS(col_index_it != _column_name_to_index.end(), "Column name not found in metadata");
642+
auto col_index_it = _column_name_to_index.find(col_name);
643+
CUDF_EXPECTS(col_index_it != _column_name_to_index.end(),
644+
"Column name mapped from its index in the filter expression "
645+
"not found in the metadata of selected columns");
620646
auto col_index = col_index_it->second;
621647
// Create a new column reference
622648
_col_ref.emplace_back(col_index);

cpp/src/io/parquet/experimental/hybrid_scan_helpers.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class aggregate_reader_metadata : public aggregate_reader_metadata_base {
146146
* @param ignore_missing_columns Whether to ignore non-existent columns
147147
* @param timestamp_type_id Type conversion parameter
148148
* @param decimal_type_id Type conversion parameter
149+
* @param case_sensitive_names Boolean indicating if column names are case sensitive
149150
*
150151
* @return input column information, output column buffers, list of output column schema
151152
* indices
@@ -158,7 +159,8 @@ class aggregate_reader_metadata : public aggregate_reader_metadata_base {
158159
bool strings_to_categorical,
159160
bool ignore_missing_columns,
160161
type_id timestamp_type_id,
161-
type_id decimal_type_id);
162+
type_id decimal_type_id,
163+
bool case_sensitive_names);
162164

163165
/**
164166
* @brief Filters row groups such that only the row groups that start within the byte range
@@ -369,7 +371,8 @@ class named_to_reference_converter : public parquet::detail::named_to_reference_
369371
named_to_reference_converter(std::optional<std::reference_wrapper<ast::expression const>> expr,
370372
table_metadata const& metadata,
371373
std::vector<SchemaElement> const& schema_tree,
372-
cudf::io::parquet_reader_options const& options);
374+
cudf::io::parquet_reader_options const& options,
375+
bool case_sensitive_names);
373376

374377
using parquet::detail::named_to_reference_converter::visit;
375378

0 commit comments

Comments
 (0)