diff --git a/dataframely/collection.py b/dataframely/collection.py index 70d6595..6e85040 100644 --- a/dataframely/collection.py +++ b/dataframely/collection.py @@ -4,7 +4,7 @@ import sys import warnings from abc import ABC -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Mapping, Sequence from pathlib import Path from typing import Any, Generic, Self, TypeVar, cast @@ -20,10 +20,10 @@ if sys.version_info >= (3, 13): SamplingType = TypeVar( - "SamplingType", bound=MutableMapping[str, Any], default=dict[str, Any] + "SamplingType", bound=Mapping[str, Any], default=Mapping[str, Any] ) else: # pragma: no cover - SamplingType = TypeVar("SamplingType", bound=MutableMapping[str, Any]) + SamplingType = TypeVar("SamplingType", bound=Mapping[str, Any]) class Collection(BaseCollection, ABC, Generic[SamplingType]): diff --git a/dataframely/functional.py b/dataframely/functional.py index eb42e7e..b2afc56 100644 --- a/dataframely/functional.py +++ b/dataframely/functional.py @@ -5,13 +5,16 @@ import polars as pl +from ._base_collection import BaseCollection from ._typing import LazyFrame -from .collection import Collection from .schema import Schema S = TypeVar("S", bound=Schema) T = TypeVar("T", bound=Schema) -C = TypeVar("C", bound=Collection) + +# NOTE: Binding to `BaseCollection` is required here as the TypeVar default for the +# sampling type otherwise causes issues for Python 3.13. +C = TypeVar("C", bound=BaseCollection) # ------------------------------------------------------------------------------------ # # FILTER # diff --git a/tests/test_typing.py b/tests/test_typing.py index 3fae814..ce403bf 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -9,7 +9,7 @@ import datetime import decimal import functools -from typing import Any +from typing import Any, NotRequired, TypedDict import polars as pl import pytest @@ -61,7 +61,21 @@ class MySecondSchema(dy.Schema): b = dy.Integer() -class MyCollection(dy.Collection): +class SamplingTypeFirst(TypedDict): + a: NotRequired[int] + + +class SamplingTypeSecond(TypedDict): + a: NotRequired[int] + b: NotRequired[int] + + +class SamplingType(TypedDict): + first: NotRequired[SamplingTypeFirst] + second: NotRequired[SamplingTypeSecond] + + +class MyCollection(dy.Collection[SamplingType]): first: dy.LazyFrame[MyFirstSchema] second: dy.LazyFrame[MySecondSchema] @@ -73,6 +87,12 @@ def test_collection_filter_return_value() -> None: assert len(failure["third"]) == 0 # type: ignore[misc] +def test_collection_concat() -> None: + c1 = MyCollection.create_empty() + c2 = MyCollection.create_empty() + dy.concat_collection_members([c1, c2]) + + # ------------------------------------------------------------------------------------ # # ITER ROWS # # ------------------------------------------------------------------------------------ #