|
| 1 | +package agent |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "crypto/rand" |
| 6 | + "crypto/subtle" |
| 7 | + "encoding/hex" |
| 8 | + "fmt" |
| 9 | + |
| 10 | + "github.com/threatwinds/go-sdk/catcher" |
| 11 | + "github.com/utmstack/UTMStack/agent-manager/models" |
| 12 | + "google.golang.org/grpc/codes" |
| 13 | + "google.golang.org/grpc/status" |
| 14 | + "gorm.io/gorm" |
| 15 | +) |
| 16 | + |
| 17 | +func (s *AgentService) loadConnectionKey() error { |
| 18 | + var rows []models.ConnectionKey |
| 19 | + if _, err := s.DBConnection.GetAll(&rows, ""); err != nil { |
| 20 | + return fmt.Errorf("failed to load connection key: %v", err) |
| 21 | + } |
| 22 | + |
| 23 | + if len(rows) > 0 { |
| 24 | + s.connKeyMutex.Lock() |
| 25 | + s.connKeyID = rows[0].ID |
| 26 | + s.connKey = rows[0].Key |
| 27 | + s.connKeyMutex.Unlock() |
| 28 | + return nil |
| 29 | + } |
| 30 | + |
| 31 | + key, err := generateConnectionKey() |
| 32 | + if err != nil { |
| 33 | + return err |
| 34 | + } |
| 35 | + row := models.ConnectionKey{Key: key} |
| 36 | + if err := s.DBConnection.Create(&row); err != nil { |
| 37 | + return fmt.Errorf("failed to create connection key: %v", err) |
| 38 | + } |
| 39 | + s.connKeyMutex.Lock() |
| 40 | + s.connKeyID = row.ID |
| 41 | + s.connKey = row.Key |
| 42 | + s.connKeyMutex.Unlock() |
| 43 | + catcher.Info("Generated agent connection key", map[string]any{"process": "agent-manager"}) |
| 44 | + return nil |
| 45 | +} |
| 46 | + |
| 47 | +// ValidateConnectionKey reports whether the presented key matches the current |
| 48 | +// connection key. |
| 49 | +func (s *AgentService) ValidateConnectionKey(key string) bool { |
| 50 | + s.connKeyMutex.RLock() |
| 51 | + defer s.connKeyMutex.RUnlock() |
| 52 | + if s.connKey == "" { |
| 53 | + return false |
| 54 | + } |
| 55 | + return subtle.ConstantTimeCompare([]byte(key), []byte(s.connKey)) == 1 |
| 56 | +} |
| 57 | + |
| 58 | +func (s *AgentService) GetConnectionKey(ctx context.Context, req *ConnectionKeyRequest) (*ConnectionKeyResponse, error) { |
| 59 | + s.connKeyMutex.RLock() |
| 60 | + key := s.connKey |
| 61 | + s.connKeyMutex.RUnlock() |
| 62 | + return &ConnectionKeyResponse{ConnectionKey: key}, nil |
| 63 | +} |
| 64 | + |
| 65 | +func (s *AgentService) RotateConnectionKey(ctx context.Context, req *ConnectionKeyRequest) (*ConnectionKeyResponse, error) { |
| 66 | + key, err := generateConnectionKey() |
| 67 | + if err != nil { |
| 68 | + return nil, status.Error(codes.Internal, fmt.Sprintf("failed to generate connection key: %v", err)) |
| 69 | + } |
| 70 | + |
| 71 | + s.connKeyMutex.RLock() |
| 72 | + id := s.connKeyID |
| 73 | + s.connKeyMutex.RUnlock() |
| 74 | + |
| 75 | + if err := s.DBConnection.Upsert(&models.ConnectionKey{Model: gorm.Model{ID: id}, Key: key}, "id = ?", map[string]interface{}{"key": key}, id); err != nil { |
| 76 | + return nil, status.Error(codes.Internal, fmt.Sprintf("failed to persist connection key: %v", err)) |
| 77 | + } |
| 78 | + |
| 79 | + s.connKeyMutex.Lock() |
| 80 | + s.connKey = key |
| 81 | + s.connKeyMutex.Unlock() |
| 82 | + |
| 83 | + catcher.Info("Rotated agent connection key", map[string]any{"process": "agent-manager"}) |
| 84 | + return &ConnectionKeyResponse{ConnectionKey: key}, nil |
| 85 | +} |
| 86 | + |
| 87 | +func generateConnectionKey() (string, error) { |
| 88 | + b := make([]byte, 32) |
| 89 | + if _, err := rand.Read(b); err != nil { |
| 90 | + return "", fmt.Errorf("failed to read random bytes: %v", err) |
| 91 | + } |
| 92 | + return hex.EncodeToString(b), nil |
| 93 | +} |
0 commit comments