Skip to content

Commit be94364

Browse files
gab23rgabrielborchero
authored
feat: Add DataFrame support in dy.Collection (#335)
Co-authored-by: gabriel <gabriel.g.robin@airbus.com> Co-authored-by: Oliver Borchert <me@borchero.com>
1 parent b339296 commit be94364

6 files changed

Lines changed: 300 additions & 63 deletions

File tree

dataframely/collection/_base.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from dataframely._filter import Filter
1717
from dataframely._polars import FrameType
18+
from dataframely._typing import DataFrame as TypedDataFrame
1819
from dataframely._typing import LazyFrame as TypedLazyFrame
1920
from dataframely.exc import AnnotationImplementationError, ImplementationError
2021
from dataframely.schema import Schema
@@ -92,6 +93,8 @@ class MemberInfo(CollectionMember):
9293
schema: type[Schema]
9394
#: Whether the member is optional.
9495
is_optional: bool
96+
#: Whether the member is a lazy frame.
97+
is_lazy: bool = True
9598

9699

97100
@dataclass
@@ -241,39 +244,46 @@ def _derive_member_info(
241244
attr, annotation_args[0], annotation_args[1]
242245
)
243246
elif origin == typing.Union:
244-
# Happy path: optional member
247+
# Happy path: optional member (e.g. dy.LazyFrame[Schema] | None)
245248
union_args = get_args(type_annotation)
246249
if len(union_args) != 2:
247250
raise AnnotationImplementationError(attr, type_annotation)
248-
if not any(get_origin(arg) is None for arg in union_args):
251+
# Check that exactly one arg is None (type(None) is NoneType)
252+
if not any(arg is type(None) for arg in union_args):
249253
raise AnnotationImplementationError(attr, type_annotation)
250254

251-
not_none_args = [arg for arg in union_args if get_origin(arg) is not None]
252-
if len(not_none_args) == 0 or not issubclass(
253-
get_origin(not_none_args[0]), TypedLazyFrame
254-
):
255+
# Get the non-None type (exactly one exists given prior checks)
256+
not_none_arg = next(arg for arg in union_args if arg is not type(None))
257+
258+
frame_origin = get_origin(not_none_arg)
259+
if frame_origin is None:
255260
raise AnnotationImplementationError(attr, type_annotation)
256261

257-
return MemberInfo(
258-
schema=get_args(not_none_args[0])[0],
259-
is_optional=True,
260-
ignored_in_filters=collection_member.ignored_in_filters,
261-
inline_for_sampling=collection_member.inline_for_sampling,
262-
propagate_row_failures=collection_member.propagate_row_failures,
263-
)
264-
elif issubclass(origin, TypedLazyFrame):
265-
# Happy path: required member
266-
return MemberInfo(
267-
schema=get_args(type_annotation)[0],
268-
is_optional=False,
269-
ignored_in_filters=collection_member.ignored_in_filters,
270-
inline_for_sampling=collection_member.inline_for_sampling,
271-
propagate_row_failures=collection_member.propagate_row_failures,
272-
)
262+
schema = get_args(not_none_arg)[0]
263+
is_optional = True
264+
elif issubclass(origin, (TypedLazyFrame, TypedDataFrame)):
265+
frame_origin = origin
266+
schema = get_args(type_annotation)[0]
267+
is_optional = False
268+
else:
269+
raise AnnotationImplementationError(attr, type_annotation)
270+
271+
if issubclass(frame_origin, TypedLazyFrame):
272+
is_lazy = True
273+
elif issubclass(frame_origin, TypedDataFrame):
274+
is_lazy = False
273275
else:
274-
# Some other unknown annotation
275276
raise AnnotationImplementationError(attr, type_annotation)
276277

278+
return MemberInfo(
279+
schema=schema,
280+
is_optional=is_optional,
281+
is_lazy=is_lazy,
282+
ignored_in_filters=collection_member.ignored_in_filters,
283+
inline_for_sampling=collection_member.inline_for_sampling,
284+
propagate_row_failures=collection_member.propagate_row_failures,
285+
)
286+
277287
def __repr__(cls) -> str:
278288
parts = [f'[Collection "{cls.__class__.__name__}"]']
279289
parts.append(textwrap.indent("Members:", prefix=" " * 2))
@@ -344,6 +354,16 @@ def non_ignored_members(cls) -> set[str]:
344354
if not member.ignored_in_filters
345355
}
346356

357+
@classmethod
358+
def lazy_members(cls) -> set[str]:
359+
"""The names of all members annotated as lazy frames."""
360+
return {name for name, member in cls.members().items() if member.is_lazy}
361+
362+
@classmethod
363+
def eager_members(cls) -> set[str]:
364+
"""The names of all members annotated as data frames (eager)."""
365+
return {name for name, member in cls.members().items() if not member.is_lazy}
366+
347367
@classmethod
348368
def _failure_propagating_members(cls) -> set[str]:
349369
"""The names of all members of the collection that propagate individual row
@@ -372,9 +392,9 @@ def _filters(cls) -> dict[str, Filter[Self]]:
372392
return getattr(cls, _FILTER_ATTR)
373393

374394
def to_dict(self) -> dict[str, pl.LazyFrame]:
375-
"""Return a dictionary representation of this collection."""
395+
"""Return a dictionary with all members as lazy frames."""
376396
return {
377-
member: getattr(self, member)
397+
member: getattr(self, member).lazy()
378398
for member in self.member_schemas()
379399
if getattr(self, member) is not None
380400
}
@@ -385,6 +405,9 @@ def _init(cls, data: Mapping[str, FrameType], /) -> Self:
385405
for member_name, member in cls.members().items():
386406
if member.is_optional and member_name not in data:
387407
setattr(out, member_name, None)
388-
else:
408+
elif member.is_lazy:
389409
setattr(out, member_name, data[member_name].lazy())
410+
else:
411+
setattr(out, member_name, data[member_name].lazy().collect())
412+
390413
return out

dataframely/collection/collection.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from dataframely._storage.constants import COLLECTION_METADATA_KEY
3333
from dataframely._storage.delta import DeltaStorageBackend
3434
from dataframely._storage.parquet import ParquetStorageBackend
35-
from dataframely._typing import LazyFrame, Validation
35+
from dataframely._typing import DataFrame, LazyFrame, Validation
3636
from dataframely.exc import (
3737
DeserializationError,
3838
ValidationError,
@@ -68,13 +68,13 @@ class Collection(BaseCollection, ABC):
6868
to 1-N relationships that are managed in separate data frames.
6969
7070
A collection must only have type annotations for :class:`~dataframely.LazyFrame`
71-
with known schema:
71+
or :class:`~dataframely.DataFrame` with known schema:
7272
7373
.. code:: python
7474
7575
class MyCollection(dy.Collection):
7676
first_member: dy.LazyFrame[MyFirstSchema]
77-
second_member: dy.LazyFrame[MySecondSchema]
77+
second_member: dy.DataFrame[MySecondSchema]
7878
7979
Besides, it may define *filters* (c.f. :meth:`~dataframely.filter`) and arbitrary
8080
methods.
@@ -788,17 +788,14 @@ def collect_all(self) -> Self:
788788
particularly useful when :meth:`filter` is called with lazy frame inputs.
789789
790790
Returns:
791-
The same collection with all members collected once.
792-
793-
Note:
794-
As all collection members are required to be lazy frames, the returned
795-
collection's members are still "lazy". However, they are "shallow-lazy",
796-
meaning they are obtained by calling `.collect().lazy()`.
791+
The same collection with all members collected once. Members annotated
792+
with :class:`~dataframely.DataFrame` are returned as DataFrames, while
793+
members annotated with :class:`~dataframely.LazyFrame` are returned as
794+
"shallow-lazy" frames (obtained by calling ``.collect().lazy()``).
797795
"""
798-
dfs = pl.collect_all(self.to_dict().values())
799-
return self._init(
800-
{key: dfs[i].lazy() for i, key in enumerate(self.to_dict().keys())}
801-
)
796+
lazy_dict = self.to_dict()
797+
dfs = pl.collect_all(lazy_dict.values())
798+
return self._init(dict(zip(lazy_dict, dfs)))
802799

803800
# --------------------------------- SERIALIZATION -------------------------------- #
804801

@@ -842,6 +839,7 @@ def serialize(cls) -> str:
842839
name: {
843840
"schema": info.schema._as_dict(),
844841
"is_optional": info.is_optional,
842+
"is_lazy": info.is_lazy,
845843
"ignored_in_filters": info.ignored_in_filters,
846844
"inline_for_sampling": info.inline_for_sampling,
847845
}
@@ -1330,11 +1328,14 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] |
13301328

13311329
annotations: dict[str, Any] = {}
13321330
for name, info in decoded["members"].items():
1333-
lf_type = LazyFrame[_schema_from_dict(info["schema"])] # type: ignore
1331+
schema = _schema_from_dict(info["schema"])
1332+
# Default to lazy for backwards compatibility with old serialized data
1333+
is_lazy = info.get("is_lazy", True)
1334+
frame_type = LazyFrame[schema] if is_lazy else DataFrame[schema] # type: ignore
13341335
if info["is_optional"]:
1335-
lf_type = lf_type | None # type: ignore
1336+
frame_type = frame_type | None # type: ignore
13361337
annotations[name] = Annotated[
1337-
lf_type,
1338+
frame_type,
13381339
CollectionMember(
13391340
ignored_in_filters=info["ignored_in_filters"],
13401341
inline_for_sampling=info["inline_for_sampling"],

dataframely/collection/filter_result.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def collect_all(self, **kwargs: Any) -> CollectionFilterResult[C]:
4040
kwargs: Keyword arguments passed directly to :meth:`polars.collect_all`.
4141
4242
Returns:
43-
The same filter result object with all lazy frames collected and exposed as
43+
The same filter result object with all frames collected. Members annotated
44+
with :class:`~dataframely.DataFrame` are returned as DataFrames, while
45+
members annotated with :class:`~dataframely.LazyFrame` are returned as
4446
"shallow" lazy frames.
4547
4648
Attention:
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) QuantCo 2025-2026
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
"""Tests for dy.DataFrame members in collections.
4+
5+
Members annotated with dy.DataFrame are collected once during _init and stored as
6+
DataFrames, while dy.LazyFrame members remain lazy.
7+
"""
8+
9+
import polars as pl
10+
import pytest
11+
12+
import dataframely as dy
13+
14+
# ------------------------------------------------------------------------------------ #
15+
# SCHEMA #
16+
# ------------------------------------------------------------------------------------ #
17+
18+
19+
class UserSchema(dy.Schema):
20+
id = dy.Integer(primary_key=True)
21+
name = dy.String()
22+
23+
24+
class OrderSchema(dy.Schema):
25+
id = dy.Integer(primary_key=True)
26+
user_id = dy.Integer()
27+
amount = dy.Float(min=0)
28+
29+
30+
class EagerCollection(dy.Collection):
31+
"""Collection with only DataFrame (eager) members."""
32+
33+
users: dy.DataFrame[UserSchema]
34+
orders: dy.DataFrame[OrderSchema]
35+
36+
37+
class MixedCollection(dy.Collection):
38+
"""Collection with mixed DataFrame and LazyFrame members."""
39+
40+
users: dy.DataFrame[UserSchema]
41+
orders: dy.LazyFrame[OrderSchema]
42+
43+
44+
class LazyCollection(dy.Collection):
45+
"""Collection with only LazyFrame members (traditional)."""
46+
47+
users: dy.LazyFrame[UserSchema]
48+
orders: dy.LazyFrame[OrderSchema]
49+
50+
51+
class OptionalEagerCollection(dy.Collection):
52+
"""Collection with optional DataFrame member."""
53+
54+
users: dy.DataFrame[UserSchema]
55+
orders: dy.DataFrame[OrderSchema] | None
56+
57+
58+
# ------------------------------------------------------------------------------------ #
59+
# FIXTURES #
60+
# ------------------------------------------------------------------------------------ #
61+
62+
63+
@pytest.fixture()
64+
def valid_data() -> dict[str, pl.DataFrame]:
65+
return {
66+
"users": pl.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}),
67+
"orders": pl.DataFrame(
68+
{"id": [1, 2], "user_id": [1, 2], "amount": [10.0, 20.0]}
69+
),
70+
}
71+
72+
73+
# ------------------------------------------------------------------------------------ #
74+
# MEMBER INFO TESTS #
75+
# ------------------------------------------------------------------------------------ #
76+
77+
78+
@pytest.mark.parametrize(
79+
("collection_cls", "expected_lazy", "expected_eager"),
80+
[
81+
(EagerCollection, set(), {"users", "orders"}),
82+
(LazyCollection, {"users", "orders"}, set()),
83+
(MixedCollection, {"orders"}, {"users"}),
84+
(OptionalEagerCollection, set(), {"users", "orders"}),
85+
],
86+
)
87+
def test_member_detection(
88+
collection_cls: type[dy.Collection],
89+
expected_lazy: set[str],
90+
expected_eager: set[str],
91+
) -> None:
92+
members = collection_cls.members()
93+
for name in expected_lazy:
94+
assert members[name].is_lazy
95+
for name in expected_eager:
96+
assert not members[name].is_lazy
97+
assert collection_cls.lazy_members() == expected_lazy
98+
assert collection_cls.eager_members() == expected_eager
99+
100+
101+
def test_optional_eager_member_detection() -> None:
102+
members = OptionalEagerCollection.members()
103+
assert not members["users"].is_optional
104+
assert members["orders"].is_optional
105+
106+
107+
# ------------------------------------------------------------------------------------ #
108+
# ACCESS PATTERN TESTS #
109+
# ------------------------------------------------------------------------------------ #
110+
111+
112+
@pytest.mark.parametrize(
113+
("collection_cls", "expected_types"),
114+
[
115+
(EagerCollection, {"users": pl.DataFrame, "orders": pl.DataFrame}),
116+
(LazyCollection, {"users": pl.LazyFrame, "orders": pl.LazyFrame}),
117+
(MixedCollection, {"users": pl.DataFrame, "orders": pl.LazyFrame}),
118+
],
119+
)
120+
def test_member_access_returns_correct_type(
121+
collection_cls: type[dy.Collection],
122+
expected_types: dict[str, type],
123+
valid_data: dict[str, pl.DataFrame],
124+
) -> None:
125+
collection = collection_cls.validate(valid_data)
126+
for name, expected_type in expected_types.items():
127+
assert isinstance(getattr(collection, name), expected_type)

0 commit comments

Comments
 (0)