Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 53 additions & 54 deletions dataframely/_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABCMeta
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Annotated, Any, Self, get_args, get_origin
from typing import Annotated, Any, Self, cast, get_args, get_origin

import polars as pl

Expand Down Expand Up @@ -183,59 +183,9 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
# Get all members via the annotations
if "__annotations__" in source:
for attr, kls in source["__annotations__"].items():
origin = get_origin(kls)

# optional annotation
collection_member = CollectionMember()

if origin is Annotated:
annotation_args = get_args(kls)
origin_arg0 = get_origin(annotation_args[0])
if not origin_arg0 or not issubclass(origin_arg0, TypedLazyFrame):
raise AnnotationImplementationError(attr, kls)
if len(annotation_args) > 2:
raise AnnotationImplementationError(attr, kls)
if not isinstance(annotation_args[1], CollectionMember):
raise AnnotationImplementationError(attr, kls)

# Continue with wrapped FrameType
collection_member = annotation_args[1]
kls = annotation_args[0]
origin = origin_arg0

if origin is None:
# `None` annotation is not allowed
raise AnnotationImplementationError(attr, kls)
elif origin == typing.Union:
# Happy path: optional member
union_args = get_args(kls)
if len(union_args) != 2:
raise AnnotationImplementationError(attr, kls)
if not any(get_origin(arg) is None for arg in union_args):
raise AnnotationImplementationError(attr, kls)

[not_none_arg] = [
arg for arg in union_args if get_origin(arg) is not None
]
if not issubclass(get_origin(not_none_arg), TypedLazyFrame):
raise AnnotationImplementationError(attr, kls)

result.members[attr] = MemberInfo(
schema=get_args(not_none_arg)[0],
is_optional=True,
ignored_in_filters=collection_member.ignored_in_filters,
)
elif issubclass(origin, TypedLazyFrame):
# Happy path: required member
result.members[attr] = MemberInfo(
schema=get_args(kls)[0],
is_optional=False,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
)
else:
# Some other unknown annotation
raise AnnotationImplementationError(attr, kls)
result.members[attr] = CollectionMeta._derive_member_info(
attr, kls, CollectionMember()
)

# Get all filters by traversing the source
for attr, value in {
Expand All @@ -246,6 +196,55 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:

return result

@staticmethod
def _derive_member_info(
attr: str, type_annotation: Any, collection_member: CollectionMember
) -> MemberInfo:
origin = get_origin(type_annotation)

if origin is None:
# `None` annotation is not allowed
raise AnnotationImplementationError(attr, type_annotation)
elif origin == Annotated:
# Maybe happy path: annotated member, dispatch recursively
annotation_args = cast(list[Any], get_args(type_annotation))
if len(annotation_args) > 2:
raise AnnotationImplementationError(attr, type_annotation)
if not isinstance(annotation_args[1], CollectionMember):
raise AnnotationImplementationError(attr, type_annotation)
return CollectionMeta._derive_member_info(
attr, annotation_args[0], annotation_args[1]
)
elif origin == typing.Union:
# Happy path: optional member
union_args = get_args(type_annotation)
if len(union_args) != 2:
raise AnnotationImplementationError(attr, type_annotation)
if not any(get_origin(arg) is None for arg in union_args):
raise AnnotationImplementationError(attr, type_annotation)

[not_none_arg] = [arg for arg in union_args if get_origin(arg) is not None]
if not issubclass(get_origin(not_none_arg), TypedLazyFrame):
raise AnnotationImplementationError(attr, type_annotation)

return MemberInfo(
schema=get_args(not_none_arg)[0],
is_optional=True,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
)
elif issubclass(origin, TypedLazyFrame):
# Happy path: required member
return MemberInfo(
schema=get_args(type_annotation)[0],
is_optional=False,
ignored_in_filters=collection_member.ignored_in_filters,
inline_for_sampling=collection_member.inline_for_sampling,
)
else:
# Some other unknown annotation
raise AnnotationImplementationError(attr, type_annotation)


class BaseCollection(metaclass=CollectionMeta):
"""Internal utility abstraction to reference collections without introducing
Expand Down
18 changes: 18 additions & 0 deletions tests/collection/test_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@ def test_annotation_union_conflicting_types_failure() -> None:
)


def test_annotation_annotated_success() -> None:
"""When we use an Annotated type, it must accept a union type."""
create_collection_raw(
"test",
{
"first": Annotated[
dy.LazyFrame[MyTestSchema] | None, dy.CollectionMember()
],
},
)
create_collection_raw(
"test",
{
"first": dy.LazyFrame[MyTestSchema] | None,
},
)


def test_annotation_only_none_failure() -> None:
"""Annotations must not just be None."""
with pytest.raises(AnnotationImplementationError):
Expand Down
26 changes: 24 additions & 2 deletions tests/collection/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ def _preprocess_sample(
return sample


class MyInlinedCollectionWithOptional(dy.Collection):
first: Annotated[
dy.LazyFrame[MyFirstSchema] | None,
dy.CollectionMember(inline_for_sampling=True),
]
second: dy.LazyFrame[MySecondSchema]

@classmethod
def _preprocess_sample(
cls, sample: dict[str, Any], index: int, generator: Generator
) -> dict[str, Any]:
sample["a"] = index
return sample


class SmallCollection(dy.Collection):
first: dy.LazyFrame[MyFirstSchema]

Expand Down Expand Up @@ -117,13 +132,20 @@ def test_sample_with_overrides() -> None:
assert collection.second.collect()["c"].to_list() == [3, 4, 6]


def test_sample_inline_with_overrides() -> None:
collection = MyInlinedCollection.sample(
@pytest.mark.parametrize(
"collection_type", [MyInlinedCollection, MyInlinedCollectionWithOptional]
)
def test_sample_inline_with_overrides(
collection_type: type[MyInlinedCollection] | type[MyInlinedCollectionWithOptional],
) -> None:
collection = collection_type.sample(
overrides=[
{"b": 4, "second": [{"c": 3}, {"c": 4}]},
{"b": 8, "second": [{"c": 6}]},
]
)

assert collection.first is not None
assert collection.first.collect()["a"].to_list() == [0, 1]
assert collection.first.collect()["b"].to_list() == [4, 8]

Expand Down
Loading