diff --git a/pkg/app/pipedv1/plugin/ecs/deployment/canary.go b/pkg/app/pipedv1/plugin/ecs/deployment/canary.go index 5ebc70c3cf..77b897e746 100644 --- a/pkg/app/pipedv1/plugin/ecs/deployment/canary.go +++ b/pkg/app/pipedv1/plugin/ecs/deployment/canary.go @@ -17,6 +17,7 @@ package deployment import ( "context" "encoding/json" + "errors" "fmt" "github.com/aws/aws-sdk-go-v2/service/ecs/types" @@ -147,3 +148,56 @@ func canaryRollout( lp.Successf("Successfully rolled out CANARY task set %s for service %s", *taskSet.TaskSetArn, *service.ServiceName) return taskSet, nil } + +func (p *ECSPlugin) executeECSCanaryCleanStage( + ctx context.Context, + input *sdk.ExecuteStageInput[ecsconfig.ECSApplicationSpec], + deployTarget *sdk.DeployTarget[ecsconfig.ECSDeployTargetConfig], +) sdk.StageStatus { + lp := input.Client.LogPersister() + + taskSetData, found, err := input.Client.GetDeploymentPluginMetadata(ctx, canaryTaskSetMetadataKey) + if err != nil { + lp.Errorf("Failed to retrieve canary task set from metadata store: %v", err) + return sdk.StageStatusFailure + } + if !found { + lp.Info("No canary task set found in metadata store, nothing to clean up") + return sdk.StageStatusSuccess + } + + var taskSet types.TaskSet + if err := json.Unmarshal([]byte(taskSetData), &taskSet); err != nil { + lp.Errorf("Failed to unmarshal canary task set from metadata store: %v", err) + return sdk.StageStatusFailure + } + + client, err := provider.DefaultRegistry().Client(deployTarget.Name, deployTarget.Config) + if err != nil { + lp.Errorf("Failed to get ECS client for deploy target %s: %v", deployTarget.Name, err) + return sdk.StageStatusFailure + } + + if err := canaryClean(ctx, lp, client, taskSet); err != nil { + lp.Errorf("Failed to clean up ECS canary task set: %v", err) + return sdk.StageStatusFailure + } + + return sdk.StageStatusSuccess +} + +// canaryClean deletes the canary task set +func canaryClean(ctx context.Context, lp sdk.StageLogPersister, client provider.Client, taskSet types.TaskSet) error { + lp.Infof("Deleting canary task set %s", *taskSet.TaskSetArn) + if err := client.DeleteTaskSet(ctx, taskSet); err != nil { + // If the task set is already gone, treat as success + var notFound *types.TaskSetNotFoundException + if errors.As(err, ¬Found) { + lp.Infof("Canary task set %s already deleted, skipping", *taskSet.TaskSetArn) + return nil + } + return fmt.Errorf("failed to delete canary task set %s: %w", *taskSet.TaskSetArn, err) + } + lp.Successf("Successfully deleted canary task set %s", *taskSet.TaskSetArn) + return nil +} diff --git a/pkg/app/pipedv1/plugin/ecs/deployment/canary_test.go b/pkg/app/pipedv1/plugin/ecs/deployment/canary_test.go index 7d3959be29..42e785dae2 100644 --- a/pkg/app/pipedv1/plugin/ecs/deployment/canary_test.go +++ b/pkg/app/pipedv1/plugin/ecs/deployment/canary_test.go @@ -25,6 +25,71 @@ import ( "github.com/stretchr/testify/require" ) +func TestCanaryClean(t *testing.T) { + t.Parallel() + + var ( + tsArn = "arn:aws:ecs:us-east-1:123456789012:task-set/my-cluster/my-service/ecs-svc:1" + taskSet = types.TaskSet{ + TaskSetArn: aws.String(tsArn), + } + ) + + testcases := []struct { + name string + taskSet types.TaskSet + client *mockECSClient + wantErr bool + wantErrMsg string + }{ + { + name: "success: canary task set is deleted", + taskSet: taskSet, + client: &mockECSClient{ + DeleteTaskSetFunc: func(_ context.Context, ts types.TaskSet) error { + assert.Equal(t, tsArn, aws.ToString(ts.TaskSetArn)) + return nil + }, + }, + }, + { + name: "success: task set already deleted (idempotent retry)", + taskSet: taskSet, + client: &mockECSClient{ + DeleteTaskSetFunc: func(_ context.Context, _ types.TaskSet) error { + return &types.TaskSetNotFoundException{} + }, + }, + }, + { + name: "fail: DeleteTaskSet error", + taskSet: taskSet, + client: &mockECSClient{ + DeleteTaskSetFunc: func(_ context.Context, _ types.TaskSet) error { + return errors.New("delete error") + }, + }, + wantErr: true, + wantErrMsg: "failed to delete canary task set", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := canaryClean(context.Background(), &fakeLogPersister{}, tc.client, tc.taskSet) + + if tc.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrMsg) + return + } + require.NoError(t, err) + }) + } +} + func TestCanaryRollout(t *testing.T) { t.Parallel() diff --git a/pkg/app/pipedv1/plugin/ecs/deployment/plugin.go b/pkg/app/pipedv1/plugin/ecs/deployment/plugin.go index 4f3b04ba11..d739c3f5ff 100644 --- a/pkg/app/pipedv1/plugin/ecs/deployment/plugin.go +++ b/pkg/app/pipedv1/plugin/ecs/deployment/plugin.go @@ -81,6 +81,10 @@ func (p *ECSPlugin) ExecuteStage( return &sdk.ExecuteStageResponse{ Status: p.executeECSCanaryRolloutStage(ctx, input, deployTargets[0]), }, nil + case StageECSCanaryClean: + return &sdk.ExecuteStageResponse{ + Status: p.executeECSCanaryCleanStage(ctx, input, deployTargets[0]), + }, nil case StageECSTrafficRouting: return &sdk.ExecuteStageResponse{ Status: p.executeECSTrafficRouting(ctx, input, deployTargets[0]),