|
15 | 15 |
|
16 | 16 | from dataframely._filter import Filter |
17 | 17 | from dataframely._polars import FrameType |
| 18 | +from dataframely._typing import DataFrame as TypedDataFrame |
18 | 19 | from dataframely._typing import LazyFrame as TypedLazyFrame |
19 | 20 | from dataframely.exc import AnnotationImplementationError, ImplementationError |
20 | 21 | from dataframely.schema import Schema |
@@ -92,6 +93,8 @@ class MemberInfo(CollectionMember): |
92 | 93 | schema: type[Schema] |
93 | 94 | #: Whether the member is optional. |
94 | 95 | is_optional: bool |
| 96 | + #: Whether the member is a lazy frame. |
| 97 | + is_lazy: bool = True |
95 | 98 |
|
96 | 99 |
|
97 | 100 | @dataclass |
@@ -249,31 +252,38 @@ def _derive_member_info( |
249 | 252 | raise AnnotationImplementationError(attr, type_annotation) |
250 | 253 |
|
251 | 254 | not_none_args = [arg for arg in union_args if get_origin(arg) is not None] |
252 | | - if len(not_none_args) == 0 or not issubclass( |
253 | | - get_origin(not_none_args[0]), TypedLazyFrame |
254 | | - ): |
| 255 | + if len(not_none_args) == 0: |
255 | 256 | raise AnnotationImplementationError(attr, type_annotation) |
256 | 257 |
|
257 | | - return MemberInfo( |
258 | | - schema=get_args(not_none_args[0])[0], |
259 | | - is_optional=True, |
260 | | - ignored_in_filters=collection_member.ignored_in_filters, |
261 | | - inline_for_sampling=collection_member.inline_for_sampling, |
262 | | - propagate_row_failures=collection_member.propagate_row_failures, |
263 | | - ) |
264 | | - elif issubclass(origin, TypedLazyFrame): |
265 | | - # Happy path: required member |
266 | | - return MemberInfo( |
267 | | - schema=get_args(type_annotation)[0], |
268 | | - is_optional=False, |
269 | | - ignored_in_filters=collection_member.ignored_in_filters, |
270 | | - inline_for_sampling=collection_member.inline_for_sampling, |
271 | | - propagate_row_failures=collection_member.propagate_row_failures, |
272 | | - ) |
| 258 | + frame_origin = get_origin(not_none_args[0]) |
| 259 | + if frame_origin is None: |
| 260 | + raise AnnotationImplementationError(attr, type_annotation) |
| 261 | + |
| 262 | + schema = get_args(not_none_args[0])[0] |
| 263 | + is_optional = True |
| 264 | + elif issubclass(origin, (TypedLazyFrame, TypedDataFrame)): |
| 265 | + frame_origin = origin |
| 266 | + schema = get_args(type_annotation)[0] |
| 267 | + is_optional = False |
273 | 268 | else: |
274 | | - # Some other unknown annotation |
275 | 269 | raise AnnotationImplementationError(attr, type_annotation) |
276 | 270 |
|
| 271 | + if issubclass(frame_origin, TypedLazyFrame): |
| 272 | + is_lazy = True |
| 273 | + elif issubclass(frame_origin, TypedDataFrame): |
| 274 | + is_lazy = False |
| 275 | + else: |
| 276 | + raise AnnotationImplementationError(attr, type_annotation) |
| 277 | + |
| 278 | + return MemberInfo( |
| 279 | + schema=schema, |
| 280 | + is_optional=is_optional, |
| 281 | + is_lazy=is_lazy, |
| 282 | + ignored_in_filters=collection_member.ignored_in_filters, |
| 283 | + inline_for_sampling=collection_member.inline_for_sampling, |
| 284 | + propagate_row_failures=collection_member.propagate_row_failures, |
| 285 | + ) |
| 286 | + |
277 | 287 | def __repr__(cls) -> str: |
278 | 288 | parts = [f'[Collection "{cls.__class__.__name__}"]'] |
279 | 289 | parts.append(textwrap.indent("Members:", prefix=" " * 2)) |
@@ -344,6 +354,16 @@ def non_ignored_members(cls) -> set[str]: |
344 | 354 | if not member.ignored_in_filters |
345 | 355 | } |
346 | 356 |
|
| 357 | + @classmethod |
| 358 | + def lazy_members(cls) -> set[str]: |
| 359 | + """The names of all members annotated as lazy frames.""" |
| 360 | + return {name for name, member in cls.members().items() if member.is_lazy} |
| 361 | + |
| 362 | + @classmethod |
| 363 | + def eager_members(cls) -> set[str]: |
| 364 | + """The names of all members annotated as data frames (eager).""" |
| 365 | + return {name for name, member in cls.members().items() if not member.is_lazy} |
| 366 | + |
347 | 367 | @classmethod |
348 | 368 | def _failure_propagating_members(cls) -> set[str]: |
349 | 369 | """The names of all members of the collection that propagate individual row |
@@ -371,20 +391,40 @@ def common_primary_key(cls) -> list[str]: |
371 | 391 | def _filters(cls) -> dict[str, Filter[Self]]: |
372 | 392 | return getattr(cls, _FILTER_ATTR) |
373 | 393 |
|
374 | | - def to_dict(self) -> dict[str, pl.LazyFrame]: |
375 | | - """Return a dictionary representation of this collection.""" |
| 394 | + def to_dict(self) -> dict[str, FrameType]: |
| 395 | + """Return a dictionary representation of this collection. |
| 396 | +
|
| 397 | + Returns: |
| 398 | + A dictionary mapping member names to their frames. |
| 399 | + Members annotated with :class:`~dataframely.DataFrame` return DataFrames, |
| 400 | + while members annotated with :class:`~dataframely.LazyFrame` return LazyFrames. |
| 401 | + """ |
376 | 402 | return { |
377 | 403 | member: getattr(self, member) |
378 | 404 | for member in self.member_schemas() |
379 | 405 | if getattr(self, member) is not None |
380 | 406 | } |
381 | 407 |
|
| 408 | + def _to_lazy_dict(self) -> dict[str, pl.LazyFrame]: |
| 409 | + """Return a dictionary with all members as lazy frames (internal use).""" |
| 410 | + return { |
| 411 | + member: getattr(self, member).lazy() |
| 412 | + for member in self.member_schemas() |
| 413 | + if getattr(self, member) is not None |
| 414 | + } |
| 415 | + |
382 | 416 | @classmethod |
383 | 417 | def _init(cls, data: Mapping[str, FrameType], /) -> Self: |
384 | 418 | out = cls() |
385 | 419 | for member_name, member in cls.members().items(): |
386 | 420 | if member.is_optional and member_name not in data: |
387 | 421 | setattr(out, member_name, None) |
388 | | - else: |
| 422 | + elif member.is_lazy: |
389 | 423 | setattr(out, member_name, data[member_name].lazy()) |
| 424 | + else: |
| 425 | + frame = data[member_name] |
| 426 | + if isinstance(frame, pl.LazyFrame): |
| 427 | + setattr(out, member_name, frame.collect()) |
| 428 | + else: |
| 429 | + setattr(out, member_name, frame) |
390 | 430 | return out |
0 commit comments