Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions cmd/workload-manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,6 @@ func main() {
os.Exit(1)
}

sandboxReconciler := &workloadmanager.SandboxReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
}

codeInterpreterReconciler := &workloadmanager.CodeInterpreterReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
}

if err := setupControllers(mgr, sandboxReconciler, codeInterpreterReconciler); err != nil {
fmt.Fprintf(os.Stderr, "unable to setup controllers: %v\n", err)
os.Exit(1)
}

// Create API server configuration
config := &workloadmanager.Config{
Port: *port,
Expand All @@ -105,12 +90,28 @@ func main() {
EnableAuth: *enableAuth,
}

sandboxReconciler := &workloadmanager.SandboxReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
}

// Create and initialize API server
server, err := workloadmanager.NewServer(config, sandboxReconciler)
if err != nil {
klog.Fatalf("Failed to create API server: %v", err)
}

codeInterpreterReconciler := &workloadmanager.CodeInterpreterReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
BootstrapPublicKeyPEM: server.GetBootstrapPublicKeyPEM(),
}

if err := setupControllers(mgr, sandboxReconciler, codeInterpreterReconciler); err != nil {
fmt.Fprintf(os.Stderr, "unable to setup controllers: %v\n", err)
os.Exit(1)
}

// Setup signal handling
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
7 changes: 7 additions & 0 deletions docs/design/PicoD-Plain-Authentication-Design.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ However, emerging use cases require a more flexible architecture where the clien

The existing self-signed key-pair model is incompatible with this centralized management flow, as it bypasses the Router's ability to mediate access. To address this, we propose a new **Plain Authentication** mechanism for `picod`. This design enables the Router/Gateway to manage credentials and connection security, simplifying the client-side workflow while maintaining robust access control.

> [!WARNING]
> **Migration Note (Two-Stage Secure Initialization)**
>
> The flow described in this document was updated by PR #352 to address cross-sandbox token replay vulnerabilities. The original `PICOD_AUTH_PUBLIC_KEY` environment variable has been renamed to `PICOD_BOOTSTRAP_PUBLIC_KEY`.
>
> While `PICOD_AUTH_PUBLIC_KEY` is still supported as a fallback for backwards compatibility, deployments should migrate to `PICOD_BOOTSTRAP_PUBLIC_KEY`. Under the new model, this key is only used to verify the bootstrap payload during the `/init` handshake, which establishes a unique session keypair for subsequent requests.

## Use Cases

### Gateway-Managed Sandbox Access
Expand Down
4 changes: 4 additions & 0 deletions pkg/common/types/sandbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ type SandboxInfo struct {
// metav1.Duration marshals as a human-readable string (e.g. "15m0s") rather than
// a raw nanosecond integer, making the persisted JSON unambiguous.
IdleTimeout metav1.Duration `json:"idleTimeout,omitempty"`
// SessionPrivateKey is intentionally excluded from JSON serialization so it
// is never persisted to the KV store or exposed via list/get APIs.
// It is only populated transiently in the WM→Router HTTP response path.
SessionPrivateKey string `json:"-"`
// LastActivityAt is populated transiently from the store's last-activity sorted set
// during ListInactiveSandboxes. It is intentionally excluded from JSON serialization.
LastActivityAt time.Time `json:"-"`
Expand Down
139 changes: 117 additions & 22 deletions pkg/picod/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net/http"
"os"
Expand All @@ -32,58 +33,152 @@ import (
"k8s.io/klog/v2"
)

var (
// ErrAlreadyInitialized is returned when attempting to initialize PicoD session key again
ErrAlreadyInitialized = errors.New("session has already been initialized")
)

const (
// PublicKeyEnvVar is the environment variable name for the public key
PublicKeyEnvVar = "PICOD_AUTH_PUBLIC_KEY"

// BootstrapPublicKeyEnvVar is the environment variable name for the bootstrap public key
BootstrapPublicKeyEnvVar = "PICOD_BOOTSTRAP_PUBLIC_KEY"
// LegacyBootstrapPublicKeyEnvVar is the deprecated environment variable name for the bootstrap public key
LegacyBootstrapPublicKeyEnvVar = "PICOD_AUTH_PUBLIC_KEY"
)

// AuthManager manages RSA public key authentication
// The public key is loaded from environment variable at startup
// The bootstrap public key is loaded from environment variable at startup
// The session public key is set dynamically via the /init endpoint
type AuthManager struct {
publicKey *rsa.PublicKey
mutex sync.RWMutex
bootstrapPublicKey *rsa.PublicKey
sessionPublicKey *rsa.PublicKey
initialized bool
mutex sync.RWMutex
}

// NewAuthManager creates a new auth manager
func NewAuthManager() *AuthManager {
return &AuthManager{}
}

// LoadPublicKeyFromEnv loads the public key from environment variable.
// parseRSAPublicKeyFromPEM parses an RSA public key from a PEM string
func parseRSAPublicKeyFromPEM(keyData string) (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(keyData))
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block")
}

pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err)
}

rsaPub, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("key is not an RSA public key")
}

return rsaPub, nil
}

// LoadBootstrapPublicKey loads the bootstrap public key from environment variable.
// The key should be in PEM format.
func (am *AuthManager) LoadPublicKeyFromEnv() error {
func (am *AuthManager) LoadBootstrapPublicKey() error {
am.mutex.Lock()
defer am.mutex.Unlock()

keyData := os.Getenv(PublicKeyEnvVar)
keyData := os.Getenv(BootstrapPublicKeyEnvVar)
if keyData == "" {
return fmt.Errorf("environment variable %s is not set", PublicKeyEnvVar)
keyData = os.Getenv(LegacyBootstrapPublicKeyEnvVar)
if keyData != "" {
klog.Warningf("Using deprecated environment variable %s, please migrate to %s", LegacyBootstrapPublicKeyEnvVar, BootstrapPublicKeyEnvVar)
}
}

block, _ := pem.Decode([]byte(keyData))
if block == nil {
return fmt.Errorf("failed to decode PEM block from %s", PublicKeyEnvVar)
if keyData == "" {
return fmt.Errorf("environment variable %s (or deprecated %s) is not set", BootstrapPublicKeyEnvVar, LegacyBootstrapPublicKeyEnvVar)
}

pub, err := x509.ParsePKIXPublicKey(block.Bytes)
rsaPub, err := parseRSAPublicKeyFromPEM(keyData)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
return fmt.Errorf("failed to parse bootstrap public key: %w", err)
}

rsaPub, ok := pub.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("key is not an RSA public key")
am.bootstrapPublicKey = rsaPub
klog.Info("Bootstrap public key loaded successfully from environment variable")
return nil
}

// SetSessionPublicKey parses and stores the ephemeral session public key
func (am *AuthManager) SetSessionPublicKey(keyData string) error {
am.mutex.Lock()
defer am.mutex.Unlock()

if am.initialized {
klog.Warning("Attempted to re-initialize an already initialized session")
return ErrAlreadyInitialized
}

am.publicKey = rsaPub
klog.Info("Public key loaded successfully from environment variable")
rsaPub, err := parseRSAPublicKeyFromPEM(keyData)
if err != nil {
return fmt.Errorf("failed to parse session public key: %w", err)
}

am.sessionPublicKey = rsaPub
am.initialized = true
klog.Info("Session public key successfully registered via /init")
return nil
Comment thread
Abhinav-kodes marked this conversation as resolved.
}

// VerifyBootstrapJWT verifies the init token against the bootstrap public key and returns the session_public_key claim
func (am *AuthManager) VerifyBootstrapJWT(tokenStr string) (string, error) {
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
am.mutex.RLock()
defer am.mutex.RUnlock()
if am.bootstrapPublicKey == nil {
return nil, fmt.Errorf("bootstrap public key is not loaded")
}
return am.bootstrapPublicKey, nil
}, jwt.WithExpirationRequired(), jwt.WithIssuedAt(), jwt.WithLeeway(time.Minute), jwt.WithIssuer("agentcube-workload-manager"))

if err != nil || !token.Valid {
return "", fmt.Errorf("JWT verification failed: %w", err)
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", fmt.Errorf("invalid token claims")
}

sessionPubKey, ok := claims["session_public_key"].(string)
if !ok || sessionPubKey == "" {
return "", fmt.Errorf("missing or invalid session_public_key claim")
}

return sessionPubKey, nil
}

// AuthMiddleware creates authentication middleware with JWT verification
// Note: Public key must be loaded at startup (via LoadPublicKeyFromEnv), so we don't check here
// Note: Requires the daemon to be initialized with a session public key.
func (am *AuthManager) AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Check if initialized
am.mutex.RLock()
isInit := am.initialized
am.mutex.RUnlock()
if !isInit {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "Daemon not initialized",
"code": http.StatusServiceUnavailable,
"detail": "PicoD is waiting for Workload Manager to initialize the session",
})
c.Abort()
return
}

authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{
Expand All @@ -108,14 +203,14 @@ func (am *AuthManager) AuthMiddleware() gin.HandlerFunc {

tokenString := parts[1]

// Parse and validate JWT using the public key
// Parse and validate JWT using the session public key
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
am.mutex.RLock()
defer am.mutex.RUnlock()
return am.publicKey, nil
return am.sessionPublicKey, nil
}, jwt.WithExpirationRequired(), jwt.WithIssuedAt(), jwt.WithLeeway(time.Minute))
Comment thread
Abhinav-kodes marked this conversation as resolved.

if err != nil || !token.Valid {
Expand Down
27 changes: 15 additions & 12 deletions pkg/picod/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestNewAuthManager(t *testing.T) {
assert.NotNil(t, manager)
}

func TestLoadPublicKeyFromEnv(t *testing.T) {
func TestLoadBootstrapPublicKey(t *testing.T) {
tests := []struct {
name string
setupEnv func() string
Expand Down Expand Up @@ -120,14 +120,14 @@ MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAinvalid
t.Run(tt.name, func(t *testing.T) {
envValue := tt.setupEnv()
if envValue != "" {
os.Setenv(PublicKeyEnvVar, envValue)
defer os.Unsetenv(PublicKeyEnvVar)
os.Setenv(BootstrapPublicKeyEnvVar, envValue)
defer os.Unsetenv(BootstrapPublicKeyEnvVar)
} else {
os.Unsetenv(PublicKeyEnvVar)
os.Unsetenv(BootstrapPublicKeyEnvVar)
}

manager := NewAuthManager()
err := manager.LoadPublicKeyFromEnv()
err := manager.LoadBootstrapPublicKey()

if tt.wantErr {
assert.Error(t, err)
Expand All @@ -145,11 +145,13 @@ func TestAuthMiddleware_HeaderValidation(t *testing.T) {
_, pubKeyPEM, err := generateTestRSAKeyPair()
require.NoError(t, err)

os.Setenv(PublicKeyEnvVar, pubKeyPEM)
defer os.Unsetenv(PublicKeyEnvVar)
os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM)
defer os.Unsetenv(BootstrapPublicKeyEnvVar)

manager := NewAuthManager()
err = manager.LoadPublicKeyFromEnv()
err = manager.LoadBootstrapPublicKey()
require.NoError(t, err)
err = manager.SetSessionPublicKey(pubKeyPEM)
require.NoError(t, err)

tests := []struct {
Expand Down Expand Up @@ -212,11 +214,13 @@ func TestAuthMiddleware_TokenValidation(t *testing.T) {
privateKey, pubKeyPEM, err := generateTestRSAKeyPair()
require.NoError(t, err)

os.Setenv(PublicKeyEnvVar, pubKeyPEM)
defer os.Unsetenv(PublicKeyEnvVar)
os.Setenv(BootstrapPublicKeyEnvVar, pubKeyPEM)
defer os.Unsetenv(BootstrapPublicKeyEnvVar)

manager := NewAuthManager()
err = manager.LoadPublicKeyFromEnv()
err = manager.LoadBootstrapPublicKey()
require.NoError(t, err)
err = manager.SetSessionPublicKey(pubKeyPEM)
require.NoError(t, err)

tests := []struct {
Expand Down Expand Up @@ -302,4 +306,3 @@ func TestAuthMiddleware_TokenValidation(t *testing.T) {
}
}


Loading
Loading