Skip to content

Commit 43ce0a6

Browse files
Rodriguespnclaude
andcommitted
feat(db): add supabase db advisors command for checking security and performance issues
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0351926 commit 43ce0a6

File tree

4 files changed

+1995
-0
lines changed

4 files changed

+1995
-0
lines changed

cmd/db.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/spf13/afero"
99
"github.com/spf13/cobra"
1010
"github.com/spf13/viper"
11+
"github.com/supabase/cli/internal/db/advisors"
1112
"github.com/supabase/cli/internal/db/diff"
1213
"github.com/supabase/cli/internal/db/dump"
1314
"github.com/supabase/cli/internal/db/lint"
@@ -291,6 +292,44 @@ without the envelope.`,
291292
return query.RunLocal(cmd.Context(), sql, flags.DbConfig, outputFormat, agentMode, os.Stdout)
292293
},
293294
}
295+
296+
advisorType = utils.EnumFlag{
297+
Allowed: advisors.AllowedTypes,
298+
Value: advisors.AllowedTypes[0],
299+
}
300+
301+
advisorLevel = utils.EnumFlag{
302+
Allowed: advisors.AllowedLevels,
303+
Value: advisors.AllowedLevels[1],
304+
}
305+
306+
advisorFailOn = utils.EnumFlag{
307+
Allowed: append([]string{"none"}, advisors.AllowedLevels...),
308+
Value: "none",
309+
}
310+
311+
dbAdvisorsCmd = &cobra.Command{
312+
Use: "advisors",
313+
Short: "Checks database for security and performance issues",
314+
Long: "Inspects the database for common security and performance issues such as missing RLS policies, unindexed foreign keys, exposed auth.users, and more.",
315+
PreRunE: func(cmd *cobra.Command, args []string) error {
316+
if flag := cmd.Flags().Lookup("linked"); flag != nil && flag.Changed {
317+
fsys := afero.NewOsFs()
318+
if _, err := utils.LoadAccessTokenFS(fsys); err != nil {
319+
utils.CmdSuggestion = fmt.Sprintf("Run %s first.", utils.Aqua("supabase login"))
320+
return err
321+
}
322+
return flags.LoadProjectRef(fsys)
323+
}
324+
return nil
325+
},
326+
RunE: func(cmd *cobra.Command, args []string) error {
327+
if flag := cmd.Flags().Lookup("linked"); flag != nil && flag.Changed {
328+
return advisors.RunLinked(cmd.Context(), advisorType.Value, advisorLevel.Value, advisorFailOn.Value, flags.ProjectRef)
329+
}
330+
return advisors.RunLocal(cmd.Context(), advisorType.Value, advisorLevel.Value, advisorFailOn.Value, flags.DbConfig)
331+
},
332+
}
294333
)
295334

296335
func init() {
@@ -409,5 +448,15 @@ func init() {
409448
queryFlags.StringVarP(&queryFile, "file", "f", "", "Path to a SQL file to execute.")
410449
queryFlags.VarP(&queryOutput, "output", "o", "Output format: table, json, or csv.")
411450
dbCmd.AddCommand(dbQueryCmd)
451+
// Build advisors command
452+
advisorsFlags := dbAdvisorsCmd.Flags()
453+
advisorsFlags.String("db-url", "", "Checks the database specified by the connection string (must be percent-encoded).")
454+
advisorsFlags.Bool("linked", false, "Checks the linked project for issues.")
455+
advisorsFlags.Bool("local", true, "Checks the local database for issues.")
456+
dbAdvisorsCmd.MarkFlagsMutuallyExclusive("db-url", "linked", "local")
457+
advisorsFlags.Var(&advisorType, "type", "Type of advisors to check: all, security, performance.")
458+
advisorsFlags.Var(&advisorLevel, "level", "Minimum issue level to display: info, warn, error.")
459+
advisorsFlags.Var(&advisorFailOn, "fail-on", "Issue level to exit with non-zero status: none, info, warn, error.")
460+
dbCmd.AddCommand(dbAdvisorsCmd)
412461
rootCmd.AddCommand(dbCmd)
413462
}

internal/db/advisors/advisors.go

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
package advisors
2+
3+
import (
4+
"context"
5+
_ "embed"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"os"
10+
11+
"github.com/go-errors/errors"
12+
"github.com/jackc/pgconn"
13+
"github.com/jackc/pgx/v4"
14+
"github.com/supabase/cli/internal/utils"
15+
"github.com/supabase/cli/pkg/api"
16+
)
17+
18+
var (
19+
AllowedLevels = []string{
20+
"info",
21+
"warn",
22+
"error",
23+
}
24+
25+
AllowedTypes = []string{
26+
"all",
27+
"security",
28+
"performance",
29+
}
30+
31+
//go:embed templates/lints.sql
32+
lintsSQL string
33+
)
34+
35+
type LintLevel int
36+
37+
func toEnum(level string) LintLevel {
38+
switch level {
39+
case "INFO", "info":
40+
return 0
41+
case "WARN", "warn":
42+
return 1
43+
case "ERROR", "error":
44+
return 2
45+
}
46+
return -1
47+
}
48+
49+
type Lint struct {
50+
Name string `json:"name"`
51+
Title string `json:"title"`
52+
Level string `json:"level"`
53+
Facing string `json:"facing"`
54+
Categories []string `json:"categories"`
55+
Description string `json:"description"`
56+
Detail string `json:"detail"`
57+
Remediation string `json:"remediation"`
58+
Metadata *json.RawMessage `json:"metadata,omitempty"`
59+
CacheKey string `json:"cache_key"`
60+
}
61+
62+
func RunLocal(ctx context.Context, advisorType string, level string, failOn string, config pgconn.Config, options ...func(*pgx.ConnConfig)) error {
63+
conn, err := utils.ConnectByConfig(ctx, config, options...)
64+
if err != nil {
65+
return err
66+
}
67+
defer conn.Close(context.Background())
68+
69+
lints, err := queryLints(ctx, conn)
70+
if err != nil {
71+
return err
72+
}
73+
74+
filtered := filterLints(lints, advisorType, level)
75+
return outputAndCheck(filtered, failOn, os.Stdout)
76+
}
77+
78+
func RunLinked(ctx context.Context, advisorType string, level string, failOn string, projectRef string) error {
79+
var lints []Lint
80+
81+
if advisorType == "all" || advisorType == "security" {
82+
securityLints, err := fetchSecurityAdvisors(ctx, projectRef)
83+
if err != nil {
84+
return err
85+
}
86+
lints = append(lints, securityLints...)
87+
}
88+
89+
if advisorType == "all" || advisorType == "performance" {
90+
perfLints, err := fetchPerformanceAdvisors(ctx, projectRef)
91+
if err != nil {
92+
return err
93+
}
94+
lints = append(lints, perfLints...)
95+
}
96+
97+
filtered := filterByLevel(lints, level)
98+
return outputAndCheck(filtered, failOn, os.Stdout)
99+
}
100+
101+
func queryLints(ctx context.Context, conn *pgx.Conn) ([]Lint, error) {
102+
tx, err := conn.Begin(ctx)
103+
if err != nil {
104+
return nil, errors.Errorf("failed to begin transaction: %w", err)
105+
}
106+
defer func() {
107+
if err := tx.Rollback(context.Background()); err != nil {
108+
fmt.Fprintln(os.Stderr, err)
109+
}
110+
}()
111+
112+
rows, err := tx.Query(ctx, lintsSQL)
113+
if err != nil {
114+
return nil, errors.Errorf("failed to query lints: %w", err)
115+
}
116+
defer rows.Close()
117+
118+
var lints []Lint
119+
for rows.Next() {
120+
var l Lint
121+
var metadata []byte
122+
if err := rows.Scan(
123+
&l.Name,
124+
&l.Title,
125+
&l.Level,
126+
&l.Facing,
127+
&l.Categories,
128+
&l.Description,
129+
&l.Detail,
130+
&l.Remediation,
131+
&metadata,
132+
&l.CacheKey,
133+
); err != nil {
134+
return nil, errors.Errorf("failed to scan lint row: %w", err)
135+
}
136+
if len(metadata) > 0 {
137+
raw := json.RawMessage(metadata)
138+
l.Metadata = &raw
139+
}
140+
lints = append(lints, l)
141+
}
142+
if err := rows.Err(); err != nil {
143+
return nil, errors.Errorf("failed to parse lint rows: %w", err)
144+
}
145+
return lints, nil
146+
}
147+
148+
func fetchSecurityAdvisors(ctx context.Context, projectRef string) ([]Lint, error) {
149+
resp, err := utils.GetSupabase().V1GetSecurityAdvisorsWithResponse(ctx, projectRef, &api.V1GetSecurityAdvisorsParams{})
150+
if err != nil {
151+
return nil, errors.Errorf("failed to fetch security advisors: %w", err)
152+
}
153+
if resp.JSON200 == nil {
154+
return nil, errors.Errorf("unexpected security advisors status %d: %s", resp.StatusCode(), string(resp.Body))
155+
}
156+
return apiResponseToLints(resp.JSON200), nil
157+
}
158+
159+
func fetchPerformanceAdvisors(ctx context.Context, projectRef string) ([]Lint, error) {
160+
resp, err := utils.GetSupabase().V1GetPerformanceAdvisorsWithResponse(ctx, projectRef)
161+
if err != nil {
162+
return nil, errors.Errorf("failed to fetch performance advisors: %w", err)
163+
}
164+
if resp.JSON200 == nil {
165+
return nil, errors.Errorf("unexpected performance advisors status %d: %s", resp.StatusCode(), string(resp.Body))
166+
}
167+
return apiResponseToLints(resp.JSON200), nil
168+
}
169+
170+
func apiResponseToLints(resp *api.V1ProjectAdvisorsResponse) []Lint {
171+
var lints []Lint
172+
for _, l := range resp.Lints {
173+
lint := Lint{
174+
Name: string(l.Name),
175+
Title: l.Title,
176+
Level: string(l.Level),
177+
Facing: string(l.Facing),
178+
Description: l.Description,
179+
Detail: l.Detail,
180+
Remediation: l.Remediation,
181+
CacheKey: l.CacheKey,
182+
}
183+
for _, c := range l.Categories {
184+
lint.Categories = append(lint.Categories, string(c))
185+
}
186+
if l.Metadata != nil {
187+
data, err := json.Marshal(l.Metadata)
188+
if err == nil {
189+
raw := json.RawMessage(data)
190+
lint.Metadata = &raw
191+
}
192+
}
193+
lints = append(lints, lint)
194+
}
195+
return lints
196+
}
197+
198+
func filterLints(lints []Lint, advisorType string, level string) []Lint {
199+
var filtered []Lint
200+
for _, l := range lints {
201+
if !matchesType(l, advisorType) {
202+
continue
203+
}
204+
if toEnum(l.Level) < toEnum(level) {
205+
continue
206+
}
207+
filtered = append(filtered, l)
208+
}
209+
return filtered
210+
}
211+
212+
func filterByLevel(lints []Lint, level string) []Lint {
213+
minLevel := toEnum(level)
214+
var filtered []Lint
215+
for _, l := range lints {
216+
if toEnum(l.Level) >= minLevel {
217+
filtered = append(filtered, l)
218+
}
219+
}
220+
return filtered
221+
}
222+
223+
func matchesType(l Lint, advisorType string) bool {
224+
if advisorType == "all" {
225+
return true
226+
}
227+
for _, c := range l.Categories {
228+
switch {
229+
case advisorType == "security" && c == "SECURITY":
230+
return true
231+
case advisorType == "performance" && c == "PERFORMANCE":
232+
return true
233+
}
234+
}
235+
return false
236+
}
237+
238+
func outputAndCheck(lints []Lint, failOn string, stdout io.Writer) error {
239+
if len(lints) == 0 {
240+
fmt.Fprintln(os.Stderr, "No issues found")
241+
return nil
242+
}
243+
244+
enc := json.NewEncoder(stdout)
245+
enc.SetIndent("", " ")
246+
if err := enc.Encode(lints); err != nil {
247+
return errors.Errorf("failed to print result json: %w", err)
248+
}
249+
250+
failOnLevel := toEnum(failOn)
251+
if failOnLevel >= 0 {
252+
for _, l := range lints {
253+
if toEnum(l.Level) >= failOnLevel {
254+
return fmt.Errorf("fail-on is set to %s, non-zero exit", failOn)
255+
}
256+
}
257+
}
258+
return nil
259+
}

0 commit comments

Comments
 (0)