11import base64
22from dataclasses import dataclass
33from io import BytesIO
4- from typing import List , Optional , Protocol , Dict
4+ from typing import Dict , List , Optional , Protocol , Tuple
55
66from google import genai
77from google .genai import types
88from google .genai .types import MediaModality , GenerateContentResponseUsageMetadata
99from loguru import logger
1010from PIL import Image
11-
1211from pinjected import injected
1312
1413from .genai_pricing import GenAIModelTable , GenAIState , log_generation_cost
1514
1615
16+ TARGET_EDIT_IMAGE_DIMENSION = 1024
17+
18+
19+ @dataclass
20+ class SquarePadTransform :
21+ """Encapsulates how an image was resized and padded to fit a square canvas."""
22+
23+ original_size : Tuple [int , int ]
24+ scaled_size : Tuple [int , int ]
25+ offset : Tuple [int , int ]
26+ processing_mode : str
27+ image_format : str
28+ target_dim : int = TARGET_EDIT_IMAGE_DIMENSION
29+
30+ @classmethod
31+ def prepare (
32+ cls ,
33+ image : Image .Image ,
34+ * ,
35+ target_dim : int ,
36+ logger ,
37+ image_index : Optional [int ] = None ,
38+ ) -> Tuple [Image .Image , "SquarePadTransform" ]:
39+ """Return a padded image plus the transform metadata used to produce it."""
40+
41+ original_width , original_height = image .size
42+ scale_ratio = target_dim / max (original_width , original_height )
43+ scaled_width = max (1 , round (original_width * scale_ratio ))
44+ scaled_height = max (1 , round (original_height * scale_ratio ))
45+ offset_x = max (0 , (target_dim - scaled_width ) // 2 )
46+ offset_y = max (0 , (target_dim - scaled_height ) // 2 )
47+
48+ processing_mode , pad_color = cls ._processing_mode_and_pad_color (image )
49+ working_image = image .convert (processing_mode )
50+ resized_image = working_image .resize (
51+ (scaled_width , scaled_height ), Image .LANCZOS
52+ )
53+ padded_image = Image .new (processing_mode , (target_dim , target_dim ), pad_color )
54+ padded_image .paste (resized_image , (offset_x , offset_y ))
55+
56+ if (scaled_width , scaled_height ) != (target_dim , target_dim ):
57+ image_number = image_index + 1 if image_index is not None else "?"
58+ logger .warning (
59+ "Scaled and padded input image %s from %sx%s to %sx%s with offsets (%s, %s)" ,
60+ image_number ,
61+ original_width ,
62+ original_height ,
63+ scaled_width ,
64+ scaled_height ,
65+ offset_x ,
66+ offset_y ,
67+ )
68+
69+ transform = cls (
70+ original_size = (original_width , original_height ),
71+ scaled_size = (scaled_width , scaled_height ),
72+ offset = (offset_x , offset_y ),
73+ processing_mode = processing_mode ,
74+ image_format = (image .format or "PNG" ).upper (),
75+ target_dim = target_dim ,
76+ )
77+ return padded_image , transform
78+
79+ @staticmethod
80+ def _processing_mode_and_pad_color (image : Image .Image ) -> Tuple [str , object ]:
81+ if image .mode in {"RGBA" , "LA" } or (
82+ image .mode == "P" and "transparency" in image .info
83+ ):
84+ return "RGBA" , (0 , 0 , 0 , 0 )
85+ if image .mode == "L" :
86+ return "L" , 0
87+ return "RGB" , (0 , 0 , 0 )
88+
89+ @property
90+ def mime_type (self ) -> str :
91+ return f"image/{ self .image_format .lower ()} "
92+
93+ def to_bytes (self , padded_image : Image .Image ) -> bytes :
94+ save_image = padded_image
95+ if self .image_format in {"JPEG" , "JPG" } and save_image .mode == "RGBA" :
96+ save_image = save_image .convert ("RGB" )
97+ buffer = BytesIO ()
98+ save_image .save (buffer , format = self .image_format )
99+ return buffer .getvalue ()
100+
101+ @staticmethod
102+ def _format_from_mime_type (mime_type : Optional [str ]) -> Optional [str ]:
103+ if mime_type and "/" in mime_type :
104+ return mime_type .split ("/" )[- 1 ].upper ()
105+ return None
106+
107+ def restore (self , image_bytes : bytes , mime_type : Optional [str ], logger ) -> bytes :
108+ """Crop and resize the generated image back to the original dimensions."""
109+
110+ try :
111+ with Image .open (BytesIO (image_bytes )) as edited_image :
112+ width , height = edited_image .size
113+ if width == 0 or height == 0 :
114+ logger .warning (
115+ "Generated image had invalid dimensions: %sx%s" , width , height
116+ )
117+ return image_bytes
118+
119+ offset_x , offset_y = self .offset
120+ scaled_width , scaled_height = self .scaled_size
121+
122+ left = min (max (offset_x , 0 ), width )
123+ top = min (max (offset_y , 0 ), height )
124+ right = min (width , left + scaled_width )
125+ bottom = min (height , top + scaled_height )
126+
127+ if right <= left or bottom <= top :
128+ logger .warning (
129+ "Skipping post-processing; computed crop box (%s, %s, %s, %s) is invalid for size %sx%s" ,
130+ left ,
131+ top ,
132+ right ,
133+ bottom ,
134+ width ,
135+ height ,
136+ )
137+ return image_bytes
138+
139+ cropped = edited_image .crop ((left , top , right , bottom ))
140+ resized = cropped .resize (self .original_size , Image .LANCZOS )
141+
142+ target_format = (
143+ self ._format_from_mime_type (mime_type ) or self .image_format
144+ )
145+ output_image = resized
146+ if target_format in {"JPEG" , "JPG" } and output_image .mode in {
147+ "RGBA" ,
148+ "LA" ,
149+ }:
150+ output_image = output_image .convert ("RGB" )
151+
152+ buffer = BytesIO ()
153+ output_image .save (buffer , format = target_format )
154+ return buffer .getvalue ()
155+
156+ except Exception as exc : # pragma: no cover - defensive logging
157+ logger .warning ("Failed to post-process generated image: %s" , exc )
158+
159+ return image_bytes
160+
161+
17162def extract_modality_specific_tokens (
18163 usage_metadata : GenerateContentResponseUsageMetadata ,
19164) -> Dict [str , int ]:
@@ -102,7 +247,6 @@ async def __call__(
102247 self ,
103248 prompt : str ,
104249 model : str ,
105- temperature : float = 0.9 ,
106250 ) -> GenerationResult : ...
107251
108252
@@ -115,7 +259,6 @@ async def a_generate_image__genai(
115259 / ,
116260 prompt : str ,
117261 model : str ,
118- temperature : float = 0.9 ,
119262) -> GenerationResult :
120263 """Generate an image using Google Gen AI SDK with nano-banana model."""
121264
@@ -124,7 +267,7 @@ async def a_generate_image__genai(
124267 try :
125268 # Configure generation with image output
126269 config = types .GenerateContentConfig (
127- temperature = temperature ,
270+ temperature = 0.9 ,
128271 max_output_tokens = 8192 ,
129272 response_modalities = ["TEXT" , "IMAGE" ], # Enable image generation
130273 )
@@ -230,39 +373,38 @@ async def a_edit_image__genai(
230373 try :
231374 # Build contents list with prompt and any input images
232375 contents = [prompt ]
376+ first_transform : Optional [SquarePadTransform ] = None
233377
234378 # Add each input image to the contents
235379 for idx , img in enumerate (input_images ):
236380 logger .info (f"Processing input image { idx + 1 } " )
237381
238- image_format = (img .format or "PNG" ).upper ()
239- processed_img = img
240382 original_width , original_height = img .size
241- max_dimension = max (original_width , original_height )
242383
243- if max_dimension > 1024 :
244- scale_ratio = 1024 / max_dimension
245- new_width = max (1 , int (original_width * scale_ratio ))
246- new_height = max (1 , int (original_height * scale_ratio ))
247- processed_img = img .resize ((new_width , new_height ), Image .LANCZOS )
384+ if original_width == 0 or original_height == 0 :
248385 logger .warning (
249- "Scaled input image %s from %sx%s to %sx%s to satisfy 1024px max dimension " ,
386+ "Skipping input image %s due to zero dimension: %sx%s" ,
250387 idx + 1 ,
251388 original_width ,
252389 original_height ,
253- new_width ,
254- new_height ,
255390 )
391+ continue
256392
257- # Convert PIL Image to bytes after optional scaling
258- img_byte_arr = BytesIO ()
259- processed_img .save (img_byte_arr , format = image_format )
260- image_bytes = img_byte_arr .getvalue ()
393+ padded_image , transform = SquarePadTransform .prepare (
394+ img ,
395+ target_dim = TARGET_EDIT_IMAGE_DIMENSION ,
396+ logger = logger ,
397+ image_index = idx ,
398+ )
399+ if first_transform is None :
400+ first_transform = transform
401+
402+ image_bytes = transform .to_bytes (padded_image )
261403
262404 contents .append (
263405 types .Part .from_bytes (
264406 data = image_bytes ,
265- mime_type = f"image/ { image_format . lower () } " ,
407+ mime_type = transform . mime_type ,
266408 )
267409 )
268410
@@ -295,8 +437,15 @@ async def a_edit_image__genai(
295437 image_bytes = part .inline_data .data
296438 mime_type = part .inline_data .mime_type
297439 if image_bytes and not generated_image : # Take first image
440+ processed_bytes = image_bytes
441+ if first_transform :
442+ processed_bytes = first_transform .restore (
443+ image_bytes = image_bytes ,
444+ mime_type = mime_type ,
445+ logger = logger ,
446+ )
298447 generated_image = GeneratedImage (
299- image_data = image_bytes ,
448+ image_data = processed_bytes ,
300449 mime_type = mime_type or "image/png" ,
301450 prompt_used = f"Edit: { prompt } " ,
302451 )
@@ -342,7 +491,6 @@ async def __call__(
342491 image_path : str ,
343492 prompt : Optional [str ] = None ,
344493 model : str = "gemini-2.5-flash" ,
345- temperature : float = 0.7 ,
346494 ) -> str : ...
347495
348496
@@ -356,7 +504,6 @@ async def a_describe_image__genai(
356504 image_path : str ,
357505 prompt : Optional [str ] = None ,
358506 model : str = "gemini-2.5-flash" ,
359- temperature : float = 0.7 ,
360507) -> str :
361508 """Describe an image using Google Gen AI SDK."""
362509 logger .info (f"Describing image: { image_path } " )
@@ -384,7 +531,7 @@ async def a_describe_image__genai(
384531
385532 # Configure generation
386533 config = types .GenerateContentConfig (
387- temperature = temperature ,
534+ temperature = 0.7 ,
388535 max_output_tokens = 2048 ,
389536 )
390537
0 commit comments