Skip to content

Commit be23334

Browse files
authored
fix(tasks): avoid dropping completed task results during collection (#639)
* fix(tasks): avoid dropping completed task results during collection * chore(tasks): make `task_result_receiver` required * refactor(tasks): make `collect_completed_results` private
1 parent f6ebc7a commit be23334

3 files changed

Lines changed: 19 additions & 17 deletions

File tree

crates/rmcp-macros/src/task_handler.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
132132
use rmcp::task_manager::current_timestamp;
133133
let task_id = request.task_id.clone();
134134
let mut processor = (#processor).lock().await;
135-
processor.collect_completed_results();
136135

137136
// Check completed results first
138137
let completed = processor.peek_completed().iter().rev().find(|r| r.descriptor.operation_id == task_id);
@@ -200,7 +199,6 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
200199
// Scope the lock so we can await outside if needed
201200
{
202201
let mut processor = (#processor).lock().await;
203-
processor.collect_completed_results();
204202

205203
if let Some(task_result) = processor.take_completed_result(&task_id) {
206204
match task_result.result {
@@ -256,7 +254,6 @@ pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenS
256254
) -> Result<(), McpError> {
257255
let task_id = request.task_id;
258256
let mut processor = (#processor).lock().await;
259-
processor.collect_completed_results();
260257

261258
if processor.cancel_task(&task_id) {
262259
return Ok(());

crates/rmcp/src/task_manager.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ pub struct OperationProcessor {
8080
running_tasks: HashMap<String, RunningTask>,
8181
/// Completed results waiting to be collected
8282
completed_results: Vec<TaskResult>,
83-
task_result_receiver: Option<mpsc::UnboundedReceiver<TaskResult>>,
83+
task_result_receiver: mpsc::UnboundedReceiver<TaskResult>,
8484
task_result_sender: mpsc::UnboundedSender<TaskResult>,
8585
}
8686

@@ -138,7 +138,7 @@ impl OperationProcessor {
138138
Self {
139139
running_tasks: HashMap::new(),
140140
completed_results: Vec::new(),
141-
task_result_receiver: Some(task_result_receiver),
141+
task_result_receiver,
142142
task_result_sender,
143143
}
144144
}
@@ -195,18 +195,16 @@ impl OperationProcessor {
195195
}
196196

197197
/// Collect completed results from running tasks and remove them from the running tasks map.
198-
pub fn collect_completed_results(&mut self) -> Vec<TaskResult> {
199-
if let Some(receiver) = &mut self.task_result_receiver {
200-
while let Ok(result) = receiver.try_recv() {
201-
self.running_tasks.remove(&result.descriptor.operation_id);
202-
self.completed_results.push(result);
203-
}
198+
fn collect_completed_results(&mut self) {
199+
while let Ok(result) = self.task_result_receiver.try_recv() {
200+
self.running_tasks.remove(&result.descriptor.operation_id);
201+
self.completed_results.push(result);
204202
}
205-
std::mem::take(&mut self.completed_results)
206203
}
207204

208205
/// Check for tasks that have exceeded their timeout and handle them appropriately.
209206
pub fn check_timeouts(&mut self) {
207+
self.collect_completed_results();
210208
let now = std::time::Instant::now();
211209
let mut timed_out_tasks = Vec::new();
212210

@@ -231,7 +229,8 @@ impl OperationProcessor {
231229
}
232230

233231
/// Get the number of running tasks.
234-
pub fn running_task_count(&self) -> usize {
232+
pub fn running_task_count(&mut self) -> usize {
233+
self.collect_completed_results();
235234
self.running_tasks.len()
236235
}
237236

@@ -240,15 +239,19 @@ impl OperationProcessor {
240239
for (_, task) in self.running_tasks.drain() {
241240
task.task_handle.abort();
242241
}
242+
while self.task_result_receiver.try_recv().is_ok() {}
243243
self.completed_results.clear();
244244
}
245+
245246
/// List running task ids.
246-
pub fn list_running(&self) -> Vec<String> {
247+
pub fn list_running(&mut self) -> Vec<String> {
248+
self.collect_completed_results();
247249
self.running_tasks.keys().cloned().collect()
248250
}
249251

250-
/// Note: collectors should call collect_completed_results; this provides a snapshot of queued results.
251-
pub fn peek_completed(&self) -> &[TaskResult] {
252+
/// Returns a snapshot of completed task results.
253+
pub fn peek_completed(&mut self) -> &[TaskResult] {
254+
self.collect_completed_results();
252255
&self.completed_results
253256
}
254257

@@ -266,6 +269,7 @@ impl OperationProcessor {
266269

267270
/// Attempt to cancel a running task.
268271
pub fn cancel_task(&mut self, task_id: &str) -> bool {
272+
self.collect_completed_results();
269273
if let Some(task) = self.running_tasks.remove(task_id) {
270274
task.task_handle.abort();
271275
// Insert a cancelled result so callers can observe the terminal state.
@@ -281,6 +285,7 @@ impl OperationProcessor {
281285

282286
/// Retrieve a completed task result if available.
283287
pub fn take_completed_result(&mut self, task_id: &str) -> Option<TaskResult> {
288+
self.collect_completed_results();
284289
if let Some(position) = self
285290
.completed_results
286291
.iter()

crates/rmcp/tests/test_task.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async fn executes_enqueued_future() {
3636
.expect("submit operation");
3737

3838
tokio::time::sleep(Duration::from_millis(30)).await;
39-
let results = processor.collect_completed_results();
39+
let results = processor.peek_completed();
4040
assert_eq!(results.len(), 1);
4141
let payload = results[0]
4242
.result

0 commit comments

Comments
 (0)