Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 648518f

Browse files
fix time roundings issues
1 parent efe9ac9 commit 648518f

File tree

7 files changed

+62
-24
lines changed

7 files changed

+62
-24
lines changed

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,12 @@ def to_timedelta_op_impl(x: ibis_types.Value, op: ops.ToTimedeltaOp):
10261026

10271027
@scalar_op_compiler.register_unary_op(ops.timedelta_floor_op)
10281028
def timedelta_floor_op_impl(x: ibis_types.NumericValue):
1029-
return x.floor()
1029+
return ibis_api.case().when(x > 0, x.floor()).else_(x.ceil()).end()
1030+
1031+
1032+
@scalar_op_compiler.register_unary_op(ops.timedelta_round_op)
1033+
def timedelta_round_op_impl(x: ibis_types.NumericValue):
1034+
return ibis_api.case().when(x > 0, x.floor()).else_(x.ceil()).end()
10301035

10311036

10321037
@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)

bigframes/core/compile/sqlglot/expressions/timedelta_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def _(expr: TypedExpr) -> sge.Expression:
3030
return sge.Floor(this=expr.expr)
3131

3232

33+
@register_unary_op(ops.timedelta_round_op)
34+
def _(expr: TypedExpr) -> sge.Expression:
35+
return sge.Round(this=expr.expr).cast()
36+
37+
3338
@register_unary_op(ops.ToTimedeltaOp, pass_op=True)
3439
def _(expr: TypedExpr, op: ops.ToTimedeltaOp) -> sge.Expression:
3540
value = expr.expr

bigframes/core/rewrite/timedeltas.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
189189
result = _TypedExpr.create_op_expr(ops.mul_op, left, right)
190190

191191
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
192-
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
192+
return _TypedExpr.create_op_expr(ops.timedelta_round_op, result)
193193
if dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE:
194-
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
194+
return _TypedExpr.create_op_expr(ops.timedelta_round_op, result)
195195

196196
return result
197197

@@ -200,18 +200,18 @@ def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
200200
result = _TypedExpr.create_op_expr(ops.div_op, left, right)
201201

202202
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
203-
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
203+
return _TypedExpr.create_op_expr(ops.timedelta_round_op, result)
204204

205205
return result
206206

207207

208208
def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
209-
result = _TypedExpr.create_op_expr(ops.floordiv_op, left, right)
210-
211209
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
212-
return _TypedExpr.create_op_expr(ops.timedelta_floor_op, result)
210+
return _TypedExpr.create_op_expr(
211+
ops.timedelta_round_op, _TypedExpr.create_op_expr(ops.div_op, left, right)
212+
)
213213

214-
return result
214+
return _TypedExpr.create_op_expr(ops.floordiv_op, left, right)
215215

216216

217217
def _rewrite_to_timedelta_op(op: ops.ToTimedeltaOp, arg: _TypedExpr):

bigframes/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@
224224
date_add_op,
225225
date_sub_op,
226226
timedelta_floor_op,
227+
timedelta_round_op,
227228
timestamp_add_op,
228229
timestamp_sub_op,
229230
ToTimedeltaOp,
@@ -308,6 +309,7 @@
308309
"date_add_op",
309310
"date_sub_op",
310311
"timedelta_floor_op",
312+
"timedelta_round_op",
311313
"timestamp_add_op",
312314
"timestamp_sub_op",
313315
"ToTimedeltaOp",

bigframes/operations/timedelta_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
5454
timedelta_floor_op = TimedeltaFloorOp()
5555

5656

57+
@dataclasses.dataclass(frozen=True)
58+
class TimedeltaRoundOp(base_ops.UnaryOp):
59+
"""Rounds the numeric value to the nearest integer and use it to represent a timedelta.
60+
61+
This operator is only meant to be used during expression tree rewrites. Do not use it anywhere else!
62+
"""
63+
64+
name: typing.ClassVar[str] = "timedelta_round"
65+
66+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
67+
input_type = input_types[0]
68+
if dtypes.is_numeric(input_type) or input_type == dtypes.TIMEDELTA_DTYPE:
69+
return dtypes.TIMEDELTA_DTYPE
70+
raise TypeError(f"unsupported type: {input_type}")
71+
72+
73+
timedelta_round_op = TimedeltaRoundOp()
74+
75+
5776
@dataclasses.dataclass(frozen=True)
5877
class TimestampAddOp(base_ops.BinaryOp):
5978
name: typing.ClassVar[str] = "timestamp_add"

tests/system/small/operations/test_timedeltas.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,9 @@ def temporal_dfs(session):
8787

8888
def _assert_series_equal(actual: pd.Series, expected: pd.Series):
8989
"""Helper function specifically for timedelta testing. Don't use it outside of this module."""
90-
# expected[expected.select_dtypes('timedelta64').columns] = expected.select_dtypes('timedelta64').astype(dtypes.TIMEDELTA_DTYPE)
9190
bigframes.testing.assert_series_equal(
9291
actual,
93-
expected, # .convert_dtypes(dtype_backend="pyarrow"),
92+
expected,
9493
check_index_type=False,
9594
check_dtype=False,
9695
)
@@ -158,25 +157,35 @@ def test_timedelta_binary_ops_series_and_literal(
158157

159158

160159
@pytest.mark.parametrize(
161-
("op", "col", "literal"),
160+
("op", "col", "literal", "arrow_supported"),
162161
[
163-
(operator.add, "timedelta_col_1", pd.Timedelta(2, "s").as_unit("us")),
164-
(operator.sub, "timedelta_col_1", pd.Timedelta(2, "s").as_unit("us")),
165-
(operator.truediv, "timedelta_col_1", pd.Timedelta(2, "s").as_unit("us")),
166-
(operator.floordiv, "timedelta_col_1", pd.Timedelta(2, "s").as_unit("us")),
167-
(operator.truediv, "float_col", pd.Timedelta(2, "s").as_unit("us")),
168-
(operator.floordiv, "float_col", pd.Timedelta(2, "s").as_unit("us")),
169-
(operator.mul, "timedelta_col_1", 3),
170-
(operator.mul, "float_col", pd.Timedelta(1, "s").as_unit("us")),
171-
(operator.mod, "timedelta_col_1", pd.Timedelta(7, "s").as_unit("us")),
162+
(operator.add, "timedelta_col_1", pd.Timedelta(2, "s").as_unit("us"), True),
163+
(operator.sub, "timedelta_col_1", pd.Timedelta(2, "s").as_unit("us"), True),
164+
(operator.truediv, "timedelta_col_1", pd.Timedelta(2, "s").as_unit("us"), True),
165+
(
166+
operator.floordiv,
167+
"timedelta_col_1",
168+
pd.Timedelta(2, "s").as_unit("us"),
169+
True,
170+
),
171+
(operator.truediv, "float_col", pd.Timedelta(2, "s").as_unit("us"), True),
172+
(operator.floordiv, "float_col", pd.Timedelta(2, "s").as_unit("us"), True),
173+
(operator.mul, "timedelta_col_1", 3, True),
174+
(operator.mul, "float_col", pd.Timedelta(1, "s").as_unit("us"), False),
175+
(operator.mod, "timedelta_col_1", pd.Timedelta(7, "s").as_unit("us"), False),
172176
],
173177
)
174-
def test_timedelta_binary_ops_literal_and_series(temporal_dfs, op, col, literal):
178+
def test_timedelta_binary_ops_literal_and_series(
179+
temporal_dfs, op, col, literal, arrow_supported
180+
):
175181
bf_df, pd_df = temporal_dfs
176182

177183
actual_result = op(literal, bf_df[col]).to_pandas()
178184

179-
expected_result = op(literal, pd_df[col])
185+
if not arrow_supported:
186+
expected_result = pd_df[col].map(lambda x: op(literal, x))
187+
else:
188+
expected_result = op(literal, pd_df[col])
180189
_assert_series_equal(actual_result, expected_result)
181190

182191

third_party/bigframes_vendored/ibis/expr/operations/numeric.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@ class Round(Value):
158158
def dtype(self):
159159
if self.arg.dtype.is_decimal():
160160
return self.arg.dtype
161-
elif self.digits is None:
162-
return dt.int64
163161
else:
164162
return dt.double
165163

0 commit comments

Comments
 (0)