Skip to content

Commit b67abea

Browse files
Skobeltsynclaude
andcommitted
feat(#1740): thread cumulative tokensUsed into SkillCompleted + Completed
Step 3.5: surface what executeAgentic already tracks onto the event surface that's been carrying tokensUsed: TokenUsage? = null placeholders. executeAgentic now returns AgenticResult(output, tokenUsage) instead of raw Any. cumulativeUsage builds up by summing promptTokens and completionTokens across all turns (TokenUsage.total is derived). executeAgentic became internal because AgenticResult is internal; only in-package callers (Agent.kt) use it. Agent.invokeSuspendForSession gains onSkillCompleted: (TokenUsage?) -> Unit callback with a default no-op. session() captures it into capturedUsage and threads it into both SkillCompleted and Completed events. Agent.invokeSuspend and Agent.invokeSuspendWithPromptOverride unwrap .output — preserves their OUT return contract byte-for-byte. For implementedBy skills the callback never fires; tokensUsed stays null. TDD red-first: two new tests in AgentSessionIntegrationTest. The single-turn case asserts SkillCompleted.tokensUsed equals the stub's turn-1 TokenUsage. The two-turn case (ToolCalls→Text with distinct usages per turn) asserts cumulative equals the field-wise sum. Full suite (root + KSP + no-reflect) green; live π test path unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 449e465 commit b67abea

4 files changed

Lines changed: 109 additions & 9 deletions

File tree

src/main/kotlin/agents_engine/core/Agent.kt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,20 @@ class Agent<IN, OUT>(
268268
internal suspend fun invokeSuspendForSession(
269269
input: IN,
270270
emitter: agents_engine.model.AgentEventEmitter? = null,
271+
onSkillCompleted: (agents_engine.model.TokenUsage?) -> Unit = { /* no-op */ },
271272
onSkillStarted: (String) -> Unit,
272273
): OUT {
273274
try {
274275
val skill = resolveSkill(input)
275276
skillChosenListener?.invoke(skill.name)
276277
onSkillStarted(skill.name)
277278
return if (skill.isAgentic) {
278-
castOut(executeAgentic(this, skill, input, emitter = emitter))
279+
val result = executeAgentic(this, skill, input, emitter = emitter)
280+
// #1740: surface cumulative usage on the way out. Non-agentic
281+
// skills don't go through executeAgentic, so onSkillCompleted
282+
// stays at its default null for the implementedBy path below.
283+
onSkillCompleted(result.tokenUsage)
284+
castOut(result.output)
279285
} else {
280286
castOut(executors[skill.name]!!(input))
281287
}
@@ -318,7 +324,7 @@ class Agent<IN, OUT>(
318324
val skill = resolveSkill(input)
319325
skillChosenListener?.invoke(skill.name)
320326
return if (skill.isAgentic) {
321-
castOut(executeAgentic(this, skill, input, effectivePrompt = promptOverride))
327+
castOut(executeAgentic(this, skill, input, effectivePrompt = promptOverride).output)
322328
} else {
323329
// Non-agentic skills don't read prompt — implementedBy lambdas
324330
// ignore the override. Same behavior as the legacy path.

src/main/kotlin/agents_engine/model/AgenticLoop.kt

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,20 @@ import kotlinx.coroutines.withContext
1313

1414
private const val MAX_ARGUMENT_REPAIR_STEPS = 8
1515

16+
/**
17+
* #1740 — return shape from [executeAgentic]. Carries the parsed output
18+
* alongside cumulative [TokenUsage] summed across all LLM turns of the
19+
* invocation. [tokenUsage] is null when the provider never reported
20+
* usage for any turn.
21+
*/
22+
internal data class AgenticResult(val output: Any, val tokenUsage: TokenUsage?)
23+
1624
/**
1725
* Runs the agentic loop for [skill] on [agent] with [input].
18-
* Returns the parsed output as [Any]; the caller casts it via the agent's castOut.
26+
* Returns the parsed output paired with cumulative token usage;
27+
* the caller casts the output via the agent's castOut.
1928
*/
20-
suspend fun <IN> executeAgentic(
29+
internal suspend fun <IN> executeAgentic(
2130
agent: Agent<IN, *>,
2231
skill: Skill<*, *>,
2332
input: IN,
@@ -37,7 +46,7 @@ suspend fun <IN> executeAgentic(
3746
* callers (`Agent.invoke`, `Agent.invokeSuspend`) pay no overhead.
3847
*/
3948
emitter: AgentEventEmitter? = null,
40-
): Any {
49+
): AgenticResult {
4150
val config = requireNotNull(agent.modelConfig) {
4251
"Agent '${agent.name}' has no model configured. Add a model { } block."
4352
}
@@ -123,6 +132,9 @@ suspend fun <IN> executeAgentic(
123132
var turns = 0
124133
var toolCalls = 0
125134
var totalTokens = 0
135+
// #1740: cumulative usage across all turns. Provider reports per-turn;
136+
// we sum prompt and completion independently (TokenUsage.total is derived).
137+
var cumulativeUsage: TokenUsage? = null
126138
var lastToolName: String? = null
127139
var consecutiveSameTool = 0
128140
val invocationStartNanos = System.nanoTime()
@@ -171,6 +183,13 @@ suspend fun <IN> executeAgentic(
171183
// even if it tips us over: the throw still surfaces the breach.
172184
response.tokenUsage?.let { usage ->
173185
totalTokens += usage.total
186+
// #1740: build cumulative TokenUsage for the event surface.
187+
cumulativeUsage = cumulativeUsage?.let { prev ->
188+
TokenUsage(
189+
promptTokens = prev.promptTokens + usage.promptTokens,
190+
completionTokens = prev.completionTokens + usage.completionTokens,
191+
)
192+
} ?: usage
174193
val cap = budget.maxTokens
175194
if (cap != null) {
176195
maybeFireThreshold(BudgetReason.TOKENS, totalTokens.toDouble() / cap)
@@ -185,9 +204,10 @@ suspend fun <IN> executeAgentic(
185204

186205
when (response) {
187206
is LlmResponse.Text -> {
188-
return skill.outputTransformer?.invoke(response.content)
207+
val parsed = skill.outputTransformer?.invoke(response.content)
189208
?: parseOutput(response.content, agent.outType)
190209
?: error("Could not parse LLM output as ${agent.outType.simpleName}: '${response.content}'")
210+
return AgenticResult(parsed, cumulativeUsage)
191211
}
192212
is LlmResponse.ToolCalls -> {
193213
messages.add(LlmMessage("assistant", "", response.calls))

src/main/kotlin/agents_engine/runtime/events/AgentSessionExtension.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ fun <IN, OUT> Agent<IN, OUT>.session(input: IN): AgentSession<OUT> {
4242
// Captured-on-the-stack: each session has its own holder, so
4343
// concurrent sessions can't race on a shared field.
4444
var capturedSkillName: String? = null
45+
// #1740: per-session usage capture from the agentic loop's cumulative
46+
// total. Stays null for implementedBy skills (no LLM round-trip).
47+
var capturedUsage: agents_engine.model.TokenUsage? = null
4548
// #1739: emitter forwards AgentEvents from inside the agentic loop
4649
// (Token, ToolCallStarted, ToolCallArgumentsDelta, ToolCallFinished)
4750
// into the same channel as the bracket events. trySend is non-
@@ -53,12 +56,16 @@ fun <IN, OUT> Agent<IN, OUT>.session(input: IN): AgentSession<OUT> {
5356
channel.trySend(event as AgentEvent<OUT>)
5457
}
5558
try {
56-
val output = agent.invokeSuspendForSession(input, emitter = streamingEmitter) { skillName ->
59+
val output = agent.invokeSuspendForSession(
60+
input,
61+
emitter = streamingEmitter,
62+
onSkillCompleted = { usage -> capturedUsage = usage },
63+
) { skillName ->
5764
capturedSkillName = skillName
5865
channel.trySend(AgentEvent.SkillStarted(agent.name, skillName))
5966
}
60-
channel.trySend(AgentEvent.SkillCompleted(agent.name, capturedSkillName ?: "?", null))
61-
channel.trySend(AgentEvent.Completed(agent.name, output, null))
67+
channel.trySend(AgentEvent.SkillCompleted(agent.name, capturedSkillName ?: "?", capturedUsage))
68+
channel.trySend(AgentEvent.Completed(agent.name, output, capturedUsage))
6269
channel.close()
6370
result.complete(output)
6471
} catch (t: Throwable) {

src/test/kotlin/agents_engine/runtime/events/AgentSessionIntegrationTest.kt

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,73 @@ class AgentSessionIntegrationTest {
202202
assertTrue(finishedIdx < tokenIdx, "ToolCallFinished (from turn 1) must precede the final Token (from turn 2)")
203203
}
204204

205+
@Test
206+
fun `tokensUsed on SkillCompleted and Completed reflects single-turn stub usage`() = runTest {
207+
// #1740 — one-turn agentic stub with explicit TokenUsage.
208+
// Cumulative usage for a one-turn run equals that turn's usage.
209+
val usage = TokenUsage(promptTokens = 12, completionTokens = 5)
210+
val stub = ModelClient { _ -> LlmResponse.Text("done", usage) }
211+
212+
val agentic = agent<String, String>("tu") {
213+
prompt("Single-turn stub.")
214+
model { ollama("llama3"); client = stub }
215+
skills { skill<String, String>("respond", "Echoes via the model") { tools() } }
216+
}
217+
218+
val events = agentic.session("kick").events.toList()
219+
220+
val skillCompleted = events.filterIsInstance<AgentEvent.SkillCompleted>().single()
221+
val completed = events.filterIsInstance<AgentEvent.Completed<String>>().single()
222+
assertEquals(usage, skillCompleted.tokensUsed, "SkillCompleted.tokensUsed must reflect the stub's TokenUsage")
223+
assertEquals(usage, completed.tokensUsed, "Completed.tokensUsed must reflect the stub's TokenUsage")
224+
}
225+
226+
@Test
227+
fun `tokensUsed sums prompt and completion tokens across multiple turns`() = runTest {
228+
// #1740 — two-turn stub (ToolCalls then Text). Each turn reports
229+
// distinct usage. Cumulative on SkillCompleted/Completed must sum
230+
// prompt and completion tokens independently across turns.
231+
val turn1Usage = TokenUsage(promptTokens = 100, completionTokens = 20)
232+
val turn2Usage = TokenUsage(promptTokens = 150, completionTokens = 35)
233+
val turn1 = LlmResponse.ToolCalls(
234+
listOf(
235+
ToolCall(
236+
name = "ping",
237+
arguments = emptyMap(),
238+
rawArguments = "{}",
239+
callId = "call-multi-turn",
240+
),
241+
),
242+
turn1Usage,
243+
)
244+
val turn2 = LlmResponse.Text("pong", turn2Usage)
245+
val responses = ArrayDeque<LlmResponse>().apply { add(turn1); add(turn2) }
246+
val stub = ModelClient { _ -> responses.removeFirst() }
247+
248+
val agentic = agent<String, String>("multi") {
249+
prompt("Two-turn stub.")
250+
model { ollama("llama3"); client = stub }
251+
tools { tool("ping", "Returns pong") { _: Map<String, Any?> -> "pong" } }
252+
skills {
253+
skill<String, String>("respond", "Two-turn skill") {
254+
@Suppress("DEPRECATION")
255+
tools("ping")
256+
}
257+
}
258+
}
259+
260+
val events = agentic.session("kick").events.toList()
261+
262+
val expected = TokenUsage(
263+
promptTokens = turn1Usage.promptTokens + turn2Usage.promptTokens,
264+
completionTokens = turn1Usage.completionTokens + turn2Usage.completionTokens,
265+
)
266+
val skillCompleted = events.filterIsInstance<AgentEvent.SkillCompleted>().single()
267+
val completed = events.filterIsInstance<AgentEvent.Completed<String>>().single()
268+
assertEquals(expected, skillCompleted.tokensUsed, "SkillCompleted.tokensUsed must sum prompt and completion tokens across turns")
269+
assertEquals(expected, completed.tokensUsed, "Completed.tokensUsed must sum prompt and completion tokens across turns")
270+
}
271+
205272
// Tiny generic 4-tuple — assertable via destructuring in the concurrent test.
206273
private data class Quad<A, B, C, D>(val a: A, val b: B, val c: C, val d: D)
207274
}

0 commit comments

Comments
 (0)