Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions google/genai/_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def __init__(
self,
session: "mcp.ClientSession", # type: ignore # noqa: F821
list_tools_result: "mcp_types.ListToolsResult", # type: ignore
is_agent_platform: bool = False,
) -> None:
self._mcp_session = session
self._list_tools_result = list_tools_result
self._is_agent_platform = is_agent_platform

async def call_tool(
self, function_call: FunctionCall
Expand Down
45 changes: 29 additions & 16 deletions google/genai/_extra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,9 @@ def format_destination(

def find_afc_incompatible_tool_indexes(
config: Optional[types.GenerateContentConfigOrDict] = None,
is_agent_platform: bool = False,
) -> list[int]:
"""Checks if the config contains any AFC incompatible tools.

A `types.Tool` object that contains `function_declarations` is considered a
non-AFC tool for this execution path.

Args:
config: The GenerateContentConfig to check for incompatible tools.

Returns:
A list of indexes of the incompatible tools in the config.
"""
"""Checks if the config contains any AFC incompatible tools."""
if not config:
return []
config_model = _create_generate_content_config_model(config)
Expand All @@ -145,7 +136,9 @@ def find_afc_incompatible_tool_indexes(
continue
if tool.function_declarations:
incompatible_tools_indexes.append(index)
if tool.mcp_servers:

# Only mark it incompatible if it's MLDev, not Agent Platform.
if tool.mcp_servers and not is_agent_platform:
incompatible_tools_indexes.append(index)
return incompatible_tools_indexes

Expand Down Expand Up @@ -383,12 +376,15 @@ async def get_function_response_parts_async(
if not part.function_call:
continue
func_name = part.function_call.name
if func_name is not None and part.function_call.args is not None:
if func_name is not None:
func = function_map[func_name]
args = convert_number_values_for_dict_function_call_args(
# Treat None as an empty dictionary for execution
raw_args = (
part.function_call.args
if part.function_call.args is not None
else {}
)
func_response: _common.StringDict
args = convert_number_values_for_dict_function_call_args(raw_args)
try:
if isinstance(func, McpToGenAiToolAdapter):
mcp_tool_response = await func.call_tool(
Expand Down Expand Up @@ -551,6 +547,7 @@ def parse_config_for_mcp_usage(

async def parse_config_for_mcp_sessions(
config: Optional[types.GenerateContentConfigOrDict] = None,
is_agent_platform: bool = False,
) -> tuple[
Optional[types.GenerateContentConfig],
dict[str, McpToGenAiToolAdapter],
Expand All @@ -571,7 +568,7 @@ async def parse_config_for_mcp_sessions(
for tool in parsed_config.tools:
if McpClientSession is not None and isinstance(tool, McpClientSession):
mcp_to_genai_tool_adapter = McpToGenAiToolAdapter(
tool, await tool.list_tools()
tool, await tool.list_tools(), is_agent_platform=is_agent_platform
)
# Extend the config with the MCP session tools converted to GenAI tools.
parsed_config_copy.tools.extend(mcp_to_genai_tool_adapter.tools)
Expand Down Expand Up @@ -677,3 +674,19 @@ def prepare_resumable_upload(
http_options.headers = {}
http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file)
return http_options, size_bytes, mime_type


def has_agent_platform_mcp_servers(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
"""Checks whether the configuration contains any MCP server requests."""
if not config:
return False
config_model = _create_generate_content_config_model(config)
if not config_model.tools:
return False

for tool in config_model.tools:
if isinstance(tool, types.Tool) and tool.mcp_servers:
return True
return False
30 changes: 27 additions & 3 deletions google/genai/_live_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,26 @@ def _LiveServerMessage_from_vertex(
return to_object


def _McpServer_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['name']) is not None:
raise ValueError(
'name parameter is only supported in Gemini Developer API mode, not in'
' Gemini Enterprise Agent Platform mode.'
)

if getv(from_object, ['streamable_http_transport']) is not None:
raise ValueError(
'streamable_http_transport parameter is only supported in Gemini'
' Developer API mode, not in Gemini Enterprise Agent Platform mode.'
)

return to_object


def _MultiSpeakerVoiceConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -1893,9 +1913,13 @@ def _Tool_to_vertex(
setv(to_object, ['urlContext'], getv(from_object, ['url_context']))

if getv(from_object, ['mcp_servers']) is not None:
raise ValueError(
'mcp_servers parameter is only supported in Gemini Developer API mode,'
' not in Gemini Enterprise Agent Platform mode.'
setv(
to_object,
['mcpServers'],
[
_McpServer_to_vertex(item, to_object)
for item in getv(from_object, ['mcp_servers'])
],
)

return to_object
Expand Down
30 changes: 27 additions & 3 deletions google/genai/caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,26 @@ def _ListCachedContentsResponse_from_vertex(
return to_object


def _McpServer_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['name']) is not None:
raise ValueError(
'name parameter is only supported in Gemini Developer API mode, not in'
' Gemini Enterprise Agent Platform mode.'
)

if getv(from_object, ['streamable_http_transport']) is not None:
raise ValueError(
'streamable_http_transport parameter is only supported in Gemini'
' Developer API mode, not in Gemini Enterprise Agent Platform mode.'
)

return to_object


def _Part_to_mldev(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -961,9 +981,13 @@ def _Tool_to_vertex(
setv(to_object, ['urlContext'], getv(from_object, ['url_context']))

if getv(from_object, ['mcp_servers']) is not None:
raise ValueError(
'mcp_servers parameter is only supported in Gemini Developer API mode,'
' not in Gemini Enterprise Agent Platform mode.'
setv(
to_object,
['mcpServers'],
[
_McpServer_to_vertex(item, to_object)
for item in getv(from_object, ['mcp_servers'])
],
)

return to_object
Expand Down
Loading
Loading