diff --git a/internal/test/nethttp/oapi_validate_prefix_test.go b/internal/test/nethttp/oapi_validate_prefix_test.go index b9b47af..057644f 100644 --- a/internal/test/nethttp/oapi_validate_prefix_test.go +++ b/internal/test/nethttp/oapi_validate_prefix_test.go @@ -3,6 +3,7 @@ package gorilla import ( "context" _ "embed" + "encoding/json" "net/http" "testing" @@ -284,3 +285,122 @@ components: assert.Equal(t, http.StatusOK, rec.Code) assert.True(t, called, "handler should have been called when auth passes") } + +// bodyReadableSpec defines a POST /resource with a required JSON body, +// used to test that the handler can still read the body after validation. +const bodyReadableSpec = ` +openapi: "3.0.0" +info: + version: 1.0.0 + title: TestServer +paths: + /resource: + post: + operationId: createResource + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - name + properties: + name: + type: string + additionalProperties: false + responses: + '204': + description: No content +` + +// TestPrefix_RequestBodyReadableByHandler_WithAndWithoutPrefix is a regression +// test for https://github.com/oapi-codegen/nethttp-middleware/issues/69. +// +// When Prefix is set, makeRequestForValidation used to clone the request via +// r.Clone(), which shallow-copies the Body. Validation then consumed the body +// on the clone, leaving the original body empty for the downstream handler. +func TestPrefix_RequestBodyReadableByHandler_WithAndWithoutPrefix(t *testing.T) { + spec, err := openapi3.NewLoader().LoadFromData([]byte(bodyReadableSpec)) + require.NoError(t, err) + spec.Servers = nil + + tests := []struct { + name string + prefix string + routePath string + requestPath string + }{ + { + name: "without prefix", + prefix: "", + routePath: "/resource", + requestPath: "http://example.com/resource", + }, + { + name: "with prefix", + prefix: "/api", + routePath: "/api/resource", + requestPath: "http://example.com/api/resource", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc(tc.routePath, func(w http.ResponseWriter, r *http.Request) { + var payload struct { + Name string `json:"name"` + } + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, "handler failed to decode body: "+err.Error(), http.StatusBadRequest) + return + } + assert.Equal(t, "Jamie", payload.Name) + w.WriteHeader(http.StatusNoContent) + }) + + mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{ + Prefix: tc.prefix, + }) + server := mw(mux) + + body := map[string]string{"name": "Jamie"} + rec := doPost(t, server, tc.requestPath, body) + assert.Equal(t, http.StatusNoContent, rec.Code, "body: %s", rec.Body.String()) + }) + } +} + +// TestPrefix_RequestBodyReadableByHandler_ErrorHandlerWithOpts is the same +// regression test but exercising the ErrorHandlerWithOpts code path. +func TestPrefix_RequestBodyReadableByHandler_ErrorHandlerWithOpts(t *testing.T) { + spec, err := openapi3.NewLoader().LoadFromData([]byte(bodyReadableSpec)) + require.NoError(t, err) + spec.Servers = nil + + mux := http.NewServeMux() + mux.HandleFunc("/api/resource", func(w http.ResponseWriter, r *http.Request) { + var payload struct { + Name string `json:"name"` + } + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, "handler failed to decode body: "+err.Error(), http.StatusBadRequest) + return + } + assert.Equal(t, "Jamie", payload.Name) + w.WriteHeader(http.StatusNoContent) + }) + + mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{ + Prefix: "/api", + ErrorHandlerWithOpts: func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) { + http.Error(w, err.Error(), opts.StatusCode) + }, + }) + server := mw(mux) + + body := map[string]string{"name": "Jamie"} + rec := doPost(t, server, "http://example.com/api/resource", body) + assert.Equal(t, http.StatusNoContent, rec.Code, "body: %s", rec.Body.String()) +} diff --git a/oapi_validate.go b/oapi_validate.go index 026bd87..82d3cd7 100644 --- a/oapi_validate.go +++ b/oapi_validate.go @@ -169,19 +169,28 @@ func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseW errorHandler(w, err.Error(), statusCode) } -func makeRequestForValidation(r *http.Request, options *Options) *http.Request { +// withPrefixStripped temporarily strips the configured prefix from the +// request's path fields, calls fn, then restores the original values. +// This avoids cloning the request (which shallow-copies the Body and +// causes it to be consumed during validation, leaving the handler with +// an empty body — see https://github.com/oapi-codegen/nethttp-middleware/issues/69). +func withPrefixStripped(r *http.Request, options *Options, fn func()) { if options == nil || options.Prefix == "" { - return r + fn() + return } // Only strip the prefix when it matches on a path segment boundary: // the path must equal the prefix exactly, or the character immediately // after the prefix must be '/'. if !hasPathPrefix(r.URL.Path, options.Prefix) { - return r + fn() + return } - r = r.Clone(r.Context()) + origRequestURI := r.RequestURI + origPath := r.URL.Path + origRawPath := r.URL.RawPath r.RequestURI = stripPrefix(r.RequestURI, options.Prefix) r.URL.Path = stripPrefix(r.URL.Path, options.Prefix) @@ -189,7 +198,11 @@ func makeRequestForValidation(r *http.Request, options *Options) *http.Request { r.URL.RawPath = stripPrefix(r.URL.RawPath, options.Prefix) } - return r + fn() + + r.RequestURI = origRequestURI + r.URL.Path = origPath + r.URL.RawPath = origRawPath } // hasPathPrefix reports whether path starts with prefix on a segment boundary. @@ -210,19 +223,39 @@ func stripPrefix(s, prefix string) string { // Note that this is an inline-and-modified version of `validateRequest`, with a simplified control flow and providing full access to the `error` for the `ErrorHandlerWithOpts` function. func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options) { - // Build a (possibly prefix-stripped) request for validation, but keep - // the original so the downstream handler sees the un-modified path. - validationReq := makeRequestForValidation(r, options) + var route *routers.Route + var pathParams map[string]string + var validationErr error + + // Temporarily strip the prefix for route finding and validation, + // then restore it so the handler and error handler see the original path. + withPrefixStripped(r, options, func() { + var err error + route, pathParams, err = router.FindRoute(r) + if err != nil { + validationErr = err + return + } - // Find route - route, pathParams, err := router.FindRoute(validationReq) - if err != nil { + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: r, + PathParams: pathParams, + Route: route, + } + + if options != nil { + requestValidationInput.Options = &options.Options + } + + validationErr = openapi3filter.ValidateRequest(r.Context(), requestValidationInput) + }) + + // Route not found + if route == nil && validationErr != nil { errOpts := ErrorHandlerOpts{ - // MatchedRoute will be nil, as we've not matched a route we know about StatusCode: http.StatusNotFound, } - - options.ErrorHandlerWithOpts(r.Context(), err, w, r, errOpts) + options.ErrorHandlerWithOpts(r.Context(), validationErr, w, r, errOpts) return } @@ -231,44 +264,26 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R Route: route, PathParams: pathParams, }, - // other options will be added before executing } - // Validate request - requestValidationInput := &openapi3filter.RequestValidationInput{ - Request: validationReq, - PathParams: pathParams, - Route: route, - } - - if options != nil { - requestValidationInput.Options = &options.Options - } - - err = openapi3filter.ValidateRequest(validationReq.Context(), requestValidationInput) - if err == nil { - // it's a valid request, so serve it with the original request + if validationErr == nil { next.ServeHTTP(w, r) return } var theErr error - switch e := err.(type) { + switch e := validationErr.(type) { case openapi3.MultiError: theErr = e errOpts.StatusCode = determineStatusCodeForMultiError(e) case *openapi3filter.RequestError: - // We've got a bad request theErr = e errOpts.StatusCode = http.StatusBadRequest case *openapi3filter.SecurityRequirementsError: theErr = e errOpts.StatusCode = http.StatusUnauthorized default: - // This should never happen today, but if our upstream code changes, - // we don't want to crash the server, so handle the unexpected error. - // return http.StatusInternalServerError, theErr = fmt.Errorf("error validating route: %w", e) errOpts.StatusCode = http.StatusInternalServerError } @@ -279,8 +294,19 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R // validateRequest is called from the middleware above and actually does the work // of validating a request. func validateRequest(r *http.Request, router routers.Router, options *Options) (int, error) { - r = makeRequestForValidation(r, options) + var statusCode int + var validationErr error + + withPrefixStripped(r, options, func() { + statusCode, validationErr = doValidateRequest(r, router, options) + }) + + return statusCode, validationErr +} +// doValidateRequest performs the actual validation, called within +// withPrefixStripped so the prefix is already removed from r's path. +func doValidateRequest(r *http.Request, router routers.Router, options *Options) (int, error) { // Find route route, pathParams, err := router.FindRoute(r) if err != nil {