Skip to content

Commit 85b3780

Browse files
speedstorm1copybara-github
authored andcommitted
chore: Handle recursive JSON schema references in type conversion.
Fixes #2181 PiperOrigin-RevId: 894740183
1 parent 07e932f commit 85b3780

File tree

4 files changed

+97
-6
lines changed

4 files changed

+97
-6
lines changed

google/genai/_transformers.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ def process_schema(
665665
defs: Optional[_common.StringDict] = None,
666666
*,
667667
order_properties: bool = True,
668+
visited_dicts_path: Optional[set[int]] = None,
668669
) -> None:
669670
"""Updates the schema and each sub-schema inplace to be API-compatible.
670671
@@ -726,6 +727,13 @@ def process_schema(
726727
'type': 'array'
727728
}
728729
"""
730+
if visited_dicts_path is None:
731+
visited_dicts_path = set()
732+
733+
if id(schema) in visited_dicts_path:
734+
return
735+
visited_dicts_path.add(id(schema))
736+
729737
if schema.get('title') == 'PlaceholderLiteralEnum':
730738
del schema['title']
731739

@@ -750,7 +758,11 @@ def process_schema(
750758
# directly referencing another '$ref':
751759
# https://json-schema.org/understanding-json-schema/structuring#recursion
752760
process_schema(
753-
sub_schema, client, defs, order_properties=order_properties
761+
sub_schema,
762+
client,
763+
defs,
764+
order_properties=order_properties,
765+
visited_dicts_path=visited_dicts_path,
754766
)
755767

756768
handle_null_fields(schema)
@@ -765,11 +777,21 @@ def _recurse(sub_schema: _common.StringDict) -> _common.StringDict:
765777
"""Returns the processed `sub_schema`, resolving its '$ref' if any."""
766778
if (ref := sub_schema.pop('$ref', None)) is not None:
767779
sub_schema = defs[ref.split('defs/')[-1]]
768-
process_schema(sub_schema, client, defs, order_properties=order_properties)
780+
if id(sub_schema) in visited_dicts_path:
781+
return {}
782+
783+
process_schema(
784+
sub_schema,
785+
client,
786+
defs,
787+
order_properties=order_properties,
788+
visited_dicts_path=visited_dicts_path,
789+
)
769790
return sub_schema
770791

771792
if (any_of := schema.get('anyOf')) is not None:
772793
schema['anyOf'] = [_recurse(sub_schema) for sub_schema in any_of]
794+
visited_dicts_path.remove(id(schema))
773795
return
774796

775797
schema_type = schema.get('type')
@@ -809,6 +831,7 @@ def _recurse(sub_schema: _common.StringDict) -> _common.StringDict:
809831
if (prefixes := schema.get('prefixItems')) is not None:
810832
schema['prefixItems'] = [_recurse(prefix) for prefix in prefixes]
811833

834+
visited_dicts_path.remove(id(schema))
812835

813836
def _process_enum(
814837
enum: EnumMeta, client: Optional[_api_client.BaseApiClient]

google/genai/tests/transformers/test_schema.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,39 @@ def test_process_schema_order_properties_propagates_into_any_of(
607607
assert schema == schema_without_property_ordering
608608

609609

610+
@pytest.mark.parametrize('use_vertex', [True, False])
611+
def test_process_schema_with_cycle(client):
612+
schema = {
613+
'type': 'OBJECT',
614+
'properties': {
615+
'recursive': {'$ref': '#/$defs/RecursiveObject'},
616+
},
617+
'$defs': {
618+
'RecursiveObject': {
619+
'type': 'OBJECT',
620+
'properties': {
621+
'self': {'$ref': '#/$defs/RecursiveObject'},
622+
}
623+
}
624+
}
625+
}
626+
627+
_transformers.process_schema(schema, client)
628+
629+
expected = {
630+
'type': 'OBJECT',
631+
'properties': {
632+
'recursive': {
633+
'type': 'OBJECT',
634+
'properties': {
635+
'self': {}
636+
}
637+
}
638+
}
639+
}
640+
assert schema == expected
641+
642+
610643
@pytest.mark.parametrize('use_vertex', [True, False])
611644
def test_t_schema_does_not_change_property_ordering_if_set(client):
612645
"""Tests t_schema doesn't overwrite the property_ordering field if already set."""

google/genai/tests/types/test_types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2687,6 +2687,30 @@ def func_under_test(a: int) -> str:
26872687
assert actual_schema_vertex == expected_schema_vertex
26882688

26892689

2690+
def test_convert_json_schema_with_cycle():
2691+
json_schema_dict = {
2692+
'type': 'object',
2693+
'properties': {
2694+
'foo': {'$ref': '#/$defs/Foo'}
2695+
},
2696+
'$defs': {
2697+
'Foo': {
2698+
'type': 'object',
2699+
'properties': {
2700+
'foo': {'$ref': '#/$defs/Foo'}
2701+
}
2702+
}
2703+
}
2704+
}
2705+
2706+
json_schema = types.JSONSchema(**json_schema_dict)
2707+
schema = types.Schema.from_json_schema(json_schema=json_schema)
2708+
2709+
assert schema.type == types.Type.OBJECT
2710+
assert schema.properties['foo'].type == types.Type.OBJECT
2711+
assert schema.properties['foo'].properties['foo'] == types.Schema()
2712+
2713+
26902714
def test_case_insensitive_enum():
26912715
assert types.Type('STRING') == types.Type.STRING
26922716
assert types.Type('string') == types.Type.STRING

google/genai/types.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2911,14 +2911,20 @@ def convert_json_schema(
29112911
root_json_schema_dict: dict[str, Any],
29122912
api_option: Literal['VERTEX_AI', 'GEMINI_API'],
29132913
raise_error_on_unsupported_field: bool,
2914+
visited_refs: Optional[set[str]] = None,
29142915
) -> 'Schema':
2916+
if visited_refs is None:
2917+
visited_refs = set()
2918+
29152919
schema = Schema()
29162920
json_schema_dict = current_json_schema.model_dump()
29172921

2918-
if json_schema_dict.get('ref'):
2919-
json_schema_dict = _resolve_ref(
2920-
json_schema_dict['ref'], root_json_schema_dict
2921-
)
2922+
ref = json_schema_dict.get('ref')
2923+
if ref:
2924+
if ref in visited_refs:
2925+
return Schema()
2926+
visited_refs.add(ref)
2927+
json_schema_dict = _resolve_ref(ref, root_json_schema_dict)
29222928

29232929
raise_error_if_cannot_convert(
29242930
json_schema_dict=json_schema_dict,
@@ -2985,6 +2991,7 @@ def convert_json_schema(
29852991
root_json_schema_dict=root_json_schema_dict,
29862992
api_option=api_option,
29872993
raise_error_on_unsupported_field=raise_error_on_unsupported_field,
2994+
visited_refs=visited_refs,
29882995
)
29892996
setattr(schema, field_name, schema_field_value)
29902997
elif field_name in list_schema_field_names:
@@ -2994,6 +3001,7 @@ def convert_json_schema(
29943001
root_json_schema_dict=root_json_schema_dict,
29953002
api_option=api_option,
29963003
raise_error_on_unsupported_field=raise_error_on_unsupported_field,
3004+
visited_refs=visited_refs,
29973005
)
29983006
for this_field_value in field_value
29993007
]
@@ -3007,6 +3015,7 @@ def convert_json_schema(
30073015
root_json_schema_dict=root_json_schema_dict,
30083016
api_option=api_option,
30093017
raise_error_on_unsupported_field=raise_error_on_unsupported_field,
3018+
visited_refs=visited_refs,
30103019
)
30113020
for key, value in field_value.items()
30123021
}
@@ -3051,6 +3060,8 @@ def convert_json_schema(
30513060
if default_value is not None:
30523061
schema.default = default_value
30533062

3063+
if ref:
3064+
visited_refs.remove(ref)
30543065
return schema
30553066

30563067
# This is the initial call to the recursive function.

0 commit comments

Comments
 (0)