Skip to content

Commit 0edf45f

Browse files
authored
Merge pull request #362 from CyberAgentAILab/fix/square-genai-edit-images
Pad GenAI edit inputs to 1024 square
2 parents b183158 + 476f82b commit 0edf45f

3 files changed

Lines changed: 238 additions & 32 deletions

File tree

packages/pinjected-genai/src/pinjected_genai/image_generation.py

Lines changed: 172 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,164 @@
11
import base64
22
from dataclasses import dataclass
33
from io import BytesIO
4-
from typing import List, Optional, Protocol, Dict
4+
from typing import Dict, List, Optional, Protocol, Tuple
55

66
from google import genai
77
from google.genai import types
88
from google.genai.types import MediaModality, GenerateContentResponseUsageMetadata
99
from loguru import logger
1010
from PIL import Image
11-
1211
from pinjected import injected
1312

1413
from .genai_pricing import GenAIModelTable, GenAIState, log_generation_cost
1514

1615

16+
TARGET_EDIT_IMAGE_DIMENSION = 1024
17+
18+
19+
@dataclass
20+
class SquarePadTransform:
21+
"""Encapsulates how an image was resized and padded to fit a square canvas."""
22+
23+
original_size: Tuple[int, int]
24+
scaled_size: Tuple[int, int]
25+
offset: Tuple[int, int]
26+
processing_mode: str
27+
image_format: str
28+
target_dim: int = TARGET_EDIT_IMAGE_DIMENSION
29+
30+
@classmethod
31+
def prepare(
32+
cls,
33+
image: Image.Image,
34+
*,
35+
target_dim: int,
36+
logger,
37+
image_index: Optional[int] = None,
38+
) -> Tuple[Image.Image, "SquarePadTransform"]:
39+
"""Return a padded image plus the transform metadata used to produce it."""
40+
41+
original_width, original_height = image.size
42+
scale_ratio = target_dim / max(original_width, original_height)
43+
scaled_width = max(1, round(original_width * scale_ratio))
44+
scaled_height = max(1, round(original_height * scale_ratio))
45+
offset_x = max(0, (target_dim - scaled_width) // 2)
46+
offset_y = max(0, (target_dim - scaled_height) // 2)
47+
48+
processing_mode, pad_color = cls._processing_mode_and_pad_color(image)
49+
working_image = image.convert(processing_mode)
50+
resized_image = working_image.resize(
51+
(scaled_width, scaled_height), Image.LANCZOS
52+
)
53+
padded_image = Image.new(processing_mode, (target_dim, target_dim), pad_color)
54+
padded_image.paste(resized_image, (offset_x, offset_y))
55+
56+
if (scaled_width, scaled_height) != (target_dim, target_dim):
57+
image_number = image_index + 1 if image_index is not None else "?"
58+
logger.warning(
59+
"Scaled and padded input image %s from %sx%s to %sx%s with offsets (%s, %s)",
60+
image_number,
61+
original_width,
62+
original_height,
63+
scaled_width,
64+
scaled_height,
65+
offset_x,
66+
offset_y,
67+
)
68+
69+
transform = cls(
70+
original_size=(original_width, original_height),
71+
scaled_size=(scaled_width, scaled_height),
72+
offset=(offset_x, offset_y),
73+
processing_mode=processing_mode,
74+
image_format=(image.format or "PNG").upper(),
75+
target_dim=target_dim,
76+
)
77+
return padded_image, transform
78+
79+
@staticmethod
80+
def _processing_mode_and_pad_color(image: Image.Image) -> Tuple[str, object]:
81+
if image.mode in {"RGBA", "LA"} or (
82+
image.mode == "P" and "transparency" in image.info
83+
):
84+
return "RGBA", (0, 0, 0, 0)
85+
if image.mode == "L":
86+
return "L", 0
87+
return "RGB", (0, 0, 0)
88+
89+
@property
90+
def mime_type(self) -> str:
91+
return f"image/{self.image_format.lower()}"
92+
93+
def to_bytes(self, padded_image: Image.Image) -> bytes:
94+
save_image = padded_image
95+
if self.image_format in {"JPEG", "JPG"} and save_image.mode == "RGBA":
96+
save_image = save_image.convert("RGB")
97+
buffer = BytesIO()
98+
save_image.save(buffer, format=self.image_format)
99+
return buffer.getvalue()
100+
101+
@staticmethod
102+
def _format_from_mime_type(mime_type: Optional[str]) -> Optional[str]:
103+
if mime_type and "/" in mime_type:
104+
return mime_type.split("/")[-1].upper()
105+
return None
106+
107+
def restore(self, image_bytes: bytes, mime_type: Optional[str], logger) -> bytes:
108+
"""Crop and resize the generated image back to the original dimensions."""
109+
110+
try:
111+
with Image.open(BytesIO(image_bytes)) as edited_image:
112+
width, height = edited_image.size
113+
if width == 0 or height == 0:
114+
logger.warning(
115+
"Generated image had invalid dimensions: %sx%s", width, height
116+
)
117+
return image_bytes
118+
119+
offset_x, offset_y = self.offset
120+
scaled_width, scaled_height = self.scaled_size
121+
122+
left = min(max(offset_x, 0), width)
123+
top = min(max(offset_y, 0), height)
124+
right = min(width, left + scaled_width)
125+
bottom = min(height, top + scaled_height)
126+
127+
if right <= left or bottom <= top:
128+
logger.warning(
129+
"Skipping post-processing; computed crop box (%s, %s, %s, %s) is invalid for size %sx%s",
130+
left,
131+
top,
132+
right,
133+
bottom,
134+
width,
135+
height,
136+
)
137+
return image_bytes
138+
139+
cropped = edited_image.crop((left, top, right, bottom))
140+
resized = cropped.resize(self.original_size, Image.LANCZOS)
141+
142+
target_format = (
143+
self._format_from_mime_type(mime_type) or self.image_format
144+
)
145+
output_image = resized
146+
if target_format in {"JPEG", "JPG"} and output_image.mode in {
147+
"RGBA",
148+
"LA",
149+
}:
150+
output_image = output_image.convert("RGB")
151+
152+
buffer = BytesIO()
153+
output_image.save(buffer, format=target_format)
154+
return buffer.getvalue()
155+
156+
except Exception as exc: # pragma: no cover - defensive logging
157+
logger.warning("Failed to post-process generated image: %s", exc)
158+
159+
return image_bytes
160+
161+
17162
def extract_modality_specific_tokens(
18163
usage_metadata: GenerateContentResponseUsageMetadata,
19164
) -> Dict[str, int]:
@@ -102,7 +247,6 @@ async def __call__(
102247
self,
103248
prompt: str,
104249
model: str,
105-
temperature: float = 0.9,
106250
) -> GenerationResult: ...
107251

108252

@@ -115,7 +259,6 @@ async def a_generate_image__genai(
115259
/,
116260
prompt: str,
117261
model: str,
118-
temperature: float = 0.9,
119262
) -> GenerationResult:
120263
"""Generate an image using Google Gen AI SDK with nano-banana model."""
121264

@@ -124,7 +267,7 @@ async def a_generate_image__genai(
124267
try:
125268
# Configure generation with image output
126269
config = types.GenerateContentConfig(
127-
temperature=temperature,
270+
temperature=0.9,
128271
max_output_tokens=8192,
129272
response_modalities=["TEXT", "IMAGE"], # Enable image generation
130273
)
@@ -230,39 +373,38 @@ async def a_edit_image__genai(
230373
try:
231374
# Build contents list with prompt and any input images
232375
contents = [prompt]
376+
first_transform: Optional[SquarePadTransform] = None
233377

234378
# Add each input image to the contents
235379
for idx, img in enumerate(input_images):
236380
logger.info(f"Processing input image {idx + 1}")
237381

238-
image_format = (img.format or "PNG").upper()
239-
processed_img = img
240382
original_width, original_height = img.size
241-
max_dimension = max(original_width, original_height)
242383

243-
if max_dimension > 1024:
244-
scale_ratio = 1024 / max_dimension
245-
new_width = max(1, int(original_width * scale_ratio))
246-
new_height = max(1, int(original_height * scale_ratio))
247-
processed_img = img.resize((new_width, new_height), Image.LANCZOS)
384+
if original_width == 0 or original_height == 0:
248385
logger.warning(
249-
"Scaled input image %s from %sx%s to %sx%s to satisfy 1024px max dimension",
386+
"Skipping input image %s due to zero dimension: %sx%s",
250387
idx + 1,
251388
original_width,
252389
original_height,
253-
new_width,
254-
new_height,
255390
)
391+
continue
256392

257-
# Convert PIL Image to bytes after optional scaling
258-
img_byte_arr = BytesIO()
259-
processed_img.save(img_byte_arr, format=image_format)
260-
image_bytes = img_byte_arr.getvalue()
393+
padded_image, transform = SquarePadTransform.prepare(
394+
img,
395+
target_dim=TARGET_EDIT_IMAGE_DIMENSION,
396+
logger=logger,
397+
image_index=idx,
398+
)
399+
if first_transform is None:
400+
first_transform = transform
401+
402+
image_bytes = transform.to_bytes(padded_image)
261403

262404
contents.append(
263405
types.Part.from_bytes(
264406
data=image_bytes,
265-
mime_type=f"image/{image_format.lower()}",
407+
mime_type=transform.mime_type,
266408
)
267409
)
268410

@@ -295,8 +437,15 @@ async def a_edit_image__genai(
295437
image_bytes = part.inline_data.data
296438
mime_type = part.inline_data.mime_type
297439
if image_bytes and not generated_image: # Take first image
440+
processed_bytes = image_bytes
441+
if first_transform:
442+
processed_bytes = first_transform.restore(
443+
image_bytes=image_bytes,
444+
mime_type=mime_type,
445+
logger=logger,
446+
)
298447
generated_image = GeneratedImage(
299-
image_data=image_bytes,
448+
image_data=processed_bytes,
300449
mime_type=mime_type or "image/png",
301450
prompt_used=f"Edit: {prompt}",
302451
)
@@ -342,7 +491,6 @@ async def __call__(
342491
image_path: str,
343492
prompt: Optional[str] = None,
344493
model: str = "gemini-2.5-flash",
345-
temperature: float = 0.7,
346494
) -> str: ...
347495

348496

@@ -356,7 +504,6 @@ async def a_describe_image__genai(
356504
image_path: str,
357505
prompt: Optional[str] = None,
358506
model: str = "gemini-2.5-flash",
359-
temperature: float = 0.7,
360507
) -> str:
361508
"""Describe an image using Google Gen AI SDK."""
362509
logger.info(f"Describing image: {image_path}")
@@ -384,7 +531,7 @@ async def a_describe_image__genai(
384531

385532
# Configure generation
386533
config = types.GenerateContentConfig(
387-
temperature=temperature,
534+
temperature=0.7,
388535
max_output_tokens=2048,
389536
)
390537

packages/pinjected-genai/tests/test_cost_tracking.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
from dataclasses import dataclass
5+
from io import BytesIO
56
from unittest.mock import Mock
67
from pinjected_genai.genai_pricing import (
78
GenAIModelTable,
@@ -13,6 +14,8 @@
1314
from pinjected import design
1415
from pinjected.test import injected_pytest
1516
from loguru import logger
17+
from PIL import Image as PILImage
18+
from google.genai.types import GenerateContentResponseUsageMetadata
1619

1720

1821
@dataclass
@@ -299,6 +302,12 @@ def test_log_generation_cost_periodic_breakdown(
299302
assert breakdown_logged, "Detailed breakdown should be logged on 10th request"
300303

301304

305+
def _dummy_png_bytes(size=(1024, 1024), color=(0, 255, 0)) -> bytes:
306+
buffer = BytesIO()
307+
PILImage.new("RGB", size, color=color).save(buffer, format="PNG")
308+
return buffer.getvalue()
309+
310+
302311
def create_mock_image_response(text_parts=None, image_count=1):
303312
"""Create a mock response for image generation."""
304313
response = Mock()
@@ -322,12 +331,21 @@ def create_mock_image_response(text_parts=None, image_count=1):
322331
part = Mock()
323332
part.text = None
324333
part.inline_data = Mock()
325-
part.inline_data.data = f"fake_image_data_{i}".encode()
334+
part.inline_data.data = _dummy_png_bytes(
335+
color=(10 * i % 255, 40 * i % 255, 70 * i % 255)
336+
)
326337
part.inline_data.mime_type = "image/png"
327338
candidate.content.parts.append(part)
328339

329340
response.candidates.append(candidate)
330341

342+
response.usage_metadata = GenerateContentResponseUsageMetadata(
343+
prompt_token_count=0,
344+
candidates_token_count=0,
345+
prompt_tokens_details=[],
346+
candidates_tokens_details=[],
347+
)
348+
331349
return response
332350

333351

0 commit comments

Comments
 (0)