Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit cd5579b

Browse files
fix isin logic
1 parent 64d5ce9 commit cd5579b

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

bigframes/core/compile/sqlglot/expressions/comparison_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,23 @@
3333
@register_unary_op(ops.IsInOp, pass_op=True)
3434
def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:
3535
values = []
36+
# bools are not comparable to non-bools in SQL, so we need to cast the expression to INT64 if the values contain non-bools.
37+
must_upcast_bools = dtypes.is_numeric(expr.dtype, include_bool=False) or any(
38+
dtypes.is_numeric(dtypes.bigframes_type(type(value)), include_bool=False)
39+
for value in op.values
40+
)
3641
for value in op.values:
3742
if _is_null(value):
3843
continue
3944
dtype = dtypes.bigframes_type(type(value))
4045
if dtypes.can_compare(expr.dtype, dtype):
46+
if must_upcast_bools and dtype == dtypes.BOOL_DTYPE:
47+
value = int(value)
4148
values.append(sge.convert(value))
4249

50+
if expr.dtype == dtypes.BOOL_DTYPE and must_upcast_bools:
51+
expr = TypedExpr(sge.cast(expr.expr, "INT64"), dtypes.INT_DTYPE)
52+
4353
if op.match_nulls:
4454
contains_nulls = any(_is_null(value) for value in op.values)
4555
if contains_nulls:

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,26 @@ def isin_join(
351351

352352
new_column: sge.Expression
353353
if joins_nulls:
354+
force_float_domain = False
355+
if (
356+
conditions[0].dtype == dtypes.FLOAT_DTYPE
357+
or conditions[1].dtype == dtypes.FLOAT_DTYPE
358+
):
359+
force_float_domain = True
354360
part1_id = sql.identifier(next(self.uid_gen.get_uid_stream("bfpart1_")))
355361
part2_id = sql.identifier(next(self.uid_gen.get_uid_stream("bfpart2_")))
356-
left_expr1, left_expr2 = _value_to_non_null_identity(conditions[0])
362+
left_expr1, left_expr2 = _value_to_non_null_identity(
363+
conditions[0], force_float_domain
364+
)
357365
left_as_struct = sge.Struct(
358366
expressions=[
359367
sge.PropertyEQ(this=part1_id, expression=left_expr1),
360368
sge.PropertyEQ(this=part2_id, expression=left_expr2),
361369
]
362370
)
363-
right_expr1, right_expr2 = _value_to_non_null_identity(conditions[1])
371+
right_expr1, right_expr2 = _value_to_non_null_identity(
372+
conditions[1], force_float_domain
373+
)
364374
right_select = right.expr.select(
365375
*[
366376
sge.Struct(
@@ -593,27 +603,36 @@ def _join_condition(
593603
"""
594604
if not joins_nulls:
595605
return sge.EQ(this=left.expr, expression=right.expr)
596-
left_expr1, left_expr2 = _value_to_non_null_identity(left)
597-
right_expr1, right_expr2 = _value_to_non_null_identity(right)
606+
607+
force_float_domain = False
608+
if left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE:
609+
force_float_domain = True
610+
left_expr1, left_expr2 = _value_to_non_null_identity(left, force_float_domain)
611+
right_expr1, right_expr2 = _value_to_non_null_identity(right, force_float_domain)
598612
return sge.And(
599613
this=sge.EQ(this=left_expr1, expression=right_expr1),
600614
expression=sge.EQ(this=left_expr2, expression=right_expr2),
601615
)
602616

603617

604618
def _value_to_non_null_identity(
605-
value: typed_expr.TypedExpr,
619+
value: typed_expr.TypedExpr, force_float_domain: bool = False
606620
) -> tuple[sge.Expression, sge.Expression]:
607621
# normal_value -> (normal_value, normal_value)
608622
# null_value -> (0, 1)
609623
# nan_value -> (2, 3)
610624
if dtypes.is_numeric(value.dtype, include_bool=False):
611-
expr1 = sge.func("COALESCE", value.expr, sql.literal(0, value.dtype))
612-
expr2 = sge.func("COALESCE", value.expr, sql.literal(1, value.dtype))
625+
dtype = dtypes.FLOAT_DTYPE if force_float_domain else value.dtype
626+
expr1 = sge.func(
627+
"COALESCE", value.expr, sql.literal(0.0 if force_float_domain else 0, dtype)
628+
)
629+
expr2 = sge.func(
630+
"COALESCE", value.expr, sql.literal(1.0 if force_float_domain else 1, dtype)
631+
)
613632
if value.dtype == dtypes.FLOAT_DTYPE:
614633
expr1 = sge.If(
615634
this=sge.IsNan(this=value.expr),
616-
true=sql.literal(2, value.dtype),
635+
true=sql.literal(2.0, value.dtype),
617636
false=expr1,
618637
)
619638
expr2 = sge.If(

0 commit comments

Comments
 (0)