Skip to content

Commit 2ef218c

Browse files
committed
fix: address review feedback and fix lint
Signed-off-by: Abhinav Singh <abhinavsingh717073@gmail.com>
1 parent d598f1c commit 2ef218c

14 files changed

Lines changed: 410 additions & 83 deletions

File tree

pkg/common/types/sandbox.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ type SandboxInfo struct {
3838
// metav1.Duration marshals as a human-readable string (e.g. "15m0s") rather than
3939
// a raw nanosecond integer, making the persisted JSON unambiguous.
4040
IdleTimeout metav1.Duration `json:"idleTimeout,omitempty"`
41+
// SessionPrivateKey is the unique RSA private key generated for this specific session.
42+
// It is used by the Router to sign JWT tokens for requests forwarded to PicoD.
43+
SessionPrivateKey string `json:"sessionPrivateKey,omitempty"`
4144
// LastActivityAt is populated transiently from the store's last-activity sorted set
4245
// during ListInactiveSandboxes. It is intentionally excluded from JSON serialization.
4346
LastActivityAt time.Time `json:"-"`
@@ -57,11 +60,12 @@ type CreateSandboxRequest struct {
5760
}
5861

5962
type CreateSandboxResponse struct {
60-
Kind string `json:"kind"`
61-
SessionID string `json:"sessionId"`
62-
SandboxID string `json:"sandboxId"`
63-
SandboxName string `json:"sandboxName"`
64-
EntryPoints []SandboxEntryPoint `json:"entryPoints"`
63+
Kind string `json:"kind"`
64+
SessionID string `json:"sessionId"`
65+
SandboxID string `json:"sandboxId"`
66+
SandboxName string `json:"sandboxName"`
67+
EntryPoints []SandboxEntryPoint `json:"entryPoints"`
68+
SessionPrivateKey string `json:"sessionPrivateKey,omitempty"`
6569
}
6670

6771
func (car *CreateSandboxRequest) Validate() error {

pkg/picod/auth.go

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"crypto/rsa"
2121
"crypto/x509"
2222
"encoding/pem"
23+
"errors"
2324
"fmt"
2425
"net/http"
2526
"os"
@@ -32,6 +33,11 @@ import (
3233
"k8s.io/klog/v2"
3334
)
3435

36+
var (
37+
// ErrAlreadyInitialized is returned when attempting to initialize PicoD session key again
38+
ErrAlreadyInitialized = errors.New("session has already been initialized")
39+
)
40+
3541
const (
3642
// MaxBodySize limits request body size to prevent memory exhaustion
3743
MaxBodySize = 32 << 20 // 32 MB
@@ -55,6 +61,26 @@ func NewAuthManager() *AuthManager {
5561
return &AuthManager{}
5662
}
5763

64+
// parseRSAPublicKeyFromPEM parses an RSA public key from a PEM string
65+
func parseRSAPublicKeyFromPEM(keyData string) (*rsa.PublicKey, error) {
66+
block, _ := pem.Decode([]byte(keyData))
67+
if block == nil {
68+
return nil, fmt.Errorf("failed to decode PEM block")
69+
}
70+
71+
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
72+
if err != nil {
73+
return nil, fmt.Errorf("failed to parse public key: %w", err)
74+
}
75+
76+
rsaPub, ok := pub.(*rsa.PublicKey)
77+
if !ok {
78+
return nil, fmt.Errorf("key is not an RSA public key")
79+
}
80+
81+
return rsaPub, nil
82+
}
83+
5884
// LoadBootstrapPublicKey loads the bootstrap public key from environment variable.
5985
// The key should be in PEM format.
6086
func (am *AuthManager) LoadBootstrapPublicKey() error {
@@ -66,19 +92,9 @@ func (am *AuthManager) LoadBootstrapPublicKey() error {
6692
return fmt.Errorf("environment variable %s is not set", BootstrapPublicKeyEnvVar)
6793
}
6894

69-
block, _ := pem.Decode([]byte(keyData))
70-
if block == nil {
71-
return fmt.Errorf("failed to decode PEM block from %s", BootstrapPublicKeyEnvVar)
72-
}
73-
74-
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
95+
rsaPub, err := parseRSAPublicKeyFromPEM(keyData)
7596
if err != nil {
76-
return fmt.Errorf("failed to parse public key: %w", err)
77-
}
78-
79-
rsaPub, ok := pub.(*rsa.PublicKey)
80-
if !ok {
81-
return fmt.Errorf("key is not an RSA public key")
97+
return fmt.Errorf("failed to parse bootstrap public key: %w", err)
8298
}
8399

84100
am.bootstrapPublicKey = rsaPub
@@ -91,21 +107,16 @@ func (am *AuthManager) SetSessionPublicKey(keyData string) error {
91107
am.mutex.Lock()
92108
defer am.mutex.Unlock()
93109

94-
block, _ := pem.Decode([]byte(keyData))
95-
if block == nil {
96-
return fmt.Errorf("failed to decode session public key PEM block")
110+
if am.initialized {
111+
klog.Warning("Attempted to re-initialize an already initialized session")
112+
return ErrAlreadyInitialized
97113
}
98114

99-
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
115+
rsaPub, err := parseRSAPublicKeyFromPEM(keyData)
100116
if err != nil {
101117
return fmt.Errorf("failed to parse session public key: %w", err)
102118
}
103119

104-
rsaPub, ok := pub.(*rsa.PublicKey)
105-
if !ok {
106-
return fmt.Errorf("session key is not an RSA public key")
107-
}
108-
109120
am.sessionPublicKey = rsaPub
110121
am.initialized = true
111122
klog.Info("Session public key successfully registered via /init")

pkg/picod/server.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package picod
1818

1919
import (
20+
"errors"
2021
"fmt"
2122
"net/http"
2223
"os"
@@ -139,6 +140,10 @@ func (s *Server) InitHandler(c *gin.Context) {
139140
}
140141

141142
if err := s.authManager.SetSessionPublicKey(sessionPubKey); err != nil {
143+
if errors.Is(err, ErrAlreadyInitialized) {
144+
c.JSON(http.StatusConflict, gin.H{"error": "session already initialized", "detail": err.Error()})
145+
return
146+
}
142147
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initialize session key", "detail": err.Error()})
143148
return
144149
}

pkg/picod/server_test.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package picod
1818

1919
import (
20+
"bytes"
2021
"crypto/rand"
2122
"crypto/rsa"
2223
"crypto/x509"
@@ -31,6 +32,7 @@ import (
3132
"time"
3233

3334
"github.com/gin-gonic/gin"
35+
"github.com/golang-jwt/jwt/v5"
3436
"github.com/stretchr/testify/assert"
3537
"github.com/stretchr/testify/require"
3638
)
@@ -368,3 +370,117 @@ func TestNewServer_DifferentPorts(t *testing.T) {
368370
})
369371
}
370372
}
373+
374+
func TestInitHandler(t *testing.T) {
375+
tmpDir, err := os.MkdirTemp("", "picod-server-test-*")
376+
require.NoError(t, err)
377+
defer os.RemoveAll(tmpDir)
378+
379+
// Generate bootstrap keys
380+
bootstrapPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
381+
require.NoError(t, err)
382+
383+
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&bootstrapPrivKey.PublicKey)
384+
require.NoError(t, err)
385+
386+
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
387+
Type: "PUBLIC KEY",
388+
Bytes: pubKeyBytes,
389+
})
390+
391+
os.Setenv(BootstrapPublicKeyEnvVar, string(pubKeyPEM))
392+
defer os.Unsetenv(BootstrapPublicKeyEnvVar)
393+
394+
config := Config{
395+
Port: 8080,
396+
Workspace: tmpDir,
397+
}
398+
server := NewServer(config)
399+
400+
// Helper to generate a token signed by bootstrap private key
401+
generateToken := func(claims jwt.MapClaims) string {
402+
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
403+
tokenStr, err := token.SignedString(bootstrapPrivKey)
404+
require.NoError(t, err)
405+
return tokenStr
406+
}
407+
408+
sessionPubPEM := generateTestPublicKeyPEM(t)
409+
410+
t.Run("invalid request format", func(t *testing.T) {
411+
w := httptest.NewRecorder()
412+
c, _ := gin.CreateTestContext(w)
413+
c.Request, _ = http.NewRequest("POST", "/init", bytes.NewBufferString("{invalid-json}"))
414+
c.Request.Header.Set("Content-Type", "application/json")
415+
416+
server.InitHandler(c)
417+
418+
assert.Equal(t, http.StatusBadRequest, w.Code)
419+
var resp map[string]string
420+
err := json.Unmarshal(w.Body.Bytes(), &resp)
421+
require.NoError(t, err)
422+
assert.Contains(t, resp["error"], "invalid request format")
423+
})
424+
425+
t.Run("invalid token", func(t *testing.T) {
426+
w := httptest.NewRecorder()
427+
c, _ := gin.CreateTestContext(w)
428+
body, _ := json.Marshal(map[string]string{"token": "invalid.token.here"})
429+
c.Request, _ = http.NewRequest("POST", "/init", bytes.NewBuffer(body))
430+
c.Request.Header.Set("Content-Type", "application/json")
431+
432+
server.InitHandler(c)
433+
434+
assert.Equal(t, http.StatusUnauthorized, w.Code)
435+
var resp map[string]string
436+
err := json.Unmarshal(w.Body.Bytes(), &resp)
437+
require.NoError(t, err)
438+
assert.Contains(t, resp["error"], "invalid bootstrap token")
439+
})
440+
441+
t.Run("successful initialization", func(t *testing.T) {
442+
claims := jwt.MapClaims{
443+
"exp": time.Now().Add(time.Minute).Unix(),
444+
"iat": time.Now().Unix(),
445+
"session_public_key": sessionPubPEM,
446+
}
447+
token := generateToken(claims)
448+
449+
w := httptest.NewRecorder()
450+
c, _ := gin.CreateTestContext(w)
451+
body, _ := json.Marshal(map[string]string{"token": token})
452+
c.Request, _ = http.NewRequest("POST", "/init", bytes.NewBuffer(body))
453+
c.Request.Header.Set("Content-Type", "application/json")
454+
455+
server.InitHandler(c)
456+
457+
assert.Equal(t, http.StatusOK, w.Code)
458+
var resp map[string]string
459+
err := json.Unmarshal(w.Body.Bytes(), &resp)
460+
require.NoError(t, err)
461+
assert.Equal(t, "initialized successfully", resp["status"])
462+
})
463+
464+
t.Run("already initialized", func(t *testing.T) {
465+
claims := jwt.MapClaims{
466+
"exp": time.Now().Add(time.Minute).Unix(),
467+
"iat": time.Now().Unix(),
468+
"session_public_key": sessionPubPEM,
469+
}
470+
token := generateToken(claims)
471+
472+
w := httptest.NewRecorder()
473+
c, _ := gin.CreateTestContext(w)
474+
body, _ := json.Marshal(map[string]string{"token": token})
475+
c.Request, _ = http.NewRequest("POST", "/init", bytes.NewBuffer(body))
476+
c.Request.Header.Set("Content-Type", "application/json")
477+
478+
server.InitHandler(c)
479+
480+
assert.Equal(t, http.StatusConflict, w.Code)
481+
var resp map[string]string
482+
err := json.Unmarshal(w.Body.Bytes(), &resp)
483+
require.NoError(t, err)
484+
assert.Equal(t, "session already initialized", resp["error"])
485+
})
486+
}

pkg/router/handlers.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,13 @@ func (s *Server) generateSandboxJWT(c *gin.Context, sandbox *types.SandboxInfo)
240240
claims := map[string]interface{}{
241241
"session_id": sandbox.SessionID,
242242
}
243-
token, err := s.jwtManager.GenerateToken(claims)
243+
var token string
244+
var err error
245+
if sandbox.SessionPrivateKey != "" {
246+
token, err = s.jwtManager.GenerateTokenWithKey(claims, sandbox.SessionPrivateKey)
247+
} else {
248+
token, err = s.jwtManager.GenerateToken(claims)
249+
}
244250
if err != nil {
245251
klog.Errorf("Failed to generate JWT token (session: %s): %v", sandbox.SessionID, err)
246252
c.JSON(http.StatusInternalServerError, gin.H{

pkg/router/jwt.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,37 @@ func (jm *JWTManager) GenerateToken(claims map[string]interface{}) (string, erro
103103
return tokenString, nil
104104
}
105105

106+
// GenerateTokenWithKey generates a JWT token signed with a specific PEM-encoded private key
107+
func (jm *JWTManager) GenerateTokenWithKey(claims map[string]interface{}, privateKeyPEM string) (string, error) {
108+
privKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKeyPEM))
109+
if err != nil {
110+
return "", fmt.Errorf("failed to parse private key: %w", err)
111+
}
112+
113+
// Create JWT claims
114+
jwtClaims := jwt.MapClaims{
115+
"exp": time.Now().Add(jwtExpiration).Unix(),
116+
"iat": time.Now().Unix(),
117+
"iss": "agentcube-router",
118+
}
119+
120+
// Add custom claims
121+
for k, v := range claims {
122+
jwtClaims[k] = v
123+
}
124+
125+
// Create token
126+
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwtClaims)
127+
128+
// Sign token with private key
129+
tokenString, err := token.SignedString(privKey)
130+
if err != nil {
131+
return "", fmt.Errorf("failed to sign JWT token: %w", err)
132+
}
133+
134+
return tokenString, nil
135+
}
136+
106137
// GetPublicKeyPEM returns the public key in PEM format
107138
func (jm *JWTManager) GetPublicKeyPEM() ([]byte, error) {
108139
pubKeyBytes, err := x509.MarshalPKIXPublicKey(jm.publicKey)

pkg/router/session_manager.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,13 @@ func (m *manager) createSandbox(ctx context.Context, namespace string, name stri
185185

186186
// Construct Sandbox Info from response
187187
sandbox := &types.SandboxInfo{
188-
Kind: res.Kind,
189-
SandboxNamespace: namespace,
190-
SandboxID: res.SandboxID,
191-
Name: res.SandboxName,
192-
SessionID: res.SessionID,
193-
EntryPoints: res.EntryPoints,
188+
Kind: res.Kind,
189+
SandboxNamespace: namespace,
190+
SandboxID: res.SandboxID,
191+
Name: res.SandboxName,
192+
SessionID: res.SessionID,
193+
EntryPoints: res.EntryPoints,
194+
SessionPrivateKey: res.SessionPrivateKey,
194195
}
195196

196197
return sandbox, nil

pkg/workloadmanager/bootstrap_auth.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,34 @@ func GetBootstrapPublicKeyPEM() string {
6262
return bootstrapPublicKeyPEM
6363
}
6464

65+
// GenerateSessionKeyPair generates an ephemeral 2048-bit RSA key pair for a specific session.
66+
// It returns the private key PEM and public key PEM strings.
67+
func GenerateSessionKeyPair() (string, string, error) {
68+
klog.Info("Generating unique Session Key Pair")
69+
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
70+
if err != nil {
71+
return "", "", fmt.Errorf("failed to generate RSA key: %w", err)
72+
}
73+
74+
privASN1 := x509.MarshalPKCS1PrivateKey(privateKey)
75+
privPEM := pem.EncodeToMemory(&pem.Block{
76+
Type: "RSA PRIVATE KEY",
77+
Bytes: privASN1,
78+
})
79+
80+
pubASN1, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
81+
if err != nil {
82+
return "", "", fmt.Errorf("failed to marshal public key: %w", err)
83+
}
84+
85+
pubPEM := pem.EncodeToMemory(&pem.Block{
86+
Type: "PUBLIC KEY",
87+
Bytes: pubASN1,
88+
})
89+
90+
return string(privPEM), string(pubPEM), nil
91+
}
92+
6593
// GenerateInitJWT creates a JWT signed by the Bootstrap Private Key
6694
// containing the session_public_key as a custom claim.
6795
func GenerateInitJWT(sandboxID, sessionPublicKey string) (string, error) {

0 commit comments

Comments
 (0)