diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index 304428ef2f..6601fe5ae5 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -772,6 +772,13 @@ def convert_schema_field( ) -> typing.Tuple[str, Dtype]: is_repeated = field.mode == "REPEATED" if field.field_type == "RECORD": + if field.description == OBJ_REF_DESCRIPTION_TAG: + bf_dtype = OBJ_REF_DTYPE # type: ignore + if is_repeated: + pa_type = pa.list_(bigframes_dtype_to_arrow_dtype(bf_dtype)) + bf_dtype = pd.ArrowDtype(pa_type) + return field.name, bf_dtype + mapped_fields = map(convert_schema_field, field.fields) fields = [] for name, dtype in mapped_fields: @@ -815,7 +822,11 @@ def convert_to_schema_field( ) inner_field = convert_to_schema_field(name, inner_type, overrides) return google.cloud.bigquery.SchemaField( - name, inner_field.field_type, mode="REPEATED", fields=inner_field.fields + name, + inner_field.field_type, + mode="REPEATED", + fields=inner_field.fields, + description=inner_field.description, ) if pa.types.is_struct(bigframes_dtype.pyarrow_dtype): inner_fields: list[google.cloud.bigquery.SchemaField] = [] @@ -827,6 +838,14 @@ def convert_to_schema_field( convert_to_schema_field(field.name, inner_bf_type, overrides) ) + if bigframes_dtype == OBJ_REF_DTYPE: + return google.cloud.bigquery.SchemaField( + name, + "RECORD", + fields=inner_fields, + description=OBJ_REF_DESCRIPTION_TAG, + ) + return google.cloud.bigquery.SchemaField( name, "RECORD", fields=inner_fields ) @@ -971,6 +990,7 @@ def lcd_type_or_throw(dtype1: Dtype, dtype2: Dtype) -> Dtype: TIMEDELTA_DESCRIPTION_TAG = "#microseconds" +OBJ_REF_DESCRIPTION_TAG = "bigframes_dtype: OBJ_REF_DTYPE" def contains_db_dtypes_json_arrow_type(type_): diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index 1e240a841c..fbcdfd33f5 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -334,13 +334,14 @@ def _export_gbq( session=array_value.session, ) - has_timedelta_col = any( - t == bigframes.dtypes.TIMEDELTA_DTYPE for t in array_value.schema.dtypes + has_special_dtype_col = any( + t in (bigframes.dtypes.TIMEDELTA_DTYPE, bigframes.dtypes.OBJ_REF_DTYPE) + for t in array_value.schema.dtypes ) - if spec.if_exists != "append" and has_timedelta_col: + if spec.if_exists != "append" and has_special_dtype_col: # Only update schema if this is not modifying an existing table, and the - # new table contains timedelta columns. + # new table contains special columns (like timedelta or obj_ref). table = self.bqclient.get_table(spec.table) table.schema = array_value.schema.to_bigquery() self.bqclient.update_table(table, ["schema"]) diff --git a/tests/system/small/test_dataframe_io.py b/tests/system/small/test_dataframe_io.py index fece679d06..3da3544cbb 100644 --- a/tests/system/small/test_dataframe_io.py +++ b/tests/system/small/test_dataframe_io.py @@ -1002,6 +1002,28 @@ def test_to_gbq_timedelta_tag_ignored_when_appending(bigquery_client, dataset_id assert table.schema[0].description is None +def test_to_gbq_obj_ref(session, dataset_id: str, bigquery_client): + destination_table = f"{dataset_id}.test_to_gbq_obj_ref" + sql = """ + SELECT + 'gs://cloud-samples-data/vision/ocr/sign.jpg' AS uri_col + """ + df = session.read_gbq(sql) + df["obj_ref_col"] = df["uri_col"].str.to_blob() + df = df.drop(columns=["uri_col"]) + + df.to_gbq(destination_table) + + table = bigquery_client.get_table(destination_table) + obj_ref_field = next(f for f in table.schema if f.name == "obj_ref_col") + assert obj_ref_field.field_type == "RECORD" + assert obj_ref_field.description == "bigframes_dtype: OBJ_REF_DTYPE" + + reloaded_df = session.read_gbq(destination_table) + assert reloaded_df["obj_ref_col"].dtype == dtypes.OBJ_REF_DTYPE + assert len(reloaded_df) == 1 + + @pytest.mark.parametrize( ("index"), [True, False], diff --git a/tests/unit/test_dtypes.py b/tests/unit/test_dtypes.py index 0e600de964..bb2b57d409 100644 --- a/tests/unit/test_dtypes.py +++ b/tests/unit/test_dtypes.py @@ -71,3 +71,11 @@ def test_infer_literal_type_arrow_scalar(scalar, expected_dtype): ) def test_contains_db_dtypes_json_arrow_type(type_, expected): assert bigframes.dtypes.contains_db_dtypes_json_arrow_type(type_) == expected + + +def test_convert_to_schema_field_list_description(): + bf_dtype = bigframes.dtypes.OBJ_REF_DTYPE + list_bf_dtype = bigframes.dtypes.list_type(bf_dtype) + field = bigframes.dtypes.convert_to_schema_field("my_list", list_bf_dtype) + assert field.description == "bigframes_dtype: OBJ_REF_DTYPE" + assert field.mode == "REPEATED"