Skip to content

Commit 795cc81

Browse files
Ark-kuncopybara-github
authored andcommitted
chore: [LLM] Added unit tests
The tests cover the Text Generation, Chat and Text Embedding PiperOrigin-RevId: 530514104
1 parent 4793740 commit 795cc81

1 file changed

Lines changed: 289 additions & 0 deletions

File tree

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# pylint: disable=protected-access,bad-continuation
19+
20+
import pytest
21+
22+
from importlib import reload
23+
from unittest import mock
24+
25+
from google.cloud import aiplatform
26+
from google.cloud.aiplatform import base
27+
from google.cloud.aiplatform import initializer
28+
29+
from google.cloud.aiplatform.compat.services import (
30+
model_garden_service_client_v1beta1,
31+
)
32+
from google.cloud.aiplatform.compat.services import prediction_service_client
33+
from google.cloud.aiplatform.compat.types import (
34+
prediction_service as gca_prediction_service,
35+
)
36+
from google.cloud.aiplatform_v1beta1.types import (
37+
publisher_model as gca_publisher_model,
38+
)
39+
40+
from vertexai.preview import language_models
41+
42+
43+
_TEST_PROJECT = "test-project"
44+
_TEST_LOCATION = "us-central1"
45+
46+
_TEXT_BISON_PUBLISHER_MODEL_DICT = {
47+
"name": "publishers/google/models/text-bison",
48+
"version_id": "001",
49+
"open_source_category": "PROPRIETARY",
50+
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/text-bison@001",
51+
"predict_schemata": {
52+
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml",
53+
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_generation_1.0.0.yaml",
54+
"prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/text_generation_1.0.0.yaml",
55+
},
56+
}
57+
58+
_CHAT_BISON_PUBLISHER_MODEL_DICT = {
59+
"name": "publishers/google/models/chat-bison",
60+
"version_id": "001",
61+
"open_source_category": "PROPRIETARY",
62+
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
63+
"predict_schemata": {
64+
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml",
65+
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/chat_generation_1.0.0.yaml",
66+
"prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/chat_generation_1.0.0.yaml",
67+
},
68+
}
69+
70+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT = {
71+
"name": "publishers/google/models/textembedding-gecko",
72+
"version_id": "001",
73+
"open_source_category": "PROPRIETARY",
74+
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
75+
"predict_schemata": {
76+
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml",
77+
"parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_generation_1.0.0.yaml",
78+
"prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/text_embedding_1.0.0.yaml",
79+
},
80+
}
81+
82+
_TEST_TEXT_GENERATION_PREDICTION = {
83+
"safetyAttributes": {
84+
"categories": ["Violent"],
85+
"blocked": False,
86+
"scores": [0.10000000149011612],
87+
},
88+
"content": """
89+
Ingredients:
90+
* 3 cups all-purpose flour
91+
92+
Instructions:
93+
1. Preheat oven to 350 degrees F (175 degrees C).""",
94+
}
95+
96+
_TEST_CHAT_GENERATION_PREDICTION1 = {
97+
"safetyAttributes": {
98+
"scores": [],
99+
"blocked": False,
100+
"categories": [],
101+
},
102+
"candidates": [
103+
{
104+
"author": "1",
105+
"content": "Chat response 1",
106+
}
107+
],
108+
}
109+
_TEST_CHAT_GENERATION_PREDICTION2 = {
110+
"safetyAttributes": {
111+
"scores": [],
112+
"blocked": False,
113+
"categories": [],
114+
},
115+
"candidates": [
116+
{
117+
"author": "1",
118+
"content": "Chat response 2",
119+
}
120+
],
121+
}
122+
123+
_TEXT_EMBEDDING_VECTOR_LENGTH = 768
124+
_TEST_TEXT_EMBEDDING_PREDICTION = {
125+
"embeddings": {
126+
"values": list([1.0] * _TEXT_EMBEDDING_VECTOR_LENGTH),
127+
}
128+
}
129+
130+
131+
@pytest.mark.usefixtures("google_auth_mock")
132+
class TestLanguageModels:
133+
"""Unit tests for the language models."""
134+
135+
def setup_method(self):
136+
reload(initializer)
137+
reload(aiplatform)
138+
139+
def teardown_method(self):
140+
initializer.global_pool.shutdown(wait=True)
141+
142+
def test_text_generation(self):
143+
"""Tests the text generation model."""
144+
aiplatform.init(
145+
project=_TEST_PROJECT,
146+
location=_TEST_LOCATION,
147+
)
148+
with mock.patch.object(
149+
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
150+
attribute="get_publisher_model",
151+
return_value=gca_publisher_model.PublisherModel(
152+
_TEXT_BISON_PUBLISHER_MODEL_DICT
153+
),
154+
) as mock_get_publisher_model:
155+
model = language_models.TextGenerationModel.from_pretrained(
156+
"google/text-bison@001"
157+
)
158+
159+
mock_get_publisher_model.assert_called_once_with(
160+
name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY
161+
)
162+
163+
gca_predict_response = gca_prediction_service.PredictResponse()
164+
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)
165+
166+
with mock.patch.object(
167+
target=prediction_service_client.PredictionServiceClient,
168+
attribute="predict",
169+
return_value=gca_predict_response,
170+
):
171+
response = model.predict(
172+
"What is the best recipe for banana bread? Recipe:",
173+
max_output_tokens=128,
174+
temperature=0,
175+
top_p=1,
176+
top_k=5,
177+
)
178+
179+
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
180+
181+
def test_chat(self):
182+
"""Tests the chat generation model."""
183+
aiplatform.init(
184+
project=_TEST_PROJECT,
185+
location=_TEST_LOCATION,
186+
)
187+
with mock.patch.object(
188+
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
189+
attribute="get_publisher_model",
190+
return_value=gca_publisher_model.PublisherModel(
191+
_CHAT_BISON_PUBLISHER_MODEL_DICT
192+
),
193+
) as mock_get_publisher_model:
194+
model = language_models.ChatModel.from_pretrained("google/chat-bison@001")
195+
196+
mock_get_publisher_model.assert_called_once_with(
197+
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
198+
)
199+
200+
chat = model.start_chat(
201+
context="""
202+
My name is Ned.
203+
You are my personal assistant.
204+
My favorite movies are Lord of the Rings and Hobbit.
205+
""",
206+
examples=[
207+
language_models.InputOutputTextPair(
208+
input_text="Who do you work for?",
209+
output_text="I work for Ned.",
210+
),
211+
language_models.InputOutputTextPair(
212+
input_text="What do I like?",
213+
output_text="Ned likes watching movies.",
214+
),
215+
],
216+
temperature=0.0,
217+
)
218+
219+
gca_predict_response1 = gca_prediction_service.PredictResponse()
220+
gca_predict_response1.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1)
221+
222+
with mock.patch.object(
223+
target=prediction_service_client.PredictionServiceClient,
224+
attribute="predict",
225+
return_value=gca_predict_response1,
226+
):
227+
response = chat.send_message(
228+
"Are my favorite movies based on a book series?"
229+
)
230+
assert (
231+
response.text
232+
== _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0]["content"]
233+
)
234+
assert len(chat._history) == 1
235+
236+
gca_predict_response2 = gca_prediction_service.PredictResponse()
237+
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
238+
239+
with mock.patch.object(
240+
target=prediction_service_client.PredictionServiceClient,
241+
attribute="predict",
242+
return_value=gca_predict_response2,
243+
):
244+
response = chat.send_message(
245+
"When where these books published?",
246+
temperature=0.1,
247+
)
248+
assert (
249+
response.text
250+
== _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0]["content"]
251+
)
252+
assert len(chat._history) == 2
253+
254+
def test_text_embedding(self):
255+
"""Tests the text embedding model."""
256+
aiplatform.init(
257+
project=_TEST_PROJECT,
258+
location=_TEST_LOCATION,
259+
)
260+
with mock.patch.object(
261+
target=model_garden_service_client_v1beta1.ModelGardenServiceClient,
262+
attribute="get_publisher_model",
263+
return_value=gca_publisher_model.PublisherModel(
264+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
265+
),
266+
) as mock_get_publisher_model:
267+
model = language_models.TextEmbeddingModel.from_pretrained(
268+
"google/textembedding-gecko@001"
269+
)
270+
271+
mock_get_publisher_model.assert_called_once_with(
272+
name="publishers/google/models/textembedding-gecko@001",
273+
retry=base._DEFAULT_RETRY,
274+
)
275+
276+
gca_predict_response = gca_prediction_service.PredictResponse()
277+
gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)
278+
279+
with mock.patch.object(
280+
target=prediction_service_client.PredictionServiceClient,
281+
attribute="predict",
282+
return_value=gca_predict_response,
283+
):
284+
embeddings = model.get_embeddings(["What is life?"])
285+
assert embeddings
286+
for embedding in embeddings:
287+
vector = embedding.values
288+
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
289+
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]

0 commit comments

Comments
 (0)