Skip to content

Commit 79ff25a

Browse files
committed
fixes
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent b412ab5 commit 79ff25a

14 files changed

Lines changed: 113 additions & 90 deletions

vortex-python/python/vortex/file.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def open(
6161
See also: :class:`vortex.dataset.VortexDataset`
6262
"""
6363

64-
return VortexFile(
65-
_file.open(path, store=store, without_segment_cache=without_segment_cache, session=session)
66-
)
64+
return VortexFile(_file.open(path, store=store, without_segment_cache=without_segment_cache, session=session))
6765

6866

6967
@final

vortex-python/python/vortex/ray/datasource.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .. import open as vx_open
1919
from ..arrow.expression import ensure_vortex_expression
2020
from ..expr import Expr as VortexExpr
21+
from ..session import Session
2122
from ..type_aliases import IntoProjection
2223

2324
if TYPE_CHECKING:
@@ -51,12 +52,14 @@ def __init__(
5152
self,
5253
*,
5354
url: str,
55+
session: Session,
5456
columns: IntoProjection = None,
5557
filter: pc.Expression | VortexExpr | None = None,
5658
batch_size: int | None = None,
5759
meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(), # pyright: ignore[reportCallInDefaultInitializer]
5860
):
5961
super().__init__()
62+
self._session = session
6063
self._columns = columns
6164
self._filter = filter
6265

@@ -101,6 +104,7 @@ def get_read_tasks(
101104
self._columns,
102105
self._filter,
103106
per_task_row_limit if per_task_row_limit is not None else self._batch_size,
107+
self._session,
104108
)
105109
for paths in partition(parallelism, self._paths)
106110
if len(paths) > 0
@@ -118,11 +122,12 @@ def _read_task(
118122
columns: IntoProjection,
119123
filter: pc.Expression | VortexExpr | None,
120124
batch_size: int | None,
125+
session: Session,
121126
) -> ReadTask:
122127
if not paths:
123128
raise ValueError("no paths specified")
124129

125-
files = [vx_open(path) for path in paths]
130+
files = [vx_open(path, session=session) for path in paths]
126131
schemas = [f.dtype.to_arrow_schema() for f in files]
127132
schema = schemas[0]
128133
assert all(s == schema for s in schemas[1:])
@@ -140,8 +145,11 @@ def read() -> Iterable[pandas.DataFrame]:
140145
# If we could serialize a PyVortexFile and a PyExpr, we could set those up earlier.
141146

142147
vx_filter = ensure_vortex_expression(filter, schema=schema)
148+
# Ray read functions may execute in worker processes, so use a worker-local session
149+
# instead of closing over the driver session.
150+
session = Session()
143151
for path in paths:
144-
f = vx_open(path)
152+
f = vx_open(path, session=session)
145153
for rb in f.to_arrow(columns, expr=vx_filter, batch_size=batch_size):
146154
# We would prefer to generate Arrow, but we run into this issue: https://github.com/apache/arrow/issues/47279
147155
#

vortex-python/test/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,13 @@
33

44
import logging
55

6+
import pytest
7+
8+
import vortex as vx
9+
610
logging.basicConfig(level=logging.DEBUG)
11+
12+
13+
@pytest.fixture
14+
def session() -> vx.Session:
15+
return vx.Session()

vortex-python/test/test_array.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,39 @@
77
import vortex
88

99

10-
def test_primitive_array_round_trip():
10+
def test_primitive_array_round_trip(session: vortex.Session):
1111
a = pa.array([0, 1, 2, 3])
1212
arr = vortex.array(a)
13-
assert arr.to_arrow_array() == a
13+
assert arr.to_arrow_array(session=session) == a
1414

1515

16-
def test_array_with_nulls():
16+
def test_array_with_nulls(session: vortex.Session):
1717
a = pa.array([b"123", None], type=pa.string_view())
1818
arr = vortex.array(a)
19-
assert arr.to_arrow_array() == a
19+
assert arr.to_arrow_array(session=session) == a
2020

2121

22-
def test_varbin_array_round_trip():
22+
def test_varbin_array_round_trip(session: vortex.Session):
2323
a = pa.array(["a", "b", "c"], type=pa.string_view())
2424
arr = vortex.array(a)
25-
assert arr.to_arrow_array() == a
25+
assert arr.to_arrow_array(session=session) == a
2626

2727

28-
def test_varbin_array_take():
28+
def test_varbin_array_take(session: vortex.Session):
2929
a = vortex.array(pa.array(["a", "b", "c", "d"], type=pa.string_view()))
30-
assert a.take(vortex.array(pa.array([0, 2]))).to_arrow_array() == pa.array(
30+
assert a.take(vortex.array(pa.array([0, 2]))).to_arrow_array(session=session) == pa.array(
3131
["a", "c"],
3232
type=pa.string_view(),
3333
)
3434

3535

36-
def test_empty_array():
36+
def test_empty_array(session: vortex.Session):
3737
a = pa.array([], type=pa.uint8())
3838
primitive = vortex.array(a)
39-
assert primitive.to_arrow_array().type == pa.uint8()
39+
assert primitive.to_arrow_array(session=session).type == pa.uint8()
4040

4141

4242
@pytest.mark.xfail(raises=IndexError)
43-
def test_scalar_at_out_of_bounds():
43+
def test_scalar_at_out_of_bounds(session: vortex.Session):
4444
a = vortex.array([10, 42, 999, 1992])
45-
_s = a.scalar_at(10)
45+
_s = a.scalar_at(10, session=session)

vortex-python/test/test_compress.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,42 @@
1212
import vortex
1313

1414

15-
def test_primitive_compress():
15+
def test_primitive_compress(session: vortex.Session):
1616
a = pa.array([0, 0, 0, 0, 9, 9, 9, 9, 1, 5])
17-
arr_compressed = vortex.compress(vortex.array(a))
17+
arr_compressed = vortex.compress(vortex.array(a), session=session)
1818
assert not isinstance(arr_compressed, vortex.PrimitiveArray)
1919
assert arr_compressed.nbytes < a.nbytes
2020

2121

22-
def test_for_compress():
22+
def test_for_compress(session: vortex.Session):
2323
a = pa.array(np.arange(10_000) + 10_000_000)
24-
arr_compressed = vortex.compress(vortex.array(a))
24+
arr_compressed = vortex.compress(vortex.array(a), session=session)
2525
assert not isinstance(arr_compressed, vortex.PrimitiveArray)
2626

2727

28-
def test_arrange_encode():
28+
def test_arrange_encode(session: vortex.Session):
2929
a = vortex.array(pa.array(np.arange(10_000), type=pa.uint32()))
30-
compressed = vortex.compress(a)
30+
compressed = vortex.compress(a, session=session)
3131
assert isinstance(compressed, vortex.FastLanesDeltaArray | vortex.SequenceArray)
3232
assert compressed.nbytes < a.nbytes
3333

3434

35-
def test_zigzag_encode():
35+
def test_zigzag_encode(session: vortex.Session):
3636
a = vortex.array(pa.array([-1, -1, 0, -1, 1, -1]))
37-
zarr = vortex.ZigZagArray.encode(a)
37+
zarr = vortex.ZigZagArray.encode(a, session=session)
3838
assert isinstance(zarr, vortex.ZigZagArray)
3939
# TODO(ngates): support decoding once we have decompressor.
4040

4141

42-
def test_chunked_encode():
42+
def test_chunked_encode(session: vortex.Session):
4343
chunked = pa.chunked_array([pa.array([0, 1, 2]), pa.array([3, 4, 5])])
4444
encoded = vortex.array(chunked)
45-
arrow = encoded.to_arrow_array()
45+
arrow = encoded.to_arrow_array(session=session)
4646
assert isinstance(arrow, pa.ChunkedArray)
4747
assert arrow.combine_chunks() == pa.array([0, 1, 2, 3, 4, 5])
4848

4949

50-
def test_table_encode():
50+
def test_table_encode(session: vortex.Session):
5151
table = pa.table( # pyright: ignore[reportCallIssue, reportUnknownVariableType]
5252
{ # pyright: ignore[reportArgumentType]
5353
"number": pa.chunked_array([pa.array([0, 1, 2]), pa.array([3, 4, 5])]),
@@ -59,7 +59,7 @@ def test_table_encode():
5959
assert isinstance(table, pa.Table)
6060

6161
encoded = vortex.array(table)
62-
arrow = encoded.to_arrow_array()
62+
arrow = encoded.to_arrow_array(session=session)
6363
assert isinstance(arrow, pa.ChunkedArray)
6464
assert arrow.combine_chunks() == pa.StructArray.from_arrays( # pyright: ignore[reportUnknownMemberType]
6565
[pa.array([0, 1, 2, 3, 4, 5]), pa.array(["a", "b", "c", "d", "e", "f"], type=pa.string_view())],
@@ -68,11 +68,11 @@ def test_table_encode():
6868

6969

7070
@pytest.mark.skip(reason="We have no way to guarantee that the vortex-bench data has been downloaded.")
71-
def test_taxi():
71+
def test_taxi(session: vortex.Session):
7272
curdir = Path(os.path.dirname(__file__)).parent.parent
7373
table = pq.read_table(curdir / "vortex-bench/data/yellow-tripdata-2023-11.parquet") # pyright: ignore[reportUnknownMemberType]
74-
compressed = vortex.compress(vortex.array(table[:100]))
75-
decompressed = compressed.to_arrow_array()
74+
compressed = vortex.compress(vortex.array(table[:100]), session=session)
75+
decompressed = compressed.to_arrow_array(session=session)
7676
assert len(decompressed) == 100
7777
# hard to test because of string_view
7878
# assert pc.equal(decompressed, table[:100].to_struct_array()), (decompressed, table[:100].to_struct_array())

vortex-python/test/test_dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ def ds(tmpdir_factory) -> vx.dataset.VortexDataset: # pyright: ignore[reportUnk
2929

3030
assert not os.path.exists(fname) # pyright: ignore[reportUnknownArgumentType]
3131

32+
session = vx.Session()
3233
a = pa.array([record(x) for x in range(1_000_000)])
33-
vx.io.write(vx.array(a), str(fname)) # pyright: ignore[reportUnknownArgumentType]
34-
return vx.dataset.VortexDataset.from_path(str(fname)) # pyright: ignore[reportUnknownArgumentType]
34+
vx.io.write(vx.array(a), str(fname), session=session) # pyright: ignore[reportUnknownArgumentType]
35+
return vx.dataset.VortexDataset.from_path(str(fname), session=session) # pyright: ignore[reportUnknownArgumentType]
3536

3637

3738
def test_schema(ds: pd.Dataset):
@@ -129,7 +130,7 @@ def test_filter(ds: vx.dataset.VortexDataset):
129130
assert len(tbl) == 20
130131

131132

132-
def test_filter_with_nested_null_dtype(tmp_path: Path):
133+
def test_filter_with_nested_null_dtype(tmp_path: Path, session: vx.Session):
133134
path = tmp_path / "test.vortex"
134135

135136
batch = pa.RecordBatch.from_pylist(
@@ -140,9 +141,9 @@ def test_filter_with_nested_null_dtype(tmp_path: Path):
140141
)
141142

142143
arr = vx.array(batch.to_struct_array())
143-
vx.io.write(vx.ArrayIterator.from_iter(arr.dtype, iter([arr])), str(path))
144+
vx.io.write(vx.ArrayIterator.from_iter(arr.dtype, iter([arr]), session=session), str(path), session=session)
144145

145-
dataset = vx.open(str(path)).to_dataset()
146+
dataset = vx.open(str(path), session=session).to_dataset()
146147
actual = dataset.to_table(filter=pc.field("a") == 0)
147148

148149
assert actual.to_pylist() == [{"a": 0, "b": {"x": None}}]

vortex-python/test/test_datasource.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ def test_partition():
4141
assert partition(3, list(range(11))) == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10]]
4242

4343

44-
def test_vortex_datasource(ray_init, tmpdir_factory): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType, reportUnusedParameter]
44+
def test_vortex_datasource(ray_init, tmpdir_factory, session: vx.Session): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType, reportUnusedParameter]
4545
folder = tmpdir_factory.mktemp("data") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
4646

4747
arr1 = vx.array([record(x) for x in range(5)])
48-
vx.io.write(arr1, str(folder / "01.vortex")) # pyright: ignore[reportUnknownArgumentType]
48+
vx.io.write(arr1, str(folder / "01.vortex"), session=session) # pyright: ignore[reportUnknownArgumentType]
4949

5050
arr2 = vx.array([record(x) for x in range(5, 10)])
51-
vx.io.write(arr2, str(folder / "02.vortex")) # pyright: ignore[reportUnknownArgumentType]
51+
vx.io.write(arr2, str(folder / "02.vortex"), session=session) # pyright: ignore[reportUnknownArgumentType]
5252

53-
ds = read_datasource(VortexDatasource(url=str(folder))) # pyright: ignore[reportUnknownArgumentType]
53+
ds = read_datasource(VortexDatasource(url=str(folder), session=session)) # pyright: ignore[reportUnknownArgumentType]
5454

5555
# Without an explicit sort, Ray may reorder rows *even within a single record batch*.
5656
ds = ds.sort("index")

vortex-python/test/test_duckdb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
import vortex as vx
1313

1414

15-
def test_duckdb_via_substrait(tmp_path: Path) -> None:
15+
def test_duckdb_via_substrait(tmp_path: Path, session: vx.Session) -> None:
1616
con = duckdb.connect()
1717

1818
arr = pa.array([datetime(2024, 1, 1), datetime(2024, 6, 15), datetime(2024, 12, 31)])
1919
table = pa.table({"ts": arr})
2020
path = str(tmp_path / "test_timestamp.vortex")
21-
vx.io.write(table, path)
21+
vx.io.write(table, path, session=session)
2222

23-
ds = vx.open(path).to_dataset() # noqa: F841 # pyright: ignore[reportUnusedVariable] - used by duckdb via SQL
23+
ds = vx.open(path, session=session).to_dataset() # noqa: F841 # pyright: ignore[reportUnusedVariable] - used by duckdb via SQL
2424
result = con.execute("SELECT * FROM ds WHERE ts > '2024-06-01'").fetchall()
2525
assert len(result) == 2
2626
print(result)

vortex-python/test/test_file.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ def record(x: int, columns: list[str] | set[str] | None = None) -> dict[str, int
2323
def vxf(tmpdir_factory) -> vx.VortexFile: # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
2424
fname = tmpdir_factory.mktemp("data") / "foo.vortex" # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
2525

26+
session = vx.Session()
2627
if not os.path.exists(fname): # pyright: ignore[reportUnknownArgumentType]
2728
a = pa.array([record(x) for x in range(1_000_000)])
28-
arr = vx.compress(vx.array(a))
29-
vx.io.write(arr, str(fname)) # pyright: ignore[reportUnknownArgumentType]
30-
return vx.open(str(fname), without_segment_cache=True) # pyright: ignore[reportUnknownArgumentType]
29+
arr = vx.compress(vx.array(a), session=session)
30+
vx.io.write(arr, str(fname), session=session) # pyright: ignore[reportUnknownArgumentType]
31+
return vx.open(str(fname), without_segment_cache=True, session=session) # pyright: ignore[reportUnknownArgumentType]
3132

3233

3334
def test_dtype(vxf: VortexFile):
@@ -62,7 +63,7 @@ def test_to_arrow_columns(vxf: VortexFile):
6263
assert rbr.schema == pa.schema([("string", pa.string_view()), ("bool", pa.bool_())])
6364

6465

65-
def test_empty_file(tmpdir_factory): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
66+
def test_empty_file(tmpdir_factory, session: vx.Session): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
6667
# test for writing empty files with null columns
6768
# create an empty table with schema `empty: null`
6869
table = pa.Table.from_pydict({"empty": []})
@@ -75,10 +76,10 @@ def test_empty_file(tmpdir_factory): # pyright: ignore[reportUnknownParameterTy
7576

7677
# writing file should succeed
7778
empty_file = tmpdir_factory.mktemp("data") / "empty.vortex" # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
78-
vx.io.write(empty, str(empty_file)) # pyright: ignore[reportUnknownArgumentType]
79+
vx.io.write(empty, str(empty_file), session=session) # pyright: ignore[reportUnknownArgumentType]
7980

8081

81-
def test_stream_pyarrow(tmpdir_factory): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
82+
def test_stream_pyarrow(tmpdir_factory, session: vx.Session): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
8283
import pyarrow.parquet as pq
8384

8485
data_dir = tmpdir_factory.mktemp("data") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
@@ -91,4 +92,4 @@ def test_stream_pyarrow(tmpdir_factory): # pyright: ignore[reportUnknownParamet
9192
pq.write_table(table, str(data_dir / "names.parquet")) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
9293

9394
df = pq.read_table(str(data_dir / "names.parquet")) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType]
94-
vx.io.write(df, str(data_dir / "names.vortex")) # pyright: ignore[reportUnknownArgumentType]
95+
vx.io.write(df, str(data_dir / "names.vortex"), session=session) # pyright: ignore[reportUnknownArgumentType]

0 commit comments

Comments
 (0)