Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pkg/webhook/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

package webhook

import "fmt"
import (
"errors"
"fmt"
"net/http"
)

// WebhookError is the base error type for all webhook-related errors.
//
Expand Down Expand Up @@ -92,3 +96,10 @@ func NewInvalidResponseError(webhookName string, err error, statusCode int) *Inv
StatusCode: statusCode,
}
}

// IsAlwaysDenyError reports whether the webhook error should deny the request
// regardless of the configured failure policy.
func IsAlwaysDenyError(err error) bool {
var invalidRespErr *InvalidResponseError
return errors.As(err, &invalidRespErr) && invalidRespErr.StatusCode == http.StatusUnprocessableEntity
}
38 changes: 38 additions & 0 deletions pkg/webhook/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,41 @@ func TestWebhookErrorBaseType(t *testing.T) {
assert.Equal(t, `webhook "base-test": some error`, err.Error())
assert.Equal(t, inner, err.Unwrap())
}

func TestIsAlwaysDenyError(t *testing.T) {
t.Parallel()

tests := []struct {
name string
err error
want bool
}{
{
name: "unprocessable entity invalid response",
err: NewInvalidResponseError("test", fmt.Errorf("unprocessable"), 422),
want: true,
},
{
name: "other invalid response status",
err: NewInvalidResponseError("test", fmt.Errorf("bad request"), 400),
want: false,
},
{
name: "invalid response without status",
err: NewInvalidResponseError("test", fmt.Errorf("decode error"), 0),
want: false,
},
{
name: "non invalid response error",
err: NewNetworkError("test", fmt.Errorf("network")),
want: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, IsAlwaysDenyError(tt.err))
})
}
}
6 changes: 6 additions & 0 deletions pkg/webhook/mutating/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ func executeSingleMutation(

resp, err := exec.client.CallMutating(ctx, whReq)
if err != nil {
if webhook.IsAlwaysDenyError(err) {
slog.Info("Mutating webhook denied request due to HTTP 422 response", "webhook", whName, "error", err)
sendErrorResponse(w, http.StatusUnprocessableEntity, "Request denied by webhook policy", msgID)
return nil, err
}

if exec.config.FailurePolicy == webhook.FailurePolicyIgnore {
slog.Warn("Mutating webhook error ignored due to fail-open policy", "webhook", whName, "error", err)
return currentBody, nil
Expand Down
41 changes: 41 additions & 0 deletions pkg/webhook/mutating/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,47 @@ func TestMutatingMiddleware_WebhookError_IgnorePolicy(t *testing.T) {
assert.JSONEq(t, reqBody, string(capturedBody))
}

//nolint:paralleltest // Uses httptest server.
func TestMutatingMiddleware_HTTP422AlwaysDenies(t *testing.T) {
tests := []struct {
name string
failurePolicy webhook.FailurePolicy
}{
{
name: "fail policy",
failurePolicy: webhook.FailurePolicyFail,
},
{
name: "ignore policy",
failurePolicy: webhook.FailurePolicyIgnore,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnprocessableEntity)
_, _ = w.Write([]byte("unprocessable request"))
}))
defer server.Close()

cfg := makeConfig(server.URL, tt.failurePolicy)
mw := createMutatingHandler(makeExecutors(t, []webhook.Config{cfg}), "srv", "stdio")

var nextCalled bool
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { nextCalled = true })

rr := httptest.NewRecorder()
mw(nextHandler).ServeHTTP(rr, makeMCPRequest(t, []byte(`{"jsonrpc":"2.0","id":1}`)))

assert.False(t, nextCalled)
assert.Equal(t, http.StatusUnprocessableEntity, rr.Code)
assert.Contains(t, rr.Body.String(), "Request denied by webhook policy")
})
}
}

func TestMutatingMiddleware_ScopeViolation_FailPolicy(t *testing.T) {
t.Parallel()
// Webhook tries to patch /principal/email — security violation.
Expand Down
7 changes: 7 additions & 0 deletions pkg/webhook/validating/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ func createValidatingHandler(executors []clientExecutor, serverName, transport s

resp, err := exec.client.Call(r.Context(), whReq)
if err != nil {
if webhook.IsAlwaysDenyError(err) {
slog.Info("Validating webhook denied request due to HTTP 422 response",
"webhook", whName, "error", err)
sendErrorResponse(w, http.StatusForbidden, "Request denied by policy", parsedMCP.ID)
return
}

// Handle error based on failure policy
if exec.config.FailurePolicy == webhook.FailurePolicyIgnore {
slog.Warn("Validating webhook error ignored due to fail-open policy",
Expand Down
60 changes: 60 additions & 0 deletions pkg/webhook/validating/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,66 @@ func TestCreateMiddleware(t *testing.T) {
require.NoError(t, mw.Close())
}

//nolint:paralleltest // Uses httptest server.
func TestValidatingMiddleware_HTTP422AlwaysDenies(t *testing.T) {
tests := []struct {
name string
failurePolicy webhook.FailurePolicy
}{
{
name: "fail policy",
failurePolicy: webhook.FailurePolicyFail,
},
{
name: "ignore policy",
failurePolicy: webhook.FailurePolicyIgnore,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnprocessableEntity)
_, _ = w.Write([]byte("unprocessable request"))
}))
defer server.Close()

cfg := webhook.Config{
Name: "test-webhook",
URL: server.URL,
Timeout: webhook.DefaultTimeout,
FailurePolicy: tt.failurePolicy,
TLSConfig: &webhook.TLSConfig{
InsecureSkipVerify: true,
},
}

client, err := webhook.NewClient(cfg, webhook.TypeValidating, nil)
require.NoError(t, err)

mw := createValidatingHandler([]clientExecutor{{client: client, config: cfg}}, "test-server", "stdio")

reqBody := []byte(`{"jsonrpc":"2.0","method":"tools/call","id":1}`)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, &mcp.ParsedMCPRequest{Method: "tools/call", ID: 1})
req = req.WithContext(ctx)

var nextCalled bool
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
nextCalled = true
})

rr := httptest.NewRecorder()
mw(nextHandler).ServeHTTP(rr, req)

assert.False(t, nextCalled)
assert.Equal(t, http.StatusForbidden, rr.Code)
assert.Contains(t, rr.Body.String(), "Request denied by policy")
})
}
}

//nolint:paralleltest // Shares a mock HTTP server and lastRequest state
func TestMultiWebhookChain(t *testing.T) {
// Setup mock webhook servers
Expand Down
Loading