Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,12 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
private Flowable<Event> run(
InvocationContext parentContext,
Function<InvocationContext, Flowable<Event>> runImplementation) {
Context otelContext = Context.current();
return Flowable.using(
() ->
Instrumentation.recordAgentInvocation(
createInvocationContext(parentContext), this, otelContext),
() -> {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseAgent#run — Capture the parent context inside deferred execution so invoke_agent is correctly parented under the active invocation span.

Context otelContext = Context.current();
return Instrumentation.recordAgentInvocation(
createInvocationContext(parentContext), this, otelContext);
},
agentInvocation -> {
InvocationContext invocationContext = agentInvocation.getCtx();
Flowable<Event> mainAndAfterEvents =
Expand Down
146 changes: 81 additions & 65 deletions core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

@r4inee r4inee Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseLlmFlow#callLlm — Create the call_llm span before beforeModelCallback, so beforeModelCallback, model execution, afterModelCallback, and model error callbacks share the same call_llm span.

Original file line number Diff line number Diff line change
Expand Up @@ -218,74 +218,87 @@ private Flowable<Event> callLlm(
Event eventForCallbackUsage) {
LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();

return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage)
.toFlowable()
.concatMap(
llmResp ->
postprocess(
context,
eventForCallbackUsage,
llmRequestBuilder.build(),
llmResp,
spanContext))
.switchIfEmpty(
Flowable.defer(
() -> {
LlmAgent agent = (LlmAgent) context.agent();
BaseLlm llm =
agent.resolvedModel().model().isPresent()
? agent.resolvedModel().model().get()
: LlmRegistry.getLlm(agent.resolvedModel().modelName().get());
LlmRequest finalLlmRequest = llmRequestBuilder.build();

Span span =
Tracing.getTracer()
.spanBuilder("call_llm")
.setParent(spanContext)
.startSpan();
Context callLlmContext = spanContext.with(span);

Flowable<Event> flowable =
llm.generateContent(
finalLlmRequest,
context.runConfig().streamingMode() == StreamingMode.SSE)
.onErrorResumeNext(
exception ->
handleOnModelErrorCallback(
context,
llmRequestBuilder,
eventForCallbackUsage,
exception)
.switchIfEmpty(Single.error(exception))
.toFlowable())
.doOnError(
error -> {
span.setStatus(StatusCode.ERROR, error.getMessage());
span.recordException(error);
})
return Flowable.defer(
() -> {
Span span =
Tracing.getTracer().spanBuilder("call_llm").setParent(spanContext).startSpan();
Context callLlmContext = spanContext.with(span);

return Tracing.traceFlowable(
callLlmContext,
span,
() ->
handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage)
.toFlowable()
.concatMap(
llmResp ->
handleAfterModelCallback(context, llmResp, eventForCallbackUsage)
.toFlowable())
.flatMap(
llmResp ->
postprocess(
context,
eventForCallbackUsage,
finalLlmRequest,
llmRequestBuilder.build(),
llmResp,
callLlmContext)
.doOnSubscribe(
s ->
subscription ->
traceCallLlm(
span,
context,
eventForCallbackUsage.id(),
finalLlmRequest,
llmResp)));

return Tracing.traceFlowable(callLlmContext, span, () -> flowable);
}));
llmRequestBuilder.build(),
llmResp)))
.switchIfEmpty(
Flowable.defer(
() -> {
LlmAgent agent = (LlmAgent) context.agent();
BaseLlm llm =
agent.resolvedModel().model().isPresent()
? agent.resolvedModel().model().get()
: LlmRegistry.getLlm(
agent.resolvedModel().modelName().get());
LlmRequest finalLlmRequest = llmRequestBuilder.build();

return llm.generateContent(
finalLlmRequest,
context.runConfig().streamingMode()
== StreamingMode.SSE)
.onErrorResumeNext(
exception ->
handleOnModelErrorCallback(
context,
llmRequestBuilder,
eventForCallbackUsage,
exception)
.switchIfEmpty(Single.error(exception))
.toFlowable())
.doOnError(
error -> {
span.setStatus(StatusCode.ERROR, error.getMessage());
span.recordException(error);
})
.concatMap(
llmResp ->
handleAfterModelCallback(
context, llmResp, eventForCallbackUsage)
.toFlowable())
.flatMap(
llmResp ->
postprocess(
context,
eventForCallbackUsage,
finalLlmRequest,
llmResp,
callLlmContext)
.doOnSubscribe(
subscription ->
traceCallLlm(
span,
context,
eventForCallbackUsage.id(),
finalLlmRequest,
llmResp)));
})))
.compose(Tracing.withContext(spanContext));
});
}

/**
Expand Down Expand Up @@ -667,10 +680,12 @@ public void onError(Throwable e) {
"Agent not found: " + event.actions().transferToAgent().get());
}
Flowable<Event> nextAgentEvents =
nextAgent
.get()
.runLive(invocationContext)
.compose(Tracing.withContext(spanContext));
Flowable.defer(
() -> {
try (Scope scope = spanContext.makeCurrent()) {
return nextAgent.get().runLive(invocationContext);
}
});
events = Flowable.concat(events, nextAgentEvents);
}
return events;
Expand All @@ -693,11 +708,12 @@ public void onError(Throwable e) {
});

return Tracing.traceFlowable(
callLlmContext,
span,
() ->
receiveFlow.takeWhile(
event -> !event.actions().endInvocation().orElse(false)));
callLlmContext,
span,
() ->
receiveFlow.takeWhile(
event -> !event.actions().endInvocation().orElse(false)))
.compose(Tracing.withContext(spanContext));

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseLlmFlow#callLlm — Re-scope emissions back to the parent context after the LLM segment, preventing follow-up tool/agent/LLM work from inheriting the previous call_llm context.

}));
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/com/google/adk/runner/Runner.java

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runner#runAsyncImpl — Capture Context.current() inside Flowable.defer(...) so the runner captures execution-time context, not stale assembly-time context.

Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,9 @@ protected Flowable<Event> runAsyncImpl(
Preconditions.checkNotNull(session, "session cannot be null");
Preconditions.checkNotNull(newMessage, "newMessage cannot be null");
Preconditions.checkNotNull(runConfig, "runConfig cannot be null");
Context capturedContext = Context.current();
return Flowable.defer(
() -> {
Context capturedContext = Context.current();
BaseAgent rootAgent = this.agent;
String invocationId = InvocationContext.newInvocationContextId();

Expand Down
12 changes: 8 additions & 4 deletions core/src/main/java/com/google/adk/telemetry/Tracing.java

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracing.TracerProvider — Start spans before subscribing to upstream streams, instead of in doOnSubscribe.

This ensures deferred RxJava sources see the tracing span as Span.current().

Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,8 @@ public Publisher<T> apply(Flowable<T> upstream) {
return Flowable.defer(
() -> {
TracingLifecycle lifecycle = new TracingLifecycle();
Flowable<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
lifecycle.start();
Flowable<T> pipeline = upstream;
if (onSuccessConsumer != null) {
pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t));
}
Expand All @@ -556,7 +557,8 @@ public SingleSource<T> apply(Single<T> upstream) {
return Single.defer(
() -> {
TracingLifecycle lifecycle = new TracingLifecycle();
Single<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
lifecycle.start();
Single<T> pipeline = upstream;
if (onSuccessConsumer != null) {
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
}
Expand All @@ -569,7 +571,8 @@ public MaybeSource<T> apply(Maybe<T> upstream) {
return Maybe.defer(
() -> {
TracingLifecycle lifecycle = new TracingLifecycle();
Maybe<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
lifecycle.start();
Maybe<T> pipeline = upstream;
if (onSuccessConsumer != null) {
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
}
Expand All @@ -582,7 +585,8 @@ public CompletableSource apply(Completable upstream) {
return Completable.defer(
() -> {
TracingLifecycle lifecycle = new TracingLifecycle();
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
lifecycle.start();
return upstream.doFinally(lifecycle::end);
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -215,6 +216,70 @@ public void testTraceTransformer() throws InterruptedException {
assertTrue(transformerSpanData.hasEnded());
}

@Test
public void testTraceTransformerStartsSpanBeforeSubscribingToDeferredUpstream()
throws InterruptedException {
Span parentSpan = tracer.spanBuilder("parent").startSpan();
AtomicReference<String> flowableSpanId = new AtomicReference<>();
AtomicReference<String> singleSpanId = new AtomicReference<>();
AtomicReference<String> maybeSpanId = new AtomicReference<>();
AtomicReference<String> completableSpanId = new AtomicReference<>();

try (Scope s = parentSpan.makeCurrent()) {
Flowable.defer(
() -> {
flowableSpanId.set(Span.current().getSpanContext().getSpanId());
return Flowable.just(1);
})
.compose(Tracing.trace("flowable-transformer"))
.test()
.await()
.assertComplete();

Single.defer(
() -> {
singleSpanId.set(Span.current().getSpanContext().getSpanId());
return Single.just(1);
})
.compose(Tracing.trace("single-transformer"))
.test()
.await()
.assertComplete();

Maybe.defer(
() -> {
maybeSpanId.set(Span.current().getSpanContext().getSpanId());
return Maybe.just(1);
})
.compose(Tracing.trace("maybe-transformer"))
.test()
.await()
.assertComplete();

Completable.defer(
() -> {
completableSpanId.set(Span.current().getSpanContext().getSpanId());
return Completable.complete();
})
.compose(Tracing.trace("completable-transformer"))
.test()
.await()
.assertComplete();
} finally {
parentSpan.end();
}

SpanData parentSpanData = findSpanByName("parent");
assertDeferredUpstreamSawTransformerSpan(
parentSpanData, findSpanByName("flowable-transformer"), flowableSpanId);
assertDeferredUpstreamSawTransformerSpan(
parentSpanData, findSpanByName("single-transformer"), singleSpanId);
assertDeferredUpstreamSawTransformerSpan(
parentSpanData, findSpanByName("maybe-transformer"), maybeSpanId);
assertDeferredUpstreamSawTransformerSpan(
parentSpanData, findSpanByName("completable-transformer"), completableSpanId);
}

@Test
public void testTraceAgentInvocation() {
Span span = tracer.spanBuilder("test").startSpan();
Expand Down Expand Up @@ -464,6 +529,38 @@ public void runnerRunLive_propagatesContext() throws InterruptedException {
assertParent(invocation, agentSpan);
}

@Test
public void testModelCallbacksObserveCallLlmSpan() throws InterruptedException {
TestLlm testLlm =
TestUtils.createTestLlm(
TestUtils.createLlmResponse(Content.fromParts(Part.fromText("response"))));
AtomicReference<String> beforeModelSpanId = new AtomicReference<>();
AtomicReference<String> afterModelSpanId = new AtomicReference<>();

LlmAgent agentWithCallbacks =
LlmAgent.builder()
.name("test_agent")
.description("description")
.model(testLlm)
.beforeModelCallback(
(callbackContext, llmRequest) -> {
beforeModelSpanId.set(Span.current().getSpanContext().getSpanId());
return Maybe.empty();
})
.afterModelCallback(
(callbackContext, llmResponse) -> {
afterModelSpanId.set(Span.current().getSpanContext().getSpanId());
return Maybe.empty();
})
.build();

runAgent(agentWithCallbacks);

SpanData callLlm = findSpanByName("call_llm");
assertEquals(callLlm.getSpanContext().getSpanId(), beforeModelSpanId.get());
assertEquals(callLlm.getSpanContext().getSpanId(), afterModelSpanId.get());
}

@Test
public void testAgentWithToolCallTraceHierarchy() throws InterruptedException {
// This test verifies the trace hierarchy created when an agent calls an LLM,
Expand Down Expand Up @@ -594,9 +691,9 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException {
assertParent(agentASpan, agentACallLlm1);
// ├── execute_tool transfer_to_agent
assertParent(agentACallLlm1, executeTool);
// └── invoke_agent AgentB
assertParent(agentACallLlm1, agentBSpan);
// └── call_llm 2
// └── invoke_agent AgentB
assertParent(agentASpan, agentBSpan);
// └── call_llm 2
assertParent(agentBSpan, agentBCallLlm);
}

Expand Down Expand Up @@ -645,6 +742,13 @@ private void assertParent(SpanData parent, SpanData child) {
assertEquals(parent.getSpanContext().getSpanId(), child.getParentSpanContext().getSpanId());
}

private void assertDeferredUpstreamSawTransformerSpan(
SpanData parent, SpanData transformer, AtomicReference<String> observedSpanId) {
assertParent(parent, transformer);
assertTrue(transformer.hasEnded());
assertEquals(transformer.getSpanContext().getSpanId(), observedSpanId.get());
}

/**
* Finds a span by name, polling multiple times.
*
Expand Down