Skip to content

Commit ad018e3

Browse files
authored
Merge pull request #357 from CyberAgentAILab/feature/add-temperature-parameter
feat: add temperature parameter to GenAI image functions
2 parents 5ec5772 + f7f46e7 commit ad018e3

2 files changed

Lines changed: 18 additions & 6 deletions

File tree

packages/pinjected-genai/src/pinjected_genai/image_generation.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ async def __call__(
102102
self,
103103
prompt: str,
104104
model: str,
105+
temperature: float = 0.9,
105106
) -> GenerationResult: ...
106107

107108

@@ -114,6 +115,7 @@ async def a_generate_image__genai(
114115
/,
115116
prompt: str,
116117
model: str,
118+
temperature: float = 0.9,
117119
) -> GenerationResult:
118120
"""Generate an image using Google Gen AI SDK with nano-banana model."""
119121

@@ -122,7 +124,7 @@ async def a_generate_image__genai(
122124
try:
123125
# Configure generation with image output
124126
config = types.GenerateContentConfig(
125-
temperature=0.9,
127+
temperature=temperature,
126128
max_output_tokens=8192,
127129
response_modalities=["TEXT", "IMAGE"], # Enable image generation
128130
)
@@ -205,6 +207,7 @@ async def __call__(
205207
input_images: List[Image.Image],
206208
prompt: str,
207209
model: str,
210+
temperature: float = 0.9,
208211
) -> GenerationResult: ...
209212

210213

@@ -218,6 +221,7 @@ async def a_edit_image__genai(
218221
input_images: List[Image.Image],
219222
prompt: str,
220223
model: str,
224+
temperature: float = 0.9,
221225
) -> GenerationResult:
222226
"""Edit/generate an image based on input images (can be empty or multiple) using Google Gen AI SDK."""
223227

@@ -247,7 +251,7 @@ async def a_edit_image__genai(
247251

248252
# Configure generation with image output
249253
config = types.GenerateContentConfig(
250-
temperature=0.9,
254+
temperature=temperature,
251255
max_output_tokens=8192,
252256
response_modalities=["TEXT", "IMAGE"], # Enable image generation
253257
)
@@ -321,6 +325,7 @@ async def __call__(
321325
image_path: str,
322326
prompt: Optional[str] = None,
323327
model: str = "gemini-2.5-flash",
328+
temperature: float = 0.7,
324329
) -> str: ...
325330

326331

@@ -334,6 +339,7 @@ async def a_describe_image__genai(
334339
image_path: str,
335340
prompt: Optional[str] = None,
336341
model: str = "gemini-2.5-flash",
342+
temperature: float = 0.7,
337343
) -> str:
338344
"""Describe an image using Google Gen AI SDK."""
339345
logger.info(f"Describing image: {image_path}")
@@ -361,7 +367,7 @@ async def a_describe_image__genai(
361367

362368
# Configure generation
363369
config = types.GenerateContentConfig(
364-
temperature=0.7,
370+
temperature=temperature,
365371
max_output_tokens=2048,
366372
)
367373

packages/pinjected-genai/src/pinjected_genai/image_generation.pyi

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class AGenerateImageProtocol(Protocol):
2727
self,
2828
prompt: str,
2929
model: str,
30+
temperature: float = 0.9,
3031
) -> GenerationResult: ...
3132

3233
class AEditImageProtocol(Protocol):
@@ -35,6 +36,7 @@ class AEditImageProtocol(Protocol):
3536
input_images: List[Image.Image],
3637
prompt: str,
3738
model: str,
39+
temperature: float = 0.9,
3840
) -> GenerationResult: ...
3941

4042
class ADescribeImageProtocol(Protocol):
@@ -43,6 +45,7 @@ class ADescribeImageProtocol(Protocol):
4345
image_path: str,
4446
prompt: Optional[str] = None,
4547
model: str = "gemini-2.5-flash",
48+
temperature: float = 0.7,
4649
) -> str: ...
4750

4851
# Gen AI SDK functions
@@ -52,15 +55,18 @@ a_describe_image__genai: ADescribeImageProtocol
5255

5356
@overload
5457
async def a_generate_image__genai(
55-
prompt: str, model: str
58+
prompt: str, model: str, temperature: float = 0.9
5659
) -> IProxy[GenerationResult]: ...
5760
@overload
5861
async def a_edit_image__genai(
59-
input_images: List[Image.Image], prompt: str, model: str
62+
input_images: List[Image.Image], prompt: str, model: str, temperature: float = 0.9
6063
) -> IProxy[GenerationResult]: ...
6164
@overload
6265
async def a_describe_image__genai(
63-
image_path: str, prompt: Optional[str] = ..., model: str = ...
66+
image_path: str,
67+
prompt: Optional[str] = ...,
68+
model: str = ...,
69+
temperature: float = ...,
6470
) -> IProxy[str]: ...
6571

6672
# Additional symbols:

0 commit comments

Comments
 (0)