diff --git a/cmd/workload-manager/main.go b/cmd/workload-manager/main.go index 65f1d80c..bd5eceb0 100644 --- a/cmd/workload-manager/main.go +++ b/cmd/workload-manager/main.go @@ -80,21 +80,6 @@ func main() { os.Exit(1) } - sandboxReconciler := &workloadmanager.SandboxReconciler{ - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), - } - - codeInterpreterReconciler := &workloadmanager.CodeInterpreterReconciler{ - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), - } - - if err := setupControllers(mgr, sandboxReconciler, codeInterpreterReconciler); err != nil { - fmt.Fprintf(os.Stderr, "unable to setup controllers: %v\n", err) - os.Exit(1) - } - // Create API server configuration config := &workloadmanager.Config{ Port: *port, @@ -105,12 +90,28 @@ func main() { EnableAuth: *enableAuth, } + sandboxReconciler := &workloadmanager.SandboxReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + } + // Create and initialize API server server, err := workloadmanager.NewServer(config, sandboxReconciler) if err != nil { klog.Fatalf("Failed to create API server: %v", err) } + codeInterpreterReconciler := &workloadmanager.CodeInterpreterReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + BootstrapPublicKeyPEM: server.GetBootstrapPublicKeyPEM(), + } + + if err := setupControllers(mgr, sandboxReconciler, codeInterpreterReconciler); err != nil { + fmt.Fprintf(os.Stderr, "unable to setup controllers: %v\n", err) + os.Exit(1) + } + // Setup signal handling ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/docs/design/PicoD-Plain-Authentication-Design.md b/docs/design/PicoD-Plain-Authentication-Design.md index e3bb2a2a..29eebe99 100644 --- a/docs/design/PicoD-Plain-Authentication-Design.md +++ b/docs/design/PicoD-Plain-Authentication-Design.md @@ -12,6 +12,13 @@ However, emerging use cases require a more flexible architecture where the clien The existing self-signed key-pair model is incompatible with this centralized management flow, as it bypasses the Router's ability to mediate access. To address this, we propose a new **Plain Authentication** mechanism for `picod`. This design enables the Router/Gateway to manage credentials and connection security, simplifying the client-side workflow while maintaining robust access control. +> [!WARNING] +> **Migration Note (Two-Stage Secure Initialization)** +> +> The flow described in this document was updated by PR #352 to address cross-sandbox token replay vulnerabilities. The original `PICOD_AUTH_PUBLIC_KEY` environment variable has been renamed to `PICOD_BOOTSTRAP_PUBLIC_KEY`. +> +> While `PICOD_AUTH_PUBLIC_KEY` is still supported as a fallback for backwards compatibility, deployments should migrate to `PICOD_BOOTSTRAP_PUBLIC_KEY`. Under the new model, this key is only used to verify the bootstrap payload during the `/init` handshake, which establishes a unique session keypair for subsequent requests. + ## Use Cases ### Gateway-Managed Sandbox Access diff --git a/pkg/common/types/sandbox.go b/pkg/common/types/sandbox.go index b829f780..9d7b6439 100644 --- a/pkg/common/types/sandbox.go +++ b/pkg/common/types/sandbox.go @@ -38,6 +38,10 @@ type SandboxInfo struct { // metav1.Duration marshals as a human-readable string (e.g. "15m0s") rather than // a raw nanosecond integer, making the persisted JSON unambiguous. IdleTimeout metav1.Duration `json:"idleTimeout,omitempty"` + // SessionPrivateKey is intentionally excluded from JSON serialization so it + // is never persisted to the KV store or exposed via list/get APIs. + // It is only populated transiently in the WM→Router HTTP response path. + SessionPrivateKey string `json:"-"` // LastActivityAt is populated transiently from the store's last-activity sorted set // during ListInactiveSandboxes. It is intentionally excluded from JSON serialization. LastActivityAt time.Time `json:"-"` diff --git a/pkg/picod/auth.go b/pkg/picod/auth.go index eae10094..9365b650 100644 --- a/pkg/picod/auth.go +++ b/pkg/picod/auth.go @@ -20,6 +20,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" "fmt" "net/http" "os" @@ -32,16 +33,27 @@ import ( "k8s.io/klog/v2" ) +var ( + // ErrAlreadyInitialized is returned when attempting to initialize PicoD session key again + ErrAlreadyInitialized = errors.New("session has already been initialized") +) + const ( - // PublicKeyEnvVar is the environment variable name for the public key - PublicKeyEnvVar = "PICOD_AUTH_PUBLIC_KEY" + + // BootstrapPublicKeyEnvVar is the environment variable name for the bootstrap public key + BootstrapPublicKeyEnvVar = "PICOD_BOOTSTRAP_PUBLIC_KEY" + // LegacyBootstrapPublicKeyEnvVar is the deprecated environment variable name for the bootstrap public key + LegacyBootstrapPublicKeyEnvVar = "PICOD_AUTH_PUBLIC_KEY" ) // AuthManager manages RSA public key authentication -// The public key is loaded from environment variable at startup +// The bootstrap public key is loaded from environment variable at startup +// The session public key is set dynamically via the /init endpoint type AuthManager struct { - publicKey *rsa.PublicKey - mutex sync.RWMutex + bootstrapPublicKey *rsa.PublicKey + sessionPublicKey *rsa.PublicKey + initialized bool + mutex sync.RWMutex } // NewAuthManager creates a new auth manager @@ -49,41 +61,124 @@ func NewAuthManager() *AuthManager { return &AuthManager{} } -// LoadPublicKeyFromEnv loads the public key from environment variable. +// parseRSAPublicKeyFromPEM parses an RSA public key from a PEM string +func parseRSAPublicKeyFromPEM(keyData string) (*rsa.PublicKey, error) { + block, _ := pem.Decode([]byte(keyData)) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("key is not an RSA public key") + } + + return rsaPub, nil +} + +// LoadBootstrapPublicKey loads the bootstrap public key from environment variable. // The key should be in PEM format. -func (am *AuthManager) LoadPublicKeyFromEnv() error { +func (am *AuthManager) LoadBootstrapPublicKey() error { am.mutex.Lock() defer am.mutex.Unlock() - keyData := os.Getenv(PublicKeyEnvVar) + keyData := os.Getenv(BootstrapPublicKeyEnvVar) if keyData == "" { - return fmt.Errorf("environment variable %s is not set", PublicKeyEnvVar) + keyData = os.Getenv(LegacyBootstrapPublicKeyEnvVar) + if keyData != "" { + klog.Warningf("Using deprecated environment variable %s, please migrate to %s", LegacyBootstrapPublicKeyEnvVar, BootstrapPublicKeyEnvVar) + } } - block, _ := pem.Decode([]byte(keyData)) - if block == nil { - return fmt.Errorf("failed to decode PEM block from %s", PublicKeyEnvVar) + if keyData == "" { + return fmt.Errorf("environment variable %s (or deprecated %s) is not set", BootstrapPublicKeyEnvVar, LegacyBootstrapPublicKeyEnvVar) } - pub, err := x509.ParsePKIXPublicKey(block.Bytes) + rsaPub, err := parseRSAPublicKeyFromPEM(keyData) if err != nil { - return fmt.Errorf("failed to parse public key: %w", err) + return fmt.Errorf("failed to parse bootstrap public key: %w", err) } - rsaPub, ok := pub.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("key is not an RSA public key") + am.bootstrapPublicKey = rsaPub + klog.Info("Bootstrap public key loaded successfully from environment variable") + return nil +} + +// SetSessionPublicKey parses and stores the ephemeral session public key +func (am *AuthManager) SetSessionPublicKey(keyData string) error { + am.mutex.Lock() + defer am.mutex.Unlock() + + if am.initialized { + klog.Warning("Attempted to re-initialize an already initialized session") + return ErrAlreadyInitialized } - am.publicKey = rsaPub - klog.Info("Public key loaded successfully from environment variable") + rsaPub, err := parseRSAPublicKeyFromPEM(keyData) + if err != nil { + return fmt.Errorf("failed to parse session public key: %w", err) + } + + am.sessionPublicKey = rsaPub + am.initialized = true + klog.Info("Session public key successfully registered via /init") return nil } +// VerifyBootstrapJWT verifies the init token against the bootstrap public key and returns the session_public_key claim +func (am *AuthManager) VerifyBootstrapJWT(tokenStr string) (string, error) { + token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + am.mutex.RLock() + defer am.mutex.RUnlock() + if am.bootstrapPublicKey == nil { + return nil, fmt.Errorf("bootstrap public key is not loaded") + } + return am.bootstrapPublicKey, nil + }, jwt.WithExpirationRequired(), jwt.WithIssuedAt(), jwt.WithLeeway(time.Minute), jwt.WithIssuer("agentcube-workload-manager")) + + if err != nil || !token.Valid { + return "", fmt.Errorf("JWT verification failed: %w", err) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return "", fmt.Errorf("invalid token claims") + } + + sessionPubKey, ok := claims["session_public_key"].(string) + if !ok || sessionPubKey == "" { + return "", fmt.Errorf("missing or invalid session_public_key claim") + } + + return sessionPubKey, nil +} + // AuthMiddleware creates authentication middleware with JWT verification -// Note: Public key must be loaded at startup (via LoadPublicKeyFromEnv), so we don't check here +// Note: Requires the daemon to be initialized with a session public key. func (am *AuthManager) AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { + // Check if initialized + am.mutex.RLock() + isInit := am.initialized + am.mutex.RUnlock() + if !isInit { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": "Daemon not initialized", + "code": http.StatusServiceUnavailable, + "detail": "PicoD is waiting for Workload Manager to initialize the session", + }) + c.Abort() + return + } + authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{ @@ -108,14 +203,14 @@ func (am *AuthManager) AuthMiddleware() gin.HandlerFunc { tokenString := parts[1] - // Parse and validate JWT using the public key + // Parse and validate JWT using the session public key token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } am.mutex.RLock() defer am.mutex.RUnlock() - return am.publicKey, nil + return am.sessionPublicKey, nil }, jwt.WithExpirationRequired(), jwt.WithIssuedAt(), jwt.WithLeeway(time.Minute)) if err != nil || !token.Valid { diff --git a/pkg/picod/auth_test.go b/pkg/picod/auth_test.go index 96a31673..e9165f39 100644 --- a/pkg/picod/auth_test.go +++ b/pkg/picod/auth_test.go @@ -64,7 +64,7 @@ func TestNewAuthManager(t *testing.T) { assert.NotNil(t, manager) } -func TestLoadPublicKeyFromEnv(t *testing.T) { +func TestLoadBootstrapPublicKey(t *testing.T) { tests := []struct { name string setupEnv func() string @@ -120,14 +120,14 @@ MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAinvalid t.Run(tt.name, func(t *testing.T) { envValue := tt.setupEnv() if envValue != "" { - os.Setenv(PublicKeyEnvVar, envValue) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, envValue) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) } else { - os.Unsetenv(PublicKeyEnvVar) + os.Unsetenv(BootstrapPublicKeyEnvVar) } manager := NewAuthManager() - err := manager.LoadPublicKeyFromEnv() + err := manager.LoadBootstrapPublicKey() if tt.wantErr { assert.Error(t, err) @@ -145,11 +145,13 @@ func TestAuthMiddleware_HeaderValidation(t *testing.T) { _, pubKeyPEM, err := generateTestRSAKeyPair() require.NoError(t, err) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) manager := NewAuthManager() - err = manager.LoadPublicKeyFromEnv() + err = manager.LoadBootstrapPublicKey() + require.NoError(t, err) + err = manager.SetSessionPublicKey(pubKeyPEM) require.NoError(t, err) tests := []struct { @@ -212,11 +214,13 @@ func TestAuthMiddleware_TokenValidation(t *testing.T) { privateKey, pubKeyPEM, err := generateTestRSAKeyPair() require.NoError(t, err) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) manager := NewAuthManager() - err = manager.LoadPublicKeyFromEnv() + err = manager.LoadBootstrapPublicKey() + require.NoError(t, err) + err = manager.SetSessionPublicKey(pubKeyPEM) require.NoError(t, err) tests := []struct { @@ -302,4 +306,3 @@ func TestAuthMiddleware_TokenValidation(t *testing.T) { } } - diff --git a/pkg/picod/execute_test.go b/pkg/picod/execute_test.go index d6cc4bdc..ad19affb 100644 --- a/pkg/picod/execute_test.go +++ b/pkg/picod/execute_test.go @@ -53,7 +53,7 @@ func setupExecuteTestServer(t *testing.T) (*Server, string) { Bytes: pubKeyBytes, }) - os.Setenv(PublicKeyEnvVar, string(pubKeyPEM)) + os.Setenv(BootstrapPublicKeyEnvVar, string(pubKeyPEM)) tmpDir, err := os.MkdirTemp("", "picod-execute-test-*") require.NoError(t, err) @@ -64,13 +64,15 @@ func setupExecuteTestServer(t *testing.T) (*Server, string) { } server := NewServer(config) + err = server.authManager.SetSessionPublicKey(string(pubKeyPEM)) + require.NoError(t, err) return server, tmpDir } func TestExecuteHandler_RequestValidation(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) tests := []struct { name string @@ -118,7 +120,7 @@ func TestExecuteHandler_RequestValidation(t *testing.T) { func TestExecuteHandler_TimeoutFormats(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) tests := []struct { name string @@ -196,7 +198,7 @@ func TestExecuteHandler_TimeoutFormats(t *testing.T) { func TestExecuteHandler_WorkingDirectory(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) subDir := filepath.Join(tmpDir, "subdir") require.NoError(t, os.Mkdir(subDir, 0755)) @@ -259,7 +261,7 @@ func TestExecuteHandler_WorkingDirectory(t *testing.T) { func TestExecuteHandler_WorkingDirectory_SymlinkEscape(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) // Plant a symlink inside the workspace that points outside it. outsideDir := t.TempDir() @@ -287,7 +289,7 @@ func TestExecuteHandler_WorkingDirectory_SymlinkEscape(t *testing.T) { func TestExecuteHandler_DefaultsToWorkspace(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) // No WorkingDir set — command should run inside the workspace directory. req := ExecuteRequest{ @@ -322,7 +324,7 @@ func TestExecuteHandler_DefaultsToWorkspace(t *testing.T) { func TestExecuteHandler_ExitCodes(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) tests := []struct { name string @@ -371,7 +373,7 @@ func TestExecuteHandler_ExitCodes(t *testing.T) { func TestExecuteHandler_TimeoutHandling(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) req := ExecuteRequest{ Command: []string{"sleep", "1"}, @@ -398,7 +400,7 @@ func TestExecuteHandler_TimeoutHandling(t *testing.T) { func TestExecuteHandler_EnvironmentVariables(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) req := ExecuteRequest{ Command: []string{"sh", "-c", "echo $TEST_VAR"}, @@ -426,7 +428,7 @@ func TestExecuteHandler_EnvironmentVariables(t *testing.T) { func TestExecuteHandler_ResponseStructure(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) req := ExecuteRequest{ Command: []string{"echo", "hello", "world"}, @@ -462,7 +464,7 @@ func TestExecuteHandler_ResponseStructure(t *testing.T) { func TestExecuteHandler_StderrCapture(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) req := ExecuteRequest{ Command: []string{"sh", "-c", "echo 'error message' >&2"}, @@ -487,7 +489,7 @@ func TestExecuteHandler_StderrCapture(t *testing.T) { func TestExecuteHandler_CommandWithArguments(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) req := ExecuteRequest{ Command: []string{"sh", "-c", "echo arg1 arg2 arg3"}, @@ -512,7 +514,7 @@ func TestExecuteHandler_CommandWithArguments(t *testing.T) { func TestExecuteHandler_EmptyEnvVars(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) req := ExecuteRequest{ Command: []string{"echo", "test"}, @@ -533,7 +535,7 @@ func TestExecuteHandler_EmptyEnvVars(t *testing.T) { func TestExecuteHandler_MultipleEnvVars(t *testing.T) { server, tmpDir := setupExecuteTestServer(t) defer os.RemoveAll(tmpDir) - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) req := ExecuteRequest{ Command: []string{"sh", "-c", "echo $VAR1 $VAR2 $VAR3"}, diff --git a/pkg/picod/picod_test.go b/pkg/picod/picod_test.go index ea116d1b..f2f17a19 100644 --- a/pkg/picod/picod_test.go +++ b/pkg/picod/picod_test.go @@ -67,8 +67,8 @@ func setupTestServer(t *testing.T, pubPEM string) (*Server, *httptest.Server, st tmpDir, err := os.MkdirTemp("", "picod_test") require.NoError(t, err) - // Set the public key environment variable - os.Setenv(PublicKeyEnvVar, pubPEM) + // Set the bootstrap public key environment variable + os.Setenv(BootstrapPublicKeyEnvVar, pubPEM) config := Config{ Port: 0, @@ -76,6 +76,8 @@ func setupTestServer(t *testing.T, pubPEM string) (*Server, *httptest.Server, st } server := NewServer(config) + err = server.authManager.SetSessionPublicKey(pubPEM) + require.NoError(t, err) ts := httptest.NewServer(server.engine) return server, ts, tmpDir @@ -89,7 +91,7 @@ func TestPicoD_EndToEnd(t *testing.T) { _, ts, tmpDir := setupTestServer(t, routerPubStr) defer os.RemoveAll(tmpDir) defer ts.Close() - defer os.Unsetenv(PublicKeyEnvVar) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) // Switch to temp dir for relative path tests originalWd, err := os.Getwd() @@ -342,8 +344,8 @@ func TestPicoD_DefaultWorkspace(t *testing.T) { // Set public key env _, pubStr := generateRSAKeys(t) - os.Setenv(PublicKeyEnvVar, pubStr) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubStr) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) // Initialize server with empty workspace config := Config{ diff --git a/pkg/picod/server.go b/pkg/picod/server.go index 29372ecd..378eaff8 100644 --- a/pkg/picod/server.go +++ b/pkg/picod/server.go @@ -17,6 +17,7 @@ limitations under the License. package picod import ( + "errors" "fmt" "net/http" "os" @@ -97,9 +98,9 @@ func NewServer(config Config) *Server { engine.MaxMultipartMemory = MaxBodySize engine.Use(gzip.Gzip(gzip.BestSpeed, gzip.WithExcludedPaths([]string{"/health"}))) // Response compression - // Load public key from environment variable (required) - if err := s.authManager.LoadPublicKeyFromEnv(); err != nil { - klog.Fatalf("Failed to load public key from environment: %v", err) + // Load bootstrap public key from environment variable (required) + if err := s.authManager.LoadBootstrapPublicKey(); err != nil { + klog.Fatalf("Failed to load bootstrap public key from environment: %v", err) } // API route group (Authenticated) @@ -112,6 +113,9 @@ func NewServer(config Config) *Server { api.GET("/files/*path", s.DownloadFileHandler) } + // Initialization endpoint (requires JWT signed by bootstrap key) + engine.POST("/init", s.InitHandler) + // Health check (no authentication required) engine.GET("/health", s.HealthCheckHandler) @@ -142,3 +146,33 @@ func (s *Server) HealthCheckHandler(c *gin.Context) { "uptime": time.Since(s.startTime).String(), }) } + +// InitHandler processes the initial POST /init request to set the session public key +func (s *Server) InitHandler(c *gin.Context) { + var req struct { + Token string `json:"token" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request format"}) + return + } + + sessionPubKey, err := s.authManager.VerifyBootstrapJWT(req.Token) + if err != nil { + klog.Errorf("bootstrap token verification failed for /init: %v", err) + c.JSON(http.StatusUnauthorized, gin.H{"error": "bootstrap token verification failed"}) + return + } + + if err := s.authManager.SetSessionPublicKey(sessionPubKey); err != nil { + if errors.Is(err, ErrAlreadyInitialized) { + c.JSON(http.StatusConflict, gin.H{"error": "session already initialized"}) + return + } + klog.Errorf("failed to set session public key: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initialize session key"}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "initialized successfully"}) +} diff --git a/pkg/picod/server_test.go b/pkg/picod/server_test.go index d38fcc86..3e6f9cf0 100644 --- a/pkg/picod/server_test.go +++ b/pkg/picod/server_test.go @@ -17,6 +17,7 @@ limitations under the License. package picod import ( + "bytes" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -33,6 +34,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,8 +60,8 @@ func generateTestPublicKeyPEM(t *testing.T) string { func TestNewServer_WorkspaceConfiguration(t *testing.T) { pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) tests := []struct { name string @@ -169,8 +171,8 @@ func TestNewServer_RoutesRegistered(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) config := Config{ Port: 8080, @@ -189,10 +191,10 @@ func TestNewServer_RoutesRegistered(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) resp.Body.Close() - // API routes should require auth (will return 401) + // API routes should require initialization and auth (will return 503 since not initialized) resp, err = http.Post(ts.URL+"/api/execute", "application/json", nil) require.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) resp.Body.Close() } @@ -202,15 +204,15 @@ func TestNewServer_PublicKeyRequired(t *testing.T) { defer os.RemoveAll(tmpDir) // Don't set public key environment variable - os.Unsetenv(PublicKeyEnvVar) + os.Unsetenv(BootstrapPublicKeyEnvVar) // NewServer should fail (calls klog.Fatalf which will panic in tests) // We can't easily test klog.Fatalf without mocking, but we can verify - // that LoadPublicKeyFromEnv would fail + // that LoadBootstrapPublicKey would fail authManager := NewAuthManager() - err = authManager.LoadPublicKeyFromEnv() + err = authManager.LoadBootstrapPublicKey() assert.Error(t, err) - assert.Contains(t, err.Error(), PublicKeyEnvVar) + assert.Contains(t, err.Error(), BootstrapPublicKeyEnvVar) } func TestHealthCheckHandler(t *testing.T) { @@ -219,8 +221,8 @@ func TestHealthCheckHandler(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) config := Config{ Port: 8080, @@ -268,8 +270,8 @@ func TestHealthCheckHandler_MultipleCalls(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) config := Config{ Port: 8080, @@ -306,8 +308,8 @@ func TestNewServer_EngineConfiguration(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) config := Config{ Port: 8080, @@ -328,8 +330,8 @@ func TestNewServer_AuthManagerInitialized(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) config := Config{ Port: 8080, @@ -351,8 +353,8 @@ func TestNewServer_DifferentPorts(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) ports := []int{8080, 9090, 3000, 0} @@ -377,8 +379,8 @@ func TestServer_GzipMiddleware_CompressesResponse(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) server := NewServer(Config{ Port: 8080, @@ -417,8 +419,8 @@ func TestServer_GzipMiddleware_ExcludesHealthEndpoint(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) server := NewServer(Config{ Port: 8080, @@ -454,8 +456,8 @@ func TestServer_MaxBodySizeMiddleware(t *testing.T) { defer os.RemoveAll(tmpDir) pubKeyPEM := generateTestPublicKeyPEM(t) - os.Setenv(PublicKeyEnvVar, pubKeyPEM) - defer os.Unsetenv(PublicKeyEnvVar) + os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) server := NewServer(Config{ Port: 8080, @@ -479,3 +481,119 @@ func TestServer_MaxBodySizeMiddleware(t *testing.T) { require.NoError(t, err) assert.Contains(t, string(body), "request body too large") } + +func TestInitHandler(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "picod-server-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Generate bootstrap keys + bootstrapPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + pubKeyBytes, err := x509.MarshalPKIXPublicKey(&bootstrapPrivKey.PublicKey) + require.NoError(t, err) + + pubKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: pubKeyBytes, + }) + + os.Setenv(BootstrapPublicKeyEnvVar, string(pubKeyPEM)) + defer os.Unsetenv(BootstrapPublicKeyEnvVar) + + config := Config{ + Port: 8080, + Workspace: tmpDir, + } + server := NewServer(config) + + // Helper to generate a token signed by bootstrap private key + generateToken := func(claims jwt.MapClaims) string { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenStr, err := token.SignedString(bootstrapPrivKey) + require.NoError(t, err) + return tokenStr + } + + sessionPubPEM := generateTestPublicKeyPEM(t) + + t.Run("invalid request format", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request, _ = http.NewRequest("POST", "/init", bytes.NewBufferString("{invalid-json}")) + c.Request.Header.Set("Content-Type", "application/json") + + server.InitHandler(c) + + assert.Equal(t, http.StatusBadRequest, w.Code) + var resp map[string]string + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Contains(t, resp["error"], "invalid request format") + }) + + t.Run("invalid token", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + body, _ := json.Marshal(map[string]string{"token": "invalid.token.here"}) + c.Request, _ = http.NewRequest("POST", "/init", bytes.NewBuffer(body)) + c.Request.Header.Set("Content-Type", "application/json") + + server.InitHandler(c) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + var resp map[string]string + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Contains(t, resp["error"], "bootstrap token verification failed") + }) + + t.Run("successful initialization", func(t *testing.T) { + claims := jwt.MapClaims{ + "iss": "agentcube-workload-manager", + "exp": time.Now().Add(time.Minute).Unix(), + "iat": time.Now().Unix(), + "session_public_key": sessionPubPEM, + } + token := generateToken(claims) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + body, _ := json.Marshal(map[string]string{"token": token}) + c.Request, _ = http.NewRequest("POST", "/init", bytes.NewBuffer(body)) + c.Request.Header.Set("Content-Type", "application/json") + + server.InitHandler(c) + + assert.Equal(t, http.StatusOK, w.Code) + var resp map[string]string + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "initialized successfully", resp["status"]) + }) + + t.Run("already initialized", func(t *testing.T) { + claims := jwt.MapClaims{ + "iss": "agentcube-workload-manager", + "exp": time.Now().Add(time.Minute).Unix(), + "iat": time.Now().Unix(), + "session_public_key": sessionPubPEM, + } + token := generateToken(claims) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + body, _ := json.Marshal(map[string]string{"token": token}) + c.Request, _ = http.NewRequest("POST", "/init", bytes.NewBuffer(body)) + c.Request.Header.Set("Content-Type", "application/json") + + server.InitHandler(c) + + assert.Equal(t, http.StatusConflict, w.Code) + var resp map[string]string + err := json.Unmarshal(w.Body.Bytes(), &resp) + require.NoError(t, err) + assert.Equal(t, "session already initialized", resp["error"]) + }) +} diff --git a/pkg/router/handlers.go b/pkg/router/handlers.go index adbd0fdd..6e9a8cb5 100644 --- a/pkg/router/handlers.go +++ b/pkg/router/handlers.go @@ -18,6 +18,7 @@ package router import ( "context" + "errors" "fmt" "net" "net/http" @@ -31,6 +32,7 @@ import ( "k8s.io/klog/v2" "github.com/volcano-sh/agentcube/pkg/common/types" + "github.com/volcano-sh/agentcube/pkg/store" ) // handleHealthLive handles liveness probe @@ -240,7 +242,22 @@ func (s *Server) generateSandboxJWT(c *gin.Context, sandbox *types.SandboxInfo) claims := map[string]interface{}{ "session_id": sandbox.SessionID, } - token, err := s.jwtManager.GenerateToken(claims) + var token string + var err error + + privKeyPEM, getErr := s.storeClient.GetSessionPrivateKey(c.Request.Context(), sandbox.SessionID) + if getErr == nil { + token, err = s.jwtManager.GenerateTokenWithKey(claims, privKeyPEM) + } else if errors.Is(getErr, store.ErrNotFound) { + token, err = s.jwtManager.GenerateToken(claims) + } else { + klog.Errorf("Failed to retrieve session private key from store (session: %s): %v", sandbox.SessionID, getErr) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "failed to sign request", + "code": "SESSION_KEY_RETRIEVAL_FAILED", + }) + return "", false + } if err != nil { klog.Errorf("Failed to generate JWT token (session: %s): %v", sandbox.SessionID, err) c.JSON(http.StatusInternalServerError, gin.H{ diff --git a/pkg/router/jwt.go b/pkg/router/jwt.go index d7b04626..dd1de638 100644 --- a/pkg/router/jwt.go +++ b/pkg/router/jwt.go @@ -24,6 +24,7 @@ import ( "encoding/pem" "fmt" "os" + "sync" "time" "github.com/golang-jwt/jwt/v5" @@ -62,6 +63,8 @@ type JWTManager struct { privateKey *rsa.PrivateKey publicKey *rsa.PublicKey clientset kubernetes.Interface + keyCache map[string]*rsa.PrivateKey + cacheMu sync.RWMutex } // NewJWTManager creates a new JWT manager with a fresh RSA key pair @@ -74,6 +77,7 @@ func NewJWTManager() (*JWTManager, error) { return &JWTManager{ privateKey: privateKey, publicKey: &privateKey.PublicKey, + keyCache: make(map[string]*rsa.PrivateKey), }, nil } @@ -103,6 +107,53 @@ func (jm *JWTManager) GenerateToken(claims map[string]interface{}) (string, erro return tokenString, nil } +const keyCacheMaxSize = 1000 + +// GenerateTokenWithKey generates a JWT token signed with a specific PEM-encoded +// private key. The parsed key is cached to avoid repeated RSA parsing overhead. +// Eviction is random-based: a random entry is removed when the cache is full, +// avoiding the thundering-herd problem of wiping the entire cache at once. +func (jm *JWTManager) GenerateTokenWithKey(claims map[string]interface{}, privateKeyPEM string) (string, error) { + jm.cacheMu.RLock() + privKey, ok := jm.keyCache[privateKeyPEM] + jm.cacheMu.RUnlock() + + if !ok { + parsedKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKeyPEM)) + if err != nil { + return "", fmt.Errorf("failed to parse private key: %w", err) + } + privKey = parsedKey + + jm.cacheMu.Lock() + // Evict a single random entry when at capacity instead of clearing all. + if len(jm.keyCache) >= keyCacheMaxSize { + for k := range jm.keyCache { + delete(jm.keyCache, k) + break + } + } + jm.keyCache[privateKeyPEM] = privKey + jm.cacheMu.Unlock() + } + + jwtClaims := jwt.MapClaims{ + "exp": jwt.NewNumericDate(time.Now().Add(jwtExpiration)), + "iat": jwt.NewNumericDate(time.Now()), + "iss": "agentcube-router", + } + for k, v := range claims { + jwtClaims[k] = v + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwtClaims) + tokenString, err := token.SignedString(privKey) + if err != nil { + return "", fmt.Errorf("failed to sign JWT token: %w", err) + } + return tokenString, nil +} + // GetPublicKeyPEM returns the public key in PEM format func (jm *JWTManager) GetPublicKeyPEM() ([]byte, error) { pubKeyBytes, err := x509.MarshalPKIXPublicKey(jm.publicKey) diff --git a/pkg/router/server.go b/pkg/router/server.go index 5bd19472..0abe96f6 100644 --- a/pkg/router/server.go +++ b/pkg/router/server.go @@ -73,8 +73,9 @@ func NewServer(config *Config) (*Server, error) { // Create a reusable HTTP transport for connection pooling httpTransport := &http.Transport{ - IdleConnTimeout: 0, - DisableCompression: false, + IdleConnTimeout: 85 * time.Second, + MaxIdleConnsPerHost: 100, + DisableCompression: false, } server := &Server{ diff --git a/pkg/router/session_manager_test.go b/pkg/router/session_manager_test.go index 8740813b..dc172d67 100644 --- a/pkg/router/session_manager_test.go +++ b/pkg/router/session_manager_test.go @@ -98,6 +98,14 @@ func (f *fakeStoreClient) Close() error { return nil } +func (f *fakeStoreClient) StoreSessionPrivateKey(_ context.Context, _ string, _ string) error { + return nil +} + +func (f *fakeStoreClient) GetSessionPrivateKey(_ context.Context, _ string) (string, error) { + return "", nil +} + // ---- tests: GetSandboxBySession ---- func TestGetSandboxBySession_Success(t *testing.T) { diff --git a/pkg/store/interface.go b/pkg/store/interface.go index 72e74749..5e24c2bd 100644 --- a/pkg/store/interface.go +++ b/pkg/store/interface.go @@ -40,6 +40,10 @@ type Store interface { ListInactiveSandboxes(ctx context.Context, before time.Time, limit int64) ([]*types.SandboxInfo, error) // UpdateSessionLastActivity updates the last-activity index for the given session UpdateSessionLastActivity(ctx context.Context, sessionID string, at time.Time) error + // StoreSessionPrivateKey stores the session private key associated with the session ID. + StoreSessionPrivateKey(ctx context.Context, sessionID string, privateKey string) error + // GetSessionPrivateKey retrieves the session private key associated with the session ID. + GetSessionPrivateKey(ctx context.Context, sessionID string) (string, error) // Close releases all resources held by the store (e.g. connection pools) Close() error } diff --git a/pkg/store/store_redis.go b/pkg/store/store_redis.go index 13880f05..bfaf0639 100644 --- a/pkg/store/store_redis.go +++ b/pkg/store/store_redis.go @@ -215,9 +215,11 @@ func (rs *redisStore) UpdateSandbox(ctx context.Context, sandboxRedis *types.San func (rs *redisStore) DeleteSandboxBySessionID(ctx context.Context, sessionID string) error { sessionKey := rs.sessionKey(sessionID) + sessionKeyKey := "session_key:" + sessionID pipe := rs.cli.Pipeline() pipe.Del(ctx, sessionKey) + pipe.Del(ctx, sessionKeyKey) pipe.ZRem(ctx, rs.expiryIndexKey, sessionID) pipe.ZRem(ctx, rs.lastActivityIndexKey, sessionID) @@ -325,3 +327,30 @@ func (rs *redisStore) UpdateSessionLastActivity(ctx context.Context, sessionID s return nil } + +func (rs *redisStore) StoreSessionPrivateKey(ctx context.Context, sessionID string, privateKey string) error { + if sessionID == "" { + return errors.New("StoreSessionPrivateKey: sessionID is empty") + } + key := "session_key:" + sessionID + err := rs.cli.Set(ctx, key, privateKey, 0).Err() + if err != nil { + return fmt.Errorf("StoreSessionPrivateKey: redis SET %s failed: %w", key, err) + } + return nil +} + +func (rs *redisStore) GetSessionPrivateKey(ctx context.Context, sessionID string) (string, error) { + if sessionID == "" { + return "", errors.New("GetSessionPrivateKey: sessionID is empty") + } + key := "session_key:" + sessionID + val, err := rs.cli.Get(ctx, key).Result() + if errors.Is(err, redisv9.Nil) { + return "", ErrNotFound + } + if err != nil { + return "", fmt.Errorf("GetSessionPrivateKey: redis GET %s failed: %w", key, err) + } + return val, nil +} diff --git a/pkg/store/store_valkey.go b/pkg/store/store_valkey.go index 5e68b20d..5c7c980f 100644 --- a/pkg/store/store_valkey.go +++ b/pkg/store/store_valkey.go @@ -227,9 +227,11 @@ func (vs *valkeyStore) UpdateSandbox(ctx context.Context, sandboxStore *types.Sa // DeleteSandboxBySessionID delete sandbox by session ID func (vs *valkeyStore) DeleteSandboxBySessionID(ctx context.Context, sessionID string) error { sessionKey := vs.sessionKey(sessionID) + sessionKeyKey := "session_key:" + sessionID - commands := make(valkey.Commands, 0, 4) + commands := make(valkey.Commands, 0, 5) commands = append(commands, vs.cli.B().Del().Key(sessionKey).Build()) + commands = append(commands, vs.cli.B().Del().Key(sessionKeyKey).Build()) commands = append(commands, vs.cli.B().Zrem().Key(vs.expiryIndexKey).Member(sessionID).Build()) commands = append(commands, vs.cli.B().Zrem().Key(vs.lastActivityIndexKey).Member(sessionID).Build()) @@ -327,3 +329,30 @@ func (vs *valkeyStore) UpdateSessionLastActivity(ctx context.Context, sessionID } return nil } + +func (vs *valkeyStore) StoreSessionPrivateKey(ctx context.Context, sessionID string, privateKey string) error { + if sessionID == "" { + return errors.New("StoreSessionPrivateKey: sessionID is empty") + } + key := "session_key:" + sessionID + err := vs.cli.Do(ctx, vs.cli.B().Set().Key(key).Value(privateKey).Build()).Error() + if err != nil { + return fmt.Errorf("StoreSessionPrivateKey: valkey SET %s failed: %w", key, err) + } + return nil +} + +func (vs *valkeyStore) GetSessionPrivateKey(ctx context.Context, sessionID string) (string, error) { + if sessionID == "" { + return "", errors.New("GetSessionPrivateKey: sessionID is empty") + } + key := "session_key:" + sessionID + val, err := vs.cli.Do(ctx, vs.cli.B().Get().Key(key).Build()).ToString() + if err != nil { + if valkey.IsValkeyNil(err) { + return "", ErrNotFound + } + return "", fmt.Errorf("GetSessionPrivateKey: valkey GET %s: %w", key, err) + } + return val, nil +} diff --git a/pkg/workloadmanager/bootstrap_auth.go b/pkg/workloadmanager/bootstrap_auth.go new file mode 100644 index 00000000..b98cf6c8 --- /dev/null +++ b/pkg/workloadmanager/bootstrap_auth.go @@ -0,0 +1,196 @@ +/* +Copyright The Volcano Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package workloadmanager + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/klog/v2" +) + +const ( + bootstrapKeySize = 2048 + bootstrapSecretName = "agentcube-bootstrap-identity" + bootstrapPrivKeyKey = "bootstrap-private.pem" + bootstrapPubKeyKey = "bootstrap-public.pem" + bootstrapJWTExpiration = 2 * time.Minute + bootstrapJWTIssuer = "agentcube-workload-manager" +) + +// BootstrapAuthManager owns the bootstrap keypair used to sign /init JWTs. +// It must be instantiated via NewBootstrapAuthManager and held on Server — +// never shared as a package-level global, which would break parallel tests +// and multi-instance deployments. +type BootstrapAuthManager struct { + privateKey *rsa.PrivateKey + publicKeyPEM string + namespace string +} + +// NewBootstrapAuthManager loads the bootstrap keypair from a Kubernetes Secret, +// or generates and persists a new one if the Secret does not exist yet. +// Persisting the keypair means it survives Workload Manager restarts, preventing +// stranded PicoD pods whose containers already received the old public key via +// PICOD_BOOTSTRAP_PUBLIC_KEY. +func NewBootstrapAuthManager(ctx context.Context, clientset kubernetes.Interface, namespace string) (*BootstrapAuthManager, error) { + m := &BootstrapAuthManager{namespace: namespace} + + secret, err := clientset.CoreV1().Secrets(namespace).Get(ctx, bootstrapSecretName, metav1.GetOptions{}) + if err != nil && !apierrors.IsNotFound(err) { + return nil, fmt.Errorf("failed to get bootstrap secret: %w", err) + } + + if err == nil { + // Secret already exists — load the persisted keypair. + privPEM, ok := secret.Data[bootstrapPrivKeyKey] + if !ok { + return nil, fmt.Errorf("bootstrap secret %s/%s is missing key %q", + namespace, bootstrapSecretName, bootstrapPrivKeyKey) + } + pubPEM, ok := secret.Data[bootstrapPubKeyKey] + if !ok { + return nil, fmt.Errorf("bootstrap secret %s/%s is missing key %q", + namespace, bootstrapSecretName, bootstrapPubKeyKey) + } + if err := m.loadPrivKeyPEM(privPEM); err != nil { + return nil, fmt.Errorf("failed to parse bootstrap private key from secret: %w", err) + } + m.publicKeyPEM = string(pubPEM) + klog.Infof("Loaded bootstrap keypair from existing secret %s/%s", namespace, bootstrapSecretName) + return m, nil + } + + // Secret does not exist — generate a new keypair and persist it. + privKey, err := rsa.GenerateKey(rand.Reader, bootstrapKeySize) + if err != nil { + return nil, fmt.Errorf("failed to generate bootstrap RSA key: %w", err) + } + m.privateKey = privKey + + pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal bootstrap public key: %w", err) + } + pubPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubKeyBytes}) + m.publicKeyPEM = string(pubPEM) + + privPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privKey), + }) + + newSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: bootstrapSecretName, + Namespace: namespace, + Labels: map[string]string{"app": "agentcube", "component": "workload-manager"}, + }, + Type: corev1.SecretTypeOpaque, + Data: map[string][]byte{ + bootstrapPrivKeyKey: privPEM, + bootstrapPubKeyKey: pubPEM, + }, + } + if _, createErr := clientset.CoreV1().Secrets(namespace).Create(ctx, newSecret, metav1.CreateOptions{}); createErr != nil { + // Another replica won the race and created the secret first — load from it. + if apierrors.IsAlreadyExists(createErr) { + return NewBootstrapAuthManager(ctx, clientset, namespace) + } + return nil, fmt.Errorf("failed to create bootstrap secret: %w", createErr) + } + + klog.Infof("Generated new bootstrap keypair and persisted to secret %s/%s", namespace, bootstrapSecretName) + return m, nil +} + + +// PublicKeyPEM returns the bootstrap public key in PEM format. +// Inject this value as PICOD_BOOTSTRAP_PUBLIC_KEY into PicoD container +// environments so PicoD can verify /init JWTs at startup. +func (m *BootstrapAuthManager) PublicKeyPEM() string { + return m.publicKeyPEM +} + + +// GenerateInitJWT creates a short-lived JWT signed by the bootstrap private key. +// The JWT carries the session_public_key claim so PicoD can store it and use it +// to verify all subsequent user-request JWTs for this sandbox session. +// The "sub" claim is set to sessionID so PicoD can bind the key to the correct session. +func (m *BootstrapAuthManager) GenerateInitJWT(sessionID, sessionPubPEM string) (string, error) { + claims := jwt.MapClaims{ + "iss": bootstrapJWTIssuer, + "sub": sessionID, + "exp": jwt.NewNumericDate(time.Now().Add(bootstrapJWTExpiration)), + "iat": jwt.NewNumericDate(time.Now()), + "session_public_key": sessionPubPEM, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signed, err := token.SignedString(m.privateKey) + if err != nil { + return "", fmt.Errorf("failed to sign init JWT: %w", err) + } + return signed, nil +} + +// loadPrivKeyPEM parses a PKCS#1 PEM-encoded RSA private key into m.privateKey. +func (m *BootstrapAuthManager) loadPrivKeyPEM(data []byte) error { + block, _ := pem.Decode(data) + if block == nil { + return fmt.Errorf("failed to decode PEM block from bootstrap private key data") + } + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse PKCS1 private key: %w", err) + } + m.privateKey = key + return nil +} + +// GenerateSessionKeyPair generates a unique 2048-bit RSA keypair for one sandbox session. +// The private key is held only in Router memory and is never persisted to the KV store. +// The public key is delivered to PicoD via the /init JWT so PicoD can verify +// all subsequent user-request JWTs signed by the Router for this session. +func GenerateSessionKeyPair() (privPEM string, pubPEM string, err error) { + key, err := rsa.GenerateKey(rand.Reader, bootstrapKeySize) + if err != nil { + return "", "", fmt.Errorf("failed to generate session RSA key: %w", err) + } + + privBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + pubBytes, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + if err != nil { + return "", "", fmt.Errorf("failed to marshal session public key: %w", err) + } + pubPEMBytes := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}) + + return string(privBytes), string(pubPEMBytes), nil +} \ No newline at end of file diff --git a/pkg/workloadmanager/bootstrap_auth_test.go b/pkg/workloadmanager/bootstrap_auth_test.go new file mode 100644 index 00000000..8e21bb87 --- /dev/null +++ b/pkg/workloadmanager/bootstrap_auth_test.go @@ -0,0 +1,89 @@ +/* +Copyright The Volcano Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package workloadmanager + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateSessionKeyPair(t *testing.T) { + privPEM, pubPEM, err := GenerateSessionKeyPair() + require.NoError(t, err) + assert.NotEmpty(t, privPEM) + assert.NotEmpty(t, pubPEM) + + // Verify private key parses + blockPriv, _ := pem.Decode([]byte(privPEM)) + require.NotNil(t, blockPriv) + assert.Equal(t, "RSA PRIVATE KEY", blockPriv.Type) + + privKey, err := x509.ParsePKCS1PrivateKey(blockPriv.Bytes) + require.NoError(t, err) + assert.Equal(t, 2048, privKey.N.BitLen()) + + // Verify public key parses + blockPub, _ := pem.Decode([]byte(pubPEM)) + require.NotNil(t, blockPub) + assert.Equal(t, "PUBLIC KEY", blockPub.Type) + + pubKey, err := x509.ParsePKIXPublicKey(blockPub.Bytes) + require.NoError(t, err) + _, ok := pubKey.(*rsa.PublicKey) + assert.True(t, ok) +} + +func TestGenerateInitJWT(t *testing.T) { + m := newTestBootstrapAuth(t) + + sandboxID := "test-session-id" + sessionPublicKey := "test-session-public-key" + + tokenStr, err := m.GenerateInitJWT(sandboxID, sessionPublicKey) + require.NoError(t, err) + assert.NotEmpty(t, tokenStr) + + // Verify the JWT using the generated Bootstrap Public Key + bootstrapPubPEM := m.PublicKeyPEM() + block, _ := pem.Decode([]byte(bootstrapPubPEM)) + require.NotNil(t, block) + + pubKeyInterface, err := x509.ParsePKIXPublicKey(block.Bytes) + require.NoError(t, err) + bootstrapPubKey, ok := pubKeyInterface.(*rsa.PublicKey) + require.True(t, ok) + + token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, assert.AnError + } + return bootstrapPubKey, nil + }) + require.NoError(t, err) + assert.True(t, token.Valid) + + claims, ok := token.Claims.(jwt.MapClaims) + assert.True(t, ok) + assert.Equal(t, sandboxID, claims["sub"]) + assert.Equal(t, sessionPublicKey, claims["session_public_key"]) +} diff --git a/pkg/workloadmanager/codeinterpreter_controller.go b/pkg/workloadmanager/codeinterpreter_controller.go index da010cbb..365dfc59 100644 --- a/pkg/workloadmanager/codeinterpreter_controller.go +++ b/pkg/workloadmanager/codeinterpreter_controller.go @@ -43,7 +43,8 @@ import ( // CodeInterpreterReconciler reconciles a CodeInterpreter object type CodeInterpreterReconciler struct { client.Client - Scheme *runtime.Scheme + Scheme *runtime.Scheme + BootstrapPublicKeyPEM string } // Reconcile is part of the main kubernetes reconciliation loop which aims to @@ -300,8 +301,8 @@ func (r *CodeInterpreterReconciler) convertToPodTemplate(template *runtimev1alph // Only inject public key for picod auth mode (default behavior) if ci.Spec.AuthMode != runtimev1alpha1.AuthModeNone { envVars = append(envVars, corev1.EnvVar{ - Name: "PICOD_AUTH_PUBLIC_KEY", - Value: GetCachedPublicKey(), + Name: "PICOD_BOOTSTRAP_PUBLIC_KEY", + Value: r.BootstrapPublicKeyPEM, }) } diff --git a/pkg/workloadmanager/codeinterpreter_controller_test.go b/pkg/workloadmanager/codeinterpreter_controller_test.go index 1404c861..9a5145be 100644 --- a/pkg/workloadmanager/codeinterpreter_controller_test.go +++ b/pkg/workloadmanager/codeinterpreter_controller_test.go @@ -164,16 +164,16 @@ func TestConvertToPodTemplate_AuthMode(t *testing.T) { foundPublicKey := false for _, env := range envVars { - if env.Name == "PICOD_AUTH_PUBLIC_KEY" { + if env.Name == "PICOD_BOOTSTRAP_PUBLIC_KEY" { foundPublicKey = true break } } if tt.expectPublicKeyVar { - assert.True(t, foundPublicKey, "PICOD_AUTH_PUBLIC_KEY should be injected") + assert.True(t, foundPublicKey, "PICOD_BOOTSTRAP_PUBLIC_KEY should be injected") } else { - assert.False(t, foundPublicKey, "PICOD_AUTH_PUBLIC_KEY should not be injected") + assert.False(t, foundPublicKey, "PICOD_BOOTSTRAP_PUBLIC_KEY should not be injected") } }) } diff --git a/pkg/workloadmanager/garbage_collection_test.go b/pkg/workloadmanager/garbage_collection_test.go index 531e9908..33462c18 100644 --- a/pkg/workloadmanager/garbage_collection_test.go +++ b/pkg/workloadmanager/garbage_collection_test.go @@ -52,6 +52,12 @@ func (nopStore) ListInactiveSandboxes(_ context.Context, _ time.Time, _ int64) ( func (nopStore) UpdateSessionLastActivity(_ context.Context, _ string, _ time.Time) error { return nil } +func (nopStore) StoreSessionPrivateKey(_ context.Context, _ string, _ string) error { + return nil +} +func (nopStore) GetSessionPrivateKey(_ context.Context, _ string) (string, error) { + return "", nil +} func (nopStore) Close() error { return nil } // gcFakeStore is a controllable store for GC tests. diff --git a/pkg/workloadmanager/handlers.go b/pkg/workloadmanager/handlers.go index 7d417c6e..919916e3 100644 --- a/pkg/workloadmanager/handlers.go +++ b/pkg/workloadmanager/handlers.go @@ -107,7 +107,7 @@ func (s *Server) handleSandboxCreate(c *gin.Context, kind string) { case types.AgentRuntimeKind: sandbox, sandboxEntry, err = buildSandboxByAgentRuntime(sandboxReq.Namespace, sandboxReq.Name, s.informers) case types.CodeInterpreterKind: - sandbox, sandboxClaim, sandboxEntry, err = buildSandboxByCodeInterpreter(sandboxReq.Namespace, sandboxReq.Name, s.informers) + sandbox, sandboxClaim, sandboxEntry, err = buildSandboxByCodeInterpreter(sandboxReq.Namespace, sandboxReq.Name, s.informers, s.GetBootstrapPublicKeyPEM()) } if err != nil { @@ -195,14 +195,21 @@ func (s *Server) createK8sResources(ctx context.Context, dynamicClient dynamic.I return nil } +// wrapInternalError returns err unchanged if it is a context error; otherwise +// it wraps msg into an API internal error. This collapses the repetitive +// isContextError guard that appears in many call sites. +func wrapInternalError(err error, msg string) error { + if isContextError(err) { + return err + } + return api.NewInternalError(fmt.Errorf("%s: %w", msg, err)) +} + // createSandbox performs sandbox creation and returns the response payload or an error with an HTTP status code. func (s *Server) createSandbox(ctx context.Context, dynamicClient dynamic.Interface, sandbox *sandboxv1alpha1.Sandbox, sandboxClaim *extensionsv1alpha1.SandboxClaim, sandboxEntry *sandboxEntry, resultChan <-chan SandboxStatusUpdate) (*types.CreateSandboxResponse, error) { placeholder := buildSandboxPlaceHolder(sandbox, sandboxEntry) if err := s.storeClient.StoreSandbox(ctx, placeholder); err != nil { - if isContextError(err) { - return nil, err - } - return nil, api.NewInternalError(fmt.Errorf("store sandbox placeholder failed: %w", err)) + return nil, wrapInternalError(err, "store sandbox placeholder failed") } // Register rollback right after the placeholder is stored so that a K8s @@ -219,16 +226,54 @@ func (s *Server) createSandbox(ctx context.Context, dynamicClient dynamic.Interf return nil, err } + createdSandbox, err := s.waitForSandboxReady(ctx, sandbox, resultChan) + if err != nil { + return nil, err + } + + podIP, err := s.prepareSandbox(ctx, sandbox, createdSandbox, sandboxEntry) + if err != nil { + return nil, err + } + + if sandboxEntry.SessionPrivateKey != "" { + if err := s.storeClient.StoreSessionPrivateKey(ctx, sandboxEntry.SessionID, sandboxEntry.SessionPrivateKey); err != nil { + return nil, wrapInternalError(err, "failed to persist session private key") + } + } + + storeCacheInfo := buildSandboxInfo(createdSandbox, podIP, sandboxEntry) + + response := &types.CreateSandboxResponse{ + Kind: storeCacheInfo.Kind, + SessionID: sandboxEntry.SessionID, + SandboxID: storeCacheInfo.SandboxID, + SandboxName: sandbox.Name, + EntryPoints: storeCacheInfo.EntryPoints, + } + + if err := s.storeClient.UpdateSandbox(ctx, storeCacheInfo); err != nil { + return nil, wrapInternalError(err, "update store cache failed") + } + + needRollbackSandbox = false + klog.V(2).Infof("init sandbox %s/%s successfully, kind: %s, sessionID: %s", createdSandbox.Namespace, + createdSandbox.Name, createdSandbox.Kind, sandboxEntry.SessionID) + return response, nil +} + +// waitForSandboxReady blocks until the sandbox reports ready, the context is +// canceled, or the creation timeout fires. +func (s *Server) waitForSandboxReady(ctx context.Context, sandbox *sandboxv1alpha1.Sandbox, resultChan <-chan SandboxStatusUpdate) (*sandboxv1alpha1.Sandbox, error) { // Use NewTimer so we can stop it explicitly when another branch wins, // preventing the runtime from retaining the timer until it fires. timer := time.NewTimer(2 * time.Minute) // consistent with router settings - var createdSandbox *sandboxv1alpha1.Sandbox select { case result := <-resultChan: timer.Stop() - createdSandbox = result.Sandbox - klog.V(2).Infof("sandbox %s/%s reported ready, verifying entrypoints", createdSandbox.Namespace, createdSandbox.Name) + klog.V(2).Infof("sandbox %s/%s reported ready, verifying entrypoints", result.Sandbox.Namespace, result.Sandbox.Name) + return result.Sandbox, nil case <-ctx.Done(): timer.Stop() klog.Warningf("sandbox %s/%s wait canceled: %v", sandbox.Namespace, sandbox.Name, ctx.Err()) @@ -237,10 +282,13 @@ func (s *Server) createSandbox(ctx context.Context, dynamicClient dynamic.Interf klog.Warningf("sandbox %s/%s create timed out", sandbox.Namespace, sandbox.Name) return nil, errSandboxCreationTimeout } +} - // agent-sandbox create pod with same name as sandbox if no warmpool is used - // so here we try to get pod IP by sandbox name first - // if warmpool is used, the pod name is stored in sandbox's annotation `agents.x-k8s.io/sandbox-pod-name` +// prepareSandbox resolves the pod IP, probes entrypoints, and initializes PicoD +// for a sandbox that has been reported ready by the controller. +func (s *Server) prepareSandbox(ctx context.Context, sandbox *sandboxv1alpha1.Sandbox, createdSandbox *sandboxv1alpha1.Sandbox, sandboxEntry *sandboxEntry) (string, error) { + // agent-sandbox creates pod with same name as sandbox if no warmpool is used. + // If warmpool is used, the pod name is stored in the sandbox's annotation. // https://github.com/kubernetes-sigs/agent-sandbox/blob/3ab7fbcd85ad0d75c6e632ecd14bcaeda5e76e1e/controllers/sandbox_controller.go#L465 sandboxPodName := sandbox.Name if podName, exists := createdSandbox.Annotations[controllers.SandboxPodNameAnnotation]; exists { @@ -249,39 +297,16 @@ func (s *Server) createSandbox(ctx context.Context, dynamicClient dynamic.Interf podIP, err := s.k8sClient.GetSandboxPodIP(ctx, sandbox.Namespace, sandbox.Name, sandboxPodName) if err != nil { - if isContextError(err) { - return nil, err - } - return nil, api.NewInternalError(fmt.Errorf("failed to get sandbox %s/%s pod IP: %w", sandbox.Namespace, sandbox.Name, err)) + return "", wrapInternalError(err, fmt.Sprintf("failed to get sandbox %s/%s pod IP", sandbox.Namespace, sandbox.Name)) } if err := s.waitForSandboxEntryPointsReady(ctx, podIP, sandboxEntry); err != nil { - if isContextError(err) { - return nil, err - } - return nil, api.NewInternalError(fmt.Errorf("failed to verify sandbox %s/%s entrypoints: %w", sandbox.Namespace, sandbox.Name, err)) + return "", wrapInternalError(err, fmt.Sprintf("failed to verify sandbox %s/%s entrypoints", sandbox.Namespace, sandbox.Name)) } - - storeCacheInfo := buildSandboxInfo(createdSandbox, podIP, sandboxEntry) - - response := &types.CreateSandboxResponse{ - Kind: storeCacheInfo.Kind, - SessionID: sandboxEntry.SessionID, - SandboxID: storeCacheInfo.SandboxID, - SandboxName: sandbox.Name, - EntryPoints: storeCacheInfo.EntryPoints, + if err := s.initializePicoD(ctx, podIP, sandboxEntry); err != nil { + return "", wrapInternalError(err, fmt.Sprintf("failed to initialize PicoD on sandbox %s/%s", sandbox.Namespace, sandbox.Name)) } - if err := s.storeClient.UpdateSandbox(ctx, storeCacheInfo); err != nil { - if isContextError(err) { - return nil, err - } - return nil, api.NewInternalError(fmt.Errorf("update store cache failed: %w", err)) - } - - needRollbackSandbox = false - klog.V(2).Infof("init sandbox %s/%s successfully, kind: %s, sessionID: %s", createdSandbox.Namespace, - createdSandbox.Name, createdSandbox.Kind, sandboxEntry.SessionID) - return response, nil + return podIP, nil } // rollbackSandboxCreation deletes the sandbox (or sandbox claim) and its store diff --git a/pkg/workloadmanager/handlers_test.go b/pkg/workloadmanager/handlers_test.go index a71aa124..42bc6b7b 100644 --- a/pkg/workloadmanager/handlers_test.go +++ b/pkg/workloadmanager/handlers_test.go @@ -72,6 +72,12 @@ func (f *fakeStore) ListInactiveSandboxes(_ context.Context, _ time.Time, _ int6 func (f *fakeStore) UpdateSessionLastActivity(_ context.Context, _ string, _ time.Time) error { return nil } +func (f *fakeStore) StoreSessionPrivateKey(_ context.Context, _ string, _ string) error { + return nil +} +func (f *fakeStore) GetSessionPrivateKey(_ context.Context, _ string) (string, error) { + return "", nil +} func (f *fakeStore) Close() error { return nil } func readySandbox() *sandboxv1alpha1.Sandbox { @@ -431,7 +437,7 @@ func TestHandleSandboxCreate(t *testing.T) { return sb, entry, nil }) - patches.ApplyFunc(buildSandboxByCodeInterpreter, func(_, _ string, _ *Informers) (*sandboxv1alpha1.Sandbox, *extensionsv1alpha1.SandboxClaim, *sandboxEntry, error) { + patches.ApplyFunc(buildSandboxByCodeInterpreter, func(_, _ string, _ *Informers, _ string) (*sandboxv1alpha1.Sandbox, *extensionsv1alpha1.SandboxClaim, *sandboxEntry, error) { if tc.kind != types.CodeInterpreterKind { return nil, nil, nil, errors.New("unexpected kind") } diff --git a/pkg/workloadmanager/k8s_client.go b/pkg/workloadmanager/k8s_client.go index 15cc47bd..3917cd16 100644 --- a/pkg/workloadmanager/k8s_client.go +++ b/pkg/workloadmanager/k8s_client.go @@ -74,10 +74,12 @@ type K8sClient struct { } type sandboxEntry struct { - Kind string - SessionID string - Ports []runtimev1alpha1.TargetPort - IdleTimeout time.Duration + Kind string + SessionID string + Ports []runtimev1alpha1.TargetPort + IdleTimeout time.Duration + AuthMode runtimev1alpha1.AuthModeType + SessionPrivateKey string } // NewK8sClient creates a new Kubernetes client diff --git a/pkg/workloadmanager/sandbox_helper.go b/pkg/workloadmanager/sandbox_helper.go index 322b5d8d..b2e7c206 100644 --- a/pkg/workloadmanager/sandbox_helper.go +++ b/pkg/workloadmanager/sandbox_helper.go @@ -17,12 +17,17 @@ limitations under the License. package workloadmanager import ( + "bytes" "context" + "encoding/json" "fmt" "net" + "net/http" "strconv" "time" + "k8s.io/klog/v2" + runtimev1alpha1 "github.com/volcano-sh/agentcube/pkg/apis/runtime/v1alpha1" "github.com/volcano-sh/agentcube/pkg/common/types" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -33,6 +38,7 @@ const ( defaultSandboxReadyProbeTimeout = 15 * time.Second defaultSandboxReadyProbeInterval = 1 * time.Second defaultSandboxReadyDialTimeout = 1 * time.Second + defaultPicoInitTimeout = 5 * time.Second sandboxStatusReady = "ready" sandboxStatusNotReady = "not-ready" @@ -88,16 +94,17 @@ func buildSandboxInfo(sandbox *sandboxv1alpha1.Sandbox, podIP string, entry *san idleTimeout = DefaultSandboxIdleTimeout } return &types.SandboxInfo{ - Kind: entry.Kind, - SandboxID: string(sandbox.GetUID()), - Name: sandbox.GetName(), - SandboxNamespace: sandbox.GetNamespace(), - EntryPoints: accesses, - SessionID: entry.SessionID, - CreatedAt: createdAt, - ExpiresAt: expiresAt, - Status: getSandboxStatus(sandbox), - IdleTimeout: metav1.Duration{Duration: idleTimeout}, + Kind: entry.Kind, + SandboxID: string(sandbox.GetUID()), + Name: sandbox.GetName(), + SandboxNamespace: sandbox.GetNamespace(), + EntryPoints: accesses, + SessionID: entry.SessionID, + CreatedAt: createdAt, + ExpiresAt: expiresAt, + Status: getSandboxStatus(sandbox), + IdleTimeout: metav1.Duration{Duration: idleTimeout}, + SessionPrivateKey: entry.SessionPrivateKey, } } @@ -161,3 +168,71 @@ func probeSandboxEntryPoints(ctx context.Context, podIP string, ports []runtimev return nil } + + +// picodInitPortName is the well-known port name used for PicoD's management HTTP API. +// The /init call is always sent here, regardless of what other ports are present. +const picodInitPortName = "picod" + +// findPicoDInitPort returns the port to use for the /init call. +// It prefers a port whose name matches picodInitPortName; falls back to port 8080. +func findPicoDInitPort(ports []runtimev1alpha1.TargetPort) uint32 { + for _, p := range ports { + if p.Name == picodInitPortName { + return p.Port + } + } + // Fallback: use 8080. The absence of a named port is a misconfiguration — + // log a warning so operators can fix the runtime spec. + klog.Warningf("no port named %q found in sandbox entry; using fallback port 8080 for /init", picodInitPortName) + return 8080 +} + +func (s *Server) initializePicoD(ctx context.Context, podIP string, entry *sandboxEntry) error { + if entry == nil || entry.AuthMode != runtimev1alpha1.AuthModePicoD { + return nil + } + + port := findPicoDInitPort(entry.Ports) + endpoint := fmt.Sprintf("http://%s:%d/init", podIP, port) + + privPEM, pubPEM, err := GenerateSessionKeyPair() + if err != nil { + return fmt.Errorf("failed to generate session key pair: %w", err) + } + + // Use the struct-based manager (not a global function) so tests can inject + // an isolated BootstrapAuthManager instance. + token, err := s.bootstrapAuth.GenerateInitJWT(entry.SessionID, pubPEM) + if err != nil { + return fmt.Errorf("failed to generate init JWT: %w", err) + } + + payload := map[string]string{"token": token} + bodyBytes, err := json.Marshal(payload) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(bodyBytes)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := s.httpClient.Do(req) + if err != nil { + return fmt.Errorf("POST /init failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("POST /init returned status: %s", resp.Status) + } + + // Assign SessionPrivateKey ONLY after confirmed 200 OK. + // This prevents a partial-init state where the Router holds a key + // that PicoD never accepted. + entry.SessionPrivateKey = privPEM + return nil +} diff --git a/pkg/workloadmanager/sandbox_helper_test.go b/pkg/workloadmanager/sandbox_helper_test.go index 5106678b..8a927ff9 100644 --- a/pkg/workloadmanager/sandbox_helper_test.go +++ b/pkg/workloadmanager/sandbox_helper_test.go @@ -17,437 +17,124 @@ limitations under the License. package workloadmanager import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "net/http/httptest" "testing" - "time" "github.com/stretchr/testify/assert" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - sandboxv1alpha1 "sigs.k8s.io/agent-sandbox/api/v1alpha1" - + "github.com/stretchr/testify/require" runtimev1alpha1 "github.com/volcano-sh/agentcube/pkg/apis/runtime/v1alpha1" - "github.com/volcano-sh/agentcube/pkg/common/types" ) -const sandboxHelperTestPodIP = "10.0.0.1" - -func TestBuildSandboxPlaceHolder_TableDriven(t *testing.T) { - now := time.Now() +// newTestBootstrapAuth creates an in-memory BootstrapAuthManager without K8s, +// suitable for unit tests. +func newTestBootstrapAuth(t *testing.T) *BootstrapAuthManager { + t.Helper() + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) + require.NoError(t, err) + pubPEM := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubKeyBytes}) + + return &BootstrapAuthManager{ + privateKey: privKey, + publicKeyPEM: string(pubPEM), + namespace: "test", + } +} - tests := []struct { - name string - setupSandbox func() *sandboxv1alpha1.Sandbox - entry *sandboxEntry - validate func(t *testing.T, result *types.SandboxInfo) - }{ - { - name: "no ShutdownTime falls back to DefaultSandboxTTL", - setupSandbox: func() *sandboxv1alpha1.Sandbox { - return &sandboxv1alpha1.Sandbox{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-sandbox", - Namespace: "default", - }, - } - }, - entry: &sandboxEntry{ - Kind: types.SandboxKind, - SessionID: "session-123", - }, - validate: func(t *testing.T, result *types.SandboxInfo) { - expected := now.Add(DefaultSandboxTTL) - assert.WithinDuration(t, expected, result.ExpiresAt, 2*time.Second) - assert.Equal(t, "creating", result.Status) - assert.Equal(t, "session-123", result.SessionID) - }, - }, - { - name: "ShutdownTime set to 24h is used as ExpiresAt", - setupSandbox: func() *sandboxv1alpha1.Sandbox { - shutdownTime := now.Add(24 * time.Hour) - return &sandboxv1alpha1.Sandbox{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-sandbox", - Namespace: "default", - }, - Spec: sandboxv1alpha1.SandboxSpec{ - Lifecycle: sandboxv1alpha1.Lifecycle{ - ShutdownTime: &metav1.Time{Time: shutdownTime}, - }, - }, - } - }, - entry: &sandboxEntry{ - Kind: types.SandboxKind, - SessionID: "session-456", - }, - validate: func(t *testing.T, result *types.SandboxInfo) { - expected := now.Add(24 * time.Hour) - assert.Equal(t, expected, result.ExpiresAt) - }, - }, - { - name: "ShutdownTime set to 30m overrides DefaultSandboxTTL", - setupSandbox: func() *sandboxv1alpha1.Sandbox { - shutdownTime := now.Add(30 * time.Minute) - return &sandboxv1alpha1.Sandbox{ - ObjectMeta: metav1.ObjectMeta{ - Name: "short-sandbox", - Namespace: "default", - }, - Spec: sandboxv1alpha1.SandboxSpec{ - Lifecycle: sandboxv1alpha1.Lifecycle{ - ShutdownTime: &metav1.Time{Time: shutdownTime}, - }, - }, - } - }, - entry: &sandboxEntry{ - Kind: types.SandboxClaimsKind, - SessionID: "session-789", - }, - validate: func(t *testing.T, result *types.SandboxInfo) { - expected := now.Add(30 * time.Minute) - assert.Equal(t, expected, result.ExpiresAt) - // Must NOT be 8h (DefaultSandboxTTL) - assert.True(t, result.ExpiresAt.Before(now.Add(DefaultSandboxTTL)), - "ExpiresAt should be 30m, not the 8h default") - }, - }, - { - name: "warm-pool path: ShutdownTime set on simpleSandbox reflects MaxSessionDuration", - setupSandbox: func() *sandboxv1alpha1.Sandbox { - // Simulates the simpleSandbox built by the warm-pool CodeInterpreter path - // after the fix in workload_builder.go sets ShutdownTime from MaxSessionDuration. - shutdownTime := now.Add(24 * time.Hour) - return &sandboxv1alpha1.Sandbox{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: "default", - Name: "ci-warmpool-abc", - Labels: map[string]string{ - SessionIdLabelKey: "session-wp-001", - }, - }, - Spec: sandboxv1alpha1.SandboxSpec{ - Lifecycle: sandboxv1alpha1.Lifecycle{ - ShutdownTime: &metav1.Time{Time: shutdownTime}, - }, - }, - } - }, - entry: &sandboxEntry{ - Kind: types.SandboxClaimsKind, - SessionID: "session-wp-001", - }, - validate: func(t *testing.T, result *types.SandboxInfo) { - expected := now.Add(24 * time.Hour) - assert.Equal(t, expected, result.ExpiresAt, - "warm-pool placeholder ExpiresAt must reflect MaxSessionDuration, not the 8h default") - assert.Equal(t, "creating", result.Status) - assert.Equal(t, types.SandboxClaimsKind, result.Kind) - }, +func TestInitializePicoD_SessionKeyOnlySetOnSuccess(t *testing.T) { + // Arrange: mock PicoD server that returns 200 OK + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/init", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + var body map[string]string + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.NotEmpty(t, body["token"]) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + entry := &sandboxEntry{ + SessionID: "sess-1", + AuthMode: runtimev1alpha1.AuthModePicoD, + Ports: []runtimev1alpha1.TargetPort{ + {Name: "picod", Port: tsPort(t, ts.URL)}, }, } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sandbox := tt.setupSandbox() - result := buildSandboxPlaceHolder(sandbox, tt.entry) - tt.validate(t, result) - }) + s := &Server{ + httpClient: ts.Client(), + bootstrapAuth: newTestBootstrapAuth(t), } + + err := s.initializePicoD(context.Background(), "127.0.0.1", entry) + require.NoError(t, err) + assert.NotEmpty(t, entry.SessionPrivateKey, "key must be set after 200 OK") } -func TestBuildSandboxInfo_TableDriven(t *testing.T) { - now := time.Now() +func TestInitializePicoD_SessionKeyNotSetOnFailure(t *testing.T) { + // Arrange: mock PicoD server that returns 500 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer ts.Close() - tests := []struct { - name string - setupSandbox func() *sandboxv1alpha1.Sandbox - podIP string - entry *sandboxEntry - validateResult func(t *testing.T, result *types.SandboxInfo) - }{ - { - name: "basic sandbox with ports", - setupSandbox: func() *sandboxv1alpha1.Sandbox { - return &sandboxv1alpha1.Sandbox{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-sandbox", - Namespace: "default", - UID: "test-uid-123", - CreationTimestamp: metav1.NewTime(now), - }, - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionTrue, - }, - }, - }, - } - }, - podIP: sandboxHelperTestPodIP, - entry: &sandboxEntry{ - Kind: types.AgentRuntimeKind, - SessionID: "test-session-123", - Ports: []runtimev1alpha1.TargetPort{ - { - Port: 8080, - Protocol: runtimev1alpha1.ProtocolTypeHTTP, - PathPrefix: "/api", - }, - { - Port: 9090, - Protocol: runtimev1alpha1.ProtocolTypeHTTP, - PathPrefix: "/metrics", - }, - }, - }, - validateResult: func(t *testing.T, result *types.SandboxInfo) { - assert.Equal(t, "ready", result.Status) - assert.Len(t, result.EntryPoints, 2) - assert.Equal(t, "/api", result.EntryPoints[0].Path) - assert.Equal(t, sandboxHelperTestPodIP+":8080", result.EntryPoints[0].Endpoint) - assert.Equal(t, "/metrics", result.EntryPoints[1].Path) - assert.Equal(t, sandboxHelperTestPodIP+":9090", result.EntryPoints[1].Endpoint) - }, - }, - { - name: "sandbox with shutdown time", - setupSandbox: func() *sandboxv1alpha1.Sandbox { - shutdownTime := now.Add(2 * time.Hour) - return &sandboxv1alpha1.Sandbox{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-sandbox", - Namespace: "default", - UID: "test-uid-123", - CreationTimestamp: metav1.NewTime(now), - }, - Spec: sandboxv1alpha1.SandboxSpec{ - Lifecycle: sandboxv1alpha1.Lifecycle{ - ShutdownTime: &metav1.Time{Time: shutdownTime}, - }, - }, - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionTrue, - }, - }, - }, - } - }, - podIP: sandboxHelperTestPodIP, - entry: &sandboxEntry{ - Kind: types.AgentRuntimeKind, - SessionID: "test-session-123", - Ports: []runtimev1alpha1.TargetPort{}, - }, - validateResult: func(t *testing.T, result *types.SandboxInfo) { - // ShutdownTime is now + 2h in setupSandbox - expectedShutdown := now.Add(2 * time.Hour) - assert.WithinDuration(t, expectedShutdown, result.ExpiresAt, 1*time.Second) - }, - }, - { - name: "sandbox with no ports", - setupSandbox: func() *sandboxv1alpha1.Sandbox { - return &sandboxv1alpha1.Sandbox{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-sandbox", - Namespace: "default", - UID: "test-uid-123", - CreationTimestamp: metav1.NewTime(now), - }, - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionTrue, - }, - }, - }, - } - }, - podIP: sandboxHelperTestPodIP, - entry: &sandboxEntry{ - Kind: types.AgentRuntimeKind, - SessionID: "test-session-123", - Ports: []runtimev1alpha1.TargetPort{}, - }, - validateResult: func(t *testing.T, result *types.SandboxInfo) { - assert.Empty(t, result.EntryPoints) - }, - }, - { - name: "sandbox with empty pod IP", - setupSandbox: func() *sandboxv1alpha1.Sandbox { - return &sandboxv1alpha1.Sandbox{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-sandbox", - Namespace: "default", - UID: "test-uid-123", - CreationTimestamp: metav1.NewTime(now), - }, - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionTrue, - }, - }, - }, - } - }, - podIP: "", - entry: &sandboxEntry{ - Kind: types.AgentRuntimeKind, - SessionID: "test-session-123", - Ports: []runtimev1alpha1.TargetPort{ - { - Port: 8080, - Protocol: runtimev1alpha1.ProtocolTypeHTTP, - PathPrefix: "/api", - }, - }, - }, - validateResult: func(t *testing.T, result *types.SandboxInfo) { - assert.Equal(t, ":8080", result.EntryPoints[0].Endpoint) - }, + entry := &sandboxEntry{ + SessionID: "sess-2", + AuthMode: runtimev1alpha1.AuthModePicoD, + Ports: []runtimev1alpha1.TargetPort{ + {Name: "picod", Port: tsPort(t, ts.URL)}, }, } + s := &Server{ + httpClient: ts.Client(), + bootstrapAuth: newTestBootstrapAuth(t), + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sandbox := tt.setupSandbox() - result := buildSandboxInfo(sandbox, tt.podIP, tt.entry) - tt.validateResult(t, result) - }) + err := s.initializePicoD(context.Background(), "127.0.0.1", entry) + require.Error(t, err) + assert.Empty(t, entry.SessionPrivateKey, + "key must NOT be set when /init returns a non-200 status") +} + +func TestInitializePicoD_SkipsNonPicoDMode(t *testing.T) { + entry := &sandboxEntry{ + SessionID: "sess-3", + AuthMode: runtimev1alpha1.AuthModeNone, } + s := &Server{bootstrapAuth: newTestBootstrapAuth(t)} + err := s.initializePicoD(context.Background(), "127.0.0.1", entry) + require.NoError(t, err) + assert.Empty(t, entry.SessionPrivateKey) } -func TestGetSandboxStatus_TableDriven(t *testing.T) { - tests := []struct { - name string - sandbox *sandboxv1alpha1.Sandbox - expected string - }{ - { - name: "ready condition true", - sandbox: &sandboxv1alpha1.Sandbox{ - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionTrue, - }, - }, - }, - }, - expected: "ready", - }, - { - name: "ready condition false without reason", - sandbox: &sandboxv1alpha1.Sandbox{ - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionFalse, - }, - }, - }, - }, - expected: "not-ready", - }, - { - name: "ready condition false with reason is not-ready", - sandbox: &sandboxv1alpha1.Sandbox{ - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionFalse, - Reason: "ErrImagePull", - Message: "Back-off pulling image", - }, - }, - }, - }, - expected: "not-ready", - }, - { - name: "ready condition unknown", - sandbox: &sandboxv1alpha1.Sandbox{ - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionUnknown, - }, - }, - }, - }, - expected: "not-ready", - }, - { - name: "no conditions", - sandbox: &sandboxv1alpha1.Sandbox{ - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{}, - }, - }, - expected: "not-ready", - }, - { - name: "nil conditions", - sandbox: &sandboxv1alpha1.Sandbox{ - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: nil, - }, - }, - expected: "not-ready", - }, - { - name: "other condition type", - sandbox: &sandboxv1alpha1.Sandbox{ - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: "OtherCondition", - Status: metav1.ConditionTrue, - }, - }, - }, - }, - expected: "not-ready", - }, - { - name: "multiple conditions with ready true", - sandbox: &sandboxv1alpha1.Sandbox{ - Status: sandboxv1alpha1.SandboxStatus{ - Conditions: []metav1.Condition{ - { - Type: "OtherCondition", - Status: metav1.ConditionFalse, - }, - { - Type: string(sandboxv1alpha1.SandboxConditionReady), - Status: metav1.ConditionTrue, - }, - }, - }, - }, - expected: "ready", - }, +func TestFindPicoDInitPort_NamedPort(t *testing.T) { + ports := []runtimev1alpha1.TargetPort{ + {Name: "http", Port: 9000}, + {Name: "picod", Port: 8080}, } + assert.Equal(t, uint32(8080), findPicoDInitPort(ports)) +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := getSandboxStatus(tt.sandbox) - assert.Equal(t, tt.expected, result) - }) +func TestFindPicoDInitPort_Fallback(t *testing.T) { + ports := []runtimev1alpha1.TargetPort{ + {Name: "http", Port: 9000}, } + assert.Equal(t, uint32(8080), findPicoDInitPort(ports)) +} + +// tsPort extracts the port number from a httptest.Server URL string. +func tsPort(t *testing.T, rawURL string) uint32 { + t.Helper() + var port uint32 + _, err := fmt.Sscanf(rawURL, "http://127.0.0.1:%d", &port) + require.NoError(t, err) + return port } diff --git a/pkg/workloadmanager/server.go b/pkg/workloadmanager/server.go index 53b10dfb..8290730c 100644 --- a/pkg/workloadmanager/server.go +++ b/pkg/workloadmanager/server.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "net/http" + "os" "sync" "time" @@ -41,6 +42,8 @@ type Server struct { tokenCache *TokenCache informers *Informers storeClient store.Store + httpClient *http.Client + bootstrapAuth *BootstrapAuthManager wg sync.WaitGroup } @@ -62,6 +65,9 @@ type Config struct { SandboxReadyProbeTimeout time.Duration // SandboxReadyProbeInterval is the retry interval for sandbox entrypoint probes. SandboxReadyProbeInterval time.Duration + // PicoInitTimeout is the timeout for the POST /init call to PicoD during + // two-stage session initialization. + PicoInitTimeout time.Duration } // NewServer creates a new API server instance @@ -75,6 +81,9 @@ func NewServer(config *Config, sandboxController *SandboxReconciler) (*Server, e if config.SandboxReadyProbeInterval <= 0 { config.SandboxReadyProbeInterval = defaultSandboxReadyProbeInterval } + if config.PicoInitTimeout <= 0 { + config.PicoInitTimeout = defaultPicoInitTimeout + } // Create Kubernetes client k8sClient, err := NewK8sClient() @@ -82,6 +91,18 @@ func NewServer(config *Config, sandboxController *SandboxReconciler) (*Server, e return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) } + // Determine namespace + ns := "default" + if v := os.Getenv("AGENTCUBE_NAMESPACE"); v != "" { + ns = v + } + + // Persist-or-load the bootstrap keypair; survives WM restarts. + bootstrapAuth, err := NewBootstrapAuthManager(context.Background(), k8sClient.clientset, ns) + if err != nil { + return nil, fmt.Errorf("failed to initialize bootstrap auth manager: %w", err) + } + // Initialize public key cache from Router's Secret in background // This will retry until successful (handles case where Router isn't ready yet) InitPublicKeyCache(k8sClient.clientset) @@ -96,6 +117,8 @@ func NewServer(config *Config, sandboxController *SandboxReconciler) (*Server, e tokenCache: tokenCache, informers: NewInformers(k8sClient), storeClient: store.Storage(), + httpClient: &http.Client{Timeout: config.PicoInitTimeout}, + bootstrapAuth: bootstrapAuth, } // Setup routes @@ -206,6 +229,14 @@ func (s *Server) CloseStore() error { return nil } +// GetBootstrapPublicKeyPEM returns the bootstrap public key in PEM format. +func (s *Server) GetBootstrapPublicKeyPEM() string { + if s.bootstrapAuth == nil { + return "" + } + return s.bootstrapAuth.PublicKeyPEM() +} + // loggingMiddleware logs each request (except /health) func (s *Server) loggingMiddleware(c *gin.Context) { start := time.Now() diff --git a/pkg/workloadmanager/workload_builder.go b/pkg/workloadmanager/workload_builder.go index 4abe59de..5fc04168 100644 --- a/pkg/workloadmanager/workload_builder.go +++ b/pkg/workloadmanager/workload_builder.go @@ -297,25 +297,26 @@ func buildSandboxByAgentRuntime(namespace string, name string, ifm *Informers) ( Ports: agentRuntimeObj.Spec.Ports, SessionID: sessionID, IdleTimeout: idleTimeout, + AuthMode: runtimev1alpha1.AuthModeNone, // AgentRuntime doesn't explicitly have AuthMode defined in the CRD, but we default to None. } return sandbox, entry, nil } // buildCodeInterpreterEnvVars copies the template env vars and injects the // public key when authMode is picod. -func buildCodeInterpreterEnvVars(templateEnv []corev1.EnvVar, authMode runtimev1alpha1.AuthModeType) []corev1.EnvVar { +func buildCodeInterpreterEnvVars(templateEnv []corev1.EnvVar, authMode runtimev1alpha1.AuthModeType, bootstrapPubKey string) []corev1.EnvVar { envVars := make([]corev1.EnvVar, len(templateEnv)) copy(envVars, templateEnv) if authMode == runtimev1alpha1.AuthModePicoD { envVars = append(envVars, corev1.EnvVar{ - Name: "PICOD_AUTH_PUBLIC_KEY", - Value: GetCachedPublicKey(), + Name: "PICOD_BOOTSTRAP_PUBLIC_KEY", + Value: bootstrapPubKey, }) } return envVars } -func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, informer *Informers) (*sandboxv1alpha1.Sandbox, *extensionsv1alpha1.SandboxClaim, *sandboxEntry, error) { +func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, informer *Informers, bootstrapPubKey string) (*sandboxv1alpha1.Sandbox, *extensionsv1alpha1.SandboxClaim, *sandboxEntry, error) { codeInterpreterKey := namespace + "/" + codeInterpreterName // TODO(hzxuzhonghu): make use of typed informer, so we don't need to do type conversion below runtimeObj, exists, err := informer.CodeInterpreterInformer.GetStore().GetByKey(codeInterpreterKey) @@ -354,6 +355,7 @@ func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, Ports: codeInterpreterObj.Spec.Ports, SessionID: sessionID, IdleTimeout: idleTimeout, + AuthMode: codeInterpreterObj.Spec.AuthMode, } // Set default port for code interpreter if not configured @@ -404,7 +406,7 @@ func buildSandboxByCodeInterpreter(namespace string, codeInterpreterName string, runtimeClassName = nil } - envVars := buildCodeInterpreterEnvVars(codeInterpreterObj.Spec.Template.Environment, codeInterpreterObj.Spec.AuthMode) + envVars := buildCodeInterpreterEnvVars(codeInterpreterObj.Spec.Template.Environment, codeInterpreterObj.Spec.AuthMode, bootstrapPubKey) podSpec := corev1.PodSpec{ ImagePullSecrets: codeInterpreterObj.Spec.Template.ImagePullSecrets, diff --git a/test/e2e/run_e2e.sh b/test/e2e/run_e2e.sh index 3f6c053f..0eb4ffd7 100755 --- a/test/e2e/run_e2e.sh +++ b/test/e2e/run_e2e.sh @@ -347,6 +347,10 @@ run_setup() { --set-json "router.extraEnv=${ROUTER_EXTRA_ENV}" \ --wait + echo "Rollout restarting deployments to pick up new images..." + kubectl rollout restart deployment/workloadmanager -n "${AGENTCUBE_NAMESPACE}" + kubectl rollout restart deployment/agentcube-router -n "${AGENTCUBE_NAMESPACE}" + step "Waiting for deployments..." kubectl -n "${AGENTCUBE_NAMESPACE}" rollout status deployment/workloadmanager --timeout=300s kubectl -n "${AGENTCUBE_NAMESPACE}" rollout status deployment/agentcube-router --timeout=300s