|
| 1 | +import asyncio |
| 2 | +import contextvars |
| 3 | +import os |
| 4 | +import threading |
| 5 | +from typing import List |
| 6 | + |
| 7 | +from google import genai |
| 8 | +from google.adk.tools import ToolContext |
| 9 | +from google.genai.types import GenerateContentConfig, Modality |
| 10 | + |
| 11 | +# Default image generation model (Nanobanana 2) |
| 12 | +DEFAULT_IMAGE_MODEL = "gemini-3.1-flash-image-preview" |
| 13 | + |
| 14 | +# Thread-safe storage for generated images keyed by session_id |
| 15 | +_generated_images: dict[str, List[bytes]] = {} |
| 16 | +_images_lock = threading.Lock() |
| 17 | + |
| 18 | +# ContextVar set by the request handler before running the agent |
| 19 | +current_session_id: contextvars.ContextVar[str] = contextvars.ContextVar( |
| 20 | + "current_session_id", default="unknown" |
| 21 | +) |
| 22 | + |
| 23 | + |
| 24 | +def get_and_clear_images(session_id: str) -> List[bytes]: |
| 25 | + """Retrieve and remove generated images for a session.""" |
| 26 | + with _images_lock: |
| 27 | + return _generated_images.pop(session_id, []) |
| 28 | + |
| 29 | + |
| 30 | +async def generate_image(prompt: str, tool_context: ToolContext, model: str = ""): |
| 31 | + """Generates images using Gemini image generation models (Nanobanana Pro / Nanobanana 2). |
| 32 | +
|
| 33 | + Use this tool when the user asks you to create, draw, generate, or design an image. |
| 34 | +
|
| 35 | + Args: |
| 36 | + prompt: A detailed description of the image to generate. |
| 37 | + model: The model to use for image generation. |
| 38 | + Use "gemini-3-pro-image-preview" (Nanobanana Pro) for higher quality. |
| 39 | + Use "gemini-3.1-flash-image-preview" (Nanobanana 2) for faster generation. |
| 40 | + Defaults to Nanobanana 2 if not specified. |
| 41 | + """ |
| 42 | + image_model = model if model else os.environ.get( |
| 43 | + "IMAGE_MODEL_NAME", DEFAULT_IMAGE_MODEL |
| 44 | + ) |
| 45 | + |
| 46 | + project_id = os.environ.get("GOOGLE_CLOUD_PROJECT") |
| 47 | + location = os.environ.get("GOOGLE_CLOUD_LOCATION", "global") |
| 48 | + |
| 49 | + def call_gemini(): |
| 50 | + client = genai.Client(vertexai=True, project=project_id, location=location) |
| 51 | + response = client.models.generate_content( |
| 52 | + model=image_model, |
| 53 | + contents=prompt, |
| 54 | + config=GenerateContentConfig( |
| 55 | + response_modalities=[Modality.TEXT, Modality.IMAGE], |
| 56 | + ), |
| 57 | + ) |
| 58 | + return response |
| 59 | + |
| 60 | + try: |
| 61 | + response = await asyncio.to_thread(call_gemini) |
| 62 | + except Exception as e: |
| 63 | + return {"error": f"Image generation failed: {e}"} |
| 64 | + |
| 65 | + text_parts = [] |
| 66 | + images = [] |
| 67 | + |
| 68 | + candidates = getattr(response, "candidates", None) |
| 69 | + if candidates: |
| 70 | + for part in candidates[0].content.parts or []: |
| 71 | + if getattr(part, "thought", None): |
| 72 | + continue |
| 73 | + if getattr(part, "text", None): |
| 74 | + text_parts.append(part.text) |
| 75 | + continue |
| 76 | + inline = getattr(part, "inline_data", None) |
| 77 | + if inline and getattr(inline, "data", None): |
| 78 | + images.append(inline.data) |
| 79 | + |
| 80 | + if not images: |
| 81 | + return { |
| 82 | + "status": "no_image_generated", |
| 83 | + "text": "\n".join(text_parts) if text_parts else "No image was generated.", |
| 84 | + } |
| 85 | + |
| 86 | + # Store images for the main handler to upload to Slack |
| 87 | + session_id = current_session_id.get() |
| 88 | + with _images_lock: |
| 89 | + _generated_images.setdefault(session_id, []).extend(images) |
| 90 | + |
| 91 | + return { |
| 92 | + "status": "success", |
| 93 | + "model": image_model, |
| 94 | + "image_count": len(images), |
| 95 | + "text": "\n".join(text_parts) if text_parts else f"{len(images)} image(s) generated successfully.", |
| 96 | + } |
0 commit comments