Skip to content

Commit 70041fe

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Adding a new BaseSessionService.createSession() that takes in map
PiperOrigin-RevId: 875673726
1 parent e162a6b commit 70041fe

7 files changed

Lines changed: 90 additions & 27 deletions

File tree

contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import java.util.concurrent.ConcurrentMap;
5151
import java.util.concurrent.atomic.AtomicBoolean;
5252
import java.util.regex.Matcher;
53+
import javax.annotation.Nullable;
5354
import org.slf4j.Logger;
5455
import org.slf4j.LoggerFactory;
5556

@@ -88,7 +89,20 @@ private CollectionReference getSessionsCollection(String userId) {
8889
/** Creates a new session in Firestore. */
8990
@Override
9091
public Single<Session> createSession(
91-
String appName, String userId, ConcurrentMap<String, Object> state, String sessionId) {
92+
String appName,
93+
String userId,
94+
@Nullable ConcurrentMap<String, Object> state,
95+
@Nullable String sessionId) {
96+
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
97+
}
98+
99+
/** Creates a new session in Firestore. */
100+
@Override
101+
public Single<Session> createSession(
102+
String appName,
103+
String userId,
104+
@Nullable Map<String, Object> state,
105+
@Nullable String sessionId) {
92106
return Single.fromCallable(
93107
() -> {
94108
Objects.requireNonNull(appName, "appName cannot be null");

core/src/main/java/com/google/adk/sessions/BaseSessionService.java

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
import io.reactivex.rxjava3.core.Maybe;
2424
import io.reactivex.rxjava3.core.Single;
2525
import java.util.List;
26+
import java.util.Map;
2627
import java.util.Objects;
2728
import java.util.Optional;
29+
import java.util.concurrent.ConcurrentHashMap;
2830
import java.util.concurrent.ConcurrentMap;
2931
import javax.annotation.Nullable;
3032

@@ -47,13 +49,35 @@ public interface BaseSessionService {
4749
* service should generate a unique ID.
4850
* @return The newly created {@link Session} instance.
4951
* @throws SessionException if creation fails.
52+
* @deprecated Use {@link #createSession(String, String, Map, String)} instead.
5053
*/
54+
@Deprecated
5155
Single<Session> createSession(
5256
String appName,
5357
String userId,
5458
@Nullable ConcurrentMap<String, Object> state,
5559
@Nullable String sessionId);
5660

61+
/**
62+
* Creates a new session with the specified parameters.
63+
*
64+
* @param appName The name of the application associated with the session.
65+
* @param userId The identifier for the user associated with the session.
66+
* @param state An optional map representing the initial state of the session. Can be null or
67+
* empty.
68+
* @param sessionId An optional client-provided identifier for the session. If empty or null, the
69+
* service should generate a unique ID.
70+
* @return The newly created {@link Session} instance.
71+
* @throws SessionException if creation fails.
72+
*/
73+
default Single<Session> createSession(
74+
String appName,
75+
String userId,
76+
@Nullable Map<String, Object> state,
77+
@Nullable String sessionId) {
78+
return createSession(appName, userId, ensureConcurrentMap(state), sessionId);
79+
}
80+
5781
/**
5882
* Creates a new session with the specified application name and user ID, using a default state
5983
* (null) and allowing the service to generate a unique session ID.
@@ -165,9 +189,9 @@ default Single<Event> appendEvent(Session session, Event event) {
165189

166190
EventActions actions = event.actions();
167191
if (actions != null) {
168-
ConcurrentMap<String, Object> stateDelta = actions.stateDelta();
192+
Map<String, Object> stateDelta = actions.stateDelta();
169193
if (stateDelta != null && !stateDelta.isEmpty()) {
170-
ConcurrentMap<String, Object> sessionState = session.state();
194+
Map<String, Object> sessionState = session.state();
171195
if (sessionState != null) {
172196
stateDelta.forEach(
173197
(key, value) -> {
@@ -190,4 +214,21 @@ default Single<Event> appendEvent(Session session, Event event) {
190214

191215
return Single.just(event);
192216
}
217+
218+
/**
219+
* Ensures the given {@link Map} is a {@link ConcurrentMap}. If the input is null, returns null.
220+
* If the input is already a {@link ConcurrentMap}, it is cast and returned. Otherwise, a new
221+
* {@link ConcurrentHashMap} is created from the input map.
222+
*/
223+
@Nullable
224+
private static ConcurrentMap<String, Object> ensureConcurrentMap(
225+
@Nullable Map<String, Object> state) {
226+
if (state == null) {
227+
return null;
228+
}
229+
if (state instanceof ConcurrentMap<String, Object> concurrentMap) {
230+
return concurrentMap;
231+
}
232+
return new ConcurrentHashMap<>(state);
233+
}
193234
}

core/src/main/java/com/google/adk/sessions/InMemorySessionService.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ public Single<Session> createSession(
7171
String userId,
7272
@Nullable ConcurrentMap<String, Object> state,
7373
@Nullable String sessionId) {
74+
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
75+
}
76+
77+
@Override
78+
public Single<Session> createSession(
79+
String appName,
80+
String userId,
81+
@Nullable Map<String, Object> state,
82+
@Nullable String sessionId) {
7483
Objects.requireNonNull(appName, "appName cannot be null");
7584
Objects.requireNonNull(userId, "userId cannot be null");
7685

@@ -83,15 +92,13 @@ public Single<Session> createSession(
8392
// Ensure state map and events list are mutable for the new session
8493
ConcurrentMap<String, Object> initialState =
8594
(state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state);
86-
List<Event> initialEvents = new ArrayList<>();
8795

8896
// Assuming Session constructor or setters allow setting these mutable collections
8997
Session newSession =
9098
Session.builder(resolvedSessionId)
9199
.appName(appName)
92100
.userId(userId)
93101
.state(initialState)
94-
.events(initialEvents)
95102
.lastUpdateTime(Instant.now())
96103
.build();
97104

core/src/main/java/com/google/adk/sessions/VertexAiClient.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
import io.reactivex.rxjava3.core.Single;
1515
import java.io.IOException;
1616
import java.io.UncheckedIOException;
17+
import java.util.HashMap;
1718
import java.util.List;
19+
import java.util.Map;
1820
import java.util.Optional;
19-
import java.util.concurrent.ConcurrentHashMap;
20-
import java.util.concurrent.ConcurrentMap;
2121
import java.util.concurrent.TimeoutException;
2222
import javax.annotation.Nullable;
2323
import okhttp3.ResponseBody;
@@ -51,8 +51,8 @@ final class VertexAiClient {
5151
}
5252

5353
Maybe<JsonNode> createSession(
54-
String reasoningEngineId, String userId, ConcurrentMap<String, Object> state) {
55-
ConcurrentHashMap<String, Object> sessionJsonMap = new ConcurrentHashMap<>();
54+
String reasoningEngineId, String userId, Map<String, Object> state) {
55+
Map<String, Object> sessionJsonMap = new HashMap<>();
5656
sessionJsonMap.put("userId", userId);
5757
if (state != null) {
5858
sessionJsonMap.put("sessionState", state);

core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ public Single<Session> createSession(
7676
String userId,
7777
@Nullable ConcurrentMap<String, Object> state,
7878
@Nullable String sessionId) {
79+
return createSession(appName, userId, (Map<String, Object>) state, sessionId);
80+
}
81+
82+
@Override
83+
public Single<Session> createSession(
84+
String appName,
85+
String userId,
86+
@Nullable Map<String, Object> state,
87+
@Nullable String sessionId) {
7988

8089
String reasoningEngineId = parseReasoningEngineId(appName);
8190
return client

core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import com.google.adk.events.Event;
2121
import com.google.adk.events.EventActions;
2222
import io.reactivex.rxjava3.core.Single;
23+
import java.util.HashMap;
2324
import java.util.Optional;
2425
import java.util.concurrent.ConcurrentHashMap;
2526
import java.util.concurrent.ConcurrentMap;
@@ -84,7 +85,7 @@ public void lifecycle_listSessions() {
8485

8586
Session session =
8687
sessionService
87-
.createSession("app-name", "user-id", new ConcurrentHashMap<>(), "session-1")
88+
.createSession("app-name", "user-id", new HashMap<>(), "session-1")
8889
.blockingGet();
8990

9091
ConcurrentMap<String, Object> stateDelta = new ConcurrentHashMap<>();
@@ -130,9 +131,7 @@ public void lifecycle_deleteSession() {
130131
public void appendEvent_updatesSessionState() {
131132
InMemorySessionService sessionService = new InMemorySessionService();
132133
Session session =
133-
sessionService
134-
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
135-
.blockingGet();
134+
sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet();
136135

137136
ConcurrentMap<String, Object> stateDelta = new ConcurrentHashMap<>();
138137
stateDelta.put("sessionKey", "sessionValue");
@@ -167,9 +166,7 @@ public void appendEvent_updatesSessionState() {
167166
public void appendEvent_removesState() {
168167
InMemorySessionService sessionService = new InMemorySessionService();
169168
Session session =
170-
sessionService
171-
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
172-
.blockingGet();
169+
sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet();
173170

174171
ConcurrentMap<String, Object> stateDeltaAdd = new ConcurrentHashMap<>();
175172
stateDeltaAdd.put("sessionKey", "sessionValue");
@@ -221,9 +218,7 @@ public void appendEvent_removesState() {
221218
public void sequentialAgents_shareTempState() {
222219
InMemorySessionService sessionService = new InMemorySessionService();
223220
Session session =
224-
sessionService
225-
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
226-
.blockingGet();
221+
sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet();
227222

228223
// Agent 1 writes to temp state
229224
ConcurrentMap<String, Object> stateDelta1 = new ConcurrentHashMap<>();

core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ public void setUp() throws Exception {
167167

168168
@Test
169169
public void createSession_success() throws Exception {
170-
ConcurrentMap<String, Object> sessionStateMap =
171-
new ConcurrentHashMap<>(ImmutableMap.of("new_key", "new_value"));
170+
Map<String, Object> sessionStateMap = new HashMap<>(ImmutableMap.of("new_key", "new_value"));
172171
Single<Session> sessionSingle =
173172
vertexAiSessionService.createSession("123", "test_user", sessionStateMap, null);
174173
Session createdSession = sessionSingle.blockingGet();
@@ -190,8 +189,7 @@ public void createSession_success() throws Exception {
190189

191190
@Test
192191
public void createSession_getSession_success() throws Exception {
193-
ConcurrentMap<String, Object> sessionStateMap =
194-
new ConcurrentHashMap<>(ImmutableMap.of("new_key", "new_value"));
192+
Map<String, Object> sessionStateMap = new HashMap<>(ImmutableMap.of("new_key", "new_value"));
195193
Single<Session> sessionSingle =
196194
vertexAiSessionService.createSession("789", "test_user", sessionStateMap, null);
197195
Session createdSession = sessionSingle.blockingGet();
@@ -252,8 +250,7 @@ public void getAndDeleteSession_success() throws Exception {
252250

253251
@Test
254252
public void createSessionAndGetSession_success() throws Exception {
255-
ConcurrentMap<String, Object> sessionStateMap =
256-
new ConcurrentHashMap<>(ImmutableMap.of("key", "value"));
253+
Map<String, Object> sessionStateMap = new HashMap<>(ImmutableMap.of("key", "value"));
257254
Single<Session> sessionSingle =
258255
vertexAiSessionService.createSession("123", "user", sessionStateMap, null);
259256
Session createdSession = sessionSingle.blockingGet();
@@ -341,8 +338,8 @@ public void listEmptySession_success() {
341338
@Test
342339
public void appendEvent_withStateRemoved_updatesSessionState() {
343340
String userId = "userB";
344-
ConcurrentMap<String, Object> initialState =
345-
new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2"));
341+
Map<String, Object> initialState =
342+
new HashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2"));
346343
Session session =
347344
vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet();
348345

0 commit comments

Comments
 (0)