Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions internal/test/nethttp/oapi_validate_prefix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gorilla
import (
"context"
_ "embed"
"encoding/json"
"net/http"
"testing"

Expand Down Expand Up @@ -284,3 +285,122 @@ components:
assert.Equal(t, http.StatusOK, rec.Code)
assert.True(t, called, "handler should have been called when auth passes")
}

// bodyReadableSpec defines a POST /resource with a required JSON body,
// used to test that the handler can still read the body after validation.
const bodyReadableSpec = `
openapi: "3.0.0"
info:
version: 1.0.0
title: TestServer
paths:
/resource:
post:
operationId: createResource
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- name
properties:
name:
type: string
additionalProperties: false
responses:
'204':
description: No content
`

// TestPrefix_RequestBodyReadableByHandler_WithAndWithoutPrefix is a regression
// test for https://github.com/oapi-codegen/nethttp-middleware/issues/69.
//
// When Prefix is set, makeRequestForValidation used to clone the request via
// r.Clone(), which shallow-copies the Body. Validation then consumed the body
// on the clone, leaving the original body empty for the downstream handler.
func TestPrefix_RequestBodyReadableByHandler_WithAndWithoutPrefix(t *testing.T) {
spec, err := openapi3.NewLoader().LoadFromData([]byte(bodyReadableSpec))
require.NoError(t, err)
spec.Servers = nil

tests := []struct {
name string
prefix string
routePath string
requestPath string
}{
{
name: "without prefix",
prefix: "",
routePath: "/resource",
requestPath: "http://example.com/resource",
},
{
name: "with prefix",
prefix: "/api",
routePath: "/api/resource",
requestPath: "http://example.com/api/resource",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc(tc.routePath, func(w http.ResponseWriter, r *http.Request) {
var payload struct {
Name string `json:"name"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "handler failed to decode body: "+err.Error(), http.StatusBadRequest)
return
}
assert.Equal(t, "Jamie", payload.Name)
w.WriteHeader(http.StatusNoContent)
})

mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: tc.prefix,
})
server := mw(mux)

body := map[string]string{"name": "Jamie"}
rec := doPost(t, server, tc.requestPath, body)
assert.Equal(t, http.StatusNoContent, rec.Code, "body: %s", rec.Body.String())
})
}
}

// TestPrefix_RequestBodyReadableByHandler_ErrorHandlerWithOpts is the same
// regression test but exercising the ErrorHandlerWithOpts code path.
func TestPrefix_RequestBodyReadableByHandler_ErrorHandlerWithOpts(t *testing.T) {
spec, err := openapi3.NewLoader().LoadFromData([]byte(bodyReadableSpec))
require.NoError(t, err)
spec.Servers = nil

mux := http.NewServeMux()
mux.HandleFunc("/api/resource", func(w http.ResponseWriter, r *http.Request) {
var payload struct {
Name string `json:"name"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
http.Error(w, "handler failed to decode body: "+err.Error(), http.StatusBadRequest)
return
}
assert.Equal(t, "Jamie", payload.Name)
w.WriteHeader(http.StatusNoContent)
})

mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/api",
ErrorHandlerWithOpts: func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) {
http.Error(w, err.Error(), opts.StatusCode)
},
})
server := mw(mux)

body := map[string]string{"name": "Jamie"}
rec := doPost(t, server, "http://example.com/api/resource", body)
assert.Equal(t, http.StatusNoContent, rec.Code, "body: %s", rec.Body.String())
}
96 changes: 61 additions & 35 deletions oapi_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,27 +169,40 @@ func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseW
errorHandler(w, err.Error(), statusCode)
}

func makeRequestForValidation(r *http.Request, options *Options) *http.Request {
// withPrefixStripped temporarily strips the configured prefix from the
// request's path fields, calls fn, then restores the original values.
// This avoids cloning the request (which shallow-copies the Body and
// causes it to be consumed during validation, leaving the handler with
// an empty body — see https://github.com/oapi-codegen/nethttp-middleware/issues/69).
func withPrefixStripped(r *http.Request, options *Options, fn func()) {
if options == nil || options.Prefix == "" {
return r
fn()
return
}

// Only strip the prefix when it matches on a path segment boundary:
// the path must equal the prefix exactly, or the character immediately
// after the prefix must be '/'.
if !hasPathPrefix(r.URL.Path, options.Prefix) {
return r
fn()
return
}

r = r.Clone(r.Context())
origRequestURI := r.RequestURI
origPath := r.URL.Path
origRawPath := r.URL.RawPath

r.RequestURI = stripPrefix(r.RequestURI, options.Prefix)
r.URL.Path = stripPrefix(r.URL.Path, options.Prefix)
if r.URL.RawPath != "" {
r.URL.RawPath = stripPrefix(r.URL.RawPath, options.Prefix)
}

return r
fn()

r.RequestURI = origRequestURI
r.URL.Path = origPath
r.URL.RawPath = origRawPath
}

// hasPathPrefix reports whether path starts with prefix on a segment boundary.
Expand All @@ -210,19 +223,39 @@ func stripPrefix(s, prefix string) string {

// Note that this is an inline-and-modified version of `validateRequest`, with a simplified control flow and providing full access to the `error` for the `ErrorHandlerWithOpts` function.
func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options) {
// Build a (possibly prefix-stripped) request for validation, but keep
// the original so the downstream handler sees the un-modified path.
validationReq := makeRequestForValidation(r, options)
var route *routers.Route
var pathParams map[string]string
var validationErr error

// Temporarily strip the prefix for route finding and validation,
// then restore it so the handler and error handler see the original path.
withPrefixStripped(r, options, func() {
var err error
route, pathParams, err = router.FindRoute(r)
if err != nil {
validationErr = err
return
}

// Find route
route, pathParams, err := router.FindRoute(validationReq)
if err != nil {
requestValidationInput := &openapi3filter.RequestValidationInput{
Request: r,
PathParams: pathParams,
Route: route,
}

if options != nil {
requestValidationInput.Options = &options.Options
}

validationErr = openapi3filter.ValidateRequest(r.Context(), requestValidationInput)
})

// Route not found
if route == nil && validationErr != nil {
errOpts := ErrorHandlerOpts{
// MatchedRoute will be nil, as we've not matched a route we know about
StatusCode: http.StatusNotFound,
}

options.ErrorHandlerWithOpts(r.Context(), err, w, r, errOpts)
options.ErrorHandlerWithOpts(r.Context(), validationErr, w, r, errOpts)
return
}

Expand All @@ -231,44 +264,26 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R
Route: route,
PathParams: pathParams,
},
// other options will be added before executing
}

// Validate request
requestValidationInput := &openapi3filter.RequestValidationInput{
Request: validationReq,
PathParams: pathParams,
Route: route,
}

if options != nil {
requestValidationInput.Options = &options.Options
}

err = openapi3filter.ValidateRequest(validationReq.Context(), requestValidationInput)
if err == nil {
// it's a valid request, so serve it with the original request
if validationErr == nil {
next.ServeHTTP(w, r)
return
}

var theErr error

switch e := err.(type) {
switch e := validationErr.(type) {
case openapi3.MultiError:
theErr = e
errOpts.StatusCode = determineStatusCodeForMultiError(e)
case *openapi3filter.RequestError:
// We've got a bad request
theErr = e
errOpts.StatusCode = http.StatusBadRequest
case *openapi3filter.SecurityRequirementsError:
theErr = e
errOpts.StatusCode = http.StatusUnauthorized
default:
// This should never happen today, but if our upstream code changes,
// we don't want to crash the server, so handle the unexpected error.
// return http.StatusInternalServerError,
theErr = fmt.Errorf("error validating route: %w", e)
errOpts.StatusCode = http.StatusInternalServerError
}
Expand All @@ -279,8 +294,19 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R
// validateRequest is called from the middleware above and actually does the work
// of validating a request.
func validateRequest(r *http.Request, router routers.Router, options *Options) (int, error) {
r = makeRequestForValidation(r, options)
var statusCode int
var validationErr error

withPrefixStripped(r, options, func() {
statusCode, validationErr = doValidateRequest(r, router, options)
})

return statusCode, validationErr
}

// doValidateRequest performs the actual validation, called within
// withPrefixStripped so the prefix is already removed from r's path.
func doValidateRequest(r *http.Request, router routers.Router, options *Options) (int, error) {
// Find route
route, pathParams, err := router.FindRoute(r)
if err != nil {
Expand Down
Loading