Skip to content

Commit 6420ed9

Browse files
fix(ec2): address TOCTOU race and heartbeat false-positive for EC2 tasks
1. TOCTOU race in instance selection: after tagging an instance as busy, re-describe to verify our task-id stuck. If another orchestrator won the race, try the next idle candidate instead of double-dispatching. 2. Heartbeat false-positive: EC2/ECS tasks invoke run_task() directly and may not send continuous heartbeats. Suppress sessionUnhealthy checks when compute-level crash detection (pollSession) is active, preventing premature task failure after ~6 minutes. 3. SSM Cancelling status: map to 'running' (transient) instead of 'failed' to avoid premature failure while cancel propagates. 4. Fix babel parse errors in test mocks (remove `: unknown` annotations from jest.mock factory callbacks).
1 parent a41fcb6 commit 6420ed9

3 files changed

Lines changed: 108 additions & 27 deletions

File tree

cdk/src/handlers/orchestrate-task.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,22 @@ const durableHandler: DurableExecutionHandler<OrchestrateTaskEvent, void> = asyn
231231
}
232232
}
233233

234-
return { ...ddbState, consecutiveComputePollFailures, consecutiveComputeCompletedPolls };
234+
// For ECS/EC2 tasks, suppress heartbeat-based sessionUnhealthy since those
235+
// backends have compute-level crash detection and may not send heartbeats.
236+
const suppressHeartbeat = computeStrategy ? { sessionUnhealthy: false } : {};
237+
return { ...ddbState, ...suppressHeartbeat, consecutiveComputePollFailures, consecutiveComputeCompletedPolls };
235238
},
236239
{
237240
initialState: { attempts: 0 },
238241
waitStrategy: (state: PollState) => {
239242
if (state.lastStatus && TERMINAL_STATUSES.includes(state.lastStatus)) {
240243
return { shouldContinue: false };
241244
}
242-
if (state.sessionUnhealthy) {
245+
// Heartbeat-based health checks only apply to AgentCore tasks.
246+
// ECS/EC2 tasks have compute-level crash detection (pollSession) in the
247+
// poll callback, so stale heartbeats should not terminate polling early
248+
// — the agent entrypoint on those backends may not send continuous heartbeats.
249+
if (state.sessionUnhealthy && !computeStrategy) {
243250
return { shouldContinue: false };
244251
}
245252
if (state.attempts >= MAX_POLL_ATTEMPTS) {

cdk/src/handlers/shared/strategies/ec2-strategy.ts

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ export class Ec2ComputeStrategy implements ComputeStrategy {
8080
ContentType: 'application/json',
8181
}));
8282

83-
// 2. Find an idle instance
83+
// 2. Find an idle instance and claim it atomically via tag-then-verify.
84+
// Multiple orchestrators may race for the same instance, so after tagging
85+
// we re-describe to confirm our task-id stuck. If another invocation
86+
// overwrote the tag, we try the next candidate.
8487
const describeResult = await getEc2Client().send(new DescribeInstancesCommand({
8588
Filters: [
8689
{ Name: `tag:${EC2_FLEET_TAG_KEY}`, Values: [EC2_FLEET_TAG_VALUE] },
@@ -89,21 +92,47 @@ export class Ec2ComputeStrategy implements ComputeStrategy {
8992
],
9093
}));
9194

92-
const instances = (describeResult.Reservations ?? []).flatMap(r => r.Instances ?? []);
93-
if (instances.length === 0 || !instances[0]?.InstanceId) {
95+
const candidates = (describeResult.Reservations ?? []).flatMap(r => r.Instances ?? []);
96+
if (candidates.length === 0) {
9497
throw new Error('No idle EC2 instances available in fleet');
9598
}
9699

97-
const instanceId = instances[0].InstanceId;
100+
let instanceId: string | undefined;
101+
for (const candidate of candidates) {
102+
const candidateId = candidate.InstanceId;
103+
if (!candidateId) continue;
98104

99-
// 3. Tag instance as busy
100-
await getEc2Client().send(new CreateTagsCommand({
101-
Resources: [instanceId],
102-
Tags: [
103-
{ Key: 'bgagent:status', Value: 'busy' },
104-
{ Key: 'bgagent:task-id', Value: taskId },
105-
],
106-
}));
105+
// 3a. Tag instance as busy with our task-id
106+
await getEc2Client().send(new CreateTagsCommand({
107+
Resources: [candidateId],
108+
Tags: [
109+
{ Key: 'bgagent:status', Value: 'busy' },
110+
{ Key: 'bgagent:task-id', Value: taskId },
111+
],
112+
}));
113+
114+
// 3b. Re-describe to verify we won the race
115+
const verifyResult = await getEc2Client().send(new DescribeInstancesCommand({
116+
InstanceIds: [candidateId],
117+
}));
118+
const verifiedInstance = verifyResult.Reservations?.[0]?.Instances?.[0];
119+
const taskIdTag = verifiedInstance?.Tags?.find(t => t.Key === 'bgagent:task-id');
120+
121+
if (taskIdTag?.Value === taskId) {
122+
instanceId = candidateId;
123+
break;
124+
}
125+
126+
logger.warn('Lost instance claim race, trying next candidate', {
127+
task_id: taskId,
128+
instance_id: candidateId,
129+
claimed_by: taskIdTag?.Value,
130+
});
131+
}
132+
133+
if (!instanceId) {
134+
throw new Error('No idle EC2 instances available in fleet (all candidates claimed by other tasks)');
135+
}
107136

108137
// 4. Build the boot script
109138
// All task data is read from the S3 payload at runtime to avoid shell
@@ -209,13 +238,13 @@ export class Ec2ComputeStrategy implements ComputeStrategy {
209238
case 'InProgress':
210239
case 'Pending':
211240
case 'Delayed':
241+
case 'Cancelling': // transient — command still running while cancel propagates
212242
return { status: 'running' };
213243
case 'Success':
214244
return { status: 'completed' };
215245
case 'Failed':
216246
case 'Cancelled':
217247
case 'TimedOut':
218-
case 'Cancelling':
219248
return { status: 'failed', error: result.StatusDetails ?? `SSM command ${status}` };
220249
default:
221250
// Covers any unexpected status values — treat as running to avoid

cdk/test/handlers/shared/strategies/ec2-strategy.test.ts

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,23 @@ process.env.ECR_IMAGE_URI = ECR_IMAGE;
3333
const mockEc2Send = jest.fn();
3434
jest.mock('@aws-sdk/client-ec2', () => ({
3535
EC2Client: jest.fn(() => ({ send: mockEc2Send })),
36-
DescribeInstancesCommand: jest.fn((input: unknown) => ({ _type: 'DescribeInstances', input })),
37-
CreateTagsCommand: jest.fn((input: unknown) => ({ _type: 'CreateTags', input })),
38-
DeleteTagsCommand: jest.fn((input: unknown) => ({ _type: 'DeleteTags', input })),
36+
DescribeInstancesCommand: jest.fn((input) => ({ _type: 'DescribeInstances', input })),
37+
CreateTagsCommand: jest.fn((input) => ({ _type: 'CreateTags', input })),
38+
DeleteTagsCommand: jest.fn((input) => ({ _type: 'DeleteTags', input })),
3939
}));
4040

4141
const mockSsmSend = jest.fn();
4242
jest.mock('@aws-sdk/client-ssm', () => ({
4343
SSMClient: jest.fn(() => ({ send: mockSsmSend })),
44-
SendCommandCommand: jest.fn((input: unknown) => ({ _type: 'SendCommand', input })),
45-
GetCommandInvocationCommand: jest.fn((input: unknown) => ({ _type: 'GetCommandInvocation', input })),
46-
CancelCommandCommand: jest.fn((input: unknown) => ({ _type: 'CancelCommand', input })),
44+
SendCommandCommand: jest.fn((input) => ({ _type: 'SendCommand', input })),
45+
GetCommandInvocationCommand: jest.fn((input) => ({ _type: 'GetCommandInvocation', input })),
46+
CancelCommandCommand: jest.fn((input) => ({ _type: 'CancelCommand', input })),
4747
}));
4848

4949
const mockS3Send = jest.fn();
5050
jest.mock('@aws-sdk/client-s3', () => ({
5151
S3Client: jest.fn(() => ({ send: mockS3Send })),
52-
PutObjectCommand: jest.fn((input: unknown) => ({ _type: 'PutObject', input })),
52+
PutObjectCommand: jest.fn((input) => ({ _type: 'PutObject', input })),
5353
}));
5454

5555
import { Ec2ComputeStrategy } from '../../../../src/handlers/shared/strategies/ec2-strategy';
@@ -65,7 +65,7 @@ describe('Ec2ComputeStrategy', () => {
6565
});
6666

6767
describe('startSession', () => {
68-
test('finds idle instance, tags as busy, uploads to S3, sends SSM command, returns handle', async () => {
68+
test('finds idle instance, tags as busy, verifies claim, uploads to S3, sends SSM command, returns handle', async () => {
6969
// S3 upload
7070
mockS3Send.mockResolvedValueOnce({});
7171
// DescribeInstances — return one idle instance
@@ -74,6 +74,10 @@ describe('Ec2ComputeStrategy', () => {
7474
});
7575
// CreateTags (mark busy)
7676
mockEc2Send.mockResolvedValueOnce({});
77+
// DescribeInstances — verify claim (tag matches our task-id)
78+
mockEc2Send.mockResolvedValueOnce({
79+
Reservations: [{ Instances: [{ InstanceId: INSTANCE_ID, Tags: [{ Key: 'bgagent:task-id', Value: 'TASK001' }] }] }],
80+
});
7781
// SSM SendCommand
7882
mockSsmSend.mockResolvedValueOnce({
7983
Command: { CommandId: COMMAND_ID },
@@ -98,8 +102,8 @@ describe('Ec2ComputeStrategy', () => {
98102
expect(s3Call.input.Bucket).toBe(PAYLOAD_BUCKET);
99103
expect(s3Call.input.Key).toBe('tasks/TASK001/payload.json');
100104

101-
// Verify DescribeInstances filter
102-
expect(mockEc2Send).toHaveBeenCalledTimes(2);
105+
// Verify EC2 calls: DescribeInstances (find idle), CreateTags (claim), DescribeInstances (verify)
106+
expect(mockEc2Send).toHaveBeenCalledTimes(3);
103107
const describeCall = mockEc2Send.mock.calls[0][0];
104108
expect(describeCall.input.Filters).toEqual(expect.arrayContaining([
105109
expect.objectContaining({ Name: `tag:${FLEET_TAG_KEY}`, Values: [FLEET_TAG_VALUE] }),
@@ -123,6 +127,43 @@ describe('Ec2ComputeStrategy', () => {
123127
expect(ssmCall.input.TimeoutSeconds).toBe(32400);
124128
});
125129

130+
test('tries next candidate when race is lost on first instance', async () => {
131+
const INSTANCE_ID_2 = 'i-0987654321fedcba0';
132+
// S3 upload
133+
mockS3Send.mockResolvedValueOnce({});
134+
// DescribeInstances — return two idle instances
135+
mockEc2Send.mockResolvedValueOnce({
136+
Reservations: [{ Instances: [{ InstanceId: INSTANCE_ID }, { InstanceId: INSTANCE_ID_2 }] }],
137+
});
138+
// CreateTags on first instance
139+
mockEc2Send.mockResolvedValueOnce({});
140+
// Verify first instance — another task claimed it
141+
mockEc2Send.mockResolvedValueOnce({
142+
Reservations: [{ Instances: [{ InstanceId: INSTANCE_ID, Tags: [{ Key: 'bgagent:task-id', Value: 'OTHER_TASK' }] }] }],
143+
});
144+
// CreateTags on second instance
145+
mockEc2Send.mockResolvedValueOnce({});
146+
// Verify second instance — our task-id stuck
147+
mockEc2Send.mockResolvedValueOnce({
148+
Reservations: [{ Instances: [{ InstanceId: INSTANCE_ID_2, Tags: [{ Key: 'bgagent:task-id', Value: 'TASK001' }] }] }],
149+
});
150+
// SSM SendCommand
151+
mockSsmSend.mockResolvedValueOnce({
152+
Command: { CommandId: COMMAND_ID },
153+
});
154+
155+
const strategy = new Ec2ComputeStrategy();
156+
const handle = await strategy.startSession({
157+
taskId: 'TASK001',
158+
payload: { repo_url: 'org/repo' },
159+
blueprintConfig: { compute_type: 'ec2', runtime_arn: '' },
160+
});
161+
162+
const ec2Handle = handle as Extract<typeof handle, { strategyType: 'ec2' }>;
163+
expect(ec2Handle.instanceId).toBe(INSTANCE_ID_2);
164+
expect(mockEc2Send).toHaveBeenCalledTimes(5); // describe + 2*(tag + verify)
165+
});
166+
126167
test('throws when no idle instances available', async () => {
127168
// S3 upload
128169
mockS3Send.mockResolvedValueOnce({});
@@ -148,6 +189,10 @@ describe('Ec2ComputeStrategy', () => {
148189
});
149190
// CreateTags
150191
mockEc2Send.mockResolvedValueOnce({});
192+
// DescribeInstances — verify claim
193+
mockEc2Send.mockResolvedValueOnce({
194+
Reservations: [{ Instances: [{ InstanceId: INSTANCE_ID, Tags: [{ Key: 'bgagent:task-id', Value: 'TASK001' }] }] }],
195+
});
151196
// SSM SendCommand — return no CommandId
152197
mockSsmSend.mockResolvedValueOnce({ Command: {} });
153198

@@ -226,12 +271,12 @@ describe('Ec2ComputeStrategy', () => {
226271
expect(result).toEqual({ status: 'failed', error: 'Command timed out' });
227272
});
228273

229-
test('returns failed for Cancelling status', async () => {
274+
test('returns running for Cancelling status (transient)', async () => {
230275
mockSsmSend.mockResolvedValueOnce({ Status: 'Cancelling', StatusDetails: 'Command is being cancelled' });
231276

232277
const strategy = new Ec2ComputeStrategy();
233278
const result = await strategy.pollSession(makeHandle());
234-
expect(result).toEqual({ status: 'failed', error: 'Command is being cancelled' });
279+
expect(result).toEqual({ status: 'running' });
235280
});
236281

237282
test('returns running for unknown status (default case)', async () => {

0 commit comments

Comments
 (0)