diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index 72a14cc4d..c11e3d1db 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -20,6 +20,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Completable; @@ -32,10 +33,13 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * An in-memory implementation of {@link BaseSessionService} assuming {@link Session} objects are @@ -49,6 +53,23 @@ * during retrieval operations ({@code getSession}, {@code createSession}). */ public final class InMemorySessionService implements BaseSessionService { + + private static final Logger log = LoggerFactory.getLogger(InMemorySessionService.class); + + /** + * Reserved session-state keys that are managed internally by the ADK framework. Callers are not + * permitted to set or override these keys through the public API (initial session state or + * per-run stateDelta). Allowing external writes to these keys would let an untrusted caller steer + * internal framework behaviour, such as hijacking the code-execution session identifier used by + * {@code VertexAiCodeExecutor}. + */ + private static final Set RESERVED_STATE_KEYS = + ImmutableSet.of( + "_code_execution_context", + "_code_executor_input_files", + "_code_executor_error_counts", + "_code_execution_results"); + // Structure: appName -> userId -> sessionId -> Session private final ConcurrentMap>> sessions; @@ -65,6 +86,31 @@ public InMemorySessionService() { this.appState = new ConcurrentHashMap<>(); } + /** + * Removes reserved internal keys from a caller-supplied state map before it is persisted. + * Logs a warning for each key that is dropped. + * + * @param state The caller-supplied state map (may be null). + * @return A new {@link ConcurrentHashMap} containing only the non-reserved entries. + */ + private static ConcurrentMap sanitizeCallerState( + @Nullable Map state) { + if (state == null) { + return new ConcurrentHashMap<>(); + } + ConcurrentMap sanitized = new ConcurrentHashMap<>(); + state.forEach( + (key, value) -> { + if (RESERVED_STATE_KEYS.contains(key)) { + log.warn( + "Caller attempted to set reserved internal state key '{}'; ignoring.", key); + } else { + sanitized.put(key, value); + } + }); + return sanitized; + } + @Override public Single createSession( String appName, @@ -89,9 +135,8 @@ public Single createSession( .filter(s -> !s.isEmpty()) .orElseGet(() -> UUID.randomUUID().toString()); - // Ensure state map and events list are mutable for the new session - ConcurrentMap initialState = - (state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state); + // Sanitize caller-supplied state: strip reserved internal keys before persisting. + ConcurrentMap initialState = sanitizeCallerState(state); // Assuming Session constructor or setters allow setting these mutable collections Session newSession = @@ -268,6 +313,12 @@ public Single appendEvent(Session session, Event event) { .put(userStateKey, value); } } else { + // Reject writes to reserved internal keys from any external stateDelta. + if (RESERVED_STATE_KEYS.contains(key)) { + log.warn( + "stateDelta contains reserved internal key '{}'; ignoring write.", key); + return; + } if (value == State.REMOVED) { session.state().remove(key); } else { diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 6a271efac..3eb8673e2 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -317,4 +317,55 @@ public void deleteSession_doesNotRemoveUserMapWhenOtherSessionsExist() throws Ex assertThat(sessions.get("app-name").get("user-id")).isNotNull(); assertThat(sessions.get("app-name").get("user-id")).hasSize(1); } -} + + @Test + public void createSession_stripsReservedStateKeys() { + InMemorySessionService sessionService = new InMemorySessionService(); + + HashMap state = new HashMap<>(); + state.put("user_key", "user_value"); + state.put("_code_execution_context", "should_be_stripped"); + state.put("_code_executor_input_files", "should_be_stripped"); + state.put("_code_executor_error_counts", "should_be_stripped"); + state.put("_code_execution_results", "should_be_stripped"); + + Session session = + sessionService.createSession("app", "user", state, "session-id").blockingGet(); + + // Only the non-reserved key should survive. + assertThat(session.state()).containsExactly("user_key", "user_value"); + assertThat(session.state()).doesNotContainKey("_code_execution_context"); + assertThat(session.state()).doesNotContainKey("_code_executor_input_files"); + assertThat(session.state()).doesNotContainKey("_code_executor_error_counts"); + assertThat(session.state()).doesNotContainKey("_code_execution_results"); + } + + @Test + public void appendEvent_stateDelta_stripsReservedStateKeys() { + InMemorySessionService sessionService = new InMemorySessionService(); + + Session session = sessionService.createSession("app", "user").blockingGet(); + + // Attempt to inject a reserved key via stateDelta. + ConcurrentMap stateDelta = new ConcurrentHashMap<>(); + stateDelta.put("safe_key", "safe_value"); + stateDelta.put("_code_execution_context", "attacker_controlled"); + + Event event = + Event.builder().actions(EventActions.builder().stateDelta(stateDelta).build()).build(); + sessionService.appendEvent(session, event).blockingGet(); + + assertThat(session.state()).containsEntry("safe_key", "safe_value"); + assertThat(session.state()).doesNotContainKey("_code_execution_context"); + } + + @Test + public void createSession_nullState_producesEmptyState() { + InMemorySessionService sessionService = new InMemorySessionService(); + + Session session = sessionService.createSession("app", "user", null, "s1").blockingGet(); + + assertThat(session.state()).isEmpty(); + } + +} \ No newline at end of file