Skip to content

Commit e13598b

Browse files
committed
tests: add tests for context middleware
1 parent 4d3860f commit e13598b

2 files changed

Lines changed: 415 additions & 164 deletions

File tree

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
package middleware_test
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"net/http"
7+
"net/http/httptest"
8+
"path"
9+
"testing"
10+
"time"
11+
12+
"github.com/gin-gonic/gin"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
16+
"github.com/tinyauthapp/tinyauth/internal/middleware"
17+
"github.com/tinyauthapp/tinyauth/internal/model"
18+
"github.com/tinyauthapp/tinyauth/internal/repository"
19+
"github.com/tinyauthapp/tinyauth/internal/service"
20+
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
21+
)
22+
23+
func TestContextMiddleware(t *testing.T) {
24+
tlog.NewTestLogger().Init()
25+
tempDir := t.TempDir()
26+
27+
authServiceCfg := service.AuthServiceConfig{
28+
LocalUsers: []model.LocalUser{
29+
{
30+
Username: "testuser",
31+
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
32+
},
33+
{
34+
Username: "totpuser",
35+
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
36+
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
37+
},
38+
},
39+
SessionExpiry: 10, // 10 seconds, useful for testing
40+
CookieDomain: "example.com",
41+
LoginTimeout: 10, // 10 seconds, useful for testing
42+
LoginMaxRetries: 3,
43+
SessionCookieName: "tinyauth-session",
44+
}
45+
46+
middlewareCfg := middleware.ContextMiddlewareConfig{
47+
CookieDomain: "example.com",
48+
SessionCookieName: "tinyauth-session",
49+
}
50+
51+
basicAuthHeader := func(username, password string) string {
52+
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
53+
}
54+
55+
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
56+
t.Helper()
57+
_, err := queries.CreateSession(context.Background(), params)
58+
require.NoError(t, err)
59+
}
60+
61+
type runArgs struct {
62+
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
63+
queries *repository.Queries
64+
}
65+
66+
type testCase struct {
67+
description string
68+
run func(t *testing.T, args runArgs)
69+
}
70+
71+
tests := []testCase{
72+
{
73+
description: "Skip path bypasses auth processing",
74+
run: func(t *testing.T, args runArgs) {
75+
req := httptest.NewRequest("GET", "/api/healthz", nil)
76+
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
77+
userCtx, _ := args.do(req)
78+
79+
assert.Nil(t, userCtx)
80+
},
81+
},
82+
{
83+
description: "No credentials yields no context",
84+
run: func(t *testing.T, args runArgs) {
85+
req := httptest.NewRequest("GET", "/api/test", nil)
86+
userCtx, _ := args.do(req)
87+
88+
assert.Nil(t, userCtx)
89+
},
90+
},
91+
{
92+
description: "Valid session cookie sets authenticated local context",
93+
run: func(t *testing.T, args runArgs) {
94+
uuid := "session-valid-local"
95+
seedSession(t, args.queries, repository.CreateSessionParams{
96+
UUID: uuid,
97+
Username: "testuser",
98+
Provider: "local",
99+
Expiry: time.Now().Add(10 * time.Second).Unix(),
100+
CreatedAt: time.Now().Unix(),
101+
})
102+
103+
req := httptest.NewRequest("GET", "/api/test", nil)
104+
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
105+
userCtx, _ := args.do(req)
106+
107+
require.NotNil(t, userCtx)
108+
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
109+
assert.Equal(t, "testuser", userCtx.GetUsername())
110+
assert.True(t, userCtx.Authenticated)
111+
require.NotNil(t, userCtx.Local)
112+
assert.False(t, userCtx.Local.TOTPEnabled)
113+
},
114+
},
115+
{
116+
description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
117+
run: func(t *testing.T, args runArgs) {
118+
uuid := "session-totp-pending"
119+
seedSession(t, args.queries, repository.CreateSessionParams{
120+
UUID: uuid,
121+
Username: "totpuser",
122+
Provider: "local",
123+
TotpPending: true,
124+
Expiry: time.Now().Add(60 * time.Second).Unix(),
125+
CreatedAt: time.Now().Unix(),
126+
})
127+
128+
req := httptest.NewRequest("GET", "/api/test", nil)
129+
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
130+
userCtx, _ := args.do(req)
131+
132+
require.NotNil(t, userCtx)
133+
assert.Equal(t, "totpuser", userCtx.GetUsername())
134+
assert.False(t, userCtx.Authenticated)
135+
require.NotNil(t, userCtx.Local)
136+
assert.True(t, userCtx.Local.TOTPPending)
137+
assert.True(t, userCtx.Local.TOTPEnabled)
138+
},
139+
},
140+
{
141+
description: "Unknown session cookie yields no context",
142+
run: func(t *testing.T, args runArgs) {
143+
req := httptest.NewRequest("GET", "/api/test", nil)
144+
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
145+
userCtx, _ := args.do(req)
146+
147+
assert.Nil(t, userCtx)
148+
},
149+
},
150+
{
151+
description: "Session for missing local user yields no context",
152+
run: func(t *testing.T, args runArgs) {
153+
uuid := "session-deleted-user"
154+
seedSession(t, args.queries, repository.CreateSessionParams{
155+
UUID: uuid,
156+
Username: "ghostuser",
157+
Provider: "local",
158+
Expiry: time.Now().Add(10 * time.Second).Unix(),
159+
CreatedAt: time.Now().Unix(),
160+
})
161+
162+
req := httptest.NewRequest("GET", "/api/test", nil)
163+
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
164+
userCtx, _ := args.do(req)
165+
166+
assert.Nil(t, userCtx)
167+
},
168+
},
169+
{
170+
description: "Expired session cookie yields no context",
171+
run: func(t *testing.T, args runArgs) {
172+
uuid := "session-expired"
173+
seedSession(t, args.queries, repository.CreateSessionParams{
174+
UUID: uuid,
175+
Username: "testuser",
176+
Provider: "local",
177+
Expiry: time.Now().Add(-1 * time.Second).Unix(),
178+
CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
179+
})
180+
181+
req := httptest.NewRequest("GET", "/api/test", nil)
182+
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
183+
userCtx, _ := args.do(req)
184+
185+
assert.Nil(t, userCtx)
186+
},
187+
},
188+
{
189+
description: "Valid basic auth sets authenticated local context",
190+
run: func(t *testing.T, args runArgs) {
191+
req := httptest.NewRequest("GET", "/api/test", nil)
192+
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
193+
userCtx, _ := args.do(req)
194+
195+
require.NotNil(t, userCtx)
196+
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
197+
assert.Equal(t, "testuser", userCtx.GetUsername())
198+
assert.True(t, userCtx.Authenticated)
199+
},
200+
},
201+
{
202+
description: "Invalid basic auth password yields no context",
203+
run: func(t *testing.T, args runArgs) {
204+
req := httptest.NewRequest("GET", "/api/test", nil)
205+
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
206+
userCtx, _ := args.do(req)
207+
208+
assert.Nil(t, userCtx)
209+
},
210+
},
211+
{
212+
description: "Basic auth is rejected for users with totp",
213+
run: func(t *testing.T, args runArgs) {
214+
req := httptest.NewRequest("GET", "/api/test", nil)
215+
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
216+
userCtx, _ := args.do(req)
217+
218+
assert.Nil(t, userCtx)
219+
},
220+
},
221+
{
222+
description: "Locked account on basic auth sets lock headers",
223+
run: func(t *testing.T, args runArgs) {
224+
for range 3 {
225+
req := httptest.NewRequest("GET", "/api/test", nil)
226+
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
227+
args.do(req)
228+
}
229+
230+
req := httptest.NewRequest("GET", "/api/test", nil)
231+
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
232+
userCtx, recorder := args.do(req)
233+
234+
assert.Nil(t, userCtx)
235+
assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
236+
assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
237+
},
238+
},
239+
{
240+
description: "Cookie auth takes precedence over basic auth",
241+
run: func(t *testing.T, args runArgs) {
242+
uuid := "session-precedence"
243+
seedSession(t, args.queries, repository.CreateSessionParams{
244+
UUID: uuid,
245+
Username: "testuser",
246+
Provider: "local",
247+
Expiry: time.Now().Add(10 * time.Second).Unix(),
248+
CreatedAt: time.Now().Unix(),
249+
})
250+
251+
req := httptest.NewRequest("GET", "/api/test", nil)
252+
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
253+
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
254+
userCtx, _ := args.do(req)
255+
256+
require.NotNil(t, userCtx)
257+
assert.Equal(t, "testuser", userCtx.GetUsername())
258+
assert.True(t, userCtx.Authenticated)
259+
},
260+
},
261+
}
262+
263+
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
264+
265+
app := bootstrap.NewBootstrapApp(model.Config{})
266+
267+
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
268+
require.NoError(t, err)
269+
270+
queries := repository.New(db)
271+
272+
ldap := service.NewLdapService(service.LdapServiceConfig{})
273+
err = ldap.Init()
274+
require.NoError(t, err)
275+
276+
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
277+
err = broker.Init()
278+
require.NoError(t, err)
279+
280+
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
281+
err = authService.Init()
282+
require.NoError(t, err)
283+
284+
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
285+
err = contextMiddleware.Init()
286+
require.NoError(t, err)
287+
288+
for _, test := range tests {
289+
authService.ClearRateLimitsTestingOnly()
290+
t.Run(test.description, func(t *testing.T) {
291+
gin.SetMode(gin.TestMode)
292+
293+
do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
294+
var captured *model.UserContext
295+
router := gin.New()
296+
router.Use(contextMiddleware.Middleware())
297+
handler := func(c *gin.Context) {
298+
if val, exists := c.Get("context"); exists {
299+
captured, _ = val.(*model.UserContext)
300+
}
301+
}
302+
router.GET("/api/test", handler)
303+
router.GET("/api/healthz", handler)
304+
305+
recorder := httptest.NewRecorder()
306+
router.ServeHTTP(recorder, req)
307+
return captured, recorder
308+
}
309+
310+
test.run(t, runArgs{do: do, queries: queries})
311+
})
312+
}
313+
314+
t.Cleanup(func() {
315+
err = db.Close()
316+
require.NoError(t, err)
317+
})
318+
}

0 commit comments

Comments
 (0)