Skip to content

Commit 1fd5697

Browse files
[data] DataSourceV2: V2 ARROW-5030 nested-type fallback (#63175)
## Why Parquet files with nested columns (e.g. `list<struct<..., string>>`) whose row groups exceed Arrow's ~2 GB chunking threshold hit `ArrowNotImplementedError` at decode time (ARROW-5030). V1 already has a metadata-only fallback that detects this and switches to `pq.ParquetFile.iter_batches`. This PR ports it to V2 and makes the decision filter-aware. ## What **Port V1's nested-type fallback to V2.** `FileReader` grows an `_iter_fragment_tables` hook; `ParquetFileReader` overrides it with V1's `_needs_nested_type_fallback` metadata check, falling back to `pq.ParquetFile.iter_batches` (with safe batch sizing, row-group pushdown via `fragment.subset`, and per-batch row-level filtering) when the check fires. **Make the fallback decision filter-aware.** Previously the check looked only at projected columns. A filter that touches a large nested column *outside* the projection would still force the scanner to decode it for row-level evaluation — and hit ARROW-5030. The check now sees the union of projected + filter-referenced columns: ```python ds.read_parquet(path).select_columns(["id"]).filter(col("nested_col").is_not_null()) # ^^^^ projection excludes nested_col # ^^^^ but filter references it # → fallback must trigger ``` **Carry the predicate as a Ray `Expr` instead of a pyarrow expression.** `pyarrow.compute.Expression` is opaque (no public visitor), so we can't extract filter columns from it after the fact. Keeping the Ray `Expr` as the source of truth — and converting to pyarrow once, at the scanner-kwargs boundary — lets the reader call `get_column_references` for the union above. Touches `ArrowFileScanner.predicate`, `FileReader.predicate`, and `push_filters` (now ANDs Ray `Expr`s). **Drop the legacy `filter=` kwarg on V2.** `read_parquet(filter=pc.field("x") > 5)` is already deprecated. Since it carries a raw pyarrow expression that can't be introspected, it's silently stripped on the V2 path. Callers should use `read_parquet(path).filter(expr=...)`. ## Tests - `test_read_parquet_nested_type_arrow_not_implemented_fallback` — V2 skip removed (regression for [#61675](#61675)). - `test_read_parquet_nested_fallback_triggered_when_filter_references_nested_column` — new, V2-only. Projects a flat column and filters on the large nested column; asserts the fallback is invoked. Signed-off-by: Goutam <goutam@anyscale.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/ray-project/ray/pull/63175). * #63326 * __->__ #63175 Signed-off-by: Goutam <goutam@anyscale.com> Co-authored-by: Goutam V. <> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9ddd07f commit 1fd5697

8 files changed

Lines changed: 340 additions & 53 deletions

File tree

python/ray/data/_internal/datasource_v2/parquet_datasource_v2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,6 @@ def create_scanner(
272272
filesystem: Optional["FileSystem"] = None,
273273
**options: Any,
274274
) -> ParquetScanner:
275-
# ``filter=`` in V1 read_parquet() is the legacy pyarrow-compute
276-
# predicate. Stamp it on the scanner's ``predicate`` field so it's
277-
# honored at scan time (V2 does not yet dispatch Ray-level
278-
# predicate pushdown rules).
279-
predicate = self._arrow_parquet_args.get("filter")
280275
# Callers (``_read_datasource_v2``) supply the sample-resolved
281276
# ``Partitioning`` via ``options["partitioning"]`` so the
282277
# datasource itself stays immutable — fall back to the
@@ -291,5 +286,4 @@ def create_scanner(
291286
shuffle=self._shuffle,
292287
ignore_prefixes=options.get("ignore_prefixes"),
293288
target_block_size=DataContext.get_current().target_max_block_size,
294-
predicate=predicate,
295289
)

python/ray/data/_internal/datasource_v2/readers/file_reader.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from enum import Enum
2-
from functools import cached_property
2+
from functools import cached_property, partial
33
from typing import Any, Iterator, List, Optional, Set, Tuple
44

55
import pyarrow as pa
66
import pyarrow.dataset as pds
7-
from pyarrow import compute as pc
87
from pyarrow.fs import FileSystem, LocalFileSystem
98

109
from ray.data._internal.arrow_block import _BATCH_SIZE_PRESERVING_STUB_COL_NAME
@@ -14,6 +13,7 @@
1413
from ray.data._internal.util import iterate_with_retry
1514
from ray.data.context import DataContext
1615
from ray.data.datasource.partitioning import Partitioning, PathPartitionParser
16+
from ray.data.expressions import Expr
1717
from ray.util.annotations import DeveloperAPI
1818

1919
# Synthetic column name produced when ``include_paths=True``. Shared with
@@ -58,7 +58,7 @@ def __init__(
5858
format: FileFormat,
5959
batch_size: int = _ARROW_DEFAULT_BATCH_SIZE,
6060
columns: Optional[List[str]] = None,
61-
predicate: Optional[pc.Expression] = None,
61+
predicate: Optional[Expr] = None,
6262
limit: Optional[int] = None,
6363
filesystem: Optional[FileSystem] = None,
6464
partitioning: Optional[Partitioning] = None,
@@ -74,7 +74,8 @@ def __init__(
7474
format: Format of the files to read.
7575
batch_size: Number of rows per batch.
7676
columns: Columns to read. None means all columns.
77-
predicate: PyArrow compute expression for filtering.
77+
predicate: Ray Data expression for filtering. Converted to a
78+
PyArrow expression at the scanner-kwargs boundary.
7879
limit: Maximum number of rows to read.
7980
filesystem: Filesystem for reading files.
8081
partitioning: Ray ``Partitioning`` object. Partition columns are
@@ -226,18 +227,17 @@ def read(self, input_split: FileManifest) -> Iterator[pa.Table]:
226227

227228
scanner_kwargs = {
228229
"columns": columns_to_read_from_file,
229-
"filter": self._predicate,
230+
"filter": (
231+
self._predicate.to_pyarrow() if self._predicate is not None else None
232+
),
230233
"batch_size": self._resolve_batch_size(dataset),
231234
"batch_readahead": _ARROW_SCANNER_BATCH_READAHEAD,
232235
}
233236
scanner_kwargs.update(self._arrow_scanner_kwargs())
234237

235-
ctx = DataContext.get_current()
236238
rows_read = 0
237-
for table, fragment_path, fragment_row_offset in iterate_with_retry(
238-
lambda: self._read_fragment_batches(dataset, scanner_kwargs),
239-
"read batches",
240-
match=ctx.retried_io_errors,
239+
for table, fragment_path, fragment_row_offset in self._read_fragment_batches(
240+
dataset, scanner_kwargs
241241
):
242242
if self._limit is not None:
243243
if rows_read >= self._limit:
@@ -340,33 +340,57 @@ def _read_fragment_batches(
340340
one fragment at a time.
341341
342342
``fragment_row_offset`` is the post-filter row position of the first
343-
row of ``table`` within the current fragment. Tracking it inside the
344-
generator means it resets correctly whenever ``iterate_with_retry``
345-
recreates the generator on a retry — outer-loop state would otherwise
346-
carry stale values from the failed attempt and corrupt row hashes.
343+
row of ``table`` within the current fragment. ``iterate_with_retry``
344+
skips already-yielded items on retry, so ``offset`` reflects only the
345+
rows that actually surface to the caller — matching V1 row-hash
346+
semantics even when a fragment fails partway through.
347+
348+
Retry is scoped per-fragment: if a fragment fails mid-read, only
349+
that fragment is re-read (skipping batches already yielded).
350+
Wrapping the whole manifest in a single retry would re-iterate
351+
fragments that already succeeded and double-emit their batches.
347352
348353
Each fragment gets its own scanner so pyarrow uses the native
349354
per-file schema. A cross-fragment scanner would force a unified
350355
schema cast, which refuses extension-to-extension conversion
351356
(e.g. variable-shape tensors). V1 ``ParquetDatasource`` follows
352357
the same per-fragment pattern via ``fragment.to_batches``.
353-
354-
When a non-extension caller schema is available we pin it at the
355-
scanner so pyarrow null-fills any column the unified schema names
356-
but the fragment lacks (V1 parity). Falling back to the
357-
per-fragment ``physical_schema`` preserves the variable-shape
358-
tensor escape hatch already encoded in ``_file_dataset_schema``.
359358
"""
359+
ctx = DataContext.get_current()
360360
for fragment in dataset.get_fragments():
361-
fragment_schema = (
362-
self._file_dataset_schema
363-
if self._file_dataset_schema is not None
364-
else fragment.physical_schema
365-
)
366-
scanner = fragment.scanner(**scanner_kwargs, schema=fragment_schema)
367361
offset = 0
368-
for tagged in scanner.scan_batches():
369-
table = pa.Table.from_batches(batches=[tagged.record_batch])
362+
for table in iterate_with_retry(
363+
partial(self._iter_fragment_tables, fragment, scanner_kwargs),
364+
f"read fragment {fragment.path}",
365+
match=ctx.retried_io_errors,
366+
):
370367
if table.num_rows > 0:
371368
yield table, fragment.path, offset
372369
offset += table.num_rows
370+
371+
def _iter_fragment_tables(
372+
self,
373+
fragment: pds.Fragment,
374+
scanner_kwargs: dict,
375+
) -> Iterator[pa.Table]:
376+
"""Yield Arrow tables for a single fragment.
377+
378+
Subclasses override this to swap in a format-specific reader for
379+
fragments that don't fit the default scanner-based path (e.g.
380+
Parquet's ARROW-5030 nested-type fallback).
381+
382+
When a non-extension caller schema is available we pin it at the
383+
scanner so pyarrow null-fills any column the unified schema names
384+
but the fragment lacks (V1 parity — ``ParquetDatasource`` passes
385+
``read_schema`` to ``fragment.to_batches``). Falling back to the
386+
per-fragment ``physical_schema`` preserves the variable-shape
387+
tensor escape hatch already encoded in ``_file_dataset_schema``.
388+
"""
389+
fragment_schema = (
390+
self._file_dataset_schema
391+
if self._file_dataset_schema is not None
392+
else fragment.physical_schema
393+
)
394+
scanner = fragment.scanner(**scanner_kwargs, schema=fragment_schema)
395+
for tagged in scanner.scan_batches():
396+
yield pa.Table.from_batches(batches=[tagged.record_batch])

python/ray/data/_internal/datasource_v2/readers/parquet_file_reader.py

Lines changed: 174 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import logging
22
import math
3-
from typing import TYPE_CHECKING, List, Optional
3+
from typing import TYPE_CHECKING, Iterator, List, Optional
44

55
import pyarrow as pa
66
import pyarrow.dataset as pds
77
import pyarrow.parquet as pq
8-
from pyarrow import compute as pc
98
from pyarrow.fs import FileSystem
109
from typing_extensions import override
1110

@@ -20,7 +19,9 @@
2019
from ray.data._internal.datasource_v2.readers.in_memory_size_estimator import (
2120
PARQUET_ENCODING_RATIO_ESTIMATE_DEFAULT,
2221
)
22+
from ray.data.expressions import Expr
2323
from ray.util.annotations import DeveloperAPI
24+
from ray.util.debug import log_once
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -135,7 +136,7 @@ def __init__(
135136
self,
136137
batch_size: Optional[int] = None,
137138
columns: Optional[List[str]] = None,
138-
predicate: Optional[pc.Expression] = None,
139+
predicate: Optional[Expr] = None,
139140
limit: Optional[int] = None,
140141
filesystem: Optional[FileSystem] = None,
141142
partitioning: "Optional[Partitioning]" = None,
@@ -151,7 +152,7 @@ def __init__(
151152
batch_size: Explicit batch size override. If provided, disables
152153
adaptive batch sizing.
153154
columns: Columns to read. None means all columns.
154-
predicate: PyArrow compute expression for filtering.
155+
predicate: Ray Data expression for filtering.
155156
limit: Maximum number of rows to read.
156157
filesystem: Filesystem for reading files.
157158
partitioning: Ray ``Partitioning`` for synthesizing partition
@@ -229,6 +230,173 @@ def _on_batch_read(self, table: pa.Table) -> None:
229230
row_size = table.nbytes / table.num_rows
230231
self._sampled_batch_size = max(math.ceil(self._target_block_size / row_size), 1)
231232

233+
@override
234+
def _iter_fragment_tables(
235+
self,
236+
fragment: pds.Fragment,
237+
scanner_kwargs: dict,
238+
) -> "Iterator[pa.Table]":
239+
"""Use V1's nested-type fallback path when the fragment has nested
240+
columns whose row-group size exceeds Arrow's ~2GB chunking limit
241+
(ARROW-5030).
242+
"""
243+
import pyarrow.compute as pc
244+
245+
from ray.data._internal.arrow_ops.transform_pyarrow import (
246+
_align_struct_fields,
247+
)
248+
from ray.data._internal.datasource.parquet_datasource import (
249+
_get_safe_batch_size_for_nested_types,
250+
_needs_nested_type_fallback,
251+
_resolve_leaf_column_indices,
252+
_resolve_read_columns,
253+
)
254+
from ray.data._internal.planner.plan_expression.expression_visitors import (
255+
get_column_references,
256+
)
257+
258+
columns = scanner_kwargs.get("columns")
259+
filter_expr: pc.Expression = scanner_kwargs.get("filter")
260+
# Include filter-referenced columns in the fallback check: a filter
261+
# that touches a large nested column outside the projection still
262+
# forces row-level decoding of that column, which would otherwise
263+
# hit ARROW-5030 in the normal scanner path.
264+
filter_columns = (
265+
get_column_references(self._predicate)
266+
if self._predicate is not None
267+
else None
268+
)
269+
read_columns = _resolve_read_columns(columns, filter_expr, filter_columns)
270+
if not _needs_nested_type_fallback(fragment, read_columns):
271+
yield from super()._iter_fragment_tables(fragment, scanner_kwargs)
272+
return
273+
274+
if log_once(f"parquet_nested_fallback_v2:{fragment.path}"):
275+
logger.warning(
276+
"Using pyarrow.parquet row-level batched reader for '%s' due "
277+
"to Arrow nested type chunking limitation (ARROW-5030). "
278+
"Consider writing Parquet files with smaller row group sizes "
279+
"to avoid this.",
280+
fragment.path,
281+
)
282+
283+
batch_size = scanner_kwargs.get("batch_size")
284+
285+
pf = pq.ParquetFile(
286+
fragment.path,
287+
filesystem=fragment.filesystem, # pyrefly: ignore[unexpected-keyword]
288+
)
289+
290+
# Scope the safe batch-size calculation to the columns actually being
291+
# decoded so we don't shrink batches based on columns we won't read.
292+
leaf_indices = (
293+
_resolve_leaf_column_indices(pf.metadata, read_columns)
294+
if read_columns is not None and pf.metadata.num_row_groups > 0
295+
else None
296+
)
297+
safe_batch_size = _get_safe_batch_size_for_nested_types(pf, leaf_indices)
298+
fallback_batch_size = (
299+
min(batch_size, safe_batch_size) if batch_size else safe_batch_size
300+
)
301+
302+
# Apply row-group-level predicate pushdown via fragment.subset; the
303+
# row-level filter is applied per-batch below since iter_batches
304+
# doesn't accept a filter expression. Under schema evolution the
305+
# filter may reference a column absent from this fragment's
306+
# physical schema — fragment.subset uses that schema (not the
307+
# unified one) and raises ArrowInvalid, so skip row-group pruning
308+
# in that case and let the per-batch filter (post null-fill) do
309+
# all the row-dropping.
310+
fragment_physical_columns = set(fragment.physical_schema.names)
311+
filter_touches_missing_column = filter_columns is not None and any(
312+
c not in fragment_physical_columns for c in filter_columns
313+
)
314+
if filter_expr is not None and not filter_touches_missing_column:
315+
subset = fragment.subset(filter=filter_expr)
316+
else:
317+
subset = fragment
318+
row_groups = (
319+
[rg.id for rg in subset.row_groups]
320+
if subset.row_groups is not None
321+
else None
322+
)
323+
if row_groups is not None and len(row_groups) == 0:
324+
return
325+
326+
# ``pq.ParquetFile.iter_batches`` returns batches with the fragment's
327+
# physical schema, so the fallback path would otherwise emit tables
328+
# that differ from the scanner path (which pins
329+
# ``_file_dataset_schema``) in struct field order, integer width,
330+
# or missing columns. Align + cast to the same unified schema so
331+
# fallback and non-fallback fragments concat cleanly downstream.
332+
# Scoped to ``columns`` (not ``read_columns``) since filter-only
333+
# columns are projected away before alignment.
334+
file_dataset_schema = self._file_dataset_schema
335+
if file_dataset_schema is not None and columns is not None:
336+
align_schema = pa.schema(
337+
[
338+
file_dataset_schema.field(c)
339+
for c in columns
340+
if file_dataset_schema.get_field_index(c) != -1
341+
]
342+
)
343+
else:
344+
align_schema = file_dataset_schema
345+
346+
# Under schema evolution a filter-referenced column may live in
347+
# the unified dataset schema but be absent from this fragment.
348+
# The scanner path null-fills such columns via dataset-level
349+
# schema pinning; ``pq.ParquetFile.iter_batches`` silently drops
350+
# them and then ``table.filter(filter_expr)`` raises
351+
# ``ArrowInvalid: No match for FieldRef.Name``. Mirror the
352+
# scanner: append a null column of the unified type before the
353+
# filter evaluates, so ``null > 15`` resolves to false and the
354+
# fragment contributes 0 rows.
355+
columns_to_null_fill: List[str] = (
356+
[c for c in read_columns if c not in fragment_physical_columns]
357+
if read_columns is not None
358+
else []
359+
)
360+
null_fill_type_by_column = {
361+
column_name: (
362+
file_dataset_schema.field(column_name).type
363+
if file_dataset_schema is not None
364+
and file_dataset_schema.get_field_index(column_name) != -1
365+
else pa.null()
366+
)
367+
for column_name in columns_to_null_fill
368+
}
369+
370+
for batch in pf.iter_batches(
371+
batch_size=fallback_batch_size,
372+
columns=read_columns,
373+
use_threads=False,
374+
row_groups=row_groups,
375+
):
376+
table = pa.Table.from_batches([batch])
377+
for column_name in columns_to_null_fill:
378+
if column_name not in table.column_names:
379+
table = table.append_column(
380+
column_name,
381+
pa.nulls(
382+
table.num_rows,
383+
type=null_fill_type_by_column[column_name],
384+
),
385+
)
386+
if filter_expr is not None:
387+
table = table.filter(filter_expr)
388+
# Skip downstream select/align/cast on fully-filtered
389+
# batches — the caller discards empty tables anyway.
390+
if table.num_rows == 0:
391+
continue
392+
if columns is not None:
393+
table = table.select([c for c in columns if c in table.column_names])
394+
if align_schema is not None:
395+
table = _align_struct_fields([table], align_schema)[0].cast(
396+
align_schema
397+
)
398+
yield table
399+
232400
@override
233401
def _arrow_scanner_kwargs(self) -> dict:
234402
# pre_buffer=True (pyarrow default) holds a whole fragment's worth of
@@ -239,10 +407,11 @@ def _arrow_scanner_kwargs(self) -> dict:
239407
# while keeping throughput equal to the default. batch_readahead=1
240408
# (inherited from FileReader base kwargs) plus fragment_readahead=1
241409
# is enough to keep decode pipelined. See apache/arrow#39808.
242-
return {
410+
kwargs: dict = {
243411
"fragment_scan_options": pds.ParquetFragmentScanOptions(
244412
pre_buffer=False,
245413
use_buffered_stream=True,
246414
),
247415
"fragment_readahead": 1,
248416
}
417+
return kwargs

0 commit comments

Comments
 (0)