Skip to content

Commit e299a42

Browse files
anbeckhamclaude
andcommitted
Add reference image support for style-consistent generation
generate_image now accepts an optional reference_image path so the model can match an existing image's art style, color palette, and visual mood when creating new images. This solves cross-session consistency problems (e.g. game characters that all share the same look) by sending the reference pixels alongside the prompt via Gemini's multimodal contents API. Key changes: - gemini_client: generate_image_gemini builds multi-part contents when reference image bytes are provided - image_gen: threads reference_image through generate_with_gemini and auto_generate with style-reference prompt framing; auto-loads from style profile when no explicit reference is passed - server: adds reference_image to generate_image and init_style_profile MCP tool schemas - style_profile: adds reference_image field to DEFAULT_PROFILE and create_profile for cross-session persistence Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8387e6c commit e299a42

7 files changed

Lines changed: 276 additions & 1 deletion

File tree

src/gemini_visual_mcp/gemini_client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,32 @@ async def generate_image_gemini(
106106
self,
107107
prompt: str,
108108
aspect_ratio: str = "16:9",
109+
reference_image_data: bytes | None = None,
110+
reference_mime_type: str | None = None,
109111
) -> list[dict]:
110112
"""Generate image(s) using Gemini native image generation.
111113
112114
Uses responseModalities: ["TEXT", "IMAGE"] to get inline image data.
115+
When reference_image_data is provided, the image is sent alongside the
116+
prompt so the model can match its style.
113117
114118
Returns list of dicts with keys: 'data' (bytes), 'mime_type' (str), 'text' (str|None)
115119
"""
116120

117121
def _call():
122+
if reference_image_data and reference_mime_type:
123+
contents = [
124+
types.Part.from_bytes(
125+
data=reference_image_data, mime_type=reference_mime_type
126+
),
127+
prompt,
128+
]
129+
else:
130+
contents = prompt
131+
118132
response = self._client.models.generate_content(
119133
model=GEMINI_FLASH_IMAGE,
120-
contents=prompt,
134+
contents=contents,
121135
config=types.GenerateContentConfig(
122136
response_modalities=["TEXT", "IMAGE"],
123137
),

src/gemini_visual_mcp/image_gen.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
"""
66

77
import logging
8+
from pathlib import Path
89
from typing import Optional
910

1011
from .asset_manager import save_generated
1112
from .config import DEFAULT_ASPECT_RATIO, DEFAULT_IMAGE_COUNT
1213
from .gemini_client import GeminiClient
14+
from .image_utils import read_image
1315
from .prompt_engine import enhance
1416
from .style_profile import load_profile
1517

@@ -23,21 +25,45 @@ async def generate_with_gemini(
2325
cwd: str = ".",
2426
use_profile: bool = True,
2527
template: Optional[str] = None,
28+
reference_image: Optional[str] = None,
2629
) -> list[dict]:
2730
"""Generate image(s) using Gemini 2.5 Flash (fast, iterative drafts).
2831
32+
When reference_image is provided, the model receives the image alongside
33+
the prompt and is instructed to match its art style in the new generation.
34+
2935
Returns list of dicts with: path, enhanced_prompt, warnings, model, metadata
3036
"""
3137
# Load profile
3238
profile = load_profile(cwd) if use_profile else None
3339

40+
# Auto-load reference image from profile if none provided explicitly
41+
if not reference_image and profile and profile.get("reference_image"):
42+
ref_path = profile["reference_image"]
43+
if Path(ref_path).is_file():
44+
reference_image = ref_path
45+
3446
# Enhance prompt
3547
enhanced_prompt, warnings = enhance(prompt, profile=profile, template=template)
3648

49+
# Build style-reference prompt and read image bytes when a reference is provided
50+
ref_data = None
51+
ref_mime = None
52+
if reference_image:
53+
ref_data, ref_mime = read_image(reference_image)
54+
enhanced_prompt = (
55+
"Use the provided image ONLY as a style and aesthetic reference. "
56+
"Do NOT reproduce or edit the reference image. Generate a completely "
57+
"new image matching its art style, color palette, rendering technique, "
58+
"and visual mood. The new image should depict: " + enhanced_prompt
59+
)
60+
3761
# Generate
3862
results = await client.generate_image_gemini(
3963
prompt=enhanced_prompt,
4064
aspect_ratio=aspect_ratio,
65+
reference_image_data=ref_data,
66+
reference_mime_type=ref_mime,
4167
)
4268

4369
# Save results
@@ -49,6 +75,7 @@ async def generate_with_gemini(
4975
"model": "gemini-2.5-flash-image",
5076
"aspect_ratio": aspect_ratio,
5177
"template": template or "",
78+
"reference_image": reference_image or "",
5279
"warnings": [w.to_dict() for w in warnings],
5380
}
5481

@@ -141,17 +168,32 @@ async def auto_generate(
141168
cwd: str = ".",
142169
use_profile: bool = True,
143170
template: Optional[str] = None,
171+
reference_image: Optional[str] = None,
144172
) -> list[dict]:
145173
"""Generate with automatic model selection.
146174
147175
- "gemini": Use Gemini Flash (fast drafts, iterative editing)
148176
- "imagen": Use Imagen 4 (high quality finals)
149177
- "auto": Use Gemini for drafts, Imagen for production-quality assets
150178
179+
When reference_image is provided, the Gemini path is always used
180+
(Imagen's text-to-image API does not accept reference images).
181+
151182
Auto logic: Use Gemini by default. Use Imagen when:
152183
- User explicitly says "final", "production", "high quality", "polished"
153184
- Template recommends Imagen
154185
"""
186+
# Reference images require Gemini — Imagen doesn't support image input for generation
187+
if reference_image:
188+
if model == "imagen":
189+
logger.warning(
190+
"Reference image provided with model='imagen'. "
191+
"Falling back to Gemini (Imagen does not support reference images)."
192+
)
193+
return await generate_with_gemini(
194+
client, prompt, aspect_ratio, cwd, use_profile, template, reference_image
195+
)
196+
155197
if model == "imagen":
156198
return await generate_with_imagen(
157199
client, prompt, count, aspect_ratio, cwd, use_profile, template

src/gemini_visual_mcp/server.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,17 @@ async def list_tools() -> list[Tool]:
122122
"default": True,
123123
"description": "Apply project style profile to the prompt",
124124
},
125+
"reference_image": {
126+
"type": "string",
127+
"minLength": 1,
128+
"description": (
129+
"Path to an existing image to use as a style reference. "
130+
"The new image will match the reference's art style, colors, "
131+
"and visual feel while depicting what the prompt describes. "
132+
"Use this for visual consistency (e.g., game characters in the same style). "
133+
"Auto-selects Gemini model when provided."
134+
),
135+
},
125136
},
126137
"required": ["prompt"],
127138
},
@@ -322,6 +333,14 @@ async def list_tools() -> list[Tool]:
322333
"type": "string",
323334
"description": "Design system (e.g., 'Material Design 3', 'custom')",
324335
},
336+
"reference_image": {
337+
"type": "string",
338+
"description": (
339+
"Path to a default reference image for style consistency. "
340+
"When set, all image generations will match this image's style "
341+
"unless overridden by an explicit reference_image in generate_image."
342+
),
343+
},
325344
},
326345
"required": ["project_type"],
327346
},
@@ -382,6 +401,7 @@ async def _handle_tool(self, name: str, args: dict) -> Any:
382401
cwd=self._cwd(),
383402
use_profile=args.get("use_profile", True),
384403
template=args.get("template"),
404+
reference_image=args.get("reference_image"),
385405
)
386406
# Clean up old previews on generation
387407
cleanup_old()
@@ -480,6 +500,7 @@ async def _handle_tool(self, name: str, args: dict) -> Any:
480500
visual_style=args.get("visual_style"),
481501
framework=args.get("framework"),
482502
design_system=args.get("design_system"),
503+
reference_image=args.get("reference_image"),
483504
)
484505

485506
elif name == "get_prompt_templates":
@@ -586,6 +607,7 @@ def _init_style_profile(
586607
visual_style: str | None = None,
587608
framework: str | None = None,
588609
design_system: str | None = None,
610+
reference_image: str | None = None,
589611
) -> dict:
590612
"""Create or update the project style profile."""
591613
cwd = self._cwd()
@@ -610,6 +632,8 @@ def _init_style_profile(
610632
detected["framework"] = framework
611633
if design_system:
612634
detected["design_system"] = design_system
635+
if reference_image:
636+
detected["reference_image"] = reference_image
613637

614638
# Create the profile
615639
path = create_profile(
@@ -624,6 +648,7 @@ def _init_style_profile(
624648
image_style=detected.get("image_style", ""),
625649
aspect_ratio=detected.get("default_aspect_ratio", "16:9"),
626650
resolution=detected.get("default_resolution", "1K"),
651+
reference_image=detected.get("reference_image", ""),
627652
)
628653

629654
return {

src/gemini_visual_mcp/style_profile.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"image_style": "modern illustrations",
3636
"default_aspect_ratio": "16:9",
3737
"default_resolution": "1K",
38+
"reference_image": "",
3839
}
3940

4041

@@ -80,6 +81,7 @@ def create_profile(
8081
image_style: str = "",
8182
aspect_ratio: str = "16:9",
8283
resolution: str = "1K",
84+
reference_image: str = "",
8385
) -> Path:
8486
"""Create a new style profile in the target directory."""
8587
profile = dict(DEFAULT_PROFILE)
@@ -88,6 +90,7 @@ def create_profile(
8890
profile["design_system"] = design_system
8991
profile["default_aspect_ratio"] = aspect_ratio
9092
profile["default_resolution"] = resolution
93+
profile["reference_image"] = reference_image
9194

9295
if colors:
9396
profile["colors"] = {**profile["colors"], **colors}

tests/test_gemini_client.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,73 @@ def test_sync_call_success_after_retry(self):
6565
assert func.call_count == 2
6666

6767

68+
class TestGenerateImageWithReference:
69+
"""Tests for reference image support in generate_image_gemini."""
70+
71+
@pytest.mark.asyncio
72+
async def test_reference_image_sent_as_multipart_contents(self):
73+
"""When reference image is provided, contents should be a list with Part + prompt."""
74+
with patch("gemini_visual_mcp.gemini_client.genai") as mock_genai:
75+
mock_models = MagicMock()
76+
mock_genai.Client.return_value.models = mock_models
77+
78+
# Build a fake response with an image part
79+
mock_part = MagicMock()
80+
mock_part.inline_data = MagicMock()
81+
mock_part.inline_data.data = b"generated-image"
82+
mock_part.inline_data.mime_type = "image/png"
83+
mock_part.text = None
84+
85+
mock_candidate = MagicMock()
86+
mock_candidate.content.parts = [mock_part]
87+
mock_response = MagicMock()
88+
mock_response.candidates = [mock_candidate]
89+
mock_models.generate_content = MagicMock(return_value=mock_response)
90+
91+
client = GeminiClient(api_key="test-key")
92+
await client.generate_image_gemini(
93+
prompt="A warrior in matching style",
94+
reference_image_data=b"ref-image-bytes",
95+
reference_mime_type="image/png",
96+
)
97+
98+
call_args = mock_models.generate_content.call_args
99+
contents = call_args.kwargs["contents"]
100+
# Should be a list with image Part and text prompt
101+
assert isinstance(contents, list)
102+
assert len(contents) == 2
103+
assert contents[1] == "A warrior in matching style"
104+
105+
@pytest.mark.asyncio
106+
async def test_no_reference_sends_plain_string(self):
107+
"""Without reference image, contents should be a plain string."""
108+
with patch("gemini_visual_mcp.gemini_client.genai") as mock_genai:
109+
mock_models = MagicMock()
110+
mock_genai.Client.return_value.models = mock_models
111+
112+
mock_part = MagicMock()
113+
mock_part.inline_data = MagicMock()
114+
mock_part.inline_data.data = b"generated-image"
115+
mock_part.inline_data.mime_type = "image/png"
116+
mock_part.text = None
117+
118+
mock_candidate = MagicMock()
119+
mock_candidate.content.parts = [mock_part]
120+
mock_response = MagicMock()
121+
mock_response.candidates = [mock_candidate]
122+
mock_models.generate_content = MagicMock(return_value=mock_response)
123+
124+
client = GeminiClient(api_key="test-key")
125+
await client.generate_image_gemini(
126+
prompt="A simple landscape",
127+
)
128+
129+
call_args = mock_models.generate_content.call_args
130+
contents = call_args.kwargs["contents"]
131+
assert isinstance(contents, str)
132+
assert contents == "A simple landscape"
133+
134+
68135
class TestVideoModelMap:
69136
"""Tests for video model name mapping."""
70137

0 commit comments

Comments
 (0)