Skip to content

Commit 334fc81

Browse files
committed
Fix RxJava tracing context propagation
1 parent ec93f50 commit 334fc81

5 files changed

Lines changed: 170 additions & 77 deletions

File tree

core/src/main/java/com/google/adk/agents/BaseAgent.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,12 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
323323
private Flowable<Event> run(
324324
InvocationContext parentContext,
325325
Function<InvocationContext, Flowable<Event>> runImplementation) {
326-
Context otelContext = Context.current();
327326
return Flowable.using(
328-
() ->
329-
Instrumentation.recordAgentInvocation(
330-
createInvocationContext(parentContext), this, otelContext),
327+
() -> {
328+
Context otelContext = Context.current();
329+
return Instrumentation.recordAgentInvocation(
330+
createInvocationContext(parentContext), this, otelContext);
331+
},
331332
agentInvocation -> {
332333
InvocationContext invocationContext = agentInvocation.getCtx();
333334
Flowable<Event> mainAndAfterEvents =

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 81 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -218,74 +218,87 @@ private Flowable<Event> callLlm(
218218
Event eventForCallbackUsage) {
219219
LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();
220220

221-
return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage)
222-
.toFlowable()
223-
.concatMap(
224-
llmResp ->
225-
postprocess(
226-
context,
227-
eventForCallbackUsage,
228-
llmRequestBuilder.build(),
229-
llmResp,
230-
spanContext))
231-
.switchIfEmpty(
232-
Flowable.defer(
233-
() -> {
234-
LlmAgent agent = (LlmAgent) context.agent();
235-
BaseLlm llm =
236-
agent.resolvedModel().model().isPresent()
237-
? agent.resolvedModel().model().get()
238-
: LlmRegistry.getLlm(agent.resolvedModel().modelName().get());
239-
LlmRequest finalLlmRequest = llmRequestBuilder.build();
240-
241-
Span span =
242-
Tracing.getTracer()
243-
.spanBuilder("call_llm")
244-
.setParent(spanContext)
245-
.startSpan();
246-
Context callLlmContext = spanContext.with(span);
247-
248-
Flowable<Event> flowable =
249-
llm.generateContent(
250-
finalLlmRequest,
251-
context.runConfig().streamingMode() == StreamingMode.SSE)
252-
.onErrorResumeNext(
253-
exception ->
254-
handleOnModelErrorCallback(
255-
context,
256-
llmRequestBuilder,
257-
eventForCallbackUsage,
258-
exception)
259-
.switchIfEmpty(Single.error(exception))
260-
.toFlowable())
261-
.doOnError(
262-
error -> {
263-
span.setStatus(StatusCode.ERROR, error.getMessage());
264-
span.recordException(error);
265-
})
221+
return Flowable.defer(
222+
() -> {
223+
Span span =
224+
Tracing.getTracer().spanBuilder("call_llm").setParent(spanContext).startSpan();
225+
Context callLlmContext = spanContext.with(span);
226+
227+
return Tracing.traceFlowable(
228+
callLlmContext,
229+
span,
230+
() ->
231+
handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage)
232+
.toFlowable()
266233
.concatMap(
267-
llmResp ->
268-
handleAfterModelCallback(context, llmResp, eventForCallbackUsage)
269-
.toFlowable())
270-
.flatMap(
271234
llmResp ->
272235
postprocess(
273236
context,
274237
eventForCallbackUsage,
275-
finalLlmRequest,
238+
llmRequestBuilder.build(),
276239
llmResp,
277240
callLlmContext)
278241
.doOnSubscribe(
279-
s ->
242+
subscription ->
280243
traceCallLlm(
281244
span,
282245
context,
283246
eventForCallbackUsage.id(),
284-
finalLlmRequest,
285-
llmResp)));
286-
287-
return Tracing.traceFlowable(callLlmContext, span, () -> flowable);
288-
}));
247+
llmRequestBuilder.build(),
248+
llmResp)))
249+
.switchIfEmpty(
250+
Flowable.defer(
251+
() -> {
252+
LlmAgent agent = (LlmAgent) context.agent();
253+
BaseLlm llm =
254+
agent.resolvedModel().model().isPresent()
255+
? agent.resolvedModel().model().get()
256+
: LlmRegistry.getLlm(
257+
agent.resolvedModel().modelName().get());
258+
LlmRequest finalLlmRequest = llmRequestBuilder.build();
259+
260+
return llm.generateContent(
261+
finalLlmRequest,
262+
context.runConfig().streamingMode()
263+
== StreamingMode.SSE)
264+
.onErrorResumeNext(
265+
exception ->
266+
handleOnModelErrorCallback(
267+
context,
268+
llmRequestBuilder,
269+
eventForCallbackUsage,
270+
exception)
271+
.switchIfEmpty(Single.error(exception))
272+
.toFlowable())
273+
.doOnError(
274+
error -> {
275+
span.setStatus(StatusCode.ERROR, error.getMessage());
276+
span.recordException(error);
277+
})
278+
.concatMap(
279+
llmResp ->
280+
handleAfterModelCallback(
281+
context, llmResp, eventForCallbackUsage)
282+
.toFlowable())
283+
.flatMap(
284+
llmResp ->
285+
postprocess(
286+
context,
287+
eventForCallbackUsage,
288+
finalLlmRequest,
289+
llmResp,
290+
callLlmContext)
291+
.doOnSubscribe(
292+
subscription ->
293+
traceCallLlm(
294+
span,
295+
context,
296+
eventForCallbackUsage.id(),
297+
finalLlmRequest,
298+
llmResp)));
299+
})))
300+
.compose(Tracing.withContext(spanContext));
301+
});
289302
}
290303

291304
/**
@@ -667,10 +680,12 @@ public void onError(Throwable e) {
667680
"Agent not found: " + event.actions().transferToAgent().get());
668681
}
669682
Flowable<Event> nextAgentEvents =
670-
nextAgent
671-
.get()
672-
.runLive(invocationContext)
673-
.compose(Tracing.withContext(spanContext));
683+
Flowable.defer(
684+
() -> {
685+
try (Scope scope = spanContext.makeCurrent()) {
686+
return nextAgent.get().runLive(invocationContext);
687+
}
688+
});
674689
events = Flowable.concat(events, nextAgentEvents);
675690
}
676691
return events;
@@ -693,11 +708,12 @@ public void onError(Throwable e) {
693708
});
694709

695710
return Tracing.traceFlowable(
696-
callLlmContext,
697-
span,
698-
() ->
699-
receiveFlow.takeWhile(
700-
event -> !event.actions().endInvocation().orElse(false)));
711+
callLlmContext,
712+
span,
713+
() ->
714+
receiveFlow.takeWhile(
715+
event -> !event.actions().endInvocation().orElse(false)))
716+
.compose(Tracing.withContext(spanContext));
701717
}));
702718
}
703719

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,9 @@ protected Flowable<Event> runAsyncImpl(
485485
Preconditions.checkNotNull(session, "session cannot be null");
486486
Preconditions.checkNotNull(newMessage, "newMessage cannot be null");
487487
Preconditions.checkNotNull(runConfig, "runConfig cannot be null");
488-
Context capturedContext = Context.current();
489488
return Flowable.defer(
490489
() -> {
490+
Context capturedContext = Context.current();
491491
BaseAgent rootAgent = this.agent;
492492
String invocationId = InvocationContext.newInvocationContextId();
493493

core/src/main/java/com/google/adk/telemetry/Tracing.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,8 @@ public Publisher<T> apply(Flowable<T> upstream) {
543543
return Flowable.defer(
544544
() -> {
545545
TracingLifecycle lifecycle = new TracingLifecycle();
546-
Flowable<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
546+
lifecycle.start();
547+
Flowable<T> pipeline = upstream;
547548
if (onSuccessConsumer != null) {
548549
pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t));
549550
}
@@ -556,7 +557,8 @@ public SingleSource<T> apply(Single<T> upstream) {
556557
return Single.defer(
557558
() -> {
558559
TracingLifecycle lifecycle = new TracingLifecycle();
559-
Single<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
560+
lifecycle.start();
561+
Single<T> pipeline = upstream;
560562
if (onSuccessConsumer != null) {
561563
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
562564
}
@@ -569,7 +571,8 @@ public MaybeSource<T> apply(Maybe<T> upstream) {
569571
return Maybe.defer(
570572
() -> {
571573
TracingLifecycle lifecycle = new TracingLifecycle();
572-
Maybe<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
574+
lifecycle.start();
575+
Maybe<T> pipeline = upstream;
573576
if (onSuccessConsumer != null) {
574577
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
575578
}
@@ -582,7 +585,8 @@ public CompletableSource apply(Completable upstream) {
582585
return Completable.defer(
583586
() -> {
584587
TracingLifecycle lifecycle = new TracingLifecycle();
585-
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
588+
lifecycle.start();
589+
return upstream.doFinally(lifecycle::end);
586590
});
587591
}
588592
}

core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import java.util.List;
6363
import java.util.Map;
6464
import java.util.Optional;
65+
import java.util.concurrent.atomic.AtomicReference;
6566
import org.junit.After;
6667
import org.junit.Before;
6768
import org.junit.Rule;
@@ -215,6 +216,70 @@ public void testTraceTransformer() throws InterruptedException {
215216
assertTrue(transformerSpanData.hasEnded());
216217
}
217218

219+
@Test
220+
public void testTraceTransformerStartsSpanBeforeSubscribingToDeferredUpstream()
221+
throws InterruptedException {
222+
Span parentSpan = tracer.spanBuilder("parent").startSpan();
223+
AtomicReference<String> flowableSpanId = new AtomicReference<>();
224+
AtomicReference<String> singleSpanId = new AtomicReference<>();
225+
AtomicReference<String> maybeSpanId = new AtomicReference<>();
226+
AtomicReference<String> completableSpanId = new AtomicReference<>();
227+
228+
try (Scope s = parentSpan.makeCurrent()) {
229+
Flowable.defer(
230+
() -> {
231+
flowableSpanId.set(Span.current().getSpanContext().getSpanId());
232+
return Flowable.just(1);
233+
})
234+
.compose(Tracing.trace("flowable-transformer"))
235+
.test()
236+
.await()
237+
.assertComplete();
238+
239+
Single.defer(
240+
() -> {
241+
singleSpanId.set(Span.current().getSpanContext().getSpanId());
242+
return Single.just(1);
243+
})
244+
.compose(Tracing.trace("single-transformer"))
245+
.test()
246+
.await()
247+
.assertComplete();
248+
249+
Maybe.defer(
250+
() -> {
251+
maybeSpanId.set(Span.current().getSpanContext().getSpanId());
252+
return Maybe.just(1);
253+
})
254+
.compose(Tracing.trace("maybe-transformer"))
255+
.test()
256+
.await()
257+
.assertComplete();
258+
259+
Completable.defer(
260+
() -> {
261+
completableSpanId.set(Span.current().getSpanContext().getSpanId());
262+
return Completable.complete();
263+
})
264+
.compose(Tracing.trace("completable-transformer"))
265+
.test()
266+
.await()
267+
.assertComplete();
268+
} finally {
269+
parentSpan.end();
270+
}
271+
272+
SpanData parentSpanData = findSpanByName("parent");
273+
assertDeferredUpstreamSawTransformerSpan(
274+
parentSpanData, findSpanByName("flowable-transformer"), flowableSpanId);
275+
assertDeferredUpstreamSawTransformerSpan(
276+
parentSpanData, findSpanByName("single-transformer"), singleSpanId);
277+
assertDeferredUpstreamSawTransformerSpan(
278+
parentSpanData, findSpanByName("maybe-transformer"), maybeSpanId);
279+
assertDeferredUpstreamSawTransformerSpan(
280+
parentSpanData, findSpanByName("completable-transformer"), completableSpanId);
281+
}
282+
218283
@Test
219284
public void testTraceAgentInvocation() {
220285
Span span = tracer.spanBuilder("test").startSpan();
@@ -594,9 +659,9 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException {
594659
assertParent(agentASpan, agentACallLlm1);
595660
// ├── execute_tool transfer_to_agent
596661
assertParent(agentACallLlm1, executeTool);
597-
// └── invoke_agent AgentB
598-
assertParent(agentACallLlm1, agentBSpan);
599-
// └── call_llm 2
662+
// └── invoke_agent AgentB
663+
assertParent(agentASpan, agentBSpan);
664+
// └── call_llm 2
600665
assertParent(agentBSpan, agentBCallLlm);
601666
}
602667

@@ -645,6 +710,13 @@ private void assertParent(SpanData parent, SpanData child) {
645710
assertEquals(parent.getSpanContext().getSpanId(), child.getParentSpanContext().getSpanId());
646711
}
647712

713+
private void assertDeferredUpstreamSawTransformerSpan(
714+
SpanData parent, SpanData transformer, AtomicReference<String> observedSpanId) {
715+
assertParent(parent, transformer);
716+
assertTrue(transformer.hasEnded());
717+
assertEquals(transformer.getSpanContext().getSpanId(), observedSpanId.get());
718+
}
719+
648720
/**
649721
* Finds a span by name, polling multiple times.
650722
*

0 commit comments

Comments
 (0)