|
61 | 61 | from google import genai |
62 | 62 | from google.genai import types |
63 | 63 | from google.genai.errors import ClientError, ServerError, APIError |
64 | | -from typing import List, Union, Optional, Dict, Any, Tuple, AsyncIterator, Callable |
| 64 | +from typing import ( |
| 65 | + List, |
| 66 | + Union, |
| 67 | + Optional, |
| 68 | + Dict, |
| 69 | + Any, |
| 70 | + Tuple, |
| 71 | + AsyncIterator, |
| 72 | + Callable, |
| 73 | + Awaitable, |
| 74 | +) |
65 | 75 | from pydantic_core import core_schema |
66 | 76 | from pydantic import BaseModel, Field, GetCoreSchemaHandler |
67 | 77 | from cryptography.fernet import Fernet, InvalidToken |
@@ -2413,9 +2423,10 @@ async def _configure_generation( |
2413 | 2423 | self.log.debug("Replacing 'fetch_url' with URL context grounding") |
2414 | 2424 | tools.append(types.Tool(url_context=types.UrlContext())) |
2415 | 2425 | elif tool_def.get("type") == "mcp": |
2416 | | - # Don't add mcp tools one by one, add the mcp session directly later |
2417 | | - self.log.debug(f"Skipping MCP tool {name}") |
2418 | | - continue |
| 2426 | + tool = self._create_callable_from_spec( |
| 2427 | + name, tool_def["spec"], tool_def["callable"] |
| 2428 | + ) |
| 2429 | + self.log.debug(f"Adding tool '{name}'") |
2419 | 2430 | elif ( |
2420 | 2431 | inspect.signature(tool_def["callable"]).return_annotation is types.Tool |
2421 | 2432 | ): |
@@ -2446,6 +2457,84 @@ async def _configure_generation( |
2446 | 2457 | filtered_params = {k: v for k, v in gen_config_params.items() if v is not None} |
2447 | 2458 | return types.GenerateContentConfig(**filtered_params) |
2448 | 2459 |
|
| 2460 | + @staticmethod |
| 2461 | + def _create_callable_from_spec( |
| 2462 | + name: str, spec: dict, callable_func: Callable[..., Awaitable[Any]] |
| 2463 | + ) -> Callable[..., Awaitable[Any]]: |
| 2464 | + """ |
| 2465 | + Dynamically creates a well-typed async function from an MCP-style tool specification. |
| 2466 | + This satisfies inspection-based SDKs (like Gemini) by providing proper |
| 2467 | + signatures, docstrings, and unique function names. |
| 2468 | + """ |
| 2469 | + import inspect |
| 2470 | + |
| 2471 | + description = spec.get("description", "") |
| 2472 | + parameters_spec = spec.get("parameters", spec.get("inputSchema", {})) |
| 2473 | + properties = parameters_spec.get("properties", {}) |
| 2474 | + required_params = parameters_spec.get("required", []) |
| 2475 | + |
| 2476 | + # Type mapping from JSON schema to Python |
| 2477 | + type_map = { |
| 2478 | + "string": str, |
| 2479 | + "number": float, |
| 2480 | + "integer": int, |
| 2481 | + "boolean": bool, |
| 2482 | + "object": dict, |
| 2483 | + "array": list, |
| 2484 | + } |
| 2485 | + |
| 2486 | + params = [] |
| 2487 | + doc_params = [] |
| 2488 | + |
| 2489 | + # Sort properties so required parameters come first to avoid "non-default argument follows default argument" |
| 2490 | + sorted_properties = sorted( |
| 2491 | + properties.items(), |
| 2492 | + key=lambda item: item[0] not in required_params, |
| 2493 | + ) |
| 2494 | + |
| 2495 | + for param_name, param_info in sorted_properties: |
| 2496 | + if param_name.startswith("__"): |
| 2497 | + continue |
| 2498 | + |
| 2499 | + param_type = type_map.get(param_info.get("type"), Any) |
| 2500 | + param_desc = param_info.get("description", "") |
| 2501 | + |
| 2502 | + default = inspect.Parameter.empty |
| 2503 | + if param_name not in required_params: |
| 2504 | + # If not required, default to None or a provided default |
| 2505 | + default = param_info.get("default", None) |
| 2506 | + |
| 2507 | + params.append( |
| 2508 | + inspect.Parameter( |
| 2509 | + name=param_name, |
| 2510 | + kind=inspect.Parameter.KEYWORD_ONLY, |
| 2511 | + default=default, |
| 2512 | + annotation=param_type, |
| 2513 | + ) |
| 2514 | + ) |
| 2515 | + |
| 2516 | + if param_desc: |
| 2517 | + doc_params.append(f":param {param_name}: {param_desc}") |
| 2518 | + |
| 2519 | + # Build the docstring |
| 2520 | + docstring = description |
| 2521 | + if doc_params: |
| 2522 | + docstring += "\n\n" + "\n".join(doc_params) |
| 2523 | + |
| 2524 | + # The actual wrapper function |
| 2525 | + async def wrapped_func(*args, **kwargs): |
| 2526 | + return await callable_func(*args, **kwargs) |
| 2527 | + |
| 2528 | + # Set metadata to satisfy SDK inspection |
| 2529 | + wrapped_func.__name__ = name |
| 2530 | + wrapped_func.__qualname__ = name |
| 2531 | + wrapped_func.__doc__ = docstring |
| 2532 | + wrapped_func.__signature__ = inspect.Signature( |
| 2533 | + parameters=params, return_annotation=Any |
| 2534 | + ) |
| 2535 | + |
| 2536 | + return wrapped_func |
| 2537 | + |
2449 | 2538 | @staticmethod |
2450 | 2539 | def _format_grounding_chunks_as_sources( |
2451 | 2540 | grounding_chunks: list[types.GroundingChunk], |
|
0 commit comments