Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit c7439e4

Browse files
fix more tests
1 parent b82c937 commit c7439e4

7 files changed

Lines changed: 128 additions & 11 deletions

File tree

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def arctanh_op_impl(x: ibis_types.Value):
169169
@scalar_op_compiler.register_unary_op(ops.floor_op)
170170
def floor_op_impl(x: ibis_types.Value):
171171
x_numeric = typing.cast(ibis_types.NumericValue, x)
172+
if x_numeric.type().is_boolean():
173+
return x_numeric.cast(ibis_dtypes.Int64()).cast(ibis_dtypes.Float64())
172174
if x_numeric.type().is_integer():
173175
return x_numeric.cast(ibis_dtypes.Float64())
174176
if x_numeric.type().is_floating():
@@ -181,6 +183,8 @@ def floor_op_impl(x: ibis_types.Value):
181183
@scalar_op_compiler.register_unary_op(ops.ceil_op)
182184
def ceil_op_impl(x: ibis_types.Value):
183185
x_numeric = typing.cast(ibis_types.NumericValue, x)
186+
if x_numeric.type().is_boolean():
187+
return x_numeric.cast(ibis_dtypes.Int64()).cast(ibis_dtypes.Float64())
184188
if x_numeric.type().is_integer():
185189
return x_numeric.cast(ibis_dtypes.Float64())
186190
if x_numeric.type().is_floating():

bigframes/core/compile/polars/lowering.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,10 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
174174
divisor.output_type
175175
):
176176
# exact same as floordiv impl for timedelta
177-
numeric_result = ops.floordiv_op.as_expr(
177+
numeric_result = ops.div_op.as_expr(
178178
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor
179179
)
180-
int_result = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(numeric_result)
181-
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result)
182-
180+
return _numeric_to_timedelta(numeric_result)
183181
if (
184182
dividend.output_type == dtypes.BOOL_DTYPE
185183
and divisor.output_type == dtypes.BOOL_DTYPE
@@ -226,11 +224,10 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
226224
divisor.output_type
227225
):
228226
# this is pretty fragile as zero will break it, and must fit back into int
229-
numeric_result = expr.op.as_expr(
227+
numeric_result = ops.div_op.as_expr(
230228
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor
231229
)
232-
int_result = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(numeric_result)
233-
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result)
230+
return _numeric_to_timedelta(numeric_result)
234231

235232
if dividend.output_type == dtypes.BOOL_DTYPE:
236233
dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend)
@@ -319,6 +316,32 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
319316
return expr
320317

321318

319+
class LowerCeilOp(op_lowering.OpLoweringRule):
320+
@property
321+
def op(self) -> type[ops.ScalarOp]:
322+
return numeric_ops.CeilOp
323+
324+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
325+
assert isinstance(expr.op, numeric_ops.CeilOp)
326+
arg = expr.children[0]
327+
if arg.output_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE):
328+
return expr.op.as_expr(ops.AsTypeOp(dtypes.FLOAT_DTYPE).as_expr(arg))
329+
return expr
330+
331+
332+
class LowerFloorOp(op_lowering.OpLoweringRule):
333+
@property
334+
def op(self) -> type[ops.ScalarOp]:
335+
return numeric_ops.FloorOp
336+
337+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
338+
assert isinstance(expr.op, numeric_ops.FloorOp)
339+
arg = expr.children[0]
340+
if arg.output_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE):
341+
return expr.op.as_expr(ops.AsTypeOp(dtypes.FLOAT_DTYPE).as_expr(arg))
342+
return expr
343+
344+
322345
class LowerIsinOp(op_lowering.OpLoweringRule):
323346
@property
324347
def op(self) -> type[ops.ScalarOp]:
@@ -465,8 +488,21 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
465488
LowerInvertOp(),
466489
LowerIsinOp(),
467490
LowerLenOp(),
491+
LowerCeilOp(),
492+
LowerFloorOp(),
468493
)
469494

470495

471496
def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
472497
return op_lowering.lower_ops(root, rules=POLARS_LOWERING_RULES)
498+
499+
500+
def _numeric_to_timedelta(expr: expression.Expression) -> expression.Expression:
501+
"""rounding logic used for emulating timedelta ops"""
502+
rounded_value = ops.where_op.as_expr(
503+
ops.floor_op.as_expr(expr),
504+
ops.gt_op.as_expr(expr, expression.const(0)),
505+
ops.ceil_op.as_expr(expr),
506+
)
507+
int_value = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rounded_value)
508+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_value)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import bigframes_vendored.sqlglot.expressions as sge
18+
19+
20+
def round_towards_zero(expr: sge.Expression):
21+
"""
22+
Round a float value to to an integer, always rounding towards zero.
23+
24+
This is used to handle duration/timedelta emulation mostly.
25+
"""
26+
return sge.Cast(
27+
this=sge.If(
28+
this=sge.GT(this=expr, expression=sge.convert(0)),
29+
true=sge.Floor(this=expr),
30+
false=sge.Ceil(this=expr),
31+
),
32+
to="INT64",
33+
)

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from bigframes import dtypes
2121
from bigframes import operations as ops
2222
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
23+
from bigframes.core.compile.sqlglot.expressions.common import round_towards_zero
2324
import bigframes.core.compile.sqlglot.expressions.constants as constants
2425
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2526
from bigframes.operations import numeric_ops
@@ -467,7 +468,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
467468

468469
result = sge.func("IEEE_DIVIDE", left_expr, right_expr)
469470
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
470-
return sge.Cast(this=sge.Floor(this=result), to="INT64")
471+
return round_towards_zero(result)
471472
else:
472473
return result
473474

@@ -510,7 +511,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
510511
)
511512

512513
if dtypes.is_numeric(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE:
513-
result = sge.Cast(this=sge.Floor(this=result), to="INT64")
514+
result = round_towards_zero(sge.func("IEEE_DIVIDE", left_expr, right_expr))
514515

515516
return result
516517

@@ -578,7 +579,7 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
578579
if (dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE) or (
579580
left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype)
580581
):
581-
return sge.Cast(this=sge.Floor(this=result), to="INT64")
582+
return round_towards_zero(result)
582583
else:
583584
return result
584585

bigframes/session/polars_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
numeric_ops.SubOp,
8080
numeric_ops.MulOp,
8181
numeric_ops.DivOp,
82+
numeric_ops.CeilOp,
83+
numeric_ops.FloorOp,
8284
numeric_ops.FloorDivOp,
8385
numeric_ops.ModOp,
8486
generic_ops.AsTypeOp,

tests/system/large/functions/test_remote_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@ def double_num(x):
13541354
bf_result = bf_int64_col.to_frame().assign(result=bf_result_col).to_pandas()
13551355

13561356
pd_int64_col = scalars_pandas_df["int64_col"]
1357-
pd_result_col = pd_int64_col.apply(lambda x: x if x is None else x * x)
1357+
pd_result_col = pd_int64_col.apply(lambda x: x if x is None else x + x)
13581358
pd_result = pd_int64_col.to_frame().assign(result=pd_result_col)
13591359

13601360
assert_frame_equal(bf_result, pd_result, check_dtype=False)

tests/system/small/engines/test_numeric_ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,47 @@ def apply_op_pairwise(
5353
return new_arr
5454

5555

56+
def apply_op(
57+
array: array_value.ArrayValue, op: ops.UnaryOp, excluded_cols=[]
58+
) -> array_value.ArrayValue:
59+
exprs = []
60+
labels = []
61+
for arg in array.column_ids:
62+
if arg in excluded_cols:
63+
continue
64+
try:
65+
_ = op.output_type(array.get_column_type(arg))
66+
expr = op.as_expr(arg)
67+
exprs.append(expr)
68+
labels.append(f"{arg}_{op.name}")
69+
except TypeError:
70+
continue
71+
assert len(exprs) > 0
72+
new_arr, ids = array.compute_values(exprs)
73+
new_arr = new_arr.rename_columns(
74+
{new_col: label for new_col, label in zip(ids, labels)}
75+
)
76+
return new_arr
77+
78+
79+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
80+
def test_engines_project_ceil(
81+
scalars_array_value: array_value.ArrayValue,
82+
engine,
83+
):
84+
arr = apply_op(scalars_array_value, ops.ceil_op)
85+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
86+
87+
88+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
89+
def test_engines_project_floor(
90+
scalars_array_value: array_value.ArrayValue,
91+
engine,
92+
):
93+
arr = apply_op(scalars_array_value, ops.floor_op)
94+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
95+
96+
5697
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5798
def test_engines_project_add(
5899
scalars_array_value: array_value.ArrayValue,

0 commit comments

Comments
 (0)