|
16 | 16 | import dataclasses |
17 | 17 | import functools |
18 | 18 | import itertools |
| 19 | +import json |
19 | 20 | from typing import cast, Literal, Optional, Sequence, Tuple, Type, TYPE_CHECKING |
20 | 21 |
|
21 | 22 | import pandas as pd |
@@ -429,13 +430,68 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: |
429 | 430 | @compile_op.register(json_ops.JSONDecode) |
430 | 431 | def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: |
431 | 432 | assert isinstance(op, json_ops.JSONDecode) |
| 433 | + target_dtype = _bigframes_dtype_to_polars_dtype(op.to_type) |
432 | 434 | if op.safe: |
433 | 435 | # Polars does not support safe JSON decoding (returning null on failure). |
434 | | - # Fallback to BigQuery execution. |
435 | | - raise NotImplementedError( |
436 | | - "Safe JSON decoding is not supported in Polars executor." |
437 | | - ) |
438 | | - return input.str.json_decode(_bigframes_dtype_to_polars_dtype(op.to_type)) |
| 436 | + # We use map_elements to provide safe JSON decoding. |
| 437 | + def safe_decode(val): |
| 438 | + if val is None: |
| 439 | + return None |
| 440 | + try: |
| 441 | + decoded = json.loads(val) |
| 442 | + except Exception: |
| 443 | + return None |
| 444 | + |
| 445 | + if decoded is None: |
| 446 | + return None |
| 447 | + |
| 448 | + if op.to_type == bigframes.dtypes.INT_DTYPE: |
| 449 | + if type(decoded) is bool: |
| 450 | + return None |
| 451 | + if isinstance(decoded, int): |
| 452 | + return decoded |
| 453 | + if isinstance(decoded, float): |
| 454 | + if decoded.is_integer(): |
| 455 | + return int(decoded) |
| 456 | + if isinstance(decoded, str): |
| 457 | + try: |
| 458 | + return int(decoded) |
| 459 | + except Exception: |
| 460 | + pass |
| 461 | + return None |
| 462 | + |
| 463 | + if op.to_type == bigframes.dtypes.FLOAT_DTYPE: |
| 464 | + if type(decoded) is bool: |
| 465 | + return None |
| 466 | + if isinstance(decoded, (int, float)): |
| 467 | + return float(decoded) |
| 468 | + if isinstance(decoded, str): |
| 469 | + try: |
| 470 | + return float(decoded) |
| 471 | + except Exception: |
| 472 | + pass |
| 473 | + return None |
| 474 | + |
| 475 | + if op.to_type == bigframes.dtypes.BOOL_DTYPE: |
| 476 | + if isinstance(decoded, bool): |
| 477 | + return decoded |
| 478 | + if isinstance(decoded, str): |
| 479 | + if decoded.lower() == "true": |
| 480 | + return True |
| 481 | + if decoded.lower() == "false": |
| 482 | + return False |
| 483 | + return None |
| 484 | + |
| 485 | + if op.to_type == bigframes.dtypes.STRING_DTYPE: |
| 486 | + if isinstance(decoded, str): |
| 487 | + return decoded |
| 488 | + return None |
| 489 | + |
| 490 | + return decoded |
| 491 | + |
| 492 | + return input.map_elements(safe_decode, return_dtype=target_dtype) |
| 493 | + |
| 494 | + return input.str.json_decode(target_dtype) |
439 | 495 |
|
440 | 496 | @compile_op.register(arr_ops.ToArrayOp) |
441 | 497 | def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr: |
|
0 commit comments