Skip to content

Commit 35ae479

Browse files
refactor of datatypes typing:
- fixed all places where it was too narrow. Most of the time str are accepted for sqltypes. odd exception seems to be the map method on Relation - using Self for annotations on arguments when pertinent
1 parent b8da15c commit 35ae479

File tree

3 files changed

+45
-49
lines changed

3 files changed

+45
-49
lines changed

_duckdb-stubs/__init__.pyi

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@ if typing.TYPE_CHECKING:
1212
import pyarrow.lib
1313
from builtins import list as lst
1414
from collections.abc import Callable, Iterable, Sequence, Mapping
15-
from ._typing import ParquetFieldIdsType, IntoExpr, IntoExprColumn, PythonLiteral, IntoValues
15+
from ._typing import (
16+
ParquetFieldIdsType,
17+
IntoExpr,
18+
IntoExprColumn,
19+
PythonLiteral,
20+
IntoValues,
21+
IntoDType,
22+
IntoNestedDType,
23+
)
1624
from duckdb import sqltypes, func
1725

1826
__all__: lst[str] = [
@@ -193,7 +201,7 @@ class DuckDBPyConnection:
193201
def __enter__(self) -> Self: ...
194202
def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: ...
195203
def append(self, table_name: str, df: pandas.DataFrame, *, by_name: bool = False) -> DuckDBPyConnection: ...
196-
def array_type(self, type: sqltypes.DuckDBPyType, size: typing.SupportsInt) -> sqltypes.DuckDBPyType: ...
204+
def array_type(self, type: IntoDType, size: typing.SupportsInt) -> sqltypes.DuckDBPyType: ...
197205
def arrow(self, rows_per_batch: typing.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader:
198206
"""Alias of to_arrow_reader(). We recommend using to_arrow_reader() instead."""
199207
...
@@ -207,8 +215,8 @@ class DuckDBPyConnection:
207215
self,
208216
name: str,
209217
function: Callable[..., typing.Any],
210-
parameters: lst[sqltypes.DuckDBPyType] | None = None,
211-
return_type: sqltypes.DuckDBPyType | None = None,
218+
parameters: lst[IntoDType] | None = None,
219+
return_type: IntoDType | None = None,
212220
*,
213221
type: func.PythonUDFType = ...,
214222
null_handling: func.FunctionNullHandling = ...,
@@ -327,9 +335,9 @@ class DuckDBPyConnection:
327335
def disable_profiling(self) -> None: ...
328336
def interrupt(self) -> None: ...
329337
def list_filesystems(self) -> lst[str]: ...
330-
def list_type(self, type: sqltypes.DuckDBPyType) -> sqltypes.DuckDBPyType: ...
338+
def list_type(self, type: IntoDType) -> sqltypes.DuckDBPyType: ...
331339
def load_extension(self, extension: str) -> None: ...
332-
def map_type(self, key: sqltypes.DuckDBPyType, value: sqltypes.DuckDBPyType) -> sqltypes.DuckDBPyType: ...
340+
def map_type(self, key: IntoDType, value: IntoDType) -> sqltypes.DuckDBPyType: ...
333341
@typing.overload
334342
def pl(
335343
self, rows_per_batch: typing.SupportsInt = 1000000, *, lazy: typing.Literal[False] = ...
@@ -439,23 +447,17 @@ class DuckDBPyConnection:
439447
def register_filesystem(self, filesystem: fsspec.AbstractFileSystem) -> None: ...
440448
def remove_function(self, name: str) -> DuckDBPyConnection: ...
441449
def rollback(self) -> DuckDBPyConnection: ...
442-
def row_type(
443-
self, fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType]
444-
) -> sqltypes.DuckDBPyType: ...
450+
def row_type(self, fields: IntoNestedDType) -> sqltypes.DuckDBPyType: ...
445451
def sql(self, query: Statement | str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ...
446452
def sqltype(self, type_str: str) -> sqltypes.DuckDBPyType: ...
447453
def string_type(self, collation: str = "") -> sqltypes.DuckDBPyType: ...
448-
def struct_type(
449-
self, fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType]
450-
) -> sqltypes.DuckDBPyType: ...
454+
def struct_type(self, fields: IntoNestedDType) -> sqltypes.DuckDBPyType: ...
451455
def table(self, table_name: str) -> DuckDBPyRelation: ...
452456
def table_function(self, name: str, parameters: object = None) -> DuckDBPyRelation: ...
453457
def tf(self) -> dict[str, typing.Any]: ...
454458
def torch(self) -> dict[str, typing.Any]: ...
455459
def type(self, type_str: str) -> sqltypes.DuckDBPyType: ...
456-
def union_type(
457-
self, members: lst[sqltypes.DuckDBPyType] | dict[str, sqltypes.DuckDBPyType]
458-
) -> sqltypes.DuckDBPyType: ...
460+
def union_type(self, members: IntoNestedDType) -> sqltypes.DuckDBPyType: ...
459461
def unregister(self, view_name: str) -> DuckDBPyConnection: ...
460462
def unregister_filesystem(self, name: str) -> None: ...
461463
def values(self, *args: IntoValues) -> DuckDBPyRelation: ...
@@ -527,13 +529,13 @@ class DuckDBPyRelation:
527529
) -> DuckDBPyRelation: ...
528530
def create(self, table_name: str) -> None: ...
529531
def create_view(self, view_name: str, replace: bool = True) -> DuckDBPyRelation: ...
530-
def cross(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ...
532+
def cross(self, other_rel: Self) -> DuckDBPyRelation: ...
531533
def cume_dist(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ...
532534
def dense_rank(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ...
533535
def describe(self) -> DuckDBPyRelation: ...
534536
def df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ...
535537
def distinct(self) -> DuckDBPyRelation: ...
536-
def except_(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ...
538+
def except_(self, other_rel: Self) -> DuckDBPyRelation: ...
537539
def execute(self) -> DuckDBPyRelation: ...
538540
def explain(self, type: ExplainType = ExplainType.STANDARD) -> str: ...
539541
def favg(
@@ -568,10 +570,10 @@ class DuckDBPyRelation:
568570
) -> DuckDBPyRelation: ...
569571
def insert(self, values: lst[object]) -> None: ...
570572
def insert_into(self, table_name: str) -> None: ...
571-
def intersect(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ...
573+
def intersect(self, other_rel: Self) -> DuckDBPyRelation: ...
572574
def join(
573575
self,
574-
other_rel: DuckDBPyRelation,
576+
other_rel: Self,
575577
condition: IntoExprColumn,
576578
how: typing.Literal["inner", "left", "right", "outer", "semi", "anti"] = "inner",
577579
) -> DuckDBPyRelation: ...
@@ -747,7 +749,7 @@ class DuckDBPyRelation:
747749
def to_table(self, table_name: str) -> None: ...
748750
def to_view(self, view_name: str, replace: bool = True) -> DuckDBPyRelation: ...
749751
def torch(self) -> dict[str, typing.Any]: ...
750-
def union(self, union_rel: DuckDBPyRelation) -> DuckDBPyRelation: ...
752+
def union(self, union_rel: Self) -> DuckDBPyRelation: ...
751753
def unique(self, unique_aggr: str) -> DuckDBPyRelation: ...
752754
def update(self, set: Mapping[str, IntoExpr], *, condition: IntoExpr = None) -> None: ...
753755
def value_counts(self, expression: str, groups: str = "") -> DuckDBPyRelation: ...
@@ -1027,7 +1029,7 @@ def append(
10271029
table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection | None = None
10281030
) -> DuckDBPyConnection: ...
10291031
def array_type(
1030-
type: sqltypes.DuckDBPyType, size: typing.SupportsInt, *, connection: DuckDBPyConnection | None = None
1032+
type: IntoDType, size: typing.SupportsInt, *, connection: DuckDBPyConnection | None = None
10311033
) -> sqltypes.DuckDBPyType: ...
10321034
@typing.overload
10331035
def arrow(
@@ -1056,8 +1058,8 @@ def connect(
10561058
def create_function(
10571059
name: str,
10581060
function: Callable[..., typing.Any],
1059-
parameters: lst[sqltypes.DuckDBPyType] | None = None,
1060-
return_type: sqltypes.DuckDBPyType | None = None,
1061+
parameters: lst[IntoDType] | None = None,
1062+
return_type: IntoDType | None = None,
10611063
*,
10621064
type: func.PythonUDFType = ...,
10631065
null_handling: func.FunctionNullHandling = ...,
@@ -1240,15 +1242,10 @@ def get_profiling_information(*, connection: DuckDBPyConnection | None = None, f
12401242
def enable_profiling(*, connection: DuckDBPyConnection | None = None) -> None: ...
12411243
def disable_profiling(*, connection: DuckDBPyConnection | None = None) -> None: ...
12421244
def list_filesystems(*, connection: DuckDBPyConnection | None = None) -> lst[str]: ...
1243-
def list_type(
1244-
type: sqltypes.DuckDBPyType, *, connection: DuckDBPyConnection | None = None
1245-
) -> sqltypes.DuckDBPyType: ...
1245+
def list_type(type: IntoDType, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ...
12461246
def load_extension(extension: str, *, connection: DuckDBPyConnection | None = None) -> None: ...
12471247
def map_type(
1248-
key: sqltypes.DuckDBPyType,
1249-
value: sqltypes.DuckDBPyType,
1250-
*,
1251-
connection: DuckDBPyConnection | None = None,
1248+
key: IntoDType, value: IntoDType, *, connection: DuckDBPyConnection | None = None
12521249
) -> sqltypes.DuckDBPyType: ...
12531250
def order(
12541251
df: pandas.DataFrame, order_expr: str, *, connection: DuckDBPyConnection | None = None
@@ -1394,11 +1391,7 @@ def register_filesystem(
13941391
) -> None: ...
13951392
def remove_function(name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ...
13961393
def rollback(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ...
1397-
def row_type(
1398-
fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType],
1399-
*,
1400-
connection: DuckDBPyConnection | None = None,
1401-
) -> sqltypes.DuckDBPyType: ...
1394+
def row_type(fields: IntoNestedDType, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ...
14021395
def rowcount(*, connection: DuckDBPyConnection | None = None) -> int: ...
14031396
def set_default_connection(connection: DuckDBPyConnection) -> None: ...
14041397
def sql(
@@ -1410,11 +1403,7 @@ def sql(
14101403
) -> DuckDBPyRelation: ...
14111404
def sqltype(type_str: str, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ...
14121405
def string_type(collation: str = "", *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ...
1413-
def struct_type(
1414-
fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType],
1415-
*,
1416-
connection: DuckDBPyConnection | None = None,
1417-
) -> sqltypes.DuckDBPyType: ...
1406+
def struct_type(fields: IntoNestedDType, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ...
14181407
def table(table_name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ...
14191408
def table_function(
14201409
name: str,
@@ -1426,11 +1415,7 @@ def tf(*, connection: DuckDBPyConnection | None = None) -> dict[str, typing.Any]
14261415
def tokenize(query: str) -> lst[tuple[int, token_type]]: ...
14271416
def torch(*, connection: DuckDBPyConnection | None = None) -> dict[str, typing.Any]: ...
14281417
def type(type_str: str, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ...
1429-
def union_type(
1430-
members: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType],
1431-
*,
1432-
connection: DuckDBPyConnection | None = None,
1433-
) -> sqltypes.DuckDBPyType: ...
1418+
def union_type(members: IntoNestedDType, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ...
14341419
def unregister(view_name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ...
14351420
def unregister_filesystem(name: str, *, connection: DuckDBPyConnection | None = None) -> None: ...
14361421
def values(*args: IntoValues, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ...

_duckdb-stubs/_expression.pyi

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import TYPE_CHECKING, Any, overload
2-
from duckdb import sqltypes
32

43
if TYPE_CHECKING:
5-
from ._typing import IntoExpr
4+
from ._typing import IntoExpr, IntoDType
65

76
class Expression:
87
def __add__(self, other: IntoExpr) -> Expression: ...
@@ -40,7 +39,7 @@ class Expression:
4039
def alias(self, name: str) -> Expression: ...
4140
def asc(self) -> Expression: ...
4241
def between(self, lower: IntoExpr, upper: IntoExpr) -> Expression: ...
43-
def cast(self, type: sqltypes.DuckDBPyType) -> Expression: ...
42+
def cast(self, type: IntoDType) -> Expression: ...
4443
def collate(self, collation: str) -> Expression: ...
4544
def desc(self) -> Expression: ...
4645
def get_name(self) -> str: ...

_duckdb-stubs/_typing.pyi

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
23
from typing import TypeAlias, TYPE_CHECKING, Protocol, Any, TypeVar, Generic
34
from datetime import date, datetime, time, timedelta
45
from decimal import Decimal
@@ -7,6 +8,7 @@ from collections.abc import Mapping, Iterator
78

89
if TYPE_CHECKING:
910
from ._expression import Expression
11+
from ._sqltypes import DuckDBPyType
1012

1113
_T_co = TypeVar("_T_co", covariant=True)
1214
_S_co = TypeVar("_S_co", bound=tuple[Any, ...], covariant=True)
@@ -54,7 +56,17 @@ PythonLiteral: TypeAlias = (
5456
# the field_ids argument to to_parquet and write_parquet has a recursive structure
5557
ParquetFieldIdsType: TypeAlias = Mapping[str, int | ParquetFieldIdsType]
5658
IntoValues: TypeAlias = list[PythonLiteral] | tuple[Expression, ...] | Expression
57-
"""Types that can be converted to a table of values."""
59+
"""Types that can be converted to a table."""
60+
IntoDType: TypeAlias = DuckDBPyType | str
61+
"""Types that can be converted to a `DuckDBPyType`.
62+
63+
Passing `INTEGER` is equivalent to passing `DuckDBPyType("INTEGER")` or `DuckDBPyType.INTEGER`.
64+
65+
Note:
66+
A `StrEnum` will be handled the same way as a `str`.
67+
"""
68+
IntoNestedDType: TypeAlias = dict[str, IntoDType] | list[IntoDType]
69+
"""Types that can be converted to a nested `DuckDBPyType` (e.g. for struct or union types)."""
5870
IntoExprColumn: TypeAlias = Expression | str
5971
"""Types that are, or can be used as a `ColumnExpression`."""
6072
IntoExpr: TypeAlias = IntoExprColumn | PythonLiteral

0 commit comments

Comments
 (0)