Skip to content

Commit 4deba62

Browse files
Refactor(typing): add annotations in dialect/parser/schema modules (#7444)
* annotate parser module functions * refactor (typing): annotate `dialects/dialect` module functions * refactor (typing): annotate schema module functions and `MappingSchema` methods * Refactor(typing): using Sequence to narrow the internal type of `build_timetostr_or_tochar` * refactor (typing): widen the containers types of dialect functions arguments * refactor (typing): widen the containers types of parser functions arguments + format dialect module * refactor: Use `Sequence[Any]` for most builder function, and define a centralized type alias for it. * fix (typing): we need runtime values for type aliases for Python <3.10
1 parent cec9b5d commit 4deba62

4 files changed

Lines changed: 95 additions & 76 deletions

File tree

sqlglot/_typing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import typing as t
4+
from collections.abc import Mapping, Sequence
45

56
if t.TYPE_CHECKING:
67
from typing_extensions import ParamSpec
7-
from collections.abc import Mapping
8+
89
import sqlglot
910
from sqlglot.dialects.dialect import DialectType
1011
from sqlglot.errors import ErrorLevel
@@ -16,6 +17,9 @@
1617
F = t.TypeVar("F", bound="sqlglot.exp.Func")
1718
T = t.TypeVar("T")
1819

20+
BuilderArgs = Sequence[t.Any]
21+
"""Sequence of arguments passed to builder functions."""
22+
1923

2024
class _DialectArg(t.TypedDict, total=False):
2125
dialect: DialectType

sqlglot/dialects/dialect.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import typing as t
66
import sys
7-
from collections.abc import Sequence
7+
from collections.abc import Iterable, MutableSequence
88
from enum import Enum, auto
99
from functools import reduce
1010
from builtins import type as Type
@@ -58,7 +58,7 @@
5858
DATETIME_ADD = (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.TimestampAdd)
5959

6060
if t.TYPE_CHECKING:
61-
from sqlglot._typing import B, E, F, GeneratorArgs, ParserArgs
61+
from sqlglot._typing import B, E, F, GeneratorArgs, ParserArgs, BuilderArgs
6262
from typing_extensions import Unpack
6363

6464
logger = logging.getLogger("sqlglot")
@@ -1398,7 +1398,7 @@ def array_concat_sql(
13981398
Dialects that propagate NULLs need to set `ARRAY_FUNCS_PROPAGATES_NULLS` to True.
13991399
"""
14001400

1401-
def _build_func_call(self: Generator, func_name: str, args: Sequence[exp.Expr]) -> str:
1401+
def _build_func_call(self: Generator, func_name: str, args: BuilderArgs) -> str:
14021402
"""Build ARRAY_CONCAT call from a list of arguments, handling variadic vs binary nesting."""
14031403
if self.ARRAY_CONCAT_IS_VAR_LEN:
14041404
return self.func(func_name, *args)
@@ -1535,8 +1535,8 @@ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str:
15351535

15361536

15371537
def build_formatted_time(
1538-
exp_class: Type[E], dialect: str, default: t.Optional[bool | str] = None
1539-
) -> t.Callable[[list], E]:
1538+
exp_class: Type[E], dialect: str, default: bool | str | None = None
1539+
) -> t.Callable[[BuilderArgs], E]:
15401540
"""Helper used for time expressions.
15411541
15421542
Args:
@@ -1548,7 +1548,7 @@ def build_formatted_time(
15481548
A callable that can be used to return the appropriately formatted time expression.
15491549
"""
15501550

1551-
def _builder(args: t.List):
1551+
def _builder(args: BuilderArgs) -> E:
15521552
return exp_class(
15531553
this=seq_get(args, 0),
15541554
format=Dialect[dialect].format_time(
@@ -1576,11 +1576,11 @@ def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) ->
15761576

15771577
def build_date_delta(
15781578
exp_class: Type[E],
1579-
unit_mapping: t.Optional[dict[str, str]] = None,
1580-
default_unit: t.Optional[str] = "DAY",
1579+
unit_mapping: dict[str, str] | None = None,
1580+
default_unit: str | None = "DAY",
15811581
supports_timezone: bool = False,
1582-
) -> t.Callable[[list], E]:
1583-
def _builder(args: list) -> E:
1582+
) -> t.Callable[[BuilderArgs], E]:
1583+
def _builder(args: BuilderArgs) -> E:
15841584
unit_based = len(args) >= 3
15851585
has_timezone = len(args) == 4
15861586
this = args[2] if unit_based else seq_get(args, 0)
@@ -1598,8 +1598,8 @@ def _builder(args: list) -> E:
15981598

15991599
def build_date_delta_with_interval(
16001600
expression_class: Type[E],
1601-
) -> t.Callable[[list], t.Optional[E]]:
1602-
def _builder(args: list) -> t.Optional[E]:
1601+
) -> t.Callable[[BuilderArgs], t.Optional[E]]:
1602+
def _builder(args: BuilderArgs) -> t.Optional[E]:
16031603
if len(args) < 2:
16041604
return None
16051605

@@ -1613,7 +1613,7 @@ def _builder(args: list) -> t.Optional[E]:
16131613
return _builder
16141614

16151615

1616-
def date_trunc_to_time(args: list) -> exp.DateTrunc | exp.TimestampTrunc:
1616+
def date_trunc_to_time(args: BuilderArgs) -> exp.DateTrunc | exp.TimestampTrunc:
16171617
unit = seq_get(args, 0)
16181618
this = seq_get(args, 1)
16191619

@@ -1808,8 +1808,8 @@ def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
18081808
)
18091809

18101810

1811-
def pivot_column_names(aggregations: t.List[exp.Expr], dialect: DialectType) -> t.List[str]:
1812-
names = []
1811+
def pivot_column_names(aggregations: Iterable[exp.Expr], dialect: DialectType) -> list[str]:
1812+
names: list[str] = []
18131813
for agg in aggregations:
18141814
if isinstance(agg, exp.Alias):
18151815
names.append(agg.alias)
@@ -1832,17 +1832,17 @@ def pivot_column_names(aggregations: t.List[exp.Expr], dialect: DialectType) ->
18321832
return names
18331833

18341834

1835-
def binary_from_function(expr_type: Type[B]) -> t.Callable[[list], B]:
1835+
def binary_from_function(expr_type: Type[B]) -> t.Callable[[BuilderArgs], B]:
18361836
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
18371837

18381838

18391839
# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
1840-
def build_timestamp_trunc(args: list) -> exp.TimestampTrunc:
1840+
def build_timestamp_trunc(args: BuilderArgs) -> exp.TimestampTrunc:
18411841
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
18421842

18431843

18441844
def build_trunc(
1845-
args: t.List,
1845+
args: BuilderArgs,
18461846
dialect: DialectType,
18471847
date_trunc_unabbreviate: bool = True,
18481848
default_date_trunc_unit: t.Optional[str] = None,
@@ -1898,7 +1898,7 @@ def is_parse_json(expression: exp.Expr) -> bool:
18981898
)
18991899

19001900

1901-
def isnull_to_is_null(args: t.List) -> exp.Expr:
1901+
def isnull_to_is_null(args: BuilderArgs) -> exp.Expr:
19021902
return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
19031903

19041904

@@ -2070,9 +2070,9 @@ def build_json_extract_path(
20702070
zero_based_indexing: bool = True,
20712071
arrow_req_json_type: bool = False,
20722072
json_type: t.Optional[str] = None,
2073-
) -> t.Callable[[t.List], F]:
2074-
def _builder(args: t.List) -> F:
2075-
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
2073+
) -> t.Callable[[MutableSequence[t.Any]], F]:
2074+
def _builder(args: MutableSequence[t.Any]) -> F:
2075+
segments: list[exp.JSONPathPart] = [exp.JSONPathRoot()]
20762076
for arg in args[1:]:
20772077
if not isinstance(arg, exp.Literal):
20782078
# We use the fallback parser because we can't really transpile non-literals safely
@@ -2236,7 +2236,7 @@ def _builder(dtype: exp.DataType) -> exp.DataType:
22362236
return _builder
22372237

22382238

2239-
def build_timestamp_from_parts(args: t.List) -> exp.Func:
2239+
def build_timestamp_from_parts(args: BuilderArgs) -> exp.Func:
22402240
if len(args) == 2:
22412241
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
22422242
# so we parse this into Anonymous for now instead of introducing complexity
@@ -2301,8 +2301,8 @@ def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateD
23012301
return self.func("SEQUENCE", start, end, step)
23022302

23032303

2304-
def build_like(expr_type: Type[E], not_like: bool = False) -> t.Callable[[list], exp.Expr]:
2305-
def _builder(args: t.List) -> exp.Expr:
2304+
def build_like(expr_type: Type[E], not_like: bool = False) -> t.Callable[[BuilderArgs], exp.Expr]:
2305+
def _builder(args: BuilderArgs) -> exp.Expr:
23062306
like_expr: exp.Expr = expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
23072307

23082308
if escape := seq_get(args, 2):
@@ -2316,8 +2316,8 @@ def _builder(args: t.List) -> exp.Expr:
23162316
return _builder
23172317

23182318

2319-
def build_regexp_extract(expr_type: Type[E]) -> t.Callable[[list, Dialect], E]:
2320-
def _builder(args: t.List, dialect: Dialect) -> E:
2319+
def build_regexp_extract(expr_type: Type[E]) -> t.Callable[[BuilderArgs, Dialect], E]:
2320+
def _builder(args: BuilderArgs, dialect: Dialect) -> E:
23212321
# The "position" argument specifies the index of the string character to start matching from.
23222322
# `null_if_pos_overflow` reflects the dialect's behavior when position is greater than the string
23232323
# length. If true, returns NULL. If false, returns an empty string. `null_if_pos_overflow` is
@@ -2399,8 +2399,8 @@ def length_or_char_length_sql(self: Generator, expression: exp.Length) -> str:
23992399
def groupconcat_sql(
24002400
self: Generator,
24012401
expression: exp.GroupConcat,
2402-
func_name="LISTAGG",
2403-
sep: t.Optional[str] = ",",
2402+
func_name: str = "LISTAGG",
2403+
sep: str | None = ",",
24042404
within_group: bool = True,
24052405
on_overflow: bool = False,
24062406
) -> str:
@@ -2443,7 +2443,9 @@ def groupconcat_sql(
24432443
return self.sql(listagg)
24442444

24452445

2446-
def build_timetostr_or_tochar(args: t.List, dialect: DialectType) -> exp.TimeToStr | exp.ToChar:
2446+
def build_timetostr_or_tochar(
2447+
args: BuilderArgs, dialect: DialectType
2448+
) -> exp.TimeToStr | exp.ToChar:
24472449
if len(args) == 2:
24482450
this = args[0]
24492451
if not this.type:
@@ -2458,7 +2460,7 @@ def build_timetostr_or_tochar(args: t.List, dialect: DialectType) -> exp.TimeToS
24582460
return exp.ToChar.from_arg_list(args)
24592461

24602462

2461-
def build_replace_with_optional_replacement(args: t.List) -> exp.Replace:
2463+
def build_replace_with_optional_replacement(args: BuilderArgs) -> exp.Replace:
24622464
return exp.Replace(
24632465
this=seq_get(args, 0),
24642466
expression=seq_get(args, 1),

0 commit comments

Comments
 (0)