Skip to content

Commit 08dfb43

Browse files
authored
Merge pull request #66 from oapi-codegen/feat/skipper
feat: add ability to skip validation checks
2 parents b16f749 + d2a226f commit 08dfb43

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed

oapi_validate.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ type ErrorHandlerOptsMatchedRoute struct {
7474
// MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
7575
type MultiErrorHandler func(openapi3.MultiError) (int, error)
7676

77+
// Skipper is a function that runs before any validation middleware, and determines whether the given request should skip any validation middleware
78+
//
79+
// Return `true` if the request should be skipped
80+
type Skipper func(r *http.Request) bool
81+
7782
// Options allows configuring the OapiRequestValidator.
7883
type Options struct {
7984
// Options contains any configuration for the underlying `openapi3filter`
@@ -103,6 +108,9 @@ type Options struct {
103108
// Prefix allows (optionally) trimming a prefix from the API path.
104109
// This may be useful if your API is routed to an internal path that is different from the OpenAPI specification.
105110
Prefix string
111+
112+
// Skipper allows writing a function that runs before any middleware and determines whether the given request should skip any validation middleware
113+
Skipper Skipper
106114
}
107115

108116
// OapiRequestValidator Creates the middleware to validate that incoming requests match the given OpenAPI 3.x spec, with a default set of configuration.
@@ -129,6 +137,11 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne
129137

130138
return func(next http.Handler) http.Handler {
131139
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
140+
if options != nil && options.Skipper != nil && options.Skipper(r) {
141+
next.ServeHTTP(w, r)
142+
return
143+
}
144+
132145
if options == nil {
133146
performRequestValidationForErrorHandler(next, w, r, router, options, http.Error)
134147
} else if options.ErrorHandlerWithOpts != nil {

oapi_validate_example_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,3 +957,142 @@ paths:
957957
// POST /public-api/v1/resource was called
958958
// Received an HTTP 204 response. Expected HTTP 204
959959
}
960+
961+
func ExampleOapiRequestValidatorWithOptions_withSkipper() {
962+
rawSpec := `
963+
openapi: "3.0.0"
964+
info:
965+
version: 1.0.0
966+
title: TestServer
967+
servers:
968+
- url: http://example.com/
969+
paths:
970+
# we also have a /healthz, but it's not externally documented, so the middleware CANNOT run against it, or it'll block requests
971+
/resource:
972+
post:
973+
operationId: createResource
974+
responses:
975+
'204':
976+
description: No content
977+
requestBody:
978+
required: true
979+
content:
980+
text/plain: {}
981+
`
982+
983+
must := func(err error) {
984+
if err != nil {
985+
panic(err)
986+
}
987+
}
988+
989+
use := func(r *http.ServeMux, middlewares ...func(next http.Handler) http.Handler) http.Handler {
990+
var s http.Handler
991+
s = r
992+
993+
for _, mw := range middlewares {
994+
s = mw(s)
995+
}
996+
997+
return s
998+
}
999+
1000+
logResponseBody := func(rr *httptest.ResponseRecorder) {
1001+
if rr.Result().Body != nil {
1002+
data, _ := io.ReadAll(rr.Result().Body)
1003+
if len(data) > 0 {
1004+
fmt.Printf("Response body: %s", data)
1005+
}
1006+
}
1007+
}
1008+
1009+
spec, err := openapi3.NewLoader().LoadFromData([]byte(rawSpec))
1010+
must(err)
1011+
1012+
// NOTE that we need to make sure that the `Servers` aren't set, otherwise the OpenAPI validation middleware will validate that the `Host` header (of incoming requests) are targeting known `Servers` in the OpenAPI spec
1013+
// See also: Options#SilenceServersWarning
1014+
spec.Servers = nil
1015+
1016+
router := http.NewServeMux()
1017+
1018+
router.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) {
1019+
fmt.Printf("%s /resource was called\n", r.Method)
1020+
1021+
if r.Method == http.MethodPost {
1022+
data, err := io.ReadAll(r.Body)
1023+
if err != nil {
1024+
w.WriteHeader(http.StatusInternalServerError)
1025+
return
1026+
}
1027+
fmt.Printf("Request body: %s\n", data)
1028+
w.WriteHeader(http.StatusNoContent)
1029+
return
1030+
}
1031+
1032+
w.WriteHeader(http.StatusMethodNotAllowed)
1033+
})
1034+
1035+
router.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
1036+
w.WriteHeader(http.StatusOK)
1037+
})
1038+
1039+
authenticationFunc := func(ctx context.Context, ai *openapi3filter.AuthenticationInput) error {
1040+
fmt.Printf("`AuthenticationFunc` was called for securitySchemeName=%s\n", ai.SecuritySchemeName)
1041+
return fmt.Errorf("this check always fails - don't let anyone in!")
1042+
}
1043+
1044+
skipperFunc := func(r *http.Request) bool {
1045+
// skip the undocumented healthcheck endpoint
1046+
if r.URL.Path == "/healthz" {
1047+
return true
1048+
}
1049+
1050+
return false
1051+
}
1052+
1053+
// create middleware
1054+
mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
1055+
Options: openapi3filter.Options{
1056+
AuthenticationFunc: authenticationFunc,
1057+
},
1058+
Skipper: skipperFunc,
1059+
})
1060+
1061+
// then wire it in
1062+
server := use(router, mw)
1063+
1064+
// ================================================================================
1065+
fmt.Println("# A request that is made to the undocumented healthcheck endpoint does not get validated")
1066+
1067+
req, err := http.NewRequest(http.MethodGet, "/healthz", http.NoBody)
1068+
must(err)
1069+
1070+
rr := httptest.NewRecorder()
1071+
1072+
server.ServeHTTP(rr, req)
1073+
1074+
fmt.Printf("Received an HTTP %d response. Expected HTTP 200\n", rr.Code)
1075+
logResponseBody(rr)
1076+
1077+
// ================================================================================
1078+
fmt.Println("# A request that is well-formed is passed through to the Handler")
1079+
1080+
req, err = http.NewRequest(http.MethodPost, "/resource", bytes.NewReader([]byte("Hello there")))
1081+
must(err)
1082+
req.Header.Set("Content-Type", "text/plain")
1083+
1084+
rr = httptest.NewRecorder()
1085+
1086+
server.ServeHTTP(rr, req)
1087+
1088+
fmt.Printf("Received an HTTP %d response. Expected HTTP 204\n", rr.Code)
1089+
logResponseBody(rr)
1090+
1091+
// Output:
1092+
// # A request that is made to the undocumented healthcheck endpoint does not get validated
1093+
// Received an HTTP 200 response. Expected HTTP 200
1094+
// # A request that is well-formed is passed through to the Handler
1095+
// POST /resource was called
1096+
// Request body: Hello there
1097+
// Received an HTTP 204 response. Expected HTTP 204
1098+
}

0 commit comments

Comments
 (0)