diff --git a/pkg/webhook/errors.go b/pkg/webhook/errors.go index 53c47e9938..af7b6416b2 100644 --- a/pkg/webhook/errors.go +++ b/pkg/webhook/errors.go @@ -3,7 +3,11 @@ package webhook -import "fmt" +import ( + "errors" + "fmt" + "net/http" +) // WebhookError is the base error type for all webhook-related errors. // @@ -92,3 +96,10 @@ func NewInvalidResponseError(webhookName string, err error, statusCode int) *Inv StatusCode: statusCode, } } + +// IsAlwaysDenyError reports whether the webhook error should deny the request +// regardless of the configured failure policy. +func IsAlwaysDenyError(err error) bool { + var invalidRespErr *InvalidResponseError + return errors.As(err, &invalidRespErr) && invalidRespErr.StatusCode == http.StatusUnprocessableEntity +} diff --git a/pkg/webhook/errors_test.go b/pkg/webhook/errors_test.go index 7ce4847fa2..17b6f4baea 100644 --- a/pkg/webhook/errors_test.go +++ b/pkg/webhook/errors_test.go @@ -83,3 +83,41 @@ func TestWebhookErrorBaseType(t *testing.T) { assert.Equal(t, `webhook "base-test": some error`, err.Error()) assert.Equal(t, inner, err.Unwrap()) } + +func TestIsAlwaysDenyError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + { + name: "unprocessable entity invalid response", + err: NewInvalidResponseError("test", fmt.Errorf("unprocessable"), 422), + want: true, + }, + { + name: "other invalid response status", + err: NewInvalidResponseError("test", fmt.Errorf("bad request"), 400), + want: false, + }, + { + name: "invalid response without status", + err: NewInvalidResponseError("test", fmt.Errorf("decode error"), 0), + want: false, + }, + { + name: "non invalid response error", + err: NewNetworkError("test", fmt.Errorf("network")), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, IsAlwaysDenyError(tt.err)) + }) + } +} diff --git a/pkg/webhook/mutating/middleware.go b/pkg/webhook/mutating/middleware.go index ba7f116f3c..23fa6777ac 100644 --- a/pkg/webhook/mutating/middleware.go +++ b/pkg/webhook/mutating/middleware.go @@ -171,6 +171,12 @@ func executeSingleMutation( resp, err := exec.client.CallMutating(ctx, whReq) if err != nil { + if webhook.IsAlwaysDenyError(err) { + slog.Info("Mutating webhook denied request due to HTTP 422 response", "webhook", whName, "error", err) + sendErrorResponse(w, http.StatusUnprocessableEntity, "Request denied by webhook policy", msgID) + return nil, err + } + if exec.config.FailurePolicy == webhook.FailurePolicyIgnore { slog.Warn("Mutating webhook error ignored due to fail-open policy", "webhook", whName, "error", err) return currentBody, nil diff --git a/pkg/webhook/mutating/middleware_test.go b/pkg/webhook/mutating/middleware_test.go index 2a79216cef..1086f06a32 100644 --- a/pkg/webhook/mutating/middleware_test.go +++ b/pkg/webhook/mutating/middleware_test.go @@ -207,6 +207,47 @@ func TestMutatingMiddleware_WebhookError_IgnorePolicy(t *testing.T) { assert.JSONEq(t, reqBody, string(capturedBody)) } +//nolint:paralleltest // Uses httptest server. +func TestMutatingMiddleware_HTTP422AlwaysDenies(t *testing.T) { + tests := []struct { + name string + failurePolicy webhook.FailurePolicy + }{ + { + name: "fail policy", + failurePolicy: webhook.FailurePolicyFail, + }, + { + name: "ignore policy", + failurePolicy: webhook.FailurePolicyIgnore, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte("unprocessable request")) + })) + defer server.Close() + + cfg := makeConfig(server.URL, tt.failurePolicy) + mw := createMutatingHandler(makeExecutors(t, []webhook.Config{cfg}), "srv", "stdio") + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, makeMCPRequest(t, []byte(`{"jsonrpc":"2.0","id":1}`))) + + assert.False(t, nextCalled) + assert.Equal(t, http.StatusUnprocessableEntity, rr.Code) + assert.Contains(t, rr.Body.String(), "Request denied by webhook policy") + }) + } +} + func TestMutatingMiddleware_ScopeViolation_FailPolicy(t *testing.T) { t.Parallel() // Webhook tries to patch /principal/email — security violation. diff --git a/pkg/webhook/validating/middleware.go b/pkg/webhook/validating/middleware.go index 6a13b81720..582709796a 100644 --- a/pkg/webhook/validating/middleware.go +++ b/pkg/webhook/validating/middleware.go @@ -125,6 +125,13 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s resp, err := exec.client.Call(r.Context(), whReq) if err != nil { + if webhook.IsAlwaysDenyError(err) { + slog.Info("Validating webhook denied request due to HTTP 422 response", + "webhook", whName, "error", err) + sendErrorResponse(w, http.StatusForbidden, "Request denied by policy", parsedMCP.ID) + return + } + // Handle error based on failure policy if exec.config.FailurePolicy == webhook.FailurePolicyIgnore { slog.Warn("Validating webhook error ignored due to fail-open policy", diff --git a/pkg/webhook/validating/middleware_test.go b/pkg/webhook/validating/middleware_test.go index 0051345bed..5a9dc143c6 100644 --- a/pkg/webhook/validating/middleware_test.go +++ b/pkg/webhook/validating/middleware_test.go @@ -356,6 +356,66 @@ func TestCreateMiddleware(t *testing.T) { require.NoError(t, mw.Close()) } +//nolint:paralleltest // Uses httptest server. +func TestValidatingMiddleware_HTTP422AlwaysDenies(t *testing.T) { + tests := []struct { + name string + failurePolicy webhook.FailurePolicy + }{ + { + name: "fail policy", + failurePolicy: webhook.FailurePolicyFail, + }, + { + name: "ignore policy", + failurePolicy: webhook.FailurePolicyIgnore, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte("unprocessable request")) + })) + defer server.Close() + + cfg := webhook.Config{ + Name: "test-webhook", + URL: server.URL, + Timeout: webhook.DefaultTimeout, + FailurePolicy: tt.failurePolicy, + TLSConfig: &webhook.TLSConfig{ + InsecureSkipVerify: true, + }, + } + + client, err := webhook.NewClient(cfg, webhook.TypeValidating, nil) + require.NoError(t, err) + + mw := createValidatingHandler([]clientExecutor{{client: client, config: cfg}}, "test-server", "stdio") + + reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{Method: "tools/call", ID: 1}) + req = req.WithContext(ctx) + + var nextCalled bool + nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + nextCalled = true + }) + + rr := httptest.NewRecorder() + mw(nextHandler).ServeHTTP(rr, req) + + assert.False(t, nextCalled) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), "Request denied by policy") + }) + } +} + //nolint:paralleltest // Shares a mock HTTP server and lastRequest state func TestMultiWebhookChain(t *testing.T) { // Setup mock webhook servers