@@ -711,22 +711,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
711711
712712 @cached_property
713713 def _max_list_lengths_by_column (self ) -> dict [str , int ]:
714- list_columns = [
715- col
716- for col in self ._other_common_columns
717- if isinstance (self .left_schema [col ], pl .List )
718- and isinstance (self .right_schema [col ], pl .List )
719- ]
720- if not list_columns :
714+ """Max list length across all nesting levels, for columns where both sides
715+ contain a List anywhere in their type tree."""
716+ left_exprs : list [pl .Expr ] = []
717+ right_exprs : list [pl .Expr ] = []
718+ columns : list [str ] = []
719+
720+ for col in self ._other_common_columns :
721+ col_left = _list_length_exprs (pl .col (col ), self .left_schema [col ])
722+ col_right = _list_length_exprs (pl .col (col ), self .right_schema [col ])
723+ if not (col_left and col_right ):
724+ continue
725+ columns .append (col )
726+ left_exprs .append (pl .max_horizontal (col_left ).alias (col ))
727+ right_exprs .append (pl .max_horizontal (col_right ).alias (col ))
728+
729+ if not columns :
721730 return {}
722731
723- exprs = [pl .col (col ).list .len ().max ().alias (col ) for col in list_columns ]
724732 [left_max , right_max ] = pl .collect_all (
725- [self .left .select (exprs ), self .right .select (exprs )]
733+ [self .left .select (left_exprs ), self .right .select (right_exprs )]
726734 )
727735 return {
728736 col : max (int (left_max [col ].item () or 0 ), int (right_max [col ].item () or 0 ))
729- for col in list_columns
737+ for col in columns
730738 }
731739
732740 def _condition_equal_rows (self , columns : list [str ]) -> pl .Expr :
@@ -833,3 +841,21 @@ def right_only(self) -> Schema:
833841 """Columns that are only present in the right data frame, mapped to their data
834842 types."""
835843 return self .right () - self .left ()
844+
845+
846+ def _list_length_exprs (
847+ expr : pl .Expr , dtype : pl .DataType | pl .datatypes .DataTypeClass
848+ ) -> list [pl .Expr ]:
849+ """Collect max-list-length scalar expressions for every List level in the type
850+ tree."""
851+ if isinstance (dtype , pl .List ):
852+ return [expr .list .len ().max (), * _list_length_exprs (expr .explode (), dtype .inner )]
853+ if isinstance (dtype , pl .Array ):
854+ return _list_length_exprs (expr .explode (), dtype .inner )
855+ if isinstance (dtype , pl .Struct ):
856+ return [
857+ e
858+ for field in dtype .fields
859+ for e in _list_length_exprs (expr .struct [field .name ], field .dtype )
860+ ]
861+ return []
0 commit comments