11# -*- coding: utf-8 -*-
22from __future__ import annotations
33
4- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type , Union , cast
4+ from typing import TYPE_CHECKING , Any , Dict , List , Mapping , Optional , Tuple , Union , cast
55
66from sqlalchemy import exc , types , util
77from sqlalchemy .sql .compiler import (
3131 UniqueConstraint ,
3232 )
3333 from sqlalchemy .sql .ddl import CreateTable
34- from sqlalchemy .sql .elements import FunctionElement
34+ from sqlalchemy .sql .functions import Function
3535 from sqlalchemy .sql .selectable import GenerativeSelect
3636
3737 from pyathena .sqlalchemy .base import AthenaDialect
3838
39- _DialectArgDict = Dict [str , Any ]
39+ _DialectArgDict = Mapping [str , Any ]
4040 CreateColumn = Any
4141
4242
@@ -61,10 +61,10 @@ class AthenaTypeCompiler(GenericTypeCompiler):
6161 https://docs.aws.amazon.com/athena/latest/ug/data-types.html
6262 """
6363
64- def visit_FLOAT (self , type_ : Type [Any ], ** kw ) -> str : # noqa: N802
65- return self .visit_REAL (type_ , ** kw )
64+ def visit_FLOAT (self , type_ : types . Float [Any ], ** kw : Any ) -> str : # noqa: N802
65+ return self .visit_REAL (type_ , ** kw ) # type: ignore[arg-type]
6666
67- def visit_REAL (self , type_ : Type [Any ], ** kw ) -> str : # noqa: N802
67+ def visit_REAL (self , type_ : types . REAL [Any ], ** kw : Any ) -> str : # noqa: N802
6868 return "FLOAT"
6969
7070 def visit_DOUBLE (self , type_ , ** kw ) -> str : # noqa: N802
@@ -73,78 +73,78 @@ def visit_DOUBLE(self, type_, **kw) -> str: # noqa: N802
7373 def visit_DOUBLE_PRECISION (self , type_ , ** kw ) -> str : # noqa: N802
7474 return "DOUBLE"
7575
76- def visit_NUMERIC (self , type_ : Type [Any ], ** kw ) -> str : # noqa: N802
77- return self .visit_DECIMAL (type_ , ** kw )
76+ def visit_NUMERIC (self , type_ : types . Numeric [Any ], ** kw : Any ) -> str : # noqa: N802
77+ return self .visit_DECIMAL (type_ , ** kw ) # type: ignore[arg-type]
7878
79- def visit_DECIMAL (self , type_ : Type [Any ], ** kw ) -> str : # noqa: N802
79+ def visit_DECIMAL (self , type_ : types . DECIMAL [Any ], ** kw : Any ) -> str : # noqa: N802
8080 if type_ .precision is None :
8181 return "DECIMAL"
8282 if type_ .scale is None :
8383 return f"DECIMAL({ type_ .precision } )"
8484 return f"DECIMAL({ type_ .precision } , { type_ .scale } )"
8585
86- def visit_TINYINT (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
86+ def visit_TINYINT (self , type_ : types . Integer , ** kw : Any ) -> str : # noqa: N802
8787 return "TINYINT"
8888
89- def visit_INTEGER (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
89+ def visit_INTEGER (self , type_ : types . Integer , ** kw : Any ) -> str : # noqa: N802
9090 return "INTEGER"
9191
92- def visit_SMALLINT (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
92+ def visit_SMALLINT (self , type_ : types . SmallInteger , ** kw : Any ) -> str : # noqa: N802
9393 return "SMALLINT"
9494
95- def visit_BIGINT (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
95+ def visit_BIGINT (self , type_ : types . BigInteger , ** kw : Any ) -> str : # noqa: N802
9696 return "BIGINT"
9797
98- def visit_TIMESTAMP (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
98+ def visit_TIMESTAMP (self , type_ : types . TIMESTAMP , ** kw : Any ) -> str : # noqa: N802
9999 return "TIMESTAMP"
100100
101- def visit_DATETIME (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
102- return self .visit_TIMESTAMP (type_ , ** kw )
101+ def visit_DATETIME (self , type_ : types . DateTime , ** kw : Any ) -> str : # noqa: N802
102+ return self .visit_TIMESTAMP (type_ , ** kw ) # type: ignore[arg-type]
103103
104- def visit_DATE (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
104+ def visit_DATE (self , type_ : types . Date , ** kw : Any ) -> str : # noqa: N802
105105 return "DATE"
106106
107- def visit_TIME (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
107+ def visit_TIME (self , type_ : types . Time , ** kw : Any ) -> str : # noqa: N802
108108 raise exc .CompileError (f"Data type `{ type_ } ` is not supported" )
109109
110- def visit_CLOB (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
111- return self .visit_BINARY (type_ , ** kw )
110+ def visit_CLOB (self , type_ : types . CLOB , ** kw : Any ) -> str : # noqa: N802
111+ return self .visit_BINARY (type_ , ** kw ) # type: ignore[arg-type]
112112
113- def visit_NCLOB (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
114- return self .visit_BINARY (type_ , ** kw )
113+ def visit_NCLOB (self , type_ : types . Text , ** kw : Any ) -> str : # noqa: N802
114+ return self .visit_BINARY (type_ , ** kw ) # type: ignore[arg-type]
115115
116- def visit_CHAR (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
116+ def visit_CHAR (self , type_ : types . CHAR , ** kw : Any ) -> str : # noqa: N802
117117 if type_ .length :
118- return cast ( str , self ._render_string_type (type_ , "CHAR" ) )
118+ return self ._render_string_type ("CHAR" , type_ . length , type_ . collation )
119119 return "STRING"
120120
121- def visit_NCHAR (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
122- return self .visit_CHAR (type_ , ** kw )
121+ def visit_NCHAR (self , type_ : types . NCHAR , ** kw : Any ) -> str : # noqa: N802
122+ return self .visit_CHAR (type_ , ** kw ) # type: ignore[arg-type]
123123
124- def visit_VARCHAR (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
124+ def visit_VARCHAR (self , type_ : types . String , ** kw : Any ) -> str : # noqa: N802
125125 if type_ .length :
126- return cast ( str , self ._render_string_type (type_ , "VARCHAR" ) )
126+ return self ._render_string_type ("VARCHAR" , type_ . length , type_ . collation )
127127 return "STRING"
128128
129- def visit_NVARCHAR (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
130- return self .visit_VARCHAR (type_ , ** kw )
129+ def visit_NVARCHAR (self , type_ : types . NVARCHAR , ** kw : Any ) -> str : # noqa: N802
130+ return self .visit_VARCHAR (type_ , ** kw ) # type: ignore[arg-type]
131131
132- def visit_TEXT (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
132+ def visit_TEXT (self , type_ : types . Text , ** kw : Any ) -> str : # noqa: N802
133133 return "STRING"
134134
135- def visit_BLOB (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
136- return self .visit_BINARY (type_ , ** kw )
135+ def visit_BLOB (self , type_ : types . LargeBinary , ** kw : Any ) -> str : # noqa: N802
136+ return self .visit_BINARY (type_ , ** kw ) # type: ignore[arg-type]
137137
138- def visit_BINARY (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
138+ def visit_BINARY (self , type_ : types . BINARY , ** kw : Any ) -> str : # noqa: N802
139139 return "BINARY"
140140
141- def visit_VARBINARY (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
142- return self .visit_BINARY (type_ , ** kw )
141+ def visit_VARBINARY (self , type_ : types . VARBINARY , ** kw : Any ) -> str : # noqa: N802
142+ return self .visit_BINARY (type_ , ** kw ) # type: ignore[arg-type]
143143
144- def visit_BOOLEAN (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
144+ def visit_BOOLEAN (self , type_ : types . Boolean , ** kw : Any ) -> str : # noqa: N802
145145 return "BOOLEAN"
146146
147- def visit_JSON (self , type_ : Type [ Any ] , ** kw ) -> str : # noqa: N802
147+ def visit_JSON (self , type_ : types . JSON , ** kw : Any ) -> str : # noqa: N802
148148 return "JSON"
149149
150150 def visit_string (self , type_ , ** kw ): # noqa: N802
@@ -219,10 +219,10 @@ class AthenaStatementCompiler(SQLCompiler):
219219 https://docs.aws.amazon.com/athena/latest/ug/ddl-sql-reference.html
220220 """
221221
222- def visit_char_length_func (self , fn : "FunctionElement [Any]" , ** kw ) :
222+ def visit_char_length_func (self , fn : "Function [Any]" , ** kw : Any ) -> str :
223223 return f"length{ self .function_argspec (fn , ** kw )} "
224224
225- def visit_filter_func (self , fn : "FunctionElement [Any]" , ** kw ) -> str :
225+ def visit_filter_func (self , fn : "Function [Any]" , ** kw : Any ) -> str :
226226 """Compile Athena filter() function with lambda expressions.
227227
228228 Supports syntax: filter(array_expr, lambda_expr)
@@ -370,7 +370,7 @@ def _get_comment_specification(self, comment: str) -> str:
370370 return f"COMMENT { self ._escape_comment (comment )} "
371371
372372 def _get_bucket_count (
373- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
373+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
374374 ) -> Optional [str ]:
375375 if dialect_opts ["bucket_count" ]:
376376 bucket_count = dialect_opts ["bucket_count" ]
@@ -381,7 +381,7 @@ def _get_bucket_count(
381381 return cast (str , bucket_count ) if bucket_count is not None else None
382382
383383 def _get_file_format (
384- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
384+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
385385 ) -> Optional [str ]:
386386 if dialect_opts ["file_format" ]:
387387 file_format = dialect_opts ["file_format" ]
@@ -392,7 +392,7 @@ def _get_file_format(
392392 return cast (Optional [str ], file_format )
393393
394394 def _get_file_format_specification (
395- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
395+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
396396 ) -> str :
397397 file_format = self ._get_file_format (dialect_opts , connect_opts )
398398 text = []
@@ -401,7 +401,7 @@ def _get_file_format_specification(
401401 return "\n " .join (text )
402402
403403 def _get_row_format (
404- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
404+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
405405 ) -> Optional [str ]:
406406 if dialect_opts ["row_format" ]:
407407 row_format = dialect_opts ["row_format" ]
@@ -412,7 +412,7 @@ def _get_row_format(
412412 return cast (Optional [str ], row_format )
413413
414414 def _get_row_format_specification (
415- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
415+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
416416 ) -> str :
417417 row_format = self ._get_row_format (dialect_opts , connect_opts )
418418 text = []
@@ -421,7 +421,7 @@ def _get_row_format_specification(
421421 return "\n " .join (text )
422422
423423 def _get_serde_properties (
424- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
424+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
425425 ) -> Optional [Union [str , Dict [str , Any ]]]:
426426 if dialect_opts ["serdeproperties" ]:
427427 serde_properties = dialect_opts ["serdeproperties" ]
@@ -432,7 +432,7 @@ def _get_serde_properties(
432432 return cast (Optional [str ], serde_properties )
433433
434434 def _get_serde_properties_specification (
435- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
435+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
436436 ) -> str :
437437 serde_properties = self ._get_serde_properties (dialect_opts , connect_opts )
438438 text = []
@@ -446,7 +446,7 @@ def _get_serde_properties_specification(
446446 return "\n " .join (text )
447447
448448 def _get_table_location (
449- self , table : "Table" , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
449+ self , table : "Table" , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
450450 ) -> Optional [str ]:
451451 if dialect_opts ["location" ]:
452452 location = cast (str , dialect_opts ["location" ])
@@ -464,7 +464,7 @@ def _get_table_location(
464464 return location
465465
466466 def _get_table_location_specification (
467- self , table : "Table" , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
467+ self , table : "Table" , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
468468 ) -> str :
469469 location = self ._get_table_location (table , dialect_opts , connect_opts )
470470 text = []
@@ -482,7 +482,7 @@ def _get_table_location_specification(
482482 return "\n " .join (text )
483483
484484 def _get_table_properties (
485- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
485+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
486486 ) -> Optional [Union [Dict [str , str ], str ]]:
487487 if dialect_opts ["tblproperties" ]:
488488 table_properties = cast (str , dialect_opts ["tblproperties" ])
@@ -493,7 +493,7 @@ def _get_table_properties(
493493 return table_properties
494494
495495 def _get_compression (
496- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
496+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
497497 ) -> Optional [str ]:
498498 if dialect_opts ["compression" ]:
499499 compression = cast (str , dialect_opts ["compression" ])
@@ -504,7 +504,7 @@ def _get_compression(
504504 return compression
505505
506506 def _get_table_properties_specification (
507- self , dialect_opts : "_DialectArgDict" , connect_opts : Dict [str , Any ]
507+ self , dialect_opts : "_DialectArgDict" , connect_opts : Mapping [str , Any ]
508508 ) -> str :
509509 properties = self ._get_table_properties (dialect_opts , connect_opts )
510510 if properties :
@@ -554,34 +554,30 @@ def get_column_specification(self, column: "Column[Any]", **kwargs) -> str:
554554 text .append (f"{ self ._get_comment_specification (column .comment )} " )
555555 return " " .join (text )
556556
557- def visit_check_constraint (self , constraint : "CheckConstraint" , ** kw ) -> Optional [ str ] :
558- return None
557+ def visit_check_constraint (self , constraint : "CheckConstraint" , ** kw : Any ) -> str :
558+ return ""
559559
560- def visit_column_check_constraint (self , constraint : "CheckConstraint" , ** kw ) -> Optional [ str ] :
561- return None
560+ def visit_column_check_constraint (self , constraint : "CheckConstraint" , ** kw : Any ) -> str :
561+ return ""
562562
563- def visit_foreign_key_constraint (
564- self , constraint : "ForeignKeyConstraint" , ** kw
565- ) -> Optional [str ]:
566- return None
563+ def visit_foreign_key_constraint (self , constraint : "ForeignKeyConstraint" , ** kw : Any ) -> str :
564+ return ""
567565
568- def visit_primary_key_constraint (
569- self , constraint : "PrimaryKeyConstraint" , ** kw
570- ) -> Optional [str ]:
571- return None
566+ def visit_primary_key_constraint (self , constraint : "PrimaryKeyConstraint" , ** kw : Any ) -> str :
567+ return ""
572568
573- def visit_unique_constraint (self , constraint : "UniqueConstraint" , ** kw ) -> Optional [ str ] :
574- return None
569+ def visit_unique_constraint (self , constraint : "UniqueConstraint" , ** kw : Any ) -> str :
570+ return ""
575571
576- def _get_connect_option_partitions (self , connect_opts : Dict [str , Any ]) -> List [str ]:
572+ def _get_connect_option_partitions (self , connect_opts : Mapping [str , Any ]) -> List [str ]:
577573 if connect_opts :
578574 partition = cast (str , connect_opts .get ("partition" ))
579575 partitions = partition .split ("," ) if partition else []
580576 else :
581577 partitions = []
582578 return partitions
583579
584- def _get_connect_option_buckets (self , connect_opts : Dict [str , Any ]) -> List [str ]:
580+ def _get_connect_option_buckets (self , connect_opts : Mapping [str , Any ]) -> List [str ]:
585581 if connect_opts :
586582 bucket = cast (str , connect_opts .get ("cluster" ))
587583 buckets = bucket .split ("," ) if bucket else []
@@ -624,7 +620,7 @@ def _prepared_columns(
624620 table : "Table" ,
625621 is_iceberg : bool ,
626622 create_columns : List ["CreateColumn" ],
627- connect_opts : Dict [str , Any ],
623+ connect_opts : Mapping [str , Any ],
628624 ) -> Tuple [List [str ], List [str ], List [str ]]:
629625 columns , partitions , buckets = [], [], []
630626 conn_partitions = self ._get_connect_option_partitions (connect_opts )
0 commit comments