Skip to content

Commit 44bda75

Browse files
committed
test: oidc tests
It's not easy to test this automatically without a real oidc server.
1 parent 83895b5 commit 44bda75

1 file changed

Lines changed: 234 additions & 0 deletions

File tree

api/oidc_test.go

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
package api
2+
3+
import (
4+
"net/http/httptest"
5+
"strings"
6+
"testing"
7+
"time"
8+
9+
"github.com/gin-gonic/gin"
10+
"github.com/gotify/server/v2/mode"
11+
"github.com/gotify/server/v2/test"
12+
"github.com/gotify/server/v2/test/testdb"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/suite"
15+
"github.com/zitadel/oidc/v3/pkg/oidc"
16+
)
17+
18+
var origGenClientToken = generateClientToken
19+
20+
func TestOIDCSuite(t *testing.T) {
21+
suite.Run(t, new(OIDCSuite))
22+
}
23+
24+
type OIDCSuite struct {
25+
suite.Suite
26+
db *testdb.Database
27+
a *OIDCAPI
28+
ctx *gin.Context
29+
recorder *httptest.ResponseRecorder
30+
notified bool
31+
}
32+
33+
func (s *OIDCSuite) BeforeTest(suiteName, testName string) {
34+
mode.Set(mode.TestDev)
35+
s.recorder = httptest.NewRecorder()
36+
s.ctx, _ = gin.CreateTestContext(s.recorder)
37+
s.db = testdb.NewDB(s.T())
38+
s.notified = false
39+
notifier := new(UserChangeNotifier)
40+
notifier.OnUserAdded(func(uint) error {
41+
s.notified = true
42+
return nil
43+
})
44+
s.a = &OIDCAPI{
45+
DB: s.db.GormDatabase,
46+
UserChangeNotifier: notifier,
47+
UsernameClaim: "preferred_username",
48+
AutoRegister: true,
49+
pendingSessions: make(map[string]*pendingOIDCSession),
50+
}
51+
}
52+
53+
func (s *OIDCSuite) AfterTest(suiteName, testName string) {
54+
s.db.Close()
55+
}
56+
57+
func (s *OIDCSuite) Test_GenerateState_Unique() {
58+
s1, _ := s.a.generateState("app")
59+
s2, _ := s.a.generateState("app")
60+
assert.NotEqual(s.T(), s1, s2)
61+
}
62+
63+
func (s *OIDCSuite) Test_StoreAndPopPendingSession() {
64+
session := &pendingOIDCSession{RedirectURI: "app://cb", ClientName: "phone", CreatedAt: time.Now()}
65+
s.a.storePendingSession("state1", session)
66+
67+
got, ok := s.a.popPendingSession("state1")
68+
assert.True(s.T(), ok)
69+
assert.Equal(s.T(), "app://cb", got.RedirectURI)
70+
71+
// second pop returns nothing (consumed)
72+
_, ok = s.a.popPendingSession("state1")
73+
assert.False(s.T(), ok)
74+
}
75+
76+
func (s *OIDCSuite) Test_PopPendingSession_UnknownState() {
77+
_, ok := s.a.popPendingSession("doesnotexist")
78+
assert.False(s.T(), ok)
79+
}
80+
81+
func (s *OIDCSuite) Test_PopPendingSession_Expired() {
82+
session := &pendingOIDCSession{RedirectURI: "x", ClientName: "x", CreatedAt: time.Now().Add(-11 * time.Minute)}
83+
s.a.storePendingSession("old", session)
84+
85+
_, ok := s.a.popPendingSession("old")
86+
assert.False(s.T(), ok)
87+
}
88+
89+
func (s *OIDCSuite) Test_StorePendingSession_PrunesExpired() {
90+
expired := &pendingOIDCSession{CreatedAt: time.Now().Add(-11 * time.Minute)}
91+
s.a.pendingSessions["stale"] = expired
92+
93+
fresh := &pendingOIDCSession{CreatedAt: time.Now()}
94+
s.a.storePendingSession("fresh", fresh)
95+
96+
s.a.pendingSessionsMu.Lock()
97+
_, staleExists := s.a.pendingSessions["stale"]
98+
_, freshExists := s.a.pendingSessions["fresh"]
99+
s.a.pendingSessionsMu.Unlock()
100+
101+
assert.False(s.T(), staleExists)
102+
assert.True(s.T(), freshExists)
103+
}
104+
105+
func (s *OIDCSuite) Test_ResolveUser_ExistingUser() {
106+
s.db.NewUserWithName(1, "alice")
107+
108+
info := &oidc.UserInfo{Claims: map[string]any{"preferred_username": "alice"}}
109+
user, status, err := s.a.resolveUser(info)
110+
111+
assert.NoError(s.T(), err)
112+
assert.Equal(s.T(), 0, status)
113+
assert.Equal(s.T(), "alice", user.Name)
114+
assert.Equal(s.T(), uint(1), user.ID)
115+
assert.False(s.T(), s.notified)
116+
}
117+
118+
func (s *OIDCSuite) Test_ResolveUser_AutoRegister() {
119+
info := &oidc.UserInfo{Claims: map[string]any{"preferred_username": "newuser"}}
120+
user, status, err := s.a.resolveUser(info)
121+
122+
assert.NoError(s.T(), err)
123+
assert.Equal(s.T(), 0, status)
124+
assert.Equal(s.T(), "newuser", user.Name)
125+
assert.False(s.T(), user.Admin)
126+
assert.True(s.T(), s.notified)
127+
128+
// verify persisted
129+
dbUser, err := s.db.GetUserByName("newuser")
130+
assert.NoError(s.T(), err)
131+
assert.NotNil(s.T(), dbUser)
132+
}
133+
134+
func (s *OIDCSuite) Test_ResolveUser_AutoRegisterDisabled() {
135+
s.a.AutoRegister = false
136+
info := &oidc.UserInfo{Claims: map[string]any{"preferred_username": "newuser"}}
137+
138+
_, status, err := s.a.resolveUser(info)
139+
140+
assert.Error(s.T(), err)
141+
assert.Equal(s.T(), 403, status)
142+
s.db.AssertUsernameNotExist("newuser")
143+
}
144+
145+
func (s *OIDCSuite) Test_ResolveUser_MissingClaim() {
146+
info := &oidc.UserInfo{Claims: map[string]any{}}
147+
148+
_, status, err := s.a.resolveUser(info)
149+
150+
assert.Error(s.T(), err)
151+
assert.Equal(s.T(), 500, status)
152+
assert.Contains(s.T(), err.Error(), "preferred_username")
153+
}
154+
155+
func (s *OIDCSuite) Test_ResolveUser_EmptyClaim() {
156+
info := &oidc.UserInfo{Claims: map[string]any{"preferred_username": ""}}
157+
158+
_, status, err := s.a.resolveUser(info)
159+
160+
assert.Error(s.T(), err)
161+
assert.Equal(s.T(), 500, status)
162+
}
163+
164+
func (s *OIDCSuite) Test_ResolveUser_NilClaim() {
165+
info := &oidc.UserInfo{Claims: map[string]any{"preferred_username": nil}}
166+
167+
_, status, err := s.a.resolveUser(info)
168+
169+
assert.Error(s.T(), err)
170+
assert.Equal(s.T(), 500, status)
171+
}
172+
173+
func (s *OIDCSuite) Test_ResolveUser_CustomClaim() {
174+
s.a.UsernameClaim = "email"
175+
s.db.NewUserWithName(1, "alice@example.com")
176+
177+
info := &oidc.UserInfo{Claims: map[string]any{"email": "alice@example.com"}}
178+
user, _, err := s.a.resolveUser(info)
179+
180+
assert.NoError(s.T(), err)
181+
assert.Equal(s.T(), "alice@example.com", user.Name)
182+
}
183+
184+
// --- createClient ---
185+
186+
func (s *OIDCSuite) Test_CreateClient() {
187+
generateClientToken = test.Tokens("Ctesttoken00001")
188+
defer func() { generateClientToken = origGenClientToken }()
189+
190+
s.db.NewUser(1)
191+
client, err := s.a.createClient("MyPhone", 1)
192+
193+
assert.NoError(s.T(), err)
194+
assert.Equal(s.T(), "MyPhone", client.Name)
195+
assert.Equal(s.T(), "Ctesttoken00001", client.Token)
196+
assert.Equal(s.T(), uint(1), client.UserID)
197+
198+
dbClient, err := s.db.GetClientByToken("Ctesttoken00001")
199+
assert.NoError(s.T(), err)
200+
assert.NotNil(s.T(), dbClient)
201+
}
202+
203+
// --- ExternalAuthorizeHandler ---
204+
205+
func (s *OIDCSuite) Test_ExternalAuthorizeHandler_MissingFields() {
206+
s.ctx.Request = httptest.NewRequest("POST", "/auth/oidc/external/authorize", strings.NewReader(`{}`))
207+
s.ctx.Request.Header.Set("Content-Type", "application/json")
208+
209+
s.a.ExternalAuthorizeHandler(s.ctx)
210+
211+
assert.Equal(s.T(), 400, s.recorder.Code)
212+
}
213+
214+
// --- ExternalTokenHandler ---
215+
216+
func (s *OIDCSuite) Test_ExternalTokenHandler_InvalidJSON() {
217+
s.ctx.Request = httptest.NewRequest("POST", "/auth/oidc/external/token", strings.NewReader(`{bad`))
218+
s.ctx.Request.Header.Set("Content-Type", "application/json")
219+
220+
s.a.ExternalTokenHandler(s.ctx)
221+
222+
assert.Equal(s.T(), 400, s.recorder.Code)
223+
}
224+
225+
func (s *OIDCSuite) Test_ExternalTokenHandler_UnknownState() {
226+
s.ctx.Request = httptest.NewRequest("POST", "/auth/oidc/external/token", strings.NewReader(
227+
`{"code":"abc","state":"bogus:1234","code_verifier":"v"}`,
228+
))
229+
s.ctx.Request.Header.Set("Content-Type", "application/json")
230+
231+
s.a.ExternalTokenHandler(s.ctx)
232+
233+
assert.Equal(s.T(), 400, s.recorder.Code)
234+
}

0 commit comments

Comments
 (0)