Skip to content

Commit 8a80dc8

Browse files
cocosheng-gabhipatel12adamfweidman
authored andcommitted
feat(a2a): switch from callback-based to event-driven tool scheduler (#21467)
Co-authored-by: Abhi <abhipatel@google.com> Co-authored-by: Adam Weidman <adamfweidman@google.com> # Conflicts: # packages/a2a-server/src/config/config.ts
1 parent 8948a90 commit 8a80dc8

10 files changed

Lines changed: 1329 additions & 60 deletions

File tree

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
/**
2+
* @license
3+
* Copyright 2025 Google LLC
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
8+
import { CoderAgentExecutor } from './executor.js';
9+
import type {
10+
ExecutionEventBus,
11+
RequestContext,
12+
TaskStore,
13+
} from '@a2a-js/sdk/server';
14+
import { EventEmitter } from 'node:events';
15+
import { requestStorage } from '../http/requestStorage.js';
16+
17+
// Mocks for constructor dependencies
18+
vi.mock('../config/config.js', () => ({
19+
loadConfig: vi.fn().mockReturnValue({
20+
getSessionId: () => 'test-session',
21+
getTargetDir: () => '/tmp',
22+
getCheckpointingEnabled: () => false,
23+
}),
24+
loadEnvironment: vi.fn(),
25+
setTargetDir: vi.fn().mockReturnValue('/tmp'),
26+
}));
27+
28+
vi.mock('../config/settings.js', () => ({
29+
loadSettings: vi.fn().mockReturnValue({}),
30+
}));
31+
32+
vi.mock('../config/extension.js', () => ({
33+
loadExtensions: vi.fn().mockReturnValue([]),
34+
}));
35+
36+
vi.mock('../http/requestStorage.js', () => ({
37+
requestStorage: {
38+
getStore: vi.fn(),
39+
},
40+
}));
41+
42+
vi.mock('./task.js', () => {
43+
const mockTaskInstance = (taskId: string, contextId: string) => ({
44+
id: taskId,
45+
contextId,
46+
taskState: 'working',
47+
acceptUserMessage: vi
48+
.fn()
49+
.mockImplementation(async function* (context, aborted) {
50+
const isConfirmation = (
51+
context.userMessage.parts as Array<{ kind: string }>
52+
).some((p) => p.kind === 'confirmation');
53+
// Hang only for main user messages (text), allow confirmations to finish quickly
54+
if (!isConfirmation && aborted) {
55+
await new Promise((resolve) => {
56+
aborted.addEventListener('abort', resolve, { once: true });
57+
});
58+
}
59+
yield { type: 'content', value: 'hello' };
60+
}),
61+
acceptAgentMessage: vi.fn().mockResolvedValue(undefined),
62+
scheduleToolCalls: vi.fn().mockResolvedValue(undefined),
63+
waitForPendingTools: vi.fn().mockResolvedValue(undefined),
64+
getAndClearCompletedTools: vi.fn().mockReturnValue([]),
65+
addToolResponsesToHistory: vi.fn(),
66+
sendCompletedToolsToLlm: vi.fn().mockImplementation(async function* () {}),
67+
cancelPendingTools: vi.fn(),
68+
setTaskStateAndPublishUpdate: vi.fn(),
69+
dispose: vi.fn(),
70+
getMetadata: vi.fn().mockResolvedValue({}),
71+
geminiClient: {
72+
initialize: vi.fn().mockResolvedValue(undefined),
73+
},
74+
toSDKTask: () => ({
75+
id: taskId,
76+
contextId,
77+
kind: 'task',
78+
status: { state: 'working', timestamp: new Date().toISOString() },
79+
metadata: {},
80+
history: [],
81+
artifacts: [],
82+
}),
83+
});
84+
85+
const MockTask = vi.fn().mockImplementation(mockTaskInstance);
86+
(MockTask as unknown as { create: Mock }).create = vi
87+
.fn()
88+
.mockImplementation(async (taskId: string, contextId: string) =>
89+
mockTaskInstance(taskId, contextId),
90+
);
91+
92+
return { Task: MockTask };
93+
});
94+
95+
describe('CoderAgentExecutor', () => {
96+
let executor: CoderAgentExecutor;
97+
let mockTaskStore: TaskStore;
98+
let mockEventBus: ExecutionEventBus;
99+
100+
beforeEach(() => {
101+
vi.clearAllMocks();
102+
mockTaskStore = {
103+
save: vi.fn().mockResolvedValue(undefined),
104+
load: vi.fn().mockResolvedValue(undefined),
105+
delete: vi.fn().mockResolvedValue(undefined),
106+
list: vi.fn().mockResolvedValue([]),
107+
} as unknown as TaskStore;
108+
109+
mockEventBus = new EventEmitter() as unknown as ExecutionEventBus;
110+
mockEventBus.publish = vi.fn();
111+
mockEventBus.finished = vi.fn();
112+
113+
executor = new CoderAgentExecutor(mockTaskStore);
114+
});
115+
116+
it('should distinguish between primary and secondary execution', async () => {
117+
const taskId = 'test-task';
118+
const contextId = 'test-context';
119+
120+
const mockSocket = new EventEmitter();
121+
const requestContext = {
122+
userMessage: {
123+
messageId: 'msg-1',
124+
taskId,
125+
contextId,
126+
parts: [{ kind: 'text', text: 'hi' }],
127+
metadata: {
128+
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
129+
},
130+
},
131+
} as unknown as RequestContext;
132+
133+
// Mock requestStorage for primary
134+
(requestStorage.getStore as Mock).mockReturnValue({
135+
req: { socket: mockSocket },
136+
});
137+
138+
// First execution (Primary)
139+
const primaryPromise = executor.execute(requestContext, mockEventBus);
140+
141+
// Give it enough time to reach line 490 in executor.ts
142+
await new Promise((resolve) => setTimeout(resolve, 50));
143+
144+
expect(
145+
(
146+
executor as unknown as { executingTasks: Set<string> }
147+
).executingTasks.has(taskId),
148+
).toBe(true);
149+
const wrapper = executor.getTask(taskId);
150+
expect(wrapper).toBeDefined();
151+
152+
// Mock requestStorage for secondary
153+
const secondarySocket = new EventEmitter();
154+
(requestStorage.getStore as Mock).mockReturnValue({
155+
req: { socket: secondarySocket },
156+
});
157+
158+
const secondaryRequestContext = {
159+
userMessage: {
160+
messageId: 'msg-2',
161+
taskId,
162+
contextId,
163+
parts: [{ kind: 'confirmation', callId: '1', outcome: 'proceed' }],
164+
metadata: {
165+
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
166+
},
167+
},
168+
} as unknown as RequestContext;
169+
170+
const secondaryPromise = executor.execute(
171+
secondaryRequestContext,
172+
mockEventBus,
173+
);
174+
175+
// Secondary execution should NOT add to executingTasks (already there)
176+
// and should return early after its loop
177+
await secondaryPromise;
178+
179+
// Task should still be in executingTasks and NOT disposed
180+
expect(
181+
(
182+
executor as unknown as { executingTasks: Set<string> }
183+
).executingTasks.has(taskId),
184+
).toBe(true);
185+
expect(wrapper?.task.dispose).not.toHaveBeenCalled();
186+
187+
// Now simulate secondary socket closure - it should NOT affect primary
188+
secondarySocket.emit('end');
189+
expect(
190+
(
191+
executor as unknown as { executingTasks: Set<string> }
192+
).executingTasks.has(taskId),
193+
).toBe(true);
194+
expect(wrapper?.task.dispose).not.toHaveBeenCalled();
195+
196+
// Set to terminal state to verify disposal on finish
197+
wrapper!.task.taskState = 'completed';
198+
199+
// Now close primary socket
200+
mockSocket.emit('end');
201+
202+
await primaryPromise;
203+
204+
expect(
205+
(
206+
executor as unknown as { executingTasks: Set<string> }
207+
).executingTasks.has(taskId),
208+
).toBe(false);
209+
expect(wrapper?.task.dispose).toHaveBeenCalled();
210+
});
211+
212+
it('should evict task from cache when it reaches terminal state', async () => {
213+
const taskId = 'test-task-terminal';
214+
const contextId = 'test-context';
215+
216+
const mockSocket = new EventEmitter();
217+
(requestStorage.getStore as Mock).mockReturnValue({
218+
req: { socket: mockSocket },
219+
});
220+
221+
const requestContext = {
222+
userMessage: {
223+
messageId: 'msg-1',
224+
taskId,
225+
contextId,
226+
parts: [{ kind: 'text', text: 'hi' }],
227+
metadata: {
228+
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
229+
},
230+
},
231+
} as unknown as RequestContext;
232+
233+
const primaryPromise = executor.execute(requestContext, mockEventBus);
234+
await new Promise((resolve) => setTimeout(resolve, 50));
235+
236+
const wrapper = executor.getTask(taskId)!;
237+
expect(wrapper).toBeDefined();
238+
// Simulate terminal state
239+
wrapper.task.taskState = 'completed';
240+
241+
// Finish primary execution
242+
mockSocket.emit('end');
243+
await primaryPromise;
244+
245+
expect(executor.getTask(taskId)).toBeUndefined();
246+
expect(wrapper.task.dispose).toHaveBeenCalled();
247+
});
248+
});

packages/a2a-server/src/agent/executor.ts

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ export class CoderAgentExecutor implements AgentExecutor {
252252
);
253253
await this.taskStore?.save(wrapper.toSDKTask());
254254
logger.info(`[CoderAgentExecutor] Task ${taskId} state CANCELED saved.`);
255+
256+
// Cleanup listener subscriptions to avoid memory leaks.
257+
wrapper.task.dispose();
258+
this.tasks.delete(taskId);
255259
} catch (error) {
256260
const errorMessage =
257261
error instanceof Error ? error.message : 'Unknown error';
@@ -320,23 +324,26 @@ export class CoderAgentExecutor implements AgentExecutor {
320324
if (store) {
321325
// Grab the raw socket from the request object
322326
const socket = store.req.socket;
323-
const onClientEnd = () => {
327+
const onSocketEnd = () => {
324328
logger.info(
325-
`[CoderAgentExecutor] Client socket closed for task ${taskId}. Cancelling execution.`,
329+
`[CoderAgentExecutor] Socket ended for message ${userMessage.messageId} (task ${taskId}). Aborting execution loop.`,
326330
);
327331
if (!abortController.signal.aborted) {
328332
abortController.abort();
329333
}
330334
// Clean up the listener to prevent memory leaks
331-
socket.removeListener('close', onClientEnd);
335+
socket.removeListener('end', onSocketEnd);
332336
};
333337

334338
// Listen on the socket's 'end' event (remote closed the connection)
335-
socket.on('end', onClientEnd);
339+
socket.on('end', onSocketEnd);
340+
socket.once('close', () => {
341+
socket.removeListener('end', onSocketEnd);
342+
});
336343

337344
// It's also good practice to remove the listener if the task completes successfully
338345
abortSignal.addEventListener('abort', () => {
339-
socket.removeListener('end', onClientEnd);
346+
socket.removeListener('end', onSocketEnd);
340347
});
341348
logger.info(
342349
`[CoderAgentExecutor] Socket close handler set up for task ${taskId}.`,
@@ -457,6 +464,26 @@ export class CoderAgentExecutor implements AgentExecutor {
457464
return;
458465
}
459466

467+
// Check if this is the primary/initial execution for this task
468+
const isPrimaryExecution = !this.executingTasks.has(taskId);
469+
470+
if (!isPrimaryExecution) {
471+
logger.info(
472+
`[CoderAgentExecutor] Primary execution already active for task ${taskId}. Starting secondary loop for message ${userMessage.messageId}.`,
473+
);
474+
currentTask.eventBus = eventBus;
475+
for await (const _ of currentTask.acceptUserMessage(
476+
requestContext,
477+
abortController.signal,
478+
)) {
479+
logger.info(
480+
`[CoderAgentExecutor] Processing user message ${userMessage.messageId} in secondary execution loop for task ${taskId}.`,
481+
);
482+
}
483+
// End this execution-- the original/source will be resumed.
484+
return;
485+
}
486+
460487
logger.info(
461488
`[CoderAgentExecutor] Starting main execution for message ${userMessage.messageId} for task ${taskId}.`,
462489
);
@@ -598,18 +625,30 @@ export class CoderAgentExecutor implements AgentExecutor {
598625
}
599626
}
600627
} finally {
601-
this.executingTasks.delete(taskId);
602-
logger.info(
603-
`[CoderAgentExecutor] Saving final state for task ${taskId}.`,
604-
);
605-
try {
606-
await this.taskStore?.save(wrapper.toSDKTask());
607-
logger.info(`[CoderAgentExecutor] Task ${taskId} state saved.`);
608-
} catch (saveError) {
609-
logger.error(
610-
`[CoderAgentExecutor] Failed to save task ${taskId} state in finally block:`,
611-
saveError,
628+
if (isPrimaryExecution) {
629+
this.executingTasks.delete(taskId);
630+
logger.info(
631+
`[CoderAgentExecutor] Saving final state for task ${taskId}.`,
612632
);
633+
try {
634+
await this.taskStore?.save(wrapper.toSDKTask());
635+
logger.info(`[CoderAgentExecutor] Task ${taskId} state saved.`);
636+
} catch (saveError) {
637+
logger.error(
638+
`[CoderAgentExecutor] Failed to save task ${taskId} state in finally block:`,
639+
saveError,
640+
);
641+
}
642+
643+
if (
644+
['canceled', 'failed', 'completed'].includes(currentTask.taskState)
645+
) {
646+
logger.info(
647+
`[CoderAgentExecutor] Task ${taskId} reached terminal state ${currentTask.taskState}. Evicting and disposing.`,
648+
);
649+
wrapper.task.dispose();
650+
this.tasks.delete(taskId);
651+
}
613652
}
614653
}
615654
}

0 commit comments

Comments
 (0)