Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions oapi_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ type ErrorHandlerOptsMatchedRoute struct {
// MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
type MultiErrorHandler func(openapi3.MultiError) (int, error)

// Skipper is a function that runs before any validation middleware, and determines whether the given request should skip any validation middleware
//
// Return `true` if the request should be skipped
type Skipper func(r *http.Request) bool

// Options allows configuring the OapiRequestValidator.
type Options struct {
// Options contains any configuration for the underlying `openapi3filter`
Expand Down Expand Up @@ -103,6 +108,9 @@ type Options struct {
// Prefix allows (optionally) trimming a prefix from the API path.
// This may be useful if your API is routed to an internal path that is different from the OpenAPI specification.
Prefix string

// Skipper allows writing a function that runs before any middleware and determines whether the given request should skip any validation middleware
Skipper Skipper
}

// OapiRequestValidator Creates the middleware to validate that incoming requests match the given OpenAPI 3.x spec, with a default set of configuration.
Expand All @@ -129,6 +137,11 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if options != nil && options.Skipper != nil && options.Skipper(r) {
next.ServeHTTP(w, r)
return
}

if options == nil {
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
} else if options.ErrorHandlerWithOpts != nil {
Expand Down
139 changes: 139 additions & 0 deletions oapi_validate_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,3 +957,142 @@ paths:
// POST /public-api/v1/resource was called
// Received an HTTP 204 response. Expected HTTP 204
}

func ExampleOapiRequestValidatorWithOptions_withSkipper() {
rawSpec := `
openapi: "3.0.0"
info:
version: 1.0.0
title: TestServer
servers:
- url: http://example.com/
paths:
# we also have a /healthz, but it's not externally documented, so the middleware CANNOT run against it, or it'll block requests
/resource:
post:
operationId: createResource
responses:
'204':
description: No content
requestBody:
required: true
content:
text/plain: {}
`

must := func(err error) {
if err != nil {
panic(err)
}
}

use := func(r *http.ServeMux, middlewares ...func(next http.Handler) http.Handler) http.Handler {
var s http.Handler
s = r

for _, mw := range middlewares {
s = mw(s)
}

return s
}

logResponseBody := func(rr *httptest.ResponseRecorder) {
if rr.Result().Body != nil {
data, _ := io.ReadAll(rr.Result().Body)
if len(data) > 0 {
fmt.Printf("Response body: %s", data)
}
}
}

spec, err := openapi3.NewLoader().LoadFromData([]byte(rawSpec))
must(err)

// NOTE that we need to make sure that the `Servers` aren't set, otherwise the OpenAPI validation middleware will validate that the `Host` header (of incoming requests) are targeting known `Servers` in the OpenAPI spec
// See also: Options#SilenceServersWarning
spec.Servers = nil

router := http.NewServeMux()

router.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("%s /resource was called\n", r.Method)

if r.Method == http.MethodPost {
data, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
fmt.Printf("Request body: %s\n", data)
w.WriteHeader(http.StatusNoContent)
return
}

w.WriteHeader(http.StatusMethodNotAllowed)
})

router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

authenticationFunc := func(ctx context.Context, ai *openapi3filter.AuthenticationInput) error {
fmt.Printf("`AuthenticationFunc` was called for securitySchemeName=%s\n", ai.SecuritySchemeName)
return fmt.Errorf("this check always fails - don't let anyone in!")
}

skipperFunc := func(r *http.Request) bool {
// skip the undocumented healthcheck endpoint
if r.URL.Path == "/healthz" {
return true
}

return false
}

// create middleware
mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Options: openapi3filter.Options{
AuthenticationFunc: authenticationFunc,
},
Skipper: skipperFunc,
})

// then wire it in
server := use(router, mw)

// ================================================================================
fmt.Println("# A request that is made to the undocumented healthcheck endpoint does not get validated")

req, err := http.NewRequest(http.MethodGet, "/healthz", http.NoBody)
must(err)

rr := httptest.NewRecorder()

server.ServeHTTP(rr, req)

fmt.Printf("Received an HTTP %d response. Expected HTTP 200\n", rr.Code)
logResponseBody(rr)

// ================================================================================
fmt.Println("# A request that is well-formed is passed through to the Handler")

req, err = http.NewRequest(http.MethodPost, "/resource", bytes.NewReader([]byte("Hello there")))
must(err)
req.Header.Set("Content-Type", "text/plain")

rr = httptest.NewRecorder()

server.ServeHTTP(rr, req)

fmt.Printf("Received an HTTP %d response. Expected HTTP 204\n", rr.Code)
logResponseBody(rr)

// Output:
// # A request that is made to the undocumented healthcheck endpoint does not get validated
// Received an HTTP 200 response. Expected HTTP 200
// # A request that is well-formed is passed through to the Handler
// POST /resource was called
// Request body: Hello there
// Received an HTTP 204 response. Expected HTTP 204
}
Loading