Skip to content

Commit 4b7f6c4

Browse files
sararobcopybara-github
authored andcommitted
feat: Add MCP support to async generate_content
PiperOrigin-RevId: 912562893
1 parent d5a9527 commit 4b7f6c4

6 files changed

Lines changed: 309 additions & 107 deletions

File tree

google/genai/_adapters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ def __init__(
3030
self,
3131
session: "mcp.ClientSession", # type: ignore # noqa: F821
3232
list_tools_result: "mcp_types.ListToolsResult", # type: ignore
33+
is_agent_platform: bool = False,
3334
) -> None:
3435
self._mcp_session = session
3536
self._list_tools_result = list_tools_result
37+
self._is_agent_platform = is_agent_platform
3638

3739
async def call_tool(
3840
self, function_call: FunctionCall

google/genai/_extra_utils.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,9 @@ def format_destination(
120120

121121
def find_afc_incompatible_tool_indexes(
122122
config: Optional[types.GenerateContentConfigOrDict] = None,
123+
is_agent_platform: bool = False,
123124
) -> list[int]:
124-
"""Checks if the config contains any AFC incompatible tools.
125-
126-
A `types.Tool` object that contains `function_declarations` is considered a
127-
non-AFC tool for this execution path.
128-
129-
Args:
130-
config: The GenerateContentConfig to check for incompatible tools.
131-
132-
Returns:
133-
A list of indexes of the incompatible tools in the config.
134-
"""
125+
"""Checks if the config contains any AFC incompatible tools."""
135126
if not config:
136127
return []
137128
config_model = _create_generate_content_config_model(config)
@@ -145,7 +136,9 @@ def find_afc_incompatible_tool_indexes(
145136
continue
146137
if tool.function_declarations:
147138
incompatible_tools_indexes.append(index)
148-
if tool.mcp_servers:
139+
140+
# Only mark it incompatible if it's MLDev, not Agent Platform.
141+
if tool.mcp_servers and not is_agent_platform:
149142
incompatible_tools_indexes.append(index)
150143
return incompatible_tools_indexes
151144

@@ -383,12 +376,15 @@ async def get_function_response_parts_async(
383376
if not part.function_call:
384377
continue
385378
func_name = part.function_call.name
386-
if func_name is not None and part.function_call.args is not None:
379+
if func_name is not None:
387380
func = function_map[func_name]
388-
args = convert_number_values_for_dict_function_call_args(
381+
# Treat None as an empty dictionary for execution
382+
raw_args = (
389383
part.function_call.args
384+
if part.function_call.args is not None
385+
else {}
390386
)
391-
func_response: _common.StringDict
387+
args = convert_number_values_for_dict_function_call_args(raw_args)
392388
try:
393389
if isinstance(func, McpToGenAiToolAdapter):
394390
mcp_tool_response = await func.call_tool(
@@ -551,6 +547,7 @@ def parse_config_for_mcp_usage(
551547

552548
async def parse_config_for_mcp_sessions(
553549
config: Optional[types.GenerateContentConfigOrDict] = None,
550+
is_agent_platform: bool = False,
554551
) -> tuple[
555552
Optional[types.GenerateContentConfig],
556553
dict[str, McpToGenAiToolAdapter],
@@ -571,7 +568,7 @@ async def parse_config_for_mcp_sessions(
571568
for tool in parsed_config.tools:
572569
if McpClientSession is not None and isinstance(tool, McpClientSession):
573570
mcp_to_genai_tool_adapter = McpToGenAiToolAdapter(
574-
tool, await tool.list_tools()
571+
tool, await tool.list_tools(), is_agent_platform=is_agent_platform
575572
)
576573
# Extend the config with the MCP session tools converted to GenAI tools.
577574
parsed_config_copy.tools.extend(mcp_to_genai_tool_adapter.tools)
@@ -677,3 +674,19 @@ def prepare_resumable_upload(
677674
http_options.headers = {}
678675
http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file)
679676
return http_options, size_bytes, mime_type
677+
678+
679+
def has_agent_platform_mcp_servers(
680+
config: Optional[types.GenerateContentConfigOrDict] = None,
681+
) -> bool:
682+
"""Checks whether the configuration contains any MCP server requests."""
683+
if not config:
684+
return False
685+
config_model = _create_generate_content_config_model(config)
686+
if not config_model.tools:
687+
return False
688+
689+
for tool in config_model.tools:
690+
if isinstance(tool, types.Tool) and tool.mcp_servers:
691+
return True
692+
return False

google/genai/_live_converters.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,26 @@ def _LiveServerMessage_from_vertex(
14831483
return to_object
14841484

14851485

1486+
def _McpServer_to_vertex(
1487+
from_object: Union[dict[str, Any], object],
1488+
parent_object: Optional[dict[str, Any]] = None,
1489+
) -> dict[str, Any]:
1490+
to_object: dict[str, Any] = {}
1491+
if getv(from_object, ['name']) is not None:
1492+
raise ValueError(
1493+
'name parameter is only supported in Gemini Developer API mode, not in'
1494+
' Gemini Enterprise Agent Platform mode.'
1495+
)
1496+
1497+
if getv(from_object, ['streamable_http_transport']) is not None:
1498+
raise ValueError(
1499+
'streamable_http_transport parameter is only supported in Gemini'
1500+
' Developer API mode, not in Gemini Enterprise Agent Platform mode.'
1501+
)
1502+
1503+
return to_object
1504+
1505+
14861506
def _MultiSpeakerVoiceConfig_to_vertex(
14871507
from_object: Union[dict[str, Any], object],
14881508
parent_object: Optional[dict[str, Any]] = None,
@@ -1893,9 +1913,13 @@ def _Tool_to_vertex(
18931913
setv(to_object, ['urlContext'], getv(from_object, ['url_context']))
18941914

18951915
if getv(from_object, ['mcp_servers']) is not None:
1896-
raise ValueError(
1897-
'mcp_servers parameter is only supported in Gemini Developer API mode,'
1898-
' not in Gemini Enterprise Agent Platform mode.'
1916+
setv(
1917+
to_object,
1918+
['mcpServers'],
1919+
[
1920+
_McpServer_to_vertex(item, to_object)
1921+
for item in getv(from_object, ['mcp_servers'])
1922+
],
18991923
)
19001924

19011925
return to_object

google/genai/caches.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,26 @@ def _ListCachedContentsResponse_from_vertex(
625625
return to_object
626626

627627

628+
def _McpServer_to_vertex(
629+
from_object: Union[dict[str, Any], object],
630+
parent_object: Optional[dict[str, Any]] = None,
631+
) -> dict[str, Any]:
632+
to_object: dict[str, Any] = {}
633+
if getv(from_object, ['name']) is not None:
634+
raise ValueError(
635+
'name parameter is only supported in Gemini Developer API mode, not in'
636+
' Gemini Enterprise Agent Platform mode.'
637+
)
638+
639+
if getv(from_object, ['streamable_http_transport']) is not None:
640+
raise ValueError(
641+
'streamable_http_transport parameter is only supported in Gemini'
642+
' Developer API mode, not in Gemini Enterprise Agent Platform mode.'
643+
)
644+
645+
return to_object
646+
647+
628648
def _Part_to_mldev(
629649
from_object: Union[dict[str, Any], object],
630650
parent_object: Optional[dict[str, Any]] = None,
@@ -961,9 +981,13 @@ def _Tool_to_vertex(
961981
setv(to_object, ['urlContext'], getv(from_object, ['url_context']))
962982

963983
if getv(from_object, ['mcp_servers']) is not None:
964-
raise ValueError(
965-
'mcp_servers parameter is only supported in Gemini Developer API mode,'
966-
' not in Gemini Enterprise Agent Platform mode.'
984+
setv(
985+
to_object,
986+
['mcpServers'],
987+
[
988+
_McpServer_to_vertex(item, to_object)
989+
for item in getv(from_object, ['mcp_servers'])
990+
],
967991
)
968992

969993
return to_object

0 commit comments

Comments
 (0)