From 9547d770f2101b406e06aff1bfb49e5b894b5adf Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Tue, 22 Apr 2025 20:30:34 +0200 Subject: [PATCH 1/3] fix: Use covariant type annotation for collection sampling type --- dataframely/collection.py | 6 +++--- tests/test_typing.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/dataframely/collection.py b/dataframely/collection.py index 70d6595..6aba74f 100644 --- a/dataframely/collection.py +++ b/dataframely/collection.py @@ -4,7 +4,7 @@ import sys import warnings from abc import ABC -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Mapping, Sequence from pathlib import Path from typing import Any, Generic, Self, TypeVar, cast @@ -20,10 +20,10 @@ if sys.version_info >= (3, 13): SamplingType = TypeVar( - "SamplingType", bound=MutableMapping[str, Any], default=dict[str, Any] + "SamplingType", bound=Mapping[str, Any], default=dict[str, Any] ) else: # pragma: no cover - SamplingType = TypeVar("SamplingType", bound=MutableMapping[str, Any]) + SamplingType = TypeVar("SamplingType", bound=Mapping[str, Any]) class Collection(BaseCollection, ABC, Generic[SamplingType]): diff --git a/tests/test_typing.py b/tests/test_typing.py index 3fae814..13986c0 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -9,7 +9,7 @@ import datetime import decimal import functools -from typing import Any +from typing import Any, NotRequired, TypedDict import polars as pl import pytest @@ -61,7 +61,21 @@ class MySecondSchema(dy.Schema): b = dy.Integer() -class MyCollection(dy.Collection): +class SamplingTypeFirst(TypedDict): + a: NotRequired[int] + + +class SamplingTypeSecond(TypedDict): + a: NotRequired[int] + b: NotRequired[int] + + +class SamplingType(TypedDict): + first: NotRequired[SamplingTypeFirst] + second: NotRequired[SamplingTypeSecond] + + +class MyCollection(dy.Collection[SamplingType]): first: dy.LazyFrame[MyFirstSchema] second: dy.LazyFrame[MySecondSchema] From da5715b4b08db10a68542aacb0546beddcd2a67c Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Tue, 22 Apr 2025 22:53:19 +0200 Subject: [PATCH 2/3] Update default --- dataframely/collection.py | 2 +- tests/test_typing.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/dataframely/collection.py b/dataframely/collection.py index 6aba74f..6e85040 100644 --- a/dataframely/collection.py +++ b/dataframely/collection.py @@ -20,7 +20,7 @@ if sys.version_info >= (3, 13): SamplingType = TypeVar( - "SamplingType", bound=Mapping[str, Any], default=dict[str, Any] + "SamplingType", bound=Mapping[str, Any], default=Mapping[str, Any] ) else: # pragma: no cover SamplingType = TypeVar("SamplingType", bound=Mapping[str, Any]) diff --git a/tests/test_typing.py b/tests/test_typing.py index 13986c0..ce403bf 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -87,6 +87,12 @@ def test_collection_filter_return_value() -> None: assert len(failure["third"]) == 0 # type: ignore[misc] +def test_collection_concat() -> None: + c1 = MyCollection.create_empty() + c2 = MyCollection.create_empty() + dy.concat_collection_members([c1, c2]) + + # ------------------------------------------------------------------------------------ # # ITER ROWS # # ------------------------------------------------------------------------------------ # From 211c3af8a09a0bf8c2199f29ae9810d9633a3c3d Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Tue, 22 Apr 2025 23:35:58 +0200 Subject: [PATCH 3/3] Fix --- dataframely/functional.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dataframely/functional.py b/dataframely/functional.py index eb42e7e..b2afc56 100644 --- a/dataframely/functional.py +++ b/dataframely/functional.py @@ -5,13 +5,16 @@ import polars as pl +from ._base_collection import BaseCollection from ._typing import LazyFrame -from .collection import Collection from .schema import Schema S = TypeVar("S", bound=Schema) T = TypeVar("T", bound=Schema) -C = TypeVar("C", bound=Collection) + +# NOTE: Binding to `BaseCollection` is required here as the TypeVar default for the +# sampling type otherwise causes issues for Python 3.13. +C = TypeVar("C", bound=BaseCollection) # ------------------------------------------------------------------------------------ # # FILTER #