Skip to content

Commit 93ae890

Browse files
Merge pull request #684 from pyathena-dev/refactor/s3-executor-strategy
Add S3Executor strategy pattern for async S3 operations
2 parents 424d564 + c23d060 commit 93ae890

File tree

6 files changed

+1354
-28
lines changed

6 files changed

+1354
-28
lines changed

pyathena/aio/s3fs/cursor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pyathena.aio.common import WithAsyncFetch
99
from pyathena.common import CursorIterator
1010
from pyathena.error import OperationalError, ProgrammingError
11+
from pyathena.filesystem.s3_async import AioS3FileSystem
1112
from pyathena.model import AthenaQueryExecution
1213
from pyathena.s3fs.converter import DefaultS3FSTypeConverter
1314
from pyathena.s3fs.result_set import AthenaS3FSResultSet, CSVReaderType
@@ -16,11 +17,12 @@
1617

1718

1819
class AioS3FSCursor(WithAsyncFetch):
19-
"""Native asyncio cursor that reads CSV results via S3FileSystem.
20+
"""Native asyncio cursor that reads CSV results via AioS3FileSystem.
2021
21-
Uses ``asyncio.to_thread()`` for result set creation and fetch operations
22-
because ``AthenaS3FSResultSet`` lazily streams rows from S3 via a CSV
23-
reader, making fetch calls blocking I/O.
22+
Uses ``AioS3FileSystem`` for S3 operations, which replaces
23+
``ThreadPoolExecutor`` parallelism with ``asyncio.gather`` +
24+
``asyncio.to_thread``. Fetch operations are wrapped in
25+
``asyncio.to_thread()`` because CSV reading is blocking I/O.
2426
2527
Example:
2628
>>> async with await pyathena.aio_connect(...) as conn:
@@ -127,6 +129,7 @@ async def execute( # type: ignore[override]
127129
arraysize=self.arraysize,
128130
retry_config=self._retry_config,
129131
csv_reader=self._csv_reader,
132+
filesystem_class=AioS3FileSystem,
130133
**kwargs,
131134
)
132135
else:

pyathena/filesystem/s3.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4-
import itertools
54
import logging
65
import mimetypes
76
import os.path
87
import re
98
from concurrent.futures import Future, as_completed
10-
from concurrent.futures.thread import ThreadPoolExecutor
119
from copy import deepcopy
1210
from datetime import datetime
1311
from multiprocessing import cpu_count
@@ -23,6 +21,7 @@
2321
from fsspec.utils import tokenize
2422

2523
import pyathena
24+
from pyathena.filesystem.s3_executor import S3Executor, S3ThreadPoolExecutor
2625
from pyathena.filesystem.s3_object import (
2726
S3CompleteMultipartUpload,
2827
S3MultipartUpload,
@@ -686,6 +685,20 @@ def _delete_object(
686685
**request,
687686
)
688687

688+
def _create_executor(self, max_workers: int) -> S3Executor:
689+
"""Create an executor strategy for parallel operations.
690+
691+
Subclasses can override to provide alternative execution strategies
692+
(e.g., asyncio-based execution).
693+
694+
Args:
695+
max_workers: Maximum number of parallel workers.
696+
697+
Returns:
698+
An S3Executor instance.
699+
"""
700+
return S3ThreadPoolExecutor(max_workers=max_workers)
701+
689702
def _delete_objects(
690703
self, bucket: str, paths: List[str], max_workers: Optional[int] = None, **kwargs
691704
) -> None:
@@ -703,7 +716,7 @@ def _delete_objects(
703716
object_.update({"VersionId": version_id})
704717
delete_objects.append(object_)
705718

706-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
719+
with self._create_executor(max_workers=max_workers) as executor:
707720
fs = []
708721
for delete in [
709722
delete_objects[i : i + self.DELETE_OBJECTS_MAX_KEYS]
@@ -861,7 +874,7 @@ def _copy_object_with_multipart_upload(
861874
**kwargs,
862875
)
863876
parts = []
864-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
877+
with self._create_executor(max_workers=max_workers) as executor:
865878
fs = [
866879
executor.submit(
867880
self._upload_part_copy,
@@ -1106,6 +1119,7 @@ def _open(
11061119
mode,
11071120
version_id=None,
11081121
max_workers=max_workers,
1122+
executor=self._create_executor(max_workers=max_workers),
11091123
block_size=block_size,
11101124
cache_type=cache_type,
11111125
autocommit=autocommit,
@@ -1256,6 +1270,7 @@ def __init__(
12561270
mode: str = "rb",
12571271
version_id: Optional[str] = None,
12581272
max_workers: int = (cpu_count() or 1) * 5,
1273+
executor: Optional[S3Executor] = None,
12591274
block_size: int = S3FileSystem.DEFAULT_BLOCK_SIZE,
12601275
cache_type: str = "bytes",
12611276
autocommit: bool = True,
@@ -1265,7 +1280,7 @@ def __init__(
12651280
**kwargs,
12661281
) -> None:
12671282
self.max_workers = max_workers
1268-
self._executor = ThreadPoolExecutor(max_workers=max_workers)
1283+
self._executor: S3Executor = executor or S3ThreadPoolExecutor(max_workers=max_workers)
12691284
self.s3_additional_kwargs = s3_additional_kwargs if s3_additional_kwargs else {}
12701285

12711286
super().__init__(
@@ -1481,24 +1496,18 @@ def _fetch_range(self, start: int, end: int) -> bytes:
14811496
start, end, max_workers=self.max_workers, worker_block_size=self.blocksize
14821497
)
14831498
if len(ranges) > 1:
1484-
object_ = self._merge_objects(
1485-
list(
1486-
self._executor.map(
1487-
lambda bucket, key, ranges, version_id, kwargs: self.fs._get_object(
1488-
bucket=bucket,
1489-
key=key,
1490-
ranges=ranges,
1491-
version_id=version_id,
1492-
**kwargs,
1493-
),
1494-
itertools.repeat(self.bucket),
1495-
itertools.repeat(self.key),
1496-
ranges,
1497-
itertools.repeat(self.version_id),
1498-
itertools.repeat(self.s3_additional_kwargs),
1499-
)
1499+
futures = [
1500+
self._executor.submit(
1501+
self.fs._get_object,
1502+
bucket=self.bucket,
1503+
key=self.key,
1504+
ranges=r,
1505+
version_id=self.version_id,
1506+
**self.s3_additional_kwargs,
15001507
)
1501-
)
1508+
for r in ranges
1509+
]
1510+
object_ = self._merge_objects([f.result() for f in as_completed(futures)])
15021511
else:
15031512
object_ = self.fs._get_object(
15041513
self.bucket,

0 commit comments

Comments
 (0)