diff --git a/dev/src/main/java/com/google/adk/web/controller/ExecutionController.java b/dev/src/main/java/com/google/adk/web/controller/ExecutionController.java index e88a83cef..3414b6f5e 100644 --- a/dev/src/main/java/com/google/adk/web/controller/ExecutionController.java +++ b/dev/src/main/java/com/google/adk/web/controller/ExecutionController.java @@ -21,6 +21,7 @@ import com.google.adk.events.Event; import com.google.adk.runner.Runner; import com.google.adk.web.dto.AgentRunRequest; +import com.google.adk.web.security.CallerStateGuard; import com.google.adk.web.service.RunnerService; import com.google.common.collect.Lists; import io.reactivex.rxjava3.core.Flowable; @@ -73,6 +74,7 @@ public List agentRun(@RequestBody AgentRunRequest request) { throw new ResponseStatusException( HttpStatus.BAD_REQUEST, "sessionId cannot be null or empty"); } + CallerStateGuard.validateCallerState(request.stateDelta); log.info("Request received for POST /run for session: {}", request.sessionId); Runner runner = this.runnerService.getRunner(request.appName); @@ -124,6 +126,12 @@ public SseEmitter agentRunSse(@RequestBody AgentRunRequest request) { new ResponseStatusException(HttpStatus.BAD_REQUEST, "sessionId cannot be null or empty")); return emitter; } + try { + CallerStateGuard.validateCallerState(request.stateDelta); + } catch (ResponseStatusException e) { + emitter.completeWithError(e); + return emitter; + } log.info( "SseEmitter Request received for POST /run_sse_emitter for session: {}", request.sessionId); diff --git a/dev/src/main/java/com/google/adk/web/controller/SessionController.java b/dev/src/main/java/com/google/adk/web/controller/SessionController.java index 52420894e..029b94e6b 100644 --- a/dev/src/main/java/com/google/adk/web/controller/SessionController.java +++ b/dev/src/main/java/com/google/adk/web/controller/SessionController.java @@ -22,6 +22,7 @@ import com.google.adk.sessions.ListSessionsResponse; import com.google.adk.sessions.Session; import com.google.adk.web.dto.SessionRequest; +import com.google.adk.web.security.CallerStateGuard; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.Collections; @@ -176,6 +177,7 @@ public Session createSessionWithId( body); Map initialState = (body != null) ? body.getState() : Collections.emptyMap(); + CallerStateGuard.validateCallerState(initialState); try { findSessionOrThrow(appName, userId, sessionId); @@ -237,8 +239,9 @@ public Session createSession( userId, body); + Map initialState = (body != null) ? body.getState() : Collections.emptyMap(); + CallerStateGuard.validateCallerState(initialState); try { - Map initialState = (body != null) ? body.getState() : Collections.emptyMap(); Session createdSession = sessionService .createSession(appName, userId, new ConcurrentHashMap<>(initialState), null) diff --git a/dev/src/main/java/com/google/adk/web/security/CallerStateGuard.java b/dev/src/main/java/com/google/adk/web/security/CallerStateGuard.java new file mode 100644 index 000000000..50ce8656c --- /dev/null +++ b/dev/src/main/java/com/google/adk/web/security/CallerStateGuard.java @@ -0,0 +1,64 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.web.security; + +import com.google.adk.sessions.State; +import java.util.Map; +import org.springframework.http.HttpStatus; +import org.springframework.web.server.ResponseStatusException; + +/** + * Validates caller-supplied session state received over the development web API. + * + *

Keys prefixed with {@link State#APP_PREFIX} or {@link State#USER_PREFIX} are stored in app- or + * user-scoped state that is shared across sessions. Such keys are rejected when supplied by an HTTP + * caller, so the development server cannot be used to write cross-session state from external + * input. Programmatic callers using the session APIs directly are unaffected. + */ +public final class CallerStateGuard { + + private CallerStateGuard() {} + + /** + * Rejects caller-supplied state containing cross-session ({@code app:}/{@code user:}) keys. + * + * @param state caller-supplied state map (may be {@code null} or empty) + * @throws ResponseStatusException with {@link HttpStatus#BAD_REQUEST} if a disallowed key is + * found + */ + public static void validateCallerState(Map state) { + if (state == null || state.isEmpty()) { + return; + } + for (String key : state.keySet()) { + if (key == null) { + continue; + } + if (key.startsWith(State.APP_PREFIX) || key.startsWith(State.USER_PREFIX)) { + throw new ResponseStatusException( + HttpStatus.BAD_REQUEST, + "Caller-supplied state may not write the key '" + + key + + "'. Keys prefixed with '" + + State.APP_PREFIX + + "' or '" + + State.USER_PREFIX + + "' are shared across sessions and cannot be set via the development web API."); + } + } + } +} diff --git a/dev/src/test/java/com/google/adk/web/AdkWebServerTest.java b/dev/src/test/java/com/google/adk/web/AdkWebServerTest.java index 275001200..0e657705f 100644 --- a/dev/src/test/java/com/google/adk/web/AdkWebServerTest.java +++ b/dev/src/test/java/com/google/adk/web/AdkWebServerTest.java @@ -139,4 +139,63 @@ public void listSessions_shouldReturnOk() throws Exception { mockMvc.perform(delete("/apps/test-app/users/test-user/sessions/test-session-1")); mockMvc.perform(delete("/apps/test-app/users/test-user/sessions/test-session-2")); } + + @Test + public void createSession_withPlainSessionScopedState_returnsOk() throws Exception { + var result = + mockMvc + .perform( + post("/apps/test-app/users/test-user/sessions") + .contentType(MediaType.APPLICATION_JSON) + .content("{\"state\":{\"topic\":\"weather\"}}")) + .andExpect(status().isOk()) + .andReturn(); + + var sessionId = + com.jayway.jsonpath.JsonPath.read(result.getResponse().getContentAsString(), "$.id"); + mockMvc.perform(delete("/apps/test-app/users/test-user/sessions/" + sessionId)); + } + + @Test + public void createSession_withAppScopedState_returnsBadRequest() throws Exception { + mockMvc + .perform( + post("/apps/test-app/users/test-user/sessions") + .contentType(MediaType.APPLICATION_JSON) + .content("{\"state\":{\"app:operationalScope\":\"everything\"}}")) + .andExpect(status().isBadRequest()); + } + + @Test + public void createSession_withUserScopedState_returnsBadRequest() throws Exception { + mockMvc + .perform( + post("/apps/test-app/users/test-user/sessions") + .contentType(MediaType.APPLICATION_JSON) + .content("{\"state\":{\"user:role\":\"admin\"}}")) + .andExpect(status().isBadRequest()); + } + + @Test + public void createSessionWithId_withAppScopedState_returnsBadRequest() throws Exception { + mockMvc + .perform( + post("/apps/test-app/users/test-user/sessions/test-session-scoped") + .contentType(MediaType.APPLICATION_JSON) + .content("{\"state\":{\"app:operationalScope\":\"everything\"}}")) + .andExpect(status().isBadRequest()); + } + + @Test + public void run_withAppScopedStateDelta_returnsBadRequest() throws Exception { + mockMvc + .perform( + post("/run") + .contentType(MediaType.APPLICATION_JSON) + .content( + "{\"appName\":\"test-app\",\"userId\":\"test-user\"," + + "\"sessionId\":\"test-session\"," + + "\"stateDelta\":{\"app:operationalScope\":\"everything\"}}")) + .andExpect(status().isBadRequest()); + } } diff --git a/dev/src/test/java/com/google/adk/web/security/CallerStateGuardTest.java b/dev/src/test/java/com/google/adk/web/security/CallerStateGuardTest.java new file mode 100644 index 000000000..fb1dd9ff9 --- /dev/null +++ b/dev/src/test/java/com/google/adk/web/security/CallerStateGuardTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.web.security; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.springframework.http.HttpStatus; +import org.springframework.web.server.ResponseStatusException; + +/** Unit tests for {@link CallerStateGuard}. */ +public class CallerStateGuardTest { + + @Test + public void validateCallerState_nullState_doesNotThrow() { + CallerStateGuard.validateCallerState(null); + } + + @Test + public void validateCallerState_emptyState_doesNotThrow() { + CallerStateGuard.validateCallerState(ImmutableMap.of()); + } + + @Test + public void validateCallerState_plainSessionScopedKey_doesNotThrow() { + CallerStateGuard.validateCallerState(ImmutableMap.of("userContext", "Alice")); + } + + @Test + public void validateCallerState_tempPrefix_doesNotThrow() { + CallerStateGuard.validateCallerState(ImmutableMap.of("temp:scratch", "x")); + } + + @Test + public void validateCallerState_underscorePrefixedKey_doesNotThrow() { + CallerStateGuard.validateCallerState(ImmutableMap.of("_adk_replay_config", ImmutableMap.of())); + } + + @Test + public void validateCallerState_appPrefix_throwsBadRequest() { + ResponseStatusException exception = + assertThrows( + ResponseStatusException.class, + () -> + CallerStateGuard.validateCallerState(ImmutableMap.of("app:operationalScope", "x"))); + + assertThat(exception.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST); + assertThat(exception.getReason()).contains("app:operationalScope"); + } + + @Test + public void validateCallerState_userPrefix_throwsBadRequest() { + ResponseStatusException exception = + assertThrows( + ResponseStatusException.class, + () -> CallerStateGuard.validateCallerState(ImmutableMap.of("user:role", "admin"))); + + assertThat(exception.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST); + assertThat(exception.getReason()).contains("user:role"); + } + + @Test + public void validateCallerState_mixedValidAndScopedKeys_throwsBadRequest() { + Map state = new HashMap<>(); + state.put("userContext", "Alice"); + state.put("app:scope", "everything"); + + assertThrows(ResponseStatusException.class, () -> CallerStateGuard.validateCallerState(state)); + } +}