diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index fe6b8c24d..000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,36 +0,0 @@ -# Command - -Typical go commands to build, run and test - -- Run server (using Makefile): make - -# Stack - -Go, Echo (HTTP framework), GORM, PostgreSQL - -# Architecture - -- /cmd — DevGuard CLI tools and commands. Also devguard scanner logic can be found here -- /controllers - controller functions which are called by the router.go files. -- /services - Service functions to outsource business logic from controller function. -- /database - Everything related to the database of the projects like models, repository functions and migrations -- /database/repositories - Repository functions which interact with the database using GORM. -- /database/models - Here are all our models defined which reflect the tables in our database using GORM. Also object specific functions can be found here occasionally -- /daemons - All the background jobs which are triggered automatically -- /accesscontrol - All the for DevGuards access control -- /dtos - All the additional data types used across the project are defined and managed here. -- /integrations - All the logic for integration external entities like gitlab, github or jira -- /normalize - All the logic used to normalize core aspects of DevGuard like purls, SBOMs and so on -- /router - Defines all the accessible routes of the backend api server. Also defines what middlewares are used on the routes -- /shared - Common helper functions and definitions which are used almost everywhere in the project. -- /tests - All the test files, including unit tests as well as integrations tests for the whole code base -- /transformer - Transforms one dto to another to avoid import cycles -- /utils - utility functions -- /vulndb - All the logic for importing and updating the vulnerability database - -# Code Style - -- Add comments only where you need to explain the why behind the code -- use CamelCase for all variable names -- Always write code according to the repository pattern — routers call controllers, controllers call services, services call repositories, repositories use models as reflections -- Mocks in /mocks are auto-generated by mockery — never edit manually, run `mockery` to regenerate diff --git a/controllers/webhook_controller.go b/controllers/webhook_controller.go index d4bb0a1e6..ef1d3337e 100644 --- a/controllers/webhook_controller.go +++ b/controllers/webhook_controller.go @@ -303,8 +303,9 @@ func (w *WebhookController) HandleEvent(ctx context.Context, event any) error { //send sbom if err := client.SendSBOM(ctx, *event.SBOM, event.Org, event.Project, event.Asset, event.AssetVersion, event.Artifact); err != nil { slog.Error("failed to send SBOM to webhook", "webhookID", webhook.ID, "err", err) + } else { + slog.Info("webhook sent", "eventType", "sbom", "webhookID", webhook.ID, "org", event.Org.Name) } - slog.Info("SBOM sent to webhook", "webhookID", webhook.ID) } } case shared.FirstPartyVulnsDetectedEvent: @@ -323,8 +324,9 @@ func (w *WebhookController) HandleEvent(ctx context.Context, event any) error { //send vulnerability if err := client.SendFirstPartyVulnerabilities(ctx, vulns, event.Org, event.Project, event.Asset, event.AssetVersion); err != nil { slog.Error("failed to send vulnerability to webhook", "webhookID", webhook.ID, "err", err) + } else { + slog.Info("webhook sent", "eventType", "firstPartyVulnerabilities", "webhookID", webhook.ID, "org", event.Org.Name) } - slog.Info("Vulnerability sent to webhook", "webhookID", webhook.ID) } } @@ -344,8 +346,9 @@ func (w *WebhookController) HandleEvent(ctx context.Context, event any) error { //send vulnerability if err := client.SendDependencyVulnerabilities(ctx, vulns, event.Org, event.Project, event.Asset, event.AssetVersion, event.Artifact); err != nil { slog.Error("failed to send vulnerability to webhook", "webhookID", webhook.ID, "err", err) + } else { + slog.Info("webhook sent", "eventType", "dependencyVulnerabilities", "webhookID", webhook.ID, "org", event.Org.Name) } - slog.Info("Vulnerability sent to webhook", "webhookID", webhook.ID) } } } diff --git a/integrations/providers.go b/integrations/providers.go index 26c41ef3e..430401e50 100644 --- a/integrations/providers.go +++ b/integrations/providers.go @@ -16,6 +16,7 @@ package integrations import ( + "github.com/l3montree-dev/devguard/controllers" "github.com/l3montree-dev/devguard/integrations/githubint" "github.com/l3montree-dev/devguard/integrations/gitlabint" "github.com/l3montree-dev/devguard/integrations/jiraint" @@ -38,8 +39,8 @@ var Module = fx.Options( // Aggregated Third Party Integration fx.Provide(fx.Annotate( - func(externalUserRepository shared.ExternalUserRepository, gitlabIntegration *gitlabint.GitlabIntegration, githubIntegration *githubint.GithubIntegration, jiraIntegration *jiraint.JiraIntegration) shared.IntegrationAggregate { - return NewThirdPartyIntegrations(externalUserRepository, githubIntegration, jiraIntegration, gitlabIntegration) + func(externalUserRepository shared.ExternalUserRepository, gitlabIntegration *gitlabint.GitlabIntegration, githubIntegration *githubint.GithubIntegration, jiraIntegration *jiraint.JiraIntegration, webhookIntegration *controllers.WebhookController) shared.IntegrationAggregate { + return NewThirdPartyIntegrations(externalUserRepository, githubIntegration, jiraIntegration, gitlabIntegration, webhookIntegration) }, fx.As(new(shared.IntegrationAggregate)), )), diff --git a/services/webhook_service.go b/services/webhook_service.go index 3482a5df3..3ab6a6168 100644 --- a/services/webhook_service.go +++ b/services/webhook_service.go @@ -51,16 +51,18 @@ const ( ) type webhookClient struct { - URL string - Secret *string - httpClient *http.Client + URL string + Secret *string + httpClient *http.Client + retryDelays []time.Duration } func NewWebhookService(url string, secret *string) *webhookClient { return &webhookClient{ - URL: url, - Secret: secret, - httpClient: &http.Client{Transport: utils.EgressTransport}, + URL: url, + Secret: secret, + httpClient: &http.Client{Transport: utils.EgressTransport}, + retryDelays: []time.Duration{1 * time.Second, 5 * time.Second, 10 * time.Second}, } } @@ -73,56 +75,77 @@ func (c *webhookClient) CreateRequest(ctx context.Context, method, url string, b ctx, cancel := context.WithTimeout(ctx, 120*time.Second) defer cancel() - // Retry logic with delays: 1s, 5s, 10s - retryDelays := []time.Duration{1 * time.Second, 5 * time.Second, 10 * time.Second} - - var resp *http.Response + var ( + resp *http.Response + lastErr error + ) + + for i, delay := range c.retryDelays { + // Drain and close the previous iteration's body so the connection can be reused. + if resp != nil { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + resp = nil + } - for i, delay := range retryDelays { req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(bodyBytes)) if err != nil { return nil, err } - if c.Secret != nil { req.Header.Set("X-Webhook-Secret", *c.Secret) } - req.Header.Set("Content-Type", "application/json") - resp, err = c.httpClient.Do(req) + resp, lastErr = c.httpClient.Do(req) - if err == nil && resp != nil && resp.StatusCode >= 200 && resp.StatusCode < 300 { + // Don't retry on 2xx or permanent 4xx — only 408/429 are retryable in the 4xx range. + if lastErr == nil && resp.StatusCode < 500 && + resp.StatusCode != http.StatusRequestTimeout && + resp.StatusCode != http.StatusTooManyRequests { return resp, nil } - if i == len(retryDelays)-1 { - return nil, fmt.Errorf("webhook request failed with no response") + if i == len(c.retryDelays)-1 { + break } - time.Sleep(delay) + select { + case <-ctx.Done(): + if resp != nil { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + return nil, ctx.Err() + case <-time.After(delay): + } } - // This should never be reached due to the break condition above - return nil, fmt.Errorf("unexpected end of retry loop") - + if lastErr != nil { + // http.Client.Do can return a non-nil response together with an error + // (e.g. CheckRedirect failures). Drain and close so the connection isn't leaked. + if resp != nil { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + return nil, lastErr + } + return resp, nil } -func (c *webhookClient) SendSBOM(ctx context.Context, SBOM cdx.BOM, org shared.OrgObject, project shared.ProjectObject, asset shared.AssetObject, assetVersion shared.AssetVersionObject, artifact shared.ArtifactObject) error { - +func (c *webhookClient) send(ctx context.Context, webhookType WebhookType, payload any, org shared.OrgObject, project shared.ProjectObject, asset shared.AssetObject, assetVersion shared.AssetVersionObject, artifact shared.ArtifactObject) error { body := WebhookStruct{ Organization: org, Project: project, Asset: asset, AssetVersion: assetVersion, - Payload: SBOM, - Type: WebhookTypeSBOM, + Payload: payload, + Type: webhookType, Artifact: artifact, } var buf bytes.Buffer - err := json.NewEncoder(&buf).Encode(body) - if err != nil { + if err := json.NewEncoder(&buf).Encode(body); err != nil { return err } @@ -130,141 +153,45 @@ func (c *webhookClient) SendSBOM(ctx context.Context, SBOM cdx.BOM, org shared.O if err != nil { return err } - if resp == nil { - return fmt.Errorf("received nil response when sending SBOM") - } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - return fmt.Errorf("failed to send SBOM, status: %s", resp.Status) - } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("webhook %s failed, status: %s", webhookType, resp.Status) + } return nil } -func (c *webhookClient) SendFirstPartyVulnerabilities(ctx context.Context, vuln []dtos.FirstPartyVulnDTO, org shared.OrgObject, project shared.ProjectObject, asset shared.AssetObject, assetVersion shared.AssetVersionObject) error { - return nil - - /*body := WebhookStruct{ - Organization: org, - Project: project, - Asset: asset, - AssetVersion: assetVersion, - Payload: vuln, - Type: WebhookTypeFirstPartyVulnerabilities, - } - - var buf bytes.Buffer - err := json.NewEncoder(&buf).Encode(body) - if err != nil { - return err - } - - resp, err := c.CreateRequest("POST", c.URL, &buf) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to send vulnerability, status: %s,", resp.Status) - } +func (c *webhookClient) SendSBOM(ctx context.Context, SBOM cdx.BOM, org shared.OrgObject, project shared.ProjectObject, asset shared.AssetObject, assetVersion shared.AssetVersionObject, artifact shared.ArtifactObject) error { + return c.send(ctx, WebhookTypeSBOM, SBOM, org, project, asset, assetVersion, artifact) +} - return nil*/ +func (c *webhookClient) SendFirstPartyVulnerabilities(ctx context.Context, vuln []dtos.FirstPartyVulnDTO, org shared.OrgObject, project shared.ProjectObject, asset shared.AssetObject, assetVersion shared.AssetVersionObject) error { + return c.send(ctx, WebhookTypeFirstPartyVulnerabilities, vuln, org, project, asset, assetVersion, shared.ArtifactObject{}) } func (c *webhookClient) SendDependencyVulnerabilities(ctx context.Context, vuln []dtos.DependencyVulnDTO, org shared.OrgObject, project shared.ProjectObject, asset shared.AssetObject, assetVersion shared.AssetVersionObject, artifact shared.ArtifactObject) error { - - body := WebhookStruct{ - Organization: org, - Project: project, - Asset: asset, - AssetVersion: assetVersion, - Payload: vuln, - Artifact: artifact, - Type: WebhookTypeDependencyVulnerabilities, - } - - var buf bytes.Buffer - err := json.NewEncoder(&buf).Encode(body) - if err != nil { - return err - } - - resp, err := c.CreateRequest(ctx, "POST", c.URL, &buf) - if err != nil { - return err - } - if resp == nil { - return fmt.Errorf("received nil response when sending dependency vulnerabilities") - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to send vulnerability, status: %s", resp.Status) - } - - return nil + return c.send(ctx, WebhookTypeDependencyVulnerabilities, vuln, org, project, asset, assetVersion, artifact) } func (c *webhookClient) SendTest(ctx context.Context, org shared.OrgObject, project shared.ProjectObject, asset shared.AssetObject, assetVersion shared.AssetVersionObject, payloadType TestPayloadType) error { + payload, webhookType := testPayload(payloadType) + return c.send(ctx, webhookType, payload, org, project, asset, assetVersion, shared.ArtifactObject{}) +} - var payload any - var webhookType WebhookType - +func testPayload(payloadType TestPayloadType) (any, WebhookType) { switch payloadType { - case TestPayloadTypeEmpty: - payload = map[string]any{ - "message": "This is a test webhook from DevGuard", - "timestamp": time.Now().UTC().Format(time.RFC3339), - } - webhookType = WebhookTypeTest - case TestPayloadTypeSampleSBOM: - payload = createSampleSBOM() - webhookType = WebhookTypeSBOM - + return createSampleSBOM(), WebhookTypeSBOM case TestPayloadTypeSampleDependencyVulns: - payload = createSampleDependencyVulns() - webhookType = WebhookTypeDependencyVulnerabilities - + return createSampleDependencyVulns(), WebhookTypeDependencyVulnerabilities case TestPayloadTypeSampleFirstPartyVulns: - payload = createSampleFirstPartyVulns() - webhookType = WebhookTypeFirstPartyVulnerabilities - + return createSampleFirstPartyVulns(), WebhookTypeFirstPartyVulnerabilities default: - payload = map[string]any{ + return map[string]any{ "message": "This is a test webhook from DevGuard", "timestamp": time.Now().UTC().Format(time.RFC3339), - } - webhookType = WebhookTypeTest + }, WebhookTypeTest } - - body := WebhookStruct{ - Organization: org, - Project: project, - Asset: asset, - AssetVersion: assetVersion, - Payload: payload, - Type: webhookType, - } - - var buf bytes.Buffer - err := json.NewEncoder(&buf).Encode(body) - if err != nil { - return err - } - - resp, err := c.CreateRequest(ctx, "POST", c.URL, &buf) - if err != nil { - return err - } - if resp == nil { - return fmt.Errorf("received nil response when sending test webhook") - } - defer resp.Body.Close() - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - return nil // Success - } - - return fmt.Errorf("failed to send test webhook, status: %s", resp.Status) } func createSampleSBOM() cdx.BOM { diff --git a/services/webhook_service_test.go b/services/webhook_service_test.go index aa7914c86..ee440170e 100644 --- a/services/webhook_service_test.go +++ b/services/webhook_service_test.go @@ -9,10 +9,18 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func newTestWebhookService(url string) *webhookClient { + webhookClient := NewWebhookService(url, nil) + webhookClient.retryDelays = []time.Duration{0, 0, 0} + return webhookClient +} + func TestWebhookClient_CreateRequest_RetryLogic(t *testing.T) { t.Run("should succeed on first attempt when request is successful", func(t *testing.T) { attemptCount := 0 @@ -25,7 +33,7 @@ func TestWebhookClient_CreateRequest_RetryLogic(t *testing.T) { })) defer server.Close() - client := NewWebhookService(server.URL, nil) + client := newTestWebhookService(server.URL) body := strings.NewReader(`{"test": "data"}`) resp, err := client.CreateRequest(context.Background(), "POST", server.URL, body) @@ -37,24 +45,66 @@ func TestWebhookClient_CreateRequest_RetryLogic(t *testing.T) { resp.Body.Close() }) - t.Run("should make exactly 3 attempts when requests fail", func(t *testing.T) { + t.Run("should retry 3 times on 5xx and return the last response", func(t *testing.T) { attemptCount := 0 - // Setup test server that always fails server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attemptCount++ w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() - client := NewWebhookService(server.URL, nil) + client := newTestWebhookService(server.URL) + body := strings.NewReader(`{"test": "data"}`) + + resp, err := client.CreateRequest(context.Background(), "POST", server.URL, body) + + assert.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.Equal(t, 3, attemptCount, "Should make exactly 3 attempts on 5xx") + }) + + t.Run("should not retry on 4xx client errors", func(t *testing.T) { + attemptCount := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + w.WriteHeader(http.StatusBadRequest) + })) + defer server.Close() + + client := newTestWebhookService(server.URL) body := strings.NewReader(`{"test": "data"}`) resp, err := client.CreateRequest(context.Background(), "POST", server.URL, body) - assert.Error(t, err) - assert.Nil(t, resp) - assert.Equal(t, 3, attemptCount, "Should make exactly 3 attempts") - assert.Contains(t, err.Error(), "webhook request failed with no response") + assert.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.Equal(t, 1, attemptCount, "Should not retry on 4xx") + }) + + t.Run("should retry on 429 Too Many Requests", func(t *testing.T) { + attemptCount := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + w.WriteHeader(http.StatusTooManyRequests) + })) + defer server.Close() + + client := newTestWebhookService(server.URL) + body := strings.NewReader(`{"test": "data"}`) + + resp, err := client.CreateRequest(context.Background(), "POST", server.URL, body) + + assert.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.Equal(t, 3, attemptCount, "Should retry on 429") }) }