diff --git a/oapi_validate.go b/oapi_validate.go index 5bbce40..30ccf79 100644 --- a/oapi_validate.go +++ b/oapi_validate.go @@ -4,6 +4,7 @@ package nethttpmiddleware import ( + "context" "errors" "fmt" "log" @@ -19,14 +20,27 @@ import ( // ErrorHandler is called when there is an error in validation type ErrorHandler func(w http.ResponseWriter, message string, statusCode int) +// ErrorHandlerWithOpts is called when there is an error in validation, if found in Options. +// Passes error-handling specific opts over and above the standard error handler. +type ErrorHandlerWithOpts func(w http.ResponseWriter, message string, statusCode int, opts ErrorHandlerOpts) + +// ErrorHandlerOpts is used with ErrorHandlerWithOpts() +type ErrorHandlerOpts struct { + *http.Request + *routers.Route + context.Context + Error error +} + // MultiErrorHandler is called when oapi returns a MultiError type type MultiErrorHandler func(openapi3.MultiError) (int, error) // Options to customize request validation, openapi3filter specified options will be passed through. type Options struct { - Options openapi3filter.Options - ErrorHandler ErrorHandler - MultiErrorHandler MultiErrorHandler + Options openapi3filter.Options + ErrorHandler ErrorHandler + ErrorHandlerWithOpts ErrorHandlerWithOpts + MultiErrorHandler MultiErrorHandler // SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil` SilenceServersWarning bool } @@ -51,12 +65,23 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) func return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // validate request - if statusCode, err := validateRequest(r, router, options); err != nil { - if options != nil && options.ErrorHandler != nil { - options.ErrorHandler(w, err.Error(), statusCode) - } else { - http.Error(w, err.Error(), statusCode) + if statusCode, route, err := validateRequest(r, router, options); err != nil { + if options != nil { + if options.ErrorHandlerWithOpts != nil { + options.ErrorHandlerWithOpts(w, err.Error(), statusCode, ErrorHandlerOpts{ + Context: r.Context(), + Request: r, + Route: route, + Error: err, + }) + return + } + if options.ErrorHandler != nil { + options.ErrorHandler(w, err.Error(), statusCode) + return + } } + http.Error(w, err.Error(), statusCode) return } @@ -69,12 +94,12 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) func // validateRequest is called from the middleware above and actually does the work // of validating a request. -func validateRequest(r *http.Request, router routers.Router, options *Options) (int, error) { +func validateRequest(r *http.Request, router routers.Router, options *Options) (int, *routers.Route, error) { // Find route route, pathParams, err := router.FindRoute(r) if err != nil { - return http.StatusNotFound, err // We failed to find a matching route for the request. + return http.StatusNotFound, nil, err // We failed to find a matching route for the request. } // Validate request @@ -92,7 +117,8 @@ func validateRequest(r *http.Request, router routers.Router, options *Options) ( me := openapi3.MultiError{} if errors.As(err, &me) { errFunc := getMultiErrorHandlerFromOptions(options) - return errFunc(me) + status, err2 := errFunc(me) + return status, route, err2 } switch e := err.(type) { @@ -101,17 +127,17 @@ func validateRequest(r *http.Request, router routers.Router, options *Options) ( // Split up the verbose error by lines and return the first one // openapi errors seem to be multi-line with a decent message on the first errorLines := strings.Split(e.Error(), "\n") - return http.StatusBadRequest, fmt.Errorf(errorLines[0]) + return http.StatusBadRequest, route, fmt.Errorf(errorLines[0]) case *openapi3filter.SecurityRequirementsError: - return http.StatusUnauthorized, err + return http.StatusUnauthorized, route, err default: // This should never happen today, but if our upstream code changes, // we don't want to crash the server, so handle the unexpected error. - return http.StatusInternalServerError, fmt.Errorf("error validating route: %s", err.Error()) + return http.StatusInternalServerError, route, fmt.Errorf("error validating route: %s", err.Error()) } } - return http.StatusOK, nil + return http.StatusOK, route, nil } // attempt to get the MultiErrorHandler from the options. If it is not set,