Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 4740708

Browse files
committed
revert: Roll back to state at 4134750
1 parent 196c111 commit 4740708

File tree

12 files changed

+32
-104
lines changed

12 files changed

+32
-104
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import dataclasses
1717
import functools
1818
import itertools
19-
import json
2019
from typing import cast, Literal, Optional, Sequence, Tuple, Type, TYPE_CHECKING
2120

2221
import pandas as pd
@@ -430,68 +429,7 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
430429
@compile_op.register(json_ops.JSONDecode)
431430
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
432431
assert isinstance(op, json_ops.JSONDecode)
433-
target_dtype = _bigframes_dtype_to_polars_dtype(op.to_type)
434-
if op.safe:
435-
# Polars does not support safe JSON decoding (returning null on failure).
436-
# We use map_elements to provide safe JSON decoding.
437-
def safe_decode(val):
438-
if val is None:
439-
return None
440-
try:
441-
decoded = json.loads(val)
442-
except Exception:
443-
return None
444-
445-
if decoded is None:
446-
return None
447-
448-
if op.to_type == bigframes.dtypes.INT_DTYPE:
449-
if type(decoded) is bool:
450-
return None
451-
if isinstance(decoded, int):
452-
return decoded
453-
if isinstance(decoded, float):
454-
if decoded.is_integer():
455-
return int(decoded)
456-
if isinstance(decoded, str):
457-
try:
458-
return int(decoded)
459-
except Exception:
460-
pass
461-
return None
462-
463-
if op.to_type == bigframes.dtypes.FLOAT_DTYPE:
464-
if type(decoded) is bool:
465-
return None
466-
if isinstance(decoded, (int, float)):
467-
return float(decoded)
468-
if isinstance(decoded, str):
469-
try:
470-
return float(decoded)
471-
except Exception:
472-
pass
473-
return None
474-
475-
if op.to_type == bigframes.dtypes.BOOL_DTYPE:
476-
if isinstance(decoded, bool):
477-
return decoded
478-
if isinstance(decoded, str):
479-
if decoded.lower() == "true":
480-
return True
481-
if decoded.lower() == "false":
482-
return False
483-
return None
484-
485-
if op.to_type == bigframes.dtypes.STRING_DTYPE:
486-
if isinstance(decoded, str):
487-
return decoded
488-
return None
489-
490-
return decoded
491-
492-
return input.map_elements(safe_decode, return_dtype=target_dtype)
493-
494-
return input.str.json_decode(target_dtype)
432+
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])
495433

496434
@compile_op.register(arr_ops.ToArrayOp)
497435
def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr:

bigframes/core/compile/polars/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
391391
return arg
392392

393393
if arg.output_type == dtypes.JSON_DTYPE:
394-
return json_ops.JSONDecode(cast_op.to_type, safe=cast_op.safe).as_expr(arg)
394+
return json_ops.JSONDecode(cast_op.to_type).as_expr(arg)
395395
if (
396396
arg.output_type == dtypes.STRING_DTYPE
397397
and cast_op.to_type == dtypes.DATETIME_DTYPE

bigframes/ml/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
"""
2525

2626
import abc
27-
import typing
28-
from typing import Optional, TypeVar, Union
27+
from typing import cast, Optional, TypeVar, Union
2928
import warnings
3029

3130
import bigframes_vendored.sklearn.base
@@ -134,7 +133,7 @@ def register(self: _T, vertex_ai_model_id: Optional[str] = None) -> _T:
134133
self._bqml_model = self._create_bqml_model() # type: ignore
135134
except AttributeError:
136135
raise RuntimeError("A model must be trained before register.")
137-
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)
136+
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
138137

139138
self._bqml_model.register(vertex_ai_model_id)
140139
return self
@@ -287,7 +286,7 @@ def _predict_and_retry(
287286
bpd.concat([df_result, df_succ]) if df_result is not None else df_succ
288287
)
289288

290-
df_result = typing.cast(
289+
df_result = cast(
291290
bpd.DataFrame,
292291
bpd.concat([df_result, df_fail]) if df_result is not None else df_fail,
293292
)
@@ -307,7 +306,7 @@ def _extract_output_names(self):
307306

308307
output_names = []
309308
for transform_col in self._bqml_model._model._properties["transformColumns"]:
310-
transform_col_dict = typing.cast(dict, transform_col)
309+
transform_col_dict = cast(dict, transform_col)
311310
# pass the columns that are not transformed
312311
if "transformSql" not in transform_col_dict:
313312
continue

bigframes/ml/compose.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
import types
2323
import typing
24-
from typing import Iterable, List, Optional, Set, Tuple, Union
24+
from typing import cast, Iterable, List, Optional, Set, Tuple, Union
2525

2626
from bigframes_vendored import constants
2727
import bigframes_vendored.sklearn.compose._column_transformer
@@ -218,7 +218,7 @@ def camel_to_snake(name):
218218

219219
output_names = []
220220
for transform_col in bq_model._properties["transformColumns"]:
221-
transform_col_dict = typing.cast(dict, transform_col)
221+
transform_col_dict = cast(dict, transform_col)
222222
# pass the columns that are not transformed
223223
if "transformSql" not in transform_col_dict:
224224
continue
@@ -282,7 +282,7 @@ def _merge(
282282
return self # SQLScalarColumnTransformer only work inside ColumnTransformer
283283
feature_columns_sorted = sorted(
284284
[
285-
typing.cast(str, feature_column.name)
285+
cast(str, feature_column.name)
286286
for feature_column in bq_model.feature_columns
287287
]
288288
)

bigframes/ml/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818

1919
import dataclasses
2020
import datetime
21-
import typing
22-
from typing import Callable, Iterable, Mapping, Optional, Union
21+
from typing import Callable, cast, Iterable, Mapping, Optional, Union
2322
import uuid
2423

2524
from google.cloud import bigquery
@@ -377,7 +376,7 @@ def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel:
377376
def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
378377
if vertex_ai_model_id is None:
379378
# vertex id needs to start with letters. https://cloud.google.com/vertex-ai/docs/general/resource-naming
380-
vertex_ai_model_id = "bigframes_" + typing.cast(str, self._model.model_id)
379+
vertex_ai_model_id = "bigframes_" + cast(str, self._model.model_id)
381380

382381
# truncate as Vertex ID only accepts 63 characters, easily exceeding the limit for temp models.
383382
# The possibility of conflicts should be low.

bigframes/ml/imported.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
import typing
20-
from typing import Mapping, Optional
19+
from typing import cast, Mapping, Optional
2120

2221
from google.cloud import bigquery
2322

@@ -79,7 +78,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
7978
if self.model_path is None:
8079
raise ValueError("Model GCS path must be provided.")
8180
self._bqml_model = self._create_bqml_model()
82-
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)
81+
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
8382

8483
(X,) = utils.batch_convert_to_dataframe(X)
8584

@@ -100,7 +99,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> TensorFlowModel:
10099
if self.model_path is None:
101100
raise ValueError("Model GCS path must be provided.")
102101
self._bqml_model = self._create_bqml_model()
103-
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)
102+
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
104103

105104
new_model = self._bqml_model.copy(model_name, replace)
106105
return new_model.session.read_gbq_model(model_name)
@@ -158,7 +157,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
158157
if self.model_path is None:
159158
raise ValueError("Model GCS path must be provided.")
160159
self._bqml_model = self._create_bqml_model()
161-
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)
160+
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
162161

163162
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
164163

@@ -179,7 +178,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> ONNXModel:
179178
if self.model_path is None:
180179
raise ValueError("Model GCS path must be provided.")
181180
self._bqml_model = self._create_bqml_model()
182-
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)
181+
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
183182

184183
new_model = self._bqml_model.copy(model_name, replace)
185184
return new_model.session.read_gbq_model(model_name)
@@ -277,7 +276,7 @@ def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
277276
if self.model_path is None:
278277
raise ValueError("Model GCS path must be provided.")
279278
self._bqml_model = self._create_bqml_model()
280-
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)
279+
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
281280

282281
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
283282

@@ -298,7 +297,7 @@ def to_gbq(self, model_name: str, replace: bool = False) -> XGBoostModel:
298297
if self.model_path is None:
299298
raise ValueError("Model GCS path must be provided.")
300299
self._bqml_model = self._create_bqml_model()
301-
self._bqml_model = typing.cast(core.BqmlModel, self._bqml_model)
300+
self._bqml_model = cast(core.BqmlModel, self._bqml_model)
302301

303302
new_model = self._bqml_model.copy(model_name, replace)
304303
return new_model.session.read_gbq_model(model_name)

bigframes/ml/llm.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
import typing
20-
from typing import Iterable, Literal, Mapping, Optional, Union
19+
from typing import cast, Iterable, Literal, Mapping, Optional, Union
2120
import warnings
2221

2322
import bigframes_vendored.constants as constants
@@ -253,7 +252,7 @@ def predict(
253252

254253
if len(X.columns) == 1:
255254
# BQML identified the column by name
256-
col_label = typing.cast(blocks.Label, X.columns[0])
255+
col_label = cast(blocks.Label, X.columns[0])
257256
X = X.rename(columns={col_label: "content"})
258257

259258
options: dict = {}
@@ -392,7 +391,7 @@ def predict(
392391

393392
if len(X.columns) == 1:
394393
# BQML identified the column by name
395-
col_label = typing.cast(blocks.Label, X.columns[0])
394+
col_label = cast(blocks.Label, X.columns[0])
396395
X = X.rename(columns={col_label: "content"})
397396

398397
# TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input
@@ -605,10 +604,7 @@ def fit(
605604
options["prompt_col"] = X.columns.tolist()[0]
606605

607606
self._bqml_model = self._bqml_model_factory.create_llm_remote_model(
608-
X,
609-
y,
610-
options=options,
611-
connection_name=typing.cast(str, self.connection_name),
607+
X, y, options=options, connection_name=cast(str, self.connection_name)
612608
)
613609
return self
614610

@@ -739,7 +735,7 @@ def predict(
739735

740736
if len(X.columns) == 1:
741737
# BQML identified the column by name
742-
col_label = typing.cast(blocks.Label, X.columns[0])
738+
col_label = cast(blocks.Label, X.columns[0])
743739
X = X.rename(columns={col_label: "prompt"})
744740

745741
options: dict = {
@@ -824,8 +820,8 @@ def score(
824820
)
825821

826822
# BQML identified the column by name
827-
X_col_label = typing.cast(blocks.Label, X.columns[0])
828-
y_col_label = typing.cast(blocks.Label, y.columns[0])
823+
X_col_label = cast(blocks.Label, X.columns[0])
824+
y_col_label = cast(blocks.Label, y.columns[0])
829825
X = X.rename(columns={X_col_label: "input_text"})
830826
y = y.rename(columns={y_col_label: "output_text"})
831827

@@ -1037,7 +1033,7 @@ def predict(
10371033

10381034
if len(X.columns) == 1:
10391035
# BQML identified the column by name
1040-
col_label = typing.cast(blocks.Label, X.columns[0])
1036+
col_label = cast(blocks.Label, X.columns[0])
10411037
X = X.rename(columns={col_label: "prompt"})
10421038

10431039
options = {

bigframes/ml/model_selection.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import inspect
2121
from itertools import chain
2222
import time
23-
import typing
24-
from typing import Generator, List, Optional, Union
23+
from typing import cast, Generator, List, Optional, Union
2524

2625
import bigframes_vendored.sklearn.model_selection._split as vendored_model_selection_split
2726
import bigframes_vendored.sklearn.model_selection._validation as vendored_model_selection_validation
@@ -100,10 +99,10 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra
10099
train_dfs.append(train)
101100
test_dfs.append(test)
102101

103-
train_df = typing.cast(
102+
train_df = cast(
104103
bpd.DataFrame, bpd.concat(train_dfs).drop(columns="bigframes_stratify_col")
105104
)
106-
test_df = typing.cast(
105+
test_df = cast(
107106
bpd.DataFrame, bpd.concat(test_dfs).drop(columns="bigframes_stratify_col")
108107
)
109108
return [train_df, test_df]

bigframes/ml/preprocessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import typing
21-
from typing import Iterable, List, Literal, Optional, Union
21+
from typing import cast, Iterable, List, Literal, Optional, Union
2222

2323
import bigframes_vendored.sklearn.preprocessing._data
2424
import bigframes_vendored.sklearn.preprocessing._discretization
@@ -470,7 +470,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[OneHotEncoder, str]:
470470
s = sql[sql.find("(") + 1 : sql.find(")")]
471471
col_label, drop_str, top_k, frequency_threshold = s.split(", ")
472472
drop = (
473-
typing.cast(Literal["most_frequent"], "most_frequent")
473+
cast(Literal["most_frequent"], "most_frequent")
474474
if drop_str.lower() == "'most_frequent'"
475475
else None
476476
)

bigframes/operations/json_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def output_type(self, *input_types):
220220
class JSONDecode(base_ops.UnaryOp):
221221
name: typing.ClassVar[str] = "json_decode"
222222
to_type: dtypes.Dtype
223-
safe: bool = False
224223

225224
def output_type(self, *input_types):
226225
input_type = input_types[0]

0 commit comments

Comments
 (0)