Skip to content

Commit 4932bfd

Browse files
goutamvenkat-anyscaleclaude
authored andcommitted
[data] DataSourceV2: support _block_udf + tensor_column_schema; per-fragment reads for variable-shape tensors (ray-project#63174)
``read_parquet`` was raising ``NotImplementedError`` for ``_block_udf`` and ``tensor_column_schema`` on the V2 path, skipping a batch of V1 tests. Wire them through: - ``ReadFiles`` gets an optional ``block_udf: Callable[[Block], Block]`` field. ``plan_read_files_op`` applies it after ``reader.read(manifest)`` and before column renames so the UDF sees on-disk column names (V1 ``ParquetDatasource`` semantics). - ``_read_datasource_v2`` accepts a ``block_udf`` kwarg and stores it on the logical op. - ``ReadFiles.infer_schema`` probes the UDF's schema effect via a dummy empty table (mirrors V1's ``dummy_table`` trick) so ``ds.schema()`` reflects post-transform types before materialization. The scanner keeps the *pre-UDF* schema so pyarrow sees the raw on-disk types. - ``read_parquet`` drops the two ``NotImplementedError`` raises; ``tensor_column_schema`` is already folded into ``_block_udf`` by ``_resolve_parquet_args`` so no extra handling is needed. While un-skipping V1 tests, a second issue surfaced: ``test_multiple_files_with_ragged_arrays`` was failing because ``pds.dataset(paths).scanner().scan_batches()`` forces a cross-fragment schema unification inside pyarrow. That unification casts per-file ``ArrowTensorTypeV2(shape=X)`` to the unified type and pyarrow refuses extension-to-extension casts — "One can first cast to the storage type, then to the extension type". V1 avoids this by iterating ``fragment.to_batches`` per fragment. Port the pattern: ``FileReader._read_fragment_batches`` builds a per-fragment scanner with that fragment's ``physical_schema`` so pyarrow keeps the native per-file type. Downstream concat handles heterogeneous block schemas, same as V1. The caller-supplied ``file_dataset_schema`` still applies for the common all-null first-column case, and steps aside when any extension column is present. Tests: V2 unit 64/64, parquet broad slice 103 pass / 1 skip / 0 fail, checkpoint suite 63 pass / 3 pre-existing ``ModuleNotFoundError`` failures (reproduced on master). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Goutam <goutam@anyscale.com> Signed-off-by: Goutam <goutam@anyscale.com> Co-authored-by: Goutam V. <> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 60ecbed commit 4932bfd

10 files changed

Lines changed: 221 additions & 47 deletions

File tree

python/ray/data/_internal/datasource_v2/parquet_datasource_v2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
file_extensions: Optional[List[str]] = None,
6969
ignore_missing_paths: bool = False,
7070
include_paths: bool = False,
71+
include_row_hash: bool = False,
7172
shuffle: Optional[Union[Literal["files"], "FileShuffleConfig"]] = None,
7273
arrow_parquet_args: Optional[dict] = None,
7374
schema: Optional[pa.Schema] = None,
@@ -89,6 +90,7 @@ def __init__(
8990
self._file_extensions = file_extensions or ParquetDatasource._FILE_EXTENSIONS
9091
self._ignore_missing_paths = ignore_missing_paths
9192
self._include_paths = include_paths
93+
self._include_row_hash = include_row_hash
9294
self._shuffle = shuffle
9395
self._arrow_parquet_args = arrow_parquet_args or {}
9496
# User-supplied schema override. When set, ``infer_schema`` returns
@@ -245,6 +247,16 @@ def _read_schema(path: str):
245247
if self._include_paths and schema.get_field_index("path") == -1:
246248
schema = schema.append(pa.field("path", pa.string()))
247249

250+
if self._include_row_hash:
251+
# ``row_hash`` is synthesized post-read as ``uint64``. Replace
252+
# the field type when the file already has a ``row_hash``
253+
# column (matches V1 ``_derive_schema``); otherwise append.
254+
idx = schema.get_field_index("row_hash")
255+
if idx == -1:
256+
schema = schema.append(pa.field("row_hash", pa.uint64()))
257+
elif schema.field(idx).type != pa.uint64():
258+
schema = schema.set(idx, pa.field("row_hash", pa.uint64()))
259+
248260
check_for_legacy_tensor_type(schema)
249261
return schema
250262

@@ -269,6 +281,7 @@ def create_scanner(
269281
filesystem=filesystem or self._filesystem,
270282
partitioning=partitioning,
271283
include_paths=self._include_paths,
284+
include_row_hash=self._include_row_hash,
272285
shuffle=self._shuffle,
273286
ignore_prefixes=options.get("ignore_prefixes"),
274287
target_block_size=DataContext.get_current().target_max_block_size,

python/ray/data/_internal/datasource_v2/readers/file_reader.py

Lines changed: 117 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pyarrow.fs import FileSystem, LocalFileSystem
99

1010
from ray.data._internal.arrow_block import _BATCH_SIZE_PRESERVING_STUB_COL_NAME
11+
from ray.data._internal.datasource.parquet_datasource import _compute_row_hashes
1112
from ray.data._internal.datasource_v2.listing.file_manifest import FileManifest
1213
from ray.data._internal.datasource_v2.readers.base_reader import Reader
1314
from ray.data._internal.util import iterate_with_retry
@@ -19,6 +20,14 @@
1920
# Default is specified by PyArrow.
2021
_ARROW_DEFAULT_BATCH_SIZE = 131_072
2122

23+
# Small fixed readahead keeps driver memory bounded when scanning
24+
# uncompressed batches (jumbo tensor columns can run to multi-GB per
25+
# batch, and pyarrow's default 16-batch readahead would retain all of
26+
# them).
27+
_ARROW_SCANNER_BATCH_READAHEAD = 1
28+
29+
_ROW_HASH_COLUMN_NAME = "row_hash"
30+
2231

2332
class FileFormat(str, Enum):
2433
PARQUET = "parquet"
@@ -50,6 +59,7 @@ def __init__(
5059
partitioning: Optional[Partitioning] = None,
5160
ignore_prefixes: Optional[List[str]] = None,
5261
include_paths: bool = False,
62+
include_row_hash: bool = False,
5363
schema: Optional[pa.Schema] = None,
5464
):
5565
"""Initialize the reader.
@@ -68,6 +78,12 @@ def __init__(
6878
ignore_prefixes: Prefixes to ignore when reading files. Default is ['.', '_'] set by PyArrow.
6979
include_paths: If True, include the source file path in a
7080
``'path'`` column for each row.
81+
include_row_hash: If True, include a deterministic uint64 hash
82+
per row in a ``'row_hash'`` column. The hash is derived from
83+
the source file path and the row's post-filter output
84+
position within the fragment, matching V1 semantics. If a
85+
``'row_hash'`` column already exists in the file, it is
86+
overwritten.
7187
schema: Caller-supplied unified schema used both to override
7288
pyarrow's per-fragment inference (so a file whose column
7389
is all-null doesn't pin the type to ``null``) and to cast
@@ -86,27 +102,53 @@ def __init__(
86102
)
87103
self._ignore_prefixes = ignore_prefixes
88104
self._include_paths = include_paths
105+
self._include_row_hash = include_row_hash
89106
self._schema = schema
90107

91108
@cached_property
92109
def _file_dataset_schema(self) -> Optional[pa.Schema]:
93110
"""Schema passed to ``pds.dataset`` — partition keys and ``path``
94111
stripped out since those are synthesized post-read.
95112
96-
A caller-supplied schema overrides pyarrow's per-fragment
97-
inference — without it, a file with all-null values in column X
98-
pins X to ``null`` type and pyarrow can't cast string → null in
99-
later files.
113+
Pinning the caller-supplied schema at the pyarrow layer is how
114+
we cover the "first file has an all-null column, later files
115+
have the real type" case (e.g.
116+
``test_read_null_data_in_first_file``): without the pin,
117+
pyarrow locks column X to ``null`` across the fragment group
118+
and the later string-typed file fails the cast.
119+
120+
But pyarrow refuses extension-to-extension casts (e.g.
121+
``ArrowTensorTypeV2(shape=X)`` → ``ArrowVariableShapedTensor``),
122+
and files with different per-file tensor shapes only unify
123+
through ``ArrowVariableShapedTensor``. When the caller schema
124+
contains *any* extension column we skip the pin entirely and
125+
let pyarrow infer per-file — downstream concat handles the
126+
heterogeneous blocks. Losing the all-null promotion in this
127+
narrow case is acceptable; the combination of an all-null
128+
first file *and* an extension column is uncommon, whereas
129+
reading multiple files with variable-shape tensors is a
130+
supported V1 feature.
100131
"""
101132
if self._schema is None:
102133
return None
134+
if any(isinstance(f.type, pa.ExtensionType) for f in self._schema):
135+
return None
103136
partition_keys = (
104137
set(self._partition_parser._scheme.field_names or [])
105138
if self._partition_parser is not None
106139
else set()
107140
)
141+
synthesized = {"path"}
142+
if self._include_row_hash:
143+
# ``row_hash`` is synthesized post-read, and the schema's type
144+
# (``uint64``) may not match the on-disk column's type when a
145+
# file already carries a ``row_hash`` column. Strip it from the
146+
# dataset schema so pyarrow doesn't try to cast.
147+
synthesized.add(_ROW_HASH_COLUMN_NAME)
108148
fields = [
109-
f for f in self._schema if f.name not in partition_keys and f.name != "path"
149+
f
150+
for f in self._schema
151+
if f.name not in partition_keys and f.name not in synthesized
110152
]
111153
return pa.schema(fields) if fields else None
112154

@@ -146,6 +188,14 @@ def read(self, input_split: FileManifest) -> Iterator[pa.Table]:
146188

147189
paths = list(input_split.paths)
148190
filesystem = self._filesystem or LocalFileSystem()
191+
# Build a ``pds.Dataset`` over *all* manifest paths so pyarrow's
192+
# listing + column metadata is shared, but then iterate its
193+
# fragments one at a time. ``dataset.scanner(fragments=...)``
194+
# at the aggregate level would force a cross-fragment cast —
195+
# which breaks variable-shape tensor extensions where each
196+
# file has its own ``ArrowTensorTypeV2(shape=...)``. Per-
197+
# fragment scanners let pyarrow use the native per-file type,
198+
# and downstream concat handles unification.
149199
dataset = pds.dataset(
150200
source=paths,
151201
format=self._format.value,
@@ -169,19 +219,18 @@ def read(self, input_split: FileManifest) -> Iterator[pa.Table]:
169219
]
170220
columns_to_synthesize = set(self._columns) - on_disk_column_names
171221

172-
scanner_kwargs = dict(
173-
columns=columns_to_read_from_file,
174-
filter=self._predicate,
175-
batch_size=self._resolve_batch_size(dataset),
176-
batch_readahead=1,
177-
)
222+
scanner_kwargs = {
223+
"columns": columns_to_read_from_file,
224+
"filter": self._predicate,
225+
"batch_size": self._resolve_batch_size(dataset),
226+
"batch_readahead": _ARROW_SCANNER_BATCH_READAHEAD,
227+
}
178228
scanner_kwargs.update(self._arrow_scanner_kwargs())
179-
scanner = dataset.scanner(**scanner_kwargs)
180229

181230
ctx = DataContext.get_current()
182231
rows_read = 0
183-
for table, fragment_path in iterate_with_retry(
184-
lambda: self._read_batches(scanner),
232+
for table, fragment_path, fragment_row_offset in iterate_with_retry(
233+
lambda: self._read_fragment_batches(dataset, scanner_kwargs),
185234
"read batches",
186235
match=ctx.retried_io_errors,
187236
):
@@ -216,6 +265,21 @@ def read(self, input_split: FileManifest) -> Iterator[pa.Table]:
216265
self._broadcast_partition_value(name, value, table.num_rows),
217266
)
218267

268+
# Skip when projection pushdown has narrowed ``columns`` to
269+
# exclude ``row_hash`` — the projection below would just drop it.
270+
if self._include_row_hash and (
271+
columns_to_synthesize is None
272+
or _ROW_HASH_COLUMN_NAME in columns_to_synthesize
273+
):
274+
hashes = _compute_row_hashes(
275+
fragment_path, fragment_row_offset, table.num_rows
276+
)
277+
if _ROW_HASH_COLUMN_NAME in table.column_names:
278+
table = table.drop([_ROW_HASH_COLUMN_NAME])
279+
table = table.append_column(
280+
_ROW_HASH_COLUMN_NAME, pa.array(hashes, type=pa.uint64())
281+
)
282+
219283
if self._columns is not None:
220284
# Project/reorder to the caller's requested column order;
221285
# drop any that weren't produced (matches V1's lenient
@@ -262,12 +326,42 @@ def _arrow_scanner_kwargs(self) -> dict:
262326
"""
263327
return {}
264328

265-
@staticmethod
266-
def _read_batches(
267-
scanner: pds.Scanner,
268-
) -> Iterator[tuple[pa.Table, str]]:
269-
"""Yield non-empty (table, fragment_path) pairs from scanner batches."""
270-
for tagged in scanner.scan_batches():
271-
table = pa.Table.from_batches(batches=[tagged.record_batch])
272-
if table.num_rows > 0:
273-
yield table, tagged.fragment.path
329+
def _read_fragment_batches(
330+
self,
331+
dataset: pds.Dataset,
332+
scanner_kwargs: dict,
333+
) -> Iterator[Tuple[pa.Table, str, int]]:
334+
"""Yield non-empty (table, fragment_path, fragment_row_offset) triples
335+
one fragment at a time.
336+
337+
``fragment_row_offset`` is the post-filter row position of the first
338+
row of ``table`` within the current fragment. Tracking it inside the
339+
generator means it resets correctly whenever ``iterate_with_retry``
340+
recreates the generator on a retry — outer-loop state would otherwise
341+
carry stale values from the failed attempt and corrupt row hashes.
342+
343+
Each fragment gets its own scanner so pyarrow uses the native
344+
per-file schema. A cross-fragment scanner would force a unified
345+
schema cast, which refuses extension-to-extension conversion
346+
(e.g. variable-shape tensors). V1 ``ParquetDatasource`` follows
347+
the same per-fragment pattern via ``fragment.to_batches``.
348+
349+
When a non-extension caller schema is available we pin it at the
350+
scanner so pyarrow null-fills any column the unified schema names
351+
but the fragment lacks (V1 parity). Falling back to the
352+
per-fragment ``physical_schema`` preserves the variable-shape
353+
tensor escape hatch already encoded in ``_file_dataset_schema``.
354+
"""
355+
for fragment in dataset.get_fragments():
356+
fragment_schema = (
357+
self._file_dataset_schema
358+
if self._file_dataset_schema is not None
359+
else fragment.physical_schema
360+
)
361+
scanner = fragment.scanner(**scanner_kwargs, schema=fragment_schema)
362+
offset = 0
363+
for tagged in scanner.scan_batches():
364+
table = pa.Table.from_batches(batches=[tagged.record_batch])
365+
if table.num_rows > 0:
366+
yield table, fragment.path, offset
367+
offset += table.num_rows

python/ray/data/_internal/datasource_v2/readers/parquet_file_reader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(
142142
ignore_prefixes: Optional[List[str]] = None,
143143
target_block_size: Optional[int] = None,
144144
include_paths: bool = False,
145+
include_row_hash: bool = False,
145146
schema: Optional[pa.Schema] = None,
146147
):
147148
"""Initialize the Parquet reader.
@@ -160,6 +161,8 @@ def __init__(
160161
Used for adaptive batch sizing when ``batch_size`` is not set.
161162
include_paths: If True, include the source file path in a
162163
``'path'`` column for each row.
164+
include_row_hash: If True, include a deterministic uint64 hash
165+
per row in a ``'row_hash'`` column.
163166
schema: Caller-supplied unified schema forwarded to the base
164167
:class:`FileReader` for per-fragment inference override
165168
and partition-column type casting.
@@ -174,6 +177,7 @@ def __init__(
174177
partitioning=partitioning,
175178
ignore_prefixes=ignore_prefixes,
176179
include_paths=include_paths,
180+
include_row_hash=include_row_hash,
177181
schema=schema,
178182
)
179183
self._explicit_batch_size = batch_size

python/ray/data/_internal/datasource_v2/scanners/parquet_scanner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@ class ParquetScanner(ArrowFileScanner):
2828

2929
target_block_size: Optional[int] = None
3030
include_paths: bool = False
31+
include_row_hash: bool = False
3132

3233
def read_schema(self) -> pa.Schema:
3334
"""Return schema after column pruning and tensor check."""
3435
schema = super().read_schema()
3536
if self.include_paths and schema.get_field_index("path") == -1:
3637
schema = schema.append(pa.field("path", pa.string()))
38+
if self.include_row_hash and schema.get_field_index("row_hash") == -1:
39+
schema = schema.append(pa.field("row_hash", pa.uint64()))
3740

3841
check_for_legacy_tensor_type(schema)
3942
return schema
@@ -54,5 +57,6 @@ def create_reader(self) -> ParquetFileReader:
5457
ignore_prefixes=self.ignore_prefixes,
5558
target_block_size=self.target_block_size,
5659
include_paths=self.include_paths,
60+
include_row_hash=self.include_row_hash,
5761
schema=self.schema,
5862
)

python/ray/data/_internal/datasource_v2/tests/test_parquet_datasource_v2.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,37 @@ def test_paths_and_filesystem_resolved(tmp_path):
105105
# the caller passed None.
106106
assert datasource.filesystem is not None
107107
assert len(datasource.paths) == 1
108+
109+
110+
def test_infer_schema_with_include_row_hash(tmp_path):
111+
file_path = tmp_path / "data.parquet"
112+
_write_parquet(str(file_path), pa.table({"a": [1, 2]}))
113+
114+
datasource = ParquetDatasourceV2([str(file_path)], include_row_hash=True)
115+
schema = datasource.infer_schema(_manifest_of([str(file_path)]))
116+
117+
assert "row_hash" in schema.names
118+
assert schema.field("row_hash").type == pa.uint64()
119+
120+
121+
def test_infer_schema_with_include_row_hash_existing_column_promoted_to_uint64(
122+
tmp_path,
123+
):
124+
file_path = tmp_path / "data.parquet"
125+
_write_parquet(str(file_path), pa.table({"val": [1, 2], "row_hash": [10, 20]}))
126+
127+
datasource = ParquetDatasourceV2([str(file_path)], include_row_hash=True)
128+
schema = datasource.infer_schema(_manifest_of([str(file_path)]))
129+
130+
assert schema.field("row_hash").type == pa.uint64()
131+
132+
133+
def test_create_scanner_propagates_include_row_hash(tmp_path):
134+
file_path = tmp_path / "data.parquet"
135+
_write_parquet(str(file_path), pa.table({"a": [1]}))
136+
137+
datasource = ParquetDatasourceV2([str(file_path)], include_row_hash=True)
138+
schema = datasource.infer_schema(_manifest_of([str(file_path)]))
139+
scanner = datasource.create_scanner(schema)
140+
141+
assert scanner.include_row_hash is True

python/ray/data/_internal/logical/operators/read_operator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from ray.data._internal.logical.operators.map_operator import AbstractMap
1414
from ray.data.block import (
15+
Block,
1516
BlockMetadata,
1617
BlockMetadataWithSchema,
1718
)
@@ -266,6 +267,12 @@ class ReadFiles(
266267
# renamed. The scanner only knows original names; renames are applied
267268
# in ``plan_read_files_op`` after each block is read.
268269
column_renames: Optional[Dict[str, str]] = None
270+
# Optional post-read block transform. Used by ``read_parquet``'s
271+
# ``_block_udf`` and ``tensor_column_schema`` (the latter is folded
272+
# into a ``_block_udf`` by ``_resolve_parquet_args`` before it gets
273+
# here). Applied in ``plan_read_files_op.do_read`` after each
274+
# table is read and before column renames.
275+
block_udf: Optional[Callable[[Block], Block]] = None
269276
can_modify_num_rows: bool = field(init=False, default=True)
270277
min_rows_per_bundled_input: Optional[int] = field(init=False, default=None)
271278
ray_remote_args_fn: None = field(init=False, default=None)
@@ -314,6 +321,18 @@ def infer_schema(self) -> "pa.Schema":
314321
# ``select_columns([])``); the stored ``self.schema`` is the
315322
# unprojected one and only used for construction.
316323
schema = self.scanner.read_schema()
324+
# When a ``block_udf`` is attached (e.g. ``read_parquet`` was
325+
# called with ``tensor_column_schema`` or ``_block_udf``), probe
326+
# its effect on the schema so downstream consumers see the
327+
# post-transform column types. Mirrors V1 ``ParquetDatasource``'s
328+
# dummy-table trick. Falls back to the scanner schema if the
329+
# probe fails — the UDF may require a non-empty input.
330+
if self.block_udf is not None:
331+
try:
332+
transformed = self.block_udf(schema.empty_table()).schema
333+
schema = transformed.with_metadata(schema.metadata)
334+
except Exception:
335+
pass
317336
if self.column_renames:
318337
import pyarrow as pa
319338

0 commit comments

Comments
 (0)