Skip to content

Commit e9d0b00

Browse files
committed
threads
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 80e3111 commit e9d0b00

2 files changed

Lines changed: 127 additions & 47 deletions

File tree

vortex-python/python/vortex/dataset.py

Lines changed: 74 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import warnings
77
from collections.abc import Iterator
8+
from contextlib import contextmanager
89
from functools import reduce
910
from typing import final
1011

@@ -17,13 +18,34 @@
1718
from .arrays import array
1819
from .arrow.expression import ensure_vortex_expression
1920
from .expr import Expr, and_
21+
from .runtime import set_worker_threads as _set_worker_threads
22+
from .runtime import set_worker_threads_to_available_parallelism as _set_worker_threads_to_available_parallelism
23+
from .runtime import worker_count as _worker_count
2024

2125

22-
def _warn_use_threads() -> None:
23-
warnings.warn(
24-
"Vortex threading is configured through vortex.runtime. Ignoring use_threads=True.",
25-
stacklevel=2,
26-
)
26+
@contextmanager
27+
def _temporary_worker_threads(use_threads: bool | None) -> Iterator[None]:
28+
if use_threads is None:
29+
yield
30+
return
31+
32+
previous_workers = _worker_count()
33+
if use_threads:
34+
_set_worker_threads_to_available_parallelism()
35+
else:
36+
_set_worker_threads(0)
37+
38+
try:
39+
yield
40+
finally:
41+
_set_worker_threads(previous_workers)
42+
43+
44+
def _read_batches_with_temporary_worker_threads(
45+
reader: pyarrow.RecordBatchReader, use_threads: bool | None
46+
) -> Iterator[pyarrow.RecordBatch]:
47+
with _temporary_worker_threads(use_threads):
48+
yield from reader
2749

2850

2951
@final
@@ -72,14 +94,13 @@ def count_rows(
7294
raise ValueError("fragment_readahead not supported")
7395
if fragment_scan_options is not None:
7496
raise ValueError("fragment_scan_options not supported")
75-
if use_threads:
76-
_warn_use_threads()
7797
if cache_metadata is not None:
7898
warnings.warn("Vortex does not support cache_metadata. Ignoring cache_metadata setting.")
7999
del memory_pool
80-
return self._dataset.count_rows(
81-
row_filter=self._filter_expression(filter), split_by=batch_size, row_range=_row_range
82-
)
100+
with _temporary_worker_threads(use_threads):
101+
return self._dataset.count_rows(
102+
row_filter=self._filter_expression(filter), split_by=batch_size, row_range=_row_range
103+
)
83104

84105
def _filter_expression(self, expression: pyarrow.dataset.Expression | Expr | None) -> Expr | None:
85106
if expression is None:
@@ -140,7 +161,8 @@ def head(
140161
fragment_scan_options : :class:`.pyarrow.dataset.FragmentScanOptions`
141162
Not implemented.
142163
use_threads : bool
143-
Not implemented.
164+
If ``True``, temporarily use available parallelism. If ``False``,
165+
temporarily disable Vortex background workers.
144166
memory_pool : :class:`.pyarrow.MemoryPool` | None
145167
Not implemented.
146168
@@ -157,23 +179,22 @@ def head(
157179
raise ValueError("fragment_readahead not supported")
158180
if fragment_scan_options is not None:
159181
raise ValueError("fragment_scan_options not supported")
160-
if use_threads:
161-
_warn_use_threads()
162182
if columns is not None and len(columns) == 0:
163183
raise ValueError("empty projections are not currently supported")
164184
if cache_metadata is not None:
165185
warnings.warn("Vortex does not support cache_metadata. Ignoring cache_metadata setting.")
166186
del memory_pool
167187

168-
return (
169-
self._dataset.to_array(
170-
columns=columns,
171-
row_filter=self._filter_expression(filter),
172-
row_range=_row_range,
188+
with _temporary_worker_threads(use_threads):
189+
return (
190+
self._dataset.to_array(
191+
columns=columns,
192+
row_filter=self._filter_expression(filter),
193+
row_range=_row_range,
194+
)
195+
.slice(0, num_rows)
196+
.to_arrow_table()
173197
)
174-
.slice(0, num_rows)
175-
.to_arrow_table()
176-
)
177198

178199
@override
179200
def join(
@@ -240,7 +261,8 @@ def scanner(
240261
fragment_scan_options : :class:`.pyarrow.dataset.FragmentScanOptions`
241262
Not implemented.
242263
use_threads : bool
243-
Not implemented.
264+
If ``True``, temporarily use available parallelism. If ``False``,
265+
temporarily disable Vortex background workers.
244266
memory_pool : :class:`.pyarrow.MemoryPool` | None
245267
Not implemented.
246268
@@ -312,7 +334,8 @@ def take( # pyright: ignore[reportIncompatibleMethodOverride]
312334
fragment_scan_options : :class:`.pyarrow.dataset.FragmentScanOptions`
313335
Not implemented.
314336
use_threads : bool
315-
Not implemented.
337+
If ``True``, temporarily use available parallelism. If ``False``,
338+
temporarily disable Vortex background workers.
316339
cache_metadata : bool
317340
Not implemented.
318341
memory_pool : :class:`.pyarrow.MemoryPool` | None
@@ -323,12 +346,13 @@ def take( # pyright: ignore[reportIncompatibleMethodOverride]
323346
table : :class:`.pyarrow.Table`
324347
325348
"""
326-
return self._dataset.to_array(
327-
columns=columns,
328-
row_filter=self._filter_expression(filter),
329-
indices=array(indices.cast(pa.uint64())),
330-
row_range=_row_range,
331-
).to_arrow_table()
349+
with _temporary_worker_threads(use_threads):
350+
return self._dataset.to_array(
351+
columns=columns,
352+
row_filter=self._filter_expression(filter),
353+
indices=array(indices.cast(pa.uint64())),
354+
row_range=_row_range,
355+
).to_arrow_table()
332356

333357
def to_record_batch_reader(
334358
self,
@@ -361,7 +385,8 @@ def to_record_batch_reader(
361385
fragment_scan_options : :class:`.pyarrow.dataset.FragmentScanOptions`
362386
Not implemented.
363387
use_threads : bool
364-
Not implemented.
388+
If ``True``, temporarily use available parallelism. If ``False``,
389+
temporarily disable Vortex background workers.
365390
memory_pool : :class:`.pyarrow.MemoryPool` | None
366391
Not implemented.
367392
@@ -376,15 +401,19 @@ def to_record_batch_reader(
376401
raise ValueError("fragment_readahead not supported")
377402
if fragment_scan_options is not None:
378403
raise ValueError("fragment_scan_options not supported")
379-
if use_threads:
380-
_warn_use_threads()
381404
if cache_metadata is not None:
382405
warnings.warn("Vortex does not support cache_metadata. Ignoring cache_metadata setting.")
383406
if columns is not None and len(columns) == 0:
384407
raise ValueError("empty projections are not currently supported")
385408
del memory_pool
386-
return self._dataset.to_record_batch_reader(
387-
columns=columns, row_filter=self._filter_expression(filter), split_by=batch_size, row_range=_row_range
409+
with _temporary_worker_threads(use_threads):
410+
reader = self._dataset.to_record_batch_reader(
411+
columns=columns, row_filter=self._filter_expression(filter), split_by=batch_size, row_range=_row_range
412+
)
413+
if use_threads is None:
414+
return reader
415+
return pyarrow.RecordBatchReader.from_batches(
416+
reader.schema, _read_batches_with_temporary_worker_threads(reader, use_threads)
388417
)
389418

390419
@override
@@ -419,7 +448,8 @@ def to_batches(
419448
fragment_scan_options : :class:`.pyarrow.dataset.FragmentScanOptions`
420449
Not implemented.
421450
use_threads : bool
422-
Not implemented.
451+
If ``True``, temporarily use available parallelism. If ``False``,
452+
temporarily disable Vortex background workers.
423453
cache_metadata : bool
424454
Not implemented.
425455
memory_pool : :class:`.pyarrow.MemoryPool` | None
@@ -442,11 +472,7 @@ def to_batches(
442472
memory_pool,
443473
_row_range,
444474
)
445-
while True:
446-
try:
447-
yield record_batch_reader.read_next_batch()
448-
except StopIteration:
449-
return
475+
yield from record_batch_reader
450476

451477
@override
452478
def to_table(
@@ -480,7 +506,8 @@ def to_table(
480506
fragment_scan_options : :class:`.pyarrow.dataset.FragmentScanOptions`
481507
Not implemented.
482508
use_threads : bool
483-
Not implemented.
509+
If ``True``, temporarily use available parallelism. If ``False``,
510+
temporarily disable Vortex background workers.
484511
memory_pool : :class:`.pyarrow.MemoryPool` | None
485512
Not implemented.
486513
@@ -497,8 +524,6 @@ def to_table(
497524
raise ValueError("fragment_readahead not supported")
498525
if fragment_scan_options is not None:
499526
raise ValueError("fragment_scan_options not supported")
500-
if use_threads:
501-
_warn_use_threads()
502527
if cache_metadata is not None:
503528
warnings.warn("Vortex does not support cache_metadata. Ignoring cache_metadata setting.")
504529
if columns is not None and len(columns) == 0:
@@ -510,9 +535,10 @@ def to_table(
510535
"VortexDataset does not currently support a dict of expressions as the 'column' parameter."
511536
)
512537

513-
return self._dataset.to_array(
514-
columns=columns, row_filter=self._filter_expression(filter), row_range=_row_range
515-
).to_arrow_table()
538+
with _temporary_worker_threads(use_threads):
539+
return self._dataset.to_array(
540+
columns=columns, row_filter=self._filter_expression(filter), row_range=_row_range
541+
).to_arrow_table()
516542

517543

518544
def from_url(url: str) -> VortexDataset:
@@ -758,7 +784,8 @@ class VortexScanner(pyarrow.dataset.Scanner):
758784
fragment_scan_options : :class:`.pyarrow.dataset.FragmentScanOptions`
759785
Not implemented.
760786
use_threads : bool
761-
Not implemented.
787+
If ``True``, temporarily use available parallelism. If ``False``,
788+
temporarily disable Vortex background workers.
762789
memory_pool : :class:`.pyarrow.MemoryPool` | None
763790
Not implemented.
764791

vortex-python/test/test_dataset.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pyarrow.compute as pc
1212
import pyarrow.dataset as pd
1313
import pytest
14+
import vortex.dataset as vx_dataset
1415

1516
import vortex as vx
1617

@@ -71,6 +72,58 @@ def test_to_batches(ds: pd.Dataset):
7172
)
7273

7374

75+
def test_use_threads_configures_worker_pool(monkeypatch: pytest.MonkeyPatch):
76+
current_workers = 3
77+
calls: list[tuple[str, int]] = []
78+
79+
def fake_worker_count() -> int:
80+
return current_workers
81+
82+
def fake_set_worker_threads(count: int) -> None:
83+
nonlocal current_workers
84+
calls.append(("set", count))
85+
current_workers = count
86+
87+
def fake_set_worker_threads_to_available_parallelism() -> None:
88+
nonlocal current_workers
89+
calls.append(("available", current_workers))
90+
current_workers = 11
91+
92+
monkeypatch.setattr(vx_dataset, "_worker_count", fake_worker_count)
93+
monkeypatch.setattr(vx_dataset, "_set_worker_threads", fake_set_worker_threads)
94+
monkeypatch.setattr(
95+
vx_dataset,
96+
"_set_worker_threads_to_available_parallelism",
97+
fake_set_worker_threads_to_available_parallelism,
98+
)
99+
100+
with vx_dataset._temporary_worker_threads(True): # pyright: ignore[reportPrivateUsage]
101+
assert current_workers == 11
102+
103+
assert current_workers == 3
104+
105+
with vx_dataset._temporary_worker_threads(False): # pyright: ignore[reportPrivateUsage]
106+
assert current_workers == 0
107+
108+
assert current_workers == 3
109+
assert calls == [("available", 3), ("set", 3), ("set", 0), ("set", 3)]
110+
111+
calls.clear()
112+
reader = pa.RecordBatchReader.from_batches(
113+
pa.schema([("x", pa.int64())]),
114+
[
115+
pa.record_batch([pa.array([1])], names=["x"]),
116+
pa.record_batch([pa.array([2])], names=["x"]),
117+
],
118+
)
119+
120+
batches = list(vx_dataset._read_batches_with_temporary_worker_threads(reader, True)) # pyright: ignore[reportPrivateUsage]
121+
122+
assert [batch.to_pylist() for batch in batches] == [[{"x": 1}], [{"x": 2}]]
123+
assert current_workers == 3
124+
assert calls == [("available", 3), ("set", 3)]
125+
126+
74127
@pytest.mark.parametrize("batch_size", [1234, 8192, 1 << 31])
75128
def test_to_batch_size(ds: pd.Dataset, batch_size: int):
76129
batch_sizes = [len(x) for x in ds.to_batches(batch_size=batch_size)]

0 commit comments

Comments
 (0)