Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion server/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ func (s *IDPServer) serveAuthorize(w http.ResponseWriter, r *http.Request) {
return
}

if who.Node.View().IsTagged() {
redirectAuthError(w, r, redirectURI, ecAccessDenied, "tagged node doesn't have a user identity", state)
return
}

// Generate and save a code and Auth Request
code := rands.HexString(32)
ar := &AuthRequest{
Expand All @@ -104,7 +109,7 @@ func (s *IDPServer) serveAuthorize(w http.ResponseWriter, r *http.Request) {
// Validate scopes
validatedScopes, err := s.validateScopes(ar.Scopes)
if err != nil {
redirectAuthError(w, r, redirectURI, "invalid_scope", fmt.Sprintf("invalid scope: %v", err), state)
redirectAuthError(w, r, redirectURI, ecInvalidScope, fmt.Sprintf("invalid scope: %v", err), state)
return
}
ar.Scopes = validatedScopes
Expand Down
160 changes: 98 additions & 62 deletions server/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"
"time"

"tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg"
)
Expand Down Expand Up @@ -534,13 +535,11 @@ func TestServeAuthorize(t *testing.T) {
redirectURI string
state string
nonce string
setupClient bool
clientRedirect string
useFunnel bool // whether to simulate funnel request
mockWhoIsError bool // whether to make WhoIs return an error
expectError bool
isTaggedNode bool // whether to make IsTagged return true
expectCode int
expectRedirect bool
}{
// Security boundary test: funnel rejection
{
Expand All @@ -549,10 +548,8 @@ func TestServeAuthorize(t *testing.T) {
redirectURI: "https://rp.example.com/callback",
state: "random-state",
nonce: "random-nonce",
setupClient: true,
clientRedirect: "https://rp.example.com/callback",
useFunnel: true,
expectError: true,
expectCode: http.StatusUnauthorized,
},

Expand All @@ -562,15 +559,13 @@ func TestServeAuthorize(t *testing.T) {
clientID: "",
redirectURI: "https://rp.example.com/callback",
useFunnel: false,
expectError: true,
expectCode: http.StatusBadRequest,
},
{
name: "missing redirect_uri",
clientID: "test-client",
redirectURI: "",
useFunnel: false,
expectError: true,
expectCode: http.StatusBadRequest,
},

Expand All @@ -579,29 +574,61 @@ func TestServeAuthorize(t *testing.T) {
name: "invalid client_id",
clientID: "invalid-client",
redirectURI: "https://rp.example.com/callback",
setupClient: false,
useFunnel: false,
expectError: true,
expectCode: http.StatusBadRequest,
},
{
name: "redirect_uri mismatch",
clientID: "test-client",
redirectURI: "https://wrong.example.com/callback",
setupClient: true,
clientRedirect: "https://rp.example.com/callback",
useFunnel: false,
expectError: true,
expectCode: http.StatusBadRequest,
},

// other cases
{
name: "WhoIs error blocks flow",
clientID: "test-client",
redirectURI: "https://rp.example.com/callback",
clientRedirect: "https://rp.example.com/callback",
mockWhoIsError: true,
expectCode: http.StatusInternalServerError,
},
{
name: "tagged node is not allowed",
clientID: "test-client",
redirectURI: "https://rp.example.com/callback",
clientRedirect: "https://rp.example.com/callback",
isTaggedNode: true,
expectCode: http.StatusFound,
},
{
name: "successfully issues auth code",
clientID: "test-client",
redirectURI: "https://rp.example.com/callback",
clientRedirect: "https://rp.example.com/callback",
expectCode: http.StatusFound,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := setupTestServer(t, nil)
whoisResponse := &apitype.WhoIsResponse{
Node: &tailcfg.Node{},
}
if tt.isTaggedNode {
whoisResponse.Node.Tags = append(whoisResponse.Node.Tags, "tag:authorize-test-tag")
}

var lc *local.Client
if tt.mockWhoIsError {
lc = newTestWhoIsClient(t, nil, true)
} else {
lc = newTestWhoIsClient(t, whoisResponse, false)
}

// For non-funnel tests, we'll test the parameter validation logic
// without needing to mock WhoIs, since the validation happens before WhoIs calls
srv := setupTestServer(t, lc)

// Setup client if needed
srv.funnelClients["test-client"] = &FunnelClient{
Expand Down Expand Up @@ -640,59 +667,68 @@ func TestServeAuthorize(t *testing.T) {
rr := httptest.NewRecorder()
srv.serveAuthorize(rr, req)

if tt.expectError {
if rr.Code != tt.expectCode {
t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String())
if rr.Code != tt.expectCode {
t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String())
}

// For any other code, the error check above is the last step
if tt.expectCode != http.StatusFound {
return
}

location := rr.Header().Get("Location")
if location == "" {
t.Error("expected Location header in redirect response")
return
}

// Parse the redirect URL to verify it contains a code
redirectURL, err := url.Parse(location)
if err != nil {
t.Errorf("failed to parse redirect URL: %v", err)
return
}

// For a tagged node, we expect an `access_denied` error as the last step
if tt.isTaggedNode {
errCode := redirectURL.Query().Get("error")
if errCode != ecAccessDenied {
t.Error("expected 'error' parameter in redirect URL to be 'access_denied'")
}
} else if tt.expectRedirect {
if rr.Code != http.StatusFound {
t.Errorf("expected redirect (302), got %d: %s", rr.Code, rr.Body.String())
return
}

code := redirectURL.Query().Get("code")
if code == "" {
t.Error("expected 'code' parameter in redirect URL")
}

// Verify state is preserved if provided
if tt.state != "" {
returnedState := redirectURL.Query().Get("state")
if returnedState != tt.state {
t.Errorf("expected state '%s', got '%s'", tt.state, returnedState)
}
}

location := rr.Header().Get("Location")
if location == "" {
t.Error("expected Location header in redirect response")
} else {
// Parse the redirect URL to verify it contains a code
redirectURL, err := url.Parse(location)
if err != nil {
t.Errorf("failed to parse redirect URL: %v", err)
} else {
code := redirectURL.Query().Get("code")
if code == "" {
t.Error("expected 'code' parameter in redirect URL")
}
// Verify the auth request was stored
srv.mu.Lock()
ar, ok := srv.code[code]
srv.mu.Unlock()

// Verify state is preserved if provided
if tt.state != "" {
returnedState := redirectURL.Query().Get("state")
if returnedState != tt.state {
t.Errorf("expected state '%s', got '%s'", tt.state, returnedState)
}
}
if !ok {
t.Error("expected authorization request to be stored")
return
}

// Verify the auth request was stored
srv.mu.Lock()
ar, ok := srv.code[code]
srv.mu.Unlock()

if !ok {
t.Error("expected authorization request to be stored")
} else {
if ar.ClientID != tt.clientID {
t.Errorf("expected clientID '%s', got '%s'", tt.clientID, ar.ClientID)
}
if ar.RedirectURI != tt.redirectURI {
t.Errorf("expected redirectURI '%s', got '%s'", tt.redirectURI, ar.RedirectURI)
}
if ar.Nonce != tt.nonce {
t.Errorf("expected nonce '%s', got '%s'", tt.nonce, ar.Nonce)
}
}
}
}
} else {
t.Errorf("unexpected test case: not expecting error or redirect")
if ar.ClientID != tt.clientID {
t.Errorf("expected clientID '%s', got '%s'", tt.clientID, ar.ClientID)
}
if ar.RedirectURI != tt.redirectURI {
t.Errorf("expected redirectURI '%s', got '%s'", tt.redirectURI, ar.RedirectURI)
}
if ar.Nonce != tt.nonce {
t.Errorf("expected nonce '%s', got '%s'", tt.nonce, ar.Nonce)
}
})
}
Expand Down
43 changes: 43 additions & 0 deletions server/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
package server

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"testing"

"gopkg.in/square/go-jose.v2"
"tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg"
)

Expand Down Expand Up @@ -45,6 +49,45 @@ func setupTestServer(t *testing.T, lc *local.Client) *IDPServer {
return srv
}

// whoisRoundTripper is an http.RoundTripper that returns a canned WhoIs
// response. It is used to test code that calls local.Client.WhoIs without
// needing a running tailscaled.
type whoisRoundTripper struct {
response *apitype.WhoIsResponse
err bool // if true, return HTTP 500
}

func (rt *whoisRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Path != "/localapi/v0/whois" {
return &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
if rt.err {
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewBufferString("whois error")),
}, nil
}
b, _ := json.Marshal(rt.response)
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": {"application/json"}},
Body: io.NopCloser(bytes.NewReader(b)),
}, nil
}

// newTestWhoIsClient returns a *local.Client whose WhoIs calls return the
// given response or an error. This uses local.Client's Transport field to
// intercept HTTP requests without needing a running tailscaled.
func newTestWhoIsClient(t *testing.T, whoisResponse *apitype.WhoIsResponse, whoisErr bool) *local.Client {
t.Helper()
return &local.Client{
Transport: &whoisRoundTripper{response: whoisResponse, err: whoisErr},
}
}

Comment thread
itsvs marked this conversation as resolved.
func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage {
t.Helper()
b, err := json.Marshal(v)
Expand Down
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ const (
ecInvalidRequest = "invalid_request"
ecInvalidClient = "invalid_client"
ecInvalidGrant = "invalid_grant"
ecInvalidScope = "invalid_scope"
ecServerError = "server_error"
ecNotFound = "not_found"
ecUnsupportedGrant = "unsupported_grant_type"
Expand Down
12 changes: 12 additions & 0 deletions server/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ func TestServeToken(t *testing.T) {
tests := []struct {
name string
caps tailcfg.PeerCapMap
tags []string
method string
grantType string
code string
Expand Down Expand Up @@ -1188,6 +1189,16 @@ func TestServeToken(t *testing.T) {
remoteAddr: "192.168.0.1:12345",
expectError: true,
},
{
name: "tagged nodes are not allowed",
method: "POST",
grantType: "authorization_code",
redirectURI: "https://rp.example.com/callback",
code: "valid-code",
remoteAddr: "192.168.0.1:12345",
tags: []string{"tag:mytag"},
expectError: true,
},
{
name: "extra claim included",
method: "POST",
Expand Down Expand Up @@ -1249,6 +1260,7 @@ func TestServeToken(t *testing.T) {
Key: key.NodePublic{},
Cap: 1,
DiscoKey: key.DiscoPublic{},
Tags: tt.tags,
}

remoteUser := &apitype.WhoIsResponse{
Expand Down
Loading