3939from google .cloud .aiplatform .compat .types import (
4040 publisher_model as gca_publisher_model ,
4141)
42+ import vertexai
4243from vertexai import vision_models as ga_vision_models
43- from vertexai .preview import vision_models
44+ from vertexai .preview import (
45+ vision_models as preview_vision_models ,
46+ )
4447
4548from PIL import Image as PIL_Image
4649import pytest
@@ -121,12 +124,12 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
121124
122125def generate_image_from_file (
123126 width : int = 100 , height : int = 100
124- ) -> vision_models .Image :
127+ ) -> ga_vision_models .Image :
125128 with tempfile .TemporaryDirectory () as temp_dir :
126129 image_path = os .path .join (temp_dir , "image.png" )
127130 pil_image = PIL_Image .new (mode = "RGB" , size = (width , height ))
128131 pil_image .save (image_path , format = "PNG" )
129- return vision_models .Image .load_from_file (image_path )
132+ return ga_vision_models .Image .load_from_file (image_path )
130133
131134
132135@pytest .mark .usefixtures ("google_auth_mock" )
@@ -140,7 +143,7 @@ def setup_method(self):
140143 def teardown_method (self ):
141144 initializer .global_pool .shutdown (wait = True )
142145
143- def _get_image_generation_model (self ) -> vision_models .ImageGenerationModel :
146+ def _get_image_generation_model (self ) -> preview_vision_models .ImageGenerationModel :
144147 """Gets the image generation model."""
145148 aiplatform .init (
146149 project = _TEST_PROJECT ,
@@ -153,7 +156,7 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
153156 _IMAGE_GENERATION_PUBLISHER_MODEL_DICT
154157 ),
155158 ) as mock_get_publisher_model :
156- model = vision_models .ImageGenerationModel .from_pretrained (
159+ model = preview_vision_models .ImageGenerationModel .from_pretrained (
157160 "imagegeneration@002"
158161 )
159162
@@ -164,13 +167,48 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
164167
165168 return model
166169
170+ def _get_preview_image_generation_model_top_level_from_pretrained (
171+ self ,
172+ ) -> preview_vision_models .ImageGenerationModel :
173+ """Gets the image generation model from the top-level vertexai.preview.from_pretrained method."""
174+ aiplatform .init (
175+ project = _TEST_PROJECT ,
176+ location = _TEST_LOCATION ,
177+ )
178+ with mock .patch .object (
179+ target = model_garden_service_client .ModelGardenServiceClient ,
180+ attribute = "get_publisher_model" ,
181+ return_value = gca_publisher_model .PublisherModel (
182+ _IMAGE_GENERATION_PUBLISHER_MODEL_DICT
183+ ),
184+ ) as mock_get_publisher_model :
185+ model = vertexai .preview .from_pretrained (
186+ foundation_model_name = "imagegeneration@002"
187+ )
188+
189+ mock_get_publisher_model .assert_called_with (
190+ name = "publishers/google/models/imagegeneration@002" ,
191+ retry = base ._DEFAULT_RETRY ,
192+ )
193+
194+ assert mock_get_publisher_model .call_count == 1
195+
196+ return model
197+
167198 def test_from_pretrained (self ):
168199 model = self ._get_image_generation_model ()
169200 assert (
170201 model ._endpoint_name
171202 == f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/imagegeneration@002"
172203 )
173204
205+ def test_top_level_from_pretrained_preview (self ):
206+ model = self ._get_preview_image_generation_model_top_level_from_pretrained ()
207+ assert (
208+ model ._endpoint_name
209+ == f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/imagegeneration@002"
210+ )
211+
174212 def test_generate_images (self ):
175213 """Tests the image generation model."""
176214 model = self ._get_image_generation_model ()
@@ -238,7 +276,7 @@ def test_generate_images(self):
238276 with tempfile .TemporaryDirectory () as temp_dir :
239277 image_path = os .path .join (temp_dir , "image.png" )
240278 image_response [0 ].save (location = image_path )
241- image1 = vision_models .GeneratedImage .load_from_file (image_path )
279+ image1 = preview_vision_models .GeneratedImage .load_from_file (image_path )
242280 # assert image1._pil_image.size == (width, height)
243281 assert image1 .generation_parameters
244282 assert image1 .generation_parameters ["prompt" ] == prompt1
@@ -247,7 +285,7 @@ def test_generate_images(self):
247285 mask_path = os .path .join (temp_dir , "mask.png" )
248286 mask_pil_image = PIL_Image .new (mode = "RGB" , size = image1 ._pil_image .size )
249287 mask_pil_image .save (mask_path , format = "PNG" )
250- mask_image = vision_models .Image .load_from_file (mask_path )
288+ mask_image = preview_vision_models .Image .load_from_file (mask_path )
251289
252290 # Test generating image from base image
253291 with mock .patch .object (
@@ -408,7 +446,7 @@ def test_upscale_image_on_provided_image(self):
408446 assert image_upscale_parameters ["mode" ] == "upscale"
409447
410448 assert upscaled_image ._image_bytes
411- assert isinstance (upscaled_image , vision_models .GeneratedImage )
449+ assert isinstance (upscaled_image , preview_vision_models .GeneratedImage )
412450
413451 def test_upscale_image_raises_if_not_1024x1024 (self ):
414452 """Tests image upscaling on generated images."""
@@ -457,7 +495,7 @@ def test_get_captions(self):
457495 image_path = os .path .join (temp_dir , "image.png" )
458496 pil_image = PIL_Image .new (mode = "RGB" , size = (100 , 100 ))
459497 pil_image .save (image_path , format = "PNG" )
460- image = vision_models .Image .load_from_file (image_path )
498+ image = preview_vision_models .Image .load_from_file (image_path )
461499
462500 with mock .patch .object (
463501 target = prediction_service_client .PredictionServiceClient ,
@@ -544,7 +582,7 @@ def test_image_embedding_model_with_only_image(self):
544582 _IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
545583 ),
546584 ) as mock_get_publisher_model :
547- model = vision_models .MultiModalEmbeddingModel .from_pretrained (
585+ model = preview_vision_models .MultiModalEmbeddingModel .from_pretrained (
548586 "multimodalembedding@001"
549587 )
550588
@@ -583,7 +621,7 @@ def test_image_embedding_model_with_image_and_text(self):
583621 _IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
584622 ),
585623 ):
586- model = vision_models .MultiModalEmbeddingModel .from_pretrained (
624+ model = preview_vision_models .MultiModalEmbeddingModel .from_pretrained (
587625 "multimodalembedding@001"
588626 )
589627
@@ -715,7 +753,7 @@ def test_get_captions(self):
715753 image_path = os .path .join (temp_dir , "image.png" )
716754 pil_image = PIL_Image .new (mode = "RGB" , size = (100 , 100 ))
717755 pil_image .save (image_path , format = "PNG" )
718- image = vision_models .Image .load_from_file (image_path )
756+ image = preview_vision_models .Image .load_from_file (image_path )
719757
720758 with mock .patch .object (
721759 target = prediction_service_client .PredictionServiceClient ,
0 commit comments