Skip to content

Commit 50883dd

Browse files
committed
Merge remote-tracking branch 'upstream/main' into feat/picod-two-stage-init
# Conflicts: # pkg/picod/auth.go # pkg/picod/auth_test.go # pkg/picod/server_test.go
2 parents d22a581 + 40ec9ae commit 50883dd

7 files changed

Lines changed: 75 additions & 47 deletions

File tree

pkg/picod/auth.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ var (
3939
)
4040

4141
const (
42-
// MaxBodySize limits request body size to prevent memory exhaustion
43-
MaxBodySize = 32 << 20 // 32 MB
44-
4542
// BootstrapPublicKeyEnvVar is the environment variable name for the bootstrap public key
4643
BootstrapPublicKeyEnvVar = "PICOD_BOOTSTRAP_PUBLIC_KEY"
4744
)
@@ -216,9 +213,6 @@ func (am *AuthManager) AuthMiddleware() gin.HandlerFunc {
216213
return
217214
}
218215

219-
// Enforce maximum body size to prevent memory exhaustion
220-
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, MaxBodySize)
221-
222216
c.Next()
223217
}
224218
}

pkg/picod/auth_test.go

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -305,34 +305,3 @@ func TestAuthMiddleware_TokenValidation(t *testing.T) {
305305
})
306306
}
307307
}
308-
309-
func TestAuthMiddleware_MaxBodySize(t *testing.T) {
310-
privateKey, pubKeyPEM, err := generateTestRSAKeyPair()
311-
require.NoError(t, err)
312-
313-
os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM)
314-
defer os.Unsetenv(BootstrapPublicKeyEnvVar)
315-
316-
manager := NewAuthManager()
317-
err = manager.LoadBootstrapPublicKey()
318-
require.NoError(t, err)
319-
err = manager.SetSessionPublicKey(pubKeyPEM)
320-
require.NoError(t, err)
321-
322-
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
323-
"exp": time.Now().Add(time.Hour).Unix(),
324-
"iat": time.Now().Unix(),
325-
})
326-
tokenString, err := token.SignedString(privateKey)
327-
require.NoError(t, err)
328-
329-
w := httptest.NewRecorder()
330-
c, _ := gin.CreateTestContext(w)
331-
c.Request, _ = http.NewRequest("POST", "/api/execute", nil)
332-
c.Request.Header.Set("Authorization", "Bearer "+tokenString)
333-
334-
handler := manager.AuthMiddleware()
335-
handler(c)
336-
337-
assert.NotNil(t, c.Request.Body)
338-
}

pkg/picod/server.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ import (
2727
"k8s.io/klog/v2"
2828
)
2929

30+
const (
31+
// MaxBodySize limits request body size to prevent memory exhaustion
32+
MaxBodySize = 32 << 20 // 32 MB
33+
)
34+
3035
// Config defines server configuration
3136
type Config struct {
3237
Port int `json:"port"`
@@ -73,6 +78,23 @@ func NewServer(config Config) *Server {
7378
// Global middleware
7479
engine.Use(gin.Logger()) // Request logging
7580
engine.Use(gin.Recovery()) // Crash recovery
81+
// Limit request body size to prevent DoS attacks.
82+
// First reject requests whose Content-Length already exceeds the limit,
83+
// then wrap the body with MaxBytesReader as a safety net for chunked
84+
// transfers or requests without Content-Length.
85+
engine.Use(func(c *gin.Context) {
86+
if c.Request.ContentLength > MaxBodySize {
87+
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
88+
"error": "request body too large",
89+
"detail": fmt.Sprintf("maximum allowed size is %d bytes", MaxBodySize),
90+
})
91+
c.Abort()
92+
return
93+
}
94+
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, MaxBodySize)
95+
c.Next()
96+
})
97+
engine.MaxMultipartMemory = MaxBodySize
7698

7799
// Load bootstrap public key from environment variable (required)
78100
if err := s.authManager.LoadBootstrapPublicKey(); err != nil {

pkg/picod/server_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ import (
2424
"encoding/json"
2525
"encoding/pem"
2626
"fmt"
27+
"io"
2728
"net/http"
2829
"net/http/httptest"
2930
"os"
3031
"path/filepath"
32+
"strings"
3133
"testing"
3234
"time"
3335

@@ -486,3 +488,35 @@ func TestInitHandler(t *testing.T) {
486488
assert.Equal(t, "session already initialized", resp["error"])
487489
})
488490
}
491+
492+
func TestServer_MaxBodySizeMiddleware(t *testing.T) {
493+
tmpDir, err := os.MkdirTemp("", "picod-server-test-*")
494+
require.NoError(t, err)
495+
defer os.RemoveAll(tmpDir)
496+
497+
pubKeyPEM := generateTestPublicKeyPEM(t)
498+
os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM)
499+
defer os.Unsetenv(BootstrapPublicKeyEnvVar)
500+
501+
server := NewServer(Config{
502+
Port: 8080,
503+
Workspace: tmpDir,
504+
})
505+
506+
ts := httptest.NewServer(server.engine)
507+
defer ts.Close()
508+
509+
// When Content-Length exceeds MaxBodySize, the global body-size limiter
510+
// middleware should reject the request with 413 before any other
511+
// middleware (auth, handler) gets a chance to run.
512+
oversizedBody := strings.NewReader(strings.Repeat("x", int(MaxBodySize)+1))
513+
resp, err := http.Post(ts.URL+"/api/execute", "application/json", oversizedBody)
514+
require.NoError(t, err)
515+
defer resp.Body.Close()
516+
517+
assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode)
518+
519+
body, err := io.ReadAll(resp.Body)
520+
require.NoError(t, err)
521+
assert.Contains(t, string(body), "request body too large")
522+
}

pkg/router/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (s *Server) concurrencyLimitMiddleware() gin.HandlerFunc {
119119
}()
120120
c.Next()
121121
default:
122-
// No slots available, return 503 Service Unavailable
122+
// No slots available, return 429 Too Many Requests
123123
c.JSON(http.StatusTooManyRequests, gin.H{
124124
"error": "server overloaded, please try again later",
125125
"code": "SERVER_OVERLOADED",

pkg/workloadmanager/client_cache.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,19 +230,25 @@ func NewTokenCache(maxSize int, ttl time.Duration) *TokenCache {
230230
// Returns found status, authenticated status, and username
231231
// If found is false, the token was not in cache or expired
232232
func (c *TokenCache) Get(token string) (found bool, authenticated bool, username string) {
233-
c.mu.RLock()
234-
defer c.mu.RUnlock()
233+
c.mu.Lock()
234+
defer c.mu.Unlock()
235235

236236
entry, exists := c.cache[token]
237237
if !exists {
238238
return false, false, ""
239239
}
240240

241-
// Check if entry is expired
241+
// Check if entry is expired; evict stale entries to prevent memory leak.
242+
// This mirrors ClientCache.Get which correctly removes expired entries.
242243
if time.Since(entry.lastAccess) > c.ttl {
244+
c.lruList.Remove(entry.element)
245+
delete(c.cache, token)
243246
return false, false, ""
244247
}
245248

249+
// Move to front on access for proper LRU ordering and sliding TTL.
250+
entry.lastAccess = time.Now()
251+
c.lruList.MoveToFront(entry.element)
246252
return true, entry.authenticated, entry.username
247253
}
248254

pkg/workloadmanager/client_cache_test.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,8 @@ func TestTokenCache_Get_Expired(t *testing.T) {
432432
found, authenticated, _ = cache.Get(token)
433433
assert.False(t, found)
434434
assert.False(t, authenticated)
435+
// Expired entry should be evicted from cache, not just hidden
436+
assert.Equal(t, 0, cache.Size(), "Expired entry should be evicted from cache")
435437
}
436438

437439
func TestTokenCache_UpdateExisting(t *testing.T) {
@@ -486,18 +488,19 @@ func TestTokenCache_LRUBehavior(t *testing.T) {
486488
cache.Set(token, true, "user"+string(rune('0'+i)))
487489
}
488490

489-
// Access first entry (Get doesn't update LRU, only Set does)
491+
// Access first entry (Get promotes to front of LRU list)
490492
cache.Get("token0")
491493

492-
// Add new entry - should evict oldest (token0, since Get doesn't update LRU)
494+
// Add new entry - should evict token1 (now the least recently used
495+
// because token0 was promoted by Get)
493496
cache.Set("token3", true, "user3")
494497

495-
// token0 should be evicted (oldest in LRU list)
498+
// token0 should still be present (was promoted by Get)
496499
found, _, _ := cache.Get("token0")
497-
assert.False(t, found)
498-
// token1 should be present
499-
found, _, _ = cache.Get("token1")
500500
assert.True(t, found)
501+
// token1 should be evicted (least recently used)
502+
found, _, _ = cache.Get("token1")
503+
assert.False(t, found)
501504
// token2 should be present
502505
found, _, _ = cache.Get("token2")
503506
assert.True(t, found)

0 commit comments

Comments
 (0)