Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@
logger = logging.getLogger('google_genai.models')


def _merge_content_parts(contents: list[types.Content]) -> Optional[types.Content]:
"""Merges streamed content chunks into one content block."""
parts: list[types.Part] = []
role = None
for content in contents:
if role is None:
role = content.role
if content.parts:
parts.extend(content.parts)
if not parts:
return None
return types.Content(role=role, parts=parts)


def _PersonGeneration_to_mldev_enum_validate(enum_value: Any) -> None:
if enum_value in set(['ALLOW_ALL']):
raise ValueError(
Expand Down Expand Up @@ -6567,6 +6581,7 @@ def generate_content_stream(
automatic_function_calling_history: list[types.Content] = []
chunk = None
func_response_parts = None
func_call_content = None
i = 0
while remaining_remote_calls_afc > 0:
i += 1
Expand All @@ -6575,6 +6590,8 @@ def generate_content_stream(
)

function_map = _extra_utils.get_function_map(parsed_config)
func_response_parts = []
func_call_contents = []

if i == 1:
# First request gets a function call.
Expand All @@ -6591,12 +6608,16 @@ def generate_content_stream(
or not chunk.candidates[0].content.parts
):
break
func_response_parts = _extra_utils.get_function_response_parts(
response_parts = _extra_utils.get_function_response_parts(
chunk, function_map
)
if not func_response_parts:
if response_parts:
func_response_parts.extend(response_parts)
func_call_contents.append(chunk.candidates[0].content)
else:
contents = _extra_utils.append_chunk_contents(contents, chunk) # type: ignore[assignment]
yield chunk
func_call_content = _merge_content_parts(func_call_contents)

else:
# Second request and beyond, yield chunks.
Expand All @@ -6617,6 +6638,7 @@ def generate_content_stream(
func_response_parts = _extra_utils.get_function_response_parts(
chunk, function_map
)
func_call_content = chunk.candidates[0].content

if not function_map:
break
Expand All @@ -6629,7 +6651,6 @@ def generate_content_stream(

# Append function response parts to contents for the next request.
if chunk is not None and chunk.candidates is not None:
func_call_content = chunk.candidates[0].content
func_response_content = types.Content(
role='user',
parts=func_response_parts,
Expand Down Expand Up @@ -8667,6 +8688,7 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d
)
automatic_function_calling_history: list[types.Content] = []
func_response_parts = None
func_call_content = None
chunk = None
i = 0
while remaining_remote_calls_afc > 0:
Expand All @@ -8686,6 +8708,8 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d
function_map = _extra_utils.get_function_map(
config, mcp_to_genai_tool_adapters, is_caller_method_async=True
)
func_response_parts = []
func_call_contents = []

if i == 1:
# First request gets a function call.
Expand All @@ -8702,14 +8726,18 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d
or not chunk.candidates[0].content.parts
):
break
func_response_parts = (
response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
)
)
if not func_response_parts:
if response_parts:
func_response_parts.extend(response_parts)
func_call_contents.append(chunk.candidates[0].content)
else:
contents = _extra_utils.append_chunk_contents(contents, chunk)
yield chunk
func_call_content = _merge_content_parts(func_call_contents)

else:
# Second request and beyond, yield chunks.
Expand All @@ -8733,6 +8761,7 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d
chunk, function_map
)
)
func_call_content = chunk.candidates[0].content
if not function_map:
break

Expand All @@ -8742,7 +8771,6 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d
if chunk is None:
continue
# Append function response parts to contents for the next request.
func_call_content = chunk.candidates[0].content
func_response_content = types.Content(
role='user',
parts=func_response_parts,
Expand Down
133 changes: 133 additions & 0 deletions google/genai/tests/afc/test_generate_content_stream_afc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,31 @@
]


def _function_call_response(
name: str,
args: dict[str, object],
thought_signature: bytes,
) -> types.GenerateContentResponse:
return types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall(
name=name,
args=args,
),
thought_signature=thought_signature,
)
],
role='model',
)
)
]
)


def get_current_weather(location: str) -> str:
"""Returns the current weather.

Expand Down Expand Up @@ -359,6 +384,58 @@ def test_generate_content_stream_with_thought_summaries(
) == TEST_AFC_HISTORY[i].model_dump(exclude_none=True)


def test_generate_content_stream_merges_function_call_chunks_with_signatures():
with mock.patch.object(
models.Models, '_generate_content_stream'
) as mock_stream_with_parallel_calls:
mock_stream_with_parallel_calls.side_effect = [
[
_function_call_response(
'get_current_weather',
{'location': 'San Francisco'},
b'weather-signature',
),
_function_call_response(
'get_aqi_from_city',
{'location': 'San Francisco'},
b'aqi-signature',
),
],
[
types.GenerateContentResponse(
candidates=[types.Candidate(content=TEST_AFC_TEXT_CONTENT)]
)
],
]
models_instance = models.Models(api_client_=mock_api_client)
stream = models_instance.generate_content_stream(
model='test_model',
contents='what is the weather and AQI in San Francisco?',
config=types.GenerateContentConfig(
tools=[get_current_weather, get_aqi_from_city]
),
)

chunks = list(stream)

assert len(chunks) == 1
assert mock_stream_with_parallel_calls.call_count == 2
second_request_contents = mock_stream_with_parallel_calls.call_args_list[
1
].kwargs['contents']
function_call_content = second_request_contents[-2]
function_response_content = second_request_contents[-1]

assert len(function_call_content.parts) == 2
assert [
part.function_call.name for part in function_call_content.parts
] == ['get_current_weather', 'get_aqi_from_city']
assert [
part.thought_signature for part in function_call_content.parts
] == [b'weather-signature', b'aqi-signature']
assert len(function_response_content.parts) == 2


@pytest.mark.asyncio
async def test_generate_content_stream_no_function_map_async(
mock_generate_content_stream_no_afc,
Expand Down Expand Up @@ -528,3 +605,59 @@ async def test_generate_content_stream_with_thought_summaries_async(
assert chunk.automatic_function_calling_history[i].model_dump(
exclude_none=True
) == TEST_AFC_HISTORY[i].model_dump(exclude_none=True)


@pytest.mark.asyncio
async def test_generate_content_stream_merges_function_call_chunks_async():
with mock.patch.object(
models.AsyncModels, '_generate_content_stream'
) as mock_stream_with_parallel_calls:

async def async_generator_1():
yield _function_call_response(
'get_current_weather',
{'location': 'San Francisco'},
b'weather-signature',
)
yield _function_call_response(
'get_aqi_from_city',
{'location': 'San Francisco'},
b'aqi-signature',
)

async def async_generator_2():
yield types.GenerateContentResponse(
candidates=[types.Candidate(content=TEST_AFC_TEXT_CONTENT)]
)

mock_stream_with_parallel_calls.side_effect = [
async_generator_1(),
async_generator_2(),
]
models_instance = models.AsyncModels(api_client_=mock_api_client)
stream = await models_instance.generate_content_stream(
model='test_model',
contents='what is the weather and AQI in San Francisco?',
config=types.GenerateContentConfig(
tools=[get_current_weather, get_aqi_from_city]
),
)

chunks = [chunk async for chunk in stream]

assert len(chunks) == 1
assert mock_stream_with_parallel_calls.call_count == 2
second_request_contents = mock_stream_with_parallel_calls.call_args_list[
1
].kwargs['contents']
function_call_content = second_request_contents[-2]
function_response_content = second_request_contents[-1]

assert len(function_call_content.parts) == 2
assert [
part.function_call.name for part in function_call_content.parts
] == ['get_current_weather', 'get_aqi_from_city']
assert [
part.thought_signature for part in function_call_content.parts
] == [b'weather-signature', b'aqi-signature']
assert len(function_response_content.parts) == 2