Skip to content

Commit 2f792ab

Browse files
feat(bigframes): Add numpy ufunc support to col expressions (googleapis#16554)
1 parent 5ccfd64 commit 2f792ab

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

packages/bigframes/bigframes/core/col.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import TYPE_CHECKING, Any, Hashable, Literal
1818

1919
import bigframes_vendored.pandas.core.col as pd_col
20+
import numpy
2021

2122
import bigframes.core.expression as bf_expression
2223
import bigframes.operations as bf_ops
@@ -56,14 +57,10 @@ def _apply_binary_op(
5657
alignment: Literal["outer", "left"] = "outer",
5758
reverse: bool = False,
5859
):
59-
if isinstance(other, Expression):
60-
other_value = other._value
61-
else:
62-
other_value = bf_expression.const(other)
6360
if reverse:
64-
return Expression(op.as_expr(other_value, self._value))
61+
return Expression(op.as_expr(_as_bf_expr(other), self._value))
6562
else:
66-
return Expression(op.as_expr(self._value, other_value))
63+
return Expression(op.as_expr(self._value, _as_bf_expr(other)))
6764

6865
def __add__(self, other: Any) -> Expression:
6966
return self._apply_binary_op(other, bf_ops.add_op)
@@ -164,13 +161,42 @@ def dt(self) -> datetimes.DatetimeSimpleMethods:
164161

165162
return datetimes.DatetimeSimpleMethods(self)
166163

164+
def __array_ufunc__(
165+
self, ufunc: numpy.ufunc, method: str, *inputs, **kwargs
166+
) -> Expression:
167+
"""Used to support numpy ufuncs.
168+
See: https://numpy.org/doc/stable/reference/ufuncs.html
169+
"""
170+
# Only __call__ supported with zero arguments
171+
if method != "__call__" or len(inputs) > 2 or len(kwargs) > 0:
172+
return NotImplemented
173+
174+
if len(inputs) == 1 and ufunc in bf_ops.NUMPY_TO_OP:
175+
op = bf_ops.NUMPY_TO_OP[ufunc]
176+
return Expression(op.as_expr(self._value))
177+
if len(inputs) == 2 and ufunc in bf_ops.NUMPY_TO_BINOP:
178+
binop = bf_ops.NUMPY_TO_BINOP[ufunc]
179+
if inputs[0] is self:
180+
return Expression(binop.as_expr(self._value, _as_bf_expr(inputs[1])))
181+
else:
182+
return Expression(binop.as_expr(_as_bf_expr(inputs[0]), self._value))
183+
184+
return NotImplemented
185+
186+
# keep this last as str declaration can shadow builtins.str
167187
@property
168188
def str(self) -> strings.StringMethods:
169189
import bigframes.operations.strings as strings
170190

171191
return strings.StringMethods(self)
172192

173193

194+
def _as_bf_expr(arg: Any) -> bf_expression.Expression:
195+
if isinstance(arg, Expression):
196+
return arg._value
197+
return bf_expression.const(arg)
198+
199+
174200
def col(col_name: Hashable) -> Expression:
175201
return Expression(bf_expression.free_var(col_name))
176202

packages/bigframes/tests/unit/test_col.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import bigframes
2323
import bigframes.pandas as bpd
2424
from bigframes.testing.utils import assert_frame_equal, convert_pandas_dtypes
25+
import numpy as np
2526

2627
pytest.importorskip("polars")
2728
pytest.importorskip("pandas", minversion="3.0.0")
@@ -246,3 +247,23 @@ def test_col_dt_accessor(scalars_dfs):
246247

247248
# int64[pyarrow] vs Int64
248249
assert_frame_equal(bf_result, pd_result, check_dtype=False)
250+
251+
252+
def test_col_numpy_ufunc(scalars_dfs):
253+
scalars_df, scalars_pandas_df = scalars_dfs
254+
255+
bf_result = scalars_df.assign(
256+
sqrt=np.sqrt(bpd.col("float64_col")), # type: ignore
257+
add_const=np.add(bpd.col("float64_col"), 2.4), # type: ignore
258+
radd_const=np.add(2.4, bpd.col("float64_col")), # type: ignore
259+
add_cols=np.add(bpd.col("float64_col"), bpd.col("int64_col")), # type: ignore
260+
).to_pandas()
261+
pd_result = scalars_pandas_df.assign(
262+
sqrt=np.sqrt(pd.col("float64_col")), # type: ignore
263+
add_const=np.add(pd.col("float64_col"), 2.4), # type: ignore
264+
radd_const=np.add(2.4, pd.col("float64_col")), # type: ignore
265+
add_cols=np.add(pd.col("float64_col"), pd.col("int64_col")), # type: ignore
266+
)
267+
268+
# int64[pyarrow] vs Int64
269+
assert_frame_equal(bf_result, pd_result, check_dtype=False)

0 commit comments

Comments
 (0)