Skip to content

Commit aee30b3

Browse files
committed
Reject goals for unknown sessions
1 parent 005329f commit aee30b3

2 files changed

Lines changed: 123 additions & 0 deletions

File tree

crates/server/src/runtime/goal_handlers.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ impl ServerRuntime {
2121
let session_id = params.session_id;
2222
let replace_existing = params.replace_existing;
2323
let title_input = params.objective.trim().to_string();
24+
if !self.sessions.lock().await.contains_key(&session_id) {
25+
return self.error_response(
26+
request_id,
27+
ProtocolErrorCode::SessionNotFound,
28+
"session does not exist",
29+
);
30+
}
2431

2532
let mut stores = self.goal_stores.lock().await;
2633
let store = stores.entry(session_id).or_insert_with(GoalStore::new);
@@ -90,6 +97,13 @@ impl ServerRuntime {
9097
== Some(devo_protocol::ThreadGoalStatus::Paused)
9198
&& params.objective.is_none()
9299
&& params.token_budget.is_none();
100+
if !self.sessions.lock().await.contains_key(&session_id) {
101+
return self.error_response(
102+
request_id,
103+
ProtocolErrorCode::SessionNotFound,
104+
"session does not exist",
105+
);
106+
}
93107

94108
let mut stores = self.goal_stores.lock().await;
95109
let store = stores.entry(session_id).or_insert_with(GoalStore::new);

crates/server/tests/goal_title_generation.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,82 @@ async fn goal_set_objective_generates_session_title_for_new_session() -> Result<
128128
Ok(())
129129
}
130130

131+
#[tokio::test]
132+
async fn goal_create_rejects_unknown_session() -> Result<()> {
133+
let data_root = TempDir::new()?;
134+
let provider = Arc::new(GoalTitleProvider::default());
135+
let runtime = build_runtime(data_root.path(), provider.clone())?;
136+
let (connection_id, _notifications_rx) = initialize_connection(&runtime).await?;
137+
let unknown_session_id = SessionId::new();
138+
139+
let response = runtime
140+
.handle_incoming(
141+
connection_id,
142+
serde_json::json!({
143+
"id": 5,
144+
"method": "goal/create",
145+
"params": {
146+
"sessionId": unknown_session_id,
147+
"objective": "unknown session goal",
148+
"replaceExisting": false
149+
}
150+
}),
151+
)
152+
.await
153+
.context("goal/create response")?;
154+
155+
assert_session_not_found(response)?;
156+
assert_eq!(
157+
provider
158+
.title_requests
159+
.lock()
160+
.expect("lock title requests")
161+
.len(),
162+
0
163+
);
164+
assert_eq!(provider.stream_requests.load(Ordering::SeqCst), 0);
165+
assert_goal_status_empty(&runtime, connection_id, unknown_session_id).await?;
166+
Ok(())
167+
}
168+
169+
#[tokio::test]
170+
async fn goal_set_rejects_unknown_session() -> Result<()> {
171+
let data_root = TempDir::new()?;
172+
let provider = Arc::new(GoalTitleProvider::default());
173+
let runtime = build_runtime(data_root.path(), provider.clone())?;
174+
let (connection_id, _notifications_rx) = initialize_connection(&runtime).await?;
175+
let unknown_session_id = SessionId::new();
176+
177+
let response = runtime
178+
.handle_incoming(
179+
connection_id,
180+
serde_json::json!({
181+
"id": 6,
182+
"method": "goal/set",
183+
"params": {
184+
"sessionId": unknown_session_id,
185+
"objective": "unknown session goal",
186+
"status": "active"
187+
}
188+
}),
189+
)
190+
.await
191+
.context("goal/set response")?;
192+
193+
assert_session_not_found(response)?;
194+
assert_eq!(
195+
provider
196+
.title_requests
197+
.lock()
198+
.expect("lock title requests")
199+
.len(),
200+
0
201+
);
202+
assert_eq!(provider.stream_requests.load(Ordering::SeqCst), 0);
203+
assert_goal_status_empty(&runtime, connection_id, unknown_session_id).await?;
204+
Ok(())
205+
}
206+
131207
fn build_runtime(
132208
data_root: &std::path::Path,
133209
provider: Arc<GoalTitleProvider>,
@@ -261,3 +337,36 @@ fn title_request_contains(request: &ModelRequest, needle: &str) -> bool {
261337
})
262338
})
263339
}
340+
341+
fn assert_session_not_found(response: serde_json::Value) -> Result<()> {
342+
let response: devo_server::ErrorResponse = serde_json::from_value(response)?;
343+
assert_eq!(
344+
response.error.code,
345+
devo_server::ProtocolErrorCode::SessionNotFound
346+
);
347+
Ok(())
348+
}
349+
350+
async fn assert_goal_status_empty(
351+
runtime: &Arc<ServerRuntime>,
352+
connection_id: u64,
353+
session_id: SessionId,
354+
) -> Result<()> {
355+
let response = runtime
356+
.handle_incoming(
357+
connection_id,
358+
serde_json::json!({
359+
"id": 7,
360+
"method": "goal/status",
361+
"params": {
362+
"sessionId": session_id
363+
}
364+
}),
365+
)
366+
.await
367+
.context("goal/status response")?;
368+
let response: devo_server::SuccessResponse<devo_protocol::GoalStatusResult> =
369+
serde_json::from_value(response)?;
370+
assert_eq!(response.result.goal, None);
371+
Ok(())
372+
}

0 commit comments

Comments
 (0)