1- # -*- coding: utf-8 -*-
21from __future__ import annotations
32
43import asyncio
54import logging
6- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union , cast
5+ from typing import TYPE_CHECKING , Any , cast
76
87from pyathena .aio .common import WithAsyncFetch
98from pyathena .arrow .converter import (
1918 import polars as pl
2019 from pyarrow import Table
2120
22- _logger = logging .getLogger (__name__ ) # type: ignore
21+ _logger = logging .getLogger (__name__ )
2322
2423
2524class AioArrowCursor (WithAsyncFetch ):
@@ -37,19 +36,19 @@ class AioArrowCursor(WithAsyncFetch):
3736
3837 def __init__ (
3938 self ,
40- s3_staging_dir : Optional [ str ] = None ,
41- schema_name : Optional [ str ] = None ,
42- catalog_name : Optional [ str ] = None ,
43- work_group : Optional [ str ] = None ,
39+ s3_staging_dir : str | None = None ,
40+ schema_name : str | None = None ,
41+ catalog_name : str | None = None ,
42+ work_group : str | None = None ,
4443 poll_interval : float = 1 ,
45- encryption_option : Optional [ str ] = None ,
46- kms_key : Optional [ str ] = None ,
44+ encryption_option : str | None = None ,
45+ kms_key : str | None = None ,
4746 kill_on_interrupt : bool = True ,
4847 unload : bool = False ,
4948 result_reuse_enable : bool = False ,
5049 result_reuse_minutes : int = CursorIterator .DEFAULT_RESULT_REUSE_MINUTES ,
51- connect_timeout : Optional [ float ] = None ,
52- request_timeout : Optional [ float ] = None ,
50+ connect_timeout : float | None = None ,
51+ request_timeout : float | None = None ,
5352 ** kwargs ,
5453 ) -> None :
5554 super ().__init__ (
@@ -68,29 +67,29 @@ def __init__(
6867 self ._unload = unload
6968 self ._connect_timeout = connect_timeout
7069 self ._request_timeout = request_timeout
71- self ._result_set : Optional [ AthenaArrowResultSet ] = None
70+ self ._result_set : AthenaArrowResultSet | None = None
7271
7372 @staticmethod
7473 def get_default_converter (
7574 unload : bool = False ,
76- ) -> Union [ DefaultArrowTypeConverter , DefaultArrowUnloadTypeConverter , Any ] :
75+ ) -> DefaultArrowTypeConverter | DefaultArrowUnloadTypeConverter | Any :
7776 if unload :
7877 return DefaultArrowUnloadTypeConverter ()
7978 return DefaultArrowTypeConverter ()
8079
8180 async def execute ( # type: ignore[override]
8281 self ,
8382 operation : str ,
84- parameters : Optional [ Union [ Dict [ str , Any ], List [str ]]] = None ,
85- work_group : Optional [ str ] = None ,
86- s3_staging_dir : Optional [ str ] = None ,
87- cache_size : Optional [ int ] = 0 ,
88- cache_expiration_time : Optional [ int ] = 0 ,
89- result_reuse_enable : Optional [ bool ] = None ,
90- result_reuse_minutes : Optional [ int ] = None ,
91- paramstyle : Optional [ str ] = None ,
83+ parameters : dict [ str , Any ] | list [str ] | None = None ,
84+ work_group : str | None = None ,
85+ s3_staging_dir : str | None = None ,
86+ cache_size : int | None = 0 ,
87+ cache_expiration_time : int | None = 0 ,
88+ result_reuse_enable : bool | None = None ,
89+ result_reuse_minutes : int | None = None ,
90+ paramstyle : str | None = None ,
9291 ** kwargs ,
93- ) -> " AioArrowCursor" :
92+ ) -> AioArrowCursor :
9493 """Execute a SQL query asynchronously and return results as Arrow Tables.
9594
9695 Args:
@@ -143,7 +142,7 @@ async def execute( # type: ignore[override]
143142
144143 async def fetchone ( # type: ignore[override]
145144 self ,
146- ) -> Optional [ Union [ Tuple [ Optional [ Any ] , ...], Dict [Any , Optional [ Any ]]]] :
145+ ) -> tuple [ Any | None , ...] | dict [Any , Any | None ] | None :
147146 """Fetch the next row of the result set.
148147
149148 Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid
@@ -161,8 +160,8 @@ async def fetchone( # type: ignore[override]
161160 return await asyncio .to_thread (result_set .fetchone )
162161
163162 async def fetchmany ( # type: ignore[override]
164- self , size : Optional [ int ] = None
165- ) -> List [ Union [ Tuple [ Optional [ Any ] , ...], Dict [Any , Optional [ Any ]] ]]:
163+ self , size : int | None = None
164+ ) -> list [ tuple [ Any | None , ...] | dict [Any , Any | None ]]:
166165 """Fetch multiple rows from the result set.
167166
168167 Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid
@@ -184,7 +183,7 @@ async def fetchmany( # type: ignore[override]
184183
185184 async def fetchall ( # type: ignore[override]
186185 self ,
187- ) -> List [ Union [ Tuple [ Optional [ Any ] , ...], Dict [Any , Optional [ Any ]] ]]:
186+ ) -> list [ tuple [ Any | None , ...] | dict [Any , Any | None ]]:
188187 """Fetch all remaining rows from the result set.
189188
190189 Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid
@@ -207,7 +206,7 @@ async def __anext__(self):
207206 raise StopAsyncIteration
208207 return row
209208
210- def as_arrow (self ) -> " Table" :
209+ def as_arrow (self ) -> Table :
211210 """Return query results as an Apache Arrow Table.
212211
213212 Returns:
@@ -218,7 +217,7 @@ def as_arrow(self) -> "Table":
218217 result_set = cast (AthenaArrowResultSet , self .result_set )
219218 return result_set .as_arrow ()
220219
221- def as_polars (self ) -> " pl.DataFrame" :
220+ def as_polars (self ) -> pl .DataFrame :
222221 """Return query results as a Polars DataFrame.
223222
224223 Returns:
0 commit comments