Skip to content

Commit 52dfba6

Browse files
Merge pull request #681 from pyathena-dev/fix/680-mypy-sqlalchemy-2.0.46
Fix mypy errors with SQLAlchemy 2.0.46+
2 parents d876865 + b9c326a commit 52dfba6

File tree

4 files changed

+119
-110
lines changed

4 files changed

+119
-110
lines changed

pyathena/aio/sqlalchemy/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def fetchall(self) -> Any:
9393
def setinputsizes(self, sizes: Any) -> None:
9494
self._cursor.setinputsizes(sizes)
9595

96+
async def _async_soft_close(self) -> None:
97+
return
98+
9699
# PyAthena-specific methods used by AthenaDialect reflection
97100
def list_databases(self, *args: Any, **kwargs: Any) -> Any:
98101
return await_only(self._cursor.list_databases(*args, **kwargs))
@@ -122,11 +125,11 @@ class AsyncAdapt_pyathena_connection(AdaptedConnection): # noqa: N801 - follows
122125

123126
def __init__(self, dbapi: "AsyncAdapt_pyathena_dbapi", connection: AioConnection) -> None:
124127
self.dbapi = dbapi
125-
self._connection = connection
128+
self._connection = connection # type: ignore[assignment]
126129

127130
@property
128131
def driver_connection(self) -> AioConnection:
129-
return self._connection # type: ignore[no-any-return]
132+
return self._connection # type: ignore[return-value]
130133

131134
@property
132135
def catalog_name(self) -> Optional[str]:
@@ -144,7 +147,7 @@ def close(self) -> None:
144147
self._connection.close()
145148

146149
def commit(self) -> None:
147-
self._connection.commit()
150+
self._connection.commit() # type: ignore[unused-coroutine]
148151

149152
def rollback(self) -> None:
150153
pass

pyathena/sqlalchemy/compiler.py

Lines changed: 65 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
4+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
55

66
from sqlalchemy import exc, types, util
77
from sqlalchemy.sql.compiler import (
@@ -31,12 +31,12 @@
3131
UniqueConstraint,
3232
)
3333
from sqlalchemy.sql.ddl import CreateTable
34-
from sqlalchemy.sql.elements import FunctionElement
34+
from sqlalchemy.sql.functions import Function
3535
from sqlalchemy.sql.selectable import GenerativeSelect
3636

3737
from pyathena.sqlalchemy.base import AthenaDialect
3838

39-
_DialectArgDict = Dict[str, Any]
39+
_DialectArgDict = Mapping[str, Any]
4040
CreateColumn = Any
4141

4242

@@ -61,10 +61,10 @@ class AthenaTypeCompiler(GenericTypeCompiler):
6161
https://docs.aws.amazon.com/athena/latest/ug/data-types.html
6262
"""
6363

64-
def visit_FLOAT(self, type_: Type[Any], **kw) -> str: # noqa: N802
65-
return self.visit_REAL(type_, **kw)
64+
def visit_FLOAT(self, type_: types.Float[Any], **kw: Any) -> str: # noqa: N802
65+
return self.visit_REAL(type_, **kw) # type: ignore[arg-type]
6666

67-
def visit_REAL(self, type_: Type[Any], **kw) -> str: # noqa: N802
67+
def visit_REAL(self, type_: types.REAL[Any], **kw: Any) -> str: # noqa: N802
6868
return "FLOAT"
6969

7070
def visit_DOUBLE(self, type_, **kw) -> str: # noqa: N802
@@ -73,78 +73,78 @@ def visit_DOUBLE(self, type_, **kw) -> str: # noqa: N802
7373
def visit_DOUBLE_PRECISION(self, type_, **kw) -> str: # noqa: N802
7474
return "DOUBLE"
7575

76-
def visit_NUMERIC(self, type_: Type[Any], **kw) -> str: # noqa: N802
77-
return self.visit_DECIMAL(type_, **kw)
76+
def visit_NUMERIC(self, type_: types.Numeric[Any], **kw: Any) -> str: # noqa: N802
77+
return self.visit_DECIMAL(type_, **kw) # type: ignore[arg-type]
7878

79-
def visit_DECIMAL(self, type_: Type[Any], **kw) -> str: # noqa: N802
79+
def visit_DECIMAL(self, type_: types.DECIMAL[Any], **kw: Any) -> str: # noqa: N802
8080
if type_.precision is None:
8181
return "DECIMAL"
8282
if type_.scale is None:
8383
return f"DECIMAL({type_.precision})"
8484
return f"DECIMAL({type_.precision}, {type_.scale})"
8585

86-
def visit_TINYINT(self, type_: Type[Any], **kw) -> str: # noqa: N802
86+
def visit_TINYINT(self, type_: types.Integer, **kw: Any) -> str: # noqa: N802
8787
return "TINYINT"
8888

89-
def visit_INTEGER(self, type_: Type[Any], **kw) -> str: # noqa: N802
89+
def visit_INTEGER(self, type_: types.Integer, **kw: Any) -> str: # noqa: N802
9090
return "INTEGER"
9191

92-
def visit_SMALLINT(self, type_: Type[Any], **kw) -> str: # noqa: N802
92+
def visit_SMALLINT(self, type_: types.SmallInteger, **kw: Any) -> str: # noqa: N802
9393
return "SMALLINT"
9494

95-
def visit_BIGINT(self, type_: Type[Any], **kw) -> str: # noqa: N802
95+
def visit_BIGINT(self, type_: types.BigInteger, **kw: Any) -> str: # noqa: N802
9696
return "BIGINT"
9797

98-
def visit_TIMESTAMP(self, type_: Type[Any], **kw) -> str: # noqa: N802
98+
def visit_TIMESTAMP(self, type_: types.TIMESTAMP, **kw: Any) -> str: # noqa: N802
9999
return "TIMESTAMP"
100100

101-
def visit_DATETIME(self, type_: Type[Any], **kw) -> str: # noqa: N802
102-
return self.visit_TIMESTAMP(type_, **kw)
101+
def visit_DATETIME(self, type_: types.DateTime, **kw: Any) -> str: # noqa: N802
102+
return self.visit_TIMESTAMP(type_, **kw) # type: ignore[arg-type]
103103

104-
def visit_DATE(self, type_: Type[Any], **kw) -> str: # noqa: N802
104+
def visit_DATE(self, type_: types.Date, **kw: Any) -> str: # noqa: N802
105105
return "DATE"
106106

107-
def visit_TIME(self, type_: Type[Any], **kw) -> str: # noqa: N802
107+
def visit_TIME(self, type_: types.Time, **kw: Any) -> str: # noqa: N802
108108
raise exc.CompileError(f"Data type `{type_}` is not supported")
109109

110-
def visit_CLOB(self, type_: Type[Any], **kw) -> str: # noqa: N802
111-
return self.visit_BINARY(type_, **kw)
110+
def visit_CLOB(self, type_: types.CLOB, **kw: Any) -> str: # noqa: N802
111+
return self.visit_BINARY(type_, **kw) # type: ignore[arg-type]
112112

113-
def visit_NCLOB(self, type_: Type[Any], **kw) -> str: # noqa: N802
114-
return self.visit_BINARY(type_, **kw)
113+
def visit_NCLOB(self, type_: types.Text, **kw: Any) -> str: # noqa: N802
114+
return self.visit_BINARY(type_, **kw) # type: ignore[arg-type]
115115

116-
def visit_CHAR(self, type_: Type[Any], **kw) -> str: # noqa: N802
116+
def visit_CHAR(self, type_: types.CHAR, **kw: Any) -> str: # noqa: N802
117117
if type_.length:
118-
return cast(str, self._render_string_type(type_, "CHAR"))
118+
return self._render_string_type("CHAR", type_.length, type_.collation)
119119
return "STRING"
120120

121-
def visit_NCHAR(self, type_: Type[Any], **kw) -> str: # noqa: N802
122-
return self.visit_CHAR(type_, **kw)
121+
def visit_NCHAR(self, type_: types.NCHAR, **kw: Any) -> str: # noqa: N802
122+
return self.visit_CHAR(type_, **kw) # type: ignore[arg-type]
123123

124-
def visit_VARCHAR(self, type_: Type[Any], **kw) -> str: # noqa: N802
124+
def visit_VARCHAR(self, type_: types.String, **kw: Any) -> str: # noqa: N802
125125
if type_.length:
126-
return cast(str, self._render_string_type(type_, "VARCHAR"))
126+
return self._render_string_type("VARCHAR", type_.length, type_.collation)
127127
return "STRING"
128128

129-
def visit_NVARCHAR(self, type_: Type[Any], **kw) -> str: # noqa: N802
130-
return self.visit_VARCHAR(type_, **kw)
129+
def visit_NVARCHAR(self, type_: types.NVARCHAR, **kw: Any) -> str: # noqa: N802
130+
return self.visit_VARCHAR(type_, **kw) # type: ignore[arg-type]
131131

132-
def visit_TEXT(self, type_: Type[Any], **kw) -> str: # noqa: N802
132+
def visit_TEXT(self, type_: types.Text, **kw: Any) -> str: # noqa: N802
133133
return "STRING"
134134

135-
def visit_BLOB(self, type_: Type[Any], **kw) -> str: # noqa: N802
136-
return self.visit_BINARY(type_, **kw)
135+
def visit_BLOB(self, type_: types.LargeBinary, **kw: Any) -> str: # noqa: N802
136+
return self.visit_BINARY(type_, **kw) # type: ignore[arg-type]
137137

138-
def visit_BINARY(self, type_: Type[Any], **kw) -> str: # noqa: N802
138+
def visit_BINARY(self, type_: types.BINARY, **kw: Any) -> str: # noqa: N802
139139
return "BINARY"
140140

141-
def visit_VARBINARY(self, type_: Type[Any], **kw) -> str: # noqa: N802
142-
return self.visit_BINARY(type_, **kw)
141+
def visit_VARBINARY(self, type_: types.VARBINARY, **kw: Any) -> str: # noqa: N802
142+
return self.visit_BINARY(type_, **kw) # type: ignore[arg-type]
143143

144-
def visit_BOOLEAN(self, type_: Type[Any], **kw) -> str: # noqa: N802
144+
def visit_BOOLEAN(self, type_: types.Boolean, **kw: Any) -> str: # noqa: N802
145145
return "BOOLEAN"
146146

147-
def visit_JSON(self, type_: Type[Any], **kw) -> str: # noqa: N802
147+
def visit_JSON(self, type_: types.JSON, **kw: Any) -> str: # noqa: N802
148148
return "JSON"
149149

150150
def visit_string(self, type_, **kw): # noqa: N802
@@ -219,10 +219,10 @@ class AthenaStatementCompiler(SQLCompiler):
219219
https://docs.aws.amazon.com/athena/latest/ug/ddl-sql-reference.html
220220
"""
221221

222-
def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):
222+
def visit_char_length_func(self, fn: "Function[Any]", **kw: Any) -> str:
223223
return f"length{self.function_argspec(fn, **kw)}"
224224

225-
def visit_filter_func(self, fn: "FunctionElement[Any]", **kw) -> str:
225+
def visit_filter_func(self, fn: "Function[Any]", **kw: Any) -> str:
226226
"""Compile Athena filter() function with lambda expressions.
227227
228228
Supports syntax: filter(array_expr, lambda_expr)
@@ -370,7 +370,7 @@ def _get_comment_specification(self, comment: str) -> str:
370370
return f"COMMENT {self._escape_comment(comment)}"
371371

372372
def _get_bucket_count(
373-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
373+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
374374
) -> Optional[str]:
375375
if dialect_opts["bucket_count"]:
376376
bucket_count = dialect_opts["bucket_count"]
@@ -381,7 +381,7 @@ def _get_bucket_count(
381381
return cast(str, bucket_count) if bucket_count is not None else None
382382

383383
def _get_file_format(
384-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
384+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
385385
) -> Optional[str]:
386386
if dialect_opts["file_format"]:
387387
file_format = dialect_opts["file_format"]
@@ -392,7 +392,7 @@ def _get_file_format(
392392
return cast(Optional[str], file_format)
393393

394394
def _get_file_format_specification(
395-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
395+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
396396
) -> str:
397397
file_format = self._get_file_format(dialect_opts, connect_opts)
398398
text = []
@@ -401,7 +401,7 @@ def _get_file_format_specification(
401401
return "\n".join(text)
402402

403403
def _get_row_format(
404-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
404+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
405405
) -> Optional[str]:
406406
if dialect_opts["row_format"]:
407407
row_format = dialect_opts["row_format"]
@@ -412,7 +412,7 @@ def _get_row_format(
412412
return cast(Optional[str], row_format)
413413

414414
def _get_row_format_specification(
415-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
415+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
416416
) -> str:
417417
row_format = self._get_row_format(dialect_opts, connect_opts)
418418
text = []
@@ -421,7 +421,7 @@ def _get_row_format_specification(
421421
return "\n".join(text)
422422

423423
def _get_serde_properties(
424-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
424+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
425425
) -> Optional[Union[str, Dict[str, Any]]]:
426426
if dialect_opts["serdeproperties"]:
427427
serde_properties = dialect_opts["serdeproperties"]
@@ -432,7 +432,7 @@ def _get_serde_properties(
432432
return cast(Optional[str], serde_properties)
433433

434434
def _get_serde_properties_specification(
435-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
435+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
436436
) -> str:
437437
serde_properties = self._get_serde_properties(dialect_opts, connect_opts)
438438
text = []
@@ -446,7 +446,7 @@ def _get_serde_properties_specification(
446446
return "\n".join(text)
447447

448448
def _get_table_location(
449-
self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
449+
self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
450450
) -> Optional[str]:
451451
if dialect_opts["location"]:
452452
location = cast(str, dialect_opts["location"])
@@ -464,7 +464,7 @@ def _get_table_location(
464464
return location
465465

466466
def _get_table_location_specification(
467-
self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
467+
self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
468468
) -> str:
469469
location = self._get_table_location(table, dialect_opts, connect_opts)
470470
text = []
@@ -482,7 +482,7 @@ def _get_table_location_specification(
482482
return "\n".join(text)
483483

484484
def _get_table_properties(
485-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
485+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
486486
) -> Optional[Union[Dict[str, str], str]]:
487487
if dialect_opts["tblproperties"]:
488488
table_properties = cast(str, dialect_opts["tblproperties"])
@@ -493,7 +493,7 @@ def _get_table_properties(
493493
return table_properties
494494

495495
def _get_compression(
496-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
496+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
497497
) -> Optional[str]:
498498
if dialect_opts["compression"]:
499499
compression = cast(str, dialect_opts["compression"])
@@ -504,7 +504,7 @@ def _get_compression(
504504
return compression
505505

506506
def _get_table_properties_specification(
507-
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
507+
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
508508
) -> str:
509509
properties = self._get_table_properties(dialect_opts, connect_opts)
510510
if properties:
@@ -554,34 +554,30 @@ def get_column_specification(self, column: "Column[Any]", **kwargs) -> str:
554554
text.append(f"{self._get_comment_specification(column.comment)}")
555555
return " ".join(text)
556556

557-
def visit_check_constraint(self, constraint: "CheckConstraint", **kw) -> Optional[str]:
558-
return None
557+
def visit_check_constraint(self, constraint: "CheckConstraint", **kw: Any) -> str:
558+
return ""
559559

560-
def visit_column_check_constraint(self, constraint: "CheckConstraint", **kw) -> Optional[str]:
561-
return None
560+
def visit_column_check_constraint(self, constraint: "CheckConstraint", **kw: Any) -> str:
561+
return ""
562562

563-
def visit_foreign_key_constraint(
564-
self, constraint: "ForeignKeyConstraint", **kw
565-
) -> Optional[str]:
566-
return None
563+
def visit_foreign_key_constraint(self, constraint: "ForeignKeyConstraint", **kw: Any) -> str:
564+
return ""
567565

568-
def visit_primary_key_constraint(
569-
self, constraint: "PrimaryKeyConstraint", **kw
570-
) -> Optional[str]:
571-
return None
566+
def visit_primary_key_constraint(self, constraint: "PrimaryKeyConstraint", **kw: Any) -> str:
567+
return ""
572568

573-
def visit_unique_constraint(self, constraint: "UniqueConstraint", **kw) -> Optional[str]:
574-
return None
569+
def visit_unique_constraint(self, constraint: "UniqueConstraint", **kw: Any) -> str:
570+
return ""
575571

576-
def _get_connect_option_partitions(self, connect_opts: Dict[str, Any]) -> List[str]:
572+
def _get_connect_option_partitions(self, connect_opts: Mapping[str, Any]) -> List[str]:
577573
if connect_opts:
578574
partition = cast(str, connect_opts.get("partition"))
579575
partitions = partition.split(",") if partition else []
580576
else:
581577
partitions = []
582578
return partitions
583579

584-
def _get_connect_option_buckets(self, connect_opts: Dict[str, Any]) -> List[str]:
580+
def _get_connect_option_buckets(self, connect_opts: Mapping[str, Any]) -> List[str]:
585581
if connect_opts:
586582
bucket = cast(str, connect_opts.get("cluster"))
587583
buckets = bucket.split(",") if bucket else []
@@ -624,7 +620,7 @@ def _prepared_columns(
624620
table: "Table",
625621
is_iceberg: bool,
626622
create_columns: List["CreateColumn"],
627-
connect_opts: Dict[str, Any],
623+
connect_opts: Mapping[str, Any],
628624
) -> Tuple[List[str], List[str], List[str]]:
629625
columns, partitions, buckets = [], [], []
630626
conn_partitions = self._get_connect_option_partitions(connect_opts)

tests/sqlalchemy/test_suite.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest
1010

1111
del BinaryTest # noqa
12-
del BizarroCharacterFKResolutionTest # noqa
1312
del ComponentReflectionTest # noqa
1413
del ComponentReflectionTestExtra # noqa
1514
del CompositeKeyReflectionTest # noqa

0 commit comments

Comments
 (0)