Skip to content

Commit 5e8f903

Browse files
Allow localhost on u2m auth (#1233)
## What changes are proposed in this pull request? This PR relaxes the validation performed on the host URL to allow `http` URLs if the URL points to localhost. This is needed to add acceptance tests in the Databricks CLI for the `databricks auth login` command. For context: The acceptance test framework in the CLI has the ability to spawn a testserver on localhost and record API requests sent to it. We can use the testserver to validate that the correct API requests are being sent during `auth login`, and that the command does not fail. Prompted by databricks/cli#2988 (comment) where I almost let a big regression pass. NO_CHANGELOG=true ## How is this tested? Unit test
1 parent 3821c8e commit 5e8f903

3 files changed

Lines changed: 64 additions & 10 deletions

File tree

credentials/u2m/account_oauth_argument.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package u2m
22

33
import (
44
"fmt"
5-
"strings"
65
)
76

87
// AccountOAuthArgument is an interface that provides the necessary information
@@ -28,11 +27,8 @@ var _ AccountOAuthArgument = BasicAccountOAuthArgument{}
2827

2928
// NewBasicAccountOAuthArgument creates a new BasicAccountOAuthArgument.
3029
func NewBasicAccountOAuthArgument(accountsHost, accountID string) (BasicAccountOAuthArgument, error) {
31-
if !strings.HasPrefix(accountsHost, "https://") {
32-
return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must start with 'https://': %s", accountsHost)
33-
}
34-
if strings.HasSuffix(accountsHost, "/") {
35-
return BasicAccountOAuthArgument{}, fmt.Errorf("accountsHost must not have a trailing slash: %s", accountsHost)
30+
if err := validateHost(accountsHost); err != nil {
31+
return BasicAccountOAuthArgument{}, err
3632
}
3733
return BasicAccountOAuthArgument{accountHost: accountsHost, accountID: accountID}, nil
3834
}

credentials/u2m/workspace_oauth_argument.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,25 @@ type BasicWorkspaceOAuthArgument struct {
2222
host string
2323
}
2424

25-
// NewBasicWorkspaceOAuthArgument creates a new BasicWorkspaceOAuthArgument.
26-
func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, error) {
25+
func validateHost(host string) error {
26+
// Allow http for localhost. This is necessary for local end to end testing
27+
// of the `databricks auth login` command using a test server on localhost.
28+
if strings.HasPrefix(host, "http://127.0.0.1") {
29+
return nil
30+
}
2731
if !strings.HasPrefix(host, "https://") {
28-
return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must start with 'https://': %s", host)
32+
return fmt.Errorf("host must start with 'https://': %s", host)
2933
}
3034
if strings.HasSuffix(host, "/") {
31-
return BasicWorkspaceOAuthArgument{}, fmt.Errorf("host must not have a trailing slash: %s", host)
35+
return fmt.Errorf("host must not have a trailing slash: %s", host)
36+
}
37+
return nil
38+
}
39+
40+
// NewBasicWorkspaceOAuthArgument creates a new BasicWorkspaceOAuthArgument.
41+
func NewBasicWorkspaceOAuthArgument(host string) (BasicWorkspaceOAuthArgument, error) {
42+
if err := validateHost(host); err != nil {
43+
return BasicWorkspaceOAuthArgument{}, err
3244
}
3345
return BasicWorkspaceOAuthArgument{host: host}, nil
3446
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package u2m
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestValidateHost(t *testing.T) {
10+
tests := []struct {
11+
host string
12+
want string
13+
}{
14+
// Valid hosts
15+
{"https://some-host.com", ""},
16+
{"http://127.0.0.1", ""},
17+
{"http://127.0.0.1:5656", ""},
18+
19+
// Invalid hosts
20+
{"http://some-host.com", "host must start with 'https://': http://some-host.com"},
21+
{"https://some-host.com/", "host must not have a trailing slash: https://some-host.com/"},
22+
}
23+
24+
for _, test := range tests {
25+
err := validateHost(test.host)
26+
if test.want == "" {
27+
assert.NoError(t, err)
28+
} else {
29+
assert.EqualError(t, err, test.want)
30+
}
31+
32+
_, err = NewBasicWorkspaceOAuthArgument(test.host)
33+
if test.want == "" {
34+
assert.NoError(t, err)
35+
} else {
36+
assert.EqualError(t, err, test.want)
37+
}
38+
39+
_, err = NewBasicAccountOAuthArgument(test.host, "123")
40+
if test.want == "" {
41+
assert.NoError(t, err)
42+
} else {
43+
assert.EqualError(t, err, test.want)
44+
}
45+
}
46+
}

0 commit comments

Comments
 (0)