Skip to content

Commit 96b28bd

Browse files
committed
feat: add ability to skip validation checks
1 parent 229a92f commit 96b28bd

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

oapi_validate.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ type ErrorHandlerOptsMatchedRoute struct {
7474
// MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
7575
type MultiErrorHandler func(openapi3.MultiError) (int, error)
7676

77+
// Skipper is a function that runs before any validation middleware, and determines whether the given request should skip any validation middleware
78+
//
79+
// Return `true` if the request should be skipped
80+
type Skipper func(r *http.Request) bool
81+
7782
// Options allows configuring the OapiRequestValidator.
7883
type Options struct {
7984
// Options contains any configuration for the underlying `openapi3filter`
@@ -100,6 +105,9 @@ type Options struct {
100105
SilenceServersWarning bool
101106
// DoNotValidateServers ensures that there is no Host validation performed (see `SilenceServersWarning` and https://github.com/deepmap/oapi-codegen/issues/882 for more details)
102107
DoNotValidateServers bool
108+
109+
// Skipper allows writing a function that runs before any middleware and determines whether the given request should skip any validation middleware
110+
Skipper Skipper
103111
}
104112

105113
// OapiRequestValidator Creates the middleware to validate that incoming requests match the given OpenAPI 3.x spec, with a default set of configuration.
@@ -126,6 +134,15 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne
126134

127135
return func(next http.Handler) http.Handler {
128136
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
137+
if options != nil && options.Skipper != nil {
138+
// copy to make sure the body can't be tampered with
139+
r2 := r.Clone(r.Context())
140+
if options.Skipper(r2) {
141+
next.ServeHTTP(w, r)
142+
return
143+
}
144+
}
145+
129146
if options == nil {
130147
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
131148
} else if options.ErrorHandlerWithOpts != nil {

oapi_validate_example_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,146 @@ paths:
839839
// Received an HTTP 400 response. Expected HTTP 400
840840
// Response body: There was a bad request
841841
}
842+
843+
func ExampleOapiRequestValidatorWithOptions_withSkipper() {
844+
rawSpec := `
845+
openapi: "3.0.0"
846+
info:
847+
version: 1.0.0
848+
title: TestServer
849+
servers:
850+
- url: http://example.com/
851+
paths:
852+
# we also have a /healthz, but it's not externally documented, so the middleware CANNOT run against it, or it'll block requests
853+
/resource:
854+
post:
855+
operationId: createResource
856+
responses:
857+
'204':
858+
description: No content
859+
requestBody:
860+
required: true
861+
content:
862+
text/plain: {}
863+
`
864+
865+
must := func(err error) {
866+
if err != nil {
867+
panic(err)
868+
}
869+
}
870+
871+
use := func(r *http.ServeMux, middlewares ...func(next http.Handler) http.Handler) http.Handler {
872+
var s http.Handler
873+
s = r
874+
875+
for _, mw := range middlewares {
876+
s = mw(s)
877+
}
878+
879+
return s
880+
}
881+
882+
logResponseBody := func(rr *httptest.ResponseRecorder) {
883+
if rr.Result().Body != nil {
884+
data, _ := io.ReadAll(rr.Result().Body)
885+
if len(data) > 0 {
886+
fmt.Printf("Response body: %s", data)
887+
}
888+
}
889+
}
890+
891+
spec, err := openapi3.NewLoader().LoadFromData([]byte(rawSpec))
892+
must(err)
893+
894+
// NOTE that we need to make sure that the `Servers` aren't set, otherwise the OpenAPI validation middleware will validate that the `Host` header (of incoming requests) are targeting known `Servers` in the OpenAPI spec
895+
// See also: Options#SilenceServersWarning
896+
spec.Servers = nil
897+
898+
router := http.NewServeMux()
899+
900+
router.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) {
901+
fmt.Printf("%s /resource was called\n", r.Method)
902+
903+
if r.Method == http.MethodPost {
904+
data, err := io.ReadAll(r.Body)
905+
if err != nil {
906+
w.WriteHeader(http.StatusInternalServerError)
907+
return
908+
}
909+
fmt.Printf("Request body: %s\n", data)
910+
w.WriteHeader(http.StatusNoContent)
911+
return
912+
}
913+
914+
w.WriteHeader(http.StatusMethodNotAllowed)
915+
})
916+
917+
router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
918+
w.WriteHeader(http.StatusOK)
919+
})
920+
921+
authenticationFunc := func(ctx context.Context, ai *openapi3filter.AuthenticationInput) error {
922+
fmt.Printf("`AuthenticationFunc` was called for securitySchemeName=%s\n", ai.SecuritySchemeName)
923+
return fmt.Errorf("this check always fails - don't let anyone in!")
924+
}
925+
926+
skipperFunc := func(r *http.Request) bool {
927+
// always consume the request body, because we're not following best practices
928+
_, _ = io.ReadAll(r.Body)
929+
930+
// skip the undocumented healthcheck endpoint
931+
if r.URL.Path == "/healthz" {
932+
return true
933+
}
934+
935+
return false
936+
}
937+
938+
// create middleware
939+
mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
940+
Options: openapi3filter.Options{
941+
AuthenticationFunc: authenticationFunc,
942+
},
943+
Skipper: skipperFunc,
944+
})
945+
946+
// then wire it in
947+
server := use(router, mw)
948+
949+
// ================================================================================
950+
fmt.Println("# A request that is made to the undocumented healthcheck endpoint does not get validated")
951+
952+
req, err := http.NewRequest(http.MethodGet, "/healthz", http.NoBody)
953+
must(err)
954+
955+
rr := httptest.NewRecorder()
956+
957+
server.ServeHTTP(rr, req)
958+
959+
fmt.Printf("Received an HTTP %d response. Expected HTTP 200\n", rr.Code)
960+
logResponseBody(rr)
961+
962+
// ================================================================================
963+
fmt.Println("# The skipper cannot consume the request body")
964+
965+
req, err = http.NewRequest(http.MethodPost, "/resource", bytes.NewReader([]byte("Hello there")))
966+
must(err)
967+
req.Header.Set("Content-Type", "text/plain")
968+
969+
rr = httptest.NewRecorder()
970+
971+
server.ServeHTTP(rr, req)
972+
973+
fmt.Printf("Received an HTTP %d response. Expected HTTP 204\n", rr.Code)
974+
logResponseBody(rr)
975+
976+
// Output:
977+
// # A request that is made to the undocumented healthcheck endpoint does not get validated
978+
// Received an HTTP 200 response. Expected HTTP 200
979+
// # The skipper cannot consume the request body
980+
// POST /resource was called
981+
// Request body: Hello there
982+
// Received an HTTP 204 response. Expected HTTP 204
983+
// Response body:
984+
}

0 commit comments

Comments
 (0)