11# -*- coding: utf-8 -*-
22from __future__ import annotations
33
4- import itertools
54import logging
65import mimetypes
76import os .path
87import re
98from concurrent .futures import Future , as_completed
10- from concurrent .futures .thread import ThreadPoolExecutor
119from copy import deepcopy
1210from datetime import datetime
1311from multiprocessing import cpu_count
2321from fsspec .utils import tokenize
2422
2523import pyathena
24+ from pyathena .filesystem .s3_executor import S3Executor , S3ThreadPoolExecutor
2625from pyathena .filesystem .s3_object import (
2726 S3CompleteMultipartUpload ,
2827 S3MultipartUpload ,
@@ -686,6 +685,20 @@ def _delete_object(
686685 ** request ,
687686 )
688687
688+ def _create_executor (self , max_workers : int ) -> S3Executor :
689+ """Create an executor strategy for parallel operations.
690+
691+ Subclasses can override to provide alternative execution strategies
692+ (e.g., asyncio-based execution).
693+
694+ Args:
695+ max_workers: Maximum number of parallel workers.
696+
697+ Returns:
698+ An S3Executor instance.
699+ """
700+ return S3ThreadPoolExecutor (max_workers = max_workers )
701+
689702 def _delete_objects (
690703 self , bucket : str , paths : List [str ], max_workers : Optional [int ] = None , ** kwargs
691704 ) -> None :
@@ -703,7 +716,7 @@ def _delete_objects(
703716 object_ .update ({"VersionId" : version_id })
704717 delete_objects .append (object_ )
705718
706- with ThreadPoolExecutor (max_workers = max_workers ) as executor :
719+ with self . _create_executor (max_workers = max_workers ) as executor :
707720 fs = []
708721 for delete in [
709722 delete_objects [i : i + self .DELETE_OBJECTS_MAX_KEYS ]
@@ -861,7 +874,7 @@ def _copy_object_with_multipart_upload(
861874 ** kwargs ,
862875 )
863876 parts = []
864- with ThreadPoolExecutor (max_workers = max_workers ) as executor :
877+ with self . _create_executor (max_workers = max_workers ) as executor :
865878 fs = [
866879 executor .submit (
867880 self ._upload_part_copy ,
@@ -1106,6 +1119,7 @@ def _open(
11061119 mode ,
11071120 version_id = None ,
11081121 max_workers = max_workers ,
1122+ executor = self ._create_executor (max_workers = max_workers ),
11091123 block_size = block_size ,
11101124 cache_type = cache_type ,
11111125 autocommit = autocommit ,
@@ -1256,6 +1270,7 @@ def __init__(
12561270 mode : str = "rb" ,
12571271 version_id : Optional [str ] = None ,
12581272 max_workers : int = (cpu_count () or 1 ) * 5 ,
1273+ executor : Optional [S3Executor ] = None ,
12591274 block_size : int = S3FileSystem .DEFAULT_BLOCK_SIZE ,
12601275 cache_type : str = "bytes" ,
12611276 autocommit : bool = True ,
@@ -1265,7 +1280,7 @@ def __init__(
12651280 ** kwargs ,
12661281 ) -> None :
12671282 self .max_workers = max_workers
1268- self ._executor = ThreadPoolExecutor (max_workers = max_workers )
1283+ self ._executor : S3Executor = executor or S3ThreadPoolExecutor (max_workers = max_workers )
12691284 self .s3_additional_kwargs = s3_additional_kwargs if s3_additional_kwargs else {}
12701285
12711286 super ().__init__ (
@@ -1481,24 +1496,18 @@ def _fetch_range(self, start: int, end: int) -> bytes:
14811496 start , end , max_workers = self .max_workers , worker_block_size = self .blocksize
14821497 )
14831498 if len (ranges ) > 1 :
1484- object_ = self ._merge_objects (
1485- list (
1486- self ._executor .map (
1487- lambda bucket , key , ranges , version_id , kwargs : self .fs ._get_object (
1488- bucket = bucket ,
1489- key = key ,
1490- ranges = ranges ,
1491- version_id = version_id ,
1492- ** kwargs ,
1493- ),
1494- itertools .repeat (self .bucket ),
1495- itertools .repeat (self .key ),
1496- ranges ,
1497- itertools .repeat (self .version_id ),
1498- itertools .repeat (self .s3_additional_kwargs ),
1499- )
1499+ futures = [
1500+ self ._executor .submit (
1501+ self .fs ._get_object ,
1502+ bucket = self .bucket ,
1503+ key = self .key ,
1504+ ranges = r ,
1505+ version_id = self .version_id ,
1506+ ** self .s3_additional_kwargs ,
15001507 )
1501- )
1508+ for r in ranges
1509+ ]
1510+ object_ = self ._merge_objects ([f .result () for f in as_completed (futures )])
15021511 else :
15031512 object_ = self .fs ._get_object (
15041513 self .bucket ,
0 commit comments