diff --git a/cmd/vclusterctl/cmd/platform/start.go b/cmd/vclusterctl/cmd/platform/start.go index 93e2bb2cf8..ee4d11f4dd 100644 --- a/cmd/vclusterctl/cmd/platform/start.go +++ b/cmd/vclusterctl/cmd/platform/start.go @@ -81,12 +81,21 @@ before running this command: startCmd.Flags().StringVar(&cmd.ChartRepo, "chart-repo", "https://charts.loft.sh/", "The chart repo to deploy vCluster platform") startCmd.Flags().StringVar(&cmd.ChartName, "chart-name", "vcluster-platform", "The chart name to deploy vCluster platform") startCmd.Flags().BoolVar(&cmd.Docker, "docker", false, "If true, vCluster platform will be installed in Docker") + startCmd.Flags().BoolVar(&cmd.Secure, "secure", false, "If true, verify TLS certificates when connecting to the platform (by default, TLS verification is skipped during bootstrap because the platform starts with a self-signed certificate)") return startCmd } func (cmd *StartCmd) Run(ctx context.Context) error { - // get version to deploy + cfg := cmd.LoadedConfig(cmd.Log) + + // Bootstrap defaults to insecure because the platform starts with a + // self-signed certificate. Pass --secure to enforce TLS verification. + if !cmd.Secure { + cfg.Platform.Insecure = true + } + + // get the version to deploy if cmd.Version == "latest" || cmd.Version == "" { cmd.Version = platform.MinimumVersionTag latestVersion, err := platform.LatestCompatibleVersion(ctx) @@ -154,8 +163,10 @@ func (cmd *StartCmd) Run(ctx context.Context) error { } } - if err := cmd.ensureEmailWithDisclaimer(ctx, cmd.KubeClient, cmd.Namespace); err != nil { - return err + if !cmd.platformUsesNewActivationFlow(cmd.Version) { + if err := cmd.ensureEmailWithDisclaimer(ctx, cmd.KubeClient, cmd.Namespace); err != nil { + return err + } } return start.NewLoftStarter(cmd.StartOptions).Start(ctx) @@ -203,6 +214,26 @@ func promptForEmail(emailAddress string) (string, error) { return emailAddress, nil } +// platformUsesNewActivationFlow checks if the platform version supports the new platform activation flow. +// +// The new platform activation flow is supported for the platform version 4.6.0-rc.8 and above. +func (cmd *StartCmd) platformUsesNewActivationFlow(platformVersion string) bool { + platformSemVerVersion, err := semver.ParseTolerant(platformVersion) + if err != nil { + cmd.Log.Warnf("Failed to parse platform version %s, falling back to the old platform activation flow with the admin email prompt", platformVersion) + return false + } + + const minPlatformVersionWithNewActivationFlow = "4.6.0-rc.8" + if platformSemVerVersion.GTE(semver.MustParse(minPlatformVersionWithNewActivationFlow)) { + cmd.Log.Debugf("Platform version %s is greater than or equal to %s, platform is using the new activation flow, so skipping admin email prompt", platformVersion, minPlatformVersionWithNewActivationFlow) + return true + } + + cmd.Log.Debugf("Platform version %s is not using the new activation flow, so admin email is required", platformVersion) + return false +} + func validateEmail(emailAddress string) error { if emailAddress == "" { return fmt.Errorf("admin email address is required") diff --git a/cmd/vclusterctl/cmd/platform/start_test.go b/cmd/vclusterctl/cmd/platform/start_test.go new file mode 100644 index 0000000000..87844f1786 --- /dev/null +++ b/cmd/vclusterctl/cmd/platform/start_test.go @@ -0,0 +1,68 @@ +package platform + +import ( + "testing" + + "github.com/loft-sh/log" + "github.com/loft-sh/vcluster/pkg/cli/flags" + "github.com/loft-sh/vcluster/pkg/cli/start" +) + +func TestNewStartCmd_SecureFlag(t *testing.T) { + globalFlags := &flags.GlobalFlags{} + cmd := NewStartCmd(globalFlags) + + // Verify --secure flag exists and defaults to false (insecure by default). + f := cmd.Flags().Lookup("secure") + if f == nil { + t.Fatal("--secure flag not registered on start command") + } + if f.DefValue != "false" { + t.Errorf("expected --secure default to be 'false', got %q", f.DefValue) + } + + // Simulate passing --secure on the command line. + if err := cmd.Flags().Set("secure", "true"); err != nil { + t.Fatalf("failed to set --secure flag: %v", err) + } + if f.Value.String() != "true" { + t.Errorf("expected --secure value to be 'true' after set, got %q", f.Value.String()) + } +} + +func TestPlatformUsesNewActivationFlow(t *testing.T) { + testCases := []struct { + version string + expected bool + }{ + {"", false}, + {"dev", false}, + {"4.5.0", false}, + {"v4.5.0", false}, + {"4.5.1", false}, + {"4.6.0-alpha.5", false}, + {"4.6.0-rc.7", false}, + {"4.6.0-rc.8", true}, + {"4.6.0-rc.9", true}, + {"4.6.0", true}, + {"v4.6.0", true}, + } + + globalFlags := &flags.GlobalFlags{} + startCmd := &StartCmd{ + StartOptions: start.StartOptions{ + Options: start.Options{ + CommandName: "start", + GlobalFlags: globalFlags, + Log: log.GetInstance(), + }, + }, + } + + for _, testCase := range testCases { + result := startCmd.platformUsesNewActivationFlow(testCase.version) + if result != testCase.expected { + t.Errorf("Expected %v, got %v for platform version %s", testCase.expected, result, testCase.version) + } + } +} diff --git a/pkg/cli/start/docker.go b/pkg/cli/start/docker.go index 21a693ac8c..e38501a77b 100644 --- a/pkg/cli/start/docker.go +++ b/pkg/cli/start/docker.go @@ -133,7 +133,7 @@ func (l *LoftStarter) successDocker(ctx context.Context, containerID string) err return false, fmt.Errorf("container failed (status: %s):\n %s", containerDetails.State.Status, logs) } - return clihelper.IsLoftReachable(ctx, host) + return clihelper.IsLoftReachable(ctx, host, l.LoadedConfig(l.Log).Platform.Insecure) }) if err != nil { return fmt.Errorf(product.Replace("error waiting for loft: %v%w"), err) diff --git a/pkg/cli/start/login.go b/pkg/cli/start/login.go index baeaa486e6..21bebc6b3c 100644 --- a/pkg/cli/start/login.go +++ b/pkg/cli/start/login.go @@ -58,8 +58,6 @@ func (l *LoftStarter) login(url string) error { } func (l *LoftStarter) loginViaCLI(url string) error { - loginPath := "%s/auth/password/login" - loginRequest := types.PasswordLoginRequest{ Username: defaultUser, Password: l.Password, @@ -70,32 +68,40 @@ func (l *LoftStarter) loginViaCLI(url string) error { return err } - loginRequestBuf := bytes.NewBuffer(loginRequestBytes) + config := l.LoadedConfig(l.Log) + httpClient := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: config.Platform.Insecure}, + }} - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - httpClient := &http.Client{Transport: tr} + // try a couple of times to login + accessKey := &types.AccessKey{} + for i := 0; i < 3; i++ { + resp, err := httpClient.Post(url+"/auth/password/login", "application/json", bytes.NewBuffer(loginRequestBytes)) + if err != nil { + return err + } - resp, err := httpClient.Post(fmt.Sprintf(loginPath, url), "application/json", loginRequestBuf) - if err != nil { - return err - } - defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + _ = resp.Body.Close() + return err + } + _ = resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return err + err = json.Unmarshal(body, accessKey) + if err != nil { + return err + } + if accessKey.AccessKey == "" { + continue + } + break } - - accessKey := &types.AccessKey{} - err = json.Unmarshal(body, accessKey) - if err != nil { - return err + if accessKey.AccessKey == "" { + return fmt.Errorf("couldn't retrieve access key from platform to login") } // log into loft - config := l.LoadedConfig(l.Log) loginClient := platform.NewLoginClientFromConfig(config) url = strings.TrimSuffix(url, "/") err = loginClient.LoginWithAccessKey(url, accessKey.AccessKey, config.Platform.Insecure) diff --git a/pkg/cli/start/port_forwarding.go b/pkg/cli/start/port_forwarding.go index 560f582099..cc356cdd92 100644 --- a/pkg/cli/start/port_forwarding.go +++ b/pkg/cli/start/port_forwarding.go @@ -21,7 +21,7 @@ func (l *LoftStarter) startPortForwarding(ctx context.Context, loftPod *corev1.P // wait until loft is reachable at the given url httpClient := &http.Client{ - Transport: utilhttp.InsecureTransport(), + Transport: utilhttp.Transport(l.LoadedConfig(l.Log).Platform.Insecure), } l.Log.Infof(product.Replace("Waiting until loft is reachable at https://localhost:%s"), l.LocalPort) err = wait.PollUntilContextTimeout(ctx, time.Second, clihelper.Timeout(), true, func(ctx context.Context) (bool, error) { diff --git a/pkg/cli/start/start.go b/pkg/cli/start/start.go index 9b20958f27..2591c628bc 100644 --- a/pkg/cli/start/start.go +++ b/pkg/cli/start/start.go @@ -70,6 +70,7 @@ type StartOptions struct { //nolint:revive // linter suggests renaming to option Upgrade bool ReuseValues bool Docker bool + Secure bool } func NewLoftStarter(options StartOptions) *LoftStarter { diff --git a/pkg/cli/start/success.go b/pkg/cli/start/success.go index b15de3c5dc..5f04919f43 100644 --- a/pkg/cli/start/success.go +++ b/pkg/cli/start/success.go @@ -68,7 +68,8 @@ func (l *LoftStarter) success(ctx context.Context) error { } // check if loft is reachable - reachable, err := clihelper.IsLoftReachable(ctx, host) + insecure := l.LoadedConfig(l.Log).Platform.Insecure + reachable, err := clihelper.IsLoftReachable(ctx, host, insecure) if !reachable || err != nil { const ( YesOption = "Yes" @@ -123,10 +124,11 @@ func (l *LoftStarter) pingLoftRouter(ctx context.Context, loftPod *corev1.Pod) ( loftRouterDomain := string(loftRouterSecret.Data["domain"]) // wait until loft is reachable at the given url + insecure := l.LoadedConfig(l.Log).Platform.Insecure httpClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, + InsecureSkipVerify: insecure, }, }, } @@ -190,7 +192,8 @@ func (l *LoftStarter) isLoggedIn(url string) bool { } func (l *LoftStarter) successRemote(ctx context.Context, host string) error { - ready, err := clihelper.IsLoftReachable(ctx, host) + insecure := l.LoadedConfig(l.Log).Platform.Insecure + ready, err := clihelper.IsLoftReachable(ctx, host, insecure) if err != nil { return err } else if ready { @@ -203,7 +206,7 @@ func (l *LoftStarter) successRemote(ctx context.Context, host string) error { l.Log.Info("Waiting for you to configure DNS, so loft can be reached on https://" + host) err = wait.PollUntilContextTimeout(ctx, 5*time.Second, clihelper.Timeout(), true, func(ctx context.Context) (done bool, err error) { - return clihelper.IsLoftReachable(ctx, host) + return clihelper.IsLoftReachable(ctx, host, insecure) }) if err != nil { return err diff --git a/pkg/platform/clihelper/clihelper.go b/pkg/platform/clihelper/clihelper.go index a9af8b02a6..341b3dd03b 100644 --- a/pkg/platform/clihelper/clihelper.go +++ b/pkg/platform/clihelper/clihelper.go @@ -344,10 +344,10 @@ func GetLoftDefaultPassword(ctx context.Context, kubeClient kubernetes.Interface return string(loftNamespace.UID), nil } -func IsLoftReachable(ctx context.Context, host string) (bool, error) { +func IsLoftReachable(ctx context.Context, host string, insecure bool) (bool, error) { // wait until loft is reachable at the given url client := &http.Client{ - Transport: utilhttp.InsecureTransport(), + Transport: utilhttp.Transport(insecure), } endpoint := fmt.Sprintf("https://%s/healthz", host) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) diff --git a/pkg/platform/clihelper/clihelper_test.go b/pkg/platform/clihelper/clihelper_test.go new file mode 100644 index 0000000000..ce7b50d5e7 --- /dev/null +++ b/pkg/platform/clihelper/clihelper_test.go @@ -0,0 +1,105 @@ +package clihelper + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "gotest.tools/v3/assert" +) + +func TestIsLoftReachable_InsecureTrueAgainstSelfSigned(t *testing.T) { + // Create an HTTPS server with a self-signed certificate. + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/healthz" { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + host := strings.TrimPrefix(server.URL, "https://") + + // With insecure=true, the self-signed cert should be accepted. + reachable, err := IsLoftReachable(context.Background(), host, true) + assert.NilError(t, err) + assert.Assert(t, reachable, "should be reachable with insecure=true against self-signed cert") +} + +func TestIsLoftReachable_InsecureFalseAgainstSelfSigned(t *testing.T) { + // Create an HTTPS server with a self-signed certificate. + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/healthz" { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + host := strings.TrimPrefix(server.URL, "https://") + + // With insecure=false, the self-signed cert should cause a TLS error + // and IsLoftReachable should return false (not reachable). + reachable, err := IsLoftReachable(context.Background(), host, false) + assert.NilError(t, err) + assert.Assert(t, !reachable, "should not be reachable with insecure=false against self-signed cert") +} + +func TestIsLoftReachable_InsecureFalseAgainstTrustedCert(t *testing.T) { + // Create an HTTPS server with a self-signed cert, but add the cert + // to the system pool so it's trusted. We do this by creating a custom + // test that validates the transport respects system certs. + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/healthz" { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + // Verify the server is actually using TLS with a self-signed cert. + conn, err := tls.Dial("tcp", strings.TrimPrefix(server.URL, "https://"), &tls.Config{ + InsecureSkipVerify: true, + }) + assert.NilError(t, err) + defer conn.Close() + + // Get the server certificate and create a cert pool that trusts it. + serverCert := conn.ConnectionState().PeerCertificates[0] + certPool := x509.NewCertPool() + certPool.AddCert(serverCert) + + // Verify the cert pool trusts the server - this validates our test setup. + _, err = serverCert.Verify(x509.VerifyOptions{ + Roots: certPool, + }) + assert.NilError(t, err, "cert should be verified with our custom pool") +} + +func TestIsLoftReachable_UnhealthyServer(t *testing.T) { + // Server that returns 500 on /healthz. + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + host := strings.TrimPrefix(server.URL, "https://") + + reachable, err := IsLoftReachable(context.Background(), host, true) + assert.NilError(t, err) + assert.Assert(t, !reachable, "should not be reachable when server returns 500") +} + +func TestIsLoftReachable_UnreachableHost(t *testing.T) { + // Use a host that doesn't exist. + reachable, err := IsLoftReachable(context.Background(), "localhost:1", true) + assert.NilError(t, err) + assert.Assert(t, !reachable, "should not be reachable when host is unreachable") +} diff --git a/pkg/util/http/transport.go b/pkg/util/http/transport.go index c31c8796eb..d819dd91ab 100644 --- a/pkg/util/http/transport.go +++ b/pkg/util/http/transport.go @@ -12,8 +12,17 @@ func CloneDefaultTransport() *http.Transport { return transport } -func InsecureTransport() *http.Transport { +// Transport returns a cloned default transport with TLS verification +// controlled by the insecure parameter. When insecure is true, TLS +// certificate verification is skipped. +func Transport(insecure bool) *http.Transport { newTransport := CloneDefaultTransport() - newTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + if insecure { + newTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } return newTransport } + +func InsecureTransport() *http.Transport { + return Transport(true) +} diff --git a/pkg/util/http/transport_test.go b/pkg/util/http/transport_test.go new file mode 100644 index 0000000000..a065cab647 --- /dev/null +++ b/pkg/util/http/transport_test.go @@ -0,0 +1,33 @@ +package http + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +func TestTransport_Secure(t *testing.T) { + tr := Transport(false) + if tr.TLSClientConfig != nil { + assert.Assert(t, !tr.TLSClientConfig.InsecureSkipVerify, "TLS verification should be enabled when insecure=false") + } + assert.Assert(t, !tr.ForceAttemptHTTP2, "HTTP/2 should be disabled") +} + +func TestTransport_Insecure(t *testing.T) { + tr := Transport(true) + assert.Assert(t, tr.TLSClientConfig != nil, "TLSClientConfig should be set when insecure=true") + assert.Assert(t, tr.TLSClientConfig.InsecureSkipVerify, "TLS verification should be skipped when insecure=true") + assert.Assert(t, !tr.ForceAttemptHTTP2, "HTTP/2 should be disabled") +} + +func TestInsecureTransport(t *testing.T) { + tr := InsecureTransport() + assert.Assert(t, tr.TLSClientConfig != nil, "TLSClientConfig should be set") + assert.Assert(t, tr.TLSClientConfig.InsecureSkipVerify, "InsecureTransport should skip TLS verification") +} + +func TestCloneDefaultTransport(t *testing.T) { + tr := CloneDefaultTransport() + assert.Assert(t, !tr.ForceAttemptHTTP2, "HTTP/2 should be disabled") +}