Skip to content

Commit 2674f72

Browse files
committed
fix: use decay map
1 parent 3d12ddc commit 2674f72

4 files changed

Lines changed: 59 additions & 110 deletions

File tree

api/oidc.go

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ import (
88
"fmt"
99
"log"
1010
"net/http"
11-
"sync"
1211
"time"
1312

1413
"github.com/gin-gonic/gin"
1514
"github.com/gotify/server/v2/auth"
1615
"github.com/gotify/server/v2/config"
1716
"github.com/gotify/server/v2/database"
17+
"github.com/gotify/server/v2/decaymap"
1818
"github.com/gotify/server/v2/model"
1919
"github.com/zitadel/oidc/v3/pkg/client/rp"
2020
httphelper "github.com/zitadel/oidc/v3/pkg/http"
@@ -60,7 +60,7 @@ func NewOIDC(conf *config.Configuration, db *database.GormDatabase, userChangeNo
6060
PasswordStrength: conf.PassStrength,
6161
SecureCookie: conf.Server.SecureCookie,
6262
AutoRegister: conf.OIDC.AutoRegister,
63-
pendingSessions: make(map[string]*pendingOIDCSession),
63+
pendingSessions: decaymap.NewDecayMap[string, *pendingOIDCSession](time.Now(), pendingSessionMaxAge),
6464
}
6565
}
6666

@@ -81,32 +81,7 @@ type OIDCAPI struct {
8181
PasswordStrength int
8282
SecureCookie bool
8383
AutoRegister bool
84-
pendingSessions map[string]*pendingOIDCSession
85-
pendingSessionsMu sync.Mutex
86-
}
87-
88-
func (a *OIDCAPI) storePendingSession(state string, session *pendingOIDCSession) {
89-
a.pendingSessionsMu.Lock()
90-
defer a.pendingSessionsMu.Unlock()
91-
for s, sess := range a.pendingSessions {
92-
if time.Since(sess.CreatedAt) > pendingSessionMaxAge {
93-
delete(a.pendingSessions, s)
94-
}
95-
}
96-
a.pendingSessions[state] = session
97-
}
98-
99-
func (a *OIDCAPI) popPendingSession(state string) (*pendingOIDCSession, bool) {
100-
a.pendingSessionsMu.Lock()
101-
session, ok := a.pendingSessions[state]
102-
if ok {
103-
delete(a.pendingSessions, state)
104-
}
105-
a.pendingSessionsMu.Unlock()
106-
if !ok || time.Since(session.CreatedAt) > pendingSessionMaxAge {
107-
return nil, false
108-
}
109-
return session, true
84+
pendingSessions *decaymap.DecayMap[string, *pendingOIDCSession]
11085
}
11186

11287
// swagger:operation GET /auth/oidc/login oidc oidcLogin
@@ -142,7 +117,7 @@ func (a *OIDCAPI) LoginHandler() gin.HandlerFunc {
142117
http.Error(w, fmt.Sprintf("failed to generate state: %v", err), http.StatusInternalServerError)
143118
return
144119
}
145-
a.storePendingSession(state, &pendingOIDCSession{ClientName: clientName, CreatedAt: time.Now()})
120+
a.pendingSessions.Set(time.Now(), state, &pendingOIDCSession{ClientName: clientName, CreatedAt: time.Now()})
146121
rp.AuthURLHandler(func() string { return state }, a.Provider)(w, r)
147122
})
148123
}
@@ -237,7 +212,7 @@ func (a *OIDCAPI) ExternalAuthorizeHandler(ctx *gin.Context) {
237212
ctx.AbortWithError(http.StatusInternalServerError, err)
238213
return
239214
}
240-
a.storePendingSession(state, &pendingOIDCSession{
215+
a.pendingSessions.Set(time.Now(), state, &pendingOIDCSession{
241216
RedirectURI: req.RedirectURI, ClientName: req.Name, CreatedAt: time.Now(),
242217
})
243218
authOpts := []rp.AuthURLOpt{
@@ -364,3 +339,11 @@ func (a *OIDCAPI) createClient(name string, userID uint) (*model.Client, error)
364339
}
365340
return client, a.DB.CreateClient(client)
366341
}
342+
343+
func (a *OIDCAPI) popPendingSession(key string) (*pendingOIDCSession, bool) {
344+
session, ok := a.pendingSessions.Pop(key)
345+
if ok && time.Since(session.CreatedAt) < pendingSessionMaxAge {
346+
return session, true
347+
}
348+
return nil, false
349+
}

api/oidc_test.go

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"time"
88

99
"github.com/gin-gonic/gin"
10+
"github.com/gotify/server/v2/decaymap"
1011
"github.com/gotify/server/v2/mode"
1112
"github.com/gotify/server/v2/test"
1213
"github.com/gotify/server/v2/test/testdb"
@@ -46,7 +47,7 @@ func (s *OIDCSuite) BeforeTest(suiteName, testName string) {
4647
UserChangeNotifier: notifier,
4748
UsernameClaim: "preferred_username",
4849
AutoRegister: true,
49-
pendingSessions: make(map[string]*pendingOIDCSession),
50+
pendingSessions: decaymap.NewDecayMap[string, *pendingOIDCSession](time.Now(), pendingSessionMaxAge),
5051
}
5152
}
5253

@@ -60,48 +61,6 @@ func (s *OIDCSuite) Test_GenerateState_Unique() {
6061
assert.NotEqual(s.T(), s1, s2)
6162
}
6263

63-
func (s *OIDCSuite) Test_StoreAndPopPendingSession() {
64-
session := &pendingOIDCSession{RedirectURI: "app://cb", ClientName: "phone", CreatedAt: time.Now()}
65-
s.a.storePendingSession("state1", session)
66-
67-
got, ok := s.a.popPendingSession("state1")
68-
assert.True(s.T(), ok)
69-
assert.Equal(s.T(), "app://cb", got.RedirectURI)
70-
71-
// second pop returns nothing (consumed)
72-
_, ok = s.a.popPendingSession("state1")
73-
assert.False(s.T(), ok)
74-
}
75-
76-
func (s *OIDCSuite) Test_PopPendingSession_UnknownState() {
77-
_, ok := s.a.popPendingSession("doesnotexist")
78-
assert.False(s.T(), ok)
79-
}
80-
81-
func (s *OIDCSuite) Test_PopPendingSession_Expired() {
82-
session := &pendingOIDCSession{RedirectURI: "x", ClientName: "x", CreatedAt: time.Now().Add(-11 * time.Minute)}
83-
s.a.storePendingSession("old", session)
84-
85-
_, ok := s.a.popPendingSession("old")
86-
assert.False(s.T(), ok)
87-
}
88-
89-
func (s *OIDCSuite) Test_StorePendingSession_PrunesExpired() {
90-
expired := &pendingOIDCSession{CreatedAt: time.Now().Add(-11 * time.Minute)}
91-
s.a.pendingSessions["stale"] = expired
92-
93-
fresh := &pendingOIDCSession{CreatedAt: time.Now()}
94-
s.a.storePendingSession("fresh", fresh)
95-
96-
s.a.pendingSessionsMu.Lock()
97-
_, staleExists := s.a.pendingSessions["stale"]
98-
_, freshExists := s.a.pendingSessions["fresh"]
99-
s.a.pendingSessionsMu.Unlock()
100-
101-
assert.False(s.T(), staleExists)
102-
assert.True(s.T(), freshExists)
103-
}
104-
10564
func (s *OIDCSuite) Test_ResolveUser_ExistingUser() {
10665
s.db.NewUserWithName(1, "alice")
10766

@@ -224,7 +183,7 @@ func (s *OIDCSuite) Test_ExternalTokenHandler_InvalidJSON() {
224183

225184
func (s *OIDCSuite) Test_ExternalTokenHandler_UnknownState() {
226185
s.ctx.Request = httptest.NewRequest("POST", "/auth/oidc/external/token", strings.NewReader(
227-
`{"code":"abc","state":"bogus:1234","code_verifier":"v"}`,
186+
`{"code":"abc","state":"bogus","code_verifier":"v"}`,
228187
))
229188
s.ctx.Request.Header.Set("Content-Type", "application/json")
230189

decaymap/decaymap.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77

88
// DecayMap is a coarse-grained paired map that bulk-frees outdated entries.
99
type DecayMap[K comparable, V any] struct {
10-
mu sync.RWMutex
10+
mu sync.Mutex
1111
maps [2]map[K]V
1212
epoch time.Time
1313
lastInsert time.Time
@@ -24,16 +24,21 @@ func NewDecayMap[K comparable, V any](epoch time.Time, period time.Duration) *De
2424
}
2525
}
2626

27-
// Attempts to retrieve a value from the store, does not guarantee the value is unexpired.
28-
func (m *DecayMap[K, V]) Get(key K) (res V, ok bool) {
29-
m.mu.RLock()
30-
defer m.mu.RUnlock()
31-
res, ok = m.maps[0][key]
27+
// Attempts to retrieve and remove a value from the store, does not guarantee the value is unexpired.
28+
func (m *DecayMap[K, V]) Pop(key K) (V, bool) {
29+
m.mu.Lock()
30+
defer m.mu.Unlock()
31+
res, ok := m.maps[0][key]
3232
if ok {
33-
return
33+
delete(m.maps[0], key)
34+
return res, ok
3435
}
3536
res, ok = m.maps[1][key]
36-
return
37+
if ok {
38+
delete(m.maps[1], key)
39+
return res, ok
40+
}
41+
return res, ok
3742
}
3843

3944
// Sets a value in the store, overwriting the existing value if exists.

decaymap/decaymap_test.go

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,41 @@
1-
21
package decaymap
32

43
import (
5-
"fmt"
64
"testing"
75
"time"
86

97
"github.com/stretchr/testify/assert"
108
)
119

12-
func TestDecayMap2(t *testing.T) {
10+
func TestDecayMap(t *testing.T) {
1311
epoch := time.Now()
14-
dm := NewDecayMap[string, string](epoch, 10*time.Millisecond)
15-
16-
now := epoch
17-
for ts := 0; ts < 100; ts++ {
18-
dm.Set(now, fmt.Sprintf("key%d", ts), fmt.Sprintf("value%d", ts))
19-
for backts := 0; backts < ts; backts++ {
20-
res, ok := dm.Get(fmt.Sprintf("key%d", backts))
21-
if ts-backts <= 10 {
22-
assert.True(t, ok)
23-
assert.Equal(t, fmt.Sprintf("value%d", backts), res)
24-
} else if ts-backts >= 20 {
25-
assert.False(t, ok)
26-
assert.Equal(t, "", res)
27-
}
28-
}
29-
now = now.Add(1 * time.Millisecond)
30-
}
31-
32-
now = now.Add(20 * time.Millisecond)
33-
dm.Set(now, "dummy", "dummy") // rachet internal state
34-
for ts := 0; ts < 100; ts++ {
35-
res, ok := dm.Get(fmt.Sprintf("key%d", ts))
36-
assert.False(t, ok)
37-
assert.Equal(t, "", res)
38-
}
12+
dm := NewDecayMap[string, string](epoch, 10*time.Second)
13+
14+
dm.Set(epoch.Add(1*time.Second), "11", "value11")
15+
dm.Set(epoch.Add(2*time.Second), "12", "value12")
16+
dm.Set(epoch.Add(3*time.Second), "13", "value13")
17+
18+
value, ok := dm.Pop("11")
19+
assert.True(t, ok)
20+
assert.Equal(t, "value11", value)
21+
22+
_, ok = dm.Pop("11")
23+
assert.False(t, ok)
24+
25+
dm.Set(epoch.Add(11*time.Second), "21", "value21")
26+
dm.Set(epoch.Add(12*time.Second), "22", "value22")
27+
dm.Set(epoch.Add(13*time.Second), "23", "value23")
28+
29+
value, ok = dm.Pop("21")
30+
assert.True(t, ok)
31+
assert.Equal(t, "value21", value)
32+
33+
value, ok = dm.Pop("12")
34+
assert.True(t, ok)
35+
assert.Equal(t, "value12", value)
36+
37+
dm.Set(epoch.Add(21*time.Second), "31", "value31")
38+
39+
_, ok = dm.Pop("13")
40+
assert.False(t, ok)
3941
}

0 commit comments

Comments
 (0)