Skip to content

Commit d39f6e6

Browse files
Jamie TannapebomikeschinkelMattiasMartens
committed
feat: add error handler with more configuration
As a means to **??**o This has been a long-standing issue **??** With thanks to Per, Mike and MattiasMartens who have **??**, as well as many others in the past who have **??** Closes #11, #27. Co-authored-by: Per Bockman <per.bockman@gmail.com> Co-authored-by: Mike Schinkel <mike@newclarity.net> Co-authored-by: MattiasMartens <33037520+MattiasMartens@users.noreply.github.com>
1 parent b73ed97 commit d39f6e6

3 files changed

Lines changed: 729 additions & 14 deletions

File tree

oapi_validate.go

Lines changed: 176 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package nethttpmiddleware
99

1010
import (
11+
"context"
1112
"errors"
1213
"fmt"
1314
"log"
@@ -21,8 +22,42 @@ import (
2122
)
2223

2324
// ErrorHandler is called when there is an error in validation
25+
//
26+
// NOTE that you likely want to use ErrorHandlerWithOpts, as it provides more **??** and access to the underlying `error`
27+
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence
2428
type ErrorHandler func(w http.ResponseWriter, message string, statusCode int)
2529

30+
// ErrorHandlerWithOpts provides
31+
// NOTE that this should ideally be used instead of ErrorHandler
32+
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence
33+
type ErrorHandlerWithOpts func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts ErrorHandlerOpts)
34+
35+
type ErrorHandlerOptsMatchedRoute struct {
36+
// Route indicates the Route that this error is received by, providing full context into the OpenAPI Spec TODO
37+
// TODO
38+
// NOTE nil when not found
39+
Route *routers.Route
40+
41+
// TODO
42+
// NOTE nil/empty when not found
43+
PathParams map[string]string
44+
}
45+
46+
// ErrorHandlerOpts contains additional options that are passed to the `ErrorHandlerWithOpts` function in the case of a validation error being returned by the middleware
47+
type ErrorHandlerOpts struct {
48+
// TODO
49+
// NOTE that this will be nil if there is no matched route (i.e. it's a **??**)
50+
MatchedRoute *ErrorHandlerOptsMatchedRoute
51+
52+
// Error is the underlying error that triggered this error handler to be executed.
53+
Error error
54+
55+
// StatusCode indicates the HTTP Status Code that TODO
56+
//
57+
// NOTE that this is a suggestion, and can be ignored if believed to be TODO
58+
StatusCode int
59+
}
60+
2661
// MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
2762
type MultiErrorHandler func(openapi3.MultiError) (int, error)
2863

@@ -32,11 +67,18 @@ type Options struct {
3267
Options openapi3filter.Options
3368
// ErrorHandler is called when a validation error occurs.
3469
//
70+
// TODO
71+
//
3572
// If not provided, `http.Error` will be called
3673
ErrorHandler ErrorHandler
74+
75+
// TODO
76+
ErrorHandlerWithOpts ErrorHandlerWithOpts
77+
3778
// MultiErrorHandler is called when there is an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError) returned by the `openapi3filter`.
3879
//
3980
// If not provided `defaultMultiErrorHandler` will be used.
81+
// NOTE nto called if ErrorHandlerWithOpts
4082
MultiErrorHandler MultiErrorHandler
4183
// SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil`
4284
SilenceServersWarning bool
@@ -62,27 +104,115 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne
62104

63105
return func(next http.Handler) http.Handler {
64106
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65-
// validate request
66-
statusCode, err := validateRequest(r, router, options)
67-
if err == nil {
68-
// serve
69-
next.ServeHTTP(w, r)
70-
return
71-
}
72-
73107
if options == nil {
74-
http.Error(w, err.Error(), statusCode)
75-
return
76-
}
77-
78-
if options.ErrorHandler != nil {
79-
options.ErrorHandler(w, err.Error(), statusCode)
108+
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
109+
} else if options.ErrorHandlerWithOpts != nil {
110+
performRequestValidationForErrorHandlerWithOpts(next, w, r, router, options)
111+
} else if options.ErrorHandler != nil {
112+
performRequestValidationForErrorHandler(next, w, r, router, options, options.ErrorHandler)
113+
} else {
114+
// NOTE that this shouldn't happen, but let's be sure that we always end up calling the default error handler if no other handler is defined
115+
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
80116
}
81117
})
82118
}
83119

84120
}
85121

122+
func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options, errorHandler ErrorHandler) {
123+
// validate request
124+
statusCode, err := validateRequest(r, router, options)
125+
if err == nil {
126+
// serve
127+
next.ServeHTTP(w, r)
128+
return
129+
}
130+
131+
errorHandler(w, err.Error(), statusCode)
132+
}
133+
134+
// **??**
135+
// Note that this is an inline-and-modified version of `validateRequest` that **??**.
136+
func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options) {
137+
// Find route
138+
route, pathParams, err := router.FindRoute(r)
139+
if err != nil {
140+
errOpts := ErrorHandlerOpts{
141+
// MatchedRoute will be nil, as we've not matched a route we know about
142+
Error: err,
143+
StatusCode: http.StatusNotFound,
144+
}
145+
146+
options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts)
147+
return
148+
}
149+
150+
errOpts := ErrorHandlerOpts{
151+
MatchedRoute: &ErrorHandlerOptsMatchedRoute{
152+
Route: route,
153+
PathParams: pathParams,
154+
},
155+
// other options will be added before executing
156+
}
157+
158+
// Validate request
159+
requestValidationInput := &openapi3filter.RequestValidationInput{
160+
Request: r,
161+
PathParams: pathParams,
162+
Route: route,
163+
}
164+
165+
if options != nil {
166+
requestValidationInput.Options = &options.Options
167+
}
168+
169+
err = openapi3filter.ValidateRequest(r.Context(), requestValidationInput)
170+
if err == nil {
171+
// it's a valid request, so serve it
172+
next.ServeHTTP(w, r)
173+
return
174+
}
175+
176+
switch e := err.(type) {
177+
case openapi3.MultiError:
178+
errOpts.Error = e
179+
errOpts.StatusCode = determineStatusCodeForMultiError(e)
180+
181+
// for _, err := range e {
182+
// errs = append(errs, err)
183+
// }
184+
//
185+
// errOpts.Error = errors.Join(errs...)
186+
//
187+
// fmt.Printf("e: %#v\n", e)
188+
// // there's no point returning a MultiError if there's a singleton error
189+
// if len(e) == 1 {
190+
// for _, ee := range e {
191+
// errOpts.Error = ee
192+
// fmt.Printf("ee: %#v\n", ee)
193+
// }
194+
//
195+
// }
196+
197+
case *openapi3filter.RequestError:
198+
// We've got a bad request
199+
errOpts.Error = e
200+
errOpts.StatusCode = http.StatusBadRequest
201+
case *openapi3filter.SecurityRequirementsError:
202+
// return http.StatusUnauthorized, err
203+
errOpts.Error = e
204+
errOpts.StatusCode = http.StatusUnauthorized
205+
default:
206+
// This should never happen today, but if our upstream code changes,
207+
// we don't want to crash the server, so handle the unexpected error.
208+
// return http.StatusInternalServerError,
209+
errOpts.Error = fmt.Errorf("error validating route: %w", e)
210+
errOpts.StatusCode = http.StatusUnauthorized
211+
}
212+
213+
options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts)
214+
}
215+
86216
// validateRequest is called from the middleware above and actually does the work
87217
// of validating a request.
88218
func validateRequest(r *http.Request, router routers.Router, options *Options) (int, error) {
@@ -150,3 +280,35 @@ func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler {
150280
func defaultMultiErrorHandler(me openapi3.MultiError) (int, error) {
151281
return http.StatusBadRequest, me
152282
}
283+
284+
func determineStatusCodeForMultiError(errs openapi3.MultiError) int {
285+
numRequestErrors := 0
286+
numSecurityRequirementsErrors := 0
287+
288+
for _, err := range errs {
289+
switch err.(type) {
290+
case *openapi3filter.RequestError:
291+
numRequestErrors++
292+
case *openapi3filter.SecurityRequirementsError:
293+
numSecurityRequirementsErrors++
294+
default:
295+
// if we have /any/ unknown error types, we should suggest returning an HTTP 500 Internal Server Error
296+
return http.StatusInternalServerError
297+
}
298+
}
299+
300+
if numRequestErrors > 0 && numSecurityRequirementsErrors > 0 {
301+
return http.StatusInternalServerError
302+
}
303+
304+
if numRequestErrors > 0 {
305+
return http.StatusBadRequest
306+
}
307+
308+
if numSecurityRequirementsErrors > 0 {
309+
return http.StatusUnauthorized
310+
}
311+
312+
// we shouldn't hit this, but to be safe, return an HTTP 500 Internal Server Error if we don't have any cases above
313+
return http.StatusInternalServerError
314+
}

0 commit comments

Comments
 (0)