Skip to content

Commit f2bde3d

Browse files
llcourageLIT team
authored andcommitted
Add a GCP text to image demo for LIT.
PiperOrigin-RevId: 758748858
1 parent 7c2e754 commit f2bde3d

4 files changed

Lines changed: 376 additions & 0 deletions

File tree

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Data loaders for text to image models."""
2+
3+
from lit_nlp.api import dataset as lit_dataset
4+
from lit_nlp.api import types as lit_types
5+
6+
7+
class TextToImageDataset(lit_dataset.Dataset):
8+
"""TextToImageDataset is a dataset that contains a list of prompts.
9+
10+
It is used to generate images using the text to image models.
11+
"""
12+
13+
def __init__(self, prompts: list[str]):
14+
self._examples = []
15+
for prompt in prompts:
16+
self._examples.append({"prompt": prompt})
17+
18+
@classmethod
19+
def init_spec(cls) -> lit_types.Spec:
20+
return {"prompt": lit_types.TextSegment(required=True)}
21+
22+
def spec(self) -> lit_types.Spec:
23+
return {"prompt": lit_types.TextSegment()}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
r"""A blank demo ready to load generative text to image models and datasets.
2+
3+
To use with VertexAI Model Garden models, you must install the following packages:
4+
pip install vertexai>=1.49.0
5+
To run the demo, you must set you GCP project location and project id.
6+
7+
Currently, the demo only supports the image generation models in the Model
8+
Garden.
9+
10+
The following command can be used to run the demo:
11+
blaze run -c opt examples/gcp_text_to_image:demo -- \
12+
--project_id=$GCP_PROJECT_ID \
13+
--project_location=$GCP_PROJECT_LOCATION \
14+
--alsologtostderr
15+
Then navigate to localhost:5432 to access the demo UI.
16+
"""
17+
18+
from collections.abc import Sequence
19+
import sys
20+
from typing import Optional
21+
22+
from absl import app
23+
from absl import flags
24+
from absl import logging
25+
import google.auth
26+
from google.cloud.aiplatform import vertexai
27+
from lit_nlp import app as lit_app
28+
from lit_nlp import dev_server
29+
from lit_nlp import server_flags
30+
from lit_nlp.api import layout
31+
from lit_nlp.examples.gcp_text_to_image import datasets as gcp_text_to_image_datasets
32+
from lit_nlp.examples.gcp_text_to_image import models as gcp_text_to_image_models
33+
34+
35+
FLAGS = flags.FLAGS
36+
# Define GCP project information and vertex AI API key.
37+
LOCATION = flags.DEFINE_string(
38+
'project_location',
39+
None,
40+
'Please enter your GCP project location',
41+
required=True,
42+
)
43+
PROJECT_ID = flags.DEFINE_string(
44+
'project_id',
45+
None,
46+
'Please enter your project id',
47+
required=True,
48+
)
49+
50+
# Custom frontend layout; see api/layout.py
51+
_modules = layout.LitModuleName
52+
_IMAGE_LAYOUT = layout.LitCanonicalLayout(
53+
upper={
54+
'Main': [
55+
_modules.DataTableModule,
56+
_modules.DatapointEditorModule,
57+
]
58+
},
59+
lower={
60+
'Predictions': [
61+
_modules.GeneratedImageModule,
62+
_modules.GeneratedTextModule,
63+
],
64+
},
65+
description='Custom layout for Text to Image models.',
66+
)
67+
68+
69+
CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {'_IMAGE_LAYOUT': _IMAGE_LAYOUT}
70+
71+
_CANNED_PROMPTS = ['I have a dream', 'I have a shiba dog named cola']
72+
73+
74+
def get_wsgi_app() -> Optional[dev_server.LitServerType]:
75+
"""Return WSGI app for container-hosted demos."""
76+
FLAGS.set_default('server_type', 'external')
77+
FLAGS.set_default('demo_mode', True)
78+
# Parse flags without calling app.run(main), to avoid conflict with
79+
# gunicorn command line flags.
80+
unused = flags.FLAGS(sys.argv, known_only=True)
81+
if unused:
82+
logging.info(
83+
'generateive_demo:get_wsgi_app() called with unused args: %s', unused
84+
)
85+
return main([])
86+
87+
88+
def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
89+
if len(argv) > 1:
90+
raise app.UsageError('Too many command-line arguments.')
91+
92+
creds, _ = google.auth.default(
93+
scopes=['https://www.googleapis.com/auth/cloud-platform']
94+
)
95+
creds = creds.with_quota_project(PROJECT_ID.value)
96+
vertexai.init(
97+
project=PROJECT_ID.value,
98+
location=LOCATION.value,
99+
credentials=creds,
100+
)
101+
models = {}
102+
model_loaders: lit_app.ModelLoadersMap = {}
103+
model_loaders['text_to_image'] = (
104+
gcp_text_to_image_models.VertexModelGardenModel,
105+
gcp_text_to_image_models.VertexModelGardenModel.init_spec(),
106+
)
107+
108+
datasets = {
109+
'prompts': gcp_text_to_image_datasets.TextToImageDataset(_CANNED_PROMPTS)
110+
}
111+
dataset_loaders: lit_app.DatasetLoadersMap = {}
112+
dataset_loaders['text_to_image'] = (
113+
gcp_text_to_image_datasets.TextToImageDataset,
114+
gcp_text_to_image_datasets.TextToImageDataset.init_spec(),
115+
)
116+
117+
lit_demo = dev_server.Server(
118+
models=models,
119+
model_loaders=model_loaders,
120+
datasets=datasets,
121+
dataset_loaders=dataset_loaders,
122+
layout=layout.DEFAULT_LAYOUTS,
123+
**server_flags.get_flags()
124+
)
125+
return lit_demo.serve()
126+
127+
128+
if __name__ == '__main__':
129+
app.run(main)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Model Wrapper for generative models."""
2+
3+
from collections.abc import Iterable
4+
import io
5+
import logging
6+
import time
7+
from typing import Literal, Optional, Union
8+
from vertexai import vision_models
9+
from lit_nlp.api import model as lit_model
10+
from lit_nlp.api import types as lit_types
11+
from lit_nlp.lib import image_utils
12+
from PIL import Image
13+
14+
_MAX_NUM_RETRIES = 5
15+
16+
_DEFAULT_CANDIDATE_COUNT = 1
17+
18+
_DEFAULT_MAX_OUTPUT_TOKENS = 256
19+
20+
_IMAGE_PREFIX = 'data:image/png;base64,'
21+
22+
23+
class VertexModelGardenModel(lit_model.BatchedRemoteModel):
24+
"""VertexModelGardenModel is a wrapper for Vertex AI Model Garden model.
25+
26+
Attributes:
27+
model_name: The name of the model to load.
28+
max_concurrent_requests: The maximum number of concurrent requests to the
29+
model.
30+
max_qps: The maximum number of queries per second to the model.
31+
temperature: The temperature to use for the model.
32+
candidate_count: The number of candidates to generate.
33+
max_output_tokens: The maximum number of tokens to generate.
34+
35+
Please note the model will predict all examples at a fixed temperature.
36+
"""
37+
38+
def __init__(
39+
self,
40+
model_name: str = 'imagen-3.0-generate-002',
41+
max_concurrent_requests: int = 4,
42+
max_qps: Union[int, float] = 25,
43+
aspect_ratio: Optional[
44+
Literal['16:9', '1:1', '3:4', '4:3', '9:16']
45+
] = None,
46+
width: int = 256,
47+
height: int = 256,
48+
):
49+
super().__init__(max_concurrent_requests, max_qps)
50+
# Connect to the remote model.
51+
self._model = vision_models.ImageGenerationModel.from_pretrained(model_name)
52+
self._aspect_ratio = aspect_ratio
53+
self._width = width
54+
self._height = height
55+
56+
def query_model(self, prompt: str, **unused_kw) -> list[lit_types.JsonDict]:
57+
num_attempts = 0
58+
predictions = None
59+
exception = None
60+
width = self._width
61+
height = self._height
62+
63+
while num_attempts < _MAX_NUM_RETRIES and predictions is None:
64+
num_attempts += 1
65+
66+
try:
67+
predictions = self._model.generate_images(
68+
prompt=prompt,
69+
aspect_ratio=self._aspect_ratio,
70+
)
71+
except Exception as e: # pylint: disable=broad-except
72+
wait_time = 2**num_attempts
73+
exception = e
74+
logging.warning('Waiting %ds to retry... (%s)', wait_time, e)
75+
time.sleep(2**num_attempts)
76+
77+
if predictions is None:
78+
raise ValueError(
79+
f'Failed to get predictions. ({exception})'
80+
) from exception
81+
82+
if not isinstance(predictions, Iterable):
83+
raise ValueError(f'Predictions is not an Iterable: {type(predictions)}')
84+
85+
images = []
86+
for image_ in predictions.images:
87+
pil_img = Image.open(io.BytesIO(getattr(image_, '_image_bytes')))
88+
pil_img = pil_img.resize((width, height))
89+
images.append(image_utils.convert_pil_to_image_str(pil_img))
90+
91+
return images
92+
93+
def predict_minibatch(
94+
self, inputs: list[lit_types.JsonDict]
95+
) -> list[lit_types.JsonDict]:
96+
"""The model can generate up to 8 images per run, but LIT may only show one due to frontend limitations.
97+
98+
In MinDalle demos, the grid_size parameter controls layout—for example,
99+
grid_size=2 creates a 2x2 grid of sub-images, rendered as a single final
100+
image. That’s why only one image might appear even if multiple are
101+
generated.
102+
103+
Args:
104+
inputs: A list of input dictionaries, each containing a 'prompt'.
105+
106+
Returns:
107+
A list of dictionaries, each containing the generated 'image' and the
108+
original 'prompt'.
109+
"""
110+
results = []
111+
for inp in inputs:
112+
prompt = inp['prompt']
113+
b64_strs = self.query_model(prompt)
114+
if not b64_strs:
115+
raise ValueError(f'No images generated for prompt: {prompt}')
116+
results.append({
117+
'image': b64_strs[0],
118+
'prompt': prompt,
119+
})
120+
return results
121+
122+
@classmethod
123+
def init_spec(cls) -> lit_types.Spec:
124+
return {
125+
'model_name': lit_types.String(
126+
default='imagen-3.0-generate-002', required=True
127+
),
128+
'aspect_ratio': lit_types.String(default='1:1', required=False),
129+
'width': lit_types.Integer(default=256, required=False),
130+
'height': lit_types.Integer(default=256, required=False),
131+
}
132+
133+
def input_spec(self) -> lit_types.Spec:
134+
return {
135+
'prompt': lit_types.TextSegment(),
136+
}
137+
138+
def output_spec(self):
139+
return {
140+
'image': lit_types.ImageBytesList(),
141+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import base64
2+
from unittest import mock
3+
from absl.testing import absltest
4+
from vertexai import vision_models
5+
from lit_nlp.examples.gcp_text_to_image import models
6+
7+
8+
class MockModel:
9+
10+
def __init__(
11+
self, images=None, raise_exception=False, sample_image_bytes=None
12+
):
13+
self.images = images if images else []
14+
self.raise_exception = raise_exception
15+
self.call_count = 0
16+
self.sample_image_bytes = sample_image_bytes
17+
18+
def generate_images(self, prompt, aspect_ratio=None):
19+
_, _ = prompt, aspect_ratio
20+
self.call_count += 1
21+
if self.raise_exception:
22+
raise ValueError("Mock Model Error")
23+
24+
if self.sample_image_bytes:
25+
# Create a mock GeneratedImage instance, passing image_bytes
26+
mock_image = mock.create_autospec(
27+
vision_models.GeneratedImage, instance=True
28+
)
29+
mock_image._image_bytes = self.sample_image_bytes
30+
mock_response = vision_models.ImageGenerationResponse(images=[mock_image])
31+
return mock_response
32+
33+
return vision_models.ImageGenerationResponse(images=[])
34+
35+
36+
class ModelsTest(absltest.TestCase):
37+
38+
def setUp(self):
39+
super().setUp()
40+
# Create a sample image for testing
41+
png_base64 = b"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMBAKh72VgAAAAASUVORK5CYII="
42+
self.sample_image_bytes = base64.b64decode(png_base64)
43+
44+
@mock.patch(
45+
"vertexai.vision_models.ImageGenerationModel.from_pretrained",
46+
)
47+
@mock.patch("PIL.Image.open")
48+
def test_query_model(self, mock_image_open, mock_from_pretrained):
49+
# Create a MockModel instance
50+
mock_model = MockModel(
51+
sample_image_bytes=self.sample_image_bytes,
52+
)
53+
# Configure mock_from_pretrained to return the mock_model
54+
mock_from_pretrained.return_value = mock_model
55+
56+
model = models.VertexModelGardenModel(model_name="test_model_name")
57+
mock_image = mock.Mock()
58+
59+
mock_image.resize.return_value = mock_image
60+
mock_image_open.return_value = mock_image
61+
62+
output = model.predict_minibatch(
63+
inputs=[{"prompt": "I say yes you say no"}]
64+
)
65+
result = list(output)
66+
67+
self.assertLen(result, 1)
68+
self.assertIn("image", result[0])
69+
self.assertIn("prompt", result[0])
70+
self.assertEqual(result[0]["prompt"], "I say yes you say no")
71+
72+
# Validate that the image is a base64 string
73+
self.assertTrue(result[0]["image"].startswith("data:image/png"))
74+
self.assertIsInstance(result[0]["image"], str)
75+
76+
mock_from_pretrained.assert_called_once_with("test_model_name")
77+
78+
# Assert that mock_generate_content was called
79+
self.assertEqual(mock_model.call_count, 1)
80+
81+
82+
if __name__ == "__main__":
83+
absltest.main()

0 commit comments

Comments
 (0)