Skip to content

Commit 170c83b

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Make BigQueryAgentAnalyticsPlugin state per-invocation
This change introduces per-invocation instances of BatchProcessor and TraceManager, managed by ConcurrentHashMaps keyed by invocation ID. This ensures that analytics and tracing data are isolated for each concurrent invocation. BatchProcessors and TraceManagers are created lazily on the first event for a given invocation and are cleaned up when the invocation completes. PiperOrigin-RevId: 897370846
1 parent 78766c1 commit 170c83b

4 files changed

Lines changed: 421 additions & 149 deletions

File tree

core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java

Lines changed: 91 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import static com.google.adk.plugins.agentanalytics.JsonFormatter.convertToJsonNode;
2222
import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate;
2323
import static com.google.adk.plugins.agentanalytics.JsonFormatter.toJavaObject;
24-
import static java.util.concurrent.TimeUnit.MILLISECONDS;
2524

2625
import com.google.adk.agents.BaseAgent;
2726
import com.google.adk.agents.CallbackContext;
@@ -41,8 +40,6 @@
4140
import com.google.adk.tools.ToolContext;
4241
import com.google.adk.tools.mcp.AbstractMcpTool;
4342
import com.google.adk.utils.AgentEnums.AgentOrigin;
44-
import com.google.api.gax.core.FixedCredentialsProvider;
45-
import com.google.api.gax.retrying.RetrySettings;
4643
import com.google.auth.oauth2.GoogleCredentials;
4744
import com.google.cloud.bigquery.BigQuery;
4845
import com.google.cloud.bigquery.BigQueryException;
@@ -53,11 +50,7 @@
5350
import com.google.cloud.bigquery.Table;
5451
import com.google.cloud.bigquery.TableId;
5552
import com.google.cloud.bigquery.TableInfo;
56-
import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient;
57-
import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings;
58-
import com.google.cloud.bigquery.storage.v1.StreamWriter;
5953
import com.google.common.annotations.VisibleForTesting;
60-
import com.google.common.base.VerifyException;
6154
import com.google.common.collect.ImmutableList;
6255
import com.google.common.collect.ImmutableMap;
6356
import com.google.genai.types.Content;
@@ -70,10 +63,6 @@
7063
import java.util.HashMap;
7164
import java.util.Map;
7265
import java.util.Optional;
73-
import java.util.concurrent.Executors;
74-
import java.util.concurrent.ScheduledExecutorService;
75-
import java.util.concurrent.ThreadFactory;
76-
import java.util.concurrent.atomic.AtomicLong;
7766
import java.util.logging.Level;
7867
import java.util.logging.Logger;
7968
import org.jspecify.annotations.Nullable;
@@ -88,7 +77,6 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {
8877
Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName());
8978
private static final ImmutableList<String> DEFAULT_AUTH_SCOPES =
9079
ImmutableList.of("https://www.googleapis.com/auth/cloud-platform");
91-
private static final AtomicLong threadCounter = new AtomicLong(0);
9280
private static final ImmutableMap<String, String> HITL_EVENT_TYPES =
9381
ImmutableMap.of(
9482
"adk_request_credential",
@@ -100,11 +88,8 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {
10088

10189
private final BigQueryLoggerConfig config;
10290
private final BigQuery bigQuery;
103-
private final BigQueryWriteClient writeClient;
104-
private final ScheduledExecutorService executor;
10591
private final Object tableEnsuredLock = new Object();
106-
@VisibleForTesting final BatchProcessor batchProcessor;
107-
@VisibleForTesting final TraceManager traceManager;
92+
private final PluginState state;
10893
private volatile boolean tableEnsured = false;
10994

11095
public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException {
@@ -113,28 +98,14 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOExcept
11398

11499
public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery)
115100
throws IOException {
101+
this(config, bigQuery, new PluginState(config));
102+
}
103+
104+
BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery, PluginState state) {
116105
super("bigquery_agent_analytics");
117106
this.config = config;
118107
this.bigQuery = bigQuery;
119-
ThreadFactory threadFactory =
120-
r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement());
121-
this.executor = Executors.newScheduledThreadPool(1, threadFactory);
122-
this.writeClient = createWriteClient(config);
123-
this.traceManager = createTraceManager();
124-
125-
if (config.enabled()) {
126-
StreamWriter writer = createWriter(config);
127-
this.batchProcessor =
128-
new BatchProcessor(
129-
writer,
130-
config.batchSize(),
131-
config.batchFlushInterval(),
132-
config.queueMaxSize(),
133-
executor);
134-
this.batchProcessor.start();
135-
} else {
136-
this.batchProcessor = null;
137-
}
108+
this.state = state;
138109
}
139110

140111
private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException {
@@ -194,7 +165,7 @@ private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) {
194165

195166
try {
196167
if (config.createViews()) {
197-
var unused = executor.submit(() -> createAnalyticsViews(bigQuery, config));
168+
var unused = state.getExecutor().submit(() -> createAnalyticsViews(bigQuery, config));
198169
}
199170
} catch (RuntimeException e) {
200171
logger.log(Level.WARNING, "Failed to create/update BigQuery views for table: " + tableId, e);
@@ -209,48 +180,6 @@ private void processBigQueryException(BigQueryException e, String logMessage) {
209180
}
210181
}
211182

212-
protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException {
213-
if (config.credentials() != null) {
214-
return BigQueryWriteClient.create(
215-
BigQueryWriteSettings.newBuilder()
216-
.setCredentialsProvider(FixedCredentialsProvider.create(config.credentials()))
217-
.build());
218-
}
219-
return BigQueryWriteClient.create();
220-
}
221-
222-
protected String getStreamName(BigQueryLoggerConfig config) {
223-
return String.format(
224-
"projects/%s/datasets/%s/tables/%s/streams/_default",
225-
config.projectId(), config.datasetId(), config.tableName());
226-
}
227-
228-
protected StreamWriter createWriter(BigQueryLoggerConfig config) {
229-
BigQueryLoggerConfig.RetryConfig retryConfig = config.retryConfig();
230-
RetrySettings retrySettings =
231-
RetrySettings.newBuilder()
232-
.setMaxAttempts(retryConfig.maxRetries())
233-
.setInitialRetryDelay(
234-
org.threeten.bp.Duration.ofMillis(retryConfig.initialDelay().toMillis()))
235-
.setRetryDelayMultiplier(retryConfig.multiplier())
236-
.setMaxRetryDelay(org.threeten.bp.Duration.ofMillis(retryConfig.maxDelay().toMillis()))
237-
.build();
238-
239-
String streamName = getStreamName(config);
240-
try {
241-
return StreamWriter.newBuilder(streamName, writeClient)
242-
.setRetrySettings(retrySettings)
243-
.setWriterSchema(BigQuerySchema.getArrowSchema())
244-
.build();
245-
} catch (Exception e) {
246-
throw new VerifyException("Failed to create StreamWriter for " + streamName, e);
247-
}
248-
}
249-
250-
protected TraceManager createTraceManager() {
251-
return new TraceManager();
252-
}
253-
254183
private void logEvent(
255184
String eventType,
256185
InvocationContext invocationContext,
@@ -265,7 +194,7 @@ private void logEvent(
265194
Object content,
266195
boolean isContentTruncated,
267196
Optional<EventData> eventData) {
268-
if (!config.enabled() || batchProcessor == null) {
197+
if (!config.enabled()) {
269198
return;
270199
}
271200
if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) {
@@ -274,6 +203,11 @@ private void logEvent(
274203
if (config.eventDenylist().contains(eventType)) {
275204
return;
276205
}
206+
if (state.isProcessed(invocationContext.invocationId())) {
207+
return;
208+
}
209+
String invocationId = invocationContext.invocationId();
210+
BatchProcessor processor = state.getBatchProcessor(invocationId);
277211
// Ensure table exists before logging.
278212
ensureTableExistsOnce();
279213
// Log common fields
@@ -301,11 +235,12 @@ private void logEvent(
301235
row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext)));
302236

303237
addTraceDetails(row, invocationContext, eventData);
304-
batchProcessor.append(row);
238+
processor.append(row);
305239
}
306240

307241
private void addTraceDetails(
308242
Map<String, Object> row, InvocationContext invocationContext, Optional<EventData> eventData) {
243+
TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
309244
String traceId =
310245
eventData
311246
.flatMap(EventData::traceIdOverride)
@@ -336,7 +271,7 @@ private void addTraceDetails(
336271
private Map<String, Object> getAttributes(
337272
EventData eventData, InvocationContext invocationContext) {
338273
Map<String, Object> attributes = new HashMap<>(eventData.extraAttributes());
339-
274+
TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
340275
attributes.put("root_agent_name", traceManager.getRootAgentName());
341276
eventData.model().ifPresent(m -> attributes.put("model", m));
342277
eventData.modelVersion().ifPresent(mv -> attributes.put("model_version", mv));
@@ -375,25 +310,17 @@ private Map<String, Object> getAttributes(
375310

376311
@Override
377312
public Completable close() {
378-
if (batchProcessor != null) {
379-
batchProcessor.close();
380-
}
381-
if (writeClient != null) {
382-
writeClient.close();
383-
}
384-
try {
385-
executor.shutdown();
386-
if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) {
387-
executor.shutdownNow();
388-
}
389-
} catch (InterruptedException e) {
390-
executor.shutdownNow();
391-
Thread.currentThread().interrupt();
392-
}
313+
state.close();
393314
return Completable.complete();
394315
}
395316

317+
@VisibleForTesting
318+
PluginState getState() {
319+
return state;
320+
}
321+
396322
private Optional<EventData> getCompletedEventData(InvocationContext invocationContext) {
323+
TraceManager traceManager = state.getTraceManager(invocationContext.invocationId());
397324
String traceId = traceManager.getTraceId(invocationContext);
398325
// Pop the invocation span from the trace manager.
399326
Optional<RecordData> popped = traceManager.popSpan();
@@ -426,7 +353,12 @@ public Maybe<Content> onUserMessageCallback(
426353
InvocationContext invocationContext, Content userMessage) {
427354
return Maybe.fromAction(
428355
() -> {
429-
traceManager.ensureInvocationSpan(invocationContext);
356+
if (state.isProcessed(invocationContext.invocationId())) {
357+
return;
358+
}
359+
state
360+
.getTraceManager(invocationContext.invocationId())
361+
.ensureInvocationSpan(invocationContext);
430362
logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty());
431363
if (userMessage.parts().isPresent()) {
432364
for (Part part : userMessage.parts().get()) {
@@ -454,6 +386,9 @@ public Maybe<Content> onUserMessageCallback(
454386
public Maybe<Event> onEventCallback(InvocationContext invocationContext, Event event) {
455387
return Maybe.fromAction(
456388
() -> {
389+
if (state.isProcessed(invocationContext.invocationId())) {
390+
return;
391+
}
457392
EventData.Builder eventDataBuilder =
458393
EventData.builder()
459394
.setExtraAttributes(
@@ -510,9 +445,16 @@ public Maybe<Event> onEventCallback(InvocationContext invocationContext, Event e
510445

511446
@Override
512447
public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
513-
traceManager.ensureInvocationSpan(invocationContext);
514448
return Maybe.fromAction(
515-
() -> logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty()));
449+
() -> {
450+
if (state.isProcessed(invocationContext.invocationId())) {
451+
return;
452+
}
453+
state
454+
.getTraceManager(invocationContext.invocationId())
455+
.ensureInvocationSpan(invocationContext);
456+
logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty());
457+
});
516458
}
517459

518460
@Override
@@ -524,16 +466,30 @@ public Completable afterRunCallback(InvocationContext invocationContext) {
524466
invocationContext,
525467
null,
526468
getCompletedEventData(invocationContext));
527-
batchProcessor.flush();
528-
traceManager.clearStack();
469+
// Mark invocation ID as processed to avoid memory leaks.
470+
state.markProcessed(invocationContext.invocationId());
471+
BatchProcessor processor = state.removeProcessor(invocationContext.invocationId());
472+
if (processor != null) {
473+
processor.flush();
474+
processor.close();
475+
}
476+
TraceManager traceManager = state.removeTraceManager(invocationContext.invocationId());
477+
if (traceManager != null) {
478+
traceManager.clearStack();
479+
}
529480
});
530481
}
531482

532483
@Override
533484
public Maybe<Content> beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) {
534485
return Maybe.fromAction(
535486
() -> {
536-
traceManager.pushSpan("agent:" + agent.name());
487+
if (state.isProcessed(callbackContext.invocationContext().invocationId())) {
488+
return;
489+
}
490+
state
491+
.getTraceManager(callbackContext.invocationContext().invocationId())
492+
.pushSpan("agent:" + agent.name());
537493
logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty());
538494
});
539495
}
@@ -563,6 +519,9 @@ public Maybe<LlmResponse> beforeModelCallback(
563519
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
564520
return Maybe.fromAction(
565521
() -> {
522+
if (state.isProcessed(callbackContext.invocationContext().invocationId())) {
523+
return;
524+
}
566525
Map<String, Object> attributes = new HashMap<>();
567526
Map<String, Object> llmConfig = new HashMap<>();
568527
LlmRequest req = llmRequest.build();
@@ -622,7 +581,9 @@ public Maybe<LlmResponse> beforeModelCallback(
622581
.setModel(req.model().orElse(""))
623582
.setExtraAttributes(attributes)
624583
.build();
625-
traceManager.pushSpan("llm_request");
584+
state
585+
.getTraceManager(callbackContext.invocationContext().invocationId())
586+
.pushSpan("llm_request");
626587
logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData));
627588
});
628589
}
@@ -632,6 +593,11 @@ public Maybe<LlmResponse> afterModelCallback(
632593
CallbackContext callbackContext, LlmResponse llmResponse) {
633594
return Maybe.fromAction(
634595
() -> {
596+
if (state.isProcessed(callbackContext.invocationContext().invocationId())) {
597+
return;
598+
}
599+
TraceManager traceManager =
600+
state.getTraceManager(callbackContext.invocationContext().invocationId());
635601
// TODO(b/495809488): Add formatting of the content
636602
ParsedContent parsedContent =
637603
JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength());
@@ -728,6 +694,11 @@ public Maybe<LlmResponse> onModelErrorCallback(
728694
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
729695
return Maybe.fromAction(
730696
() -> {
697+
if (state.isProcessed(callbackContext.invocationContext().invocationId())) {
698+
return;
699+
}
700+
TraceManager traceManager =
701+
state.getTraceManager(callbackContext.invocationContext().invocationId());
731702
InvocationContext invocationContext = callbackContext.invocationContext();
732703
Optional<RecordData> popped = traceManager.popSpan();
733704
String spanId = popped.map(RecordData::spanId).orElse(null);
@@ -758,11 +729,14 @@ public Maybe<Map<String, Object>> beforeToolCallback(
758729
BaseTool tool, Map<String, Object> toolArgs, ToolContext toolContext) {
759730
return Maybe.fromAction(
760731
() -> {
732+
if (state.isProcessed(toolContext.invocationContext().invocationId())) {
733+
return;
734+
}
761735
TruncationResult res = smartTruncate(toolArgs, config.maxContentLength());
762736
ImmutableMap<String, Object> contentMap =
763737
ImmutableMap.of(
764738
"tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node());
765-
traceManager.pushSpan("tool");
739+
state.getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool");
766740
logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty());
767741
});
768742
}
@@ -775,6 +749,14 @@ public Maybe<Map<String, Object>> afterToolCallback(
775749
Map<String, Object> result) {
776750
return Maybe.fromAction(
777751
() -> {
752+
if (state.isProcessed(toolContext.invocationContext().invocationId())) {
753+
return;
754+
}
755+
state
756+
.getTraceManager(toolContext.invocationContext().invocationId())
757+
.ensureInvocationSpan(toolContext.invocationContext());
758+
TraceManager traceManager =
759+
state.getTraceManager(toolContext.invocationContext().invocationId());
778760
Optional<RecordData> popped = traceManager.popSpan();
779761
TruncationResult truncationResult = smartTruncate(result, config.maxContentLength());
780762
ImmutableMap<String, Object> contentMap =
@@ -812,6 +794,11 @@ public Maybe<Map<String, Object>> onToolErrorCallback(
812794
BaseTool tool, Map<String, Object> toolArgs, ToolContext toolContext, Throwable error) {
813795
return Maybe.fromAction(
814796
() -> {
797+
if (state.isProcessed(toolContext.invocationContext().invocationId())) {
798+
return;
799+
}
800+
TraceManager traceManager =
801+
state.getTraceManager(toolContext.invocationContext().invocationId());
815802
Optional<RecordData> popped = traceManager.popSpan();
816803
TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength());
817804
String toolOrigin = getToolOrigin(tool);

0 commit comments

Comments
 (0)