diff --git a/cli/daemon/db.go b/cli/daemon/db.go index a9b5a93ef2..f08ace9897 100644 --- a/cli/daemon/db.go +++ b/cli/daemon/db.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "strconv" + "strings" "time" "github.com/rs/zerolog/log" @@ -50,6 +51,9 @@ func (s *Server) DBConnect(ctx context.Context, req *daemonpb.DBConnectRequest) } else if appID == "" { return nil, errNotLinked } + if err := validateEnvName(ctx, appID, req.EnvName); err != nil { + return nil, err + } port, passwd, err := sqldb.OneshotProxy(appID, req.EnvName, toRoleType(req.Role)) if err != nil { return nil, err @@ -176,6 +180,11 @@ func (s *Server) DBProxy(params *daemonpb.DBProxyRequest, stream daemonpb.Daemon } else if appID == "" && params.EnvName != "local" { return errNotLinked } + if params.EnvName != "local" { + if err := validateEnvName(ctx, appID, params.EnvName); err != nil { + return err + } + } ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:"+strconv.Itoa(int(params.Port))) if err != nil { @@ -397,6 +406,35 @@ func serveProxy(ctx context.Context, ln net.Listener, handler func(context.Conte } } +// validateEnvName checks that the given environment slug exists for the app. +// If not, it returns an error listing the available environment names. +func validateEnvName(ctx context.Context, appSlug, envName string) error { + envs, err := platform.ListEnvs(ctx, appSlug) + if err != nil { + // If we can't list environments, skip validation and let the + // connection attempt proceed; it will fail with its own error. + return nil + } + + for _, env := range envs { + if env.Slug == envName { + return nil + } + } + + var slugs []string + for _, env := range envs { + slugs = append(slugs, env.Slug) + } + if len(slugs) > 0 { + return status.Errorf(codes.NotFound, + "environment %q not found; available environments for this app are: %s", + envName, strings.Join(slugs, ", ")) + } + return status.Errorf(codes.NotFound, + "environment %q not found; this app has no environments", envName) +} + func getClusterType(req interface{ GetClusterType() daemonpb.DBClusterType }) sqldb.ClusterType { switch req.GetClusterType() { case daemonpb.DBClusterType_DB_CLUSTER_TYPE_RUN: