Skip to content

Commit 8732a6d

Browse files
committed
#60 Add ability to schedule session post init activity
1 parent 34b8bd7 commit 8732a6d

7 files changed

Lines changed: 104 additions & 5 deletions

File tree

acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,4 +970,60 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver
970970
assertEquals(1, sessions.size)
971971
}
972972

973+
@Test
974+
@OptIn(UnstableApi::class)
975+
fun `session ready notifies available commands`() = testWithProtocols { clientProtocol, agentProtocol ->
976+
val readyUpdate = CompletableDeferred<SessionUpdate>()
977+
val client = Client(protocol = clientProtocol)
978+
Agent(protocol = agentProtocol, agentSupport = object : AgentSupport {
979+
override suspend fun initialize(clientInfo: ClientInfo): AgentInfo {
980+
return AgentInfo(clientInfo.protocolVersion)
981+
}
982+
983+
override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession {
984+
return object : AgentSession {
985+
override val sessionId: SessionId = SessionId("ready-session-id")
986+
987+
override suspend fun postInitialize() {
988+
currentCoroutineContext().client.notify(
989+
SessionUpdate.AvailableCommandsUpdate(
990+
listOf(AvailableCommand("help", "Show available commands", AvailableCommandInput.Unstructured("topic")))
991+
)
992+
)
993+
}
994+
995+
override suspend fun prompt(
996+
content: List<ContentBlock>,
997+
_meta: JsonElement?,
998+
): Flow<Event> = emptyFlow()
999+
}
1000+
}
1001+
})
1002+
1003+
client.initialize(ClientInfo(protocolVersion = 10))
1004+
client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ ->
1005+
object : ClientSessionOperations {
1006+
override suspend fun requestPermissions(
1007+
toolCall: SessionUpdate.ToolCallUpdate,
1008+
permissions: List<PermissionOption>,
1009+
_meta: JsonElement?,
1010+
): RequestPermissionResponse {
1011+
return RequestPermissionResponse(RequestPermissionOutcome.Cancelled)
1012+
}
1013+
1014+
override suspend fun notify(
1015+
notification: SessionUpdate,
1016+
_meta: JsonElement?,
1017+
) {
1018+
readyUpdate.complete(notification)
1019+
}
1020+
}
1021+
}
1022+
1023+
val update = withTimeout(1000) { readyUpdate.await() }
1024+
assertTrue(update is SessionUpdate.AvailableCommandsUpdate)
1025+
val command = (update as SessionUpdate.AvailableCommandsUpdate).availableCommands.single()
1026+
assertEquals("help", command.name)
1027+
}
1028+
9731029
}

acp/api/acp.api

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ public abstract interface class com/agentclientprotocol/agent/AgentSession {
4949
public fun getDefaultMode-kyZWFqk ()Ljava/lang/String;
5050
public fun getDefaultModel-GMZLII8 ()Ljava/lang/String;
5151
public abstract fun getSessionId-7EW-EgU ()Ljava/lang/String;
52+
public fun postInitialize (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
53+
public static synthetic fun postInitialize$suspendImpl (Lcom/agentclientprotocol/agent/AgentSession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
5254
public abstract fun prompt (Ljava/util/List;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
5355
public static synthetic fun prompt$default (Lcom/agentclientprotocol/agent/AgentSession;Ljava/util/List;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
5456
public fun setConfigOption-p1da2xM (Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
@@ -66,6 +68,7 @@ public final class com/agentclientprotocol/agent/AgentSession$DefaultImpls {
6668
public static fun getConfigOptions (Lcom/agentclientprotocol/agent/AgentSession;)Ljava/util/List;
6769
public static fun getDefaultMode-kyZWFqk (Lcom/agentclientprotocol/agent/AgentSession;)Ljava/lang/String;
6870
public static fun getDefaultModel-GMZLII8 (Lcom/agentclientprotocol/agent/AgentSession;)Ljava/lang/String;
71+
public static fun postInitialize (Lcom/agentclientprotocol/agent/AgentSession;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
6972
public static synthetic fun prompt$default (Lcom/agentclientprotocol/agent/AgentSession;Ljava/util/List;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
7073
public static fun setConfigOption-p1da2xM (Lcom/agentclientprotocol/agent/AgentSession;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
7174
public static fun setMode-tFeI3nk (Lcom/agentclientprotocol/agent/AgentSession;Ljava/lang/String;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;

acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ public class Agent(
249249
RemoteClientSessionOperations(protocol, session.sessionId, clientInfo.capabilities),
250250
protocol
251251
)
252+
currentCoroutineContext().executeAfterCurrentRequest { sessionWrapper.executeWithSession { session.postInitialize() } }
252253

253254
_sessions.update {
254255
it.put(session.sessionId, sessionWrapper)

acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSession.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ import kotlinx.serialization.json.JsonElement
99
public interface AgentSession {
1010
public val sessionId: SessionId
1111

12+
/**
13+
* Executes after the session is created and send to the client. Can be used to send additional notifications like Commands.
14+
*
15+
* To access client operations use:
16+
* ```
17+
* currentCoroutineContext().client
18+
* ```
19+
*
20+
* This method shouldn't throw exceptions.
21+
*/
22+
public suspend fun postInitialize() {}
1223
/**
1324
* Sends a message to the agent for execution and waits for the whole turn to be completed.
1425
* During execution, the agent can send notifications or requests to the client.

acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.extensions.kt

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,28 @@ public operator fun <TNotification : AcpNotification> AcpMethod.AcpNotificationM
111111
return rpc.sendNotification(this, notification)
112112
}
113113

114-
internal class JsonRpcRequestContextElement(val jsonRpcRequest: JsonRpcRequest) : AbstractCoroutineContextElement(Key) {
114+
internal class RequestHolder(val jsonRpcRequest: JsonRpcRequest) {
115+
// probably make it thread safe
116+
internal val handlers = mutableListOf<suspend () -> Unit>()
117+
fun executeAfterCurrentRequest(block: suspend () -> Unit) {
118+
handlers.add(block)
119+
}
120+
}
121+
122+
internal class JsonRpcRequestContextElement(val requestHolder: RequestHolder) : AbstractCoroutineContextElement(Key) {
115123
object Key : CoroutineContext.Key<JsonRpcRequestContextElement>
116124
}
117125

126+
internal val CoroutineContext.requestHolder: RequestHolder
127+
get() = this[JsonRpcRequestContextElement.Key]?.requestHolder ?: error("There is no active incoming request in this context")
128+
118129
public val CoroutineContext.jsonRpcRequest: JsonRpcRequest
119-
get() = this[JsonRpcRequestContextElement.Key]?.jsonRpcRequest ?: error("No JsonRpcRequest found in context")
130+
get() = this.requestHolder.jsonRpcRequest
120131

121-
internal fun JsonRpcRequest.asContextElement() = JsonRpcRequestContextElement(this)
132+
133+
/**
134+
* Execute a block after the current request is processed and the response is sent back to the client.
135+
*/
136+
internal fun CoroutineContext.executeAfterCurrentRequest(block: suspend () -> Unit) {
137+
requestHolder.executeAfterCurrentRequest(block)
138+
}

acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,11 @@ public class Protocol(
372372
}
373373

374374
private suspend fun handleRequest(request: JsonRpcRequest) {
375+
val requestHolder = RequestHolder(request)
375376
val handler = requestHandlers.value[request.method]
376377
if (handler != null) {
377378
try {
378-
val result = withContext(request.asContextElement()) {
379+
val result = withContext(JsonRpcRequestContextElement(requestHolder)) {
379380
handler(request)
380381
}
381382
sendResponse(request.id, result, null)
@@ -415,6 +416,16 @@ public class Protocol(
415416
)
416417
)
417418
}
419+
for (handler in requestHolder.handlers) {
420+
runCatching { handler() }.onFailure { t ->
421+
if (t is CancellationException) {
422+
// ignore CE
423+
logger.trace(t) { "Request handler for '${request.method}' cancelled" }
424+
} else {
425+
logger.error(t) { "Error handling after request handlers for ${request.method}" }
426+
}
427+
}
428+
}
418429
} else {
419430
val error = JsonRpcError(
420431
code = JsonRpcErrorCode.METHOD_NOT_FOUND.code, message = "Method not supported: ${request.method}"

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ plugins {
77
private val buildNumber: String? = System.getenv("GITHUB_RUN_NUMBER")
88
private val isReleasePublication = System.getenv("RELEASE_PUBLICATION")?.toBoolean() ?: false
99

10-
private val baseVersion = "0.14.1"
10+
private val baseVersion = "0.15.1"
1111

1212
allprojects {
1313
group = "com.agentclientprotocol"

0 commit comments

Comments
 (0)