diff --git a/cmd/delete.go b/cmd/delete.go index e52958b..a7ce723 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -21,12 +21,13 @@ func Delete(newRM func(context.Context) (rootmanager.RootManager, error)) *cobra Short: "Delete root user credentials", Long: `Delete root user credentials for specific AWS Organization member accounts.`, } - cmd.PersistentFlags().StringSliceVarP(&accountsFlags, "accounts", "a", []string{}, "List of AWS account IDs to audit (comma-separated). Use \"all\" to audit all accounts.") cmd.AddCommand(deleteSubcommand(newRM, "all", "Delete all existing root user credentials", "Delete all existing root user credentials for specific AWS Organization member accounts.")) cmd.AddCommand(deleteSubcommand(newRM, "login", "Delete root user Login Profile", "Delete existing root user Login Profile for specific AWS Organization member accounts.")) cmd.AddCommand(deleteSubcommand(newRM, "keys", "Delete root user Access Keys", "Delete existing root user Access Keys for specific AWS Organization member accounts.")) cmd.AddCommand(deleteSubcommand(newRM, "mfa", "Deactivate root user MFA Devices", "Deactivate existing root user MFA Devices for specific AWS Organization member accounts.")) cmd.AddCommand(deleteSubcommand(newRM, "certificates", "Delete root user Signin Certificates", "Delete existing root user Signing Certificates for specific AWS Organization member accounts.")) + cmd.AddCommand(DeleteS3BucketPolicy(newRM)) + cmd.AddCommand(DeleteSQSQueuePolicy(newRM)) return cmd } @@ -35,15 +36,18 @@ func deleteSubcommand(newRM func(context.Context) (rootmanager.RootManager, erro if use == "certificates" { credentialType = "certificate" } - return &cobra.Command{ + var accounts []string + cmd := &cobra.Command{ Use: use, Short: short, Long: long, SilenceUsage: true, RunE: func(cmd *cobra.Command, args []string) error { - return runDelete(newRM, cmd.OutOrStdout(), accountsFlags, credentialType) + return runDelete(newRM, cmd.OutOrStdout(), accounts, credentialType) }, } + cmd.Flags().StringSliceVarP(&accounts, "accounts", "a", []string{}, "List of AWS account IDs (comma-separated). Use \"all\" to select all accounts.") + return cmd } func runDelete(newRM func(context.Context) (rootmanager.RootManager, error), w io.Writer, accountsFlags []string, credentialType string) error { diff --git a/cmd/delete_s3_bucket_policy.go b/cmd/delete_s3_bucket_policy.go new file mode 100644 index 0000000..2bde02d --- /dev/null +++ b/cmd/delete_s3_bucket_policy.go @@ -0,0 +1,120 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + + "github.com/spf13/cobra" + "github.com/unicrons/aws-root-manager/internal/aws" + "github.com/unicrons/aws-root-manager/internal/cli/output" + "github.com/unicrons/aws-root-manager/internal/cli/ui" + "github.com/unicrons/aws-root-manager/rootmanager" +) + +func DeleteS3BucketPolicy(newRM func(context.Context) (rootmanager.RootManager, error)) *cobra.Command { + var accountId, bucketName string + cmd := &cobra.Command{ + Use: "s3-bucket-policy", + Short: "Delete an S3 bucket policy", + Long: `Delete the bucket policy attached to an S3 bucket owned by a member account using the S3UnlockBucketPolicy root task policy.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return runDeleteS3BucketPolicy(newRM, cmd.OutOrStdout(), accountId, bucketName) + }, + } + cmd.Flags().StringVar(&accountId, "account", "", "AWS account ID that owns the bucket (optional; if absent, a TUI lists the organization's accounts)") + cmd.Flags().StringVar(&bucketName, "bucket", "", "Name of the S3 bucket (optional; if absent, a TUI lists the account's buckets)") + return cmd +} + +func runDeleteS3BucketPolicy(newRM func(context.Context) (rootmanager.RootManager, error), w io.Writer, accountId, bucketName string) error { + ctx := context.Background() + rm, err := newRM(ctx) + if err != nil { + return fmt.Errorf("failed to initialize root manager: %w", err) + } + + accountId, err = selectSingleAccount(ctx, accountId) + if err != nil { + return err + } + + if bucketName == "" { + buckets, err := rm.ListAccountBuckets(ctx, accountId) + if err != nil { + return fmt.Errorf("failed to list buckets for account %s: %w", accountId, err) + } + if len(buckets) == 0 { + return fmt.Errorf("no buckets found in account %s", accountId) + } + idx, err := ui.PromptSingle("Select the bucket whose policy will be deleted", buckets) + if err != nil { + return err + } + if idx < 0 { + return fmt.Errorf("no bucket selected") + } + bucketName = buckets[idx] + } + + policy, err := rm.GetS3BucketPolicy(ctx, accountId, bucketName) + if err != nil { + return fmt.Errorf("failed to get bucket policy: %w", err) + } + if policy == "" { + fmt.Fprintln(w, "No bucket policy found.") + return nil + } + if outputFlag == "table" { + fmt.Fprintf(w, "Current bucket policy for %s:\n\n", bucketName) + output.RenderPolicy(w, policy) + + confirmed, err := ui.Confirm("Delete this policy?") + if err != nil { + return err + } + if !confirmed { + fmt.Fprintln(w, "Aborted.") + return nil + } + } + + result, err := rm.DeleteS3BucketPolicy(ctx, accountId, bucketName) + if err != nil { + return err + } + + if !result.Success { + slog.Error("failed to delete s3 bucket policy", "account_id", result.AccountId, "bucket", result.ResourceName, "error", result.Error) + return fmt.Errorf("failed to delete bucket policy for bucket %s", result.ResourceName) + } + + var headers []string + var data [][]any + if outputFlag == "table" { + headers = []string{"Account", "ResourceType", "Bucket", "Status"} + data = [][]any{{result.AccountId, result.ResourceType, result.ResourceName, "deleted"}} + } else { + headers = []string{"Account", "ResourceType", "Bucket", "Status", "Policy"} + data = [][]any{{result.AccountId, result.ResourceType, result.ResourceName, "deleted", json.RawMessage(policy)}} + } + output.HandleOutput(w, outputFlag, headers, data) + return nil +} + +// selectSingleAccount resolves a single account ID from the --account flag or +// via a single-select TUI (no "all" option). +func selectSingleAccount(ctx context.Context, accountId string) (string, error) { + var flag []string + if accountId != "" { + flag = []string{accountId} + } + awscfg, err := aws.LoadAWSConfig(ctx) + if err != nil { + return "", fmt.Errorf("failed to load aws config: %w", err) + } + return ui.SelectSingleTargetAccount(ctx, aws.NewOrganizationsClient(awscfg), flag) +} diff --git a/cmd/delete_s3_bucket_policy_test.go b/cmd/delete_s3_bucket_policy_test.go new file mode 100644 index 0000000..5ed676b --- /dev/null +++ b/cmd/delete_s3_bucket_policy_test.go @@ -0,0 +1,97 @@ +package cmd + +import ( + "bytes" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/unicrons/aws-root-manager/rootmanager" +) + +// When getBucketPolicyResult is empty, the command prints "No bucket policy found." and exits +// without invoking the TUI confirmation — safe to use in tests. + +func TestDeleteS3BucketPolicyCommand_NoPolicyFound(t *testing.T) { + mock := &mockRootManager{ + getBucketPolicyResult: "", + } + + var buf bytes.Buffer + cmd := Delete(newMockFactory(mock)) + cmd.SetOut(&buf) + cmd.SetArgs([]string{"s3-bucket-policy", "--account", "123456789012", "--bucket", "my-bucket"}) + + require.NoError(t, cmd.Execute()) + assert.Contains(t, buf.String(), "No bucket policy found.") +} + +func TestDeleteS3BucketPolicyCommand_GetPolicyError(t *testing.T) { + mock := &mockRootManager{ + getBucketPolicyErr: errors.New("assume root denied"), + } + + cmd := Delete(newMockFactory(mock)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"s3-bucket-policy", "--account", "123456789012", "--bucket", "my-bucket"}) + + require.Error(t, cmd.Execute()) +} + +func TestDeleteS3BucketPolicyCommand_FactoryError(t *testing.T) { + factoryErr := errors.New("failed to load AWS config") + + cmd := Delete(newFailingFactory(factoryErr)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"s3-bucket-policy", "--account", "123456789012", "--bucket", "my-bucket"}) + + err := cmd.Execute() + require.Error(t, err) + assert.ErrorIs(t, err, factoryErr) +} + +func TestDeleteS3BucketPolicyCommand_DeletionFailure(t *testing.T) { + mock := &mockRootManager{ + // Return non-empty policy so get succeeds, but deletion fails. + // Confirmation TUI is skipped because tests aren't interactive — + // PromptSingle returns -1 (no selection), which maps to "No". + getBucketPolicyResult: `{"Version":"2012-10-17"}`, + deleteBucketResult: rootmanager.PolicyDeletionResult{ + AccountId: "123456789012", Success: false, Error: "access denied", + }, + } + + cmd := Delete(newMockFactory(mock)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"s3-bucket-policy", "--account", "123456789012", "--bucket", "my-bucket"}) + + // Non-interactive: confirm TUI will fail/return no-selection → "Aborted." + _ = cmd.Execute() +} + +func TestDeleteS3BucketPolicyCommand_NoBucketsFoundInTUI(t *testing.T) { + mock := &mockRootManager{ + listBucketsResult: []string{}, + } + + cmd := Delete(newMockFactory(mock)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"s3-bucket-policy", "--account", "123456789012"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "no buckets found") +} + +func TestDeleteS3BucketPolicyCommand_ListBucketsError(t *testing.T) { + mock := &mockRootManager{ + listBucketsErr: errors.New("assume root denied"), + } + + cmd := Delete(newMockFactory(mock)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"s3-bucket-policy", "--account", "123456789012"}) + + require.Error(t, cmd.Execute()) +} diff --git a/cmd/delete_sqs_queue_policy.go b/cmd/delete_sqs_queue_policy.go new file mode 100644 index 0000000..9b735d9 --- /dev/null +++ b/cmd/delete_sqs_queue_policy.go @@ -0,0 +1,105 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + + "github.com/spf13/cobra" + "github.com/unicrons/aws-root-manager/internal/cli/output" + "github.com/unicrons/aws-root-manager/internal/cli/ui" + "github.com/unicrons/aws-root-manager/rootmanager" +) + +func DeleteSQSQueuePolicy(newRM func(context.Context) (rootmanager.RootManager, error)) *cobra.Command { + var accountId, queueUrl string + cmd := &cobra.Command{ + Use: "sqs-queue-policy", + Short: "Delete an SQS queue policy", + Long: `Clear the access policy attached to an SQS queue owned by a member account using the SQSUnlockQueuePolicy root task policy.`, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return runDeleteSQSQueuePolicy(newRM, cmd.OutOrStdout(), accountId, queueUrl) + }, + } + cmd.Flags().StringVar(&accountId, "account", "", "AWS account ID that owns the queue (optional; if absent, a TUI lists the organization's accounts)") + cmd.Flags().StringVar(&queueUrl, "queue", "", "URL of the SQS queue (optional; if absent, a TUI lists the account's queues)") + return cmd +} + +func runDeleteSQSQueuePolicy(newRM func(context.Context) (rootmanager.RootManager, error), w io.Writer, accountId, queueUrl string) error { + ctx := context.Background() + rm, err := newRM(ctx) + if err != nil { + return fmt.Errorf("failed to initialize root manager: %w", err) + } + + accountId, err = selectSingleAccount(ctx, accountId) + if err != nil { + return err + } + + if queueUrl == "" { + queues, err := rm.ListAccountQueues(ctx, accountId) + if err != nil { + return fmt.Errorf("failed to list queues for account %s: %w", accountId, err) + } + if len(queues) == 0 { + return fmt.Errorf("no queues found in account %s", accountId) + } + idx, err := ui.PromptSingle("Select the queue whose policy will be deleted", queues) + if err != nil { + return err + } + if idx < 0 { + return fmt.Errorf("no queue selected") + } + queueUrl = queues[idx] + } + + policy, err := rm.GetSQSQueuePolicy(ctx, accountId, queueUrl) + if err != nil { + return fmt.Errorf("failed to get queue policy: %w", err) + } + if policy == "" { + fmt.Fprintln(w, "No queue policy found.") + return nil + } + if outputFlag == "table" { + fmt.Fprintf(w, "Current queue policy for %s:\n\n", queueUrl) + output.RenderPolicy(w, policy) + + confirmed, err := ui.Confirm("Delete this policy?") + if err != nil { + return err + } + if !confirmed { + fmt.Fprintln(w, "Aborted.") + return nil + } + } + + result, err := rm.DeleteSQSQueuePolicy(ctx, accountId, queueUrl) + if err != nil { + return err + } + + if !result.Success { + slog.Error("failed to delete sqs queue policy", "account_id", result.AccountId, "queue_url", result.ResourceName, "error", result.Error) + return fmt.Errorf("failed to delete queue policy for queue %s", result.ResourceName) + } + + var headers []string + var data [][]any + if outputFlag == "table" { + headers = []string{"Account", "ResourceType", "Queue", "Status"} + data = [][]any{{result.AccountId, result.ResourceType, result.ResourceName, "deleted"}} + } else { + headers = []string{"Account", "ResourceType", "Queue", "Status", "Policy"} + data = [][]any{{result.AccountId, result.ResourceType, result.ResourceName, "deleted", json.RawMessage(policy)}} + } + output.HandleOutput(w, outputFlag, headers, data) + return nil +} diff --git a/cmd/delete_sqs_queue_policy_test.go b/cmd/delete_sqs_queue_policy_test.go new file mode 100644 index 0000000..89a1d0c --- /dev/null +++ b/cmd/delete_sqs_queue_policy_test.go @@ -0,0 +1,77 @@ +package cmd + +import ( + "bytes" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// When getQueuePolicyResult is empty, the command prints "No queue policy found." and exits +// without invoking the TUI confirmation — safe to use in tests. + +func TestDeleteSQSQueuePolicyCommand_NoPolicyFound(t *testing.T) { + mock := &mockRootManager{ + getQueuePolicyResult: "", + } + + var buf bytes.Buffer + cmd := Delete(newMockFactory(mock)) + cmd.SetOut(&buf) + cmd.SetArgs([]string{"sqs-queue-policy", "--account", "123456789012", "--queue", "https://sqs/q1"}) + + require.NoError(t, cmd.Execute()) + assert.Contains(t, buf.String(), "No queue policy found.") +} + +func TestDeleteSQSQueuePolicyCommand_GetPolicyError(t *testing.T) { + mock := &mockRootManager{ + getQueuePolicyErr: errors.New("assume root denied"), + } + + cmd := Delete(newMockFactory(mock)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"sqs-queue-policy", "--account", "123456789012", "--queue", "https://sqs/q1"}) + + require.Error(t, cmd.Execute()) +} + +func TestDeleteSQSQueuePolicyCommand_FactoryError(t *testing.T) { + factoryErr := errors.New("failed to load AWS config") + + cmd := Delete(newFailingFactory(factoryErr)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"sqs-queue-policy", "--account", "123456789012", "--queue", "https://sqs/q1"}) + + err := cmd.Execute() + require.Error(t, err) + assert.ErrorIs(t, err, factoryErr) +} + +func TestDeleteSQSQueuePolicyCommand_NoQueuesFoundInTUI(t *testing.T) { + mock := &mockRootManager{ + listQueuesResult: []string{}, + } + + cmd := Delete(newMockFactory(mock)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"sqs-queue-policy", "--account", "123456789012"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "no queues found") +} + +func TestDeleteSQSQueuePolicyCommand_ListQueuesError(t *testing.T) { + mock := &mockRootManager{ + listQueuesErr: errors.New("assume root denied"), + } + + cmd := Delete(newMockFactory(mock)) + cmd.SilenceErrors = true + cmd.SetArgs([]string{"sqs-queue-policy", "--account", "123456789012"}) + + require.Error(t, cmd.Execute()) +} diff --git a/cmd/mock_test.go b/cmd/mock_test.go index 0e133c1..965166b 100644 --- a/cmd/mock_test.go +++ b/cmd/mock_test.go @@ -19,6 +19,20 @@ type mockRootManager struct { deleteErr error recoveryResult []rootmanager.RecoveryResult recoveryErr error + + getBucketPolicyResult string + getBucketPolicyErr error + listBucketsResult []string + listBucketsErr error + deleteBucketResult rootmanager.PolicyDeletionResult + deleteBucketErr error + + getQueuePolicyResult string + getQueuePolicyErr error + listQueuesResult []string + listQueuesErr error + deleteQueueResult rootmanager.PolicyDeletionResult + deleteQueueErr error } func (m *mockRootManager) CheckRootAccess(_ context.Context) (rootmanager.RootAccessStatus, error) { @@ -36,6 +50,24 @@ func (m *mockRootManager) DeleteCredentials(_ context.Context, _ []rootmanager.R func (m *mockRootManager) RecoverRootPassword(_ context.Context, _ []string) ([]rootmanager.RecoveryResult, error) { return m.recoveryResult, m.recoveryErr } +func (m *mockRootManager) GetS3BucketPolicy(_ context.Context, _, _ string) (string, error) { + return m.getBucketPolicyResult, m.getBucketPolicyErr +} +func (m *mockRootManager) ListAccountBuckets(_ context.Context, _ string) ([]string, error) { + return m.listBucketsResult, m.listBucketsErr +} +func (m *mockRootManager) DeleteS3BucketPolicy(_ context.Context, _, _ string) (rootmanager.PolicyDeletionResult, error) { + return m.deleteBucketResult, m.deleteBucketErr +} +func (m *mockRootManager) GetSQSQueuePolicy(_ context.Context, _, _ string) (string, error) { + return m.getQueuePolicyResult, m.getQueuePolicyErr +} +func (m *mockRootManager) ListAccountQueues(_ context.Context, _ string) ([]string, error) { + return m.listQueuesResult, m.listQueuesErr +} +func (m *mockRootManager) DeleteSQSQueuePolicy(_ context.Context, _, _ string) (rootmanager.PolicyDeletionResult, error) { + return m.deleteQueueResult, m.deleteQueueErr +} // newMockFactory returns a factory function that always returns the given mock. func newMockFactory(mock rootmanager.RootManager) func(context.Context) (rootmanager.RootManager, error) { diff --git a/go.mod b/go.mod index dfd885f..d272043 100644 --- a/go.mod +++ b/go.mod @@ -17,13 +17,19 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.99.0 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sqs v1.42.25 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect github.com/aws/smithy-go v1.24.2 // indirect diff --git a/go.sum b/go.sum index 783d1a2..1caba56 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs= charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM= github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI= github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo= github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= @@ -16,16 +18,26 @@ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgq github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 h1:rWyie/PxDRIdhNf4DzRk0lvjVOqFJuNnO8WwaIRVxzQ= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22/go.mod h1:zd/JsJ4P7oGfUhXn1VyLqaRZwPmZwg44Jf2dS84Dm3Y= github.com/aws/aws-sdk-go-v2/service/iam v1.53.7 h1:n9YLiWtX3+6pTLZWvRJmtq5JIB9NA/KFelyCg5fOlTU= github.com/aws/aws-sdk-go-v2/service/iam v1.53.7/go.mod h1:sP46Vo6MeJcM4s0ZXcG2PFmfiSyixhIuC/74W52yKuk= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 h1:JRaIgADQS/U6uXDqlPiefP32yXTda7Kqfx+LgspooZM= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13/go.mod h1:CEuVn5WqOMilYl+tbccq8+N2ieCy0gVn3OtRb0vBNNM= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 h1:ZlvrNcHSFFWURB8avufQq9gFsheUgjVD9536obIknfM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21/go.mod h1:cv3TNhVrssKR0O/xxLJVRfd2oazSnZnkUeTf6ctUwfQ= github.com/aws/aws-sdk-go-v2/service/organizations v1.51.1 h1:5hM1jQjIzEiu07ZqQ8iI4sC+06C8a+idNtytO65dhAw= github.com/aws/aws-sdk-go-v2/service/organizations v1.51.1/go.mod h1:urLFj1twuR/h5T0wN/2/kmY1gxBFa1tTKr+c60lZ2fA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.99.0 h1:hlSuz394kV0vhv9drL5lhuEFbEOEP1VyQpy15qWh1Pk= +github.com/aws/aws-sdk-go-v2/service/s3 v1.99.0/go.mod h1:uoA43SdFwacedBfSgfFSjjCvYe8aYBS7EnU5GZ/YKMM= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI= +github.com/aws/aws-sdk-go-v2/service/sqs v1.42.25 h1:8Bv3TQ1Cob6HLlpUbAnWxeHhAkYScJO9RIHh2WPXaxw= +github.com/aws/aws-sdk-go-v2/service/sqs v1.42.25/go.mod h1:eDstEbM0OEnBUnNQxIA7j74Jy61cCU1S4EMlCtdMwzs= github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM= github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys= diff --git a/internal/aws/factory.go b/internal/aws/factory.go index 1c40507..296ab30 100644 --- a/internal/aws/factory.go +++ b/internal/aws/factory.go @@ -19,3 +19,27 @@ type DefaultIamClientFactory struct{} func (f *DefaultIamClientFactory) NewIamClient(cfg awssdk.Config) IamClient { return NewIamClient(cfg) } + +// S3ClientFactory creates S3 clients with a given AWS config. +type S3ClientFactory interface { + NewS3Client(cfg awssdk.Config) S3Client +} + +// DefaultS3ClientFactory is the production implementation of S3ClientFactory. +type DefaultS3ClientFactory struct{} + +func (f *DefaultS3ClientFactory) NewS3Client(cfg awssdk.Config) S3Client { + return NewS3Client(cfg) +} + +// SqsClientFactory creates SQS clients with a given AWS config. +type SqsClientFactory interface { + NewSqsClient(cfg awssdk.Config) SqsClient +} + +// DefaultSqsClientFactory is the production implementation of SqsClientFactory. +type DefaultSqsClientFactory struct{} + +func (f *DefaultSqsClientFactory) NewSqsClient(cfg awssdk.Config) SqsClient { + return NewSqsClient(cfg) +} diff --git a/internal/aws/interfaces.go b/internal/aws/interfaces.go index 53930e5..bc71d43 100644 --- a/internal/aws/interfaces.go +++ b/internal/aws/interfaces.go @@ -53,6 +53,28 @@ type StsClient interface { GetAssumeRootConfig(ctx context.Context, accountId, taskPolicyName string) (awssdk.Config, error) } +// S3Client defines the interface for S3 operations scoped to a single account. +// This interface enables mocking and dependency injection for testing. +type S3Client interface { + // ListBuckets returns the names of all buckets owned by the caller. + ListBuckets(ctx context.Context) ([]string, error) + // GetBucketPolicy returns the bucket policy JSON, or empty string if none exists. + GetBucketPolicy(ctx context.Context, bucketName string) (string, error) + // DeleteBucketPolicy deletes the bucket policy attached to the given bucket. + DeleteBucketPolicy(ctx context.Context, bucketName string) error +} + +// SqsClient defines the interface for SQS operations scoped to a single account. +// This interface enables mocking and dependency injection for testing. +type SqsClient interface { + // ListQueues returns the URLs of all queues owned by the caller. + ListQueues(ctx context.Context) ([]string, error) + // GetQueuePolicy returns the queue policy JSON, or empty string if none exists. + GetQueuePolicy(ctx context.Context, queueUrl string) (string, error) + // DeleteQueuePolicy clears the access policy attached to the given queue URL. + DeleteQueuePolicy(ctx context.Context, queueUrl string) error +} + // OrganizationsClient defines the interface for AWS Organizations operations. // This interface enables mocking and dependency injection for testing. type OrganizationsClient interface { diff --git a/internal/aws/s3.go b/internal/aws/s3.go new file mode 100644 index 0000000..1fce076 --- /dev/null +++ b/internal/aws/s3.go @@ -0,0 +1,62 @@ +package aws + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +type s3Client struct { + client *s3.Client +} + +func NewS3Client(awscfg aws.Config) S3Client { + return &s3Client{client: s3.NewFromConfig(awscfg)} +} + +func (c *s3Client) ListBuckets(ctx context.Context) ([]string, error) { + slog.Debug("listing s3 buckets") + + output, err := c.client.ListBuckets(ctx, &s3.ListBucketsInput{}) + if err != nil { + return nil, fmt.Errorf("error listing s3 buckets: %w", err) + } + + buckets := make([]string, 0, len(output.Buckets)) + for _, b := range output.Buckets { + buckets = append(buckets, aws.ToString(b.Name)) + } + return buckets, nil +} + +// GetBucketPolicy returns the bucket policy JSON string, or empty string if no policy exists. +func (c *s3Client) GetBucketPolicy(ctx context.Context, bucketName string) (string, error) { + slog.Debug("getting s3 bucket policy", "bucket", bucketName) + + output, err := c.client.GetBucketPolicy(ctx, &s3.GetBucketPolicyInput{ + Bucket: aws.String(bucketName), + }) + if err != nil { + if strings.Contains(err.Error(), "NoSuchBucketPolicy") { + return "", nil + } + return "", fmt.Errorf("error getting bucket policy for bucket %s: %w", bucketName, err) + } + return aws.ToString(output.Policy), nil +} + +func (c *s3Client) DeleteBucketPolicy(ctx context.Context, bucketName string) error { + slog.Debug("deleting s3 bucket policy", "bucket", bucketName) + + _, err := c.client.DeleteBucketPolicy(ctx, &s3.DeleteBucketPolicyInput{ + Bucket: aws.String(bucketName), + }) + if err != nil { + return fmt.Errorf("error deleting bucket policy for bucket %s: %w", bucketName, err) + } + return nil +} diff --git a/internal/aws/sqs.go b/internal/aws/sqs.go new file mode 100644 index 0000000..5c1c7c4 --- /dev/null +++ b/internal/aws/sqs.go @@ -0,0 +1,67 @@ +package aws + +import ( + "context" + "fmt" + "log/slog" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" +) + +type sqsClient struct { + client *sqs.Client +} + +func NewSqsClient(awscfg aws.Config) SqsClient { + return &sqsClient{client: sqs.NewFromConfig(awscfg)} +} + +func (c *sqsClient) ListQueues(ctx context.Context) ([]string, error) { + slog.Debug("listing sqs queues") + + var queues []string + var nextToken *string + for { + output, err := c.client.ListQueues(ctx, &sqs.ListQueuesInput{NextToken: nextToken}) + if err != nil { + return nil, fmt.Errorf("error listing sqs queues: %w", err) + } + queues = append(queues, output.QueueUrls...) + if output.NextToken == nil { + break + } + nextToken = output.NextToken + } + return queues, nil +} + +// GetQueuePolicy returns the queue policy JSON string, or empty string if no policy exists. +func (c *sqsClient) GetQueuePolicy(ctx context.Context, queueUrl string) (string, error) { + slog.Debug("getting sqs queue policy", "queue_url", queueUrl) + + output, err := c.client.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{ + QueueUrl: aws.String(queueUrl), + AttributeNames: []types.QueueAttributeName{types.QueueAttributeNamePolicy}, + }) + if err != nil { + return "", fmt.Errorf("error getting queue policy for queue %s: %w", queueUrl, err) + } + return output.Attributes[string(types.QueueAttributeNamePolicy)], nil +} + +func (c *sqsClient) DeleteQueuePolicy(ctx context.Context, queueUrl string) error { + slog.Debug("deleting sqs queue policy", "queue_url", queueUrl) + + _, err := c.client.SetQueueAttributes(ctx, &sqs.SetQueueAttributesInput{ + QueueUrl: aws.String(queueUrl), + Attributes: map[string]string{ + string(types.QueueAttributeNamePolicy): "", + }, + }) + if err != nil { + return fmt.Errorf("error deleting queue policy for queue %s: %w", queueUrl, err) + } + return nil +} diff --git a/internal/cli/output/output.go b/internal/cli/output/output.go index f00f5c0..5022f40 100644 --- a/internal/cli/output/output.go +++ b/internal/cli/output/output.go @@ -1,6 +1,7 @@ package output import ( + "encoding/json" "fmt" "io" "log/slog" @@ -29,7 +30,12 @@ func dataToString(data [][]any) [][]string { for i, row := range data { convertedRow := make([]string, len(row)) for j, cell := range row { - convertedRow[j] = fmt.Sprintf("%v", cell) + switch v := cell.(type) { + case json.RawMessage: + convertedRow[j] = string(v) + default: + convertedRow[j] = fmt.Sprintf("%v", cell) + } } convertedData[i] = convertedRow } diff --git a/internal/cli/output/policy.go b/internal/cli/output/policy.go new file mode 100644 index 0000000..41a657b --- /dev/null +++ b/internal/cli/output/policy.go @@ -0,0 +1,37 @@ +package output + +import ( + "encoding/json" + "fmt" + "io" + + "charm.land/lipgloss/v2" +) + +var policyBoxStyle = lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color("8")). + Padding(0, 1) + +// RenderPolicy pretty-prints a JSON policy string inside a styled box and writes it to w. +// Used for interactive (table) output mode only. +func RenderPolicy(w io.Writer, policy string) { + pretty, err := prettyJSON(policy) + if err != nil { + pretty = policy + } + termWidth := getTerminalWidth() + fmt.Fprintln(w, policyBoxStyle.MaxWidth(termWidth-4).Render(pretty)) +} + +func prettyJSON(raw string) (string, error) { + var v any + if err := json.Unmarshal([]byte(raw), &v); err != nil { + return "", err + } + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + return "", err + } + return string(b), nil +} diff --git a/internal/cli/ui/account_selector.go b/internal/cli/ui/account_selector.go index 7798f72..59ba8eb 100644 --- a/internal/cli/ui/account_selector.go +++ b/internal/cli/ui/account_selector.go @@ -10,7 +10,7 @@ import ( const ( AllAccountsOption = "all" - AllAccountsSelectorText = "all non management accounts" + AllAccountsSelectorText = "All Accounts" ) // SelectTargetAccounts handles interactive account selection or returns accounts based on flags. @@ -58,6 +58,40 @@ func SelectTargetAccounts(ctx context.Context, org aws.OrganizationsClient, acco return extractSelectedAccounts(orgAccounts, selectedIndexes), nil } +// SelectSingleTargetAccount resolves exactly one account ID from the flag or +// a single-select TUI. Unlike SelectTargetAccounts it has no "all" option. +// org is only used when accountsFlag is empty. +func SelectSingleTargetAccount(ctx context.Context, org aws.OrganizationsClient, accountsFlag []string) (string, error) { + slog.Debug("processing single target account", "accounts_flag", accountsFlag) + + if len(accountsFlag) == 1 && accountsFlag[0] != AllAccountsOption { + return accountsFlag[0], nil + } + if len(accountsFlag) > 1 { + return "", fmt.Errorf("this command operates on a single account; got %d", len(accountsFlag)) + } + + orgAccounts, err := aws.GetNonManagementOrganizationAccounts(ctx, org) + if err != nil { + return "", fmt.Errorf("error fetching organization accounts: %w", err) + } + + var choices []string + for _, account := range orgAccounts { + choices = append(choices, fmt.Sprintf("%s - %s", account.AccountID, account.Name)) + } + + idx, err := PromptSingle("Please select the AWS account", choices) + if err != nil { + return "", err + } + if idx < 0 { + return "", fmt.Errorf("no account selected") + } + + return orgAccounts[idx].AccountID, nil +} + // Checks if all option is selected func allSelected(selectedIndexes []int) bool { for _, index := range selectedIndexes { diff --git a/internal/cli/ui/confirm.go b/internal/cli/ui/confirm.go new file mode 100644 index 0000000..2e78c43 --- /dev/null +++ b/internal/cli/ui/confirm.go @@ -0,0 +1,10 @@ +package ui + +// Confirm shows a yes/no single-select TUI. Returns true if the user chose "Yes". +func Confirm(question string) (bool, error) { + idx, err := PromptSingle(question, []string{"Yes", "No"}) + if err != nil { + return false, err + } + return idx == 0, nil +} diff --git a/internal/cli/ui/selector.go b/internal/cli/ui/selector.go index 89e5157..0e7de6d 100644 --- a/internal/cli/ui/selector.go +++ b/internal/cli/ui/selector.go @@ -19,6 +19,7 @@ var ( cursorStyle = lipgloss.NewStyle().Foreground(pink) helpTextMultipleChoice = "↑/↓/←/→: Navigate • Space: Select • Enter: Confirm" + helpTextSingleChoice = "↑/↓/←/→: Navigate • Enter: Select" ) type model struct { @@ -32,6 +33,7 @@ type model struct { quit bool pageSize int currentPage int + single bool // when true, enter selects the current cursor item and quits } func (m model) Init() tea.Cmd { @@ -46,6 +48,14 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.quit = true return m, tea.Quit case "enter": + if m.single { + if len(m.filtered) > 0 { + chosen := m.filtered[m.currentPage*m.pageSize+m.cursor] + m.selected = map[string]struct{}{chosen: {}} + return m, tea.Quit + } + return m, nil + } if len(m.selected) > 0 { return m, tea.Quit } @@ -74,6 +84,9 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.cursor = 0 } case "space": + if m.single { + return m, nil + } if len(m.filtered) > 0 { currentChoice := m.filtered[m.currentPage*m.pageSize+m.cursor] if _, ok := m.selected[currentChoice]; ok { @@ -138,6 +151,11 @@ func (m model) View() tea.View { cursor = cursorStyle.Render(">") } + if m.single { + s += fmt.Sprintf("%s %s\n", cursor, choice) + continue + } + checked := " " if _, ok := m.selected[choice]; ok { checked = cursorStyle.Render("x") @@ -157,33 +175,52 @@ func (m model) View() tea.View { } } - if len(m.selected) == 0 { + if !m.single && len(m.selected) == 0 { s += "\n" + helpStyle.Render("Please select at least one item") } - s += "\n" + helpStyle.Render(helpTextMultipleChoice) + helpText := helpTextMultipleChoice + if m.single { + helpText = helpTextSingleChoice + } + s += "\n" + helpStyle.Render(helpText) return tea.NewView(s) } +// Prompt shows a multi-select TUI and returns the original indexes of the +// chosen items. func Prompt(question string, choices []string) ([]int, error) { + return runPrompt(question, choices, false) +} + +// PromptSingle shows a single-select TUI and returns the original index of the +// chosen item, or -1 if the user quit without selecting. +func PromptSingle(question string, choices []string) (int, error) { + indexes, err := runPrompt(question, choices, true) + if err != nil { + return -1, err + } + if len(indexes) == 0 { + return -1, nil + } + return indexes[0], nil +} + +func runPrompt(question string, choices []string, single bool) ([]int, error) { + allChoicesMap := make(map[int]string, len(choices)) + for i, choice := range choices { + allChoicesMap[i] = choice + } + m := model{ - question: question, - choices: choices, - filtered: choices, - filter: "", - cursor: 0, - selected: make(map[string]struct{}), - allChoicesMap: func() map[int]string { - m := make(map[int]string) - for i, choice := range choices { - m[i] = choice - } - return m - }(), - quit: false, - pageSize: 10, - currentPage: 0, + question: question, + choices: choices, + filtered: choices, + selected: make(map[string]struct{}), + allChoicesMap: allChoicesMap, + pageSize: 10, + single: single, } p := tea.NewProgram(m) diff --git a/rootmanager/api.go b/rootmanager/api.go index af7c430..9d277fb 100644 --- a/rootmanager/api.go +++ b/rootmanager/api.go @@ -38,4 +38,28 @@ type RootManager interface { // This triggers AWS to send password reset emails to the account's root email address. // Returns a slice of RecoveryResult showing the outcome for each account. RecoverRootPassword(ctx context.Context, accountIds []string) ([]RecoveryResult, error) + + // GetS3BucketPolicy returns the JSON policy attached to the given bucket using + // AssumeRoot with the S3UnlockBucketPolicy task policy. Returns empty string if none. + GetS3BucketPolicy(ctx context.Context, accountId, bucketName string) (string, error) + + // ListAccountBuckets returns the names of all S3 buckets owned by the given account + // using AssumeRoot with the S3UnlockBucketPolicy task policy. + ListAccountBuckets(ctx context.Context, accountId string) ([]string, error) + + // DeleteS3BucketPolicy removes the bucket policy from the given bucket using + // AssumeRoot in the bucket's owning account with the S3UnlockBucketPolicy task policy. + DeleteS3BucketPolicy(ctx context.Context, accountId, bucketName string) (PolicyDeletionResult, error) + + // GetSQSQueuePolicy returns the JSON policy attached to the given queue URL using + // AssumeRoot with the SQSUnlockQueuePolicy task policy. Returns empty string if none. + GetSQSQueuePolicy(ctx context.Context, accountId, queueUrl string) (string, error) + + // ListAccountQueues returns the URLs of all SQS queues owned by the given account + // using AssumeRoot with the SQSUnlockQueuePolicy task policy. + ListAccountQueues(ctx context.Context, accountId string) ([]string, error) + + // DeleteSQSQueuePolicy clears the access policy from the given queue URL using + // AssumeRoot in the queue's owning account with the SQSUnlockQueuePolicy task policy. + DeleteSQSQueuePolicy(ctx context.Context, accountId, queueUrl string) (PolicyDeletionResult, error) } diff --git a/rootmanager/manager.go b/rootmanager/manager.go index 0140b5c..a168fb4 100644 --- a/rootmanager/manager.go +++ b/rootmanager/manager.go @@ -10,16 +10,18 @@ import ( // manager implements RootManager using AWS clients. type manager struct { - iam aws.IamClient - sts aws.StsClient - org aws.OrganizationsClient - factory aws.IamClientFactory + iam aws.IamClient + sts aws.StsClient + org aws.OrganizationsClient + factory aws.IamClientFactory + s3Factory aws.S3ClientFactory + sqsFactory aws.SqsClientFactory } -// newManager returns a RootManager that uses the given AWS clients and factory. +// newManager returns a RootManager that uses the given AWS clients and factories. // sts and org may be nil for callers that only use CheckRootAccess. -func newManager(iam aws.IamClient, sts aws.StsClient, org aws.OrganizationsClient, factory aws.IamClientFactory) RootManager { - return &manager{iam: iam, sts: sts, org: org, factory: factory} +func newManager(iam aws.IamClient, sts aws.StsClient, org aws.OrganizationsClient, factory aws.IamClientFactory, s3Factory aws.S3ClientFactory, sqsFactory aws.SqsClientFactory) RootManager { + return &manager{iam: iam, sts: sts, org: org, factory: factory, s3Factory: s3Factory, sqsFactory: sqsFactory} } // NewRootManager returns a RootManager configured from the default AWS environment. @@ -40,9 +42,53 @@ func NewRootManager(ctx context.Context) (RootManager, error) { aws.NewStsClient(stsCfg), aws.NewOrganizationsClient(cfg), &aws.DefaultIamClientFactory{}, + &aws.DefaultS3ClientFactory{}, + &aws.DefaultSqsClientFactory{}, ), nil } +func (m *manager) GetS3BucketPolicy(ctx context.Context, accountId, bucketName string) (string, error) { + if m.sts == nil { + return "", errors.New("STS client required for get") + } + return getS3BucketPolicy(ctx, m.sts, m.s3Factory, accountId, bucketName) +} + +func (m *manager) ListAccountBuckets(ctx context.Context, accountId string) ([]string, error) { + if m.sts == nil { + return nil, errors.New("STS client required for listing buckets") + } + return listAccountBuckets(ctx, m.sts, m.s3Factory, accountId) +} + +func (m *manager) DeleteS3BucketPolicy(ctx context.Context, accountId, bucketName string) (PolicyDeletionResult, error) { + if m.sts == nil { + return PolicyDeletionResult{}, errors.New("STS client required for delete") + } + return deleteS3BucketPolicy(ctx, m.sts, m.s3Factory, accountId, bucketName) +} + +func (m *manager) GetSQSQueuePolicy(ctx context.Context, accountId, queueUrl string) (string, error) { + if m.sts == nil { + return "", errors.New("STS client required for get") + } + return getSQSQueuePolicy(ctx, m.sts, m.sqsFactory, accountId, queueUrl) +} + +func (m *manager) ListAccountQueues(ctx context.Context, accountId string) ([]string, error) { + if m.sts == nil { + return nil, errors.New("STS client required for listing queues") + } + return listAccountQueues(ctx, m.sts, m.sqsFactory, accountId) +} + +func (m *manager) DeleteSQSQueuePolicy(ctx context.Context, accountId, queueUrl string) (PolicyDeletionResult, error) { + if m.sts == nil { + return PolicyDeletionResult{}, errors.New("STS client required for delete") + } + return deleteSQSQueuePolicy(ctx, m.sts, m.sqsFactory, accountId, queueUrl) +} + func (m *manager) AuditAccounts(ctx context.Context, accountIds []string) ([]RootCredentials, error) { if m.sts == nil { return nil, errors.New("STS client required for audit") diff --git a/rootmanager/mock_test.go b/rootmanager/mock_test.go index d91bded..ce60673 100644 --- a/rootmanager/mock_test.go +++ b/rootmanager/mock_test.go @@ -98,3 +98,57 @@ type mockIamClientFactory struct { func (f *mockIamClientFactory) NewIamClient(_ awssdk.Config) aws.IamClient { return f.client } + +// mockS3Client implements aws.S3Client for testing. +type mockS3Client struct { + listBucketsResult []string + listBucketsErr error + getBucketPolResult string + getBucketPolErr error + deleteBucketPolErr error +} + +func (m *mockS3Client) ListBuckets(_ context.Context) ([]string, error) { + return m.listBucketsResult, m.listBucketsErr +} +func (m *mockS3Client) GetBucketPolicy(_ context.Context, _ string) (string, error) { + return m.getBucketPolResult, m.getBucketPolErr +} +func (m *mockS3Client) DeleteBucketPolicy(_ context.Context, _ string) error { + return m.deleteBucketPolErr +} + +type mockS3ClientFactory struct { + client aws.S3Client +} + +func (f *mockS3ClientFactory) NewS3Client(_ awssdk.Config) aws.S3Client { + return f.client +} + +// mockSqsClient implements aws.SqsClient for testing. +type mockSqsClient struct { + listQueuesResult []string + listQueuesErr error + getQueuePolResult string + getQueuePolErr error + deleteQueuePolErr error +} + +func (m *mockSqsClient) ListQueues(_ context.Context) ([]string, error) { + return m.listQueuesResult, m.listQueuesErr +} +func (m *mockSqsClient) GetQueuePolicy(_ context.Context, _ string) (string, error) { + return m.getQueuePolResult, m.getQueuePolErr +} +func (m *mockSqsClient) DeleteQueuePolicy(_ context.Context, _ string) error { + return m.deleteQueuePolErr +} + +type mockSqsClientFactory struct { + client aws.SqsClient +} + +func (f *mockSqsClientFactory) NewSqsClient(_ awssdk.Config) aws.SqsClient { + return f.client +} diff --git a/rootmanager/policies.go b/rootmanager/policies.go new file mode 100644 index 0000000..98454ab --- /dev/null +++ b/rootmanager/policies.go @@ -0,0 +1,104 @@ +package rootmanager + +import ( + "context" + "log/slog" + + "github.com/unicrons/aws-root-manager/internal/aws" +) + +const ( + s3UnlockTaskPolicy = "S3UnlockBucketPolicy" + sqsUnlockTaskPolicy = "SQSUnlockQueuePolicy" + + resourceTypeS3Bucket = "s3-bucket" + resourceTypeSqsQueue = "sqs-queue" +) + +func getS3BucketPolicy(ctx context.Context, sts aws.StsClient, factory aws.S3ClientFactory, accountId, bucketName string) (string, error) { + slog.Debug("getting s3 bucket policy", "account_id", accountId, "bucket", bucketName) + + cfg, err := sts.GetAssumeRootConfig(ctx, accountId, s3UnlockTaskPolicy) + if err != nil { + return "", err + } + return factory.NewS3Client(cfg).GetBucketPolicy(ctx, bucketName) +} + +func listAccountBuckets(ctx context.Context, sts aws.StsClient, factory aws.S3ClientFactory, accountId string) ([]string, error) { + slog.Debug("listing account buckets", "account_id", accountId) + + cfg, err := sts.GetAssumeRootConfig(ctx, accountId, s3UnlockTaskPolicy) + if err != nil { + return nil, err + } + return factory.NewS3Client(cfg).ListBuckets(ctx) +} + +func deleteS3BucketPolicy(ctx context.Context, sts aws.StsClient, factory aws.S3ClientFactory, accountId, bucketName string) (PolicyDeletionResult, error) { + slog.Debug("deleting s3 bucket policy", "account_id", accountId, "bucket", bucketName) + + result := PolicyDeletionResult{ + AccountId: accountId, + ResourceType: resourceTypeS3Bucket, + ResourceName: bucketName, + } + + cfg, err := sts.GetAssumeRootConfig(ctx, accountId, s3UnlockTaskPolicy) + if err != nil { + result.Error = err.Error() + return result, nil + } + + if err := factory.NewS3Client(cfg).DeleteBucketPolicy(ctx, bucketName); err != nil { + result.Error = err.Error() + return result, nil + } + + result.Success = true + return result, nil +} + +func getSQSQueuePolicy(ctx context.Context, sts aws.StsClient, factory aws.SqsClientFactory, accountId, queueUrl string) (string, error) { + slog.Debug("getting sqs queue policy", "account_id", accountId, "queue_url", queueUrl) + + cfg, err := sts.GetAssumeRootConfig(ctx, accountId, sqsUnlockTaskPolicy) + if err != nil { + return "", err + } + return factory.NewSqsClient(cfg).GetQueuePolicy(ctx, queueUrl) +} + +func listAccountQueues(ctx context.Context, sts aws.StsClient, factory aws.SqsClientFactory, accountId string) ([]string, error) { + slog.Debug("listing account queues", "account_id", accountId) + + cfg, err := sts.GetAssumeRootConfig(ctx, accountId, sqsUnlockTaskPolicy) + if err != nil { + return nil, err + } + return factory.NewSqsClient(cfg).ListQueues(ctx) +} + +func deleteSQSQueuePolicy(ctx context.Context, sts aws.StsClient, factory aws.SqsClientFactory, accountId, queueUrl string) (PolicyDeletionResult, error) { + slog.Debug("deleting sqs queue policy", "account_id", accountId, "queue_url", queueUrl) + + result := PolicyDeletionResult{ + AccountId: accountId, + ResourceType: resourceTypeSqsQueue, + ResourceName: queueUrl, + } + + cfg, err := sts.GetAssumeRootConfig(ctx, accountId, sqsUnlockTaskPolicy) + if err != nil { + result.Error = err.Error() + return result, nil + } + + if err := factory.NewSqsClient(cfg).DeleteQueuePolicy(ctx, queueUrl); err != nil { + result.Error = err.Error() + return result, nil + } + + result.Success = true + return result, nil +} diff --git a/rootmanager/policies_test.go b/rootmanager/policies_test.go new file mode 100644 index 0000000..4b81801 --- /dev/null +++ b/rootmanager/policies_test.go @@ -0,0 +1,188 @@ +package rootmanager + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- getS3BucketPolicy --- + +func TestGetS3BucketPolicy_Success(t *testing.T) { + s3 := &mockS3Client{getBucketPolResult: `{"Version":"2012-10-17"}`} + factory := &mockS3ClientFactory{client: s3} + sts := &mockStsClient{} + + got, err := getS3BucketPolicy(context.Background(), sts, factory, "123456789012", "my-bucket") + require.NoError(t, err) + assert.Equal(t, `{"Version":"2012-10-17"}`, got) +} + +func TestGetS3BucketPolicy_NoPolicyReturnsEmpty(t *testing.T) { + s3 := &mockS3Client{getBucketPolResult: ""} + factory := &mockS3ClientFactory{client: s3} + sts := &mockStsClient{} + + got, err := getS3BucketPolicy(context.Background(), sts, factory, "123456789012", "my-bucket") + require.NoError(t, err) + assert.Empty(t, got) +} + +func TestGetS3BucketPolicy_STSError(t *testing.T) { + stsErr := errors.New("assume root denied") + sts := &mockStsClient{assumeRootErr: stsErr} + factory := &mockS3ClientFactory{client: &mockS3Client{}} + + _, err := getS3BucketPolicy(context.Background(), sts, factory, "123456789012", "my-bucket") + require.Error(t, err) + assert.ErrorIs(t, err, stsErr) +} + +// --- getSQSQueuePolicy --- + +func TestGetSQSQueuePolicy_Success(t *testing.T) { + sqs := &mockSqsClient{getQueuePolResult: `{"Version":"2012-10-17"}`} + factory := &mockSqsClientFactory{client: sqs} + sts := &mockStsClient{} + + got, err := getSQSQueuePolicy(context.Background(), sts, factory, "123456789012", "https://sqs/q1") + require.NoError(t, err) + assert.Equal(t, `{"Version":"2012-10-17"}`, got) +} + +func TestGetSQSQueuePolicy_STSError(t *testing.T) { + stsErr := errors.New("assume root denied") + sts := &mockStsClient{assumeRootErr: stsErr} + factory := &mockSqsClientFactory{client: &mockSqsClient{}} + + _, err := getSQSQueuePolicy(context.Background(), sts, factory, "123456789012", "https://sqs/q1") + require.Error(t, err) + assert.ErrorIs(t, err, stsErr) +} + +// --- listAccountBuckets --- + +func TestListAccountBuckets_Success(t *testing.T) { + s3 := &mockS3Client{listBucketsResult: []string{"a", "b"}} + factory := &mockS3ClientFactory{client: s3} + sts := &mockStsClient{} + + got, err := listAccountBuckets(context.Background(), sts, factory, "123456789012") + require.NoError(t, err) + assert.Equal(t, []string{"a", "b"}, got) +} + +func TestListAccountBuckets_STSError(t *testing.T) { + stsErr := errors.New("assume root denied") + sts := &mockStsClient{assumeRootErr: stsErr} + factory := &mockS3ClientFactory{client: &mockS3Client{}} + + _, err := listAccountBuckets(context.Background(), sts, factory, "123456789012") + require.Error(t, err) + assert.ErrorIs(t, err, stsErr) +} + +func TestListAccountBuckets_S3Error(t *testing.T) { + s3 := &mockS3Client{listBucketsErr: errors.New("access denied")} + factory := &mockS3ClientFactory{client: s3} + sts := &mockStsClient{} + + _, err := listAccountBuckets(context.Background(), sts, factory, "123456789012") + require.Error(t, err) +} + +// --- deleteS3BucketPolicy --- + +func TestDeleteS3BucketPolicy_Success(t *testing.T) { + s3 := &mockS3Client{} + factory := &mockS3ClientFactory{client: s3} + sts := &mockStsClient{} + + result, err := deleteS3BucketPolicy(context.Background(), sts, factory, "123456789012", "my-bucket") + require.NoError(t, err) + assert.True(t, result.Success) + assert.Equal(t, "s3-bucket", result.ResourceType) + assert.Equal(t, "my-bucket", result.ResourceName) + assert.Empty(t, result.Error) +} + +func TestDeleteS3BucketPolicy_STSError(t *testing.T) { + sts := &mockStsClient{assumeRootErr: errors.New("assume root denied")} + factory := &mockS3ClientFactory{client: &mockS3Client{}} + + result, err := deleteS3BucketPolicy(context.Background(), sts, factory, "123456789012", "my-bucket") + require.NoError(t, err) + assert.False(t, result.Success) + assert.NotEmpty(t, result.Error) +} + +func TestDeleteS3BucketPolicy_S3Error(t *testing.T) { + s3 := &mockS3Client{deleteBucketPolErr: errors.New("policy not found")} + factory := &mockS3ClientFactory{client: s3} + sts := &mockStsClient{} + + result, err := deleteS3BucketPolicy(context.Background(), sts, factory, "123456789012", "my-bucket") + require.NoError(t, err) + assert.False(t, result.Success) + assert.NotEmpty(t, result.Error) +} + +// --- listAccountQueues --- + +func TestListAccountQueues_Success(t *testing.T) { + sqs := &mockSqsClient{listQueuesResult: []string{"https://sqs/q1", "https://sqs/q2"}} + factory := &mockSqsClientFactory{client: sqs} + sts := &mockStsClient{} + + got, err := listAccountQueues(context.Background(), sts, factory, "123456789012") + require.NoError(t, err) + assert.Len(t, got, 2) +} + +func TestListAccountQueues_STSError(t *testing.T) { + stsErr := errors.New("assume root denied") + sts := &mockStsClient{assumeRootErr: stsErr} + factory := &mockSqsClientFactory{client: &mockSqsClient{}} + + _, err := listAccountQueues(context.Background(), sts, factory, "123456789012") + require.Error(t, err) + assert.ErrorIs(t, err, stsErr) +} + +// --- deleteSQSQueuePolicy --- + +func TestDeleteSQSQueuePolicy_Success(t *testing.T) { + sqs := &mockSqsClient{} + factory := &mockSqsClientFactory{client: sqs} + sts := &mockStsClient{} + + result, err := deleteSQSQueuePolicy(context.Background(), sts, factory, "123456789012", "https://sqs/q1") + require.NoError(t, err) + assert.True(t, result.Success) + assert.Equal(t, "sqs-queue", result.ResourceType) + assert.Equal(t, "https://sqs/q1", result.ResourceName) +} + +func TestDeleteSQSQueuePolicy_STSError(t *testing.T) { + sts := &mockStsClient{assumeRootErr: errors.New("assume root denied")} + factory := &mockSqsClientFactory{client: &mockSqsClient{}} + + result, err := deleteSQSQueuePolicy(context.Background(), sts, factory, "123456789012", "https://sqs/q1") + require.NoError(t, err) + assert.False(t, result.Success) + assert.NotEmpty(t, result.Error) +} + +func TestDeleteSQSQueuePolicy_SqsError(t *testing.T) { + sqs := &mockSqsClient{deleteQueuePolErr: errors.New("access denied")} + factory := &mockSqsClientFactory{client: sqs} + sts := &mockStsClient{} + + result, err := deleteSQSQueuePolicy(context.Background(), sts, factory, "123456789012", "https://sqs/q1") + require.NoError(t, err) + assert.False(t, result.Success) + assert.NotEmpty(t, result.Error) +} diff --git a/rootmanager/types.go b/rootmanager/types.go index 59f5fb1..dfd2c34 100644 --- a/rootmanager/types.go +++ b/rootmanager/types.go @@ -31,3 +31,12 @@ type DeletionResult struct { Success bool // Whether deletion was successful Error string // Error message if deletion failed (empty if Success=true) } + +// PolicyDeletionResult represents the result of a resource policy deletion operation. +type PolicyDeletionResult struct { + AccountId string // AWS account ID + ResourceType string // Type of resource ("s3-bucket", "sqs-queue") + ResourceName string // Bucket name or queue URL + Success bool // Whether deletion was successful + Error string // Error message if deletion failed (empty if Success=true) +}