Skip to content

Commit 2474109

Browse files
committed
oauth: add HTMLSuccess + HTMLError fields for branded login pages
The htmlSuccess / htmlError templates are good defaults but shipped as unexported package vars, so callers can't brand the post-login landing page short of forking. Adds two optional fields on AuthorizationCodeTokenSource that override the defaults when set, and plumbs them through to authHandler via two matching unexported fields. When empty, the new fields are no-ops and the existing constants are served, so existing callers see no behaviour change. The custom error HTML supports the same $ERROR / $DETAILS substitution the built-in errorHTML does, so callers can reuse the documented format without having to learn two templating systems. Tests: five new cases in oauth/authcode_test.go covering default + custom success, default + custom error, and field independence (overriding success alone leaves the error path unaffected).
1 parent dd92a85 commit 2474109

2 files changed

Lines changed: 100 additions & 5 deletions

File tree

oauth/authcode.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,24 +177,37 @@ func getInput(input chan string) {
177177
}
178178

179179
// authHandler is an HTTP handler that takes a channel and sends the `code`
180-
// query param when it gets a request.
180+
// query param when it gets a request. The successHTML and errorHTML fields
181+
// allow callers to override the built-in HTML pages; an empty string falls
182+
// back to the package-level htmlSuccess / htmlError defaults.
181183
type authHandler struct {
182-
c chan string
184+
c chan string
185+
successHTML string
186+
errorHTML string
183187
}
184188

185189
func (h authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
186190
w.Header().Set("Content-Type", "text/html")
187191

192+
successBody := h.successHTML
193+
if successBody == "" {
194+
successBody = htmlSuccess
195+
}
196+
errorBody := h.errorHTML
197+
if errorBody == "" {
198+
errorBody = htmlError
199+
}
200+
188201
if err := r.URL.Query().Get("error"); err != "" {
189202
details := r.URL.Query().Get("error_description")
190-
rendered := strings.Replace(strings.Replace(htmlError, "$ERROR", err, 1), "$DETAILS", details, 1)
203+
rendered := strings.Replace(strings.Replace(errorBody, "$ERROR", err, 1), "$DETAILS", details, 1)
191204
w.Write([]byte(rendered))
192205
h.c <- ""
193206
return
194207
}
195208

196209
h.c <- r.URL.Query().Get("code")
197-
w.Write([]byte(htmlSuccess))
210+
w.Write([]byte(successBody))
198211
}
199212

200213
// AuthorizationCodeTokenSource with PKCE as described in:
@@ -212,6 +225,8 @@ type AuthorizationCodeTokenSource struct {
212225
RedirectURL string
213226
EndpointParams *url.Values
214227
Scopes []string
228+
HTMLSuccess string
229+
HTMLError string
215230
}
216231

217232
func (ac *AuthorizationCodeTokenSource) getRedirectUrl() string {
@@ -260,7 +275,9 @@ func (ac *AuthorizationCodeTokenSource) Token() (*oauth2.Token, error) {
260275
// Run server before opening the user's browser so we are ready for any redirect.
261276
codeChan := make(chan string)
262277
handler := authHandler{
263-
c: codeChan,
278+
c: codeChan,
279+
successHTML: ac.HTMLSuccess,
280+
errorHTML: ac.HTMLError,
264281
}
265282

266283
// strip protocol prefix from configured redirect url for local webserver

oauth/authcode_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package oauth
22

33
import (
4+
"io"
5+
"net/http"
6+
"net/http/httptest"
47
"strings"
58
"testing"
69

710
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
812
)
913

1014
func TestEncodeUrlWindowsSuccess(t *testing.T) {
@@ -18,3 +22,77 @@ func TestEncodeUrlWindowsSuccess(t *testing.T) {
1822
assert.False(t, strings.HasPrefix(r, "^&"))
1923
assert.False(t, strings.HasSuffix(r, "^&"))
2024
}
25+
26+
// runAuthHandler drives the authHandler with the given query string and
27+
// returns the response body, allowing assertions on what HTML was served.
28+
func runAuthHandler(t *testing.T, h authHandler, query string) string {
29+
t.Helper()
30+
31+
rec := httptest.NewRecorder()
32+
req := httptest.NewRequest(http.MethodGet, "http://localhost/callback?"+query, nil)
33+
34+
// authHandler sends on h.c synchronously; drain in a goroutine so the
35+
// handler can complete without blocking on an unbuffered channel.
36+
done := make(chan struct{})
37+
go func() {
38+
<-h.c
39+
close(done)
40+
}()
41+
42+
h.ServeHTTP(rec, req)
43+
<-done
44+
45+
body, err := io.ReadAll(rec.Body)
46+
require.NoError(t, err)
47+
48+
return string(body)
49+
}
50+
51+
func TestAuthHandlerServesDefaultSuccessHTML(t *testing.T) {
52+
h := authHandler{c: make(chan string)}
53+
54+
body := runAuthHandler(t, h, "code=abc123")
55+
56+
assert.Contains(t, body, "Login Successful!", "default success HTML should be served when successHTML is unset")
57+
}
58+
59+
func TestAuthHandlerServesCustomSuccessHTML(t *testing.T) {
60+
const custom = `<html><body><h1>Welcome to my-tool</h1></body></html>`
61+
h := authHandler{c: make(chan string), successHTML: custom}
62+
63+
body := runAuthHandler(t, h, "code=abc123")
64+
65+
assert.Equal(t, custom, body, "custom successHTML should be served verbatim")
66+
assert.NotContains(t, body, "Login Successful!", "default content should not leak through")
67+
}
68+
69+
func TestAuthHandlerServesDefaultErrorHTML(t *testing.T) {
70+
h := authHandler{c: make(chan string)}
71+
72+
body := runAuthHandler(t, h, "error=access_denied&error_description=user+denied")
73+
74+
assert.Contains(t, body, "access_denied", "$ERROR should be substituted in the default error HTML")
75+
assert.Contains(t, body, "user denied", "$DETAILS should be substituted in the default error HTML")
76+
}
77+
78+
func TestAuthHandlerServesCustomErrorHTML(t *testing.T) {
79+
const custom = `<html><body><h1>Login failed: $ERROR</h1><p>$DETAILS</p></body></html>`
80+
h := authHandler{c: make(chan string), errorHTML: custom}
81+
82+
body := runAuthHandler(t, h, "error=invalid_grant&error_description=code+expired")
83+
84+
assert.Equal(t, `<html><body><h1>Login failed: invalid_grant</h1><p>code expired</p></body></html>`, body,
85+
"custom errorHTML should be served with $ERROR and $DETAILS substituted")
86+
}
87+
88+
func TestAuthHandlerCustomSuccessDoesNotAffectErrorPath(t *testing.T) {
89+
// A caller that only overrides successHTML should still see the default
90+
// error HTML when an error comes through. Each field is independently
91+
// optional.
92+
h := authHandler{c: make(chan string), successHTML: "<html>custom success</html>"}
93+
94+
body := runAuthHandler(t, h, "error=server_error")
95+
96+
assert.Contains(t, body, "server_error", "default error HTML should still substitute $ERROR")
97+
assert.NotContains(t, body, "custom success")
98+
}

0 commit comments

Comments
 (0)