Skip to content

Commit 6f8668e

Browse files
authored
fix: enforce header nav access control for public modules (#4889)
1 parent 8a10ded commit 6f8668e

17 files changed

Lines changed: 687 additions & 149 deletions

File tree

controller/rankings.go

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,11 @@ package controller
33
import (
44
"net/http"
55

6-
"github.com/QuantumNous/new-api/common"
76
"github.com/QuantumNous/new-api/service"
87
"github.com/gin-gonic/gin"
98
)
109

11-
func isRankingsEnabled() bool {
12-
common.OptionMapRWMutex.RLock()
13-
raw := common.OptionMap["HeaderNavModules"]
14-
common.OptionMapRWMutex.RUnlock()
15-
16-
if raw == "" {
17-
return true
18-
}
19-
20-
var parsed map[string]interface{}
21-
if err := common.Unmarshal([]byte(raw), &parsed); err != nil {
22-
return true
23-
}
24-
rankings, ok := parsed["rankings"]
25-
if !ok {
26-
return true
27-
}
28-
switch v := rankings.(type) {
29-
case bool:
30-
return v
31-
case map[string]interface{}:
32-
if enabled, ok := v["enabled"]; ok {
33-
if b, ok := enabled.(bool); ok {
34-
return b
35-
}
36-
}
37-
return true
38-
}
39-
return true
40-
}
41-
4210
func GetRankings(c *gin.Context) {
43-
if !isRankingsEnabled() {
44-
c.JSON(http.StatusForbidden, gin.H{
45-
"success": false,
46-
"message": "rankings is disabled",
47-
})
48-
return
49-
}
50-
5111
result, err := service.GetRankingsSnapshot(c.DefaultQuery("period", "week"))
5212
if err != nil {
5313
c.JSON(http.StatusBadRequest, gin.H{

middleware/header_nav.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package middleware
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"strings"
7+
8+
"github.com/QuantumNous/new-api/common"
9+
"github.com/gin-gonic/gin"
10+
)
11+
12+
type headerNavAccess struct {
13+
Enabled bool
14+
RequireAuth bool
15+
}
16+
17+
func getHeaderNavAccess(module string) headerNavAccess {
18+
fallback := headerNavAccess{
19+
Enabled: true,
20+
RequireAuth: false,
21+
}
22+
23+
common.OptionMapRWMutex.RLock()
24+
raw := common.OptionMap["HeaderNavModules"]
25+
common.OptionMapRWMutex.RUnlock()
26+
27+
if strings.TrimSpace(raw) == "" {
28+
return fallback
29+
}
30+
31+
var parsed map[string]any
32+
if err := common.Unmarshal([]byte(raw), &parsed); err != nil {
33+
return fallback
34+
}
35+
36+
return parseHeaderNavAccess(parsed[module], fallback)
37+
}
38+
39+
func parseHeaderNavAccess(raw any, fallback headerNavAccess) headerNavAccess {
40+
switch value := raw.(type) {
41+
case bool:
42+
return headerNavAccess{
43+
Enabled: value,
44+
RequireAuth: fallback.RequireAuth,
45+
}
46+
case string:
47+
return headerNavAccess{
48+
Enabled: parseHeaderNavBool(value, fallback.Enabled),
49+
RequireAuth: fallback.RequireAuth,
50+
}
51+
case float64:
52+
return headerNavAccess{
53+
Enabled: parseHeaderNavBool(value, fallback.Enabled),
54+
RequireAuth: fallback.RequireAuth,
55+
}
56+
case map[string]any:
57+
access := fallback
58+
if enabled, ok := value["enabled"]; ok {
59+
access.Enabled = parseHeaderNavBool(enabled, fallback.Enabled)
60+
}
61+
if requireAuth, ok := value["requireAuth"]; ok {
62+
access.RequireAuth = parseHeaderNavBool(requireAuth, fallback.RequireAuth)
63+
}
64+
return access
65+
default:
66+
return fallback
67+
}
68+
}
69+
70+
func parseHeaderNavBool(value any, fallback bool) bool {
71+
switch v := value.(type) {
72+
case bool:
73+
return v
74+
case string:
75+
switch strings.ToLower(strings.TrimSpace(v)) {
76+
case "true", "1":
77+
return true
78+
case "false", "0":
79+
return false
80+
default:
81+
return fallback
82+
}
83+
case float64:
84+
if v == 1 {
85+
return true
86+
}
87+
if v == 0 {
88+
return false
89+
}
90+
return fallback
91+
case int:
92+
if v == 1 {
93+
return true
94+
}
95+
if v == 0 {
96+
return false
97+
}
98+
return fallback
99+
default:
100+
return fallback
101+
}
102+
}
103+
104+
func HeaderNavModuleAuth(module string) gin.HandlerFunc {
105+
return func(c *gin.Context) {
106+
access := getHeaderNavAccess(module)
107+
if !access.Enabled {
108+
c.JSON(http.StatusForbidden, gin.H{
109+
"success": false,
110+
"message": fmt.Sprintf("%s is disabled", module),
111+
})
112+
c.Abort()
113+
return
114+
}
115+
116+
if access.RequireAuth {
117+
UserAuth()(c)
118+
return
119+
}
120+
121+
TryUserAuth()(c)
122+
}
123+
}
124+
125+
func HeaderNavModulePublicOrUserAuth(module string) gin.HandlerFunc {
126+
return func(c *gin.Context) {
127+
access := getHeaderNavAccess(module)
128+
if !access.Enabled || access.RequireAuth {
129+
UserAuth()(c)
130+
return
131+
}
132+
133+
TryUserAuth()(c)
134+
}
135+
}

middleware/header_nav_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/QuantumNous/new-api/common"
9+
"github.com/gin-contrib/sessions"
10+
"github.com/gin-contrib/sessions/cookie"
11+
"github.com/gin-gonic/gin"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func withHeaderNavModules(t *testing.T, raw string) {
16+
t.Helper()
17+
18+
common.OptionMapRWMutex.Lock()
19+
if common.OptionMap == nil {
20+
common.OptionMap = map[string]string{}
21+
}
22+
previous, hadPrevious := common.OptionMap["HeaderNavModules"]
23+
common.OptionMap["HeaderNavModules"] = raw
24+
common.OptionMapRWMutex.Unlock()
25+
26+
t.Cleanup(func() {
27+
common.OptionMapRWMutex.Lock()
28+
defer common.OptionMapRWMutex.Unlock()
29+
if hadPrevious {
30+
common.OptionMap["HeaderNavModules"] = previous
31+
return
32+
}
33+
delete(common.OptionMap, "HeaderNavModules")
34+
})
35+
}
36+
37+
func performHeaderNavRequest(t *testing.T, handler gin.HandlerFunc, authenticated bool) *httptest.ResponseRecorder {
38+
t.Helper()
39+
40+
gin.SetMode(gin.TestMode)
41+
router := gin.New()
42+
router.Use(sessions.Sessions("session", cookie.NewStore([]byte("header-nav-test"))))
43+
router.GET("/login", func(c *gin.Context) {
44+
session := sessions.Default(c)
45+
session.Set("username", "tester")
46+
session.Set("role", common.RoleCommonUser)
47+
session.Set("id", 1)
48+
session.Set("status", common.UserStatusEnabled)
49+
session.Set("group", "default")
50+
if err := session.Save(); err != nil {
51+
c.JSON(http.StatusInternalServerError, gin.H{"success": false})
52+
return
53+
}
54+
c.Status(http.StatusNoContent)
55+
})
56+
router.GET("/api/test", handler, func(c *gin.Context) {
57+
c.JSON(http.StatusOK, gin.H{"success": true})
58+
})
59+
60+
var cookies []*http.Cookie
61+
if authenticated {
62+
loginRecorder := httptest.NewRecorder()
63+
loginRequest := httptest.NewRequest(http.MethodGet, "/login", nil)
64+
router.ServeHTTP(loginRecorder, loginRequest)
65+
require.Equal(t, http.StatusNoContent, loginRecorder.Code)
66+
cookies = loginRecorder.Result().Cookies()
67+
}
68+
69+
recorder := httptest.NewRecorder()
70+
request := httptest.NewRequest(http.MethodGet, "/api/test", nil)
71+
if authenticated {
72+
request.Header.Set("New-Api-User", "1")
73+
for _, cookie := range cookies {
74+
request.AddCookie(cookie)
75+
}
76+
}
77+
router.ServeHTTP(recorder, request)
78+
return recorder
79+
}
80+
81+
func TestHeaderNavModuleAuthAllowsDefaultPublicAccess(t *testing.T) {
82+
withHeaderNavModules(t, "")
83+
84+
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false)
85+
86+
require.Equal(t, http.StatusOK, recorder.Code)
87+
}
88+
89+
func TestHeaderNavModuleAuthRejectsDisabledPricing(t *testing.T) {
90+
raw := `{"pricing":{"enabled":false,"requireAuth":false}}`
91+
withHeaderNavModules(t, raw)
92+
93+
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false)
94+
95+
require.Equal(t, http.StatusForbidden, recorder.Code)
96+
}
97+
98+
func TestHeaderNavModuleAuthRequiresLoginForPricing(t *testing.T) {
99+
raw := `{"pricing":{"enabled":true,"requireAuth":true}}`
100+
withHeaderNavModules(t, raw)
101+
102+
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false)
103+
104+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
105+
}
106+
107+
func TestHeaderNavModuleAuthRequiresLoginForRankings(t *testing.T) {
108+
raw := `{"rankings":{"enabled":true,"requireAuth":true}}`
109+
withHeaderNavModules(t, raw)
110+
111+
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("rankings"), false)
112+
113+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
114+
}
115+
116+
func TestHeaderNavModuleAuthRejectsLegacyDisabledModule(t *testing.T) {
117+
raw := `{"rankings":false}`
118+
withHeaderNavModules(t, raw)
119+
120+
recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("rankings"), false)
121+
122+
require.Equal(t, http.StatusForbidden, recorder.Code)
123+
}
124+
125+
func TestHeaderNavModulePublicOrUserAuthAllowsDefaultPublicAccess(t *testing.T) {
126+
withHeaderNavModules(t, "")
127+
128+
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false)
129+
130+
require.Equal(t, http.StatusOK, recorder.Code)
131+
}
132+
133+
func TestHeaderNavModulePublicOrUserAuthRequiresLoginWhenDisabled(t *testing.T) {
134+
raw := `{"pricing":{"enabled":false,"requireAuth":false}}`
135+
withHeaderNavModules(t, raw)
136+
137+
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false)
138+
139+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
140+
}
141+
142+
func TestHeaderNavModulePublicOrUserAuthAllowsLoggedInWhenDisabled(t *testing.T) {
143+
raw := `{"pricing":{"enabled":false,"requireAuth":false}}`
144+
withHeaderNavModules(t, raw)
145+
146+
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), true)
147+
148+
require.Equal(t, http.StatusOK, recorder.Code)
149+
}
150+
151+
func TestHeaderNavModulePublicOrUserAuthRequiresLoginWhenRequireAuth(t *testing.T) {
152+
raw := `{"pricing":{"enabled":true,"requireAuth":true}}`
153+
withHeaderNavModules(t, raw)
154+
155+
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false)
156+
157+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
158+
}
159+
160+
func TestHeaderNavModulePublicOrUserAuthRequiresLoginForLegacyDisabledModule(t *testing.T) {
161+
raw := `{"pricing":false}`
162+
withHeaderNavModules(t, raw)
163+
164+
recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false)
165+
166+
require.Equal(t, http.StatusUnauthorized, recorder.Code)
167+
}

router/api-router.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ func SetApiRouter(router *gin.Engine) {
3030
apiRouter.GET("/about", controller.GetAbout)
3131
//apiRouter.GET("/midjourney", controller.GetMidjourney)
3232
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
33-
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
33+
apiRouter.GET("/pricing", middleware.HeaderNavModuleAuth("pricing"), controller.GetPricing)
3434
perfMetricsRoute := apiRouter.Group("/perf-metrics")
35-
perfMetricsRoute.Use(middleware.TryUserAuth())
35+
perfMetricsRoute.Use(middleware.HeaderNavModulePublicOrUserAuth("pricing"))
3636
{
3737
perfMetricsRoute.GET("/summary", controller.GetPerfMetricsSummary)
3838
perfMetricsRoute.GET("", controller.GetPerfMetrics)
3939
}
40-
apiRouter.GET("/rankings", controller.GetRankings)
40+
apiRouter.GET("/rankings", middleware.HeaderNavModuleAuth("rankings"), controller.GetRankings)
4141
apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
4242
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
4343
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)

0 commit comments

Comments
 (0)