Skip to content

Commit 7552e7b

Browse files
Generate MCP tool signatures
1 parent cc24604 commit 7552e7b

1 file changed

Lines changed: 93 additions & 4 deletions

File tree

pipelines/google/google_gemini.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,17 @@
6161
from google import genai
6262
from google.genai import types
6363
from google.genai.errors import ClientError, ServerError, APIError
64-
from typing import List, Union, Optional, Dict, Any, Tuple, AsyncIterator, Callable
64+
from typing import (
65+
List,
66+
Union,
67+
Optional,
68+
Dict,
69+
Any,
70+
Tuple,
71+
AsyncIterator,
72+
Callable,
73+
Awaitable,
74+
)
6575
from pydantic_core import core_schema
6676
from pydantic import BaseModel, Field, GetCoreSchemaHandler
6777
from cryptography.fernet import Fernet, InvalidToken
@@ -2413,9 +2423,10 @@ async def _configure_generation(
24132423
self.log.debug("Replacing 'fetch_url' with URL context grounding")
24142424
tools.append(types.Tool(url_context=types.UrlContext()))
24152425
elif tool_def.get("type") == "mcp":
2416-
# Don't add mcp tools one by one, add the mcp session directly later
2417-
self.log.debug(f"Skipping MCP tool {name}")
2418-
continue
2426+
tool = self._create_callable_from_spec(
2427+
name, tool_def["spec"], tool_def["callable"]
2428+
)
2429+
self.log.debug(f"Adding tool '{name}'")
24192430
elif (
24202431
inspect.signature(tool_def["callable"]).return_annotation is types.Tool
24212432
):
@@ -2446,6 +2457,84 @@ async def _configure_generation(
24462457
filtered_params = {k: v for k, v in gen_config_params.items() if v is not None}
24472458
return types.GenerateContentConfig(**filtered_params)
24482459

2460+
@staticmethod
2461+
def _create_callable_from_spec(
2462+
name: str, spec: dict, callable_func: Callable[..., Awaitable[Any]]
2463+
) -> Callable[..., Awaitable[Any]]:
2464+
"""
2465+
Dynamically creates a well-typed async function from an MCP-style tool specification.
2466+
This satisfies inspection-based SDKs (like Gemini) by providing proper
2467+
signatures, docstrings, and unique function names.
2468+
"""
2469+
import inspect
2470+
2471+
description = spec.get("description", "")
2472+
parameters_spec = spec.get("parameters", spec.get("inputSchema", {}))
2473+
properties = parameters_spec.get("properties", {})
2474+
required_params = parameters_spec.get("required", [])
2475+
2476+
# Type mapping from JSON schema to Python
2477+
type_map = {
2478+
"string": str,
2479+
"number": float,
2480+
"integer": int,
2481+
"boolean": bool,
2482+
"object": dict,
2483+
"array": list,
2484+
}
2485+
2486+
params = []
2487+
doc_params = []
2488+
2489+
# Sort properties so required parameters come first to avoid "non-default argument follows default argument"
2490+
sorted_properties = sorted(
2491+
properties.items(),
2492+
key=lambda item: item[0] not in required_params,
2493+
)
2494+
2495+
for param_name, param_info in sorted_properties:
2496+
if param_name.startswith("__"):
2497+
continue
2498+
2499+
param_type = type_map.get(param_info.get("type"), Any)
2500+
param_desc = param_info.get("description", "")
2501+
2502+
default = inspect.Parameter.empty
2503+
if param_name not in required_params:
2504+
# If not required, default to None or a provided default
2505+
default = param_info.get("default", None)
2506+
2507+
params.append(
2508+
inspect.Parameter(
2509+
name=param_name,
2510+
kind=inspect.Parameter.KEYWORD_ONLY,
2511+
default=default,
2512+
annotation=param_type,
2513+
)
2514+
)
2515+
2516+
if param_desc:
2517+
doc_params.append(f":param {param_name}: {param_desc}")
2518+
2519+
# Build the docstring
2520+
docstring = description
2521+
if doc_params:
2522+
docstring += "\n\n" + "\n".join(doc_params)
2523+
2524+
# The actual wrapper function
2525+
async def wrapped_func(*args, **kwargs):
2526+
return await callable_func(*args, **kwargs)
2527+
2528+
# Set metadata to satisfy SDK inspection
2529+
wrapped_func.__name__ = name
2530+
wrapped_func.__qualname__ = name
2531+
wrapped_func.__doc__ = docstring
2532+
wrapped_func.__signature__ = inspect.Signature(
2533+
parameters=params, return_annotation=Any
2534+
)
2535+
2536+
return wrapped_func
2537+
24492538
@staticmethod
24502539
def _format_grounding_chunks_as_sources(
24512540
grounding_chunks: list[types.GroundingChunk],

0 commit comments

Comments
 (0)