@@ -351,16 +351,26 @@ def isin_join(
351351
352352 new_column : sge .Expression
353353 if joins_nulls :
354+ force_float_domain = False
355+ if (
356+ conditions [0 ].dtype == dtypes .FLOAT_DTYPE
357+ or conditions [1 ].dtype == dtypes .FLOAT_DTYPE
358+ ):
359+ force_float_domain = True
354360 part1_id = sql .identifier (next (self .uid_gen .get_uid_stream ("bfpart1_" )))
355361 part2_id = sql .identifier (next (self .uid_gen .get_uid_stream ("bfpart2_" )))
356- left_expr1 , left_expr2 = _value_to_non_null_identity (conditions [0 ])
362+ left_expr1 , left_expr2 = _value_to_non_null_identity (
363+ conditions [0 ], force_float_domain
364+ )
357365 left_as_struct = sge .Struct (
358366 expressions = [
359367 sge .PropertyEQ (this = part1_id , expression = left_expr1 ),
360368 sge .PropertyEQ (this = part2_id , expression = left_expr2 ),
361369 ]
362370 )
363- right_expr1 , right_expr2 = _value_to_non_null_identity (conditions [1 ])
371+ right_expr1 , right_expr2 = _value_to_non_null_identity (
372+ conditions [1 ], force_float_domain
373+ )
364374 right_select = right .expr .select (
365375 * [
366376 sge .Struct (
@@ -593,27 +603,36 @@ def _join_condition(
593603 """
594604 if not joins_nulls :
595605 return sge .EQ (this = left .expr , expression = right .expr )
596- left_expr1 , left_expr2 = _value_to_non_null_identity (left )
597- right_expr1 , right_expr2 = _value_to_non_null_identity (right )
606+
607+ force_float_domain = False
608+ if left .dtype == dtypes .FLOAT_DTYPE or right .dtype == dtypes .FLOAT_DTYPE :
609+ force_float_domain = True
610+ left_expr1 , left_expr2 = _value_to_non_null_identity (left , force_float_domain )
611+ right_expr1 , right_expr2 = _value_to_non_null_identity (right , force_float_domain )
598612 return sge .And (
599613 this = sge .EQ (this = left_expr1 , expression = right_expr1 ),
600614 expression = sge .EQ (this = left_expr2 , expression = right_expr2 ),
601615 )
602616
603617
604618def _value_to_non_null_identity (
605- value : typed_expr .TypedExpr ,
619+ value : typed_expr .TypedExpr , force_float_domain : bool = False
606620) -> tuple [sge .Expression , sge .Expression ]:
607621 # normal_value -> (normal_value, normal_value)
608622 # null_value -> (0, 1)
609623 # nan_value -> (2, 3)
610624 if dtypes .is_numeric (value .dtype , include_bool = False ):
611- expr1 = sge .func ("COALESCE" , value .expr , sql .literal (0 , value .dtype ))
612- expr2 = sge .func ("COALESCE" , value .expr , sql .literal (1 , value .dtype ))
625+ dtype = dtypes .FLOAT_DTYPE if force_float_domain else value .dtype
626+ expr1 = sge .func (
627+ "COALESCE" , value .expr , sql .literal (0.0 if force_float_domain else 0 , dtype )
628+ )
629+ expr2 = sge .func (
630+ "COALESCE" , value .expr , sql .literal (1.0 if force_float_domain else 1 , dtype )
631+ )
613632 if value .dtype == dtypes .FLOAT_DTYPE :
614633 expr1 = sge .If (
615634 this = sge .IsNan (this = value .expr ),
616- true = sql .literal (2 , value .dtype ),
635+ true = sql .literal (2.0 , value .dtype ),
617636 false = expr1 ,
618637 )
619638 expr2 = sge .If (
0 commit comments