Skip to content

Commit cd1b4ac

Browse files
cleop-googlecopybara-github
authored andcommitted
fix: GenAI SDK client(multimodal) - fix Pydantic validation errors when using create_* in some cases
PiperOrigin-RevId: 901268856
1 parent f2d73fd commit cd1b4ac

2 files changed

Lines changed: 172 additions & 33 deletions

File tree

vertexai/_genai/_datasets_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,6 @@
3434
T = TypeVar("T", bound=BaseModel)
3535

3636

37-
def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
38-
"""Creates a model from a response."""
39-
model_field_names = model_type.model_fields.keys()
40-
filtered_response = {}
41-
for key, value in response.items():
42-
snake_key = common.camel_to_snake(key)
43-
if snake_key in model_field_names:
44-
filtered_response[snake_key] = value
45-
return model_type(**filtered_response)
46-
47-
4837
def validate_multimodal_dataset_bigquery_uri(
4938
multimodal_dataset: common.MultimodalDataset,
5039
) -> None:

vertexai/_genai/datasets.py

Lines changed: 172 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,24 @@ def create_from_bigquery(
947947
operation=multimodal_dataset_operation,
948948
timeout_seconds=config.timeout,
949949
)
950-
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
950+
return types.MultimodalDataset._from_response(
951+
response=response,
952+
kwargs=(
953+
{
954+
"config": {
955+
"response_schema": getattr(config, "response_schema", None),
956+
"response_json_schema": getattr(
957+
config, "response_json_schema", None
958+
),
959+
"include_all_fields": getattr(
960+
config, "include_all_fields", None
961+
),
962+
}
963+
}
964+
if config
965+
else {}
966+
),
967+
)
951968

952969
def create_from_pandas(
953970
self,
@@ -1267,9 +1284,23 @@ def assess_tuning_resources(
12671284
operation=operation,
12681285
timeout_seconds=config.timeout,
12691286
)
1270-
return _datasets_utils.create_from_response(
1271-
types.TuningResourceUsageAssessmentResult,
1272-
response["tuningResourceUsageAssessmentResult"],
1287+
return types.TuningResourceUsageAssessmentResult._from_response(
1288+
response=response["tuningResourceUsageAssessmentResult"],
1289+
kwargs=(
1290+
{
1291+
"config": {
1292+
"response_schema": getattr(config, "response_schema", None),
1293+
"response_json_schema": getattr(
1294+
config, "response_json_schema", None
1295+
),
1296+
"include_all_fields": getattr(
1297+
config, "include_all_fields", None
1298+
),
1299+
}
1300+
}
1301+
if config
1302+
else {}
1303+
),
12731304
)
12741305

12751306
def assess_tuning_validity(
@@ -1329,9 +1360,23 @@ def assess_tuning_validity(
13291360
operation=operation,
13301361
timeout_seconds=config.timeout,
13311362
)
1332-
return _datasets_utils.create_from_response(
1333-
types.TuningValidationAssessmentResult,
1334-
response["tuningValidationAssessmentResult"],
1363+
return types.TuningValidationAssessmentResult._from_response(
1364+
response=response["tuningValidationAssessmentResult"],
1365+
kwargs=(
1366+
{
1367+
"config": {
1368+
"response_schema": getattr(config, "response_schema", None),
1369+
"response_json_schema": getattr(
1370+
config, "response_json_schema", None
1371+
),
1372+
"include_all_fields": getattr(
1373+
config, "include_all_fields", None
1374+
),
1375+
}
1376+
}
1377+
if config
1378+
else {}
1379+
),
13351380
)
13361381

13371382
def assess_batch_prediction_resources(
@@ -1389,8 +1434,23 @@ def assess_batch_prediction_resources(
13891434
timeout_seconds=config.timeout,
13901435
)
13911436
result = response["batchPredictionResourceUsageAssessmentResult"]
1392-
return _datasets_utils.create_from_response(
1393-
types.BatchPredictionResourceUsageAssessmentResult, result
1437+
return types.BatchPredictionResourceUsageAssessmentResult._from_response(
1438+
response=result,
1439+
kwargs=(
1440+
{
1441+
"config": {
1442+
"response_schema": getattr(config, "response_schema", None),
1443+
"response_json_schema": getattr(
1444+
config, "response_json_schema", None
1445+
),
1446+
"include_all_fields": getattr(
1447+
config, "include_all_fields", None
1448+
),
1449+
}
1450+
}
1451+
if config
1452+
else {}
1453+
),
13941454
)
13951455

13961456
def assess_batch_prediction_validity(
@@ -1448,8 +1508,23 @@ def assess_batch_prediction_validity(
14481508
timeout_seconds=config.timeout,
14491509
)
14501510
result = response["batchPredictionValidationAssessmentResult"]
1451-
return _datasets_utils.create_from_response(
1452-
types.BatchPredictionValidationAssessmentResult, result
1511+
return types.BatchPredictionValidationAssessmentResult._from_response(
1512+
response=result,
1513+
kwargs=(
1514+
{
1515+
"config": {
1516+
"response_schema": getattr(config, "response_schema", None),
1517+
"response_json_schema": getattr(
1518+
config, "response_json_schema", None
1519+
),
1520+
"include_all_fields": getattr(
1521+
config, "include_all_fields", None
1522+
),
1523+
}
1524+
}
1525+
if config
1526+
else {}
1527+
),
14531528
)
14541529

14551530

@@ -2171,7 +2246,24 @@ async def create_from_bigquery(
21712246
operation=multimodal_dataset_operation,
21722247
timeout_seconds=config.timeout,
21732248
)
2174-
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
2249+
return types.MultimodalDataset._from_response(
2250+
response=response,
2251+
kwargs=(
2252+
{
2253+
"config": {
2254+
"response_schema": getattr(config, "response_schema", None),
2255+
"response_json_schema": getattr(
2256+
config, "response_json_schema", None
2257+
),
2258+
"include_all_fields": getattr(
2259+
config, "include_all_fields", None
2260+
),
2261+
}
2262+
}
2263+
if config
2264+
else {}
2265+
),
2266+
)
21752267

21762268
async def create_from_pandas(
21772269
self,
@@ -2489,9 +2581,23 @@ async def assess_tuning_resources(
24892581
operation=operation,
24902582
timeout_seconds=config.timeout,
24912583
)
2492-
return _datasets_utils.create_from_response(
2493-
types.TuningResourceUsageAssessmentResult,
2494-
response["tuningResourceUsageAssessmentResult"],
2584+
return types.TuningResourceUsageAssessmentResult._from_response(
2585+
response=response["tuningResourceUsageAssessmentResult"],
2586+
kwargs=(
2587+
{
2588+
"config": {
2589+
"response_schema": getattr(config, "response_schema", None),
2590+
"response_json_schema": getattr(
2591+
config, "response_json_schema", None
2592+
),
2593+
"include_all_fields": getattr(
2594+
config, "include_all_fields", None
2595+
),
2596+
}
2597+
}
2598+
if config
2599+
else {}
2600+
),
24952601
)
24962602

24972603
async def assess_tuning_validity(
@@ -2551,9 +2657,23 @@ async def assess_tuning_validity(
25512657
operation=operation,
25522658
timeout_seconds=config.timeout,
25532659
)
2554-
return _datasets_utils.create_from_response(
2555-
types.TuningValidationAssessmentResult,
2556-
response["tuningValidationAssessmentResult"],
2660+
return types.TuningValidationAssessmentResult._from_response(
2661+
response=response["tuningValidationAssessmentResult"],
2662+
kwargs=(
2663+
{
2664+
"config": {
2665+
"response_schema": getattr(config, "response_schema", None),
2666+
"response_json_schema": getattr(
2667+
config, "response_json_schema", None
2668+
),
2669+
"include_all_fields": getattr(
2670+
config, "include_all_fields", None
2671+
),
2672+
}
2673+
}
2674+
if config
2675+
else {}
2676+
),
25572677
)
25582678

25592679
async def assess_batch_prediction_resources(
@@ -2611,8 +2731,23 @@ async def assess_batch_prediction_resources(
26112731
timeout_seconds=config.timeout,
26122732
)
26132733
result = response["batchPredictionResourceUsageAssessmentResult"]
2614-
return _datasets_utils.create_from_response(
2615-
types.BatchPredictionResourceUsageAssessmentResult, result
2734+
return types.BatchPredictionResourceUsageAssessmentResult._from_response(
2735+
response=result,
2736+
kwargs=(
2737+
{
2738+
"config": {
2739+
"response_schema": getattr(config, "response_schema", None),
2740+
"response_json_schema": getattr(
2741+
config, "response_json_schema", None
2742+
),
2743+
"include_all_fields": getattr(
2744+
config, "include_all_fields", None
2745+
),
2746+
}
2747+
}
2748+
if config
2749+
else {}
2750+
),
26162751
)
26172752

26182753
async def assess_batch_prediction_validity(
@@ -2670,6 +2805,21 @@ async def assess_batch_prediction_validity(
26702805
timeout_seconds=config.timeout,
26712806
)
26722807
result = response["batchPredictionValidationAssessmentResult"]
2673-
return _datasets_utils.create_from_response(
2674-
types.BatchPredictionValidationAssessmentResult, result
2808+
return types.BatchPredictionValidationAssessmentResult._from_response(
2809+
response=result,
2810+
kwargs=(
2811+
{
2812+
"config": {
2813+
"response_schema": getattr(config, "response_schema", None),
2814+
"response_json_schema": getattr(
2815+
config, "response_json_schema", None
2816+
),
2817+
"include_all_fields": getattr(
2818+
config, "include_all_fields", None
2819+
),
2820+
}
2821+
}
2822+
if config
2823+
else {}
2824+
),
26752825
)

0 commit comments

Comments
 (0)