Skip to content

Commit 690c15d

Browse files
refactor(bigframes): Extract json conversions to distinct ops (#17473)
1 parent 1890637 commit 690c15d

18 files changed

Lines changed: 204 additions & 174 deletions

File tree

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -922,35 +922,6 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):
922922
elif to_type == ibis_dtypes.time:
923923
return x_converted.time()
924924

925-
if to_type == ibis_dtypes.json:
926-
if x.type() == ibis_dtypes.string:
927-
return parse_json_in_safe(x) if op.safe else parse_json(x)
928-
if x.type() == ibis_dtypes.bool:
929-
x_bool = typing.cast(
930-
ibis_types.StringValue,
931-
bigframes.core.compile.ibis_types.cast_ibis_value(
932-
x, ibis_dtypes.string, safe=op.safe
933-
),
934-
).lower()
935-
return parse_json_in_safe(x_bool) if op.safe else parse_json(x_bool)
936-
if x.type() in (ibis_dtypes.int64, ibis_dtypes.float64):
937-
x_str = bigframes.core.compile.ibis_types.cast_ibis_value(
938-
x, ibis_dtypes.string, safe=op.safe
939-
)
940-
return parse_json_in_safe(x_str) if op.safe else parse_json(x_str)
941-
942-
if x.type() == ibis_dtypes.json:
943-
if to_type == ibis_dtypes.int64:
944-
return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x)
945-
if to_type == ibis_dtypes.float64:
946-
return (
947-
cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x)
948-
)
949-
if to_type == ibis_dtypes.bool:
950-
return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x)
951-
if to_type == ibis_dtypes.string:
952-
return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x)
953-
954925
# TODO: either inline this function, or push rest of this op into the function
955926
return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type, safe=op.safe)
956927

@@ -1193,9 +1164,27 @@ def parse_json_op_impl(x: ibis_types.Value, op: ops.ParseJSON):
11931164
return parse_json(json_str=x)
11941165

11951166

1196-
@scalar_op_compiler.register_unary_op(ops.ToJSON)
1197-
def to_json_op_impl(json_obj: ibis_types.Value):
1198-
return to_json(json_obj=json_obj)
1167+
@scalar_op_compiler.register_unary_op(ops.ToJSON, pass_op=True)
1168+
def to_json_op_impl(x: ibis_types.Value, op: ops.ToJSON):
1169+
if x.type() == ibis_dtypes.string:
1170+
return parse_json_in_safe(x) if op.safe else parse_json(x)
1171+
return x.isnull().ifelse(ibis.null().cast(ibis_dtypes.json), to_json(x))
1172+
1173+
1174+
@scalar_op_compiler.register_unary_op(ops.JSONDecode, pass_op=True)
1175+
def json_decode_op_impl(x: ibis_types.Value, op: ops.JSONDecode):
1176+
to_type = bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(
1177+
op.to_type
1178+
)
1179+
if to_type == ibis_dtypes.int64:
1180+
return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x)
1181+
if to_type == ibis_dtypes.float64:
1182+
return cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x)
1183+
if to_type == ibis_dtypes.bool:
1184+
return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x)
1185+
if to_type == ibis_dtypes.string:
1186+
return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x)
1187+
raise TypeError(f"Cannot cast from JSON to type {to_type}")
11991188

12001189

12011190
@scalar_op_compiler.register_unary_op(ops.ToJSONString)

packages/bigframes/bigframes/core/compile/polars/compiler.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,20 @@ class PolarsExpressionCompiler:
138138
Should be extended to dispatch based on bigframes schema types.
139139
"""
140140

141-
@functools.singledispatchmethod
141+
_expr_types: dict[int, bigframes.dtypes.ExpressionType] = dataclasses.field(
142+
default_factory=dict, init=False, compare=False
143+
)
144+
142145
def compile_expression(self, expression: ex.Expression) -> pl.Expr:
146+
res = self._compile_expression(expression)
147+
self._expr_types[id(res)] = expression.output_type
148+
return res
149+
150+
@functools.singledispatchmethod
151+
def _compile_expression(self, expression: ex.Expression) -> pl.Expr:
143152
raise NotImplementedError(f"Cannot compile expression: {expression}")
144153

145-
@compile_expression.register
154+
@_compile_expression.register
146155
def _(
147156
self,
148157
expression: ex.ScalarConstantExpression,
@@ -159,21 +168,21 @@ def _(
159168

160169
return pl.lit(value, _bigframes_dtype_to_polars_dtype(expression.dtype))
161170

162-
@compile_expression.register
171+
@_compile_expression.register
163172
def _(
164173
self,
165174
expression: ex.DerefOp,
166175
) -> pl.Expr:
167176
return pl.col(expression.id.sql)
168177

169-
@compile_expression.register
178+
@_compile_expression.register
170179
def _(
171180
self,
172181
expression: ex.ResolvedDerefOp,
173182
) -> pl.Expr:
174183
return pl.col(expression.id.sql)
175184

176-
@compile_expression.register
185+
@_compile_expression.register
177186
def _(
178187
self,
179188
expression: ex.OpExpression,
@@ -478,10 +487,21 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
478487
)
479488

480489
@compile_op.register(json_ops.JSONDecode)
481-
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
490+
def _(self, op: json_ops.JSONDecode, input: pl.Expr) -> pl.Expr:
482491
assert isinstance(op, json_ops.JSONDecode)
483492
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])
484493

494+
@compile_op.register(json_ops.ToJSON)
495+
def _(self, op: json_ops.ToJSON, input: pl.Expr) -> pl.Expr:
496+
from_type = self._expr_types.get(id(input))
497+
if from_type in (
498+
bigframes.dtypes.STRING_DTYPE,
499+
bigframes.dtypes.JSON_DTYPE,
500+
):
501+
return input
502+
else:
503+
return input.cast(pl.String())
504+
485505
@compile_op.register(arr_ops.ToArrayOp)
486506
def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr:
487507
return pl.concat_list(*inputs)

packages/bigframes/bigframes/core/compile/polars/lowering.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
comparison_ops,
2727
datetime_ops,
2828
generic_ops,
29-
json_ops,
3029
numeric_ops,
3130
string_ops,
3231
)
@@ -412,9 +411,6 @@ def _coerce_comparables(
412411
def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
413412
if arg.output_type == cast_op.to_type:
414413
return arg
415-
416-
if arg.output_type == dtypes.JSON_DTYPE:
417-
return json_ops.JSONDecode(cast_op.to_type).as_expr(arg)
418414
if (
419415
arg.output_type == dtypes.STRING_DTYPE
420416
and cast_op.to_type == dtypes.DATETIME_DTYPE

packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
3636
sg_to_type = sqlglot_types.from_bigframes_dtype(to_type)
3737
sg_expr = expr.expr
3838

39-
if to_type == dtypes.JSON_DTYPE:
40-
return _cast_to_json(expr, op)
41-
42-
if from_type == dtypes.JSON_DTYPE:
43-
return _cast_from_json(expr, op)
44-
4539
if to_type == dtypes.INT_DTYPE:
4640
result = _cast_to_int(expr, op)
4741
if result is not None:
@@ -251,35 +245,6 @@ def _(*values: TypedExpr) -> sge.Expression:
251245

252246

253247
# Helper functions
254-
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
255-
from_type = expr.dtype
256-
sg_expr = expr.expr
257-
258-
if from_type == dtypes.STRING_DTYPE:
259-
func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON"
260-
return sge.func(func_name, sg_expr)
261-
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
262-
sg_expr = sge.Cast(this=sg_expr, to="STRING")
263-
return sge.func("PARSE_JSON", sg_expr)
264-
raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}")
265-
266-
267-
def _cast_from_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
268-
to_type = op.to_type
269-
sg_expr = expr.expr
270-
func_name = ""
271-
if to_type == dtypes.INT_DTYPE:
272-
func_name = "INT64"
273-
elif to_type == dtypes.FLOAT_DTYPE:
274-
func_name = "FLOAT64"
275-
elif to_type == dtypes.BOOL_DTYPE:
276-
func_name = "BOOL"
277-
elif to_type == dtypes.STRING_DTYPE:
278-
func_name = "STRING"
279-
if func_name:
280-
func_name = "SAFE." + func_name if op.safe else func_name
281-
return sge.func(func_name, sg_expr)
282-
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")
283248

284249

285250
def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None:

packages/bigframes/bigframes/core/compile/sqlglot/expressions/json_ops.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import bigframes_vendored.sqlglot.expressions as sge
1818

1919
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
20+
from bigframes import dtypes
2021
from bigframes import operations as ops
2122
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2223

@@ -69,9 +70,39 @@ def _(expr: TypedExpr) -> sge.Expression:
6970
return sge.func("PARSE_JSON", expr.expr)
7071

7172

72-
@register_unary_op(ops.ToJSON)
73-
def _(expr: TypedExpr) -> sge.Expression:
74-
return sge.func("TO_JSON", expr.expr)
73+
@register_unary_op(ops.ToJSON, pass_op=True)
74+
def _(expr: TypedExpr, op: ops.ToJSON) -> sge.Expression:
75+
from_type = expr.dtype
76+
sg_expr = expr.expr
77+
78+
# Parsing really should be a distinct operation from serialization, but
79+
# this was the way things were intially launched.
80+
if from_type == dtypes.STRING_DTYPE:
81+
func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON"
82+
return sge.func(func_name, sg_expr)
83+
else:
84+
return sge.func(
85+
"IF", sg_expr.is_(sge.Null()), sge.Null(), sge.func("TO_JSON", sg_expr)
86+
)
87+
88+
89+
@register_unary_op(ops.JSONDecode, pass_op=True)
90+
def _(expr: TypedExpr, op: ops.JSONDecode) -> sge.Expression:
91+
to_type = op.to_type
92+
sg_expr = expr.expr
93+
func_name = ""
94+
if to_type == dtypes.INT_DTYPE:
95+
func_name = "INT64"
96+
elif to_type == dtypes.FLOAT_DTYPE:
97+
func_name = "FLOAT64"
98+
elif to_type == dtypes.BOOL_DTYPE:
99+
func_name = "BOOL"
100+
elif to_type == dtypes.STRING_DTYPE:
101+
func_name = "STRING"
102+
if func_name:
103+
func_name = "SAFE." + func_name if op.safe else func_name
104+
return sge.func(func_name, sg_expr)
105+
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")
75106

76107

77108
@register_unary_op(ops.ToJSONString)

packages/bigframes/bigframes/dataframe.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -442,17 +442,41 @@ def astype(
442442
if errors not in ["raise", "null"]:
443443
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
444444

445+
if isinstance(dtype, dict):
446+
for col in dtype:
447+
if col not in self.columns:
448+
raise KeyError(
449+
f"Only Column Names are allowed in dtypes dict. '{col}' is not in the columns."
450+
)
451+
445452
safe_cast = errors == "null"
446453

447-
if isinstance(dtype, dict):
448-
result = self.copy()
449-
for col, to_type in dtype.items():
450-
result[col] = result[col].astype(to_type)
451-
return result
454+
exprs: list[ex.Expression] = []
455+
for col_id, col_label in zip(
456+
self._block.value_columns, self._block.column_labels
457+
):
458+
from_type = self._block._column_type(col_id)
452459

453-
dtype = bigframes.dtypes.bigframes_type(dtype)
460+
if isinstance(dtype, dict):
461+
if col_label not in dtype:
462+
exprs.append(ex.deref(col_id))
463+
continue
464+
to_type = bigframes.dtypes.bigframes_type(dtype[col_label])
465+
else:
466+
to_type = bigframes.dtypes.bigframes_type(dtype)
467+
468+
op: ops.UnaryOp
469+
if to_type == bigframes.dtypes.JSON_DTYPE:
470+
op = ops.ToJSON(safe=safe_cast)
471+
elif from_type == bigframes.dtypes.JSON_DTYPE:
472+
op = ops.JSONDecode(to_type=to_type, safe=safe_cast)
473+
else:
474+
op = ops.AsTypeOp(to_type=to_type, safe=safe_cast)
454475

455-
return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))
476+
exprs.append(op.as_expr(ex.deref(col_id)))
477+
478+
block = self._block.project_exprs(exprs, labels=self.columns, drop=True)
479+
return DataFrame(block)
456480

457481
def _should_sql_have_index(self) -> bool:
458482
"""Should the SQL we pass to BQML and other I/O include the index?"""

packages/bigframes/bigframes/dtypes.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,30 @@ def is_json_like(type_: ExpressionType) -> bool:
364364
return type_ == JSON_DTYPE or type_ == STRING_DTYPE # Including JSON string
365365

366366

367-
def is_json_encoding_type(type_: ExpressionType) -> bool:
367+
def is_json_encoding_type(type_: ExpressionType, strict: bool = False) -> bool:
368368
# Types can be converted into JSON.
369369
# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_encodings
370-
return type_ != GEO_DTYPE
370+
if is_array_like(type_):
371+
return is_json_encoding_type(get_array_inner_type(type_), strict=strict)
372+
if is_struct_like(type_):
373+
return all(
374+
is_json_encoding_type(field_type, strict=strict)
375+
for field_type in get_struct_fields(type_).values()
376+
)
377+
378+
if strict:
379+
# Strict are the types (mostly) defined by json spec, with no/minimal
380+
# encoding/decoding involved. So no temporal types.
381+
return type_ in (
382+
INT_DTYPE,
383+
FLOAT_DTYPE,
384+
BOOL_DTYPE,
385+
STRING_DTYPE,
386+
JSON_DTYPE,
387+
)
388+
else:
389+
# GoogleSQL implementation handles anything but GEO
390+
return type_ != GEO_DTYPE
371391

372392

373393
def is_numeric(type_: ExpressionType, include_bool: bool = True) -> bool:

packages/bigframes/bigframes/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
)
129129
from bigframes.operations.googlesql import GoogleSqlScalarOp
130130
from bigframes.operations.json_ops import (
131+
JSONDecode,
131132
JSONExtract,
132133
JSONExtractArray,
133134
JSONExtractStringArray,
@@ -382,6 +383,7 @@
382383
"FloorDtOp",
383384
"IntegerLabelToDatetimeOp",
384385
# JSON ops
386+
"JSONDecode",
385387
"JSONExtract",
386388
"JSONExtractArray",
387389
"JSONExtractStringArray",

0 commit comments

Comments
 (0)