Skip to content

Commit a74de0f

Browse files
authored
feat(unstable): Add logout support (#84)
1 parent 06ca772 commit a74de0f

File tree

3 files changed

+102
-0
lines changed

3 files changed

+102
-0
lines changed

src/agent-client-protocol/src/agent.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use agent_client_protocol_schema::{
1111
use agent_client_protocol_schema::{CloseSessionRequest, CloseSessionResponse};
1212
#[cfg(feature = "unstable_session_fork")]
1313
use agent_client_protocol_schema::{ForkSessionRequest, ForkSessionResponse};
14+
#[cfg(feature = "unstable_logout")]
15+
use agent_client_protocol_schema::{LogoutRequest, LogoutResponse};
1416
#[cfg(feature = "unstable_session_resume")]
1517
use agent_client_protocol_schema::{ResumeSessionRequest, ResumeSessionResponse};
1618
#[cfg(feature = "unstable_session_model")]
@@ -46,6 +48,21 @@ pub trait Agent {
4648
/// See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization)
4749
async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse>;
4850

51+
/// **UNSTABLE**
52+
///
53+
/// This capability is not part of the spec yet, and may be removed or changed at any point.
54+
///
55+
/// Logs out of the current authenticated state.
56+
///
57+
/// After a successful logout, all new sessions will require authentication.
58+
/// There is no guarantee about the behavior of already running sessions.
59+
///
60+
/// Only available if the Agent supports the `auth.logout` capability.
61+
#[cfg(feature = "unstable_logout")]
62+
async fn logout(&self, _args: LogoutRequest) -> Result<LogoutResponse> {
63+
Err(Error::method_not_found())
64+
}
65+
4966
/// Creates a new conversation session with the agent.
5067
///
5168
/// Sessions represent independent conversation contexts with their own history and state.
@@ -229,6 +246,10 @@ impl<T: Agent> Agent for Rc<T> {
229246
async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
230247
self.as_ref().authenticate(args).await
231248
}
249+
#[cfg(feature = "unstable_logout")]
250+
async fn logout(&self, args: LogoutRequest) -> Result<LogoutResponse> {
251+
self.as_ref().logout(args).await
252+
}
232253
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
233254
self.as_ref().new_session(args).await
234255
}
@@ -292,6 +313,10 @@ impl<T: Agent> Agent for Arc<T> {
292313
async fn authenticate(&self, args: AuthenticateRequest) -> Result<AuthenticateResponse> {
293314
self.as_ref().authenticate(args).await
294315
}
316+
#[cfg(feature = "unstable_logout")]
317+
async fn logout(&self, args: LogoutRequest) -> Result<LogoutResponse> {
318+
self.as_ref().logout(args).await
319+
}
295320
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
296321
self.as_ref().new_session(args).await
297322
}

src/agent-client-protocol/src/lib.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,17 @@ impl Agent for ClientSideConnection {
9595
.map(Option::unwrap_or_default)
9696
}
9797

98+
#[cfg(feature = "unstable_logout")]
99+
async fn logout(&self, args: LogoutRequest) -> Result<LogoutResponse> {
100+
self.conn
101+
.request::<Option<_>>(
102+
AGENT_METHOD_NAMES.logout,
103+
Some(ClientRequest::LogoutRequest(args)),
104+
)
105+
.await
106+
.map(Option::unwrap_or_default)
107+
}
108+
98109
async fn new_session(&self, args: NewSessionRequest) -> Result<NewSessionResponse> {
99110
self.conn
100111
.request(
@@ -554,6 +565,10 @@ impl Side for AgentSide {
554565
m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
555566
.map(ClientRequest::AuthenticateRequest)
556567
.map_err(Into::into),
568+
#[cfg(feature = "unstable_logout")]
569+
m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get())
570+
.map(ClientRequest::LogoutRequest)
571+
.map_err(Into::into),
557572
m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
558573
.map(ClientRequest::NewSessionRequest)
559574
.map_err(Into::into),
@@ -635,6 +650,11 @@ impl<T: Agent> MessageHandler<AgentSide> for T {
635650
let response = self.authenticate(args).await?;
636651
Ok(AgentResponse::AuthenticateResponse(response))
637652
}
653+
#[cfg(feature = "unstable_logout")]
654+
ClientRequest::LogoutRequest(args) => {
655+
let response = self.logout(args).await?;
656+
Ok(AgentResponse::LogoutResponse(response))
657+
}
638658
ClientRequest::NewSessionRequest(args) => {
639659
let response = self.new_session(args).await?;
640660
Ok(AgentResponse::NewSessionResponse(response))

src/agent-client-protocol/src/rpc_tests.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ struct TestAgent {
133133
sessions: Arc<Mutex<std::collections::HashMap<SessionId, std::path::PathBuf>>>,
134134
prompts_received: Arc<Mutex<Vec<PromptReceived>>>,
135135
cancellations_received: Arc<Mutex<Vec<SessionId>>>,
136+
#[cfg(feature = "unstable_logout")]
137+
logout_count: Arc<Mutex<u32>>,
136138
extension_notifications: Arc<Mutex<Vec<(String, ExtNotification)>>>,
137139
}
138140

@@ -144,6 +146,8 @@ impl TestAgent {
144146
sessions: Arc::new(Mutex::new(std::collections::HashMap::new())),
145147
prompts_received: Arc::new(Mutex::new(Vec::new())),
146148
cancellations_received: Arc::new(Mutex::new(Vec::new())),
149+
#[cfg(feature = "unstable_logout")]
150+
logout_count: Arc::new(Mutex::new(0)),
147151
extension_notifications: Arc::new(Mutex::new(Vec::new())),
148152
}
149153
}
@@ -153,13 +157,28 @@ impl TestAgent {
153157
impl Agent for TestAgent {
154158
async fn initialize(&self, arguments: InitializeRequest) -> Result<InitializeResponse> {
155159
Ok(InitializeResponse::new(arguments.protocol_version)
160+
.agent_capabilities(
161+
AgentCapabilities::new().auth(
162+
agent_client_protocol_schema::AgentAuthCapabilities::new()
163+
.logout(agent_client_protocol_schema::LogoutCapabilities::new()),
164+
),
165+
)
156166
.agent_info(Implementation::new("test-agent", "0.0.0").title("Test Agent")))
157167
}
158168

159169
async fn authenticate(&self, _arguments: AuthenticateRequest) -> Result<AuthenticateResponse> {
160170
Ok(AuthenticateResponse::default())
161171
}
162172

173+
#[cfg(feature = "unstable_logout")]
174+
async fn logout(
175+
&self,
176+
_arguments: agent_client_protocol_schema::LogoutRequest,
177+
) -> Result<agent_client_protocol_schema::LogoutResponse> {
178+
*self.logout_count.lock().unwrap() += 1;
179+
Ok(agent_client_protocol_schema::LogoutResponse::default())
180+
}
181+
163182
async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse> {
164183
let session_id = SessionId::new("test-session-123");
165184
self.sessions
@@ -886,6 +905,44 @@ async fn test_session_info_update() {
886905
.await;
887906
}
888907

908+
#[cfg(feature = "unstable_logout")]
909+
#[tokio::test]
910+
async fn test_logout() {
911+
let local_set = tokio::task::LocalSet::new();
912+
local_set
913+
.run_until(async {
914+
let client = TestClient::new();
915+
let agent = TestAgent::new();
916+
917+
let (agent_conn, _client_conn) = create_connection_pair(&client, &agent);
918+
919+
let initialize_response =
920+
agent_conn
921+
.initialize(InitializeRequest::new(ProtocolVersion::LATEST).client_info(
922+
Implementation::new("test-client", "0.0.0").title("Test Client"),
923+
))
924+
.await
925+
.expect("initialize failed");
926+
927+
assert!(
928+
initialize_response.agent_capabilities.auth.logout.is_some(),
929+
"agent should advertise auth.logout capability"
930+
);
931+
932+
let response = agent_conn
933+
.logout(agent_client_protocol_schema::LogoutRequest::new())
934+
.await
935+
.expect("logout failed");
936+
937+
assert_eq!(
938+
response,
939+
agent_client_protocol_schema::LogoutResponse::default()
940+
);
941+
assert_eq!(*agent.logout_count.lock().unwrap(), 1);
942+
})
943+
.await;
944+
}
945+
889946
#[tokio::test]
890947
async fn test_set_session_config_option() {
891948
let local_set = tokio::task::LocalSet::new();

0 commit comments

Comments
 (0)