Skip to content

Commit 3dff1b6

Browse files
copybara-service[bot]Zhenyi Qi
andauthored
chore: [vertexai]Add integration test for ChatSession. (#10641)
PiperOrigin-RevId: 620268882 Co-authored-by: Zhenyi Qi <zhenyiqi@google.com>
1 parent 0eacdf6 commit 3dff1b6

1 file changed

Lines changed: 135 additions & 0 deletions

File tree

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.cloud.vertexai.it;
17+
18+
import static com.google.common.truth.Truth.assertThat;
19+
20+
import com.google.cloud.vertexai.VertexAI;
21+
import com.google.cloud.vertexai.api.Content;
22+
import com.google.cloud.vertexai.api.GenerateContentResponse;
23+
import com.google.cloud.vertexai.api.GenerationConfig;
24+
import com.google.cloud.vertexai.api.HarmCategory;
25+
import com.google.cloud.vertexai.api.SafetySetting;
26+
import com.google.cloud.vertexai.generativeai.ChatSession;
27+
import com.google.cloud.vertexai.generativeai.ContentMaker;
28+
import com.google.cloud.vertexai.generativeai.GenerativeModel;
29+
import com.google.cloud.vertexai.generativeai.ResponseStream;
30+
import java.io.IOException;
31+
import java.util.Arrays;
32+
import java.util.List;
33+
import java.util.logging.Logger;
34+
import org.junit.After;
35+
import org.junit.Before;
36+
import org.junit.Test;
37+
import org.junit.runner.RunWith;
38+
import org.junit.runners.JUnit4;
39+
40+
@RunWith(JUnit4.class)
41+
public class ITChatSessionIntegrationTest {
42+
private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT");
43+
private static final String MODEL_NAME_TEXT = "gemini-pro";
44+
private static final String LOCATION = "us-central1";
45+
private static final Logger logger =
46+
Logger.getLogger(ITGenerativeModelIntegrationTest.class.getName());
47+
48+
private VertexAI vertexAi;
49+
private GenerativeModel model;
50+
private ChatSession chat;
51+
52+
@Before
53+
public void setUp() throws IOException {
54+
vertexAi = new VertexAI(PROJECT_ID, LOCATION);
55+
model = new GenerativeModel(MODEL_NAME_TEXT, vertexAi);
56+
}
57+
58+
@After
59+
public void tearDown() throws IOException {
60+
vertexAi.close();
61+
}
62+
63+
@Test
64+
public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException {
65+
// Arrange
66+
String firstMessage = "hello!";
67+
String secondMessage = "how old are you?";
68+
Content expectedFirstContent = ContentMaker.fromString(firstMessage);
69+
Content expectedThirdContent = ContentMaker.fromString(secondMessage);
70+
71+
// Act
72+
chat = model.startChat();
73+
ResponseStream<GenerateContentResponse> stream = chat.sendMessageStream(firstMessage);
74+
// We consume the stream before sending another message
75+
for (GenerateContentResponse resp : stream) {
76+
// Assert while consuming
77+
assertThat(resp.getCandidatesList()).isNotEmpty();
78+
}
79+
GenerateContentResponse response = chat.sendMessage(secondMessage);
80+
List<Content> history = chat.getHistory();
81+
82+
// Assert
83+
// GenAI output is flaky so we always print out the response.
84+
// For the same reason, we don't do assertions much.
85+
logger.info(String.format("The whole history is:\n%s", history));
86+
assertThat(history.size()).isEqualTo(4);
87+
assertThat(history.get(0)).isEqualTo(expectedFirstContent);
88+
assertThat(history.get(1).getRole()).isEqualTo("model");
89+
assertThat(history.get(2)).isEqualTo(expectedThirdContent);
90+
assertThat(history.get(3).getRole()).isEqualTo("model");
91+
}
92+
93+
@Test
94+
public void sendMessageWithNewConfigs_historyContainsFullConversation() throws IOException {
95+
// Arrange
96+
String firstMessage = "hello!";
97+
String secondMessage = "how old are you?";
98+
Content expectedFirstContent = ContentMaker.fromString(firstMessage);
99+
Content expectedThirdContent = ContentMaker.fromString(secondMessage);
100+
GenerationConfig config = GenerationConfig.newBuilder().setTemperature(0.7F).build();
101+
List<SafetySetting> safetySettings =
102+
Arrays.asList(
103+
SafetySetting.newBuilder()
104+
.setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
105+
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH)
106+
.build());
107+
108+
// Act
109+
chat = model.startChat();
110+
ResponseStream<GenerateContentResponse> stream =
111+
chat.withGenerationConfig(config)
112+
.withSafetySettings(safetySettings)
113+
.sendMessageStream(firstMessage);
114+
// We consume the stream before sending another message
115+
for (GenerateContentResponse resp : stream) {
116+
// Assert while consuming
117+
assertThat(resp.getCandidatesList()).isNotEmpty();
118+
}
119+
GenerateContentResponse response =
120+
chat.withGenerationConfig(config)
121+
.withSafetySettings(safetySettings)
122+
.sendMessage(secondMessage);
123+
124+
// Assert
125+
List<Content> history = chat.getHistory();
126+
// GenAI output is flaky so we always print out the response.
127+
// For the same reason, we don't do assertions much.
128+
logger.info(String.format("The whole history is:\n%s", history));
129+
assertThat(history.size()).isEqualTo(4);
130+
assertThat(history.get(0)).isEqualTo(expectedFirstContent);
131+
assertThat(history.get(1).getRole()).isEqualTo("model");
132+
assertThat(history.get(2)).isEqualTo(expectedThirdContent);
133+
assertThat(history.get(3).getRole()).isEqualTo("model");
134+
}
135+
}

0 commit comments

Comments
 (0)