Skip to content

Commit a577949

Browse files
committed
Cleanup and add test
1 parent 554233e commit a577949

2 files changed

Lines changed: 126 additions & 17 deletions

File tree

src/agent-client-protocol-core/src/jsonrpc.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,11 +2349,17 @@ impl<Req: JsonRpcRequest, Notif: JsonRpcMessage> Dispatch<Req, Notif> {
23492349
/// * **Requests** – sends the error back to the caller via the [`Responder`].
23502350
/// * **Responses** – forwards the error to the waiting handler via the
23512351
/// [`ResponseRouter`].
2352-
/// * **Notifications** – there is no request ID to reply to, so the error
2352+
/// * **Notifications** – there is no request ID to reply to, and no
2353+
/// connection is available to send an error notification, so the error
23532354
/// is logged and swallowed.
23542355
///
23552356
/// Returns `Ok(Handled::Yes)` in all cases so the connection loop
23562357
/// continues.
2358+
///
2359+
/// **Prefer [`respond_with_error`](Self::respond_with_error)** when a
2360+
/// [`ConnectionTo`] is available — it can send an error notification for
2361+
/// malformed notifications, which is consistent with
2362+
/// [`TypeNotification`](crate::util::TypeNotification).
23572363
pub(crate) fn reject_parse_error(
23582364
self,
23592365
error: crate::Error,
@@ -2570,13 +2576,17 @@ impl Dispatch {
25702576
}
25712577
}
25722578
Dispatch::Response(result, cx) => {
2573-
let method = cx.method().to_string();
2574-
if Req::matches_method(&method) {
2575-
let typed_result = match result {
2576-
Ok(value) => match <Req::Response as JsonRpcResponse>::from_value(
2577-
&method,
2579+
if !Req::matches_method(cx.method()) {
2580+
tracing::trace!("method doesn't match");
2581+
return TypedDispatchOutcome::Unhandled(Dispatch::Response(result, cx));
2582+
}
2583+
let typed_result = match result {
2584+
Ok(value) => {
2585+
let parsed = <Req::Response as JsonRpcResponse>::from_value(
2586+
cx.method(),
25782587
value.clone(),
2579-
) {
2588+
);
2589+
match parsed {
25802590
Ok(parsed) => {
25812591
tracing::trace!(?parsed, "parse ok");
25822592
Ok(parsed)
@@ -2588,17 +2598,14 @@ impl Dispatch {
25882598
error: err,
25892599
};
25902600
}
2591-
},
2592-
Err(err) => {
2593-
tracing::trace!("error, passthrough");
2594-
Err(err)
25952601
}
2596-
};
2597-
TypedDispatchOutcome::Matched(Dispatch::Response(typed_result, cx.cast()))
2598-
} else {
2599-
tracing::trace!("method doesn't match");
2600-
TypedDispatchOutcome::Unhandled(Dispatch::Response(result, cx))
2601-
}
2602+
}
2603+
Err(err) => {
2604+
tracing::trace!("error, passthrough");
2605+
Err(err)
2606+
}
2607+
};
2608+
TypedDispatchOutcome::Matched(Dispatch::Response(typed_result, cx.cast()))
26022609
}
26032610
}
26042611
}

src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,108 @@ async fn test_bad_request_params_return_invalid_params_and_connection_stays_aliv
590590
.await;
591591
}
592592

593+
#[tokio::test(flavor = "current_thread")]
594+
async fn test_bad_notification_params_swallowed_and_connection_stays_alive() {
595+
use tokio::io::{AsyncWriteExt, BufReader};
596+
use tokio::task::LocalSet;
597+
598+
let local = LocalSet::new();
599+
600+
local
601+
.run_until(async {
602+
let (mut client_writer, server_reader) = tokio::io::duplex(2048);
603+
let (server_writer, client_reader) = tokio::io::duplex(2048);
604+
605+
let server_reader = server_reader.compat();
606+
let server_writer = server_writer.compat_write();
607+
608+
let server_transport =
609+
agent_client_protocol_core::ByteStreams::new(server_writer, server_reader);
610+
let server = UntypedRole
611+
.builder()
612+
.on_receive_notification(
613+
async |_notif: SimpleNotification,
614+
_connection: ConnectionTo<UntypedRole>| {
615+
// If we get here, the notification parsed successfully.
616+
Ok(())
617+
},
618+
agent_client_protocol_core::on_receive_notification!(),
619+
)
620+
.on_receive_request(
621+
async |request: SimpleRequest,
622+
responder: Responder<SimpleResponse>,
623+
_connection: ConnectionTo<UntypedRole>| {
624+
responder.respond(SimpleResponse {
625+
result: format!("echo: {}", request.message),
626+
})
627+
},
628+
agent_client_protocol_core::on_receive_request!(),
629+
);
630+
631+
tokio::task::spawn_local(async move {
632+
if let Err(err) = server.connect_to(server_transport).await {
633+
panic!("server should stay alive: {err:?}");
634+
}
635+
});
636+
637+
let mut client_reader = BufReader::new(client_reader);
638+
639+
// Send a notification with bad params (wrong field name).
640+
// Notifications have no "id", so the server sends an error
641+
// notification (id: null) and keeps the connection alive.
642+
client_writer
643+
.write_all(
644+
br#"{"jsonrpc":"2.0","method":"simple_notification","params":{"wrong_field":"hello"}}
645+
"#,
646+
)
647+
.await
648+
.unwrap();
649+
client_writer.flush().await.unwrap();
650+
651+
// The server sends an error notification (id: null) for the
652+
// malformed notification.
653+
let error_notification = read_jsonrpc_response_line(&mut client_reader).await;
654+
expect![[r#"
655+
{
656+
"error": {
657+
"code": -32602,
658+
"data": {
659+
"error": "missing field `message`",
660+
"json": {
661+
"wrong_field": "hello"
662+
},
663+
"phase": "deserialization"
664+
},
665+
"message": "Invalid params"
666+
},
667+
"jsonrpc": "2.0"
668+
}"#]]
669+
.assert_eq(&serde_json::to_string_pretty(&error_notification).unwrap());
670+
671+
// Now send a valid request to prove the connection is still alive.
672+
client_writer
673+
.write_all(
674+
br#"{"jsonrpc":"2.0","id":10,"method":"simple_method","params":{"message":"after bad notification"}}
675+
"#,
676+
)
677+
.await
678+
.unwrap();
679+
client_writer.flush().await.unwrap();
680+
681+
let ok_response = read_jsonrpc_response_line(&mut client_reader).await;
682+
expect![[r#"
683+
{
684+
"id": 10,
685+
"jsonrpc": "2.0",
686+
"result": {
687+
"result": "echo: after bad notification"
688+
}
689+
}"#]]
690+
.assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap());
691+
})
692+
.await;
693+
}
694+
593695
#[tokio::test(flavor = "current_thread")]
594696
async fn test_match_dispatch_from_if_message_invalid_params_keeps_connection_alive() {
595697
use tokio::io::{AsyncWriteExt, BufReader};

0 commit comments

Comments
 (0)