Skip to content

Commit f6d9ee2

Browse files
author
Jamie Tanna
committed
sq
1 parent 8519112 commit f6d9ee2

3 files changed

Lines changed: 133 additions & 10 deletions

File tree

oapi_validate.go

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type ErrorHandlerOptsMatchedRoute struct {
4646
// ErrorHandlerOpts contains additional options that are passed to the `ErrorHandlerWithOpts` function in the case of a validation error being returned by the middleware
4747
type ErrorHandlerOpts struct {
4848
// TODO
49+
// NOTE that this will be nil if there is no matched route (i.e. it's a **??**)
4950
MatchedRoute *ErrorHandlerOptsMatchedRoute
5051

5152
// Error is the underlying error that triggered this error handler to be executed.
@@ -77,6 +78,7 @@ type Options struct {
7778
// MultiErrorHandler is called when there is an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError) returned by the `openapi3filter`.
7879
//
7980
// If not provided `defaultMultiErrorHandler` will be used.
81+
// NOTE nto called if ErrorHandlerWithOpts
8082
MultiErrorHandler MultiErrorHandler
8183
// SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil`
8284
SilenceServersWarning bool
@@ -171,16 +173,27 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R
171173
return
172174
}
173175

174-
// // TODO
175-
// me := openapi3.MultiError{}
176-
// if errors.As(err, &me) {
177-
// fmt.Printf("me: %v\n", me)
178-
// // errFunc := getMultiErrorHandlerFromOptions(options)
179-
// // return errFunc(me)
180-
// }
181-
// // TODO
182-
183176
switch e := err.(type) {
177+
case openapi3.MultiError:
178+
errOpts.Error = e
179+
errOpts.StatusCode = determineStatusCodeForMultiError(e)
180+
181+
// for _, err := range e {
182+
// errs = append(errs, err)
183+
// }
184+
//
185+
// errOpts.Error = errors.Join(errs...)
186+
//
187+
// fmt.Printf("e: %#v\n", e)
188+
// // there's no point returning a MultiError if there's a singleton error
189+
// if len(e) == 1 {
190+
// for _, ee := range e {
191+
// errOpts.Error = ee
192+
// fmt.Printf("ee: %#v\n", ee)
193+
// }
194+
//
195+
// }
196+
184197
case *openapi3filter.RequestError:
185198
// We've got a bad request
186199
errOpts.Error = e
@@ -267,3 +280,35 @@ func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler {
267280
func defaultMultiErrorHandler(me openapi3.MultiError) (int, error) {
268281
return http.StatusBadRequest, me
269282
}
283+
284+
func determineStatusCodeForMultiError(errs openapi3.MultiError) int {
285+
numRequestErrors := 0
286+
numSecurityRequirementsErrors := 0
287+
288+
for _, err := range errs {
289+
switch err.(type) {
290+
case *openapi3filter.RequestError:
291+
numRequestErrors++
292+
case *openapi3filter.SecurityRequirementsError:
293+
numSecurityRequirementsErrors++
294+
default:
295+
// if we have /any/ unknown error types, we should suggest returning an HTTP 500 Internal Server Error
296+
return http.StatusInternalServerError
297+
}
298+
}
299+
300+
if numRequestErrors > 0 && numSecurityRequirementsErrors > 0 {
301+
return http.StatusInternalServerError
302+
}
303+
304+
if numRequestErrors > 0 {
305+
return http.StatusBadRequest
306+
}
307+
308+
if numSecurityRequirementsErrors > 0 {
309+
return http.StatusUnauthorized
310+
}
311+
312+
// we shouldn't hit this, but to be safe, return an HTTP 500 Internal Server Error if we don't have any cases above
313+
return http.StatusInternalServerError
314+
}

oapi_validate_example_test.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"io"
910
"net/http"
@@ -729,6 +730,8 @@ paths:
729730
if childErr := e.Unwrap(); childErr != nil {
730731
out += "There was a child error, which was "
731732
switch e := childErr.(type) {
733+
case openapi3.MultiError:
734+
out += "a MultiError" + errors.Join(e).Error()
732735
case *openapi3.SchemaError:
733736
out += "a SchemaError, which failed to validate on the " + e.SchemaField + " field"
734737
default:
@@ -740,6 +743,16 @@ paths:
740743

741744
http.Error(w, "A bad request was made - but I'm not going to tell you where or how", opts.StatusCode)
742745
return
746+
// NOTE that when it's a MultiError, there's more work needed here
747+
case openapi3.MultiError:
748+
for _, eee := range e {
749+
fmt.Printf("eee: %v\n", eee)
750+
}
751+
http.Error(w, fmt.Sprintf("There were %d errors with that request - try again?", len(e)), opts.StatusCode)
752+
return
753+
// TODO
754+
// http.Error(w, "MULTI"+errors.Join(e).Error(), opts.StatusCode)
755+
// return
743756
}
744757

745758
http.Error(w, err.Error(), opts.StatusCode)
@@ -803,7 +816,7 @@ paths:
803816
// Received an HTTP 400 response. Expected HTTP 400
804817
// Response body: A bad request was made - but I'm not going to tell you where or how
805818
//
806-
// # A request that is malformed is rejected with HTTP 400 Bad Request (with an invalid request body), and is then logged by the ErrorHandlerWithOpts
819+
// # A request that is malformed is rejected with HTTP 400 Bad Request (with an invalid request body, with multiple issues), and is then logged by the ErrorHandlerWithOpts
807820
// ErrorHandlerWithOpts: A RequestError was returned when attempting to validate the request to POST /resource: request body has an error: doesn't match schema: Error at "/id": minimum string length is 100
808821
// Schema:
809822
// {

oapi_validate_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package nethttpmiddleware
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/getkin/kin-openapi/openapi3filter"
8+
)
9+
10+
func Test_determineStatusCodeForMultiError(t *testing.T) {
11+
t.Run("returns HTTP 400 Bad Request when only `RequestError`s", func(t *testing.T) {
12+
errs := []error{
13+
&openapi3filter.RequestError{},
14+
&openapi3filter.RequestError{},
15+
}
16+
17+
expected := 400
18+
actual := determineStatusCodeForMultiError(errs)
19+
20+
if expected != actual {
21+
t.Errorf("Expected an HTTP %d to be returned, but received %d", expected, actual)
22+
}
23+
})
24+
25+
t.Run("returns HTTP 401 Unauthorized when only `SecurityRequirementsError`s", func(t *testing.T) {
26+
errs := []error{
27+
&openapi3filter.SecurityRequirementsError{},
28+
&openapi3filter.SecurityRequirementsError{},
29+
}
30+
31+
expected := 401
32+
actual := determineStatusCodeForMultiError(errs)
33+
34+
if expected != actual {
35+
t.Errorf("Expected an HTTP %d to be returned, but received %d", expected, actual)
36+
}
37+
})
38+
39+
t.Run("returns HTTP 500 Internal Server Error when mixed error types", func(t *testing.T) {
40+
errs := []error{
41+
&openapi3filter.RequestError{},
42+
&openapi3filter.SecurityRequirementsError{},
43+
}
44+
45+
expected := 500
46+
actual := determineStatusCodeForMultiError(errs)
47+
48+
if expected != actual {
49+
t.Errorf("Expected an HTTP %d to be returned, but received %d", expected, actual)
50+
}
51+
})
52+
53+
t.Run("returns HTTP 500 Internal Server Error when unknown error type(s) are seen", func(t *testing.T) {
54+
errs := []error{
55+
fmt.Errorf("this isn't a known error type"),
56+
}
57+
58+
expected := 500
59+
actual := determineStatusCodeForMultiError(errs)
60+
61+
if expected != actual {
62+
t.Errorf("Expected an HTTP %d to be returned, but received %d", expected, actual)
63+
}
64+
})
65+
}

0 commit comments

Comments
 (0)