From 8e2b5dad186f82fd7dc973a9b53699cec330536b Mon Sep 17 00:00:00 2001 From: dhernando Date: Mon, 16 Mar 2026 14:58:09 +0100 Subject: [PATCH] chore: refactor grpc tests to use a better mocking api --- internal/cli/config_load_test.go | 49 ++--- internal/cmd/backup/create_test.go | 55 ++--- internal/cmd/backup/delete_test.go | 20 +- internal/cmd/backup/describe_test.go | 43 ++-- internal/cmd/backup/list_test.go | 66 +++--- internal/cmd/cluster/cloud_provider_test.go | 23 +- internal/cmd/cluster/cloud_region_test.go | 26 +-- internal/cmd/cluster/completion_test.go | 97 ++++----- internal/cmd/cluster/create_test.go | 230 ++++++++------------ internal/cmd/cluster/key_create_test.go | 79 +++---- internal/cmd/cluster/key_delete_test.go | 18 +- internal/cmd/cluster/key_list_test.go | 62 +++--- internal/cmd/cluster/list_test.go | 198 +++++++---------- internal/cmd/cluster/package_test.go | 51 ++--- internal/cmd/cluster/restart_test.go | 128 +++++------ internal/cmd/cluster/suspend_test.go | 16 +- internal/cmd/cluster/unsuspend_test.go | 16 +- internal/cmd/cluster/update_test.go | 49 ++--- internal/cmd/cluster/version_test.go | 20 +- internal/cmd/cluster/wait_test.go | 73 +++---- internal/testutil/fake_backup.go | 69 +----- internal/testutil/fake_booking.go | 9 +- internal/testutil/fake_cluster.go | 63 +----- internal/testutil/fake_database_api_key.go | 21 +- internal/testutil/fake_platform.go | 15 +- 25 files changed, 586 insertions(+), 910 deletions(-) diff --git a/internal/cli/config_load_test.go b/internal/cli/config_load_test.go index d637fed..f129bc1 100644 --- a/internal/cli/config_load_test.go +++ b/internal/cli/config_load_test.go @@ -1,7 +1,6 @@ package cli_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -16,28 +15,25 @@ import ( // loaded via --config reaches the gRPC request. func TestConfigLoad_FlagSetsAccountID(t *testing.T) { env := testutil.NewBareTestEnv(t) - t.Cleanup(env.Cleanup) cfgPath := testutil.WriteConfigFile(t, t.TempDir(), map[string]any{ "account_id": "account-from-file", }) - var capturedAccountID string - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - capturedAccountID = req.GetAccountId() - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "--config", cfgPath, "cluster", "list") require.NoError(t, err) - assert.Equal(t, "account-from-file", capturedAccountID) + + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + assert.Equal(t, "account-from-file", req.GetAccountId()) } // TestConfigLoad_EnvVarSetsAccountID verifies that QDRANT_CLOUD_CONFIG env var // is respected when no --config flag is given. func TestConfigLoad_EnvVarSetsAccountID(t *testing.T) { env := testutil.NewBareTestEnv(t) - t.Cleanup(env.Cleanup) cfgPath := testutil.WriteConfigFile(t, t.TempDir(), map[string]any{ "account_id": "account-from-envvar", @@ -45,22 +41,20 @@ func TestConfigLoad_EnvVarSetsAccountID(t *testing.T) { t.Setenv("QDRANT_CLOUD_CONFIG", cfgPath) - var capturedAccountID string - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - capturedAccountID = req.GetAccountId() - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "cluster", "list") require.NoError(t, err) - assert.Equal(t, "account-from-envvar", capturedAccountID) + + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + assert.Equal(t, "account-from-envvar", req.GetAccountId()) } // TestConfigLoad_FlagOverridesEnvVar verifies that --config flag takes // precedence over QDRANT_CLOUD_CONFIG env var. func TestConfigLoad_FlagOverridesEnvVar(t *testing.T) { env := testutil.NewBareTestEnv(t) - t.Cleanup(env.Cleanup) dir := t.TempDir() flagCfg := testutil.WriteConfigFile(t, dir, map[string]any{ @@ -73,34 +67,31 @@ func TestConfigLoad_FlagOverridesEnvVar(t *testing.T) { t.Setenv("QDRANT_CLOUD_CONFIG", envCfg) - var capturedAccountID string - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - capturedAccountID = req.GetAccountId() - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "--config", flagCfg, "cluster", "list") require.NoError(t, err) - assert.Equal(t, "account-from-flag", capturedAccountID) + + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + assert.Equal(t, "account-from-flag", req.GetAccountId()) } // TestConfigLoad_WithAccountIDTakesPrecedence verifies that WithAccountID (Set) // takes precedence over a config file loaded via --config (Set > config file). func TestConfigLoad_WithAccountIDTakesPrecedence(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("explicit-id")) - t.Cleanup(env.Cleanup) cfgPath := testutil.WriteConfigFile(t, t.TempDir(), map[string]any{ "account_id": "account-from-file", }) - var capturedAccountID string - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - capturedAccountID = req.GetAccountId() - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "--config", cfgPath, "cluster", "list") require.NoError(t, err) - assert.Equal(t, "explicit-id", capturedAccountID) + + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + assert.Equal(t, "explicit-id", req.GetAccountId()) } diff --git a/internal/cmd/backup/create_test.go b/internal/cmd/backup/create_test.go index 55ac0c6..8424673 100644 --- a/internal/cmd/backup/create_test.go +++ b/internal/cmd/backup/create_test.go @@ -1,7 +1,6 @@ package backup_test import ( - "context" "encoding/json" "testing" @@ -15,50 +14,47 @@ import ( func TestBackupCreate_Success(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.BackupServer.CreateBackupFunc = func(_ context.Context, req *backupv1.CreateBackupRequest) (*backupv1.CreateBackupResponse, error) { - assert.Equal(t, "test-account-id", req.GetBackup().GetAccountId()) - assert.Equal(t, "cluster-abc", req.GetBackup().GetClusterId()) - return &backupv1.CreateBackupResponse{ - Backup: &backupv1.Backup{Id: "backup-new", ClusterId: "cluster-abc"}, - }, nil - } + + env.BackupServer.CreateBackupCalls.Returns(&backupv1.CreateBackupResponse{ + Backup: &backupv1.Backup{Id: "backup-new", ClusterId: "cluster-abc"}, + }, nil) stdout, _, err := testutil.Exec(t, env, "backup", "create", "--cluster-id=cluster-abc", "--retention-days=7") require.NoError(t, err) assert.Contains(t, stdout, "backup-new") assert.Contains(t, stdout, "cluster-abc") + + req, ok := env.BackupServer.CreateBackupCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetBackup().GetAccountId()) + assert.Equal(t, "cluster-abc", req.GetBackup().GetClusterId()) } func TestBackupCreate_WithRetention(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedRetention int64 - env.BackupServer.CreateBackupFunc = func(_ context.Context, req *backupv1.CreateBackupRequest) (*backupv1.CreateBackupResponse, error) { - if req.GetBackup().GetRetentionPeriod() != nil { - capturedRetention = int64(req.GetBackup().GetRetentionPeriod().AsDuration().Hours()) / 24 - } - return &backupv1.CreateBackupResponse{ - Backup: &backupv1.Backup{Id: "backup-ret", ClusterId: "cluster-abc"}, - }, nil - } + env.BackupServer.CreateBackupCalls.Returns(&backupv1.CreateBackupResponse{ + Backup: &backupv1.Backup{Id: "backup-ret", ClusterId: "cluster-abc"}, + }, nil) _, _, err := testutil.Exec(t, env, "backup", "create", "--cluster-id=cluster-abc", "--retention-days=7") require.NoError(t, err) + + req, ok := env.BackupServer.CreateBackupCalls.Last() + require.True(t, ok) + var capturedRetention int64 + if req.GetBackup().GetRetentionPeriod() != nil { + capturedRetention = int64(req.GetBackup().GetRetentionPeriod().AsDuration().Hours()) / 24 + } assert.Equal(t, int64(7), capturedRetention) } func TestBackupCreate_JSONOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.CreateBackupFunc = func(_ context.Context, _ *backupv1.CreateBackupRequest) (*backupv1.CreateBackupResponse, error) { - return &backupv1.CreateBackupResponse{ - Backup: &backupv1.Backup{Id: "backup-json", ClusterId: "cluster-123"}, - }, nil - } + env.BackupServer.CreateBackupCalls.Returns(&backupv1.CreateBackupResponse{ + Backup: &backupv1.Backup{Id: "backup-json", ClusterId: "cluster-123"}, + }, nil) stdout, _, err := testutil.Exec(t, env, "backup", "create", "--cluster-id=cluster-123", "--retention-days=7", "--json") require.NoError(t, err) @@ -73,7 +69,6 @@ func TestBackupCreate_JSONOutput(t *testing.T) { func TestBackupCreate_InvalidRetention(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "backup", "create", "--cluster-id=cluster-abc", "--retention-days=0") require.Error(t, err) @@ -81,7 +76,6 @@ func TestBackupCreate_InvalidRetention(t *testing.T) { func TestBackupCreate_MissingClusterID(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "backup", "create") require.Error(t, err) @@ -89,11 +83,8 @@ func TestBackupCreate_MissingClusterID(t *testing.T) { func TestBackupCreate_APIError(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.CreateBackupFunc = func(_ context.Context, _ *backupv1.CreateBackupRequest) (*backupv1.CreateBackupResponse, error) { - return nil, assert.AnError - } + env.BackupServer.CreateBackupCalls.Returns(nil, assert.AnError) _, _, err := testutil.Exec(t, env, "backup", "create", "--cluster-id=cluster-abc", "--retention-days=7") require.Error(t, err) diff --git a/internal/cmd/backup/delete_test.go b/internal/cmd/backup/delete_test.go index dc24bcb..593792a 100644 --- a/internal/cmd/backup/delete_test.go +++ b/internal/cmd/backup/delete_test.go @@ -1,7 +1,6 @@ package backup_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,29 +13,24 @@ import ( func TestBackupDelete_WithForce(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedBackupID string - env.BackupServer.DeleteBackupFunc = func(_ context.Context, req *backupv1.DeleteBackupRequest) (*backupv1.DeleteBackupResponse, error) { - assert.Equal(t, "test-account-id", req.GetAccountId()) - capturedBackupID = req.GetBackupId() - return &backupv1.DeleteBackupResponse{}, nil - } + env.BackupServer.DeleteBackupCalls.Returns(&backupv1.DeleteBackupResponse{}, nil) stdout, _, err := testutil.Exec(t, env, "backup", "delete", "backup-abc", "--force") require.NoError(t, err) - assert.Equal(t, "backup-abc", capturedBackupID) + + req, ok := env.BackupServer.DeleteBackupCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) + assert.Equal(t, "backup-abc", req.GetBackupId()) assert.Contains(t, stdout, "backup-abc") assert.Contains(t, stdout, "deleted") } func TestBackupDelete_APIError(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.DeleteBackupFunc = func(_ context.Context, _ *backupv1.DeleteBackupRequest) (*backupv1.DeleteBackupResponse, error) { - return nil, assert.AnError - } + env.BackupServer.DeleteBackupCalls.Returns(nil, assert.AnError) _, _, err := testutil.Exec(t, env, "backup", "delete", "backup-abc", "--force") require.Error(t, err) diff --git a/internal/cmd/backup/describe_test.go b/internal/cmd/backup/describe_test.go index 5982981..3210767 100644 --- a/internal/cmd/backup/describe_test.go +++ b/internal/cmd/backup/describe_test.go @@ -1,7 +1,6 @@ package backup_test import ( - "context" "encoding/json" "testing" @@ -16,21 +15,16 @@ import ( func TestBackupDescribe_TextOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.GetBackupFunc = func(_ context.Context, req *backupv1.GetBackupRequest) (*backupv1.GetBackupResponse, error) { - assert.Equal(t, "test-account-id", req.GetAccountId()) - assert.Equal(t, "backup-abc", req.GetBackupId()) - return &backupv1.GetBackupResponse{ - Backup: &backupv1.Backup{ - Id: "backup-abc", - Name: "my-backup", - ClusterId: "cluster-123", - Status: backupv1.BackupStatus_BACKUP_STATUS_SUCCEEDED, - CreatedAt: timestamppb.Now(), - }, - }, nil - } + env.BackupServer.GetBackupCalls.Returns(&backupv1.GetBackupResponse{ + Backup: &backupv1.Backup{ + Id: "backup-abc", + Name: "my-backup", + ClusterId: "cluster-123", + Status: backupv1.BackupStatus_BACKUP_STATUS_SUCCEEDED, + CreatedAt: timestamppb.Now(), + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "backup", "describe", "backup-abc") require.NoError(t, err) @@ -38,17 +32,19 @@ func TestBackupDescribe_TextOutput(t *testing.T) { assert.Contains(t, stdout, "my-backup") assert.Contains(t, stdout, "cluster-123") assert.Contains(t, stdout, "SUCCEEDED") + + req, ok := env.BackupServer.GetBackupCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) + assert.Equal(t, "backup-abc", req.GetBackupId()) } func TestBackupDescribe_JSONOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.GetBackupFunc = func(_ context.Context, _ *backupv1.GetBackupRequest) (*backupv1.GetBackupResponse, error) { - return &backupv1.GetBackupResponse{ - Backup: &backupv1.Backup{Id: "backup-json", ClusterId: "cluster-xyz"}, - }, nil - } + env.BackupServer.GetBackupCalls.Returns(&backupv1.GetBackupResponse{ + Backup: &backupv1.Backup{Id: "backup-json", ClusterId: "cluster-xyz"}, + }, nil) stdout, _, err := testutil.Exec(t, env, "backup", "describe", "backup-json", "--json") require.NoError(t, err) @@ -64,11 +60,8 @@ func TestBackupDescribe_JSONOutput(t *testing.T) { func TestBackupDescribe_APIError(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.GetBackupFunc = func(_ context.Context, _ *backupv1.GetBackupRequest) (*backupv1.GetBackupResponse, error) { - return nil, assert.AnError - } + env.BackupServer.GetBackupCalls.Returns(nil, assert.AnError) _, _, err := testutil.Exec(t, env, "backup", "describe", "backup-abc") require.Error(t, err) diff --git a/internal/cmd/backup/list_test.go b/internal/cmd/backup/list_test.go index b608777..d658a21 100644 --- a/internal/cmd/backup/list_test.go +++ b/internal/cmd/backup/list_test.go @@ -1,7 +1,6 @@ package backup_test import ( - "context" "encoding/json" "testing" @@ -16,22 +15,18 @@ import ( func TestBackupList_TableOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.BackupServer.ListBackupsFunc = func(_ context.Context, req *backupv1.ListBackupsRequest) (*backupv1.ListBackupsResponse, error) { - assert.Equal(t, "test-account-id", req.GetAccountId()) - return &backupv1.ListBackupsResponse{ - Items: []*backupv1.Backup{ - { - Id: "backup-1", - Name: "my-backup", - ClusterId: "cluster-abc", - Status: backupv1.BackupStatus_BACKUP_STATUS_SUCCEEDED, - CreatedAt: timestamppb.Now(), - }, + + env.BackupServer.ListBackupsCalls.Returns(&backupv1.ListBackupsResponse{ + Items: []*backupv1.Backup{ + { + Id: "backup-1", + Name: "my-backup", + ClusterId: "cluster-abc", + Status: backupv1.BackupStatus_BACKUP_STATUS_SUCCEEDED, + CreatedAt: timestamppb.Now(), }, - }, nil - } + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "backup", "list") require.NoError(t, err) @@ -44,19 +39,20 @@ func TestBackupList_TableOutput(t *testing.T) { assert.Contains(t, stdout, "my-backup") assert.Contains(t, stdout, "cluster-abc") assert.Contains(t, stdout, "SUCCEEDED") + + req, ok := env.BackupServer.ListBackupsCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) } func TestBackupList_JSONOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.ListBackupsFunc = func(_ context.Context, _ *backupv1.ListBackupsRequest) (*backupv1.ListBackupsResponse, error) { - return &backupv1.ListBackupsResponse{ - Items: []*backupv1.Backup{ - {Id: "backup-json", ClusterId: "cluster-123"}, - }, - }, nil - } + env.BackupServer.ListBackupsCalls.Returns(&backupv1.ListBackupsResponse{ + Items: []*backupv1.Backup{ + {Id: "backup-json", ClusterId: "cluster-123"}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "backup", "list", "--json") require.NoError(t, err) @@ -75,11 +71,8 @@ func TestBackupList_JSONOutput(t *testing.T) { func TestBackupList_EmptyResponse(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.ListBackupsFunc = func(_ context.Context, _ *backupv1.ListBackupsRequest) (*backupv1.ListBackupsResponse, error) { - return &backupv1.ListBackupsResponse{}, nil - } + env.BackupServer.ListBackupsCalls.Returns(&backupv1.ListBackupsResponse{}, nil) stdout, _, err := testutil.Exec(t, env, "backup", "list") require.NoError(t, err) @@ -89,26 +82,21 @@ func TestBackupList_EmptyResponse(t *testing.T) { func TestBackupList_ClusterIDFilter(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedClusterID string - env.BackupServer.ListBackupsFunc = func(_ context.Context, req *backupv1.ListBackupsRequest) (*backupv1.ListBackupsResponse, error) { - capturedClusterID = req.GetClusterId() - return &backupv1.ListBackupsResponse{}, nil - } + env.BackupServer.ListBackupsCalls.Returns(&backupv1.ListBackupsResponse{}, nil) _, _, err := testutil.Exec(t, env, "backup", "list", "--cluster-id=my-cluster") require.NoError(t, err) - assert.Equal(t, "my-cluster", capturedClusterID) + + req, ok := env.BackupServer.ListBackupsCalls.Last() + require.True(t, ok) + assert.Equal(t, "my-cluster", req.GetClusterId()) } func TestBackupList_APIError(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BackupServer.ListBackupsFunc = func(_ context.Context, _ *backupv1.ListBackupsRequest) (*backupv1.ListBackupsResponse, error) { - return nil, assert.AnError - } + env.BackupServer.ListBackupsCalls.Returns(nil, assert.AnError) _, _, err := testutil.Exec(t, env, "backup", "list") require.Error(t, err) diff --git a/internal/cmd/cluster/cloud_provider_test.go b/internal/cmd/cluster/cloud_provider_test.go index 167a4b7..a122357 100644 --- a/internal/cmd/cluster/cloud_provider_test.go +++ b/internal/cmd/cluster/cloud_provider_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,17 +13,13 @@ import ( func TestListCloudProviders_TableOutput(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("test-account-id")) - t.Cleanup(env.Cleanup) - - env.PlatformServer.ListCloudProvidersFunc = func(_ context.Context, req *platformv1.ListCloudProvidersRequest) (*platformv1.ListCloudProvidersResponse, error) { - assert.Equal(t, "test-account-id", req.GetAccountId()) - return &platformv1.ListCloudProvidersResponse{ - Items: []*platformv1.CloudProvider{ - {Id: "aws", Name: "Amazon Web Services", Available: true}, - {Id: "gcp", Name: "Google Cloud", Available: false}, - }, - }, nil - } + + env.PlatformServer.ListCloudProvidersCalls.Returns(&platformv1.ListCloudProvidersResponse{ + Items: []*platformv1.CloudProvider{ + {Id: "aws", Name: "Amazon Web Services", Available: true}, + {Id: "gcp", Name: "Google Cloud", Available: false}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "cloud-provider", "list") require.NoError(t, err) @@ -37,4 +32,8 @@ func TestListCloudProviders_TableOutput(t *testing.T) { assert.Contains(t, stdout, "gcp") assert.Contains(t, stdout, "Google Cloud") assert.Contains(t, stdout, "false") + + req, ok := env.PlatformServer.ListCloudProvidersCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) } diff --git a/internal/cmd/cluster/cloud_region_test.go b/internal/cmd/cluster/cloud_region_test.go index 6882b0f..2626d67 100644 --- a/internal/cmd/cluster/cloud_region_test.go +++ b/internal/cmd/cluster/cloud_region_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,7 +13,6 @@ import ( func TestListCloudRegions_RequiresCloudProvider(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "cluster", "cloud-region", "list") require.Error(t, err) @@ -23,18 +21,13 @@ func TestListCloudRegions_RequiresCloudProvider(t *testing.T) { func TestListCloudRegions_TableOutput(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("test-account-id")) - t.Cleanup(env.Cleanup) - - env.PlatformServer.ListCloudProviderRegionsFunc = func(_ context.Context, req *platformv1.ListCloudProviderRegionsRequest) (*platformv1.ListCloudProviderRegionsResponse, error) { - assert.Equal(t, "test-account-id", req.GetAccountId()) - assert.Equal(t, "aws", req.GetCloudProviderId()) - return &platformv1.ListCloudProviderRegionsResponse{ - Items: []*platformv1.CloudProviderRegion{ - {Id: "us-east-1", Name: "US East (N. Virginia)", Provider: "aws", Available: true}, - {Id: "eu-west-1", Name: "Europe (Ireland)", Provider: "aws", Available: false}, - }, - }, nil - } + + env.PlatformServer.ListCloudProviderRegionsCalls.Returns(&platformv1.ListCloudProviderRegionsResponse{ + Items: []*platformv1.CloudProviderRegion{ + {Id: "us-east-1", Name: "US East (N. Virginia)", Provider: "aws", Available: true}, + {Id: "eu-west-1", Name: "Europe (Ireland)", Provider: "aws", Available: false}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "cloud-region", "list", "--cloud-provider", "aws") require.NoError(t, err) @@ -48,4 +41,9 @@ func TestListCloudRegions_TableOutput(t *testing.T) { assert.Contains(t, stdout, "Europe (Ireland)") assert.Contains(t, stdout, "true") assert.Contains(t, stdout, "false") + + req, ok := env.PlatformServer.ListCloudProviderRegionsCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) + assert.Equal(t, "aws", req.GetCloudProviderId()) } diff --git a/internal/cmd/cluster/completion_test.go b/internal/cmd/cluster/completion_test.go index 810aca9..f30b374 100644 --- a/internal/cmd/cluster/completion_test.go +++ b/internal/cmd/cluster/completion_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -21,16 +20,13 @@ import ( // completion logic without a real shell. func TestClusterIDCompletion(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.ListClustersFunc = func(_ context.Context, _ *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - return &clusterv1.ListClustersResponse{ - Items: []*clusterv1.Cluster{ - {Id: "cluster-abc", Name: "my-cluster"}, - {Id: "cluster-xyz", Name: "other-cluster"}, - }, - }, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{ + Items: []*clusterv1.Cluster{ + {Id: "cluster-abc", Name: "my-cluster"}, + {Id: "cluster-xyz", Name: "other-cluster"}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "__complete", "cluster", "describe", "") require.NoError(t, err) @@ -41,16 +37,13 @@ func TestClusterIDCompletion(t *testing.T) { func TestCloudProviderCompletion(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.PlatformServer.ListCloudProvidersFunc = func(_ context.Context, _ *platformv1.ListCloudProvidersRequest) (*platformv1.ListCloudProvidersResponse, error) { - return &platformv1.ListCloudProvidersResponse{ - Items: []*platformv1.CloudProvider{ - {Id: "aws", Name: "Amazon Web Services"}, - {Id: "gcp", Name: "Google Cloud"}, - }, - }, nil - } + env.PlatformServer.ListCloudProvidersCalls.Returns(&platformv1.ListCloudProvidersResponse{ + Items: []*platformv1.CloudProvider{ + {Id: "aws", Name: "Amazon Web Services"}, + {Id: "gcp", Name: "Google Cloud"}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "__complete", "cluster", "list", "--cloud-provider", "") require.NoError(t, err) @@ -61,15 +54,12 @@ func TestCloudProviderCompletion(t *testing.T) { func TestCloudRegionCompletion(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.PlatformServer.ListCloudProviderRegionsFunc = func(_ context.Context, req *platformv1.ListCloudProviderRegionsRequest) (*platformv1.ListCloudProviderRegionsResponse, error) { - return &platformv1.ListCloudProviderRegionsResponse{ - Items: []*platformv1.CloudProviderRegion{ - {Id: "us-east-1", Name: "US East (N. Virginia)"}, - }, - }, nil - } + env.PlatformServer.ListCloudProviderRegionsCalls.Returns(&platformv1.ListCloudProviderRegionsResponse{ + Items: []*platformv1.CloudProviderRegion{ + {Id: "us-east-1", Name: "US East (N. Virginia)"}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "__complete", "cluster", "list", "--cloud-provider", "aws", "--cloud-region", "") require.NoError(t, err) @@ -79,7 +69,6 @@ func TestCloudRegionCompletion(t *testing.T) { func TestCloudRegionCompletion_NoProvider(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) stdout, _, err := testutil.Exec(t, env, "__complete", "cluster", "list", "--cloud-region", "") require.NoError(t, err) @@ -89,16 +78,13 @@ func TestCloudRegionCompletion_NoProvider(t *testing.T) { func TestPackageCompletion(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BookingServer.ListPackagesFunc = func(_ context.Context, _ *bookingv1.ListPackagesRequest) (*bookingv1.ListPackagesResponse, error) { - return &bookingv1.ListPackagesResponse{ - Items: []*bookingv1.Package{ - {Id: "pkg-uuid-1", Name: "startup-4"}, - {Id: "pkg-uuid-2", Name: "business-8"}, - }, - }, nil - } + env.BookingServer.ListPackagesCalls.Returns(&bookingv1.ListPackagesResponse{ + Items: []*bookingv1.Package{ + {Id: "pkg-uuid-1", Name: "startup-4"}, + {Id: "pkg-uuid-2", Name: "business-8"}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "__complete", "cluster", "create", "--cloud-provider", "aws", "--cloud-region", "us-east-1", "--package", "") require.NoError(t, err) @@ -108,7 +94,6 @@ func TestPackageCompletion(t *testing.T) { func TestPackageCompletion_NoProvider(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) stdout, _, err := testutil.Exec(t, env, "__complete", "cluster", "create", "--package", "") require.NoError(t, err) @@ -117,19 +102,16 @@ func TestPackageCompletion_NoProvider(t *testing.T) { func TestVersionCompletion(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) remarks := "upgrade recommended" - env.Server.ListQdrantReleasesFunc = func(_ context.Context, _ *clusterv1.ListQdrantReleasesRequest) (*clusterv1.ListQdrantReleasesResponse, error) { - return &clusterv1.ListQdrantReleasesResponse{ - Items: []*clusterv1.QdrantRelease{ - {Version: "1.14.0", Default: true}, - {Version: "1.13.0", EndOfLife: true}, - {Version: "1.12.0", Unavailable: true}, - {Version: "1.11.0", Remarks: &remarks}, - }, - }, nil - } + env.Server.ListQdrantReleasesCalls.Returns(&clusterv1.ListQdrantReleasesResponse{ + Items: []*clusterv1.QdrantRelease{ + {Version: "1.14.0", Default: true}, + {Version: "1.13.0", EndOfLife: true}, + {Version: "1.12.0", Unavailable: true}, + {Version: "1.11.0", Remarks: &remarks}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "__complete", "cluster", "create", "--cloud-provider", "aws", "--cloud-region", "us-east-1", "--version", "") require.NoError(t, err) @@ -144,16 +126,13 @@ func TestVersionCompletion(t *testing.T) { func TestVersionCompletion_UnavailableExcluded(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.Server.ListQdrantReleasesFunc = func(_ context.Context, _ *clusterv1.ListQdrantReleasesRequest) (*clusterv1.ListQdrantReleasesResponse, error) { - return &clusterv1.ListQdrantReleasesResponse{ - Items: []*clusterv1.QdrantRelease{ - {Version: "1.14.0"}, - {Version: "1.13.0", Unavailable: true}, - }, - }, nil - } + + env.Server.ListQdrantReleasesCalls.Returns(&clusterv1.ListQdrantReleasesResponse{ + Items: []*clusterv1.QdrantRelease{ + {Version: "1.14.0"}, + {Version: "1.13.0", Unavailable: true}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "__complete", "cluster", "create", "--cloud-provider", "aws", "--cloud-region", "us-east-1", "--version", "") require.NoError(t, err) diff --git a/internal/cmd/cluster/create_test.go b/internal/cmd/cluster/create_test.go index f674b80..f7a52c9 100644 --- a/internal/cmd/cluster/create_test.go +++ b/internal/cmd/cluster/create_test.go @@ -2,7 +2,6 @@ package cluster_test import ( "context" - "sync/atomic" "testing" "github.com/stretchr/testify/assert" @@ -16,18 +15,10 @@ import ( func TestCreateCluster_WithLabels(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - var capturedLabels map[string]string - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { - capturedLabels = make(map[string]string) - for _, kv := range req.GetCluster().GetLabels() { - capturedLabels[kv.GetKey()] = kv.GetValue() - } - return &clusterv1.CreateClusterResponse{ - Cluster: &clusterv1.Cluster{Id: "cluster-labeled"}, - }, nil - } + + env.Server.CreateClusterCalls.Returns(&clusterv1.CreateClusterResponse{ + Cluster: &clusterv1.Cluster{Id: "cluster-labeled"}, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "create", @@ -39,30 +30,32 @@ func TestCreateCluster_WithLabels(t *testing.T) { "--label", "team=platform", ) require.NoError(t, err) + + req, ok := env.Server.CreateClusterCalls.Last() + require.True(t, ok) + capturedLabels := make(map[string]string) + for _, kv := range req.GetCluster().GetLabels() { + capturedLabels[kv.GetKey()] = kv.GetValue() + } assert.Equal(t, map[string]string{"env": "prod", "team": "platform"}, capturedLabels) } func TestCreateCluster_NoWait(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var getCallCount int32 - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { + env.Server.CreateClusterCalls.Always(func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { return &clusterv1.CreateClusterResponse{ Cluster: &clusterv1.Cluster{ Id: "cluster-abc", Name: req.GetCluster().GetName(), }, }, nil - } - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - atomic.AddInt32(&getCallCount, 1) - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, - }, - }, nil - } + }) + env.Server.GetClusterCalls.Returns(&clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "create", @@ -73,43 +66,48 @@ func TestCreateCluster_NoWait(t *testing.T) { ) require.NoError(t, err) assert.Contains(t, stdout, "cluster-abc") - assert.EqualValues(t, 0, atomic.LoadInt32(&getCallCount), "GetCluster should not be called without --wait") + assert.Equal(t, 0, env.Server.GetClusterCalls.Count(), "GetCluster should not be called without --wait") } func TestCreateCluster_WaitSuccess(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { + env.Server.CreateClusterCalls.Always(func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { return &clusterv1.CreateClusterResponse{ Cluster: &clusterv1.Cluster{ Id: "cluster-xyz", Name: req.GetCluster().GetName(), }, }, nil - } - - var callCount int32 - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - n := atomic.AddInt32(&callCount, 1) - if n < 3 { + }) + env.Server.GetClusterCalls. + OnCall(0, func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { return &clusterv1.GetClusterResponse{ Cluster: &clusterv1.Cluster{ Id: "cluster-xyz", State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, }, }, nil - } - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-xyz", - State: &clusterv1.ClusterState{ - Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_HEALTHY, - Endpoint: &clusterv1.ClusterEndpoint{Url: "https://xyz.aws.cloud.qdrant.io"}, + }). + OnCall(1, func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + return &clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-xyz", + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, }, - }, - }, nil - } + }, nil + }). + Always(func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + return &clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-xyz", + State: &clusterv1.ClusterState{ + Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_HEALTHY, + Endpoint: &clusterv1.ClusterEndpoint{Url: "https://xyz.aws.cloud.qdrant.io"}, + }, + }, + }, nil + }) stdout, stderr, err := testutil.Exec(t, env, "cluster", "create", @@ -130,24 +128,19 @@ func TestCreateCluster_WaitSuccess(t *testing.T) { func TestCreateCluster_WaitFailure(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { - return &clusterv1.CreateClusterResponse{ - Cluster: &clusterv1.Cluster{Id: "cluster-fail"}, - }, nil - } - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-fail", - State: &clusterv1.ClusterState{ - Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_FAILED_TO_CREATE, - Reason: "quota exceeded", - }, + env.Server.CreateClusterCalls.Returns(&clusterv1.CreateClusterResponse{ + Cluster: &clusterv1.Cluster{Id: "cluster-fail"}, + }, nil) + env.Server.GetClusterCalls.Returns(&clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-fail", + State: &clusterv1.ClusterState{ + Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_FAILED_TO_CREATE, + Reason: "quota exceeded", }, - }, nil - } + }, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "create", @@ -166,21 +159,16 @@ func TestCreateCluster_WaitFailure(t *testing.T) { func TestCreateCluster_WaitTimeout(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { - return &clusterv1.CreateClusterResponse{ - Cluster: &clusterv1.Cluster{Id: "cluster-slow"}, - }, nil - } - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-slow", - State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, - }, - }, nil - } + env.Server.CreateClusterCalls.Returns(&clusterv1.CreateClusterResponse{ + Cluster: &clusterv1.Cluster{Id: "cluster-slow"}, + }, nil) + env.Server.GetClusterCalls.Returns(&clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-slow", + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, + }, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "create", @@ -197,21 +185,11 @@ func TestCreateCluster_WaitTimeout(t *testing.T) { func TestCreateCluster_PackageByUUID(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var listCallCount int32 - env.BookingServer.ListPackagesFunc = func(_ context.Context, _ *bookingv1.ListPackagesRequest) (*bookingv1.ListPackagesResponse, error) { - atomic.AddInt32(&listCallCount, 1) - return &bookingv1.ListPackagesResponse{}, nil - } - - var capturedPackageID string - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { - capturedPackageID = req.GetCluster().GetConfiguration().GetPackageId() - return &clusterv1.CreateClusterResponse{ - Cluster: &clusterv1.Cluster{Id: "cluster-pkg-uuid"}, - }, nil - } + env.BookingServer.ListPackagesCalls.Returns(&bookingv1.ListPackagesResponse{}, nil) + env.Server.CreateClusterCalls.Returns(&clusterv1.CreateClusterResponse{ + Cluster: &clusterv1.Cluster{Id: "cluster-pkg-uuid"}, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "create", @@ -221,29 +199,24 @@ func TestCreateCluster_PackageByUUID(t *testing.T) { "--package", "550e8400-e29b-41d4-a716-446655440000", ) require.NoError(t, err) - assert.EqualValues(t, 0, atomic.LoadInt32(&listCallCount), "ListPackages should not be called for UUID input") - assert.Equal(t, "550e8400-e29b-41d4-a716-446655440000", capturedPackageID) + assert.Equal(t, 0, env.BookingServer.ListPackagesCalls.Count(), "ListPackages should not be called for UUID input") + + req, ok := env.Server.CreateClusterCalls.Last() + require.True(t, ok) + assert.Equal(t, "550e8400-e29b-41d4-a716-446655440000", req.GetCluster().GetConfiguration().GetPackageId()) } func TestCreateCluster_PackageByName(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.BookingServer.ListPackagesFunc = func(_ context.Context, _ *bookingv1.ListPackagesRequest) (*bookingv1.ListPackagesResponse, error) { - return &bookingv1.ListPackagesResponse{ - Items: []*bookingv1.Package{ - {Id: "pkg-uuid-123", Name: "starter"}, - }, - }, nil - } - var capturedPackageID string - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { - capturedPackageID = req.GetCluster().GetConfiguration().GetPackageId() - return &clusterv1.CreateClusterResponse{ - Cluster: &clusterv1.Cluster{Id: "cluster-named-pkg"}, - }, nil - } + env.BookingServer.ListPackagesCalls.Returns(&bookingv1.ListPackagesResponse{ + Items: []*bookingv1.Package{ + {Id: "pkg-uuid-123", Name: "starter"}, + }, + }, nil) + env.Server.CreateClusterCalls.Returns(&clusterv1.CreateClusterResponse{ + Cluster: &clusterv1.Cluster{Id: "cluster-named-pkg"}, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "create", @@ -253,16 +226,16 @@ func TestCreateCluster_PackageByName(t *testing.T) { "--package", "starter", ) require.NoError(t, err) - assert.Equal(t, "pkg-uuid-123", capturedPackageID) + + req, ok := env.Server.CreateClusterCalls.Last() + require.True(t, ok) + assert.Equal(t, "pkg-uuid-123", req.GetCluster().GetConfiguration().GetPackageId()) } func TestCreateCluster_PackageNameNotFound(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BookingServer.ListPackagesFunc = func(_ context.Context, _ *bookingv1.ListPackagesRequest) (*bookingv1.ListPackagesResponse, error) { - return &bookingv1.ListPackagesResponse{}, nil - } + env.BookingServer.ListPackagesCalls.Returns(&bookingv1.ListPackagesResponse{}, nil) _, _, err := testutil.Exec(t, env, "cluster", "create", @@ -277,19 +250,11 @@ func TestCreateCluster_PackageNameNotFound(t *testing.T) { func TestCreateCluster_AutoGeneratedName(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.Server.SuggestClusterNameFunc = func(_ context.Context, _ *clusterv1.SuggestClusterNameRequest) (*clusterv1.SuggestClusterNameResponse, error) { - return &clusterv1.SuggestClusterNameResponse{Name: "eager-pelican"}, nil - } - var capturedName string - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { - capturedName = req.GetCluster().GetName() - return &clusterv1.CreateClusterResponse{ - Cluster: &clusterv1.Cluster{Id: "cluster-auto", Name: capturedName}, - }, nil - } + env.Server.SuggestClusterNameCalls.Returns(&clusterv1.SuggestClusterNameResponse{Name: "eager-pelican"}, nil) + env.Server.CreateClusterCalls.Returns(&clusterv1.CreateClusterResponse{ + Cluster: &clusterv1.Cluster{Id: "cluster-auto", Name: "eager-pelican"}, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "create", @@ -298,24 +263,21 @@ func TestCreateCluster_AutoGeneratedName(t *testing.T) { "--package", "00000000-0000-0000-0000-000000000001", ) require.NoError(t, err) - assert.Equal(t, "eager-pelican", capturedName) + + req, ok := env.Server.CreateClusterCalls.Last() + require.True(t, ok) + assert.Equal(t, "eager-pelican", req.GetCluster().GetName()) } func TestCreateCluster_ExplicitNameSkipsSuggest(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - var suggestCalled bool - env.Server.SuggestClusterNameFunc = func(_ context.Context, _ *clusterv1.SuggestClusterNameRequest) (*clusterv1.SuggestClusterNameResponse, error) { - suggestCalled = true - return &clusterv1.SuggestClusterNameResponse{Name: "should-not-use"}, nil - } - env.Server.CreateClusterFunc = func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { + env.Server.SuggestClusterNameCalls.Returns(&clusterv1.SuggestClusterNameResponse{Name: "should-not-use"}, nil) + env.Server.CreateClusterCalls.Always(func(_ context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { return &clusterv1.CreateClusterResponse{ Cluster: &clusterv1.Cluster{Id: "cluster-named", Name: req.GetCluster().GetName()}, }, nil - } + }) _, _, err := testutil.Exec(t, env, "cluster", "create", @@ -325,5 +287,5 @@ func TestCreateCluster_ExplicitNameSkipsSuggest(t *testing.T) { "--package", "00000000-0000-0000-0000-000000000001", ) require.NoError(t, err) - assert.False(t, suggestCalled, "SuggestClusterName should not be called when --name is provided") + assert.Equal(t, 0, env.Server.SuggestClusterNameCalls.Count(), "SuggestClusterName should not be called when --name is provided") } diff --git a/internal/cmd/cluster/key_create_test.go b/internal/cmd/cluster/key_create_test.go index efe38ed..c1665d6 100644 --- a/internal/cmd/cluster/key_create_test.go +++ b/internal/cmd/cluster/key_create_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,45 +13,42 @@ import ( func TestKeyCreate_Basic(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("test-account-id")) - t.Cleanup(env.Cleanup) - - var capturedKey *clusterauthv2.DatabaseApiKey - env.DatabaseApiKeyServer.CreateDatabaseApiKeyFunc = func(_ context.Context, req *clusterauthv2.CreateDatabaseApiKeyRequest) (*clusterauthv2.CreateDatabaseApiKeyResponse, error) { - capturedKey = req.GetDatabaseApiKey() - return &clusterauthv2.CreateDatabaseApiKeyResponse{ - DatabaseApiKey: &clusterauthv2.DatabaseApiKey{ - Id: "key-new", - Name: req.GetDatabaseApiKey().GetName(), - Key: "secret-key-value", - }, - }, nil - } + + env.DatabaseApiKeyServer.CreateDatabaseApiKeyCalls.Returns(&clusterauthv2.CreateDatabaseApiKeyResponse{ + DatabaseApiKey: &clusterauthv2.DatabaseApiKey{ + Id: "key-new", + Key: "secret-key-value", + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "key", "create", "cluster-123", "--name", "my-key") require.NoError(t, err) + assert.Contains(t, stdout, "key-new") + assert.Contains(t, stdout, "secret-key-value") + assert.Contains(t, stdout, "not be shown again") + + req, ok := env.DatabaseApiKeyServer.CreateDatabaseApiKeyCalls.Last() + require.True(t, ok) + capturedKey := req.GetDatabaseApiKey() assert.Equal(t, "test-account-id", capturedKey.GetAccountId()) assert.Equal(t, "cluster-123", capturedKey.GetClusterId()) assert.Equal(t, "my-key", capturedKey.GetName()) assert.Empty(t, capturedKey.GetAccessRules()) - assert.Contains(t, stdout, "key-new") - assert.Contains(t, stdout, "secret-key-value") - assert.Contains(t, stdout, "not be shown again") } func TestKeyCreate_WithManageAccessType(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedKey *clusterauthv2.DatabaseApiKey - env.DatabaseApiKeyServer.CreateDatabaseApiKeyFunc = func(_ context.Context, req *clusterauthv2.CreateDatabaseApiKeyRequest) (*clusterauthv2.CreateDatabaseApiKeyResponse, error) { - capturedKey = req.GetDatabaseApiKey() - return &clusterauthv2.CreateDatabaseApiKeyResponse{ - DatabaseApiKey: &clusterauthv2.DatabaseApiKey{Id: "key-manage"}, - }, nil - } + env.DatabaseApiKeyServer.CreateDatabaseApiKeyCalls.Returns(&clusterauthv2.CreateDatabaseApiKeyResponse{ + DatabaseApiKey: &clusterauthv2.DatabaseApiKey{Id: "key-manage"}, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "key", "create", "cluster-123", "--name", "manage-key", "--access-type", "manage") require.NoError(t, err) + + req, ok := env.DatabaseApiKeyServer.CreateDatabaseApiKeyCalls.Last() + require.True(t, ok) + capturedKey := req.GetDatabaseApiKey() require.Len(t, capturedKey.GetAccessRules(), 1) globalAccess := capturedKey.GetAccessRules()[0].GetGlobalAccess() require.NotNil(t, globalAccess) @@ -61,18 +57,17 @@ func TestKeyCreate_WithManageAccessType(t *testing.T) { func TestKeyCreate_WithReadOnlyAccessType(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedKey *clusterauthv2.DatabaseApiKey - env.DatabaseApiKeyServer.CreateDatabaseApiKeyFunc = func(_ context.Context, req *clusterauthv2.CreateDatabaseApiKeyRequest) (*clusterauthv2.CreateDatabaseApiKeyResponse, error) { - capturedKey = req.GetDatabaseApiKey() - return &clusterauthv2.CreateDatabaseApiKeyResponse{ - DatabaseApiKey: &clusterauthv2.DatabaseApiKey{Id: "key-ro"}, - }, nil - } + env.DatabaseApiKeyServer.CreateDatabaseApiKeyCalls.Returns(&clusterauthv2.CreateDatabaseApiKeyResponse{ + DatabaseApiKey: &clusterauthv2.DatabaseApiKey{Id: "key-ro"}, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "key", "create", "cluster-123", "--name", "ro-key", "--access-type", "read-only") require.NoError(t, err) + + req, ok := env.DatabaseApiKeyServer.CreateDatabaseApiKeyCalls.Last() + require.True(t, ok) + capturedKey := req.GetDatabaseApiKey() require.Len(t, capturedKey.GetAccessRules(), 1) globalAccess := capturedKey.GetAccessRules()[0].GetGlobalAccess() require.NotNil(t, globalAccess) @@ -81,7 +76,6 @@ func TestKeyCreate_WithReadOnlyAccessType(t *testing.T) { func TestKeyCreate_InvalidAccessType(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "cluster", "key", "create", "cluster-123", "--name", "bad-key", "--access-type", "superuser") require.Error(t, err) @@ -90,25 +84,23 @@ func TestKeyCreate_InvalidAccessType(t *testing.T) { func TestKeyCreate_WithExpires(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedKey *clusterauthv2.DatabaseApiKey - env.DatabaseApiKeyServer.CreateDatabaseApiKeyFunc = func(_ context.Context, req *clusterauthv2.CreateDatabaseApiKeyRequest) (*clusterauthv2.CreateDatabaseApiKeyResponse, error) { - capturedKey = req.GetDatabaseApiKey() - return &clusterauthv2.CreateDatabaseApiKeyResponse{ - DatabaseApiKey: &clusterauthv2.DatabaseApiKey{Id: "key-exp"}, - }, nil - } + env.DatabaseApiKeyServer.CreateDatabaseApiKeyCalls.Returns(&clusterauthv2.CreateDatabaseApiKeyResponse{ + DatabaseApiKey: &clusterauthv2.DatabaseApiKey{Id: "key-exp"}, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "key", "create", "cluster-123", "--name", "exp-key", "--expires", "2027-06-15") require.NoError(t, err) + + req, ok := env.DatabaseApiKeyServer.CreateDatabaseApiKeyCalls.Last() + require.True(t, ok) + capturedKey := req.GetDatabaseApiKey() require.NotNil(t, capturedKey.GetExpiresAt()) assert.Equal(t, "2027-06-15", capturedKey.GetExpiresAt().AsTime().UTC().Format("2006-01-02")) } func TestKeyCreate_InvalidExpires(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "cluster", "key", "create", "cluster-123", "--name", "bad-key", "--expires", "not-a-date") require.Error(t, err) @@ -117,7 +109,6 @@ func TestKeyCreate_InvalidExpires(t *testing.T) { func TestKeyCreate_MissingName(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "cluster", "key", "create", "cluster-123") require.Error(t, err) diff --git a/internal/cmd/cluster/key_delete_test.go b/internal/cmd/cluster/key_delete_test.go index 94c0361..4380ece 100644 --- a/internal/cmd/cluster/key_delete_test.go +++ b/internal/cmd/cluster/key_delete_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,26 +13,23 @@ import ( func TestKeyDelete_WithForce(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("test-account-id")) - t.Cleanup(env.Cleanup) - var capturedReq *clusterauthv2.DeleteDatabaseApiKeyRequest - env.DatabaseApiKeyServer.DeleteDatabaseApiKeyFunc = func(_ context.Context, req *clusterauthv2.DeleteDatabaseApiKeyRequest) (*clusterauthv2.DeleteDatabaseApiKeyResponse, error) { - capturedReq = req - return &clusterauthv2.DeleteDatabaseApiKeyResponse{}, nil - } + env.DatabaseApiKeyServer.DeleteDatabaseApiKeyCalls.Returns(&clusterauthv2.DeleteDatabaseApiKeyResponse{}, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "key", "delete", "cluster-123", "key-abc", "--force") require.NoError(t, err) - assert.Equal(t, "test-account-id", capturedReq.GetAccountId()) - assert.Equal(t, "cluster-123", capturedReq.GetClusterId()) - assert.Equal(t, "key-abc", capturedReq.GetDatabaseApiKeyId()) assert.Contains(t, stdout, "key-abc") assert.Contains(t, stdout, "deleted") + + req, ok := env.DatabaseApiKeyServer.DeleteDatabaseApiKeyCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) + assert.Equal(t, "cluster-123", req.GetClusterId()) + assert.Equal(t, "key-abc", req.GetDatabaseApiKeyId()) } func TestKeyDelete_MissingArgs(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "cluster", "key", "delete", "cluster-123") require.Error(t, err) diff --git a/internal/cmd/cluster/key_list_test.go b/internal/cmd/cluster/key_list_test.go index 8503ab1..4c3a0eb 100644 --- a/internal/cmd/cluster/key_list_test.go +++ b/internal/cmd/cluster/key_list_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "encoding/json" "testing" "time" @@ -17,25 +16,19 @@ import ( func TestKeyList_TableOutput(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("test-account-id")) - t.Cleanup(env.Cleanup) expires := time.Date(2027, 1, 1, 0, 0, 0, 0, time.UTC) - - env.DatabaseApiKeyServer.ListDatabaseApiKeysFunc = func(_ context.Context, req *clusterauthv2.ListDatabaseApiKeysRequest) (*clusterauthv2.ListDatabaseApiKeysResponse, error) { - assert.Equal(t, "test-account-id", req.GetAccountId()) - assert.Equal(t, "cluster-123", req.GetClusterId()) - return &clusterauthv2.ListDatabaseApiKeysResponse{ - Items: []*clusterauthv2.DatabaseApiKey{ - { - Id: "key-abc", - Name: "my-key", - Postfix: "xyz", - CreatedAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), - ExpiresAt: timestamppb.New(expires), - }, + env.DatabaseApiKeyServer.ListDatabaseApiKeysCalls.Returns(&clusterauthv2.ListDatabaseApiKeysResponse{ + Items: []*clusterauthv2.DatabaseApiKey{ + { + Id: "key-abc", + Name: "my-key", + Postfix: "xyz", + CreatedAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), + ExpiresAt: timestamppb.New(expires), }, - }, nil - } + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "key", "list", "cluster-123") require.NoError(t, err) @@ -49,19 +42,21 @@ func TestKeyList_TableOutput(t *testing.T) { assert.Contains(t, stdout, "xyz") assert.Contains(t, stdout, "ago") assert.Contains(t, stdout, "2027-01-01") + + req, ok := env.DatabaseApiKeyServer.ListDatabaseApiKeysCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) + assert.Equal(t, "cluster-123", req.GetClusterId()) } func TestKeyList_JSONOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.DatabaseApiKeyServer.ListDatabaseApiKeysFunc = func(_ context.Context, _ *clusterauthv2.ListDatabaseApiKeysRequest) (*clusterauthv2.ListDatabaseApiKeysResponse, error) { - return &clusterauthv2.ListDatabaseApiKeysResponse{ - Items: []*clusterauthv2.DatabaseApiKey{ - {Id: "key-json", Name: "json-key"}, - }, - }, nil - } + env.DatabaseApiKeyServer.ListDatabaseApiKeysCalls.Returns(&clusterauthv2.ListDatabaseApiKeysResponse{ + Items: []*clusterauthv2.DatabaseApiKey{ + {Id: "key-json", Name: "json-key"}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "key", "list", "cluster-123", "--json") require.NoError(t, err) @@ -80,11 +75,8 @@ func TestKeyList_JSONOutput(t *testing.T) { func TestKeyList_EmptyResponse(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.DatabaseApiKeyServer.ListDatabaseApiKeysFunc = func(_ context.Context, _ *clusterauthv2.ListDatabaseApiKeysRequest) (*clusterauthv2.ListDatabaseApiKeysResponse, error) { - return &clusterauthv2.ListDatabaseApiKeysResponse{}, nil - } + env.DatabaseApiKeyServer.ListDatabaseApiKeysCalls.Returns(&clusterauthv2.ListDatabaseApiKeysResponse{}, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "key", "list", "cluster-123") require.NoError(t, err) @@ -94,15 +86,13 @@ func TestKeyList_EmptyResponse(t *testing.T) { func TestKeyList_ClusterIDPassedToServer(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedClusterID string - env.DatabaseApiKeyServer.ListDatabaseApiKeysFunc = func(_ context.Context, req *clusterauthv2.ListDatabaseApiKeysRequest) (*clusterauthv2.ListDatabaseApiKeysResponse, error) { - capturedClusterID = req.GetClusterId() - return &clusterauthv2.ListDatabaseApiKeysResponse{}, nil - } + env.DatabaseApiKeyServer.ListDatabaseApiKeysCalls.Returns(&clusterauthv2.ListDatabaseApiKeysResponse{}, nil) _, _, err := testutil.Exec(t, env, "cluster", "key", "list", "my-cluster-id") require.NoError(t, err) - assert.Equal(t, "my-cluster-id", capturedClusterID) + + req, ok := env.DatabaseApiKeyServer.ListDatabaseApiKeysCalls.Last() + require.True(t, ok) + assert.Equal(t, "my-cluster-id", req.GetClusterId()) } diff --git a/internal/cmd/cluster/list_test.go b/internal/cmd/cluster/list_test.go index 7ad6b21..a3f0843 100644 --- a/internal/cmd/cluster/list_test.go +++ b/internal/cmd/cluster/list_test.go @@ -16,29 +16,26 @@ import ( func TestListClusters_TableOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - return &clusterv1.ListClustersResponse{ - Items: []*clusterv1.Cluster{ - { - Id: "cluster-1", - Name: "my-cluster", - CloudProviderId: "aws", - CloudProviderRegionId: "us-east-1", - State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_HEALTHY}, - Configuration: &clusterv1.ClusterConfiguration{Version: new("1.8.0")}, - CreatedAt: timestamppb.New(time.Now().Add(-3 * time.Hour)), - }, - { - Id: "cluster-2", - Name: "other-cluster", - CloudProviderId: "gcp", - CloudProviderRegionId: "europe-west1", - }, + + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{ + Items: []*clusterv1.Cluster{ + { + Id: "cluster-1", + Name: "my-cluster", + CloudProviderId: "aws", + CloudProviderRegionId: "us-east-1", + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_HEALTHY}, + Configuration: &clusterv1.ClusterConfiguration{Version: new("1.8.0")}, + CreatedAt: timestamppb.New(time.Now().Add(-3 * time.Hour)), + }, + { + Id: "cluster-2", + Name: "other-cluster", + CloudProviderId: "gcp", + CloudProviderRegionId: "europe-west1", }, - }, nil - } + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "list") require.NoError(t, err) @@ -60,18 +57,15 @@ func TestListClusters_TableOutput(t *testing.T) { func TestListClusters_JSONOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - return &clusterv1.ListClustersResponse{ - Items: []*clusterv1.Cluster{ - { - Id: "json-cluster", - Name: "json-name", - }, + + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{ + Items: []*clusterv1.Cluster{ + { + Id: "json-cluster", + Name: "json-name", }, - }, nil - } + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "list", "--json") require.NoError(t, err) @@ -83,11 +77,8 @@ func TestListClusters_JSONOutput(t *testing.T) { func TestListClusters_EmptyResponse(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "list") require.NoError(t, err) @@ -99,11 +90,8 @@ func TestListClusters_EmptyResponse(t *testing.T) { func TestListClusters_AuthMetadata(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAPIKey("my-secret-key")) - t.Cleanup(env.Cleanup) - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "cluster", "list") require.NoError(t, err) @@ -117,11 +105,8 @@ func TestListClusters_AuthMetadata(t *testing.T) { func TestListClusters_UserAgent(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithVersion("1.2.3")) - t.Cleanup(env.Cleanup) - env.Server.ListClustersFunc = func(_ context.Context, _ *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "cluster", "list") require.NoError(t, err) @@ -135,42 +120,38 @@ func TestListClusters_UserAgent(t *testing.T) { func TestListClusters_AccountIDPassedToServer(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedAccountID string - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - capturedAccountID = req.GetAccountId() - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "cluster", "list") require.NoError(t, err) - assert.Equal(t, "test-account-id", capturedAccountID) + + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) } func TestListClusters_AutoPaginateMultiplePages(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) token := "page-2-token" - callCount := 0 - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - callCount++ - if req.PageToken == nil || *req.PageToken == "" { + env.Server.ListClustersCalls. + OnCall(0, func(_ context.Context, _ *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { return &clusterv1.ListClustersResponse{ Items: []*clusterv1.Cluster{{Id: "cluster-1", Name: "first"}}, NextPageToken: &token, }, nil - } - return &clusterv1.ListClustersResponse{ - Items: []*clusterv1.Cluster{{Id: "cluster-2", Name: "second"}}, - }, nil - } + }). + OnCall(1, func(_ context.Context, _ *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { + return &clusterv1.ListClustersResponse{ + Items: []*clusterv1.Cluster{{Id: "cluster-2", Name: "second"}}, + }, nil + }) stdout, _, err := testutil.Exec(t, env, "cluster", "list") require.NoError(t, err) - assert.Equal(t, 2, callCount) + assert.Equal(t, 2, env.Server.ListClustersCalls.Count()) assert.Contains(t, stdout, "cluster-1") assert.Contains(t, stdout, "cluster-2") // No next page token footer when auto-paginating. @@ -179,101 +160,80 @@ func TestListClusters_AutoPaginateMultiplePages(t *testing.T) { func TestListClusters_PageSizeFlagSingleRequest(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) token := "next-token" - var capturedPageSize *int32 - callCount := 0 - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - callCount++ - capturedPageSize = req.PageSize - return &clusterv1.ListClustersResponse{ - Items: []*clusterv1.Cluster{{Id: "cluster-1"}}, - NextPageToken: &token, - }, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{ + Items: []*clusterv1.Cluster{{Id: "cluster-1"}}, + NextPageToken: &token, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "list", "--page-size", "1") require.NoError(t, err) - assert.Equal(t, 1, callCount) - require.NotNil(t, capturedPageSize) - assert.Equal(t, int32(1), *capturedPageSize) + assert.Equal(t, 1, env.Server.ListClustersCalls.Count()) + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + require.NotNil(t, req.PageSize) + assert.Equal(t, int32(1), *req.PageSize) assert.Contains(t, stdout, "Next page token: next-token") } func TestListClusters_PageTokenFlagSingleRequest(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - var capturedPageToken *string - callCount := 0 - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - callCount++ - capturedPageToken = req.PageToken - return &clusterv1.ListClustersResponse{ - Items: []*clusterv1.Cluster{{Id: "cluster-2"}}, - }, nil - } + + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{ + Items: []*clusterv1.Cluster{{Id: "cluster-2"}}, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "list", "--page-token", "my-token") require.NoError(t, err) - assert.Equal(t, 1, callCount) - require.NotNil(t, capturedPageToken) - assert.Equal(t, "my-token", *capturedPageToken) + assert.Equal(t, 1, env.Server.ListClustersCalls.Count()) + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + require.NotNil(t, req.PageToken) + assert.Equal(t, "my-token", *req.PageToken) } func TestListClusters_CloudProviderFilter(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedReq *clusterv1.ListClustersRequest - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - capturedReq = req - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "cluster", "list", "--cloud-provider", "aws") require.NoError(t, err) - require.NotNil(t, capturedReq) - require.NotNil(t, capturedReq.CloudProviderId) - assert.Equal(t, "aws", *capturedReq.CloudProviderId) - assert.Nil(t, capturedReq.CloudProviderRegionId) + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + require.NotNil(t, req.CloudProviderId) + assert.Equal(t, "aws", *req.CloudProviderId) + assert.Nil(t, req.CloudProviderRegionId) } func TestListClusters_CloudRegionFilter(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var capturedReq *clusterv1.ListClustersRequest - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - capturedReq = req - return &clusterv1.ListClustersResponse{}, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{}, nil) _, _, err := testutil.Exec(t, env, "cluster", "list", "--cloud-provider", "aws", "--cloud-region", "us-east-1") require.NoError(t, err) - require.NotNil(t, capturedReq) - require.NotNil(t, capturedReq.CloudProviderId) - assert.Equal(t, "aws", *capturedReq.CloudProviderId) - require.NotNil(t, capturedReq.CloudProviderRegionId) - assert.Equal(t, "us-east-1", *capturedReq.CloudProviderRegionId) + req, ok := env.Server.ListClustersCalls.Last() + require.True(t, ok) + require.NotNil(t, req.CloudProviderId) + assert.Equal(t, "aws", *req.CloudProviderId) + require.NotNil(t, req.CloudProviderRegionId) + assert.Equal(t, "us-east-1", *req.CloudProviderRegionId) } func TestListClusters_NextPageTokenPrintedAsFooter(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) token := "footer-token" - env.Server.ListClustersFunc = func(_ context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { - return &clusterv1.ListClustersResponse{ - Items: []*clusterv1.Cluster{{Id: "cluster-1"}}, - NextPageToken: &token, - }, nil - } + env.Server.ListClustersCalls.Returns(&clusterv1.ListClustersResponse{ + Items: []*clusterv1.Cluster{{Id: "cluster-1"}}, + NextPageToken: &token, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "list", "--page-size", "1") require.NoError(t, err) diff --git a/internal/cmd/cluster/package_test.go b/internal/cmd/cluster/package_test.go index 2193208..8590722 100644 --- a/internal/cmd/cluster/package_test.go +++ b/internal/cmd/cluster/package_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,26 +13,23 @@ import ( func TestListPackages_TableOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BookingServer.ListPackagesFunc = func(_ context.Context, _ *bookingv1.ListPackagesRequest) (*bookingv1.ListPackagesResponse, error) { - return &bookingv1.ListPackagesResponse{ - Items: []*bookingv1.Package{ - { - Id: "pkg-123", - Name: "starter", - Tier: bookingv1.PackageTier_PACKAGE_TIER_STANDARD, - ResourceConfiguration: &bookingv1.ResourceConfiguration{ - Ram: "1GiB", - Cpu: "0.5", - Disk: "10GiB", - }, - UnitIntPricePerHour: 5000, - Currency: "USD", + env.BookingServer.ListPackagesCalls.Returns(&bookingv1.ListPackagesResponse{ + Items: []*bookingv1.Package{ + { + Id: "pkg-123", + Name: "starter", + Tier: bookingv1.PackageTier_PACKAGE_TIER_STANDARD, + ResourceConfiguration: &bookingv1.ResourceConfiguration{ + Ram: "1GiB", + Cpu: "0.5", + Disk: "10GiB", }, + UnitIntPricePerHour: 5000, + Currency: "USD", }, - }, nil - } + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "package", "list", @@ -52,19 +48,16 @@ func TestListPackages_TableOutput(t *testing.T) { func TestListPackages_FreePackage(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.BookingServer.ListPackagesFunc = func(_ context.Context, _ *bookingv1.ListPackagesRequest) (*bookingv1.ListPackagesResponse, error) { - return &bookingv1.ListPackagesResponse{ - Items: []*bookingv1.Package{ - { - Id: "pkg-free", - Name: "free", - UnitIntPricePerHour: 0, - }, + env.BookingServer.ListPackagesCalls.Returns(&bookingv1.ListPackagesResponse{ + Items: []*bookingv1.Package{ + { + Id: "pkg-free", + Name: "free", + UnitIntPricePerHour: 0, }, - }, nil - } + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "package", "list", diff --git a/internal/cmd/cluster/restart_test.go b/internal/cmd/cluster/restart_test.go index 85e6698..250a75c 100644 --- a/internal/cmd/cluster/restart_test.go +++ b/internal/cmd/cluster/restart_test.go @@ -2,7 +2,6 @@ package cluster_test import ( "context" - "sync/atomic" "testing" "github.com/stretchr/testify/assert" @@ -15,25 +14,22 @@ import ( func TestRestart_WithForce(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("test-account-id")) - t.Cleanup(env.Cleanup) - var capturedReq *clusterv1.RestartClusterRequest - env.Server.RestartClusterFunc = func(_ context.Context, req *clusterv1.RestartClusterRequest) (*clusterv1.RestartClusterResponse, error) { - capturedReq = req - return &clusterv1.RestartClusterResponse{}, nil - } + env.Server.RestartClusterCalls.Returns(&clusterv1.RestartClusterResponse{}, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "restart", "cluster-123", "--force") require.NoError(t, err) - assert.Equal(t, "test-account-id", capturedReq.GetAccountId()) - assert.Equal(t, "cluster-123", capturedReq.GetClusterId()) assert.Contains(t, stdout, "cluster-123") assert.Contains(t, stdout, "restarting") + + req, ok := env.Server.RestartClusterCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) + assert.Equal(t, "cluster-123", req.GetClusterId()) } func TestRestart_MissingArgs(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "cluster", "restart") require.Error(t, err) @@ -41,57 +37,53 @@ func TestRestart_MissingArgs(t *testing.T) { func TestRestart_NoWait(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - var getCallCount int32 - env.Server.RestartClusterFunc = func(_ context.Context, _ *clusterv1.RestartClusterRequest) (*clusterv1.RestartClusterResponse, error) { - return &clusterv1.RestartClusterResponse{}, nil - } - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - atomic.AddInt32(&getCallCount, 1) - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_UPDATING}, - }, - }, nil - } + + env.Server.RestartClusterCalls.Returns(&clusterv1.RestartClusterResponse{}, nil) + env.Server.GetClusterCalls.Returns(&clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_UPDATING}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "restart", "cluster-123", "--force") require.NoError(t, err) assert.Contains(t, stdout, "restarting") - assert.EqualValues(t, 0, atomic.LoadInt32(&getCallCount), "GetCluster should not be called without --wait") + assert.Equal(t, 0, env.Server.GetClusterCalls.Count(), "GetCluster should not be called without --wait") } func TestRestart_WaitSuccess(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.RestartClusterFunc = func(_ context.Context, _ *clusterv1.RestartClusterRequest) (*clusterv1.RestartClusterResponse, error) { - return &clusterv1.RestartClusterResponse{}, nil - } - - var callCount int32 - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - n := atomic.AddInt32(&callCount, 1) - if n < 3 { + env.Server.RestartClusterCalls.Returns(&clusterv1.RestartClusterResponse{}, nil) + env.Server.GetClusterCalls. + OnCall(0, func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { return &clusterv1.GetClusterResponse{ Cluster: &clusterv1.Cluster{ Id: "cluster-123", State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_UPDATING}, }, }, nil - } - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-123", - Name: "my-cluster", - State: &clusterv1.ClusterState{ - Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_HEALTHY, - Endpoint: &clusterv1.ClusterEndpoint{Url: "https://cluster-123.aws.cloud.qdrant.io"}, + }). + OnCall(1, func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + return &clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-123", + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_UPDATING}, }, - }, - }, nil - } + }, nil + }). + Always(func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + return &clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-123", + Name: "my-cluster", + State: &clusterv1.ClusterState{ + Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_HEALTHY, + Endpoint: &clusterv1.ClusterEndpoint{Url: "https://cluster-123.aws.cloud.qdrant.io"}, + }, + }, + }, nil + }) stdout, stderr, err := testutil.Exec(t, env, "cluster", "restart", "cluster-123", "--force", @@ -108,22 +100,17 @@ func TestRestart_WaitSuccess(t *testing.T) { func TestRestart_WaitFailure(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.Server.RestartClusterFunc = func(_ context.Context, _ *clusterv1.RestartClusterRequest) (*clusterv1.RestartClusterResponse, error) { - return &clusterv1.RestartClusterResponse{}, nil - } - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-123", - State: &clusterv1.ClusterState{ - Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_FAILED_TO_SYNC, - Reason: "sync failed", - }, + + env.Server.RestartClusterCalls.Returns(&clusterv1.RestartClusterResponse{}, nil) + env.Server.GetClusterCalls.Returns(&clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-123", + State: &clusterv1.ClusterState{ + Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_FAILED_TO_SYNC, + Reason: "sync failed", }, - }, nil - } + }, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "restart", "cluster-123", "--force", @@ -138,19 +125,14 @@ func TestRestart_WaitFailure(t *testing.T) { func TestRestart_WaitTimeout(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - - env.Server.RestartClusterFunc = func(_ context.Context, _ *clusterv1.RestartClusterRequest) (*clusterv1.RestartClusterResponse, error) { - return &clusterv1.RestartClusterResponse{}, nil - } - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-123", - State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_UPDATING}, - }, - }, nil - } + + env.Server.RestartClusterCalls.Returns(&clusterv1.RestartClusterResponse{}, nil) + env.Server.GetClusterCalls.Returns(&clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-123", + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_UPDATING}, + }, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "restart", "cluster-123", "--force", diff --git a/internal/cmd/cluster/suspend_test.go b/internal/cmd/cluster/suspend_test.go index c80041d..b46d0bd 100644 --- a/internal/cmd/cluster/suspend_test.go +++ b/internal/cmd/cluster/suspend_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,25 +13,22 @@ import ( func TestSuspend_WithForce(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("test-account-id")) - t.Cleanup(env.Cleanup) - var capturedReq *clusterv1.SuspendClusterRequest - env.Server.SuspendClusterFunc = func(_ context.Context, req *clusterv1.SuspendClusterRequest) (*clusterv1.SuspendClusterResponse, error) { - capturedReq = req - return &clusterv1.SuspendClusterResponse{}, nil - } + env.Server.SuspendClusterCalls.Returns(&clusterv1.SuspendClusterResponse{}, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "suspend", "cluster-123", "--force") require.NoError(t, err) - assert.Equal(t, "test-account-id", capturedReq.GetAccountId()) - assert.Equal(t, "cluster-123", capturedReq.GetClusterId()) + + req, ok := env.Server.SuspendClusterCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) + assert.Equal(t, "cluster-123", req.GetClusterId()) assert.Contains(t, stdout, "cluster-123") assert.Contains(t, stdout, "suspended") } func TestSuspend_MissingArgs(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "cluster", "suspend") require.Error(t, err) diff --git a/internal/cmd/cluster/unsuspend_test.go b/internal/cmd/cluster/unsuspend_test.go index 431590f..41bb048 100644 --- a/internal/cmd/cluster/unsuspend_test.go +++ b/internal/cmd/cluster/unsuspend_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,25 +13,22 @@ import ( func TestUnsuspend(t *testing.T) { env := testutil.NewTestEnv(t, testutil.WithAccountID("test-account-id")) - t.Cleanup(env.Cleanup) - var capturedReq *clusterv1.UnsuspendClusterRequest - env.Server.UnsuspendClusterFunc = func(_ context.Context, req *clusterv1.UnsuspendClusterRequest) (*clusterv1.UnsuspendClusterResponse, error) { - capturedReq = req - return &clusterv1.UnsuspendClusterResponse{}, nil - } + env.Server.UnsuspendClusterCalls.Returns(&clusterv1.UnsuspendClusterResponse{}, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "unsuspend", "cluster-123") require.NoError(t, err) - assert.Equal(t, "test-account-id", capturedReq.GetAccountId()) - assert.Equal(t, "cluster-123", capturedReq.GetClusterId()) + + req, ok := env.Server.UnsuspendClusterCalls.Last() + require.True(t, ok) + assert.Equal(t, "test-account-id", req.GetAccountId()) + assert.Equal(t, "cluster-123", req.GetClusterId()) assert.Contains(t, stdout, "cluster-123") assert.Contains(t, stdout, "unsuspending") } func TestUnsuspend_MissingArgs(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) _, _, err := testutil.Exec(t, env, "cluster", "unsuspend") require.Error(t, err) diff --git a/internal/cmd/cluster/update_test.go b/internal/cmd/cluster/update_test.go index ba9deaf..4a701bc 100644 --- a/internal/cmd/cluster/update_test.go +++ b/internal/cmd/cluster/update_test.go @@ -15,24 +15,17 @@ import ( func TestUpdateCluster_SetLabels(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.GetClusterFunc = func(_ context.Context, req *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + env.Server.GetClusterCalls.Always(func(_ context.Context, req *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { return &clusterv1.GetClusterResponse{ Cluster: &clusterv1.Cluster{Id: req.GetClusterId(), Name: "my-cluster"}, }, nil - } - - var capturedLabels map[string]string - env.Server.UpdateClusterFunc = func(_ context.Context, req *clusterv1.UpdateClusterRequest) (*clusterv1.UpdateClusterResponse, error) { - capturedLabels = make(map[string]string) - for _, kv := range req.GetCluster().GetLabels() { - capturedLabels[kv.GetKey()] = kv.GetValue() - } + }) + env.Server.UpdateClusterCalls.Always(func(_ context.Context, req *clusterv1.UpdateClusterRequest) (*clusterv1.UpdateClusterResponse, error) { return &clusterv1.UpdateClusterResponse{ Cluster: req.GetCluster(), }, nil - } + }) stdout, _, err := testutil.Exec(t, env, "cluster", "update", "cluster-abc", @@ -42,45 +35,47 @@ func TestUpdateCluster_SetLabels(t *testing.T) { require.NoError(t, err) assert.Contains(t, stdout, "cluster-abc") assert.Contains(t, stdout, "updated successfully") + + req, ok := env.Server.UpdateClusterCalls.Last() + require.True(t, ok) + capturedLabels := make(map[string]string) + for _, kv := range req.GetCluster().GetLabels() { + capturedLabels[kv.GetKey()] = kv.GetValue() + } assert.Equal(t, map[string]string{"env": "prod", "team": "platform"}, capturedLabels) } func TestUpdateCluster_ClearLabels(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.GetClusterFunc = func(_ context.Context, req *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + env.Server.GetClusterCalls.Always(func(_ context.Context, req *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { return &clusterv1.GetClusterResponse{ Cluster: &clusterv1.Cluster{Id: req.GetClusterId(), Name: "my-cluster"}, }, nil - } - - var capturedLabelCount int - env.Server.UpdateClusterFunc = func(_ context.Context, req *clusterv1.UpdateClusterRequest) (*clusterv1.UpdateClusterResponse, error) { - capturedLabelCount = len(req.GetCluster().GetLabels()) + }) + env.Server.UpdateClusterCalls.Always(func(_ context.Context, req *clusterv1.UpdateClusterRequest) (*clusterv1.UpdateClusterResponse, error) { return &clusterv1.UpdateClusterResponse{ Cluster: req.GetCluster(), }, nil - } + }) _, _, err := testutil.Exec(t, env, "cluster", "update", "cluster-abc") require.NoError(t, err) - assert.Equal(t, 0, capturedLabelCount) + + req, ok := env.Server.UpdateClusterCalls.Last() + require.True(t, ok) + assert.Empty(t, req.GetCluster().GetLabels()) } func TestUpdateCluster_APIError(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.GetClusterFunc = func(_ context.Context, req *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + env.Server.GetClusterCalls.Always(func(_ context.Context, req *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { return &clusterv1.GetClusterResponse{ Cluster: &clusterv1.Cluster{Id: req.GetClusterId()}, }, nil - } - - env.Server.UpdateClusterFunc = func(_ context.Context, _ *clusterv1.UpdateClusterRequest) (*clusterv1.UpdateClusterResponse, error) { - return nil, fmt.Errorf("internal server error") - } + }) + env.Server.UpdateClusterCalls.Returns(nil, fmt.Errorf("internal server error")) _, _, err := testutil.Exec(t, env, "cluster", "update", "cluster-abc") require.Error(t, err) diff --git a/internal/cmd/cluster/version_test.go b/internal/cmd/cluster/version_test.go index bb76859..c8153d3 100644 --- a/internal/cmd/cluster/version_test.go +++ b/internal/cmd/cluster/version_test.go @@ -1,7 +1,6 @@ package cluster_test import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -14,19 +13,16 @@ import ( func TestListVersions_TableOutput(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) remarks := "upgrade recommended" - env.Server.ListQdrantReleasesFunc = func(_ context.Context, _ *clusterv1.ListQdrantReleasesRequest) (*clusterv1.ListQdrantReleasesResponse, error) { - return &clusterv1.ListQdrantReleasesResponse{ - Items: []*clusterv1.QdrantRelease{ - {Version: "1.14.0", Default: true}, - {Version: "1.13.0", EndOfLife: true}, - {Version: "1.12.0", Unavailable: true}, - {Version: "1.11.0", Remarks: &remarks}, - }, - }, nil - } + env.Server.ListQdrantReleasesCalls.Returns(&clusterv1.ListQdrantReleasesResponse{ + Items: []*clusterv1.QdrantRelease{ + {Version: "1.14.0", Default: true}, + {Version: "1.13.0", EndOfLife: true}, + {Version: "1.12.0", Unavailable: true}, + {Version: "1.11.0", Remarks: &remarks}, + }, + }, nil) stdout, _, err := testutil.Exec(t, env, "cluster", "version", "list") require.NoError(t, err) diff --git a/internal/cmd/cluster/wait_test.go b/internal/cmd/cluster/wait_test.go index 9166c68..ee64af6 100644 --- a/internal/cmd/cluster/wait_test.go +++ b/internal/cmd/cluster/wait_test.go @@ -2,7 +2,6 @@ package cluster_test import ( "context" - "sync/atomic" "testing" "github.com/stretchr/testify/assert" @@ -15,30 +14,36 @@ import ( func TestWaitCluster_Success(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - var callCount int32 - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - n := atomic.AddInt32(&callCount, 1) - if n < 3 { + env.Server.GetClusterCalls. + OnCall(0, func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { return &clusterv1.GetClusterResponse{ Cluster: &clusterv1.Cluster{ Id: "cluster-abc", State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, }, }, nil - } - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-abc", - Name: "my-cluster", - State: &clusterv1.ClusterState{ - Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_HEALTHY, - Endpoint: &clusterv1.ClusterEndpoint{Url: "https://abc.aws.cloud.qdrant.io"}, + }). + OnCall(1, func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + return &clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-abc", + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, }, - }, - }, nil - } + }, nil + }). + Always(func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { + return &clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-abc", + Name: "my-cluster", + State: &clusterv1.ClusterState{ + Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_HEALTHY, + Endpoint: &clusterv1.ClusterEndpoint{Url: "https://abc.aws.cloud.qdrant.io"}, + }, + }, + }, nil + }) stdout, stderr, err := testutil.Exec(t, env, "cluster", "wait", "cluster-abc", @@ -53,19 +58,16 @@ func TestWaitCluster_Success(t *testing.T) { func TestWaitCluster_Failure(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-fail", - State: &clusterv1.ClusterState{ - Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_FAILED_TO_CREATE, - Reason: "quota exceeded", - }, + env.Server.GetClusterCalls.Returns(&clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-fail", + State: &clusterv1.ClusterState{ + Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_FAILED_TO_CREATE, + Reason: "quota exceeded", }, - }, nil - } + }, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "wait", "cluster-fail", @@ -79,16 +81,13 @@ func TestWaitCluster_Failure(t *testing.T) { func TestWaitCluster_Timeout(t *testing.T) { env := testutil.NewTestEnv(t) - t.Cleanup(env.Cleanup) - env.Server.GetClusterFunc = func(_ context.Context, _ *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { - return &clusterv1.GetClusterResponse{ - Cluster: &clusterv1.Cluster{ - Id: "cluster-slow", - State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, - }, - }, nil - } + env.Server.GetClusterCalls.Returns(&clusterv1.GetClusterResponse{ + Cluster: &clusterv1.Cluster{ + Id: "cluster-slow", + State: &clusterv1.ClusterState{Phase: clusterv1.ClusterPhase_CLUSTER_PHASE_CREATING}, + }, + }, nil) _, _, err := testutil.Exec(t, env, "cluster", "wait", "cluster-slow", diff --git a/internal/testutil/fake_backup.go b/internal/testutil/fake_backup.go index ed0fbf4..558e3d5 100644 --- a/internal/testutil/fake_backup.go +++ b/internal/testutil/fake_backup.go @@ -7,22 +7,10 @@ import ( ) // FakeBackupService is a test fake that implements BackupServiceServer. -// Set the function fields to control responses per test. +// Use the *Calls fields to configure responses and inspect captured requests. type FakeBackupService struct { backupv1.UnimplementedBackupServiceServer - ListBackupsFunc func(context.Context, *backupv1.ListBackupsRequest) (*backupv1.ListBackupsResponse, error) - GetBackupFunc func(context.Context, *backupv1.GetBackupRequest) (*backupv1.GetBackupResponse, error) - CreateBackupFunc func(context.Context, *backupv1.CreateBackupRequest) (*backupv1.CreateBackupResponse, error) - DeleteBackupFunc func(context.Context, *backupv1.DeleteBackupRequest) (*backupv1.DeleteBackupResponse, error) - ListBackupRestoresFunc func(context.Context, *backupv1.ListBackupRestoresRequest) (*backupv1.ListBackupRestoresResponse, error) - RestoreBackupFunc func(context.Context, *backupv1.RestoreBackupRequest) (*backupv1.RestoreBackupResponse, error) - ListBackupSchedulesFunc func(context.Context, *backupv1.ListBackupSchedulesRequest) (*backupv1.ListBackupSchedulesResponse, error) - GetBackupScheduleFunc func(context.Context, *backupv1.GetBackupScheduleRequest) (*backupv1.GetBackupScheduleResponse, error) - CreateBackupScheduleFunc func(context.Context, *backupv1.CreateBackupScheduleRequest) (*backupv1.CreateBackupScheduleResponse, error) - UpdateBackupScheduleFunc func(context.Context, *backupv1.UpdateBackupScheduleRequest) (*backupv1.UpdateBackupScheduleResponse, error) - DeleteBackupScheduleFunc func(context.Context, *backupv1.DeleteBackupScheduleRequest) (*backupv1.DeleteBackupScheduleResponse, error) - ListBackupsCalls MethodSpy[*backupv1.ListBackupsRequest, *backupv1.ListBackupsResponse] GetBackupCalls MethodSpy[*backupv1.GetBackupRequest, *backupv1.GetBackupResponse] CreateBackupCalls MethodSpy[*backupv1.CreateBackupRequest, *backupv1.CreateBackupResponse] @@ -36,101 +24,68 @@ type FakeBackupService struct { DeleteBackupScheduleCalls MethodSpy[*backupv1.DeleteBackupScheduleRequest, *backupv1.DeleteBackupScheduleResponse] } -// ListBackups delegates to ListBackupsFunc if set, otherwise dispatches via ListBackupsCalls. +// ListBackups records the call and dispatches via ListBackupsCalls. func (f *FakeBackupService) ListBackups(ctx context.Context, req *backupv1.ListBackupsRequest) (*backupv1.ListBackupsResponse, error) { f.ListBackupsCalls.record(req) - if f.ListBackupsFunc != nil { - return f.ListBackupsFunc(ctx, req) - } return f.ListBackupsCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.ListBackups) } -// GetBackup delegates to GetBackupFunc if set, otherwise dispatches via GetBackupCalls. +// GetBackup records the call and dispatches via GetBackupCalls. func (f *FakeBackupService) GetBackup(ctx context.Context, req *backupv1.GetBackupRequest) (*backupv1.GetBackupResponse, error) { f.GetBackupCalls.record(req) - if f.GetBackupFunc != nil { - return f.GetBackupFunc(ctx, req) - } return f.GetBackupCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.GetBackup) } -// CreateBackup delegates to CreateBackupFunc if set, otherwise dispatches via CreateBackupCalls. +// CreateBackup records the call and dispatches via CreateBackupCalls. func (f *FakeBackupService) CreateBackup(ctx context.Context, req *backupv1.CreateBackupRequest) (*backupv1.CreateBackupResponse, error) { f.CreateBackupCalls.record(req) - if f.CreateBackupFunc != nil { - return f.CreateBackupFunc(ctx, req) - } return f.CreateBackupCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.CreateBackup) } -// DeleteBackup delegates to DeleteBackupFunc if set, otherwise dispatches via DeleteBackupCalls. +// DeleteBackup records the call and dispatches via DeleteBackupCalls. func (f *FakeBackupService) DeleteBackup(ctx context.Context, req *backupv1.DeleteBackupRequest) (*backupv1.DeleteBackupResponse, error) { f.DeleteBackupCalls.record(req) - if f.DeleteBackupFunc != nil { - return f.DeleteBackupFunc(ctx, req) - } return f.DeleteBackupCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.DeleteBackup) } -// ListBackupRestores delegates to ListBackupRestoresFunc if set, otherwise dispatches via ListBackupRestoresCalls. +// ListBackupRestores records the call and dispatches via ListBackupRestoresCalls. func (f *FakeBackupService) ListBackupRestores(ctx context.Context, req *backupv1.ListBackupRestoresRequest) (*backupv1.ListBackupRestoresResponse, error) { f.ListBackupRestoresCalls.record(req) - if f.ListBackupRestoresFunc != nil { - return f.ListBackupRestoresFunc(ctx, req) - } return f.ListBackupRestoresCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.ListBackupRestores) } -// RestoreBackup delegates to RestoreBackupFunc if set, otherwise dispatches via RestoreBackupCalls. +// RestoreBackup records the call and dispatches via RestoreBackupCalls. func (f *FakeBackupService) RestoreBackup(ctx context.Context, req *backupv1.RestoreBackupRequest) (*backupv1.RestoreBackupResponse, error) { f.RestoreBackupCalls.record(req) - if f.RestoreBackupFunc != nil { - return f.RestoreBackupFunc(ctx, req) - } return f.RestoreBackupCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.RestoreBackup) } -// ListBackupSchedules delegates to ListBackupSchedulesFunc if set, otherwise dispatches via ListBackupSchedulesCalls. +// ListBackupSchedules records the call and dispatches via ListBackupSchedulesCalls. func (f *FakeBackupService) ListBackupSchedules(ctx context.Context, req *backupv1.ListBackupSchedulesRequest) (*backupv1.ListBackupSchedulesResponse, error) { f.ListBackupSchedulesCalls.record(req) - if f.ListBackupSchedulesFunc != nil { - return f.ListBackupSchedulesFunc(ctx, req) - } return f.ListBackupSchedulesCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.ListBackupSchedules) } -// GetBackupSchedule delegates to GetBackupScheduleFunc if set, otherwise dispatches via GetBackupScheduleCalls. +// GetBackupSchedule records the call and dispatches via GetBackupScheduleCalls. func (f *FakeBackupService) GetBackupSchedule(ctx context.Context, req *backupv1.GetBackupScheduleRequest) (*backupv1.GetBackupScheduleResponse, error) { f.GetBackupScheduleCalls.record(req) - if f.GetBackupScheduleFunc != nil { - return f.GetBackupScheduleFunc(ctx, req) - } return f.GetBackupScheduleCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.GetBackupSchedule) } -// CreateBackupSchedule delegates to CreateBackupScheduleFunc if set, otherwise dispatches via CreateBackupScheduleCalls. +// CreateBackupSchedule records the call and dispatches via CreateBackupScheduleCalls. func (f *FakeBackupService) CreateBackupSchedule(ctx context.Context, req *backupv1.CreateBackupScheduleRequest) (*backupv1.CreateBackupScheduleResponse, error) { f.CreateBackupScheduleCalls.record(req) - if f.CreateBackupScheduleFunc != nil { - return f.CreateBackupScheduleFunc(ctx, req) - } return f.CreateBackupScheduleCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.CreateBackupSchedule) } -// UpdateBackupSchedule delegates to UpdateBackupScheduleFunc if set, otherwise dispatches via UpdateBackupScheduleCalls. +// UpdateBackupSchedule records the call and dispatches via UpdateBackupScheduleCalls. func (f *FakeBackupService) UpdateBackupSchedule(ctx context.Context, req *backupv1.UpdateBackupScheduleRequest) (*backupv1.UpdateBackupScheduleResponse, error) { f.UpdateBackupScheduleCalls.record(req) - if f.UpdateBackupScheduleFunc != nil { - return f.UpdateBackupScheduleFunc(ctx, req) - } return f.UpdateBackupScheduleCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.UpdateBackupSchedule) } -// DeleteBackupSchedule delegates to DeleteBackupScheduleFunc if set, otherwise dispatches via DeleteBackupScheduleCalls. +// DeleteBackupSchedule records the call and dispatches via DeleteBackupScheduleCalls. func (f *FakeBackupService) DeleteBackupSchedule(ctx context.Context, req *backupv1.DeleteBackupScheduleRequest) (*backupv1.DeleteBackupScheduleResponse, error) { f.DeleteBackupScheduleCalls.record(req) - if f.DeleteBackupScheduleFunc != nil { - return f.DeleteBackupScheduleFunc(ctx, req) - } return f.DeleteBackupScheduleCalls.dispatch(ctx, req, f.UnimplementedBackupServiceServer.DeleteBackupSchedule) } diff --git a/internal/testutil/fake_booking.go b/internal/testutil/fake_booking.go index 88402b3..4a2d014 100644 --- a/internal/testutil/fake_booking.go +++ b/internal/testutil/fake_booking.go @@ -7,20 +7,15 @@ import ( ) // FakeBookingService is a test fake that implements BookingServiceServer. -// Set the function fields to control responses per test. +// Use the *Calls fields to configure responses and inspect captured requests. type FakeBookingService struct { bookingv1.UnimplementedBookingServiceServer - ListPackagesFunc func(context.Context, *bookingv1.ListPackagesRequest) (*bookingv1.ListPackagesResponse, error) - ListPackagesCalls MethodSpy[*bookingv1.ListPackagesRequest, *bookingv1.ListPackagesResponse] } -// ListPackages delegates to ListPackagesFunc if set, otherwise dispatches via ListPackagesCalls. +// ListPackages records the call and dispatches via ListPackagesCalls. func (f *FakeBookingService) ListPackages(ctx context.Context, req *bookingv1.ListPackagesRequest) (*bookingv1.ListPackagesResponse, error) { f.ListPackagesCalls.record(req) - if f.ListPackagesFunc != nil { - return f.ListPackagesFunc(ctx, req) - } return f.ListPackagesCalls.dispatch(ctx, req, f.UnimplementedBookingServiceServer.ListPackages) } diff --git a/internal/testutil/fake_cluster.go b/internal/testutil/fake_cluster.go index 6c5c563..3389054 100644 --- a/internal/testutil/fake_cluster.go +++ b/internal/testutil/fake_cluster.go @@ -7,21 +7,10 @@ import ( ) // FakeClusterService is a test fake that implements ClusterServiceServer. -// Set the function fields to control responses per test. +// Use the *Calls fields to configure responses and inspect captured requests. type FakeClusterService struct { clusterv1.UnimplementedClusterServiceServer - ListClustersFunc func(context.Context, *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) - GetClusterFunc func(context.Context, *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) - CreateClusterFunc func(context.Context, *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) - UpdateClusterFunc func(context.Context, *clusterv1.UpdateClusterRequest) (*clusterv1.UpdateClusterResponse, error) - DeleteClusterFunc func(context.Context, *clusterv1.DeleteClusterRequest) (*clusterv1.DeleteClusterResponse, error) - RestartClusterFunc func(context.Context, *clusterv1.RestartClusterRequest) (*clusterv1.RestartClusterResponse, error) - SuspendClusterFunc func(context.Context, *clusterv1.SuspendClusterRequest) (*clusterv1.SuspendClusterResponse, error) - UnsuspendClusterFunc func(context.Context, *clusterv1.UnsuspendClusterRequest) (*clusterv1.UnsuspendClusterResponse, error) - SuggestClusterNameFunc func(context.Context, *clusterv1.SuggestClusterNameRequest) (*clusterv1.SuggestClusterNameResponse, error) - ListQdrantReleasesFunc func(context.Context, *clusterv1.ListQdrantReleasesRequest) (*clusterv1.ListQdrantReleasesResponse, error) - ListClustersCalls MethodSpy[*clusterv1.ListClustersRequest, *clusterv1.ListClustersResponse] GetClusterCalls MethodSpy[*clusterv1.GetClusterRequest, *clusterv1.GetClusterResponse] CreateClusterCalls MethodSpy[*clusterv1.CreateClusterRequest, *clusterv1.CreateClusterResponse] @@ -34,92 +23,62 @@ type FakeClusterService struct { ListQdrantReleasesCalls MethodSpy[*clusterv1.ListQdrantReleasesRequest, *clusterv1.ListQdrantReleasesResponse] } -// ListClusters delegates to ListClustersFunc if set, otherwise dispatches via ListClustersCalls. +// ListClusters records the call and dispatches via ListClustersCalls. func (f *FakeClusterService) ListClusters(ctx context.Context, req *clusterv1.ListClustersRequest) (*clusterv1.ListClustersResponse, error) { f.ListClustersCalls.record(req) - if f.ListClustersFunc != nil { - return f.ListClustersFunc(ctx, req) - } return f.ListClustersCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.ListClusters) } -// GetCluster delegates to GetClusterFunc if set, otherwise dispatches via GetClusterCalls. +// GetCluster records the call and dispatches via GetClusterCalls. func (f *FakeClusterService) GetCluster(ctx context.Context, req *clusterv1.GetClusterRequest) (*clusterv1.GetClusterResponse, error) { f.GetClusterCalls.record(req) - if f.GetClusterFunc != nil { - return f.GetClusterFunc(ctx, req) - } return f.GetClusterCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.GetCluster) } -// CreateCluster delegates to CreateClusterFunc if set, otherwise dispatches via CreateClusterCalls. +// CreateCluster records the call and dispatches via CreateClusterCalls. func (f *FakeClusterService) CreateCluster(ctx context.Context, req *clusterv1.CreateClusterRequest) (*clusterv1.CreateClusterResponse, error) { f.CreateClusterCalls.record(req) - if f.CreateClusterFunc != nil { - return f.CreateClusterFunc(ctx, req) - } return f.CreateClusterCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.CreateCluster) } -// UpdateCluster delegates to UpdateClusterFunc if set, otherwise dispatches via UpdateClusterCalls. +// UpdateCluster records the call and dispatches via UpdateClusterCalls. func (f *FakeClusterService) UpdateCluster(ctx context.Context, req *clusterv1.UpdateClusterRequest) (*clusterv1.UpdateClusterResponse, error) { f.UpdateClusterCalls.record(req) - if f.UpdateClusterFunc != nil { - return f.UpdateClusterFunc(ctx, req) - } return f.UpdateClusterCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.UpdateCluster) } -// DeleteCluster delegates to DeleteClusterFunc if set, otherwise dispatches via DeleteClusterCalls. +// DeleteCluster records the call and dispatches via DeleteClusterCalls. func (f *FakeClusterService) DeleteCluster(ctx context.Context, req *clusterv1.DeleteClusterRequest) (*clusterv1.DeleteClusterResponse, error) { f.DeleteClusterCalls.record(req) - if f.DeleteClusterFunc != nil { - return f.DeleteClusterFunc(ctx, req) - } return f.DeleteClusterCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.DeleteCluster) } -// RestartCluster delegates to RestartClusterFunc if set, otherwise dispatches via RestartClusterCalls. +// RestartCluster records the call and dispatches via RestartClusterCalls. func (f *FakeClusterService) RestartCluster(ctx context.Context, req *clusterv1.RestartClusterRequest) (*clusterv1.RestartClusterResponse, error) { f.RestartClusterCalls.record(req) - if f.RestartClusterFunc != nil { - return f.RestartClusterFunc(ctx, req) - } return f.RestartClusterCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.RestartCluster) } -// SuspendCluster delegates to SuspendClusterFunc if set, otherwise dispatches via SuspendClusterCalls. +// SuspendCluster records the call and dispatches via SuspendClusterCalls. func (f *FakeClusterService) SuspendCluster(ctx context.Context, req *clusterv1.SuspendClusterRequest) (*clusterv1.SuspendClusterResponse, error) { f.SuspendClusterCalls.record(req) - if f.SuspendClusterFunc != nil { - return f.SuspendClusterFunc(ctx, req) - } return f.SuspendClusterCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.SuspendCluster) } -// UnsuspendCluster delegates to UnsuspendClusterFunc if set, otherwise dispatches via UnsuspendClusterCalls. +// UnsuspendCluster records the call and dispatches via UnsuspendClusterCalls. func (f *FakeClusterService) UnsuspendCluster(ctx context.Context, req *clusterv1.UnsuspendClusterRequest) (*clusterv1.UnsuspendClusterResponse, error) { f.UnsuspendClusterCalls.record(req) - if f.UnsuspendClusterFunc != nil { - return f.UnsuspendClusterFunc(ctx, req) - } return f.UnsuspendClusterCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.UnsuspendCluster) } -// SuggestClusterName delegates to SuggestClusterNameFunc if set, otherwise dispatches via SuggestClusterNameCalls. +// SuggestClusterName records the call and dispatches via SuggestClusterNameCalls. func (f *FakeClusterService) SuggestClusterName(ctx context.Context, req *clusterv1.SuggestClusterNameRequest) (*clusterv1.SuggestClusterNameResponse, error) { f.SuggestClusterNameCalls.record(req) - if f.SuggestClusterNameFunc != nil { - return f.SuggestClusterNameFunc(ctx, req) - } return f.SuggestClusterNameCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.SuggestClusterName) } -// ListQdrantReleases delegates to ListQdrantReleasesFunc if set, otherwise dispatches via ListQdrantReleasesCalls. +// ListQdrantReleases records the call and dispatches via ListQdrantReleasesCalls. func (f *FakeClusterService) ListQdrantReleases(ctx context.Context, req *clusterv1.ListQdrantReleasesRequest) (*clusterv1.ListQdrantReleasesResponse, error) { f.ListQdrantReleasesCalls.record(req) - if f.ListQdrantReleasesFunc != nil { - return f.ListQdrantReleasesFunc(ctx, req) - } return f.ListQdrantReleasesCalls.dispatch(ctx, req, f.UnimplementedClusterServiceServer.ListQdrantReleases) } diff --git a/internal/testutil/fake_database_api_key.go b/internal/testutil/fake_database_api_key.go index 052cd3d..d92e713 100644 --- a/internal/testutil/fake_database_api_key.go +++ b/internal/testutil/fake_database_api_key.go @@ -7,42 +7,29 @@ import ( ) // FakeDatabaseApiKeyService is a test fake that implements DatabaseApiKeyServiceServer. -// Set the function fields to control responses per test. +// Use the *Calls fields to configure responses and inspect captured requests. type FakeDatabaseApiKeyService struct { clusterauthv2.UnimplementedDatabaseApiKeyServiceServer - ListDatabaseApiKeysFunc func(context.Context, *clusterauthv2.ListDatabaseApiKeysRequest) (*clusterauthv2.ListDatabaseApiKeysResponse, error) - CreateDatabaseApiKeyFunc func(context.Context, *clusterauthv2.CreateDatabaseApiKeyRequest) (*clusterauthv2.CreateDatabaseApiKeyResponse, error) - DeleteDatabaseApiKeyFunc func(context.Context, *clusterauthv2.DeleteDatabaseApiKeyRequest) (*clusterauthv2.DeleteDatabaseApiKeyResponse, error) - ListDatabaseApiKeysCalls MethodSpy[*clusterauthv2.ListDatabaseApiKeysRequest, *clusterauthv2.ListDatabaseApiKeysResponse] CreateDatabaseApiKeyCalls MethodSpy[*clusterauthv2.CreateDatabaseApiKeyRequest, *clusterauthv2.CreateDatabaseApiKeyResponse] DeleteDatabaseApiKeyCalls MethodSpy[*clusterauthv2.DeleteDatabaseApiKeyRequest, *clusterauthv2.DeleteDatabaseApiKeyResponse] } -// ListDatabaseApiKeys delegates to ListDatabaseApiKeysFunc if set, otherwise dispatches via ListDatabaseApiKeysCalls. +// ListDatabaseApiKeys records the call and dispatches via ListDatabaseApiKeysCalls. func (f *FakeDatabaseApiKeyService) ListDatabaseApiKeys(ctx context.Context, req *clusterauthv2.ListDatabaseApiKeysRequest) (*clusterauthv2.ListDatabaseApiKeysResponse, error) { f.ListDatabaseApiKeysCalls.record(req) - if f.ListDatabaseApiKeysFunc != nil { - return f.ListDatabaseApiKeysFunc(ctx, req) - } return f.ListDatabaseApiKeysCalls.dispatch(ctx, req, f.UnimplementedDatabaseApiKeyServiceServer.ListDatabaseApiKeys) } -// CreateDatabaseApiKey delegates to CreateDatabaseApiKeyFunc if set, otherwise dispatches via CreateDatabaseApiKeyCalls. +// CreateDatabaseApiKey records the call and dispatches via CreateDatabaseApiKeyCalls. func (f *FakeDatabaseApiKeyService) CreateDatabaseApiKey(ctx context.Context, req *clusterauthv2.CreateDatabaseApiKeyRequest) (*clusterauthv2.CreateDatabaseApiKeyResponse, error) { f.CreateDatabaseApiKeyCalls.record(req) - if f.CreateDatabaseApiKeyFunc != nil { - return f.CreateDatabaseApiKeyFunc(ctx, req) - } return f.CreateDatabaseApiKeyCalls.dispatch(ctx, req, f.UnimplementedDatabaseApiKeyServiceServer.CreateDatabaseApiKey) } -// DeleteDatabaseApiKey delegates to DeleteDatabaseApiKeyFunc if set, otherwise dispatches via DeleteDatabaseApiKeyCalls. +// DeleteDatabaseApiKey records the call and dispatches via DeleteDatabaseApiKeyCalls. func (f *FakeDatabaseApiKeyService) DeleteDatabaseApiKey(ctx context.Context, req *clusterauthv2.DeleteDatabaseApiKeyRequest) (*clusterauthv2.DeleteDatabaseApiKeyResponse, error) { f.DeleteDatabaseApiKeyCalls.record(req) - if f.DeleteDatabaseApiKeyFunc != nil { - return f.DeleteDatabaseApiKeyFunc(ctx, req) - } return f.DeleteDatabaseApiKeyCalls.dispatch(ctx, req, f.UnimplementedDatabaseApiKeyServiceServer.DeleteDatabaseApiKey) } diff --git a/internal/testutil/fake_platform.go b/internal/testutil/fake_platform.go index a89c846..f5658a5 100644 --- a/internal/testutil/fake_platform.go +++ b/internal/testutil/fake_platform.go @@ -7,31 +7,22 @@ import ( ) // FakePlatformService is a test fake that implements PlatformServiceServer. -// Set the function fields to control responses per test. +// Use the *Calls fields to configure responses and inspect captured requests. type FakePlatformService struct { platformv1.UnimplementedPlatformServiceServer - ListCloudProvidersFunc func(context.Context, *platformv1.ListCloudProvidersRequest) (*platformv1.ListCloudProvidersResponse, error) - ListCloudProviderRegionsFunc func(context.Context, *platformv1.ListCloudProviderRegionsRequest) (*platformv1.ListCloudProviderRegionsResponse, error) - ListCloudProvidersCalls MethodSpy[*platformv1.ListCloudProvidersRequest, *platformv1.ListCloudProvidersResponse] ListCloudProviderRegionsCalls MethodSpy[*platformv1.ListCloudProviderRegionsRequest, *platformv1.ListCloudProviderRegionsResponse] } -// ListCloudProviders delegates to ListCloudProvidersFunc if set, otherwise dispatches via ListCloudProvidersCalls. +// ListCloudProviders records the call and dispatches via ListCloudProvidersCalls. func (f *FakePlatformService) ListCloudProviders(ctx context.Context, req *platformv1.ListCloudProvidersRequest) (*platformv1.ListCloudProvidersResponse, error) { f.ListCloudProvidersCalls.record(req) - if f.ListCloudProvidersFunc != nil { - return f.ListCloudProvidersFunc(ctx, req) - } return f.ListCloudProvidersCalls.dispatch(ctx, req, f.UnimplementedPlatformServiceServer.ListCloudProviders) } -// ListCloudProviderRegions delegates to ListCloudProviderRegionsFunc if set, otherwise dispatches via ListCloudProviderRegionsCalls. +// ListCloudProviderRegions records the call and dispatches via ListCloudProviderRegionsCalls. func (f *FakePlatformService) ListCloudProviderRegions(ctx context.Context, req *platformv1.ListCloudProviderRegionsRequest) (*platformv1.ListCloudProviderRegionsResponse, error) { f.ListCloudProviderRegionsCalls.record(req) - if f.ListCloudProviderRegionsFunc != nil { - return f.ListCloudProviderRegionsFunc(ctx, req) - } return f.ListCloudProviderRegionsCalls.dispatch(ctx, req, f.UnimplementedPlatformServiceServer.ListCloudProviderRegions) }