Skip to content

Commit 769534f

Browse files
committed
Fix AI tests signature mismatch
1 parent e625cd2 commit 769534f

1 file changed

Lines changed: 35 additions & 11 deletions

File tree

  • packages/bigframes/tests/system/small/bigquery

packages/bigframes/tests/system/small/bigquery/test_ai.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@
2626
from bigframes.testing import utils as test_utils
2727

2828

29-
def _create_mock_obj_ref_df(session, uris, name="image"):
29+
def _create_mock_obj_ref_df(session, uris, name="image", connection=None):
3030
df = bpd.DataFrame({name: uris}, session=session)
31+
# Convert string URIs to ObjectRef structs
32+
if connection is None:
33+
connection = "us.bigframes-rf-conn"
34+
df[name] = bbq.obj.make_ref(df[name], authorizer=connection)
35+
3136
table_id = f"bigframes-dev.bigframes_tests_sys.tmp_obj_ref_{uuid.uuid4().hex}"
3237
df.to_gbq(table_id, if_exists="replace")
3338

@@ -41,6 +46,7 @@ def _create_mock_obj_ref_df(session, uris, name="image"):
4146
field_type=field.field_type,
4247
mode=field.mode,
4348
description="bigframes_dtype: OBJ_REF_DTYPE",
49+
fields=field.fields,
4450
)
4551
break
4652
table.schema = schema
@@ -183,9 +189,12 @@ def test_ai_generate_bool(session):
183189
)
184190

185191

186-
def test_ai_generate_bool_multi_model(session):
192+
def test_ai_generate_bool_multi_model(session, bq_connection):
187193
df = _create_mock_obj_ref_df(
188-
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
194+
session,
195+
["gs://cloud-samples-data/vision/ocr/sign.jpg"],
196+
name="image",
197+
connection=bq_connection,
189198
)
190199

191200
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
@@ -221,9 +230,12 @@ def test_ai_generate_int(session):
221230
)
222231

223232

224-
def test_ai_generate_int_multi_model(session):
233+
def test_ai_generate_int_multi_model(session, bq_connection):
225234
df = _create_mock_obj_ref_df(
226-
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
235+
session,
236+
["gs://cloud-samples-data/vision/ocr/sign.jpg"],
237+
name="image",
238+
connection=bq_connection,
227239
)
228240

229241
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
@@ -261,9 +273,12 @@ def test_ai_generate_double(session):
261273
)
262274

263275

264-
def test_ai_generate_double_multi_model(session):
276+
def test_ai_generate_double_multi_model(session, bq_connection):
265277
df = _create_mock_obj_ref_df(
266-
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
278+
session,
279+
["gs://cloud-samples-data/vision/ocr/sign.jpg"],
280+
name="image",
281+
connection=bq_connection,
267282
)
268283

269284
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
@@ -363,7 +378,10 @@ def test_ai_if(session):
363378

364379
def test_ai_if_multi_model(session, bq_connection):
365380
df = _create_mock_obj_ref_df(
366-
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
381+
session,
382+
["gs://cloud-samples-data/vision/ocr/sign.jpg"],
383+
name="image",
384+
connection=bq_connection,
367385
)
368386

369387
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
@@ -393,7 +411,10 @@ def test_ai_classify_with_examples(session):
393411

394412
def test_ai_classify_multi_model(session, bq_connection):
395413
df = _create_mock_obj_ref_df(
396-
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
414+
session,
415+
["gs://cloud-samples-data/vision/ocr/sign.jpg"],
416+
name="image",
417+
connection=bq_connection,
397418
)
398419

399420
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
@@ -413,9 +434,12 @@ def test_ai_score(session):
413434
assert result.dtype == dtypes.FLOAT_DTYPE
414435

415436

416-
def test_ai_score_multi_model(session):
437+
def test_ai_score_multi_model(session, bq_connection):
417438
df = _create_mock_obj_ref_df(
418-
session, ["gs://cloud-samples-data/vision/ocr/sign.jpg"], name="image"
439+
session,
440+
["gs://cloud-samples-data/vision/ocr/sign.jpg"],
441+
name="image",
442+
connection=bq_connection,
419443
)
420444
image_runtime = bbq.obj.get_access_url(df["image"], mode="R")
421445
prompt = ("Rank the liveliness of ", image_runtime, "on the scale from 1 to 3")

0 commit comments

Comments
 (0)