Skip to content

Commit 51ced99

Browse files
author
gabriel
committed
feat: Add DataFrame support in Collection
1 parent b993da6 commit 51ced99

5 files changed

Lines changed: 242 additions & 48 deletions

File tree

dataframely/collection/_base.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from dataframely._filter import Filter
1717
from dataframely._polars import FrameType
18+
from dataframely._typing import DataFrame as TypedDataFrame
1819
from dataframely._typing import LazyFrame as TypedLazyFrame
1920
from dataframely.exc import AnnotationImplementationError, ImplementationError
2021
from dataframely.schema import Schema
@@ -92,6 +93,8 @@ class MemberInfo(CollectionMember):
9293
schema: type[Schema]
9394
#: Whether the member is optional.
9495
is_optional: bool
96+
#: Whether the member is a lazy frame.
97+
is_lazy: bool = True
9598

9699

97100
@dataclass
@@ -249,31 +252,38 @@ def _derive_member_info(
249252
raise AnnotationImplementationError(attr, type_annotation)
250253

251254
not_none_args = [arg for arg in union_args if get_origin(arg) is not None]
252-
if len(not_none_args) == 0 or not issubclass(
253-
get_origin(not_none_args[0]), TypedLazyFrame
254-
):
255+
if len(not_none_args) == 0:
255256
raise AnnotationImplementationError(attr, type_annotation)
256257

257-
return MemberInfo(
258-
schema=get_args(not_none_args[0])[0],
259-
is_optional=True,
260-
ignored_in_filters=collection_member.ignored_in_filters,
261-
inline_for_sampling=collection_member.inline_for_sampling,
262-
propagate_row_failures=collection_member.propagate_row_failures,
263-
)
264-
elif issubclass(origin, TypedLazyFrame):
265-
# Happy path: required member
266-
return MemberInfo(
267-
schema=get_args(type_annotation)[0],
268-
is_optional=False,
269-
ignored_in_filters=collection_member.ignored_in_filters,
270-
inline_for_sampling=collection_member.inline_for_sampling,
271-
propagate_row_failures=collection_member.propagate_row_failures,
272-
)
258+
frame_origin = get_origin(not_none_args[0])
259+
if frame_origin is None:
260+
raise AnnotationImplementationError(attr, type_annotation)
261+
262+
schema = get_args(not_none_args[0])[0]
263+
is_optional = True
264+
elif issubclass(origin, (TypedLazyFrame, TypedDataFrame)):
265+
frame_origin = origin
266+
schema = get_args(type_annotation)[0]
267+
is_optional = False
273268
else:
274-
# Some other unknown annotation
275269
raise AnnotationImplementationError(attr, type_annotation)
276270

271+
if issubclass(frame_origin, TypedLazyFrame):
272+
is_lazy = True
273+
elif issubclass(frame_origin, TypedDataFrame):
274+
is_lazy = False
275+
else:
276+
raise AnnotationImplementationError(attr, type_annotation)
277+
278+
return MemberInfo(
279+
schema=schema,
280+
is_optional=is_optional,
281+
is_lazy=is_lazy,
282+
ignored_in_filters=collection_member.ignored_in_filters,
283+
inline_for_sampling=collection_member.inline_for_sampling,
284+
propagate_row_failures=collection_member.propagate_row_failures,
285+
)
286+
277287
def __repr__(cls) -> str:
278288
parts = [f'[Collection "{cls.__class__.__name__}"]']
279289
parts.append(textwrap.indent("Members:", prefix=" " * 2))
@@ -344,6 +354,16 @@ def non_ignored_members(cls) -> set[str]:
344354
if not member.ignored_in_filters
345355
}
346356

357+
@classmethod
358+
def lazy_members(cls) -> set[str]:
359+
"""The names of all members annotated as lazy frames."""
360+
return {name for name, member in cls.members().items() if member.is_lazy}
361+
362+
@classmethod
363+
def eager_members(cls) -> set[str]:
364+
"""The names of all members annotated as data frames (eager)."""
365+
return {name for name, member in cls.members().items() if not member.is_lazy}
366+
347367
@classmethod
348368
def _failure_propagating_members(cls) -> set[str]:
349369
"""The names of all members of the collection that propagate individual row
@@ -371,20 +391,40 @@ def common_primary_key(cls) -> list[str]:
371391
def _filters(cls) -> dict[str, Filter[Self]]:
372392
return getattr(cls, _FILTER_ATTR)
373393

374-
def to_dict(self) -> dict[str, pl.LazyFrame]:
375-
"""Return a dictionary representation of this collection."""
394+
def to_dict(self) -> dict[str, FrameType]:
395+
"""Return a dictionary representation of this collection.
396+
397+
Returns:
398+
A dictionary mapping member names to their frames.
399+
Members annotated with :class:`~dataframely.DataFrame` return DataFrames,
400+
while members annotated with :class:`~dataframely.LazyFrame` return LazyFrames.
401+
"""
376402
return {
377403
member: getattr(self, member)
378404
for member in self.member_schemas()
379405
if getattr(self, member) is not None
380406
}
381407

408+
def _to_lazy_dict(self) -> dict[str, pl.LazyFrame]:
409+
"""Return a dictionary with all members as lazy frames (internal use)."""
410+
return {
411+
member: getattr(self, member).lazy()
412+
for member in self.member_schemas()
413+
if getattr(self, member) is not None
414+
}
415+
382416
@classmethod
383417
def _init(cls, data: Mapping[str, FrameType], /) -> Self:
384418
out = cls()
385419
for member_name, member in cls.members().items():
386420
if member.is_optional and member_name not in data:
387421
setattr(out, member_name, None)
388-
else:
422+
elif member.is_lazy:
389423
setattr(out, member_name, data[member_name].lazy())
424+
else:
425+
frame = data[member_name]
426+
if isinstance(frame, pl.LazyFrame):
427+
setattr(out, member_name, frame.collect())
428+
else:
429+
setattr(out, member_name, frame)
390430
return out

dataframely/collection/collection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def join(
735735
how=how,
736736
maintain_order=maintain_order,
737737
)
738-
for key, lf in self.to_dict().items()
738+
for key, lf in self._to_lazy_dict().items()
739739
}
740740
)
741741

@@ -795,9 +795,10 @@ def collect_all(self) -> Self:
795795
collection's members are still "lazy". However, they are "shallow-lazy",
796796
meaning they are obtained by calling `.collect().lazy()`.
797797
"""
798-
dfs = pl.collect_all(self.to_dict().values())
798+
lazy_dict = self._to_lazy_dict()
799+
dfs = pl.collect_all(lazy_dict.values())
799800
return self._init(
800-
{key: dfs[i].lazy() for i, key in enumerate(self.to_dict().keys())}
801+
{key: dfs[i].lazy() for i, key in enumerate(lazy_dict.keys())}
801802
)
802803

803804
# --------------------------------- SERIALIZATION -------------------------------- #
@@ -1172,7 +1173,7 @@ def _write(self, backend: StorageBackend, **kwargs: Any) -> None:
11721173
# Utility method encapsulating the interaction with the StorageBackend
11731174

11741175
backend.write_collection(
1175-
self.to_dict(),
1176+
self._to_lazy_dict(),
11761177
serialized_collection=self.serialize(),
11771178
serialized_schemas={
11781179
key: schema.serialize() for key, schema in self.member_schemas().items()
@@ -1184,7 +1185,7 @@ def _sink(self, backend: StorageBackend, **kwargs: Any) -> None:
11841185
# Utility method encapsulating the interaction with the StorageBackend
11851186

11861187
backend.sink_collection(
1187-
self.to_dict(),
1188+
self._to_lazy_dict(),
11881189
serialized_collection=self.serialize(),
11891190
serialized_schemas={
11901191
key: schema.serialize() for key, schema in self.member_schemas().items()

dataframely/collection/filter_result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def collect_all(self, **kwargs: Any) -> CollectionFilterResult[C]:
4747
Until https://github.com/pola-rs/polars/pull/24129 is released, the
4848
performance advantage of this method is limited.
4949
"""
50-
members = self.result.to_dict()
50+
members = self.result._to_lazy_dict()
5151
collected = pl.collect_all(
5252
itertools.chain(
5353
members.values(),
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) QuantCo 2025-2026
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
"""Tests for dy.DataFrame members in collections.
5+
6+
Members annotated with dy.DataFrame are collected once during _init and stored
7+
as DataFrames, while dy.LazyFrame members remain lazy.
8+
"""
9+
10+
import polars as pl
11+
import pytest
12+
13+
import dataframely as dy
14+
15+
# ------------------------------------------------------------------------------------ #
16+
# SCHEMA #
17+
# ------------------------------------------------------------------------------------ #
18+
19+
20+
class UserSchema(dy.Schema):
21+
id = dy.Integer(primary_key=True)
22+
name = dy.String()
23+
24+
25+
class OrderSchema(dy.Schema):
26+
id = dy.Integer(primary_key=True)
27+
user_id = dy.Integer()
28+
amount = dy.Float(min=0)
29+
30+
31+
class EagerCollection(dy.Collection):
32+
"""Collection with only DataFrame (eager) members."""
33+
34+
users: dy.DataFrame[UserSchema]
35+
orders: dy.DataFrame[OrderSchema]
36+
37+
38+
class MixedCollection(dy.Collection):
39+
"""Collection with mixed DataFrame and LazyFrame members."""
40+
41+
users: dy.DataFrame[UserSchema]
42+
orders: dy.LazyFrame[OrderSchema]
43+
44+
45+
class LazyCollection(dy.Collection):
46+
"""Collection with only LazyFrame members (traditional)."""
47+
48+
users: dy.LazyFrame[UserSchema]
49+
orders: dy.LazyFrame[OrderSchema]
50+
51+
52+
class OptionalEagerCollection(dy.Collection):
53+
"""Collection with optional DataFrame member."""
54+
55+
users: dy.DataFrame[UserSchema]
56+
orders: dy.DataFrame[OrderSchema] | None
57+
58+
59+
# ------------------------------------------------------------------------------------ #
60+
# FIXTURES #
61+
# ------------------------------------------------------------------------------------ #
62+
63+
64+
@pytest.fixture()
65+
def valid_data() -> dict[str, pl.DataFrame]:
66+
return {
67+
"users": pl.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}),
68+
"orders": pl.DataFrame(
69+
{"id": [1, 2], "user_id": [1, 2], "amount": [10.0, 20.0]}
70+
),
71+
}
72+
73+
74+
# ------------------------------------------------------------------------------------ #
75+
# MEMBER INFO TESTS #
76+
# ------------------------------------------------------------------------------------ #
77+
78+
79+
def test_eager_member_detection() -> None:
80+
members = EagerCollection.members()
81+
assert not members["users"].is_lazy
82+
assert not members["orders"].is_lazy
83+
84+
85+
def test_lazy_member_detection() -> None:
86+
members = LazyCollection.members()
87+
assert members["users"].is_lazy
88+
assert members["orders"].is_lazy
89+
90+
91+
def test_mixed_member_detection() -> None:
92+
members = MixedCollection.members()
93+
assert not members["users"].is_lazy
94+
assert members["orders"].is_lazy
95+
96+
97+
def test_optional_eager_member_detection() -> None:
98+
members = OptionalEagerCollection.members()
99+
assert not members["users"].is_lazy
100+
assert not members["orders"].is_lazy
101+
assert not members["users"].is_optional
102+
assert members["orders"].is_optional
103+
104+
105+
def test_lazy_members_helper() -> None:
106+
assert EagerCollection.lazy_members() == set()
107+
assert LazyCollection.lazy_members() == {"users", "orders"}
108+
assert MixedCollection.lazy_members() == {"orders"}
109+
110+
111+
def test_eager_members_helper() -> None:
112+
assert EagerCollection.eager_members() == {"users", "orders"}
113+
assert LazyCollection.eager_members() == set()
114+
assert MixedCollection.eager_members() == {"users"}
115+
116+
117+
# ------------------------------------------------------------------------------------ #
118+
# ACCESS PATTERN TESTS #
119+
# ------------------------------------------------------------------------------------ #
120+
121+
122+
def test_eager_member_returns_dataframe(valid_data: dict[str, pl.DataFrame]) -> None:
123+
collection = EagerCollection.validate(valid_data)
124+
assert isinstance(collection.users, pl.DataFrame)
125+
assert isinstance(collection.orders, pl.DataFrame)
126+
127+
128+
def test_lazy_member_returns_lazyframe(valid_data: dict[str, pl.DataFrame]) -> None:
129+
collection = LazyCollection.validate(valid_data)
130+
assert isinstance(collection.users, pl.LazyFrame)
131+
assert isinstance(collection.orders, pl.LazyFrame)
132+
133+
134+
def test_mixed_collection_returns_correct_types(
135+
valid_data: dict[str, pl.DataFrame],
136+
) -> None:
137+
collection = MixedCollection.validate(valid_data)
138+
assert isinstance(collection.users, pl.DataFrame)
139+
assert isinstance(collection.orders, pl.LazyFrame)
140+
141+
142+
def test_to_dict_returns_correct_types(valid_data: dict[str, pl.DataFrame]) -> None:
143+
eager = EagerCollection.validate(valid_data)
144+
result = eager.to_dict()
145+
assert isinstance(result["users"], pl.DataFrame)
146+
assert isinstance(result["orders"], pl.DataFrame)
147+
148+
mixed = MixedCollection.validate(valid_data)
149+
result = mixed.to_dict()
150+
assert isinstance(result["users"], pl.DataFrame)
151+
assert isinstance(result["orders"], pl.LazyFrame)

tests/collection/test_implementation.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@ class MyTestSchema(dy.Schema):
1616
a = dy.Integer(primary_key=True)
1717

1818

19-
def test_annotation_type_failure() -> None:
20-
with pytest.raises(
21-
AnnotationImplementationError,
22-
):
23-
create_collection(
24-
"test",
25-
{
26-
"first": create_schema("first", {"a": dy.Integer()}),
27-
},
28-
annotation_base_class=dy.DataFrame,
29-
)
19+
def test_annotation_dataframe_success() -> None:
20+
"""DataFrame annotations are now supported."""
21+
collection = create_collection(
22+
"test",
23+
{
24+
"first": create_schema("first", {"a": dy.Integer()}),
25+
},
26+
annotation_base_class=dy.DataFrame,
27+
)
28+
members = collection.members()
29+
assert not members["first"].is_lazy
3030

3131

3232
def test_annotation_union_success() -> None:
@@ -40,14 +40,16 @@ def test_annotation_union_success() -> None:
4040

4141

4242
def test_annotation_union_with_data_frame() -> None:
43-
"""When we use a union annotation, it must contain one typed LazyFrame and None."""
44-
with pytest.raises(AnnotationImplementationError):
45-
create_collection_raw(
46-
"test",
47-
{
48-
"first": dy.DataFrame[MyTestSchema] | None,
49-
},
50-
)
43+
"""DataFrame union with None is now supported for optional eager members."""
44+
collection = create_collection_raw(
45+
"test",
46+
{
47+
"first": dy.DataFrame[MyTestSchema] | None,
48+
},
49+
)
50+
members = collection.members()
51+
assert not members["first"].is_lazy
52+
assert members["first"].is_optional
5153

5254

5355
def test_annotation_union_too_many_arg_failure() -> None:

0 commit comments

Comments
 (0)