Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions google/genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ def process_schema(
defs: Optional[_common.StringDict] = None,
*,
order_properties: bool = True,
visited_dicts_path: Optional[set[int]] = None,
) -> None:
"""Updates the schema and each sub-schema inplace to be API-compatible.

Expand Down Expand Up @@ -726,6 +727,13 @@ def process_schema(
'type': 'array'
}
"""
if visited_dicts_path is None:
visited_dicts_path = set()

if id(schema) in visited_dicts_path:
return
visited_dicts_path.add(id(schema))

if schema.get('title') == 'PlaceholderLiteralEnum':
del schema['title']

Expand All @@ -750,7 +758,11 @@ def process_schema(
# directly referencing another '$ref':
# https://json-schema.org/understanding-json-schema/structuring#recursion
process_schema(
sub_schema, client, defs, order_properties=order_properties
sub_schema,
client,
defs,
order_properties=order_properties,
visited_dicts_path=visited_dicts_path,
)

handle_null_fields(schema)
Expand All @@ -765,11 +777,21 @@ def _recurse(sub_schema: _common.StringDict) -> _common.StringDict:
"""Returns the processed `sub_schema`, resolving its '$ref' if any."""
if (ref := sub_schema.pop('$ref', None)) is not None:
sub_schema = defs[ref.split('defs/')[-1]]
process_schema(sub_schema, client, defs, order_properties=order_properties)
if id(sub_schema) in visited_dicts_path:
return {}

process_schema(
sub_schema,
client,
defs,
order_properties=order_properties,
visited_dicts_path=visited_dicts_path,
)
return sub_schema

if (any_of := schema.get('anyOf')) is not None:
schema['anyOf'] = [_recurse(sub_schema) for sub_schema in any_of]
visited_dicts_path.remove(id(schema))
return

schema_type = schema.get('type')
Expand Down Expand Up @@ -809,6 +831,7 @@ def _recurse(sub_schema: _common.StringDict) -> _common.StringDict:
if (prefixes := schema.get('prefixItems')) is not None:
schema['prefixItems'] = [_recurse(prefix) for prefix in prefixes]

visited_dicts_path.remove(id(schema))

def _process_enum(
enum: EnumMeta, client: Optional[_api_client.BaseApiClient]
Expand Down
33 changes: 33 additions & 0 deletions google/genai/tests/transformers/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,39 @@ def test_process_schema_order_properties_propagates_into_any_of(
assert schema == schema_without_property_ordering


@pytest.mark.parametrize('use_vertex', [True, False])
def test_process_schema_with_cycle(client):
schema = {
'type': 'OBJECT',
'properties': {
'recursive': {'$ref': '#/$defs/RecursiveObject'},
},
'$defs': {
'RecursiveObject': {
'type': 'OBJECT',
'properties': {
'self': {'$ref': '#/$defs/RecursiveObject'},
}
}
}
}

_transformers.process_schema(schema, client)

expected = {
'type': 'OBJECT',
'properties': {
'recursive': {
'type': 'OBJECT',
'properties': {
'self': {}
}
}
}
}
assert schema == expected


@pytest.mark.parametrize('use_vertex', [True, False])
def test_t_schema_does_not_change_property_ordering_if_set(client):
"""Tests t_schema doesn't overwrite the property_ordering field if already set."""
Expand Down
24 changes: 24 additions & 0 deletions google/genai/tests/types/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2687,6 +2687,30 @@ def func_under_test(a: int) -> str:
assert actual_schema_vertex == expected_schema_vertex


def test_convert_json_schema_with_cycle():
json_schema_dict = {
'type': 'object',
'properties': {
'foo': {'$ref': '#/$defs/Foo'}
},
'$defs': {
'Foo': {
'type': 'object',
'properties': {
'foo': {'$ref': '#/$defs/Foo'}
}
}
}
}

json_schema = types.JSONSchema(**json_schema_dict)
schema = types.Schema.from_json_schema(json_schema=json_schema)

assert schema.type == types.Type.OBJECT
assert schema.properties['foo'].type == types.Type.OBJECT
assert schema.properties['foo'].properties['foo'] == types.Schema()


def test_case_insensitive_enum():
assert types.Type('STRING') == types.Type.STRING
assert types.Type('string') == types.Type.STRING
Expand Down
19 changes: 15 additions & 4 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2911,14 +2911,20 @@ def convert_json_schema(
root_json_schema_dict: dict[str, Any],
api_option: Literal['VERTEX_AI', 'GEMINI_API'],
raise_error_on_unsupported_field: bool,
visited_refs: Optional[set[str]] = None,
) -> 'Schema':
if visited_refs is None:
visited_refs = set()

schema = Schema()
json_schema_dict = current_json_schema.model_dump()

if json_schema_dict.get('ref'):
json_schema_dict = _resolve_ref(
json_schema_dict['ref'], root_json_schema_dict
)
ref = json_schema_dict.get('ref')
if ref:
if ref in visited_refs:
return Schema()
visited_refs.add(ref)
json_schema_dict = _resolve_ref(ref, root_json_schema_dict)

raise_error_if_cannot_convert(
json_schema_dict=json_schema_dict,
Expand Down Expand Up @@ -2985,6 +2991,7 @@ def convert_json_schema(
root_json_schema_dict=root_json_schema_dict,
api_option=api_option,
raise_error_on_unsupported_field=raise_error_on_unsupported_field,
visited_refs=visited_refs,
)
setattr(schema, field_name, schema_field_value)
elif field_name in list_schema_field_names:
Expand All @@ -2994,6 +3001,7 @@ def convert_json_schema(
root_json_schema_dict=root_json_schema_dict,
api_option=api_option,
raise_error_on_unsupported_field=raise_error_on_unsupported_field,
visited_refs=visited_refs,
)
for this_field_value in field_value
]
Expand All @@ -3007,6 +3015,7 @@ def convert_json_schema(
root_json_schema_dict=root_json_schema_dict,
api_option=api_option,
raise_error_on_unsupported_field=raise_error_on_unsupported_field,
visited_refs=visited_refs,
)
for key, value in field_value.items()
}
Expand Down Expand Up @@ -3051,6 +3060,8 @@ def convert_json_schema(
if default_value is not None:
schema.default = default_value

if ref:
visited_refs.remove(ref)
return schema

# This is the initial call to the recursive function.
Expand Down
Loading