diff --git a/cmd/deployment-tracker/main.go b/cmd/deployment-tracker/main.go index 05c4747..b047782 100644 --- a/cmd/deployment-tracker/main.go +++ b/cmd/deployment-tracker/main.go @@ -3,6 +3,7 @@ package main import ( "context" "flag" + "fmt" "log" "log/slog" "net/http" @@ -44,6 +45,13 @@ func main() { flag.StringVar(&metricsPort, "metrics-port", "9090", "port to listen to for metrics") flag.Parse() + // Validate worker count + if workers < 1 || workers > 100 { + slog.Error("Invalid worker count, must be between 1 and 100", + "workers", workers) + os.Exit(1) + } + // init logging log.SetFlags(log.LstdFlags | log.Lshortfile | log.LUTC) opts := slog.HandlerOptions{Level: slog.LevelInfo} @@ -59,6 +67,13 @@ func main() { Organization: os.Getenv("GITHUB_ORG"), } + if !controller.ValidTemplate(cntrlCfg.Template) { + slog.Error("Template must contain at least one placeholder", + "template", cntrlCfg.Template, + "valid_placeholders", []string{controller.TmplNS, controller.TmplDN, controller.TmplCN}) + os.Exit(1) + } + if cntrlCfg.LogicalEnvironment == "" { slog.Error("Logical environment is required") os.Exit(1) @@ -87,20 +102,20 @@ func main() { } // Start the metrics server + var promSrv = &http.Server{ + Addr: ":" + metricsPort, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + ReadHeaderTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + Handler: http.NewServeMux(), + } + promSrv.Handler.(*http.ServeMux).Handle("/metrics", promhttp.Handler()) + go func() { - var mm = http.NewServeMux() - mm.Handle("/metrics", promhttp.Handler()) - - var promSrv = &http.Server{ - Addr: ":" + metricsPort, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - ReadHeaderTimeout: 10 * time.Second, - Handler: mm, - } slog.Info("starting Prometheus metrics server", "url", promSrv.Addr) - if err := promSrv.ListenAndServe(); err != nil { + if err := promSrv.ListenAndServe(); err != nil && err != http.ErrServerClosed { slog.Error("failed to start metrics server", "error", err) } @@ -113,10 +128,24 @@ func main() { go func() { <-sigCh slog.Info("Shutting down...") + + // Gracefully shutdown the metrics server + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := promSrv.Shutdown(shutdownCtx); err != nil { + slog.Error("failed to shutdown metrics server gracefully", + "error", err) + } + cancel() }() - cntrl := controller.New(clientset, namespace, &cntrlCfg) + cntrl, err := controller.New(clientset, namespace, &cntrlCfg) + if err != nil { + slog.Error("Failed to create controller", + "error", err) + os.Exit(1) + } slog.Info("Starting deployment-tracker controller") if err := cntrl.Run(ctx, workers); err != nil { @@ -144,6 +173,9 @@ func createK8sConfig(kubeconfig string) (*rest.Config, error) { } // Fall back to default kubeconfig location - homeDir, _ := os.UserHomeDir() + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get user home directory: %w", err) + } return clientcmd.BuildConfigFromFlags("", homeDir+"/.kube/config") } diff --git a/go.mod b/go.mod index 87b0f09..2b41c71 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ require ( golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect - golang.org/x/time v0.9.0 // indirect + golang.org/x/time v0.14.0 // indirect google.golang.org/protobuf v1.36.8 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index dcf9eea..da0a6ea 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,8 @@ golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= diff --git a/internal/controller/config.go b/internal/controller/config.go index 7aa8295..a5cb060 100644 --- a/internal/controller/config.go +++ b/internal/controller/config.go @@ -1,5 +1,9 @@ package controller +import ( + "strings" +) + const ( // TmplNS is the meta variable for the k8s namespace. TmplNS = "{{namespace}}" @@ -19,3 +23,13 @@ type Config struct { BaseURL string Organization string } + +// ValidTemplate verifies that at least one placeholder is present +// in the provided template t. +func ValidTemplate(t string) bool { + hasPlaceholder := strings.Contains(t, TmplNS) || + strings.Contains(t, TmplDN) || + strings.Contains(t, TmplCN) + + return hasPlaceholder +} diff --git a/internal/controller/config_test.go b/internal/controller/config_test.go new file mode 100644 index 0000000..3ec2b8b --- /dev/null +++ b/internal/controller/config_test.go @@ -0,0 +1,113 @@ +package controller + +import ( + "testing" +) + +func TestValidTemplate(t *testing.T) { + tests := []struct { + name string + template string + expected bool + }{ + { + name: "empty string", + template: "", + expected: false, + }, + { + name: "static string without placeholders", + template: "static-deployment-name", + expected: false, + }, + { + name: "namespace placeholder only", + template: "{{namespace}}", + expected: true, + }, + { + name: "deployment name placeholder only", + template: "{{deploymentName}}", + expected: true, + }, + { + name: "container name placeholder only", + template: "{{containerName}}", + expected: true, + }, + { + name: "all three placeholders", + template: "{{namespace}}/{{deploymentName}}/{{containerName}}", + expected: true, + }, + { + name: "namespace and deployment name", + template: "{{namespace}}-{{deploymentName}}", + expected: true, + }, + { + name: "mixed static and placeholders", + template: "prefix-{{namespace}}-suffix", + expected: true, + }, + { + name: "placeholder with surrounding text", + template: "app/{{containerName}}/prod", + expected: true, + }, + { + name: "similar but invalid placeholder", + template: "{{namespaces}}", + expected: false, + }, + { + name: "partial placeholder - missing closing braces", + template: "{{namespace", + expected: false, + }, + { + name: "partial placeholder - missing opening braces", + template: "namespace}}", + expected: false, + }, + { + name: "wrong case placeholder", + template: "{{Namespace}}", + expected: false, + }, + { + name: "placeholder with extra space", + template: "{{ namespace }}", + expected: false, + }, + { + name: "default template format", + template: TmplNS + "/" + TmplDN + "/" + TmplCN, + expected: true, + }, + { + name: "complex valid template", + template: "org/{{namespace}}/env/{{deploymentName}}/container/{{containerName}}", + expected: true, + }, + { + name: "whitespace only", + template: " ", + expected: false, + }, + { + name: "special characters without placeholders", + template: "app-name_v1.2.3", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ValidTemplate(tt.template) + if result != tt.expected { + t.Errorf("ValidTemplate(%q) = %v, expected %v", tt.template, result, tt.expected) + } + }) + } +} diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 97ed8ec..e73dbe5 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -3,6 +3,7 @@ package controller import ( "context" "errors" + "fmt" "log/slog" "strings" "time" @@ -36,7 +37,7 @@ type Controller struct { } // New creates a new deployment tracker controller. -func New(clientset kubernetes.Interface, namespace string, cfg *Config) *Controller { +func New(clientset kubernetes.Interface, namespace string, cfg *Config) (*Controller, error) { // Create informer factory var factory informers.SharedInformerFactory if namespace == "" { @@ -63,11 +64,14 @@ func New(clientset kubernetes.Interface, namespace string, cfg *Config) *Control if cfg.APIToken != "" { clientOpts = append(clientOpts, deploymentrecord.WithAPIToken(cfg.APIToken)) } - apiClient := deploymentrecord.NewClient( + apiClient, err := deploymentrecord.NewClient( cfg.BaseURL, cfg.Organization, clientOpts..., ) + if err != nil { + return nil, fmt.Errorf("failed to create API client: %w", err) + } cntrl := &Controller{ clientset: clientset, @@ -78,7 +82,7 @@ func New(clientset kubernetes.Interface, namespace string, cfg *Config) *Control } // Add event handlers to the informer - _, err := podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + _, err = podInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{ AddFunc: func(obj any) { pod, ok := obj.(*corev1.Pod) if !ok { @@ -172,11 +176,10 @@ func New(clientset kubernetes.Interface, namespace string, cfg *Config) *Control }) if err != nil { - slog.Error("Failed to add event handlers", - "error", err) + return nil, fmt.Errorf("failed to add event handlers: %w", err) } - return cntrl + return cntrl, nil } // Run starts the controller. @@ -304,6 +307,13 @@ func (c *Controller) recordContainer(ctx context.Context, pod *corev1.Pod, conta digest := getContainerDigest(pod, container.Name) if dn == "" || digest == "" { + slog.Debug("Skipping container: missing deployment name or digest", + "namespace", pod.Namespace, + "pod", pod.Name, + "container", container.Name, + "deployment_name", dn, + "has_digest", digest != "", + ) return nil } diff --git a/pkg/deploymentrecord/client.go b/pkg/deploymentrecord/client.go index 4c2dd93..6eae994 100644 --- a/pkg/deploymentrecord/client.go +++ b/pkg/deploymentrecord/client.go @@ -6,28 +6,60 @@ import ( "encoding/json" "errors" "fmt" + "io" "log/slog" + "math" + "math/rand/v2" "net/http" + "regexp" + "strings" "time" "github.com/github/deployment-tracker/pkg/metrics" + "golang.org/x/time/rate" ) // ClientOption is a function that configures the Client. type ClientOption func(*Client) +// validOrgPattern validates organization names (alphanumeric, hyphens, +// underscores). +var validOrgPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + // Client is an API client for posting deployment records. type Client struct { - baseURL string - org string - httpClient *http.Client - retries int - apiToken string + baseURL string + org string + httpClient *http.Client + retries int + apiToken string + rateLimiter *rate.Limiter } // NewClient creates a new API client with the given base URL and -// organization. -func NewClient(baseURL, org string, opts ...ClientOption) *Client { +// organization. Returns an error if the base URL is not HTTPS for +// non-local hosts. +func NewClient(baseURL, org string, opts ...ClientOption) (*Client, error) { + // Check if URL is local (allowed to use HTTP) + isLocal := strings.HasPrefix(baseURL, "http://localhost") || + strings.HasPrefix(baseURL, "http://127.0.0.1") || + strings.Contains(baseURL, ".svc.cluster.local") + + // Reject non-HTTPS URLs for non-local hosts + if strings.HasPrefix(baseURL, "http://") && !isLocal { + return nil, fmt.Errorf("insecure URL not allowed: %s (use HTTPS for non-local hosts)", baseURL) + } + + // Add https:// prefix if no scheme is provided + if !strings.HasPrefix(baseURL, "https://") && !strings.HasPrefix(baseURL, "http://") { + baseURL = "https://" + baseURL + } + + // Validate organization name to prevent URL injection + if !validOrgPattern.MatchString(org) { + return nil, fmt.Errorf("invalid organization name: %s (must be alphanumeric, hyphens, or underscores)", org) + } + c := &Client{ baseURL: baseURL, org: org, @@ -35,13 +67,15 @@ func NewClient(baseURL, org string, opts ...ClientOption) *Client { Timeout: 5 * time.Second, }, retries: 3, + // 20 req/sec with burst of 50 + rateLimiter: rate.NewLimiter(rate.Limit(20), 50), } for _, opt := range opts { opt(c) } - return c + return c, nil } // WithTimeout sets the HTTP client timeout in seconds. @@ -65,6 +99,13 @@ func WithAPIToken(token string) ClientOption { } } +// WithRateLimiter sets a custom rate limiter for API calls. +func WithRateLimiter(rps float64, burst int) ClientOption { + return func(c *Client) { + c.rateLimiter = rate.NewLimiter(rate.Limit(rps), burst) + } +} + // ClientError represents a client error that can not be retried. type ClientError struct { err error @@ -85,6 +126,11 @@ func (c *Client) PostOne(ctx context.Context, record *DeploymentRecord) error { return errors.New("record cannot be nil") } + // Wait for rate limiter + if err := c.rateLimiter.Wait(ctx); err != nil { + return fmt.Errorf("rate limiter wait failed: %w", err) + } + url := fmt.Sprintf("%s/orgs/%s/artifacts/metadata/deployment-record", c.baseURL, c.org) body, err := json.Marshal(record) @@ -98,9 +144,22 @@ func (c *Client) PostOne(ctx context.Context, record *DeploymentRecord) error { // The first attempt is not a retry! for attempt := range c.retries + 1 { if attempt > 0 { - // Wait before retry with exponential backoff - time.Sleep(time.Duration(attempt*100) * - time.Millisecond) + backoff := time.Duration(math.Pow(2, + float64(attempt))) * 100 * time.Millisecond + //nolint:gosec + jitter := time.Duration(rand.Int64N(50)) * time.Millisecond + delay := backoff + jitter + + if delay > 5*time.Second { + delay = 5 * time.Second + } + + // Wait with context cancellation support + select { + case <-time.After(delay): + case <-ctx.Done(): + return fmt.Errorf("context cancelled during retry backoff: %w", ctx.Err()) + } } // Reset reader position for retries @@ -130,6 +189,9 @@ func (c *Client) PostOne(ctx context.Context, record *DeploymentRecord) error { metrics.PostDeploymentRecordSoftFail.Inc() continue } + + // Drain and close response body to enable connection reuse + _, _ = io.Copy(io.Discard, resp.Body) resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { diff --git a/pkg/deploymentrecord/client_test.go b/pkg/deploymentrecord/client_test.go new file mode 100644 index 0000000..97e4bba --- /dev/null +++ b/pkg/deploymentrecord/client_test.go @@ -0,0 +1,266 @@ +package deploymentrecord + +import ( + "strings" + "testing" + "time" +) + +func TestNewClient(t *testing.T) { + tests := []struct { + name string + baseURL string + org string + wantErr bool + errContains string + wantBaseURL string + }{ + { + name: "valid HTTPS URL", + baseURL: "https://api.github.com", + org: "my-org", + wantErr: false, + wantBaseURL: "https://api.github.com", + }, + { + name: "URL without scheme gets HTTPS prefix", + baseURL: "api.github.com", + org: "my-org", + wantErr: false, + wantBaseURL: "https://api.github.com", + }, + { + name: "HTTP URL rejected for non-local host", + baseURL: "http://api.github.com", + org: "my-org", + wantErr: true, + errContains: "insecure URL not allowed", + }, + { + name: "HTTP localhost allowed", + baseURL: "http://localhost:8080", + org: "my-org", + wantErr: false, + wantBaseURL: "http://localhost:8080", + }, + { + name: "HTTP localhost without port allowed", + baseURL: "http://localhost", + org: "my-org", + wantErr: false, + wantBaseURL: "http://localhost", + }, + { + name: "HTTP 127.0.0.1 allowed", + baseURL: "http://127.0.0.1:9090", + org: "my-org", + wantErr: false, + wantBaseURL: "http://127.0.0.1:9090", + }, + { + name: "HTTP Kubernetes service allowed", + baseURL: "http://my-service.my-namespace.svc.cluster.local:8080", + org: "my-org", + wantErr: false, + wantBaseURL: "http://my-service.my-namespace.svc.cluster.local:8080", + }, + { + name: "HTTPS Kubernetes service allowed", + baseURL: "https://my-service.my-namespace.svc.cluster.local", + org: "my-org", + wantErr: false, + wantBaseURL: "https://my-service.my-namespace.svc.cluster.local", + }, + { + name: "valid org with hyphens", + baseURL: "https://api.github.com", + org: "my-org-name", + wantErr: false, + wantBaseURL: "https://api.github.com", + }, + { + name: "valid org with underscores", + baseURL: "https://api.github.com", + org: "my_org_name", + wantErr: false, + wantBaseURL: "https://api.github.com", + }, + { + name: "valid org alphanumeric", + baseURL: "https://api.github.com", + org: "MyOrg123", + wantErr: false, + wantBaseURL: "https://api.github.com", + }, + { + name: "invalid org with spaces", + baseURL: "https://api.github.com", + org: "my org", + wantErr: true, + errContains: "invalid organization name", + }, + { + name: "invalid org with slash", + baseURL: "https://api.github.com", + org: "my-org/../other", + wantErr: true, + errContains: "invalid organization name", + }, + { + name: "invalid org with special characters", + baseURL: "https://api.github.com", + org: "my@org!", + wantErr: true, + errContains: "invalid organization name", + }, + { + name: "empty org", + baseURL: "https://api.github.com", + org: "", + wantErr: true, + errContains: "invalid organization name", + }, + { + name: "HTTP with external IP rejected", + baseURL: "http://192.168.1.1:8080", + org: "my-org", + wantErr: true, + errContains: "insecure URL not allowed", + }, + { + name: "HTTP with domain rejected", + baseURL: "http://example.com", + org: "my-org", + wantErr: true, + errContains: "insecure URL not allowed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.baseURL, tt.org) + + if tt.wantErr { + if err == nil { + t.Errorf("NewClient(%q, %q) expected error containing %q, got nil", + tt.baseURL, tt.org, tt.errContains) + return + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("NewClient(%q, %q) error = %q, want error containing %q", + tt.baseURL, tt.org, err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Errorf("NewClient(%q, %q) unexpected error: %v", + tt.baseURL, tt.org, err) + return + } + + if client.baseURL != tt.wantBaseURL { + t.Errorf("NewClient(%q, %q) baseURL = %q, want %q", + tt.baseURL, tt.org, client.baseURL, tt.wantBaseURL) + } + + if client.org != tt.org { + t.Errorf("NewClient(%q, %q) org = %q, want %q", + tt.baseURL, tt.org, client.org, tt.org) + } + }) + } +} + +func TestNewClientWithOptions(t *testing.T) { + t.Run("WithTimeout option", func(t *testing.T) { + client, err := NewClient("https://api.github.com", "my-org", + WithTimeout(30)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client.httpClient.Timeout != 30*time.Second { + t.Errorf("timeout = %v, want %v", client.httpClient.Timeout, 30*time.Second) + } + }) + + t.Run("WithRetries option", func(t *testing.T) { + client, err := NewClient("https://api.github.com", "my-org", + WithRetries(5)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client.retries != 5 { + t.Errorf("retries = %d, want %d", client.retries, 5) + } + }) + + t.Run("WithAPIToken option", func(t *testing.T) { + client, err := NewClient("https://api.github.com", "my-org", + WithAPIToken("test-token")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client.apiToken != "test-token" { + t.Errorf("apiToken = %q, want %q", client.apiToken, "test-token") + } + }) + + t.Run("multiple options", func(t *testing.T) { + client, err := NewClient("https://api.github.com", "my-org", + WithTimeout(60), + WithRetries(10), + WithAPIToken("multi-token")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client.httpClient.Timeout != 60*time.Second { + t.Errorf("timeout = %v, want %v", client.httpClient.Timeout, 60*time.Second) + } + if client.retries != 10 { + t.Errorf("retries = %d, want %d", client.retries, 10) + } + if client.apiToken != "multi-token" { + t.Errorf("apiToken = %q, want %q", client.apiToken, "multi-token") + } + }) +} + +func TestValidOrgPattern(t *testing.T) { + validOrgs := []string{ + "github", + "my-org", + "my_org", + "MyOrg123", + "org-with-many-hyphens", + "org_with_many_underscores", + "MixedCase-and_underscores-123", + "a", + "A", + "1", + } + + for _, org := range validOrgs { + if !validOrgPattern.MatchString(org) { + t.Errorf("validOrgPattern should match %q", org) + } + } + + invalidOrgs := []string{ + "", + "has space", + "has/slash", + "has\\backslash", + "has@symbol", + "has!exclaim", + "has.dot", + "../traversal", + "org/../../../etc/passwd", + } + + for _, org := range invalidOrgs { + if validOrgPattern.MatchString(org) { + t.Errorf("validOrgPattern should not match %q", org) + } + } +}