Skip to content

Commit 698b701

Browse files
authored
fix: Use covariant type annotation for collection sampling type (#9)
1 parent 54cbc75 commit 698b701

3 files changed

Lines changed: 30 additions & 7 deletions

File tree

dataframely/collection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
import warnings
66
from abc import ABC
7-
from collections.abc import Mapping, MutableMapping, Sequence
7+
from collections.abc import Mapping, Sequence
88
from pathlib import Path
99
from typing import Any, Generic, Self, TypeVar, cast
1010

@@ -20,10 +20,10 @@
2020

2121
if sys.version_info >= (3, 13):
2222
SamplingType = TypeVar(
23-
"SamplingType", bound=MutableMapping[str, Any], default=dict[str, Any]
23+
"SamplingType", bound=Mapping[str, Any], default=Mapping[str, Any]
2424
)
2525
else: # pragma: no cover
26-
SamplingType = TypeVar("SamplingType", bound=MutableMapping[str, Any])
26+
SamplingType = TypeVar("SamplingType", bound=Mapping[str, Any])
2727

2828

2929
class Collection(BaseCollection, ABC, Generic[SamplingType]):

dataframely/functional.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
import polars as pl
77

8+
from ._base_collection import BaseCollection
89
from ._typing import LazyFrame
9-
from .collection import Collection
1010
from .schema import Schema
1111

1212
S = TypeVar("S", bound=Schema)
1313
T = TypeVar("T", bound=Schema)
14-
C = TypeVar("C", bound=Collection)
14+
15+
# NOTE: Binding to `BaseCollection` is required here as the TypeVar default for the
16+
# sampling type otherwise causes issues for Python 3.13.
17+
C = TypeVar("C", bound=BaseCollection)
1518

1619
# ------------------------------------------------------------------------------------ #
1720
# FILTER #

tests/test_typing.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import datetime
1010
import decimal
1111
import functools
12-
from typing import Any
12+
from typing import Any, NotRequired, TypedDict
1313

1414
import polars as pl
1515
import pytest
@@ -61,7 +61,21 @@ class MySecondSchema(dy.Schema):
6161
b = dy.Integer()
6262

6363

64-
class MyCollection(dy.Collection):
64+
class SamplingTypeFirst(TypedDict):
65+
a: NotRequired[int]
66+
67+
68+
class SamplingTypeSecond(TypedDict):
69+
a: NotRequired[int]
70+
b: NotRequired[int]
71+
72+
73+
class SamplingType(TypedDict):
74+
first: NotRequired[SamplingTypeFirst]
75+
second: NotRequired[SamplingTypeSecond]
76+
77+
78+
class MyCollection(dy.Collection[SamplingType]):
6579
first: dy.LazyFrame[MyFirstSchema]
6680
second: dy.LazyFrame[MySecondSchema]
6781

@@ -73,6 +87,12 @@ def test_collection_filter_return_value() -> None:
7387
assert len(failure["third"]) == 0 # type: ignore[misc]
7488

7589

90+
def test_collection_concat() -> None:
91+
c1 = MyCollection.create_empty()
92+
c2 = MyCollection.create_empty()
93+
dy.concat_collection_members([c1, c2])
94+
95+
7696
# ------------------------------------------------------------------------------------ #
7797
# ITER ROWS #
7898
# ------------------------------------------------------------------------------------ #

0 commit comments

Comments
 (0)