Skip to content

Commit 5e9ffd9

Browse files
Generate MCP tool signatures
1 parent cc24604 commit 5e9ffd9

1 file changed

Lines changed: 94 additions & 9 deletions

File tree

pipelines/google/google_gemini.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,17 @@
6161
from google import genai
6262
from google.genai import types
6363
from google.genai.errors import ClientError, ServerError, APIError
64-
from typing import List, Union, Optional, Dict, Any, Tuple, AsyncIterator, Callable
64+
from typing import (
65+
List,
66+
Union,
67+
Optional,
68+
Dict,
69+
Any,
70+
Tuple,
71+
AsyncIterator,
72+
Callable,
73+
Awaitable,
74+
)
6575
from pydantic_core import core_schema
6676
from pydantic import BaseModel, Field, GetCoreSchemaHandler
6777
from cryptography.fernet import Fernet, InvalidToken
@@ -2413,9 +2423,11 @@ async def _configure_generation(
24132423
self.log.debug("Replacing 'fetch_url' with URL context grounding")
24142424
tools.append(types.Tool(url_context=types.UrlContext()))
24152425
elif tool_def.get("type") == "mcp":
2416-
# Don't add mcp tools one by one, add the mcp session directly later
2417-
self.log.debug(f"Skipping MCP tool {name}")
2418-
continue
2426+
self.log.debug(f"Adding MCP tool '{name}'")
2427+
mcp_tool = self._create_callable_from_spec(
2428+
name, tool_def["spec"], tool_def["callable"]
2429+
)
2430+
tools.append(mcp_tool)
24192431
elif (
24202432
inspect.signature(tool_def["callable"]).return_annotation is types.Tool
24212433
):
@@ -2434,18 +2446,91 @@ async def _configure_generation(
24342446
self.log.debug(f"Adding tool '{name}'")
24352447
tools.append(tool_def["callable"])
24362448

2437-
# Add MCP server sessions
2438-
for name, mcp_client in __metadata__.get("mcp_clients", {}).items():
2439-
self.log.debug(f"Adding MCP server '{name}'")
2440-
tools.append(mcp_client.session)
2441-
24422449
if tools:
24432450
gen_config_params["tools"] = tools
24442451

24452452
# Filter out None values for generation config
24462453
filtered_params = {k: v for k, v in gen_config_params.items() if v is not None}
24472454
return types.GenerateContentConfig(**filtered_params)
24482455

2456+
@staticmethod
2457+
def _create_callable_from_spec(
2458+
name: str, spec: dict, callable_func: Callable[..., Awaitable[Any]]
2459+
) -> Callable[..., Awaitable[Any]]:
2460+
"""
2461+
Dynamically creates a well-typed async function from an MCP-style tool specification.
2462+
This satisfies inspection-based SDKs (like Gemini) by providing proper
2463+
signatures, docstrings, and unique function names.
2464+
"""
2465+
import inspect
2466+
2467+
description = spec.get("description", "")
2468+
parameters_spec = spec.get("parameters", spec.get("inputSchema", {}))
2469+
properties = parameters_spec.get("properties", {})
2470+
required_params = parameters_spec.get("required", [])
2471+
2472+
# Type mapping from JSON schema to Python
2473+
type_map = {
2474+
"string": str,
2475+
"number": float,
2476+
"integer": int,
2477+
"boolean": bool,
2478+
"object": dict,
2479+
"array": list,
2480+
}
2481+
2482+
params = []
2483+
doc_params = []
2484+
2485+
# Sort properties so required parameters come first to avoid "non-default argument follows default argument"
2486+
sorted_properties = sorted(
2487+
properties.items(),
2488+
key=lambda item: item[0] not in required_params,
2489+
)
2490+
2491+
for param_name, param_info in sorted_properties:
2492+
if param_name.startswith("__"):
2493+
continue
2494+
2495+
param_type = type_map.get(param_info.get("type"), Any)
2496+
param_desc = param_info.get("description", "")
2497+
2498+
default = inspect.Parameter.empty
2499+
if param_name not in required_params:
2500+
# If not required, default to None or a provided default
2501+
default = param_info.get("default", None)
2502+
2503+
params.append(
2504+
inspect.Parameter(
2505+
name=param_name,
2506+
kind=inspect.Parameter.KEYWORD_ONLY,
2507+
default=default,
2508+
annotation=param_type,
2509+
)
2510+
)
2511+
2512+
if param_desc:
2513+
doc_params.append(f":param {param_name}: {param_desc}")
2514+
2515+
# Build the docstring
2516+
docstring = description
2517+
if doc_params:
2518+
docstring += "\n\n" + "\n".join(doc_params)
2519+
2520+
# The actual wrapper function
2521+
async def wrapped_func(*args, **kwargs):
2522+
return await callable_func(*args, **kwargs)
2523+
2524+
# Set metadata to satisfy SDK inspection
2525+
wrapped_func.__name__ = name
2526+
wrapped_func.__qualname__ = name
2527+
wrapped_func.__doc__ = docstring
2528+
wrapped_func.__signature__ = inspect.Signature(
2529+
parameters=params, return_annotation=Any
2530+
)
2531+
2532+
return wrapped_func
2533+
24492534
@staticmethod
24502535
def _format_grounding_chunks_as_sources(
24512536
grounding_chunks: list[types.GroundingChunk],

0 commit comments

Comments
 (0)