66import polars as pl
77import pytest
88
9- from diffly ._conditions import _can_compare_dtypes , condition_equal_columns
9+ from diffly ._conditions import (
10+ _can_compare_dtypes ,
11+ _needs_element_wise_comparison ,
12+ condition_equal_columns ,
13+ )
1014from diffly .comparison import compare_frames
1115
1216
@@ -512,6 +516,45 @@ def test_condition_equal_columns_lists_only_inner() -> None:
512516 assert actual .to_list () == [True , False ]
513517
514518
519+ def test_condition_equal_columns_list_of_different_enums () -> None :
520+ # Arrange
521+ first_enum = pl .Enum (["one" , "two" ])
522+ second_enum = pl .Enum (["one" , "two" , "three" ])
523+
524+ lhs = pl .DataFrame (
525+ {"pk" : [1 , 2 ], "a" : [["one" , "two" ], ["one" , "one" ]]},
526+ schema_overrides = {"a" : pl .List (first_enum )},
527+ )
528+ rhs = pl .DataFrame (
529+ {"pk" : [1 , 2 ], "a" : [["one" , "two" ], ["one" , "three" ]]},
530+ schema_overrides = {"a" : pl .List (second_enum )},
531+ )
532+ c = compare_frames (lhs , rhs , primary_key = "pk" )
533+
534+ # Act
535+ lhs = lhs .rename ({"a" : "a_left" })
536+ rhs = rhs .rename ({"a" : "a_right" })
537+ actual = (
538+ lhs .join (rhs , on = "pk" , maintain_order = "left" )
539+ .select (
540+ condition_equal_columns (
541+ "a" ,
542+ dtype_left = lhs .schema ["a_left" ],
543+ dtype_right = rhs .schema ["a_right" ],
544+ max_list_length = c ._max_list_lengths_by_column .get ("a" ),
545+ abs_tol = c .abs_tol_by_column ["a" ],
546+ rel_tol = c .rel_tol_by_column ["a" ],
547+ )
548+ )
549+ .to_series ()
550+ )
551+
552+ # Assert
553+ assert c ._max_list_lengths_by_column == {"a" : 2 }
554+ assert _needs_element_wise_comparison (first_enum , second_enum )
555+ assert actual .to_list () == [True , False ]
556+
557+
515558@pytest .mark .parametrize (
516559 ("dtype_left" , "dtype_right" , "can_compare_dtypes" ),
517560 [
@@ -534,3 +577,73 @@ def test_can_compare_dtypes(
534577 dtype_left = dtype_left , dtype_right = dtype_right
535578 )
536579 assert can_compare_dtypes_actual == can_compare_dtypes
580+
581+
582+ @pytest .mark .parametrize (
583+ ("dtype_left" , "dtype_right" , "expected" ),
584+ [
585+ # Primitives that don't need element-wise comparison
586+ (pl .Int64 , pl .Int64 , False ),
587+ (pl .String , pl .String , False ),
588+ (pl .Boolean , pl .Boolean , False ),
589+ # Float/numeric pairs
590+ (pl .Float64 , pl .Float64 , True ),
591+ (pl .Int64 , pl .Float64 , True ),
592+ (pl .Float32 , pl .Int32 , True ),
593+ # Temporal pairs
594+ (pl .Datetime , pl .Datetime , True ),
595+ (pl .Date , pl .Date , True ),
596+ (pl .Datetime , pl .Date , True ),
597+ # Enum/categorical
598+ (pl .Enum (["a" , "b" ]), pl .Enum (["a" , "b" ]), False ),
599+ (pl .Enum (["a" , "b" ]), pl .Enum (["a" , "b" , "c" ]), True ),
600+ (pl .Enum (["a" ]), pl .Categorical (), True ),
601+ (pl .Categorical (), pl .Enum (["a" ]), True ),
602+ # Struct with no tolerance-requiring fields
603+ (
604+ pl .Struct ({"x" : pl .Int64 , "y" : pl .String }),
605+ pl .Struct ({"x" : pl .Int64 , "y" : pl .String }),
606+ False ,
607+ ),
608+ # Struct with a float field
609+ (
610+ pl .Struct ({"x" : pl .Int64 , "y" : pl .Float64 }),
611+ pl .Struct ({"x" : pl .Int64 , "y" : pl .Float64 }),
612+ True ,
613+ ),
614+ # Struct with different-category enums
615+ (
616+ pl .Struct ({"x" : pl .Enum (["a" ])}),
617+ pl .Struct ({"x" : pl .Enum (["b" ])}),
618+ True ,
619+ ),
620+ # List/Array with non-tolerance inner type
621+ (pl .List (pl .Int64 ), pl .List (pl .Int64 ), False ),
622+ (pl .Array (pl .String , shape = 3 ), pl .Array (pl .String , shape = 3 ), False ),
623+ # List/Array with tolerance-requiring inner type
624+ (pl .List (pl .Float64 ), pl .List (pl .Float64 ), True ),
625+ (pl .Array (pl .Datetime , shape = 2 ), pl .Array (pl .Datetime , shape = 2 ), True ),
626+ # Nested: list of structs with a float field
627+ (
628+ pl .List (pl .Struct ({"x" : pl .Float64 })),
629+ pl .List (pl .Struct ({"x" : pl .Float64 })),
630+ True ,
631+ ),
632+ # Nested: list of structs without tolerance-requiring fields
633+ (
634+ pl .List (pl .Struct ({"x" : pl .Int64 })),
635+ pl .List (pl .Struct ({"x" : pl .Int64 })),
636+ False ,
637+ ),
638+ # Deeply nested: struct with a list of structs with a float field
639+ (
640+ pl .List (pl .Struct ({"x" : pl .String , "y" : pl .List (pl .Float64 )})),
641+ pl .List (pl .Struct ({"x" : pl .String , "y" : pl .List (pl .Float64 )})),
642+ True ,
643+ ),
644+ ],
645+ )
646+ def test_needs_element_wise_comparison (
647+ dtype_left : pl .DataType , dtype_right : pl .DataType , expected : bool
648+ ) -> None :
649+ assert _needs_element_wise_comparison (dtype_left , dtype_right ) == expected
0 commit comments