2121import static com .google .adk .plugins .agentanalytics .JsonFormatter .convertToJsonNode ;
2222import static com .google .adk .plugins .agentanalytics .JsonFormatter .smartTruncate ;
2323import static com .google .adk .plugins .agentanalytics .JsonFormatter .toJavaObject ;
24- import static java .util .concurrent .TimeUnit .MILLISECONDS ;
2524
2625import com .google .adk .agents .BaseAgent ;
2726import com .google .adk .agents .CallbackContext ;
4140import com .google .adk .tools .ToolContext ;
4241import com .google .adk .tools .mcp .AbstractMcpTool ;
4342import com .google .adk .utils .AgentEnums .AgentOrigin ;
44- import com .google .api .gax .core .FixedCredentialsProvider ;
45- import com .google .api .gax .retrying .RetrySettings ;
4643import com .google .auth .oauth2 .GoogleCredentials ;
4744import com .google .cloud .bigquery .BigQuery ;
4845import com .google .cloud .bigquery .BigQueryException ;
5350import com .google .cloud .bigquery .Table ;
5451import com .google .cloud .bigquery .TableId ;
5552import com .google .cloud .bigquery .TableInfo ;
56- import com .google .cloud .bigquery .storage .v1 .BigQueryWriteClient ;
57- import com .google .cloud .bigquery .storage .v1 .BigQueryWriteSettings ;
58- import com .google .cloud .bigquery .storage .v1 .StreamWriter ;
5953import com .google .common .annotations .VisibleForTesting ;
60- import com .google .common .base .VerifyException ;
6154import com .google .common .collect .ImmutableList ;
6255import com .google .common .collect .ImmutableMap ;
6356import com .google .genai .types .Content ;
7063import java .util .HashMap ;
7164import java .util .Map ;
7265import java .util .Optional ;
73- import java .util .concurrent .Executors ;
74- import java .util .concurrent .ScheduledExecutorService ;
75- import java .util .concurrent .ThreadFactory ;
76- import java .util .concurrent .atomic .AtomicLong ;
7766import java .util .logging .Level ;
7867import java .util .logging .Logger ;
7968import org .jspecify .annotations .Nullable ;
@@ -88,7 +77,6 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {
8877 Logger .getLogger (BigQueryAgentAnalyticsPlugin .class .getName ());
8978 private static final ImmutableList <String > DEFAULT_AUTH_SCOPES =
9079 ImmutableList .of ("https://www.googleapis.com/auth/cloud-platform" );
91- private static final AtomicLong threadCounter = new AtomicLong (0 );
9280 private static final ImmutableMap <String , String > HITL_EVENT_TYPES =
9381 ImmutableMap .of (
9482 "adk_request_credential" ,
@@ -100,11 +88,8 @@ public class BigQueryAgentAnalyticsPlugin extends BasePlugin {
10088
10189 private final BigQueryLoggerConfig config ;
10290 private final BigQuery bigQuery ;
103- private final BigQueryWriteClient writeClient ;
104- private final ScheduledExecutorService executor ;
10591 private final Object tableEnsuredLock = new Object ();
106- @ VisibleForTesting final BatchProcessor batchProcessor ;
107- @ VisibleForTesting final TraceManager traceManager ;
92+ private final PluginState state ;
10893 private volatile boolean tableEnsured = false ;
10994
11095 public BigQueryAgentAnalyticsPlugin (BigQueryLoggerConfig config ) throws IOException {
@@ -113,28 +98,14 @@ public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOExcept
11398
11499 public BigQueryAgentAnalyticsPlugin (BigQueryLoggerConfig config , BigQuery bigQuery )
115100 throws IOException {
101+ this (config , bigQuery , new PluginState (config ));
102+ }
103+
104+ BigQueryAgentAnalyticsPlugin (BigQueryLoggerConfig config , BigQuery bigQuery , PluginState state ) {
116105 super ("bigquery_agent_analytics" );
117106 this .config = config ;
118107 this .bigQuery = bigQuery ;
119- ThreadFactory threadFactory =
120- r -> new Thread (r , "bq-analytics-plugin-" + threadCounter .getAndIncrement ());
121- this .executor = Executors .newScheduledThreadPool (1 , threadFactory );
122- this .writeClient = createWriteClient (config );
123- this .traceManager = createTraceManager ();
124-
125- if (config .enabled ()) {
126- StreamWriter writer = createWriter (config );
127- this .batchProcessor =
128- new BatchProcessor (
129- writer ,
130- config .batchSize (),
131- config .batchFlushInterval (),
132- config .queueMaxSize (),
133- executor );
134- this .batchProcessor .start ();
135- } else {
136- this .batchProcessor = null ;
137- }
108+ this .state = state ;
138109 }
139110
140111 private static BigQuery createBigQuery (BigQueryLoggerConfig config ) throws IOException {
@@ -194,7 +165,7 @@ private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) {
194165
195166 try {
196167 if (config .createViews ()) {
197- var unused = executor .submit (() -> createAnalyticsViews (bigQuery , config ));
168+ var unused = state . getExecutor () .submit (() -> createAnalyticsViews (bigQuery , config ));
198169 }
199170 } catch (RuntimeException e ) {
200171 logger .log (Level .WARNING , "Failed to create/update BigQuery views for table: " + tableId , e );
@@ -209,48 +180,6 @@ private void processBigQueryException(BigQueryException e, String logMessage) {
209180 }
210181 }
211182
212- protected BigQueryWriteClient createWriteClient (BigQueryLoggerConfig config ) throws IOException {
213- if (config .credentials () != null ) {
214- return BigQueryWriteClient .create (
215- BigQueryWriteSettings .newBuilder ()
216- .setCredentialsProvider (FixedCredentialsProvider .create (config .credentials ()))
217- .build ());
218- }
219- return BigQueryWriteClient .create ();
220- }
221-
222- protected String getStreamName (BigQueryLoggerConfig config ) {
223- return String .format (
224- "projects/%s/datasets/%s/tables/%s/streams/_default" ,
225- config .projectId (), config .datasetId (), config .tableName ());
226- }
227-
228- protected StreamWriter createWriter (BigQueryLoggerConfig config ) {
229- BigQueryLoggerConfig .RetryConfig retryConfig = config .retryConfig ();
230- RetrySettings retrySettings =
231- RetrySettings .newBuilder ()
232- .setMaxAttempts (retryConfig .maxRetries ())
233- .setInitialRetryDelay (
234- org .threeten .bp .Duration .ofMillis (retryConfig .initialDelay ().toMillis ()))
235- .setRetryDelayMultiplier (retryConfig .multiplier ())
236- .setMaxRetryDelay (org .threeten .bp .Duration .ofMillis (retryConfig .maxDelay ().toMillis ()))
237- .build ();
238-
239- String streamName = getStreamName (config );
240- try {
241- return StreamWriter .newBuilder (streamName , writeClient )
242- .setRetrySettings (retrySettings )
243- .setWriterSchema (BigQuerySchema .getArrowSchema ())
244- .build ();
245- } catch (Exception e ) {
246- throw new VerifyException ("Failed to create StreamWriter for " + streamName , e );
247- }
248- }
249-
250- protected TraceManager createTraceManager () {
251- return new TraceManager ();
252- }
253-
254183 private void logEvent (
255184 String eventType ,
256185 InvocationContext invocationContext ,
@@ -265,7 +194,7 @@ private void logEvent(
265194 Object content ,
266195 boolean isContentTruncated ,
267196 Optional <EventData > eventData ) {
268- if (!config .enabled () || batchProcessor == null ) {
197+ if (!config .enabled ()) {
269198 return ;
270199 }
271200 if (!config .eventAllowlist ().isEmpty () && !config .eventAllowlist ().contains (eventType )) {
@@ -274,6 +203,11 @@ private void logEvent(
274203 if (config .eventDenylist ().contains (eventType )) {
275204 return ;
276205 }
206+ if (state .isProcessed (invocationContext .invocationId ())) {
207+ return ;
208+ }
209+ String invocationId = invocationContext .invocationId ();
210+ BatchProcessor processor = state .getBatchProcessor (invocationId );
277211 // Ensure table exists before logging.
278212 ensureTableExistsOnce ();
279213 // Log common fields
@@ -301,11 +235,12 @@ private void logEvent(
301235 row .put ("attributes" , convertToJsonNode (getAttributes (data , invocationContext )));
302236
303237 addTraceDetails (row , invocationContext , eventData );
304- batchProcessor .append (row );
238+ processor .append (row );
305239 }
306240
307241 private void addTraceDetails (
308242 Map <String , Object > row , InvocationContext invocationContext , Optional <EventData > eventData ) {
243+ TraceManager traceManager = state .getTraceManager (invocationContext .invocationId ());
309244 String traceId =
310245 eventData
311246 .flatMap (EventData ::traceIdOverride )
@@ -336,7 +271,7 @@ private void addTraceDetails(
336271 private Map <String , Object > getAttributes (
337272 EventData eventData , InvocationContext invocationContext ) {
338273 Map <String , Object > attributes = new HashMap <>(eventData .extraAttributes ());
339-
274+ TraceManager traceManager = state . getTraceManager ( invocationContext . invocationId ());
340275 attributes .put ("root_agent_name" , traceManager .getRootAgentName ());
341276 eventData .model ().ifPresent (m -> attributes .put ("model" , m ));
342277 eventData .modelVersion ().ifPresent (mv -> attributes .put ("model_version" , mv ));
@@ -375,25 +310,17 @@ private Map<String, Object> getAttributes(
375310
376311 @ Override
377312 public Completable close () {
378- if (batchProcessor != null ) {
379- batchProcessor .close ();
380- }
381- if (writeClient != null ) {
382- writeClient .close ();
383- }
384- try {
385- executor .shutdown ();
386- if (!executor .awaitTermination (config .shutdownTimeout ().toMillis (), MILLISECONDS )) {
387- executor .shutdownNow ();
388- }
389- } catch (InterruptedException e ) {
390- executor .shutdownNow ();
391- Thread .currentThread ().interrupt ();
392- }
313+ state .close ();
393314 return Completable .complete ();
394315 }
395316
317+ @ VisibleForTesting
318+ PluginState getState () {
319+ return state ;
320+ }
321+
396322 private Optional <EventData > getCompletedEventData (InvocationContext invocationContext ) {
323+ TraceManager traceManager = state .getTraceManager (invocationContext .invocationId ());
397324 String traceId = traceManager .getTraceId (invocationContext );
398325 // Pop the invocation span from the trace manager.
399326 Optional <RecordData > popped = traceManager .popSpan ();
@@ -426,7 +353,12 @@ public Maybe<Content> onUserMessageCallback(
426353 InvocationContext invocationContext , Content userMessage ) {
427354 return Maybe .fromAction (
428355 () -> {
429- traceManager .ensureInvocationSpan (invocationContext );
356+ if (state .isProcessed (invocationContext .invocationId ())) {
357+ return ;
358+ }
359+ state
360+ .getTraceManager (invocationContext .invocationId ())
361+ .ensureInvocationSpan (invocationContext );
430362 logEvent ("USER_MESSAGE_RECEIVED" , invocationContext , userMessage , Optional .empty ());
431363 if (userMessage .parts ().isPresent ()) {
432364 for (Part part : userMessage .parts ().get ()) {
@@ -454,6 +386,9 @@ public Maybe<Content> onUserMessageCallback(
454386 public Maybe <Event > onEventCallback (InvocationContext invocationContext , Event event ) {
455387 return Maybe .fromAction (
456388 () -> {
389+ if (state .isProcessed (invocationContext .invocationId ())) {
390+ return ;
391+ }
457392 EventData .Builder eventDataBuilder =
458393 EventData .builder ()
459394 .setExtraAttributes (
@@ -510,9 +445,16 @@ public Maybe<Event> onEventCallback(InvocationContext invocationContext, Event e
510445
511446 @ Override
512447 public Maybe <Content > beforeRunCallback (InvocationContext invocationContext ) {
513- traceManager .ensureInvocationSpan (invocationContext );
514448 return Maybe .fromAction (
515- () -> logEvent ("INVOCATION_STARTING" , invocationContext , null , Optional .empty ()));
449+ () -> {
450+ if (state .isProcessed (invocationContext .invocationId ())) {
451+ return ;
452+ }
453+ state
454+ .getTraceManager (invocationContext .invocationId ())
455+ .ensureInvocationSpan (invocationContext );
456+ logEvent ("INVOCATION_STARTING" , invocationContext , null , Optional .empty ());
457+ });
516458 }
517459
518460 @ Override
@@ -524,16 +466,30 @@ public Completable afterRunCallback(InvocationContext invocationContext) {
524466 invocationContext ,
525467 null ,
526468 getCompletedEventData (invocationContext ));
527- batchProcessor .flush ();
528- traceManager .clearStack ();
469+ // Mark invocation ID as processed to avoid memory leaks.
470+ state .markProcessed (invocationContext .invocationId ());
471+ BatchProcessor processor = state .removeProcessor (invocationContext .invocationId ());
472+ if (processor != null ) {
473+ processor .flush ();
474+ processor .close ();
475+ }
476+ TraceManager traceManager = state .removeTraceManager (invocationContext .invocationId ());
477+ if (traceManager != null ) {
478+ traceManager .clearStack ();
479+ }
529480 });
530481 }
531482
532483 @ Override
533484 public Maybe <Content > beforeAgentCallback (BaseAgent agent , CallbackContext callbackContext ) {
534485 return Maybe .fromAction (
535486 () -> {
536- traceManager .pushSpan ("agent:" + agent .name ());
487+ if (state .isProcessed (callbackContext .invocationContext ().invocationId ())) {
488+ return ;
489+ }
490+ state
491+ .getTraceManager (callbackContext .invocationContext ().invocationId ())
492+ .pushSpan ("agent:" + agent .name ());
537493 logEvent ("AGENT_STARTING" , callbackContext .invocationContext (), null , Optional .empty ());
538494 });
539495 }
@@ -563,6 +519,9 @@ public Maybe<LlmResponse> beforeModelCallback(
563519 CallbackContext callbackContext , LlmRequest .Builder llmRequest ) {
564520 return Maybe .fromAction (
565521 () -> {
522+ if (state .isProcessed (callbackContext .invocationContext ().invocationId ())) {
523+ return ;
524+ }
566525 Map <String , Object > attributes = new HashMap <>();
567526 Map <String , Object > llmConfig = new HashMap <>();
568527 LlmRequest req = llmRequest .build ();
@@ -622,7 +581,9 @@ public Maybe<LlmResponse> beforeModelCallback(
622581 .setModel (req .model ().orElse ("" ))
623582 .setExtraAttributes (attributes )
624583 .build ();
625- traceManager .pushSpan ("llm_request" );
584+ state
585+ .getTraceManager (callbackContext .invocationContext ().invocationId ())
586+ .pushSpan ("llm_request" );
626587 logEvent ("LLM_REQUEST" , callbackContext .invocationContext (), req , Optional .of (eventData ));
627588 });
628589 }
@@ -632,6 +593,11 @@ public Maybe<LlmResponse> afterModelCallback(
632593 CallbackContext callbackContext , LlmResponse llmResponse ) {
633594 return Maybe .fromAction (
634595 () -> {
596+ if (state .isProcessed (callbackContext .invocationContext ().invocationId ())) {
597+ return ;
598+ }
599+ TraceManager traceManager =
600+ state .getTraceManager (callbackContext .invocationContext ().invocationId ());
635601 // TODO(b/495809488): Add formatting of the content
636602 ParsedContent parsedContent =
637603 JsonFormatter .parse (llmResponse .content ().orElse (null ), config .maxContentLength ());
@@ -728,6 +694,11 @@ public Maybe<LlmResponse> onModelErrorCallback(
728694 CallbackContext callbackContext , LlmRequest .Builder llmRequest , Throwable error ) {
729695 return Maybe .fromAction (
730696 () -> {
697+ if (state .isProcessed (callbackContext .invocationContext ().invocationId ())) {
698+ return ;
699+ }
700+ TraceManager traceManager =
701+ state .getTraceManager (callbackContext .invocationContext ().invocationId ());
731702 InvocationContext invocationContext = callbackContext .invocationContext ();
732703 Optional <RecordData > popped = traceManager .popSpan ();
733704 String spanId = popped .map (RecordData ::spanId ).orElse (null );
@@ -758,11 +729,14 @@ public Maybe<Map<String, Object>> beforeToolCallback(
758729 BaseTool tool , Map <String , Object > toolArgs , ToolContext toolContext ) {
759730 return Maybe .fromAction (
760731 () -> {
732+ if (state .isProcessed (toolContext .invocationContext ().invocationId ())) {
733+ return ;
734+ }
761735 TruncationResult res = smartTruncate (toolArgs , config .maxContentLength ());
762736 ImmutableMap <String , Object > contentMap =
763737 ImmutableMap .of (
764738 "tool_origin" , getToolOrigin (tool ), "tool" , tool .name (), "args" , res .node ());
765- traceManager .pushSpan ("tool" );
739+ state . getTraceManager ( toolContext . invocationContext (). invocationId ()) .pushSpan ("tool" );
766740 logEvent ("TOOL_STARTING" , toolContext .invocationContext (), contentMap , Optional .empty ());
767741 });
768742 }
@@ -775,6 +749,14 @@ public Maybe<Map<String, Object>> afterToolCallback(
775749 Map <String , Object > result ) {
776750 return Maybe .fromAction (
777751 () -> {
752+ if (state .isProcessed (toolContext .invocationContext ().invocationId ())) {
753+ return ;
754+ }
755+ state
756+ .getTraceManager (toolContext .invocationContext ().invocationId ())
757+ .ensureInvocationSpan (toolContext .invocationContext ());
758+ TraceManager traceManager =
759+ state .getTraceManager (toolContext .invocationContext ().invocationId ());
778760 Optional <RecordData > popped = traceManager .popSpan ();
779761 TruncationResult truncationResult = smartTruncate (result , config .maxContentLength ());
780762 ImmutableMap <String , Object > contentMap =
@@ -812,6 +794,11 @@ public Maybe<Map<String, Object>> onToolErrorCallback(
812794 BaseTool tool , Map <String , Object > toolArgs , ToolContext toolContext , Throwable error ) {
813795 return Maybe .fromAction (
814796 () -> {
797+ if (state .isProcessed (toolContext .invocationContext ().invocationId ())) {
798+ return ;
799+ }
800+ TraceManager traceManager =
801+ state .getTraceManager (toolContext .invocationContext ().invocationId ());
815802 Optional <RecordData > popped = traceManager .popSpan ();
816803 TruncationResult truncationResult = smartTruncate (toolArgs , config .maxContentLength ());
817804 String toolOrigin = getToolOrigin (tool );
0 commit comments