44import logging
55import typing as t
66import sys
7- from collections .abc import Sequence
7+ from collections .abc import Iterable , MutableSequence
88from enum import Enum , auto
99from functools import reduce
1010from builtins import type as Type
5858DATETIME_ADD = (exp .DateAdd , exp .TimeAdd , exp .DatetimeAdd , exp .TsOrDsAdd , exp .TimestampAdd )
5959
6060if t .TYPE_CHECKING :
61- from sqlglot ._typing import B , E , F , GeneratorArgs , ParserArgs
61+ from sqlglot ._typing import B , E , F , GeneratorArgs , ParserArgs , BuilderArgs
6262 from typing_extensions import Unpack
6363
6464logger = logging .getLogger ("sqlglot" )
@@ -1398,7 +1398,7 @@ def array_concat_sql(
13981398 Dialects that propagate NULLs need to set `ARRAY_FUNCS_PROPAGATES_NULLS` to True.
13991399 """
14001400
1401- def _build_func_call (self : Generator , func_name : str , args : Sequence [ exp . Expr ] ) -> str :
1401+ def _build_func_call (self : Generator , func_name : str , args : BuilderArgs ) -> str :
14021402 """Build ARRAY_CONCAT call from a list of arguments, handling variadic vs binary nesting."""
14031403 if self .ARRAY_CONCAT_IS_VAR_LEN :
14041404 return self .func (func_name , * args )
@@ -1535,8 +1535,8 @@ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str:
15351535
15361536
15371537def build_formatted_time (
1538- exp_class : Type [E ], dialect : str , default : t . Optional [ bool | str ] = None
1539- ) -> t .Callable [[list ], E ]:
1538+ exp_class : Type [E ], dialect : str , default : bool | str | None = None
1539+ ) -> t .Callable [[BuilderArgs ], E ]:
15401540 """Helper used for time expressions.
15411541
15421542 Args:
@@ -1548,7 +1548,7 @@ def build_formatted_time(
15481548 A callable that can be used to return the appropriately formatted time expression.
15491549 """
15501550
1551- def _builder (args : t . List ) :
1551+ def _builder (args : BuilderArgs ) -> E :
15521552 return exp_class (
15531553 this = seq_get (args , 0 ),
15541554 format = Dialect [dialect ].format_time (
@@ -1576,11 +1576,11 @@ def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) ->
15761576
15771577def build_date_delta (
15781578 exp_class : Type [E ],
1579- unit_mapping : t . Optional [ dict [str , str ]] = None ,
1580- default_unit : t . Optional [ str ] = "DAY" ,
1579+ unit_mapping : dict [str , str ] | None = None ,
1580+ default_unit : str | None = "DAY" ,
15811581 supports_timezone : bool = False ,
1582- ) -> t .Callable [[list ], E ]:
1583- def _builder (args : list ) -> E :
1582+ ) -> t .Callable [[BuilderArgs ], E ]:
1583+ def _builder (args : BuilderArgs ) -> E :
15841584 unit_based = len (args ) >= 3
15851585 has_timezone = len (args ) == 4
15861586 this = args [2 ] if unit_based else seq_get (args , 0 )
@@ -1598,8 +1598,8 @@ def _builder(args: list) -> E:
15981598
15991599def build_date_delta_with_interval (
16001600 expression_class : Type [E ],
1601- ) -> t .Callable [[list ], t .Optional [E ]]:
1602- def _builder (args : list ) -> t .Optional [E ]:
1601+ ) -> t .Callable [[BuilderArgs ], t .Optional [E ]]:
1602+ def _builder (args : BuilderArgs ) -> t .Optional [E ]:
16031603 if len (args ) < 2 :
16041604 return None
16051605
@@ -1613,7 +1613,7 @@ def _builder(args: list) -> t.Optional[E]:
16131613 return _builder
16141614
16151615
1616- def date_trunc_to_time (args : list ) -> exp .DateTrunc | exp .TimestampTrunc :
1616+ def date_trunc_to_time (args : BuilderArgs ) -> exp .DateTrunc | exp .TimestampTrunc :
16171617 unit = seq_get (args , 0 )
16181618 this = seq_get (args , 1 )
16191619
@@ -1808,8 +1808,8 @@ def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
18081808 )
18091809
18101810
1811- def pivot_column_names (aggregations : t . List [exp .Expr ], dialect : DialectType ) -> t . List [str ]:
1812- names = []
1811+ def pivot_column_names (aggregations : Iterable [exp .Expr ], dialect : DialectType ) -> list [str ]:
1812+ names : list [ str ] = []
18131813 for agg in aggregations :
18141814 if isinstance (agg , exp .Alias ):
18151815 names .append (agg .alias )
@@ -1832,17 +1832,17 @@ def pivot_column_names(aggregations: t.List[exp.Expr], dialect: DialectType) ->
18321832 return names
18331833
18341834
1835- def binary_from_function (expr_type : Type [B ]) -> t .Callable [[list ], B ]:
1835+ def binary_from_function (expr_type : Type [B ]) -> t .Callable [[BuilderArgs ], B ]:
18361836 return lambda args : expr_type (this = seq_get (args , 0 ), expression = seq_get (args , 1 ))
18371837
18381838
18391839# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
1840- def build_timestamp_trunc (args : list ) -> exp .TimestampTrunc :
1840+ def build_timestamp_trunc (args : BuilderArgs ) -> exp .TimestampTrunc :
18411841 return exp .TimestampTrunc (this = seq_get (args , 1 ), unit = seq_get (args , 0 ))
18421842
18431843
18441844def build_trunc (
1845- args : t . List ,
1845+ args : BuilderArgs ,
18461846 dialect : DialectType ,
18471847 date_trunc_unabbreviate : bool = True ,
18481848 default_date_trunc_unit : t .Optional [str ] = None ,
@@ -1898,7 +1898,7 @@ def is_parse_json(expression: exp.Expr) -> bool:
18981898 )
18991899
19001900
1901- def isnull_to_is_null (args : t . List ) -> exp .Expr :
1901+ def isnull_to_is_null (args : BuilderArgs ) -> exp .Expr :
19021902 return exp .Paren (this = exp .Is (this = seq_get (args , 0 ), expression = exp .null ()))
19031903
19041904
@@ -2070,9 +2070,9 @@ def build_json_extract_path(
20702070 zero_based_indexing : bool = True ,
20712071 arrow_req_json_type : bool = False ,
20722072 json_type : t .Optional [str ] = None ,
2073- ) -> t .Callable [[t . List ], F ]:
2074- def _builder (args : t . List ) -> F :
2075- segments : t . List [exp .JSONPathPart ] = [exp .JSONPathRoot ()]
2073+ ) -> t .Callable [[MutableSequence [ t . Any ] ], F ]:
2074+ def _builder (args : MutableSequence [ t . Any ] ) -> F :
2075+ segments : list [exp .JSONPathPart ] = [exp .JSONPathRoot ()]
20762076 for arg in args [1 :]:
20772077 if not isinstance (arg , exp .Literal ):
20782078 # We use the fallback parser because we can't really transpile non-literals safely
@@ -2236,7 +2236,7 @@ def _builder(dtype: exp.DataType) -> exp.DataType:
22362236 return _builder
22372237
22382238
2239- def build_timestamp_from_parts (args : t . List ) -> exp .Func :
2239+ def build_timestamp_from_parts (args : BuilderArgs ) -> exp .Func :
22402240 if len (args ) == 2 :
22412241 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
22422242 # so we parse this into Anonymous for now instead of introducing complexity
@@ -2301,8 +2301,8 @@ def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateD
23012301 return self .func ("SEQUENCE" , start , end , step )
23022302
23032303
2304- def build_like (expr_type : Type [E ], not_like : bool = False ) -> t .Callable [[list ], exp .Expr ]:
2305- def _builder (args : t . List ) -> exp .Expr :
2304+ def build_like (expr_type : Type [E ], not_like : bool = False ) -> t .Callable [[BuilderArgs ], exp .Expr ]:
2305+ def _builder (args : BuilderArgs ) -> exp .Expr :
23062306 like_expr : exp .Expr = expr_type (this = seq_get (args , 0 ), expression = seq_get (args , 1 ))
23072307
23082308 if escape := seq_get (args , 2 ):
@@ -2316,8 +2316,8 @@ def _builder(args: t.List) -> exp.Expr:
23162316 return _builder
23172317
23182318
2319- def build_regexp_extract (expr_type : Type [E ]) -> t .Callable [[list , Dialect ], E ]:
2320- def _builder (args : t . List , dialect : Dialect ) -> E :
2319+ def build_regexp_extract (expr_type : Type [E ]) -> t .Callable [[BuilderArgs , Dialect ], E ]:
2320+ def _builder (args : BuilderArgs , dialect : Dialect ) -> E :
23212321 # The "position" argument specifies the index of the string character to start matching from.
23222322 # `null_if_pos_overflow` reflects the dialect's behavior when position is greater than the string
23232323 # length. If true, returns NULL. If false, returns an empty string. `null_if_pos_overflow` is
@@ -2399,8 +2399,8 @@ def length_or_char_length_sql(self: Generator, expression: exp.Length) -> str:
23992399def groupconcat_sql (
24002400 self : Generator ,
24012401 expression : exp .GroupConcat ,
2402- func_name = "LISTAGG" ,
2403- sep : t . Optional [ str ] = "," ,
2402+ func_name : str = "LISTAGG" ,
2403+ sep : str | None = "," ,
24042404 within_group : bool = True ,
24052405 on_overflow : bool = False ,
24062406) -> str :
@@ -2443,7 +2443,9 @@ def groupconcat_sql(
24432443 return self .sql (listagg )
24442444
24452445
2446- def build_timetostr_or_tochar (args : t .List , dialect : DialectType ) -> exp .TimeToStr | exp .ToChar :
2446+ def build_timetostr_or_tochar (
2447+ args : BuilderArgs , dialect : DialectType
2448+ ) -> exp .TimeToStr | exp .ToChar :
24472449 if len (args ) == 2 :
24482450 this = args [0 ]
24492451 if not this .type :
@@ -2458,7 +2460,7 @@ def build_timetostr_or_tochar(args: t.List, dialect: DialectType) -> exp.TimeToS
24582460 return exp .ToChar .from_arg_list (args )
24592461
24602462
2461- def build_replace_with_optional_replacement (args : t . List ) -> exp .Replace :
2463+ def build_replace_with_optional_replacement (args : BuilderArgs ) -> exp .Replace :
24622464 return exp .Replace (
24632465 this = seq_get (args , 0 ),
24642466 expression = seq_get (args , 1 ),
0 commit comments