Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions packages/a2a-server/src/agent/executor.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { CoderAgentExecutor } from './executor.js';
import type {
ExecutionEventBus,
RequestContext,
TaskStore,
} from '@a2a-js/sdk/server';
import { EventEmitter } from 'node:events';
import { requestStorage } from '../http/requestStorage.js';

// Mocks for constructor dependencies
vi.mock('../config/config.js', () => ({
loadConfig: vi.fn().mockReturnValue({
getSessionId: () => 'test-session',
getTargetDir: () => '/tmp',
getCheckpointingEnabled: () => false,
}),
loadEnvironment: vi.fn(),
setTargetDir: vi.fn().mockReturnValue('/tmp'),
}));

vi.mock('../config/settings.js', () => ({
loadSettings: vi.fn().mockReturnValue({}),
}));

vi.mock('../config/extension.js', () => ({
loadExtensions: vi.fn().mockReturnValue([]),
}));

vi.mock('../http/requestStorage.js', () => ({
requestStorage: {
getStore: vi.fn(),
},
}));

vi.mock('./task.js', () => {
const mockTaskInstance = (taskId: string, contextId: string) => ({
id: taskId,
contextId,
taskState: 'working',
acceptUserMessage: vi
.fn()
.mockImplementation(async function* (context, aborted) {
const isConfirmation = (
context.userMessage.parts as Array<{ kind: string }>
).some((p) => p.kind === 'confirmation');
// Hang only for main user messages (text), allow confirmations to finish quickly
if (!isConfirmation && aborted) {
await new Promise((resolve) => {
aborted.addEventListener('abort', resolve, { once: true });
});
}
yield { type: 'content', value: 'hello' };
}),
acceptAgentMessage: vi.fn().mockResolvedValue(undefined),
scheduleToolCalls: vi.fn().mockResolvedValue(undefined),
waitForPendingTools: vi.fn().mockResolvedValue(undefined),
getAndClearCompletedTools: vi.fn().mockReturnValue([]),
addToolResponsesToHistory: vi.fn(),
sendCompletedToolsToLlm: vi.fn().mockImplementation(async function* () {}),
cancelPendingTools: vi.fn(),
setTaskStateAndPublishUpdate: vi.fn(),
dispose: vi.fn(),
getMetadata: vi.fn().mockResolvedValue({}),
geminiClient: {
initialize: vi.fn().mockResolvedValue(undefined),
},
toSDKTask: () => ({
id: taskId,
contextId,
kind: 'task',
status: { state: 'working', timestamp: new Date().toISOString() },
metadata: {},
history: [],
artifacts: [],
}),
});

const MockTask = vi.fn().mockImplementation(mockTaskInstance);
(MockTask as unknown as { create: Mock }).create = vi
.fn()
.mockImplementation(async (taskId: string, contextId: string) =>
mockTaskInstance(taskId, contextId),
);

return { Task: MockTask };
});

describe('CoderAgentExecutor', () => {
let executor: CoderAgentExecutor;
let mockTaskStore: TaskStore;
let mockEventBus: ExecutionEventBus;

beforeEach(() => {
vi.clearAllMocks();
mockTaskStore = {
save: vi.fn().mockResolvedValue(undefined),
load: vi.fn().mockResolvedValue(undefined),
delete: vi.fn().mockResolvedValue(undefined),
list: vi.fn().mockResolvedValue([]),
} as unknown as TaskStore;

mockEventBus = new EventEmitter() as unknown as ExecutionEventBus;
mockEventBus.publish = vi.fn();
mockEventBus.finished = vi.fn();

executor = new CoderAgentExecutor(mockTaskStore);
});

it('should distinguish between primary and secondary execution', async () => {
const taskId = 'test-task';
const contextId = 'test-context';

const mockSocket = new EventEmitter();
const requestContext = {
userMessage: {
messageId: 'msg-1',
taskId,
contextId,
parts: [{ kind: 'text', text: 'hi' }],
metadata: {
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
},
},
} as unknown as RequestContext;

// Mock requestStorage for primary
(requestStorage.getStore as Mock).mockReturnValue({
req: { socket: mockSocket },
});

// First execution (Primary)
const primaryPromise = executor.execute(requestContext, mockEventBus);

// Give it enough time to reach line 490 in executor.ts
await new Promise((resolve) => setTimeout(resolve, 50));

expect(
(
executor as unknown as { executingTasks: Set<string> }
).executingTasks.has(taskId),
).toBe(true);
const wrapper = executor.getTask(taskId);
expect(wrapper).toBeDefined();

// Mock requestStorage for secondary
const secondarySocket = new EventEmitter();
(requestStorage.getStore as Mock).mockReturnValue({
req: { socket: secondarySocket },
});

const secondaryRequestContext = {
userMessage: {
messageId: 'msg-2',
taskId,
contextId,
parts: [{ kind: 'confirmation', callId: '1', outcome: 'proceed' }],
metadata: {
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
},
},
} as unknown as RequestContext;

const secondaryPromise = executor.execute(
secondaryRequestContext,
mockEventBus,
);

// Secondary execution should NOT add to executingTasks (already there)
// and should return early after its loop
await secondaryPromise;

// Task should still be in executingTasks and NOT disposed
expect(
(
executor as unknown as { executingTasks: Set<string> }
).executingTasks.has(taskId),
).toBe(true);
expect(wrapper?.task.dispose).not.toHaveBeenCalled();

// Now simulate secondary socket closure - it should NOT affect primary
secondarySocket.emit('end');
expect(
(
executor as unknown as { executingTasks: Set<string> }
).executingTasks.has(taskId),
).toBe(true);
expect(wrapper?.task.dispose).not.toHaveBeenCalled();

// Set to terminal state to verify disposal on finish
wrapper!.task.taskState = 'completed';

// Now close primary socket
mockSocket.emit('end');

await primaryPromise;

expect(
(
executor as unknown as { executingTasks: Set<string> }
).executingTasks.has(taskId),
).toBe(false);
expect(wrapper?.task.dispose).toHaveBeenCalled();
});

it('should evict task from cache when it reaches terminal state', async () => {
const taskId = 'test-task-terminal';
const contextId = 'test-context';

const mockSocket = new EventEmitter();
(requestStorage.getStore as Mock).mockReturnValue({
req: { socket: mockSocket },
});

const requestContext = {
userMessage: {
messageId: 'msg-1',
taskId,
contextId,
parts: [{ kind: 'text', text: 'hi' }],
metadata: {
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
},
},
} as unknown as RequestContext;

const primaryPromise = executor.execute(requestContext, mockEventBus);
await new Promise((resolve) => setTimeout(resolve, 50));

const wrapper = executor.getTask(taskId)!;
expect(wrapper).toBeDefined();
// Simulate terminal state
wrapper.task.taskState = 'completed';

// Finish primary execution
mockSocket.emit('end');
await primaryPromise;

expect(executor.getTask(taskId)).toBeUndefined();
expect(wrapper.task.dispose).toHaveBeenCalled();
});
});
71 changes: 55 additions & 16 deletions packages/a2a-server/src/agent/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ export class CoderAgentExecutor implements AgentExecutor {
);
await this.taskStore?.save(wrapper.toSDKTask());
logger.info(`[CoderAgentExecutor] Task ${taskId} state CANCELED saved.`);

// Cleanup listener subscriptions to avoid memory leaks.
wrapper.task.dispose();
this.tasks.delete(taskId);
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : 'Unknown error';
Expand Down Expand Up @@ -320,23 +324,26 @@ export class CoderAgentExecutor implements AgentExecutor {
if (store) {
// Grab the raw socket from the request object
const socket = store.req.socket;
const onClientEnd = () => {
const onSocketEnd = () => {
logger.info(
`[CoderAgentExecutor] Client socket closed for task ${taskId}. Cancelling execution.`,
`[CoderAgentExecutor] Socket ended for message ${userMessage.messageId} (task ${taskId}). Aborting execution loop.`,
);
if (!abortController.signal.aborted) {
abortController.abort();
}
// Clean up the listener to prevent memory leaks
socket.removeListener('close', onClientEnd);
socket.removeListener('end', onSocketEnd);
};

// Listen on the socket's 'end' event (remote closed the connection)
socket.on('end', onClientEnd);
socket.on('end', onSocketEnd);
socket.once('close', () => {
socket.removeListener('end', onSocketEnd);
});

// It's also good practice to remove the listener if the task completes successfully
abortSignal.addEventListener('abort', () => {
socket.removeListener('end', onClientEnd);
socket.removeListener('end', onSocketEnd);
});
logger.info(
`[CoderAgentExecutor] Socket close handler set up for task ${taskId}.`,
Expand Down Expand Up @@ -457,6 +464,26 @@ export class CoderAgentExecutor implements AgentExecutor {
return;
}

// Check if this is the primary/initial execution for this task
const isPrimaryExecution = !this.executingTasks.has(taskId);

if (!isPrimaryExecution) {
logger.info(
`[CoderAgentExecutor] Primary execution already active for task ${taskId}. Starting secondary loop for message ${userMessage.messageId}.`,
);
currentTask.eventBus = eventBus;
for await (const _ of currentTask.acceptUserMessage(
requestContext,
abortController.signal,
)) {
logger.info(
`[CoderAgentExecutor] Processing user message ${userMessage.messageId} in secondary execution loop for task ${taskId}.`,
);
}
// End this execution-- the original/source will be resumed.
return;
}

logger.info(
`[CoderAgentExecutor] Starting main execution for message ${userMessage.messageId} for task ${taskId}.`,
);
Expand Down Expand Up @@ -598,18 +625,30 @@ export class CoderAgentExecutor implements AgentExecutor {
}
}
} finally {
this.executingTasks.delete(taskId);
logger.info(
`[CoderAgentExecutor] Saving final state for task ${taskId}.`,
);
try {
await this.taskStore?.save(wrapper.toSDKTask());
logger.info(`[CoderAgentExecutor] Task ${taskId} state saved.`);
} catch (saveError) {
logger.error(
`[CoderAgentExecutor] Failed to save task ${taskId} state in finally block:`,
saveError,
if (isPrimaryExecution) {
this.executingTasks.delete(taskId);
logger.info(
`[CoderAgentExecutor] Saving final state for task ${taskId}.`,
);
try {
await this.taskStore?.save(wrapper.toSDKTask());
logger.info(`[CoderAgentExecutor] Task ${taskId} state saved.`);
} catch (saveError) {
logger.error(
`[CoderAgentExecutor] Failed to save task ${taskId} state in finally block:`,
saveError,
);
}

if (
['canceled', 'failed', 'completed'].includes(currentTask.taskState)
) {
logger.info(
`[CoderAgentExecutor] Task ${taskId} reached terminal state ${currentTask.taskState}. Evicting and disposing.`,
);
wrapper.task.dispose();
this.tasks.delete(taskId);
}
}
}
}
Expand Down
Loading
Loading