@@ -803,3 +803,202 @@ func TestExtractSourceWithHeaders(t *testing.T) {
803803 assert .Equal (t , "TestAgent/1.0" , source .Extra [SourceExtraKeyUserAgent ])
804804 assert .Equal (t , "req-12345" , source .Extra [SourceExtraKeyRequestID ])
805805}
806+
807+ func TestErrorDetectionBodyCapture (t * testing.T ) {
808+ t .Parallel ()
809+
810+ t .Run ("captures prefix when DetectApplicationErrors is enabled" , func (t * testing.T ) {
811+ t .Parallel ()
812+ detectErrors := true
813+ config := & Config {
814+ DetectApplicationErrors : & detectErrors ,
815+ }
816+ auditor , err := NewAuditorWithTransport (config , "streamable-http" )
817+ require .NoError (t , err )
818+
819+ rw := & responseWriter {
820+ ResponseWriter : httptest .NewRecorder (),
821+ statusCode : http .StatusOK ,
822+ auditor : auditor ,
823+ errorDetectionBody : & bytes.Buffer {},
824+ }
825+
826+ responseData := `{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"test error"}}`
827+ _ , err = rw .Write ([]byte (responseData ))
828+ require .NoError (t , err )
829+
830+ assert .Equal (t , responseData , rw .errorDetectionBody .String ())
831+ })
832+
833+ t .Run ("does not capture when DetectApplicationErrors is disabled" , func (t * testing.T ) {
834+ t .Parallel ()
835+ detectErrors := false
836+ config := & Config {
837+ DetectApplicationErrors : & detectErrors ,
838+ }
839+ auditor , err := NewAuditorWithTransport (config , "streamable-http" )
840+ require .NoError (t , err )
841+
842+ rw := & responseWriter {
843+ ResponseWriter : httptest .NewRecorder (),
844+ statusCode : http .StatusOK ,
845+ auditor : auditor ,
846+ // errorDetectionBody is nil when detection is disabled
847+ }
848+
849+ _ , err = rw .Write ([]byte (`{"error":{"code":-32603}}` ))
850+ require .NoError (t , err )
851+
852+ assert .Nil (t , rw .errorDetectionBody )
853+ })
854+
855+ t .Run ("truncates capture at buffer size limit" , func (t * testing.T ) {
856+ t .Parallel ()
857+ detectErrors := true
858+ config := & Config {
859+ DetectApplicationErrors : & detectErrors ,
860+ }
861+ auditor , err := NewAuditorWithTransport (config , "streamable-http" )
862+ require .NoError (t , err )
863+
864+ rw := & responseWriter {
865+ ResponseWriter : httptest .NewRecorder (),
866+ statusCode : http .StatusOK ,
867+ auditor : auditor ,
868+ errorDetectionBody : & bytes.Buffer {},
869+ }
870+
871+ // Write more than errorDetectionBufferSize bytes
872+ largeData := bytes .Repeat ([]byte ("x" ), errorDetectionBufferSize + 100 )
873+ _ , err = rw .Write (largeData )
874+ require .NoError (t , err )
875+
876+ assert .Equal (t , errorDetectionBufferSize , rw .errorDetectionBody .Len ())
877+ })
878+
879+ t .Run ("captures independently of IncludeResponseData" , func (t * testing.T ) {
880+ t .Parallel ()
881+ detectErrors := true
882+ config := & Config {
883+ IncludeResponseData : false ,
884+ DetectApplicationErrors : & detectErrors ,
885+ }
886+ auditor , err := NewAuditorWithTransport (config , "streamable-http" )
887+ require .NoError (t , err )
888+
889+ rw := & responseWriter {
890+ ResponseWriter : httptest .NewRecorder (),
891+ statusCode : http .StatusOK ,
892+ auditor : auditor ,
893+ errorDetectionBody : & bytes.Buffer {},
894+ // body is nil because IncludeResponseData is false
895+ }
896+
897+ responseData := `{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"unauthorized"}}`
898+ _ , err = rw .Write ([]byte (responseData ))
899+ require .NoError (t , err )
900+
901+ // errorDetectionBody should capture even though body is nil
902+ assert .Equal (t , responseData , rw .errorDetectionBody .String ())
903+ assert .Nil (t , rw .body )
904+ })
905+ }
906+
907+ func TestMiddlewareDetectsJSONRPCErrors (t * testing.T ) {
908+ t .Parallel ()
909+
910+ t .Run ("overrides outcome to application_error for JSON-RPC error response" , func (t * testing.T ) {
911+ t .Parallel ()
912+ var logBuf bytes.Buffer
913+ detectErrors := true
914+ config := & Config {
915+ DetectApplicationErrors : & detectErrors ,
916+ }
917+ auditor , err := NewAuditorWithTransport (config , "streamable-http" )
918+ require .NoError (t , err )
919+ auditor .auditLogger = NewAuditLogger (& logBuf )
920+
921+ errorResponse := `{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"GitLab API error: 401 Unauthorized"}}`
922+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
923+ w .WriteHeader (http .StatusOK )
924+ _ , err := w .Write ([]byte (errorResponse ))
925+ require .NoError (t , err )
926+ })
927+
928+ middleware := auditor .Middleware (handler )
929+ req := httptest .NewRequest ("POST" , "/mcp" , strings .NewReader (`{"jsonrpc":"2.0","id":"1","method":"tools/call","params":{"name":"test"}}` ))
930+ req .Header .Set ("Content-Type" , "application/json" )
931+ rr := httptest .NewRecorder ()
932+
933+ middleware .ServeHTTP (rr , req )
934+
935+ // The response should still be passed through unchanged
936+ assert .Equal (t , http .StatusOK , rr .Code )
937+ assert .Equal (t , errorResponse , rr .Body .String ())
938+
939+ // The audit log should contain application_error
940+ logOutput := logBuf .String ()
941+ assert .Contains (t , logOutput , OutcomeApplicationError )
942+ assert .Contains (t , logOutput , "jsonrpc_error_code" )
943+ })
944+
945+ t .Run ("keeps outcome=success for valid JSON-RPC result" , func (t * testing.T ) {
946+ t .Parallel ()
947+ var logBuf bytes.Buffer
948+ detectErrors := true
949+ config := & Config {
950+ DetectApplicationErrors : & detectErrors ,
951+ }
952+ auditor , err := NewAuditorWithTransport (config , "streamable-http" )
953+ require .NoError (t , err )
954+ auditor .auditLogger = NewAuditLogger (& logBuf )
955+
956+ successResponse := `{"jsonrpc":"2.0","id":"1","result":{"content":[{"type":"text","text":"hello"}]}}`
957+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
958+ w .WriteHeader (http .StatusOK )
959+ _ , err := w .Write ([]byte (successResponse ))
960+ require .NoError (t , err )
961+ })
962+
963+ middleware := auditor .Middleware (handler )
964+ req := httptest .NewRequest ("POST" , "/mcp" , strings .NewReader (`{"jsonrpc":"2.0","id":"1","method":"tools/call","params":{"name":"test"}}` ))
965+ req .Header .Set ("Content-Type" , "application/json" )
966+ rr := httptest .NewRecorder ()
967+
968+ middleware .ServeHTTP (rr , req )
969+
970+ assert .Equal (t , http .StatusOK , rr .Code )
971+
972+ logOutput := logBuf .String ()
973+ assert .NotContains (t , logOutput , OutcomeApplicationError )
974+ })
975+
976+ t .Run ("does not inspect body when DetectApplicationErrors is disabled" , func (t * testing.T ) {
977+ t .Parallel ()
978+ var logBuf bytes.Buffer
979+ detectErrors := false
980+ config := & Config {
981+ DetectApplicationErrors : & detectErrors ,
982+ }
983+ auditor , err := NewAuditorWithTransport (config , "streamable-http" )
984+ require .NoError (t , err )
985+ auditor .auditLogger = NewAuditLogger (& logBuf )
986+
987+ errorResponse := `{"jsonrpc":"2.0","id":"1","error":{"code":-32603,"message":"should not be detected"}}`
988+ handler := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
989+ w .WriteHeader (http .StatusOK )
990+ _ , err := w .Write ([]byte (errorResponse ))
991+ require .NoError (t , err )
992+ })
993+
994+ middleware := auditor .Middleware (handler )
995+ req := httptest .NewRequest ("POST" , "/mcp" , strings .NewReader (`{"jsonrpc":"2.0","id":"1","method":"tools/call","params":{"name":"test"}}` ))
996+ req .Header .Set ("Content-Type" , "application/json" )
997+ rr := httptest .NewRecorder ()
998+
999+ middleware .ServeHTTP (rr , req )
1000+
1001+ logOutput := logBuf .String ()
1002+ assert .NotContains (t , logOutput , OutcomeApplicationError )
1003+ })
1004+ }
0 commit comments