From 9109638ec82ec27998b5acfc0f98edefcf774c31 Mon Sep 17 00:00:00 2001 From: Vanshaj Singhania Date: Sun, 22 Mar 2026 11:55:35 -0700 Subject: [PATCH] perform node.IsTagged() check in /authorize Signed-off-by: Vanshaj Singhania --- server/authorize.go | 7 +- server/authorize_test.go | 160 ++++++++++++++++++++++++--------------- server/helpers_test.go | 43 +++++++++++ server/server.go | 1 + server/token_test.go | 12 +++ 5 files changed, 160 insertions(+), 63 deletions(-) diff --git a/server/authorize.go b/server/authorize.go index e71c81f28..fe66c9981 100644 --- a/server/authorize.go +++ b/server/authorize.go @@ -86,6 +86,11 @@ func (s *IDPServer) serveAuthorize(w http.ResponseWriter, r *http.Request) { return } + if who.Node.View().IsTagged() { + redirectAuthError(w, r, redirectURI, ecAccessDenied, "tagged node doesn't have a user identity", state) + return + } + // Generate and save a code and Auth Request code := rands.HexString(32) ar := &AuthRequest{ @@ -104,7 +109,7 @@ func (s *IDPServer) serveAuthorize(w http.ResponseWriter, r *http.Request) { // Validate scopes validatedScopes, err := s.validateScopes(ar.Scopes) if err != nil { - redirectAuthError(w, r, redirectURI, "invalid_scope", fmt.Sprintf("invalid scope: %v", err), state) + redirectAuthError(w, r, redirectURI, ecInvalidScope, fmt.Sprintf("invalid scope: %v", err), state) return } ar.Scopes = validatedScopes diff --git a/server/authorize_test.go b/server/authorize_test.go index 26652c5bd..bb40ad0c0 100644 --- a/server/authorize_test.go +++ b/server/authorize_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" ) @@ -534,13 +535,11 @@ func TestServeAuthorize(t *testing.T) { redirectURI string state string nonce string - setupClient bool clientRedirect string useFunnel bool // whether to simulate funnel request mockWhoIsError bool // whether to make WhoIs return an error - expectError bool + isTaggedNode bool // whether to make IsTagged return true expectCode int - expectRedirect bool }{ // Security boundary test: funnel rejection { @@ -549,10 +548,8 @@ func TestServeAuthorize(t *testing.T) { redirectURI: "https://rp.example.com/callback", state: "random-state", nonce: "random-nonce", - setupClient: true, clientRedirect: "https://rp.example.com/callback", useFunnel: true, - expectError: true, expectCode: http.StatusUnauthorized, }, @@ -562,7 +559,6 @@ func TestServeAuthorize(t *testing.T) { clientID: "", redirectURI: "https://rp.example.com/callback", useFunnel: false, - expectError: true, expectCode: http.StatusBadRequest, }, { @@ -570,7 +566,6 @@ func TestServeAuthorize(t *testing.T) { clientID: "test-client", redirectURI: "", useFunnel: false, - expectError: true, expectCode: http.StatusBadRequest, }, @@ -579,29 +574,61 @@ func TestServeAuthorize(t *testing.T) { name: "invalid client_id", clientID: "invalid-client", redirectURI: "https://rp.example.com/callback", - setupClient: false, useFunnel: false, - expectError: true, expectCode: http.StatusBadRequest, }, { name: "redirect_uri mismatch", clientID: "test-client", redirectURI: "https://wrong.example.com/callback", - setupClient: true, clientRedirect: "https://rp.example.com/callback", useFunnel: false, - expectError: true, expectCode: http.StatusBadRequest, }, + + // other cases + { + name: "WhoIs error blocks flow", + clientID: "test-client", + redirectURI: "https://rp.example.com/callback", + clientRedirect: "https://rp.example.com/callback", + mockWhoIsError: true, + expectCode: http.StatusInternalServerError, + }, + { + name: "tagged node is not allowed", + clientID: "test-client", + redirectURI: "https://rp.example.com/callback", + clientRedirect: "https://rp.example.com/callback", + isTaggedNode: true, + expectCode: http.StatusFound, + }, + { + name: "successfully issues auth code", + clientID: "test-client", + redirectURI: "https://rp.example.com/callback", + clientRedirect: "https://rp.example.com/callback", + expectCode: http.StatusFound, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - srv := setupTestServer(t, nil) + whoisResponse := &apitype.WhoIsResponse{ + Node: &tailcfg.Node{}, + } + if tt.isTaggedNode { + whoisResponse.Node.Tags = append(whoisResponse.Node.Tags, "tag:authorize-test-tag") + } + + var lc *local.Client + if tt.mockWhoIsError { + lc = newTestWhoIsClient(t, nil, true) + } else { + lc = newTestWhoIsClient(t, whoisResponse, false) + } - // For non-funnel tests, we'll test the parameter validation logic - // without needing to mock WhoIs, since the validation happens before WhoIs calls + srv := setupTestServer(t, lc) // Setup client if needed srv.funnelClients["test-client"] = &FunnelClient{ @@ -640,59 +667,68 @@ func TestServeAuthorize(t *testing.T) { rr := httptest.NewRecorder() srv.serveAuthorize(rr, req) - if tt.expectError { - if rr.Code != tt.expectCode { - t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String()) + if rr.Code != tt.expectCode { + t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String()) + } + + // For any other code, the error check above is the last step + if tt.expectCode != http.StatusFound { + return + } + + location := rr.Header().Get("Location") + if location == "" { + t.Error("expected Location header in redirect response") + return + } + + // Parse the redirect URL to verify it contains a code + redirectURL, err := url.Parse(location) + if err != nil { + t.Errorf("failed to parse redirect URL: %v", err) + return + } + + // For a tagged node, we expect an `access_denied` error as the last step + if tt.isTaggedNode { + errCode := redirectURL.Query().Get("error") + if errCode != ecAccessDenied { + t.Error("expected 'error' parameter in redirect URL to be 'access_denied'") } - } else if tt.expectRedirect { - if rr.Code != http.StatusFound { - t.Errorf("expected redirect (302), got %d: %s", rr.Code, rr.Body.String()) + return + } + + code := redirectURL.Query().Get("code") + if code == "" { + t.Error("expected 'code' parameter in redirect URL") + } + + // Verify state is preserved if provided + if tt.state != "" { + returnedState := redirectURL.Query().Get("state") + if returnedState != tt.state { + t.Errorf("expected state '%s', got '%s'", tt.state, returnedState) } + } - location := rr.Header().Get("Location") - if location == "" { - t.Error("expected Location header in redirect response") - } else { - // Parse the redirect URL to verify it contains a code - redirectURL, err := url.Parse(location) - if err != nil { - t.Errorf("failed to parse redirect URL: %v", err) - } else { - code := redirectURL.Query().Get("code") - if code == "" { - t.Error("expected 'code' parameter in redirect URL") - } + // Verify the auth request was stored + srv.mu.Lock() + ar, ok := srv.code[code] + srv.mu.Unlock() - // Verify state is preserved if provided - if tt.state != "" { - returnedState := redirectURL.Query().Get("state") - if returnedState != tt.state { - t.Errorf("expected state '%s', got '%s'", tt.state, returnedState) - } - } + if !ok { + t.Error("expected authorization request to be stored") + return + } - // Verify the auth request was stored - srv.mu.Lock() - ar, ok := srv.code[code] - srv.mu.Unlock() - - if !ok { - t.Error("expected authorization request to be stored") - } else { - if ar.ClientID != tt.clientID { - t.Errorf("expected clientID '%s', got '%s'", tt.clientID, ar.ClientID) - } - if ar.RedirectURI != tt.redirectURI { - t.Errorf("expected redirectURI '%s', got '%s'", tt.redirectURI, ar.RedirectURI) - } - if ar.Nonce != tt.nonce { - t.Errorf("expected nonce '%s', got '%s'", tt.nonce, ar.Nonce) - } - } - } - } - } else { - t.Errorf("unexpected test case: not expecting error or redirect") + if ar.ClientID != tt.clientID { + t.Errorf("expected clientID '%s', got '%s'", tt.clientID, ar.ClientID) + } + if ar.RedirectURI != tt.redirectURI { + t.Errorf("expected redirectURI '%s', got '%s'", tt.redirectURI, ar.RedirectURI) + } + if ar.Nonce != tt.nonce { + t.Errorf("expected nonce '%s', got '%s'", tt.nonce, ar.Nonce) } }) } diff --git a/server/helpers_test.go b/server/helpers_test.go index c1d7b02b5..e8b0fe549 100644 --- a/server/helpers_test.go +++ b/server/helpers_test.go @@ -4,15 +4,19 @@ package server import ( + "bytes" "crypto/rand" "crypto/rsa" "encoding/json" "fmt" + "io" + "net/http" "sort" "testing" "gopkg.in/square/go-jose.v2" "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" ) @@ -45,6 +49,45 @@ func setupTestServer(t *testing.T, lc *local.Client) *IDPServer { return srv } +// whoisRoundTripper is an http.RoundTripper that returns a canned WhoIs +// response. It is used to test code that calls local.Client.WhoIs without +// needing a running tailscaled. +type whoisRoundTripper struct { + response *apitype.WhoIsResponse + err bool // if true, return HTTP 500 +} + +func (rt *whoisRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/localapi/v0/whois" { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewReader(nil)), + }, nil + } + if rt.err { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(bytes.NewBufferString("whois error")), + }, nil + } + b, _ := json.Marshal(rt.response) + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": {"application/json"}}, + Body: io.NopCloser(bytes.NewReader(b)), + }, nil +} + +// newTestWhoIsClient returns a *local.Client whose WhoIs calls return the +// given response or an error. This uses local.Client's Transport field to +// intercept HTTP requests without needing a running tailscaled. +func newTestWhoIsClient(t *testing.T, whoisResponse *apitype.WhoIsResponse, whoisErr bool) *local.Client { + t.Helper() + return &local.Client{ + Transport: &whoisRoundTripper{response: whoisResponse, err: whoisErr}, + } +} + func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage { t.Helper() b, err := json.Marshal(v) diff --git a/server/server.go b/server/server.go index c6fe9e198..af76648d1 100644 --- a/server/server.go +++ b/server/server.go @@ -156,6 +156,7 @@ const ( ecInvalidRequest = "invalid_request" ecInvalidClient = "invalid_client" ecInvalidGrant = "invalid_grant" + ecInvalidScope = "invalid_scope" ecServerError = "server_error" ecNotFound = "not_found" ecUnsupportedGrant = "unsupported_grant_type" diff --git a/server/token_test.go b/server/token_test.go index 8c165fce8..6e6795f8e 100644 --- a/server/token_test.go +++ b/server/token_test.go @@ -1132,6 +1132,7 @@ func TestServeToken(t *testing.T) { tests := []struct { name string caps tailcfg.PeerCapMap + tags []string method string grantType string code string @@ -1188,6 +1189,16 @@ func TestServeToken(t *testing.T) { remoteAddr: "192.168.0.1:12345", expectError: true, }, + { + name: "tagged nodes are not allowed", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + remoteAddr: "192.168.0.1:12345", + tags: []string{"tag:mytag"}, + expectError: true, + }, { name: "extra claim included", method: "POST", @@ -1249,6 +1260,7 @@ func TestServeToken(t *testing.T) { Key: key.NodePublic{}, Cap: 1, DiscoKey: key.DiscoPublic{}, + Tags: tt.tags, } remoteUser := &apitype.WhoIsResponse{