Skip to content

Commit 1fca9dc

Browse files
committed
feat(unstable): Add logout support
1 parent 06ca772 commit 1fca9dc

3 files changed

Lines changed: 103 additions & 0 deletions

File tree

src/agent-client-protocol/src/agent.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use agent_client_protocol_schema::{
1111
use agent_client_protocol_schema::{CloseSessionRequest, CloseSessionResponse};
1212
#[cfg(feature = "unstable_session_fork")]
1313
use agent_client_protocol_schema::{ForkSessionRequest, ForkSessionResponse};
14+
#[cfg(feature = "unstable_logout")]
15+
use agent_client_protocol_schema::{LogoutRequest, LogoutResponse};
1416
#[cfg(feature = "unstable_session_resume")]
1517
use agent_client_protocol_schema::{ResumeSessionRequest, ResumeSessionResponse};
1618
#[cfg(feature = "unstable_session_model")]
@@ -46,6 +48,21 @@ pub trait Agent {
4648
/// See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization)
4749
async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse>;
4850

51+
/// **UNSTABLE**
52+
///
53+
/// This capability is not part of the spec yet, and may be removed or changed at any point.
54+
///
55+
/// Logs out of the current authenticated state.
56+
///
57+
/// After a successful logout, all new sessions will require authentication.
58+
/// There is no guarantee about the behavior of already running sessions.
59+
///
60+
/// Only available if the Agent supports the `auth.logout` capability.
61+
#[cfg(feature = "unstable_logout")]
62+
async fn logout(&self, _args: LogoutRequest) -> Result<LogoutResponse> {
63+
Err(Error::method_not_found())
64+
}
65+
4966
/// Creates a new conversation session with the agent.
5067
///
5168
/// Sessions represent independent conversation contexts with their own history and state.
@@ -229,6 +246,10 @@ impl<T: Agent> Agent for Rc<T> {
229246
async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
230247
self.as_ref().authenticate(args).await
231248
}
249+
#[cfg(feature = "unstable_logout")]
250+
async fn logout(&self, args: LogoutRequest) -> Result<LogoutResponse> {
251+
self.as_ref().logout(args).await
252+
}
232253
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
233254
self.as_ref().new_session(args).await
234255
}
@@ -292,6 +313,10 @@ impl<T: Agent> Agent for Arc<T> {
292313
async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
293314
self.as_ref().authenticate(args).await
294315
}
316+
#[cfg(feature = "unstable_logout")]
317+
async fn logout(&self, args: LogoutRequest) -> Result<LogoutResponse> {
318+
self.as_ref().logout(args).await
319+
}
295320
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
296321
self.as_ref().new_session(args).await
297322
}

src/agent-client-protocol/src/lib.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,17 @@ impl Agent for ClientSideConnection {
9595
.map(Option::unwrap_or_default)
9696
}
9797

98+
#[cfg(feature = "unstable_logout")]
99+
async fn logout(&self, args: LogoutRequest) -> Result<LogoutResponse> {
100+
self.conn
101+
.request::<Option<_>>(
102+
AGENT_METHOD_NAMES.logout,
103+
Some(ClientRequest::LogoutRequest(args)),
104+
)
105+
.await
106+
.map(Option::unwrap_or_default)
107+
}
108+
98109
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99110
self.conn
100111
.request(
@@ -554,6 +565,10 @@ impl Side for AgentSide {
554565
m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
555566
.map(ClientRequest::AuthenticateRequest)
556567
.map_err(Into::into),
568+
#[cfg(feature = "unstable_logout")]
569+
m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get())
570+
.map(ClientRequest::LogoutRequest)
571+
.map_err(Into::into),
557572
m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
558573
.map(ClientRequest::NewSessionRequest)
559574
.map_err(Into::into),
@@ -635,6 +650,11 @@ impl<T: Agent> MessageHandler<AgentSide> for T {
635650
let response = self.authenticate(args).await?;
636651
Ok(AgentResponse::AuthenticateResponse(response))
637652
}
653+
#[cfg(feature = "unstable_logout")]
654+
ClientRequest::LogoutRequest(args) => {
655+
let response = self.logout(args).await?;
656+
Ok(AgentResponse::LogoutResponse(response))
657+
}
638658
ClientRequest::NewSessionRequest(args) => {
639659
let response = self.new_session(args).await?;
640660
Ok(AgentResponse::NewSessionResponse(response))

src/agent-client-protocol/src/rpc_tests.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ struct TestAgent {
133133
sessions: Arc<Mutex<std::collections::HashMap<SessionId, std::path::PathBuf>>>,
134134
prompts_received: Arc<Mutex<Vec<PromptReceived>>>,
135135
cancellations_received: Arc<Mutex<Vec<SessionId>>>,
136+
#[cfg(feature = "unstable_logout")]
137+
logout_count: Arc<Mutex<u32>>,
136138
extension_notifications: Arc<Mutex<Vec<(String, ExtNotification)>>>,
137139
}
138140

@@ -144,6 +146,8 @@ impl TestAgent {
144146
sessions: Arc::new(Mutex::new(std::collections::HashMap::new())),
145147
prompts_received: Arc::new(Mutex::new(Vec::new())),
146148
cancellations_received: Arc::new(Mutex::new(Vec::new())),
149+
#[cfg(feature = "unstable_logout")]
150+
logout_count: Arc::new(Mutex::new(0)),
147151
extension_notifications: Arc::new(Mutex::new(Vec::new())),
148152
}
149153
}
@@ -153,13 +157,29 @@ impl TestAgent {
153157
impl Agent for TestAgent {
154158
async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse> {
155159
Ok(InitializeResponse::new(arguments.protocol_version)
160+
#[cfg(feature = "unstable_logout")]
161+
.agent_capabilities(
162+
AgentCapabilities::new().auth(
163+
agent_client_protocol_schema::AgentAuthCapabilities::new()
164+
.logout(agent_client_protocol_schema::LogoutCapabilities::new()),
165+
),
166+
)
156167
.agent_info(Implementation::new("test-agent", "0.0.0").title("Test Agent")))
157168
}
158169

159170
async fn authenticate(&self, _arguments: AuthenticateRequest) -> Result<AuthenticateResponse> {
160171
Ok(AuthenticateResponse::default())
161172
}
162173

174+
#[cfg(feature = "unstable_logout")]
175+
async fn logout(
176+
&self,
177+
_arguments: agent_client_protocol_schema::LogoutRequest,
178+
) -> Result<agent_client_protocol_schema::LogoutResponse> {
179+
*self.logout_count.lock().unwrap() += 1;
180+
Ok(agent_client_protocol_schema::LogoutResponse::default())
181+
}
182+
163183
async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse> {
164184
let session_id = SessionId::new("test-session-123");
165185
self.sessions
@@ -886,6 +906,44 @@ async fn test_session_info_update() {
886906
.await;
887907
}
888908

909+
#[cfg(feature = "unstable_logout")]
910+
#[tokio::test]
911+
async fn test_logout() {
912+
let local_set = tokio::task::LocalSet::new();
913+
local_set
914+
.run_until(async {
915+
let client = TestClient::new();
916+
let agent = TestAgent::new();
917+
918+
let (agent_conn, _client_conn) = create_connection_pair(&client, &agent);
919+
920+
let initialize_response =
921+
agent_conn
922+
.initialize(InitializeRequest::new(ProtocolVersion::LATEST).client_info(
923+
Implementation::new("test-client", "0.0.0").title("Test Client"),
924+
))
925+
.await
926+
.expect("initialize failed");
927+
928+
assert!(
929+
initialize_response.agent_capabilities.auth.logout.is_some(),
930+
"agent should advertise auth.logout capability"
931+
);
932+
933+
let response = agent_conn
934+
.logout(agent_client_protocol_schema::LogoutRequest::new())
935+
.await
936+
.expect("logout failed");
937+
938+
assert_eq!(
939+
response,
940+
agent_client_protocol_schema::LogoutResponse::default()
941+
);
942+
assert_eq!(*agent.logout_count.lock().unwrap(), 1);
943+
})
944+
.await;
945+
}
946+
889947
#[tokio::test]
890948
async fn test_set_session_config_option() {
891949
let local_set = tokio::task::LocalSet::new();

0 commit comments

Comments
 (0)