|
6 | 6 | import pandas as pd |
7 | 7 | import pyarrow as pa |
8 | 8 | import pytest |
| 9 | +from packaging.version import parse as parse_version |
9 | 10 |
|
10 | 11 | from ray.data._internal.arrow_ops.transform_pyarrow import ( |
11 | 12 | MIN_PYARROW_VERSION_TYPE_PROMOTION, |
12 | 13 | _align_struct_fields, |
| 14 | + _has_unhashable_pandas_types, |
13 | 15 | concat, |
14 | 16 | hash_partition, |
15 | 17 | shuffle, |
@@ -144,6 +146,78 @@ def _concat_and_sort_partitions(parts: Iterable[pa.Table]) -> pa.Table: |
144 | 146 | assert t == _concat_and_sort_partitions(_structs_partition_dict.values()) |
145 | 147 |
|
146 | 148 |
|
| 149 | +@pytest.mark.parametrize( |
| 150 | + "pa_type,expected", |
| 151 | + [ |
| 152 | + # Nested types -> unhashable in pandas (convert to dict/list) |
| 153 | + (pa.struct([("a", pa.int32())]), True), |
| 154 | + (pa.list_(pa.int32()), True), |
| 155 | + (pa.large_list(pa.int32()), True), |
| 156 | + (pa.list_(pa.int32(), 3), True), # fixed_size_list |
| 157 | + (pa.map_(pa.string(), pa.int32()), True), |
| 158 | + (pa.dense_union([pa.field("x", pa.int32())]), True), |
| 159 | + # Ray extension types -> numpy arrays / arbitrary objects in pandas |
| 160 | + (ArrowTensorTypeV2((2, 2), pa.int64()), True), |
| 161 | + (ArrowPythonObjectType(), True), |
| 162 | + # Hashable primitives -> must stay False so we keep the fast path |
| 163 | + (pa.int32(), False), |
| 164 | + (pa.float64(), False), |
| 165 | + (pa.bool_(), False), |
| 166 | + (pa.string(), False), |
| 167 | + (pa.large_string(), False), |
| 168 | + (pa.binary(), False), |
| 169 | + (pa.decimal128(10, 2), False), |
| 170 | + (pa.date32(), False), |
| 171 | + (pa.timestamp("ns"), False), |
| 172 | + (pa.dictionary(pa.int32(), pa.string()), False), |
| 173 | + ], |
| 174 | +) |
| 175 | +def test_has_unhashable_pandas_types(pa_type, expected): |
| 176 | + schema = pa.schema([("c", pa_type)]) |
| 177 | + assert _has_unhashable_pandas_types(schema) is expected |
| 178 | + |
| 179 | + |
| 180 | +@pytest.mark.skipif( |
| 181 | + get_pyarrow_version() < parse_version("16.0.0"), |
| 182 | + reason="list_view / large_list_view require pyarrow 16+", |
| 183 | +) |
| 184 | +def test_has_unhashable_pandas_types_list_views(): |
| 185 | + # Regression: list_view/large_list_view also convert to Python lists in |
| 186 | + # pandas, so they must be flagged as unhashable like list/large_list. |
| 187 | + for view_type in (pa.list_view(pa.int32()), pa.large_list_view(pa.int32())): |
| 188 | + schema = pa.schema([("c", view_type)]) |
| 189 | + assert _has_unhashable_pandas_types(schema) is True |
| 190 | + |
| 191 | + |
| 192 | +def test_hash_partition_null_struct_consistent_across_blocks(): |
| 193 | + struct_t = pa.struct([("v", pa.int32())]) |
| 194 | + num_partitions = 8 |
| 195 | + |
| 196 | + all_null = pa.Table.from_pydict( |
| 197 | + {"k": pa.array([None, None, None], type=struct_t), "idx": [0, 1, 2]} |
| 198 | + ) |
| 199 | + mixed = pa.Table.from_pydict( |
| 200 | + { |
| 201 | + "k": pa.array([None, {"v": 1}, None], type=struct_t), |
| 202 | + "idx": [10, 11, 12], |
| 203 | + } |
| 204 | + ) |
| 205 | + |
| 206 | + p1 = hash_partition(all_null, hash_cols=["k"], num_partitions=num_partitions) |
| 207 | + p2 = hash_partition(mixed, hash_cols=["k"], num_partitions=num_partitions) |
| 208 | + |
| 209 | + def null_partition_id(parts): |
| 210 | + # Return the partition id holding null-key rows (there should be |
| 211 | + # exactly one — identical null keys must co-locate). |
| 212 | + null_pids = { |
| 213 | + pid for pid, tbl in parts.items() if any(tbl["k"].is_null().to_pylist()) |
| 214 | + } |
| 215 | + assert len(null_pids) == 1, null_pids |
| 216 | + return next(iter(null_pids)) |
| 217 | + |
| 218 | + assert null_partition_id(p1) == null_partition_id(p2) |
| 219 | + |
| 220 | + |
147 | 221 | def test_shuffle(): |
148 | 222 | t = pa.Table.from_pydict( |
149 | 223 | { |
|
0 commit comments