1515
1616from dataframely ._filter import Filter
1717from dataframely ._polars import FrameType
18+ from dataframely ._typing import DataFrame as TypedDataFrame
1819from dataframely ._typing import LazyFrame as TypedLazyFrame
1920from dataframely .exc import AnnotationImplementationError , ImplementationError
2021from dataframely .schema import Schema
@@ -92,6 +93,8 @@ class MemberInfo(CollectionMember):
9293 schema : type [Schema ]
9394 #: Whether the member is optional.
9495 is_optional : bool
96+ #: Whether the member is a lazy frame.
97+ is_lazy : bool = True
9598
9699
97100@dataclass
@@ -241,39 +244,46 @@ def _derive_member_info(
241244 attr , annotation_args [0 ], annotation_args [1 ]
242245 )
243246 elif origin == typing .Union :
244- # Happy path: optional member
247+ # Happy path: optional member (e.g. dy.LazyFrame[Schema] | None)
245248 union_args = get_args (type_annotation )
246249 if len (union_args ) != 2 :
247250 raise AnnotationImplementationError (attr , type_annotation )
248- if not any (get_origin (arg ) is None for arg in union_args ):
251+ # Check that exactly one arg is None (type(None) is NoneType)
252+ if not any (arg is type (None ) for arg in union_args ):
249253 raise AnnotationImplementationError (attr , type_annotation )
250254
251- not_none_args = [arg for arg in union_args if get_origin (arg ) is not None ]
252- if len (not_none_args ) == 0 or not issubclass (
253- get_origin (not_none_args [0 ]), TypedLazyFrame
254- ):
255+ # Get the non-None type (exactly one exists given prior checks)
256+ not_none_arg = next (arg for arg in union_args if arg is not type (None ))
257+
258+ frame_origin = get_origin (not_none_arg )
259+ if frame_origin is None :
255260 raise AnnotationImplementationError (attr , type_annotation )
256261
257- return MemberInfo (
258- schema = get_args (not_none_args [0 ])[0 ],
259- is_optional = True ,
260- ignored_in_filters = collection_member .ignored_in_filters ,
261- inline_for_sampling = collection_member .inline_for_sampling ,
262- propagate_row_failures = collection_member .propagate_row_failures ,
263- )
264- elif issubclass (origin , TypedLazyFrame ):
265- # Happy path: required member
266- return MemberInfo (
267- schema = get_args (type_annotation )[0 ],
268- is_optional = False ,
269- ignored_in_filters = collection_member .ignored_in_filters ,
270- inline_for_sampling = collection_member .inline_for_sampling ,
271- propagate_row_failures = collection_member .propagate_row_failures ,
272- )
262+ schema = get_args (not_none_arg )[0 ]
263+ is_optional = True
264+ elif issubclass (origin , (TypedLazyFrame , TypedDataFrame )):
265+ frame_origin = origin
266+ schema = get_args (type_annotation )[0 ]
267+ is_optional = False
268+ else :
269+ raise AnnotationImplementationError (attr , type_annotation )
270+
271+ if issubclass (frame_origin , TypedLazyFrame ):
272+ is_lazy = True
273+ elif issubclass (frame_origin , TypedDataFrame ):
274+ is_lazy = False
273275 else :
274- # Some other unknown annotation
275276 raise AnnotationImplementationError (attr , type_annotation )
276277
278+ return MemberInfo (
279+ schema = schema ,
280+ is_optional = is_optional ,
281+ is_lazy = is_lazy ,
282+ ignored_in_filters = collection_member .ignored_in_filters ,
283+ inline_for_sampling = collection_member .inline_for_sampling ,
284+ propagate_row_failures = collection_member .propagate_row_failures ,
285+ )
286+
277287 def __repr__ (cls ) -> str :
278288 parts = [f'[Collection "{ cls .__class__ .__name__ } "]' ]
279289 parts .append (textwrap .indent ("Members:" , prefix = " " * 2 ))
@@ -344,6 +354,16 @@ def non_ignored_members(cls) -> set[str]:
344354 if not member .ignored_in_filters
345355 }
346356
357+ @classmethod
358+ def lazy_members (cls ) -> set [str ]:
359+ """The names of all members annotated as lazy frames."""
360+ return {name for name , member in cls .members ().items () if member .is_lazy }
361+
362+ @classmethod
363+ def eager_members (cls ) -> set [str ]:
364+ """The names of all members annotated as data frames (eager)."""
365+ return {name for name , member in cls .members ().items () if not member .is_lazy }
366+
347367 @classmethod
348368 def _failure_propagating_members (cls ) -> set [str ]:
349369 """The names of all members of the collection that propagate individual row
@@ -372,9 +392,9 @@ def _filters(cls) -> dict[str, Filter[Self]]:
372392 return getattr (cls , _FILTER_ATTR )
373393
374394 def to_dict (self ) -> dict [str , pl .LazyFrame ]:
375- """Return a dictionary representation of this collection ."""
395+ """Return a dictionary with all members as lazy frames ."""
376396 return {
377- member : getattr (self , member )
397+ member : getattr (self , member ). lazy ()
378398 for member in self .member_schemas ()
379399 if getattr (self , member ) is not None
380400 }
@@ -385,6 +405,9 @@ def _init(cls, data: Mapping[str, FrameType], /) -> Self:
385405 for member_name , member in cls .members ().items ():
386406 if member .is_optional and member_name not in data :
387407 setattr (out , member_name , None )
388- else :
408+ elif member . is_lazy :
389409 setattr (out , member_name , data [member_name ].lazy ())
410+ else :
411+ setattr (out , member_name , data [member_name ].lazy ().collect ())
412+
390413 return out
0 commit comments