Skip to content

Commit 9fc0c17

Browse files
fabmaxMax Thiele
andauthored
Fabmax/fix prompt cancellation (#93)
* Handle prompt cancellation on AgentLayer instead of Protocol layer to return correct PromptResponse on cancellation (#91) * Add tests for basic agent functions --------- Co-authored-by: Max Thiele <maximilian.thiele@jetbrains.com>
1 parent 9abd3ff commit 9fc0c17

4 files changed

Lines changed: 351 additions & 26 deletions

File tree

acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@ import kotlinx.atomicfu.atomic
1313
import kotlinx.atomicfu.update
1414
import kotlinx.collections.immutable.persistentMapOf
1515
import kotlinx.coroutines.CompletableDeferred
16+
import kotlinx.coroutines.CoroutineStart
1617
import kotlinx.coroutines.ExperimentalCoroutinesApi
18+
import kotlinx.coroutines.Job
19+
import kotlinx.coroutines.coroutineScope
1720
import kotlinx.coroutines.currentCoroutineContext
21+
import kotlinx.coroutines.launch
1822
import kotlinx.coroutines.withContext
1923
import kotlinx.serialization.json.JsonElement
2024
import kotlin.coroutines.AbstractCoroutineContextElement
2125
import kotlin.coroutines.CoroutineContext
22-
import kotlin.coroutines.cancellation.CancellationException
2326
import kotlin.math.min
2427
import kotlin.uuid.ExperimentalUuidApi
2528

@@ -58,55 +61,60 @@ public class Agent(
5861
val clientOperations: ClientSessionOperations,
5962
protocol: Protocol
6063
) : BaseSessionWrapper(agent, protocol) {
61-
private class PromptSession(val currentRequestId: RequestId)
64+
private class PromptSession(val currentRequestId: RequestId, val promptJob: Job)
6265
private val _activePrompt = atomic<PromptSession?>(null)
6366

6467
suspend fun prompt(content: List<ContentBlock>, _meta: JsonElement? = null): PromptResponse {
6568
val currentRpcRequest = currentCoroutineContext().jsonRpcRequest
66-
if (!_activePrompt.compareAndSet(null, PromptSession(currentRpcRequest.id))) error("There is already active prompt execution")
67-
try {
68-
var response: PromptResponse? = null
69-
70-
agentSession.prompt(content, _meta).collect { event ->
71-
when (event) {
72-
is Event.PromptResponseEvent -> {
73-
if (response != null) {
74-
logger.error { "Received repeated prompt response: ${event.response} (previous: $response). The last is used" }
69+
var response: PromptResponse? = null
70+
return coroutineScope {
71+
try {
72+
val promptJob = launch(start = CoroutineStart.LAZY) {
73+
agentSession.prompt(content, _meta).collect { event ->
74+
when (event) {
75+
is Event.PromptResponseEvent -> {
76+
if (response != null) {
77+
logger.error { "Received repeated prompt response: ${event.response} (previous: $response). The last is used" }
78+
}
79+
response = event.response
80+
}
81+
82+
is Event.SessionUpdateEvent -> {
83+
clientOperations.notify(event.update, _meta)
84+
}
7585
}
76-
response = event.response
7786
}
87+
}
7888

79-
is Event.SessionUpdateEvent -> {
80-
clientOperations.notify(event.update, _meta)
81-
}
89+
val promptSession = PromptSession(currentRpcRequest.id, promptJob)
90+
if (!_activePrompt.compareAndSet(null, promptSession)) {
91+
error("There is already active prompt execution")
8292
}
93+
promptJob.join()
94+
response ?: PromptResponse(
95+
stopReason = if (promptJob.isCancelled) StopReason.CANCELLED else StopReason.END_TURN
96+
)
97+
} finally {
98+
_activePrompt.getAndSet(null)
8399
}
84-
85-
return response ?: PromptResponse(StopReason.END_TURN)
86-
}
87-
catch (ce: CancellationException) {
88-
logger.trace(ce) { "Prompt job cancelled" }
89-
return PromptResponse(StopReason.CANCELLED)
90-
}
91-
finally {
92-
_activePrompt.getAndSet(null)
93100
}
94101
}
95102

96103
suspend fun cancel() {
97-
// TODO do we need it while the cancellation can be handled by coroutine mechanism (catching CE inside `prompt`)
104+
// notify AgentSession about upcoming cancellation, this way implementations can gracefully stop ongoing requests
98105
agentSession.cancel()
99106

100107
val activePrompt = _activePrompt.getAndSet(null)
101108
if (activePrompt != null) {
109+
logger.trace { "Cancelling prompt" }
102110
// we expect that all nested outgoing jobs will be cancelled automatically due to structured concurrency
103111
// -> prompt task
104112
// <- [request] read file
105113
// -> [response] read file
106114
// <- [request] permissions
107115
// |suspended|
108116
// cancelling the whole prompt should cancel all nested outgoing requests. These requests on CE will propagate cancellation to the other side
109-
protocol.cancelPendingIncomingRequest(activePrompt.currentRequestId)
117+
activePrompt.promptJob.cancel()
110118
}
111119
}
112120
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package com.agentclientprotocol.agent
2+
3+
import com.agentclientprotocol.annotations.UnstableApi
4+
import com.agentclientprotocol.model.*
5+
import kotlinx.coroutines.async
6+
import kotlinx.coroutines.delay
7+
import kotlin.test.Test
8+
import kotlin.test.assertEquals
9+
import kotlin.test.assertNotNull
10+
import kotlin.test.assertTrue
11+
import kotlin.time.Duration.Companion.seconds
12+
13+
@OptIn(UnstableApi::class)
14+
class AgentTest {
15+
16+
@Test
17+
fun `initialize agent`() {
18+
withTestAgent { testAgent ->
19+
val (response) = testAgent.testInitialize(InitializeRequest(LATEST_PROTOCOL_VERSION))
20+
assertNotNull(response)
21+
assertTrue(testAgent.agentSupport.isInitialized)
22+
}
23+
}
24+
25+
@Test
26+
fun `create new session`() {
27+
withInitializedTestAgent { testAgent ->
28+
val (response) = testAgent.testNewSession(NewSessionRequest(cwd = ".", mcpServers = emptyList()))
29+
assertNotNull(response)
30+
assertTrue(response.sessionId in testAgent.agentSupport.createdSessions)
31+
}
32+
}
33+
34+
@Test
35+
fun `simple prompt turn`() {
36+
withTestAgentSession(promptHandler = echoPromptHandler) { testAgent, _ ->
37+
testAgent.simplePrompt("hello").let { (response, updates) ->
38+
assertEquals(StopReason.END_TURN, response.stopReason)
39+
assertEquals(1, updates.size)
40+
41+
val message = updates.filterIsInstance<SessionUpdate.AgentMessageChunk>()
42+
.map { (it.content as? ContentBlock.Text)?.text }
43+
.firstOrNull()
44+
assertEquals("hello", message)
45+
}
46+
47+
testAgent.simplePrompt("world").let { (response, updates) ->
48+
assertEquals(StopReason.END_TURN, response.stopReason)
49+
assertEquals(1, updates.size)
50+
51+
val message = updates.filterIsInstance<SessionUpdate.AgentMessageChunk>()
52+
.map { (it.content as? ContentBlock.Text)?.text }
53+
.firstOrNull()
54+
assertEquals("world", message)
55+
}
56+
}
57+
}
58+
59+
@Test
60+
fun `prompt cancellation`() {
61+
withTestAgentSession(promptHandler = delayEchoPromptHandler(2.seconds)) { testAgent, session ->
62+
val deferredResponse = async { testAgent.simplePrompt("hello").first }
63+
delay(1.seconds)
64+
testAgent.testCancel(CancelNotification(session.sessionId))
65+
66+
val response = deferredResponse.await()
67+
assertEquals(StopReason.CANCELLED, response.stopReason)
68+
}
69+
}
70+
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
@file:OptIn(UnstableApi::class)
2+
3+
package com.agentclientprotocol.agent
4+
5+
import com.agentclientprotocol.annotations.UnstableApi
6+
import com.agentclientprotocol.client.ClientInfo
7+
import com.agentclientprotocol.common.Event
8+
import com.agentclientprotocol.common.SessionCreationParameters
9+
import com.agentclientprotocol.model.AcpMethod
10+
import com.agentclientprotocol.model.AcpNotification
11+
import com.agentclientprotocol.model.AcpRequest
12+
import com.agentclientprotocol.model.AcpResponse
13+
import com.agentclientprotocol.model.CancelNotification
14+
import com.agentclientprotocol.model.ContentBlock
15+
import com.agentclientprotocol.model.InitializeRequest
16+
import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION
17+
import com.agentclientprotocol.model.McpServer
18+
import com.agentclientprotocol.model.NewSessionRequest
19+
import com.agentclientprotocol.model.PromptRequest
20+
import com.agentclientprotocol.model.PromptResponse
21+
import com.agentclientprotocol.model.SessionId
22+
import com.agentclientprotocol.model.SessionUpdate
23+
import com.agentclientprotocol.model.StopReason
24+
import com.agentclientprotocol.protocol.Protocol
25+
import com.agentclientprotocol.rpc.ACPJson
26+
import com.agentclientprotocol.rpc.JsonRpcNotification
27+
import com.agentclientprotocol.rpc.JsonRpcResponse
28+
import kotlinx.atomicfu.atomic
29+
import kotlinx.coroutines.CoroutineScope
30+
import kotlinx.coroutines.delay
31+
import kotlinx.coroutines.flow.Flow
32+
import kotlinx.coroutines.flow.FlowCollector
33+
import kotlinx.coroutines.flow.flow
34+
import kotlinx.coroutines.runBlocking
35+
import kotlinx.serialization.json.JsonElement
36+
import kotlin.time.Duration
37+
import kotlin.time.Duration.Companion.milliseconds
38+
import kotlin.time.Duration.Companion.seconds
39+
40+
class TestAgent(val agent: Agent, val agentSupport: TestAgentSupport, val transport: TestTransport) {
41+
suspend fun <TRequest : AcpRequest, TResponse : AcpResponse> testRequest(
42+
method: AcpMethod.AcpRequestResponseMethod<TRequest, TResponse>,
43+
request: TRequest
44+
): Pair<TResponse?, List<JsonRpcNotification>> {
45+
val received = transport.fireTestRequest(
46+
methodName = method.methodName,
47+
params = ACPJson.encodeToJsonElement(method.requestSerializer, request)
48+
)
49+
val response = (received.lastOrNull() as? JsonRpcResponse)?.result?.let {
50+
ACPJson.decodeFromJsonElement(method.responseSerializer, it)
51+
}
52+
val notifications = received.filterIsInstance<JsonRpcNotification>()
53+
return response to notifications
54+
}
55+
56+
fun <TNotification : AcpNotification> testNotification(
57+
method: AcpMethod.AcpNotificationMethod<TNotification>,
58+
notification: TNotification
59+
) {
60+
transport.fireTestNotification(method.methodName, ACPJson.encodeToJsonElement(method.serializer, notification))
61+
}
62+
63+
fun close() {
64+
agent.protocol.close()
65+
}
66+
67+
suspend fun testInitialize(request: InitializeRequest) = testRequest(AcpMethod.AgentMethods.Initialize, request)
68+
suspend fun testNewSession(request: NewSessionRequest) = testRequest(AcpMethod.AgentMethods.SessionNew, request)
69+
suspend fun testPrompt(request: PromptRequest) = testRequest(AcpMethod.AgentMethods.SessionPrompt, request)
70+
71+
fun testCancel(notification: CancelNotification) = testNotification(AcpMethod.AgentMethods.SessionCancel, notification)
72+
}
73+
74+
suspend fun TestAgent.simplePrompt(prompt: String): Pair<PromptResponse, List<SessionUpdate>> {
75+
val session = agentSupport.createdSessions.values.single()
76+
val (resp, notifications) = testPrompt(PromptRequest(session.sessionId, listOf(ContentBlock.Text(prompt))))
77+
checkNotNull(resp)
78+
79+
return resp to notifications
80+
.filter { it.method == AcpMethod.ClientMethods.SessionUpdate.methodName }
81+
.mapNotNull { it.params }
82+
.map { ACPJson.decodeFromJsonElement(AcpMethod.ClientMethods.SessionUpdate.serializer, it).update }
83+
}
84+
85+
class TestAgentSupport(val promptHandler: PromptHandler) : AgentSupport {
86+
var isInitialized = false
87+
val createdSessions = mutableMapOf<SessionId, TestAgentSession>()
88+
89+
override suspend fun initialize(clientInfo: ClientInfo): AgentInfo {
90+
isInitialized = true
91+
return AgentInfo()
92+
}
93+
94+
override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession {
95+
val sessionId = SessionId("test-agent-session-${sessionId.incrementAndGet()}")
96+
val session = TestAgentSession(sessionId, promptHandler)
97+
createdSessions[sessionId] = session
98+
return session
99+
}
100+
101+
companion object {
102+
private val sessionId = atomic(0)
103+
}
104+
}
105+
106+
typealias PromptHandler = suspend FlowCollector<Event>.(List<ContentBlock>) -> Unit
107+
108+
class TestAgentSession(
109+
override val sessionId: SessionId,
110+
val promptHandler: PromptHandler
111+
) : AgentSession {
112+
override suspend fun prompt(content: List<ContentBlock>, _meta: JsonElement?): Flow<Event> = flow {
113+
promptHandler(content)
114+
}
115+
}
116+
117+
fun withTestAgent(
118+
timeout: Duration = 5.seconds,
119+
promptHandler: PromptHandler = echoPromptHandler,
120+
block: suspend CoroutineScope.(TestAgent) -> Unit
121+
) = runBlocking {
122+
val transport = TestTransport(timeout)
123+
val protocol = Protocol(this, transport)
124+
val agentSupport = TestAgentSupport(promptHandler)
125+
val agent = Agent(protocol, agentSupport)
126+
protocol.start()
127+
128+
// wait a little after protocol start, if messages get sent right away they can get lost
129+
delay(100.milliseconds)
130+
131+
val testAgent = TestAgent(agent, agentSupport, transport)
132+
block(testAgent)
133+
testAgent.close()
134+
}
135+
136+
fun withInitializedTestAgent(
137+
timeout: Duration = 5.seconds,
138+
promptHandler: PromptHandler = echoPromptHandler,
139+
block: suspend CoroutineScope.(TestAgent) -> Unit
140+
) = withTestAgent(
141+
timeout = timeout,
142+
promptHandler = promptHandler,
143+
) { testAgent ->
144+
testAgent.testInitialize(InitializeRequest(LATEST_PROTOCOL_VERSION))
145+
check(testAgent.agentSupport.isInitialized)
146+
block(testAgent)
147+
}
148+
149+
fun withTestAgentSession(
150+
timeout: Duration = 5.seconds,
151+
promptHandler: PromptHandler = echoPromptHandler,
152+
cwd: String = ".",
153+
mcpServers: List<McpServer> = emptyList(),
154+
block: suspend CoroutineScope.(TestAgent, TestAgentSession) -> Unit
155+
) = withInitializedTestAgent(
156+
timeout = timeout,
157+
promptHandler = promptHandler,
158+
) { testAgent ->
159+
val (newSessionResponse) = testAgent.testNewSession(NewSessionRequest(cwd, mcpServers))
160+
checkNotNull(newSessionResponse)
161+
val session = testAgent.agentSupport.createdSessions[newSessionResponse.sessionId]
162+
checkNotNull(session)
163+
block(testAgent, session)
164+
}
165+
166+
val echoPromptHandler: PromptHandler = { prompt ->
167+
prompt.filterIsInstance<ContentBlock.Text>().forEach {
168+
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(it)))
169+
}
170+
emit(Event.PromptResponseEvent(PromptResponse(StopReason.END_TURN)))
171+
}
172+
173+
fun delayEchoPromptHandler(delay: Duration): PromptHandler = { prompt ->
174+
delay(delay)
175+
prompt.filterIsInstance<ContentBlock.Text>().forEach {
176+
emit(Event.SessionUpdateEvent(SessionUpdate.AgentMessageChunk(it)))
177+
}
178+
emit(Event.PromptResponseEvent(PromptResponse(StopReason.END_TURN)))
179+
}

0 commit comments

Comments
 (0)