diff --git a/internal/branches/create/create.go b/internal/branches/create/create.go index 5ee14aacd9..50e731d803 100644 --- a/internal/branches/create/create.go +++ b/internal/branches/create/create.go @@ -30,6 +30,7 @@ func Run(ctx context.Context, body api.CreateBranchBody, fsys afero.Fs) error { if err != nil { return errors.Errorf("failed to create preview branch: %w", err) } else if resp.JSON201 == nil { + utils.SuggestUpgradeOnError(ctx, flags.ProjectRef, "branching_limit", resp.StatusCode()) return errors.Errorf("unexpected create branch status %d: %s", resp.StatusCode(), string(resp.Body)) } diff --git a/internal/branches/create/create_test.go b/internal/branches/create/create_test.go index 60dbfc5279..c08a10286c 100644 --- a/internal/branches/create/create_test.go +++ b/internal/branches/create/create_test.go @@ -77,4 +77,45 @@ func TestCreateCommand(t *testing.T) { // Check error assert.ErrorContains(t, err, "unexpected create branch status 503:") }) + + t.Run("suggests upgrade on payment required", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + t.Cleanup(func() { utils.CmdSuggestion = "" }) + // Mock branches create returns 402 + gock.New(utils.DefaultApiHost). + Post("/v1/projects/" + flags.ProjectRef + "/branches"). + Reply(http.StatusPaymentRequired). + JSON(map[string]interface{}{"message": "branching requires a paid plan"}) + // Mock project lookup for SuggestUpgradeOnError + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + flags.ProjectRef). + Reply(http.StatusOK). + JSON(map[string]interface{}{ + "ref": flags.ProjectRef, + "organization_slug": "test-org", + "name": "test", + "region": "us-east-1", + "created_at": "2024-01-01T00:00:00Z", + "status": "ACTIVE_HEALTHY", + "database": map[string]interface{}{"host": "db.example.supabase.co", "version": "15.1.0.117"}, + }) + // Mock entitlements + gock.New(utils.DefaultApiHost). + Get("/v1/organizations/test-org/entitlements"). + Reply(http.StatusOK). + JSON(map[string]interface{}{ + "entitlements": []map[string]interface{}{ + { + "feature": map[string]interface{}{"key": "branching_limit", "type": "numeric"}, + "hasAccess": false, + "type": "numeric", + "config": map[string]interface{}{"enabled": false, "value": 0, "unlimited": false, "unit": "count"}, + }, + }, + }) + fsys := afero.NewMemMapFs() + err := Run(context.Background(), api.CreateBranchBody{Region: cast.Ptr("sin")}, fsys) + assert.ErrorContains(t, err, "unexpected create branch status 402") + assert.Contains(t, utils.CmdSuggestion, "/org/test-org/billing") + }) } diff --git a/internal/branches/update/update.go b/internal/branches/update/update.go index a467ae1d2a..8ad8c1e381 100644 --- a/internal/branches/update/update.go +++ b/internal/branches/update/update.go @@ -10,6 +10,7 @@ import ( "github.com/supabase/cli/internal/branches/list" "github.com/supabase/cli/internal/branches/pause" "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/api" ) @@ -22,6 +23,7 @@ func Run(ctx context.Context, branchId string, body api.UpdateBranchBody, fsys a if err != nil { return errors.Errorf("failed to update preview branch: %w", err) } else if resp.JSON200 == nil { + utils.SuggestUpgradeOnError(ctx, flags.ProjectRef, "branching_persistent", resp.StatusCode()) return errors.Errorf("unexpected update branch status %d: %s", resp.StatusCode(), string(resp.Body)) } fmt.Fprintln(os.Stderr, "Updated preview branch:") diff --git a/internal/branches/update/update_test.go b/internal/branches/update/update_test.go index 18382e94e0..57548ce47d 100644 --- a/internal/branches/update/update_test.go +++ b/internal/branches/update/update_test.go @@ -106,4 +106,45 @@ func TestUpdateBranch(t *testing.T) { err := Run(context.Background(), flags.ProjectRef, api.UpdateBranchBody{}, nil) assert.ErrorContains(t, err, "unexpected update branch status 503:") }) + + t.Run("suggests upgrade on payment required for persistent", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + t.Cleanup(func() { utils.CmdSuggestion = "" }) + // Mock branch update returns 402 + gock.New(utils.DefaultApiHost). + Patch("/v1/branches/" + flags.ProjectRef). + Reply(http.StatusPaymentRequired). + JSON(map[string]interface{}{"message": "Persistent branches are not available on your plan"}) + // Mock project lookup for SuggestUpgradeOnError + gock.New(utils.DefaultApiHost). + Get("/v1/projects/" + flags.ProjectRef). + Reply(http.StatusOK). + JSON(map[string]interface{}{ + "ref": flags.ProjectRef, + "organization_slug": "test-org", + "name": "test", + "region": "us-east-1", + "created_at": "2024-01-01T00:00:00Z", + "status": "ACTIVE_HEALTHY", + "database": map[string]interface{}{"host": "db.example.supabase.co", "version": "15.1.0.117"}, + }) + // Mock entitlements + gock.New(utils.DefaultApiHost). + Get("/v1/organizations/test-org/entitlements"). + Reply(http.StatusOK). + JSON(map[string]interface{}{ + "entitlements": []map[string]interface{}{ + { + "feature": map[string]interface{}{"key": "branching_persistent", "type": "boolean"}, + "hasAccess": false, + "type": "boolean", + "config": map[string]interface{}{"enabled": false}, + }, + }, + }) + persistent := true + err := Run(context.Background(), flags.ProjectRef, api.UpdateBranchBody{Persistent: &persistent}, nil) + assert.ErrorContains(t, err, "unexpected update branch status 402") + assert.Contains(t, utils.CmdSuggestion, "/org/test-org/billing") + }) } diff --git a/internal/utils/plan_gate.go b/internal/utils/plan_gate.go new file mode 100644 index 0000000000..f5c1842845 --- /dev/null +++ b/internal/utils/plan_gate.go @@ -0,0 +1,54 @@ +package utils + +import ( + "context" + "fmt" + "net/http" +) + +func GetOrgSlugFromProjectRef(ctx context.Context, projectRef string) (string, error) { + resp, err := GetSupabase().V1GetProjectWithResponse(ctx, projectRef) + if err != nil { + return "", fmt.Errorf("failed to get project: %w", err) + } + if resp.JSON200 == nil { + return "", fmt.Errorf("unexpected get project status %d: %s", resp.StatusCode(), string(resp.Body)) + } + return resp.JSON200.OrganizationSlug, nil +} + +func GetOrgBillingURL(orgSlug string) string { + return fmt.Sprintf("%s/org/%s/billing", GetSupabaseDashboardURL(), orgSlug) +} + +// SuggestUpgradeOnError checks if a failed API response is due to plan limitations +// and sets CmdSuggestion with a billing upgrade link. Best-effort: never returns errors. +// Only triggers on 402 Payment Required (not 403, which could be a permissions issue). +func SuggestUpgradeOnError(ctx context.Context, projectRef, featureKey string, statusCode int) { + if statusCode != http.StatusPaymentRequired { + return + } + + orgSlug, err := GetOrgSlugFromProjectRef(ctx, projectRef) + if err != nil { + CmdSuggestion = fmt.Sprintf("This feature may require a plan upgrade. Manage billing: %s", Bold(GetSupabaseDashboardURL())) + return + } + + billingURL := GetOrgBillingURL(orgSlug) + + resp, err := GetSupabase().V1GetOrganizationEntitlementsWithResponse(ctx, orgSlug) + if err != nil || resp.JSON200 == nil { + CmdSuggestion = fmt.Sprintf("This feature may require a plan upgrade. Manage billing: %s", Bold(billingURL)) + return + } + + for _, e := range resp.JSON200.Entitlements { + if string(e.Feature.Key) == featureKey && !e.HasAccess { + CmdSuggestion = fmt.Sprintf("Your organization does not have access to this feature. Upgrade your plan: %s", Bold(billingURL)) + return + } + } + + CmdSuggestion = fmt.Sprintf("This feature may require a plan upgrade. Manage billing: %s", Bold(billingURL)) +} diff --git a/internal/utils/plan_gate_test.go b/internal/utils/plan_gate_test.go new file mode 100644 index 0000000000..dee3ef7865 --- /dev/null +++ b/internal/utils/plan_gate_test.go @@ -0,0 +1,153 @@ +package utils + +import ( + "context" + "net/http" + "testing" + + "github.com/h2non/gock" + "github.com/stretchr/testify/assert" + "github.com/supabase/cli/internal/testing/apitest" +) + +var planGateProjectJSON = map[string]interface{}{ + "ref": "test-ref", + "organization_slug": "my-org", + "name": "test", + "region": "us-east-1", + "created_at": "2024-01-01T00:00:00Z", + "status": "ACTIVE_HEALTHY", + "database": map[string]interface{}{"host": "db.example.supabase.co", "version": "15.1.0.117"}, +} + +func TestGetOrgSlugFromProjectRef(t *testing.T) { + ref := apitest.RandomProjectRef() + + t.Run("returns org slug on success", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + gock.New(DefaultApiHost). + Get("/v1/projects/" + ref). + Reply(http.StatusOK). + JSON(planGateProjectJSON) + slug, err := GetOrgSlugFromProjectRef(context.Background(), ref) + assert.NoError(t, err) + assert.Equal(t, "my-org", slug) + }) + + t.Run("returns error on not found", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + gock.New(DefaultApiHost). + Get("/v1/projects/" + ref). + Reply(http.StatusNotFound) + _, err := GetOrgSlugFromProjectRef(context.Background(), ref) + assert.ErrorContains(t, err, "unexpected get project status 404") + }) + + t.Run("returns error on network failure", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + gock.New(DefaultApiHost). + Get("/v1/projects/" + ref). + ReplyError(assert.AnError) + _, err := GetOrgSlugFromProjectRef(context.Background(), ref) + assert.ErrorContains(t, err, "failed to get project") + }) +} + +func TestGetOrgBillingURL(t *testing.T) { + url := GetOrgBillingURL("my-org") + assert.Equal(t, GetSupabaseDashboardURL()+"/org/my-org/billing", url) +} + +func entitlementsJSON(featureKey string, hasAccess bool) map[string]interface{} { + return map[string]interface{}{ + "entitlements": []map[string]interface{}{ + { + "feature": map[string]interface{}{"key": featureKey, "type": "numeric"}, + "hasAccess": hasAccess, + "type": "numeric", + "config": map[string]interface{}{"enabled": hasAccess, "value": 0, "unlimited": false, "unit": "count"}, + }, + }, + } +} + +func TestSuggestUpgradeOnError(t *testing.T) { + ref := apitest.RandomProjectRef() + + t.Run("sets specific suggestion on 402 with gated feature", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + t.Cleanup(func() { CmdSuggestion = "" }) + gock.New(DefaultApiHost). + Get("/v1/projects/" + ref). + Reply(http.StatusOK). + JSON(planGateProjectJSON) + gock.New(DefaultApiHost). + Get("/v1/organizations/my-org/entitlements"). + Reply(http.StatusOK). + JSON(entitlementsJSON("branching_limit", false)) + SuggestUpgradeOnError(context.Background(), ref, "branching_limit", http.StatusPaymentRequired) + assert.Contains(t, CmdSuggestion, "/org/my-org/billing") + assert.Contains(t, CmdSuggestion, "does not have access") + }) + + t.Run("sets generic suggestion when entitlements lookup fails", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + t.Cleanup(func() { CmdSuggestion = "" }) + gock.New(DefaultApiHost). + Get("/v1/projects/" + ref). + Reply(http.StatusOK). + JSON(planGateProjectJSON) + gock.New(DefaultApiHost). + Get("/v1/organizations/my-org/entitlements"). + Reply(http.StatusInternalServerError) + SuggestUpgradeOnError(context.Background(), ref, "branching_limit", http.StatusPaymentRequired) + assert.Contains(t, CmdSuggestion, "/org/my-org/billing") + assert.Contains(t, CmdSuggestion, "may require a plan upgrade") + }) + + t.Run("sets fallback suggestion when project lookup fails", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + t.Cleanup(func() { CmdSuggestion = "" }) + gock.New(DefaultApiHost). + Get("/v1/projects/" + ref). + Reply(http.StatusNotFound) + SuggestUpgradeOnError(context.Background(), ref, "branching_limit", http.StatusPaymentRequired) + assert.Contains(t, CmdSuggestion, "plan upgrade") + assert.Contains(t, CmdSuggestion, GetSupabaseDashboardURL()) + assert.NotContains(t, CmdSuggestion, "/org/") + }) + + t.Run("sets generic suggestion when feature has access", func(t *testing.T) { + t.Cleanup(apitest.MockPlatformAPI(t)) + t.Cleanup(func() { CmdSuggestion = "" }) + gock.New(DefaultApiHost). + Get("/v1/projects/" + ref). + Reply(http.StatusOK). + JSON(planGateProjectJSON) + gock.New(DefaultApiHost). + Get("/v1/organizations/my-org/entitlements"). + Reply(http.StatusOK). + JSON(entitlementsJSON("branching_limit", true)) + SuggestUpgradeOnError(context.Background(), ref, "branching_limit", http.StatusPaymentRequired) + assert.Contains(t, CmdSuggestion, "/org/my-org/billing") + assert.Contains(t, CmdSuggestion, "may require a plan upgrade") + }) + + t.Run("skips suggestion on 403 forbidden", func(t *testing.T) { + CmdSuggestion = "" + SuggestUpgradeOnError(context.Background(), ref, "branching_limit", http.StatusForbidden) + assert.Empty(t, CmdSuggestion) + }) + + t.Run("skips suggestion on non-billing status codes", func(t *testing.T) { + CmdSuggestion = "" + SuggestUpgradeOnError(context.Background(), ref, "branching_limit", http.StatusInternalServerError) + assert.Empty(t, CmdSuggestion) + }) + + t.Run("skips suggestion on success status codes", func(t *testing.T) { + CmdSuggestion = "" + SuggestUpgradeOnError(context.Background(), ref, "branching_limit", http.StatusOK) + assert.Empty(t, CmdSuggestion) + }) +}