Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/server/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions docs/server/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions docs/server/swagger.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 87 additions & 1 deletion pkg/audit/auditor.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,26 @@
return a.transportType == types.TransportTypeSSE.String()
}

// errorDetectionBufferSize is the maximum number of bytes buffered from the
// response body for JSON-RPC error detection. JSON-RPC error responses have
// the "error" field near the top of the object, so a small prefix is
// sufficient. This buffer is allocated independently of IncludeResponseData.
const errorDetectionBufferSize = 512

// maxAuditErrorMessageLength caps the JSON-RPC error message length stored
// in audit event metadata to keep log entries compact.
const maxAuditErrorMessageLength = 256

// responseWriter wraps http.ResponseWriter to capture response data and status.
type responseWriter struct {
http.ResponseWriter
statusCode int
body *bytes.Buffer
auditor *Auditor
// errorDetectionBody is a small prefix buffer used exclusively for
// JSON-RPC error detection. It is allocated when DetectApplicationErrors
// is true, independent of IncludeResponseData.
errorDetectionBody *bytes.Buffer
auditor *Auditor
}

func (rw *responseWriter) WriteHeader(statusCode int) {
Expand All @@ -127,6 +141,15 @@
rw.body.Write(data)
}
}
// Capture a small prefix for JSON-RPC error detection
if rw.errorDetectionBody != nil && rw.errorDetectionBody.Len() < errorDetectionBufferSize {
remaining := errorDetectionBufferSize - rw.errorDetectionBody.Len()
if len(data) <= remaining {
rw.errorDetectionBody.Write(data)
} else {
rw.errorDetectionBody.Write(data[:remaining])
}
}
return rw.ResponseWriter.Write(data)
}

Expand Down Expand Up @@ -201,6 +224,13 @@
rw.body = &bytes.Buffer{}
}

// Allocate a small prefix buffer for JSON-RPC error detection,
// independent of IncludeResponseData. When IncludeResponseData
// is already true, we reuse rw.body instead of double-buffering.
if a.config.ShouldDetectApplicationErrors() && !a.config.IncludeResponseData {
rw.errorDetectionBody = &bytes.Buffer{}
}

// Process the request
next.ServeHTTP(rw, r)

Expand All @@ -213,13 +243,36 @@
}

// logAuditEvent creates and logs an audit event for the HTTP request.
func (a *Auditor) logAuditEvent(r *http.Request, rw *responseWriter, requestData []byte, duration time.Duration) {

Check failure on line 246 in pkg/audit/auditor.go

View workflow job for this annotation

GitHub Actions / Linting / Lint Go Code

cyclomatic complexity 16 of func `(*Auditor).logAuditEvent` is high (> 15) (gocyclo)
// Determine event type based on the request
eventType := a.determineEventType(r)

// Determine outcome based on status code
outcome := a.determineOutcome(rw.statusCode)

// When HTTP status indicates success, check the response body for
// JSON-RPC errors (e.g., expired tokens wrapped inside HTTP 200).
// Reuse rw.body when IncludeResponseData is on to avoid double-buffering.
var mcpResponse *mcp.ParsedMCPResponse
if outcome == OutcomeSuccess && a.config.ShouldDetectApplicationErrors() {
var prefix []byte
if rw.body != nil && rw.body.Len() > 0 {
prefix = rw.body.Bytes()
if len(prefix) > errorDetectionBufferSize {
prefix = prefix[:errorDetectionBufferSize]
}
} else if rw.errorDetectionBody != nil && rw.errorDetectionBody.Len() > 0 {
prefix = rw.errorDetectionBody.Bytes()
}
// Only attempt JSON parse if the prefix looks like a JSON object
if len(prefix) > 0 && prefix[0] == '{' {
mcpResponse = mcp.ParseMCPResponse(prefix)
if mcpResponse.HasError {
outcome = OutcomeApplicationError
}
}
}

// Check if we should audit this event
if !a.config.ShouldAuditEvent(eventType) {
return
Expand All @@ -246,6 +299,20 @@
// Add metadata
a.addMetadata(event, r, duration, rw)

// Attach JSON-RPC error details so operators can see the error code
// and message without enabling full response data capture.
if outcome == OutcomeApplicationError {
if event.Metadata.Extra == nil {
event.Metadata.Extra = make(map[string]any)
}
event.Metadata.Extra["jsonrpc_error_code"] = mcpResponse.ErrorCode
msg := mcpResponse.ErrorMessage
if len(msg) > maxAuditErrorMessageLength {
msg = msg[:maxAuditErrorMessageLength]
}
event.Metadata.Extra["jsonrpc_error_message"] = msg
}

// Add request/response data if configured
a.addEventData(event, r, rw, requestData)

Expand Down Expand Up @@ -321,6 +388,25 @@
}
}

// detectApplicationError inspects the captured response body prefix for a
// JSON-RPC error field. It reuses rw.body when IncludeResponseData is
// enabled to avoid double-buffering.
func (*Auditor) detectApplicationError(rw *responseWriter) *mcp.ParsedMCPResponse {

Check failure on line 394 in pkg/audit/auditor.go

View workflow job for this annotation

GitHub Actions / Linting / Lint Go Code

func (*Auditor).detectApplicationError is unused (unused)
var prefix []byte
if rw.body != nil && rw.body.Len() > 0 {
prefix = rw.body.Bytes()
if len(prefix) > errorDetectionBufferSize {
prefix = prefix[:errorDetectionBufferSize]
}
} else if rw.errorDetectionBody != nil && rw.errorDetectionBody.Len() > 0 {
prefix = rw.errorDetectionBody.Bytes()
}
if len(prefix) > 0 && prefix[0] == '{' {
return mcp.ParseMCPResponse(prefix)
}
return nil
}

// extractSource extracts source information from the HTTP request.
func (a *Auditor) extractSource(r *http.Request) EventSource {
// Get the client IP address
Expand Down
199 changes: 199 additions & 0 deletions pkg/audit/auditor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,202 @@ func TestExtractSourceWithHeaders(t *testing.T) {
assert.Equal(t, "TestAgent/1.0", source.Extra[SourceExtraKeyUserAgent])
assert.Equal(t, "req-12345", source.Extra[SourceExtraKeyRequestID])
}

func TestErrorDetectionBodyCapture(t *testing.T) {
t.Parallel()

t.Run("captures prefix when DetectApplicationErrors is enabled", func(t *testing.T) {
t.Parallel()
detectErrors := true
config := &Config{
DetectApplicationErrors: &detectErrors,
}
auditor, err := NewAuditorWithTransport(config, "streamable-http")
require.NoError(t, err)

rw := &responseWriter{
ResponseWriter: httptest.NewRecorder(),
statusCode: http.StatusOK,
auditor: auditor,
errorDetectionBody: &bytes.Buffer{},
}

responseData := `{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"test error"}}`
_, err = rw.Write([]byte(responseData))
require.NoError(t, err)

assert.Equal(t, responseData, rw.errorDetectionBody.String())
})

t.Run("does not capture when DetectApplicationErrors is disabled", func(t *testing.T) {
t.Parallel()
detectErrors := false
config := &Config{
DetectApplicationErrors: &detectErrors,
}
auditor, err := NewAuditorWithTransport(config, "streamable-http")
require.NoError(t, err)

rw := &responseWriter{
ResponseWriter: httptest.NewRecorder(),
statusCode: http.StatusOK,
auditor: auditor,
// errorDetectionBody is nil when detection is disabled
}

_, err = rw.Write([]byte(`{"error":{"code":-32603}}`))
require.NoError(t, err)

assert.Nil(t, rw.errorDetectionBody)
})

t.Run("truncates capture at buffer size limit", func(t *testing.T) {
t.Parallel()
detectErrors := true
config := &Config{
DetectApplicationErrors: &detectErrors,
}
auditor, err := NewAuditorWithTransport(config, "streamable-http")
require.NoError(t, err)

rw := &responseWriter{
ResponseWriter: httptest.NewRecorder(),
statusCode: http.StatusOK,
auditor: auditor,
errorDetectionBody: &bytes.Buffer{},
}

// Write more than errorDetectionBufferSize bytes
largeData := bytes.Repeat([]byte("x"), errorDetectionBufferSize+100)
_, err = rw.Write(largeData)
require.NoError(t, err)

assert.Equal(t, errorDetectionBufferSize, rw.errorDetectionBody.Len())
})

t.Run("captures independently of IncludeResponseData", func(t *testing.T) {
t.Parallel()
detectErrors := true
config := &Config{
IncludeResponseData: false,
DetectApplicationErrors: &detectErrors,
}
auditor, err := NewAuditorWithTransport(config, "streamable-http")
require.NoError(t, err)

rw := &responseWriter{
ResponseWriter: httptest.NewRecorder(),
statusCode: http.StatusOK,
auditor: auditor,
errorDetectionBody: &bytes.Buffer{},
// body is nil because IncludeResponseData is false
}

responseData := `{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"unauthorized"}}`
_, err = rw.Write([]byte(responseData))
require.NoError(t, err)

// errorDetectionBody should capture even though body is nil
assert.Equal(t, responseData, rw.errorDetectionBody.String())
assert.Nil(t, rw.body)
})
}

func TestMiddlewareDetectsJSONRPCErrors(t *testing.T) {
t.Parallel()

t.Run("overrides outcome to application_error for JSON-RPC error response", func(t *testing.T) {
t.Parallel()
var logBuf bytes.Buffer
detectErrors := true
config := &Config{
DetectApplicationErrors: &detectErrors,
}
auditor, err := NewAuditorWithTransport(config, "streamable-http")
require.NoError(t, err)
auditor.auditLogger = NewAuditLogger(&logBuf)

errorResponse := `{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"GitLab API error: 401 Unauthorized"}}`
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(errorResponse))
require.NoError(t, err)
})

middleware := auditor.Middleware(handler)
req := httptest.NewRequest("POST", "/mcp", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/call","params":{"name":"test"}}`))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()

middleware.ServeHTTP(rr, req)

// The response should still be passed through unchanged
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, errorResponse, rr.Body.String())

// The audit log should contain application_error
logOutput := logBuf.String()
assert.Contains(t, logOutput, OutcomeApplicationError)
assert.Contains(t, logOutput, "jsonrpc_error_code")
})

t.Run("keeps outcome=success for valid JSON-RPC result", func(t *testing.T) {
t.Parallel()
var logBuf bytes.Buffer
detectErrors := true
config := &Config{
DetectApplicationErrors: &detectErrors,
}
auditor, err := NewAuditorWithTransport(config, "streamable-http")
require.NoError(t, err)
auditor.auditLogger = NewAuditLogger(&logBuf)

successResponse := `{"jsonrpc":"2.0","id":"1","result":{"content":[{"type":"text","text":"hello"}]}}`
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(successResponse))
require.NoError(t, err)
})

middleware := auditor.Middleware(handler)
req := httptest.NewRequest("POST", "/mcp", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/call","params":{"name":"test"}}`))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()

middleware.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)

logOutput := logBuf.String()
assert.NotContains(t, logOutput, OutcomeApplicationError)
})

t.Run("does not inspect body when DetectApplicationErrors is disabled", func(t *testing.T) {
t.Parallel()
var logBuf bytes.Buffer
detectErrors := false
config := &Config{
DetectApplicationErrors: &detectErrors,
}
auditor, err := NewAuditorWithTransport(config, "streamable-http")
require.NoError(t, err)
auditor.auditLogger = NewAuditLogger(&logBuf)

errorResponse := `{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"should not be detected"}}`
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte(errorResponse))
require.NoError(t, err)
})

middleware := auditor.Middleware(handler)
req := httptest.NewRequest("POST", "/mcp", strings.NewReader(`{"jsonrpc":"2.0","id":"1","method":"tools/call","params":{"name":"test"}}`))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()

middleware.ServeHTTP(rr, req)

logOutput := logBuf.String()
assert.NotContains(t, logOutput, OutcomeApplicationError)
})
}
Loading
Loading