diff --git a/dataframely/_base_collection.py b/dataframely/_base_collection.py index eddb1e9..8ddf141 100644 --- a/dataframely/_base_collection.py +++ b/dataframely/_base_collection.py @@ -7,7 +7,7 @@ from abc import ABCMeta from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Annotated, Any, Self, get_args, get_origin +from typing import Annotated, Any, Self, cast, get_args, get_origin import polars as pl @@ -183,59 +183,9 @@ def _get_metadata(source: dict[str, Any]) -> Metadata: # Get all members via the annotations if "__annotations__" in source: for attr, kls in source["__annotations__"].items(): - origin = get_origin(kls) - - # optional annotation - collection_member = CollectionMember() - - if origin is Annotated: - annotation_args = get_args(kls) - origin_arg0 = get_origin(annotation_args[0]) - if not origin_arg0 or not issubclass(origin_arg0, TypedLazyFrame): - raise AnnotationImplementationError(attr, kls) - if len(annotation_args) > 2: - raise AnnotationImplementationError(attr, kls) - if not isinstance(annotation_args[1], CollectionMember): - raise AnnotationImplementationError(attr, kls) - - # Continue with wrapped FrameType - collection_member = annotation_args[1] - kls = annotation_args[0] - origin = origin_arg0 - - if origin is None: - # `None` annotation is not allowed - raise AnnotationImplementationError(attr, kls) - elif origin == typing.Union: - # Happy path: optional member - union_args = get_args(kls) - if len(union_args) != 2: - raise AnnotationImplementationError(attr, kls) - if not any(get_origin(arg) is None for arg in union_args): - raise AnnotationImplementationError(attr, kls) - - [not_none_arg] = [ - arg for arg in union_args if get_origin(arg) is not None - ] - if not issubclass(get_origin(not_none_arg), TypedLazyFrame): - raise AnnotationImplementationError(attr, kls) - - result.members[attr] = MemberInfo( - schema=get_args(not_none_arg)[0], - is_optional=True, - ignored_in_filters=collection_member.ignored_in_filters, - ) - elif issubclass(origin, TypedLazyFrame): - # Happy path: required member - result.members[attr] = MemberInfo( - schema=get_args(kls)[0], - is_optional=False, - ignored_in_filters=collection_member.ignored_in_filters, - inline_for_sampling=collection_member.inline_for_sampling, - ) - else: - # Some other unknown annotation - raise AnnotationImplementationError(attr, kls) + result.members[attr] = CollectionMeta._derive_member_info( + attr, kls, CollectionMember() + ) # Get all filters by traversing the source for attr, value in { @@ -246,6 +196,55 @@ def _get_metadata(source: dict[str, Any]) -> Metadata: return result + @staticmethod + def _derive_member_info( + attr: str, type_annotation: Any, collection_member: CollectionMember + ) -> MemberInfo: + origin = get_origin(type_annotation) + + if origin is None: + # `None` annotation is not allowed + raise AnnotationImplementationError(attr, type_annotation) + elif origin == Annotated: + # Maybe happy path: annotated member, dispatch recursively + annotation_args = cast(list[Any], get_args(type_annotation)) + if len(annotation_args) > 2: + raise AnnotationImplementationError(attr, type_annotation) + if not isinstance(annotation_args[1], CollectionMember): + raise AnnotationImplementationError(attr, type_annotation) + return CollectionMeta._derive_member_info( + attr, annotation_args[0], annotation_args[1] + ) + elif origin == typing.Union: + # Happy path: optional member + union_args = get_args(type_annotation) + if len(union_args) != 2: + raise AnnotationImplementationError(attr, type_annotation) + if not any(get_origin(arg) is None for arg in union_args): + raise AnnotationImplementationError(attr, type_annotation) + + [not_none_arg] = [arg for arg in union_args if get_origin(arg) is not None] + if not issubclass(get_origin(not_none_arg), TypedLazyFrame): + raise AnnotationImplementationError(attr, type_annotation) + + return MemberInfo( + schema=get_args(not_none_arg)[0], + is_optional=True, + ignored_in_filters=collection_member.ignored_in_filters, + inline_for_sampling=collection_member.inline_for_sampling, + ) + elif issubclass(origin, TypedLazyFrame): + # Happy path: required member + return MemberInfo( + schema=get_args(type_annotation)[0], + is_optional=False, + ignored_in_filters=collection_member.ignored_in_filters, + inline_for_sampling=collection_member.inline_for_sampling, + ) + else: + # Some other unknown annotation + raise AnnotationImplementationError(attr, type_annotation) + class BaseCollection(metaclass=CollectionMeta): """Internal utility abstraction to reference collections without introducing diff --git a/tests/collection/test_implementation.py b/tests/collection/test_implementation.py index 74c4094..014e2ca 100644 --- a/tests/collection/test_implementation.py +++ b/tests/collection/test_implementation.py @@ -81,6 +81,24 @@ def test_annotation_union_conflicting_types_failure() -> None: ) +def test_annotation_annotated_success() -> None: + """When we use an Annotated type, it must accept a union type.""" + create_collection_raw( + "test", + { + "first": Annotated[ + dy.LazyFrame[MyTestSchema] | None, dy.CollectionMember() + ], + }, + ) + create_collection_raw( + "test", + { + "first": dy.LazyFrame[MyTestSchema] | None, + }, + ) + + def test_annotation_only_none_failure() -> None: """Annotations must not just be None.""" with pytest.raises(AnnotationImplementationError): diff --git a/tests/collection/test_sample.py b/tests/collection/test_sample.py index 42d388f..2d7de77 100644 --- a/tests/collection/test_sample.py +++ b/tests/collection/test_sample.py @@ -54,6 +54,21 @@ def _preprocess_sample( return sample +class MyInlinedCollectionWithOptional(dy.Collection): + first: Annotated[ + dy.LazyFrame[MyFirstSchema] | None, + dy.CollectionMember(inline_for_sampling=True), + ] + second: dy.LazyFrame[MySecondSchema] + + @classmethod + def _preprocess_sample( + cls, sample: dict[str, Any], index: int, generator: Generator + ) -> dict[str, Any]: + sample["a"] = index + return sample + + class SmallCollection(dy.Collection): first: dy.LazyFrame[MyFirstSchema] @@ -117,13 +132,20 @@ def test_sample_with_overrides() -> None: assert collection.second.collect()["c"].to_list() == [3, 4, 6] -def test_sample_inline_with_overrides() -> None: - collection = MyInlinedCollection.sample( +@pytest.mark.parametrize( + "collection_type", [MyInlinedCollection, MyInlinedCollectionWithOptional] +) +def test_sample_inline_with_overrides( + collection_type: type[MyInlinedCollection] | type[MyInlinedCollectionWithOptional], +) -> None: + collection = collection_type.sample( overrides=[ {"b": 4, "second": [{"c": 3}, {"c": 4}]}, {"b": 8, "second": [{"c": 6}]}, ] ) + + assert collection.first is not None assert collection.first.collect()["a"].to_list() == [0, 1] assert collection.first.collect()["b"].to_list() == [4, 8]