Skip to content

Commit 8c3a705

Browse files
committed
feat: reworked csh-auth for a v1 release
1 parent a33a9cd commit 8c3a705

5 files changed

Lines changed: 364 additions & 183 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,8 @@
1010
# Output of the go coverage tool, specifically when used with LiteIDE
1111
*.out
1212

13+
# IDE / editor files
14+
.idea/
15+
1316
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
1417
.glide/

auth.go

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
package csh_auth
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"encoding/json"
7+
"errors"
8+
"io"
9+
"net/http"
10+
"time"
11+
12+
"github.com/coreos/go-oidc"
13+
"github.com/gin-gonic/gin"
14+
"github.com/golang-jwt/jwt/v5"
15+
log "github.com/sirupsen/logrus"
16+
"golang.org/x/oauth2"
17+
)
18+
19+
const ContextKey = "cshauth"
20+
const CookieName = "Auth"
21+
const ProviderURI = "https://sso.csh.rit.edu/auth/realms/csh"
22+
23+
var StateLookup map[string]string
24+
25+
type Auth struct {
26+
// clientID is the OIDC client ID.
27+
clientID string
28+
// clientSecret is the OIDC client secret.
29+
clientSecret string
30+
// serverURL is the "base" URL that this service is hosted from, e.g. "http://localhost:8000"
31+
serverURL string
32+
// authenticateURL is the URL for users to start the OAuth flow and login.
33+
// Commonly, this is set to something like ServerHost+"/auth/login"
34+
authenticateURL string
35+
// callbackURL is the URL that users will be redirected to at the end of the OAuth flow.
36+
// Commonly, this is set to something like ServerHost+"/auth/callback"
37+
callbackURL string
38+
// secure will be set if the serverURL contains https
39+
secure bool
40+
oauth oauth2.Config
41+
ctx context.Context
42+
oidcProvider *oidc.Provider
43+
oidcCerts jwt.VerificationKeySet
44+
}
45+
46+
type UserInfo struct {
47+
Uuid string `json:"uuid"`
48+
Email string `json:"email"`
49+
Username string `json:"preferred_username"`
50+
FullName string `json:"name"`
51+
Groups []string `json:"groups"`
52+
}
53+
54+
type Claims struct {
55+
jwt.RegisteredClaims
56+
UserInfo
57+
}
58+
59+
func Init(oidcClientID string, oidcClientSecret string, serverURL string, authenticateURL string, callbackURL string, scopes []string) (Auth, error) {
60+
auth := Auth{
61+
clientID: oidcClientID,
62+
clientSecret: oidcClientSecret,
63+
serverURL: serverURL,
64+
authenticateURL: authenticateURL,
65+
callbackURL: callbackURL,
66+
ctx: context.Background(),
67+
}
68+
69+
auth.secure = serverURL[0:5] == "https"
70+
71+
auth.oidcCerts = getVerificationKeys()
72+
73+
var err error
74+
auth.oidcProvider, err = oidc.NewProvider(auth.ctx, ProviderURI)
75+
if err != nil {
76+
log.Error("Failed to create OIDC Provider")
77+
log.Error(err)
78+
return auth, err
79+
}
80+
scopes = append(scopes, oidc.ScopeOpenID)
81+
auth.oauth = oauth2.Config{
82+
ClientID: auth.clientID,
83+
ClientSecret: auth.clientSecret,
84+
Endpoint: auth.oidcProvider.Endpoint(),
85+
RedirectURL: auth.callbackURL,
86+
Scopes: scopes,
87+
}
88+
89+
StateLookup = make(map[string]string)
90+
91+
return auth, nil
92+
}
93+
94+
// Route functions
95+
96+
func (auth *Auth) HandleLogin(c *gin.Context) {
97+
auth.oauth.RedirectURL = auth.callbackURL + "?referer=" + c.Query("referer")
98+
state := rand.Text()
99+
ref := rand.Text()
100+
c.SetCookie("ref", ref, int(time.Minute), "", "", auth.secure, true)
101+
StateLookup[ref] = state
102+
c.Redirect(http.StatusFound, auth.oauth.AuthCodeURL(state))
103+
}
104+
105+
func (auth *Auth) HandleCallback(c *gin.Context) {
106+
ref, err := c.Cookie("ref")
107+
if err != nil {
108+
log.Error("no callback ref cookie")
109+
c.Redirect(http.StatusFound, auth.authenticateURL)
110+
return
111+
}
112+
state, ok := StateLookup[ref]
113+
if !ok {
114+
log.Error("callback ref not found")
115+
c.Redirect(http.StatusFound, auth.authenticateURL)
116+
return
117+
}
118+
if c.Query("state") != state {
119+
log.Error("state does not match")
120+
c.Redirect(http.StatusFound, auth.authenticateURL)
121+
return
122+
}
123+
124+
oauthJWT, err := auth.oauth.Exchange(auth.ctx, c.Query("code"))
125+
if err != nil {
126+
log.Error("failed to exchange token")
127+
return
128+
}
129+
130+
c.SetCookie(CookieName, oauthJWT.AccessToken, int(oauthJWT.ExpiresIn), "", "", false, true)
131+
c.Redirect(http.StatusFound, c.Query("referer"))
132+
}
133+
134+
// Middleware functions
135+
136+
func (auth *Auth) CookieMiddleware() gin.HandlerFunc {
137+
return func(c *gin.Context) {
138+
cookie, err := c.Cookie(CookieName)
139+
if err != nil {
140+
log.Error(CookieName, "cookie not found")
141+
c.Redirect(http.StatusFound, auth.authenticateURL+"?referer="+c.Request.URL.String())
142+
return
143+
}
144+
err = auth.setGinContext(c, cookie)
145+
if err != nil {
146+
log.Error("failed to set context")
147+
return
148+
}
149+
}
150+
}
151+
152+
func (auth *Auth) HeaderMiddleware() gin.HandlerFunc {
153+
return func(c *gin.Context) {
154+
header := c.Request.Header.Get("Authorization")
155+
if header == "" {
156+
c.Header("WWW-Authenticate", "Authentication Required")
157+
c.AbortWithStatus(http.StatusUnauthorized)
158+
return
159+
}
160+
if header[0:8] != "Bearer: " {
161+
c.Header("WWW-Authenticate", "Bad Authentication Header")
162+
c.AbortWithStatus(http.StatusUnauthorized)
163+
return
164+
}
165+
err := auth.setGinContext(c, header)
166+
if err != nil {
167+
log.Error("failed to set context")
168+
return
169+
}
170+
171+
}
172+
}
173+
174+
// Utility functions
175+
176+
func (auth *Auth) setGinContext(c *gin.Context, tokenString string) error {
177+
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
178+
return auth.oidcCerts, nil
179+
})
180+
if err != nil {
181+
log.Error("failed to parse token", err)
182+
return err
183+
}
184+
185+
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
186+
c.Set(ContextKey, claims)
187+
return nil
188+
}
189+
190+
log.Error("failed parsing JWT claims")
191+
return errors.New("failed parsing JWT claims")
192+
}
193+
194+
func getVerificationKeys() jwt.VerificationKeySet {
195+
client := http.DefaultClient
196+
res, err := client.Get(ProviderURI + "/protocol/openid-connect/certs")
197+
if err != nil {
198+
log.Error("Failed to get verification keys", err)
199+
return jwt.VerificationKeySet{}
200+
}
201+
data, err := io.ReadAll(res.Body)
202+
if err != nil {
203+
log.Error("Failed to read verification keys", err)
204+
return jwt.VerificationKeySet{}
205+
}
206+
res.Body.Close()
207+
ret := jwt.VerificationKeySet{}
208+
err = json.Unmarshal(data, &ret)
209+
if err != nil {
210+
log.Error("Failed to unmarshal verification keys", err)
211+
return jwt.VerificationKeySet{}
212+
}
213+
return ret
214+
}

0 commit comments

Comments
 (0)