Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dataframely/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import warnings
from abc import ABC
from collections.abc import Mapping, MutableMapping, Sequence
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, Generic, Self, TypeVar, cast

Expand All @@ -20,10 +20,10 @@

if sys.version_info >= (3, 13):
SamplingType = TypeVar(
"SamplingType", bound=MutableMapping[str, Any], default=dict[str, Any]
"SamplingType", bound=Mapping[str, Any], default=Mapping[str, Any]
)
else: # pragma: no cover
SamplingType = TypeVar("SamplingType", bound=MutableMapping[str, Any])
SamplingType = TypeVar("SamplingType", bound=Mapping[str, Any])


class Collection(BaseCollection, ABC, Generic[SamplingType]):
Expand Down
7 changes: 5 additions & 2 deletions dataframely/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

import polars as pl

from ._base_collection import BaseCollection
from ._typing import LazyFrame
from .collection import Collection
from .schema import Schema

S = TypeVar("S", bound=Schema)
T = TypeVar("T", bound=Schema)
C = TypeVar("C", bound=Collection)

# NOTE: Binding to `BaseCollection` is required here as the TypeVar default for the
# sampling type otherwise causes issues for Python 3.13.
C = TypeVar("C", bound=BaseCollection)

# ------------------------------------------------------------------------------------ #
# FILTER #
Expand Down
24 changes: 22 additions & 2 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import datetime
import decimal
import functools
from typing import Any
from typing import Any, NotRequired, TypedDict

import polars as pl
import pytest
Expand Down Expand Up @@ -61,7 +61,21 @@ class MySecondSchema(dy.Schema):
b = dy.Integer()


class MyCollection(dy.Collection):
class SamplingTypeFirst(TypedDict):
a: NotRequired[int]


class SamplingTypeSecond(TypedDict):
a: NotRequired[int]
b: NotRequired[int]


class SamplingType(TypedDict):
first: NotRequired[SamplingTypeFirst]
second: NotRequired[SamplingTypeSecond]


class MyCollection(dy.Collection[SamplingType]):
first: dy.LazyFrame[MyFirstSchema]
second: dy.LazyFrame[MySecondSchema]

Expand All @@ -73,6 +87,12 @@ def test_collection_filter_return_value() -> None:
assert len(failure["third"]) == 0 # type: ignore[misc]


def test_collection_concat() -> None:
c1 = MyCollection.create_empty()
c2 = MyCollection.create_empty()
dy.concat_collection_members([c1, c2])


# ------------------------------------------------------------------------------------ #
# ITER ROWS #
# ------------------------------------------------------------------------------------ #
Expand Down
Loading