Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions pkg/inference/models/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,28 @@ import (
"github.com/docker/model-runner/pkg/middleware"
)

// parseBoolQueryParam parses a boolean query parameter from the request.
// Returns the parsed value, or false if the parameter is absent or unparseable
// (logging a warning in the latter case). Treats presence of the key with an
// empty value (e.g., `?force`) as true.
func parseBoolQueryParam(r *http.Request, log logging.Logger, name string) bool {
q := r.URL.Query()
if !q.Has(name) {
return false
}
valStr := q.Get(name)
// Treat presence of key with empty value as true (e.g., `?force`)
if valStr == "" {
return true
}
val, err := strconv.ParseBool(valStr)
if err != nil {
log.Warn("error while parsing query parameter", "param", name, "value", valStr, "error", err)
return false
}
return val
}

// HTTPHandler manages inference model pulls and storage.
type HTTPHandler struct {
// log is the associated logger.
Expand Down Expand Up @@ -195,16 +217,7 @@ func (h *HTTPHandler) handleGetModel(w http.ResponseWriter, r *http.Request) {
}

func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request, modelRef string) {
// Parse remote query parameter
remote := false
if r.URL.Query().Has("remote") {
val, err := strconv.ParseBool(r.URL.Query().Get("remote"))
if err != nil {
h.log.Warn("error while parsing remote query parameter", "error", err)
} else {
remote = val
}
}
remote := parseBoolQueryParam(r, h.log, "remote")

var (
apiModel *Model
Expand Down Expand Up @@ -309,14 +322,7 @@ func (h *HTTPHandler) handleDeleteModel(w http.ResponseWriter, r *http.Request)

modelRef := r.PathValue("name")

var force bool
if r.URL.Query().Has("force") {
if val, err := strconv.ParseBool(r.URL.Query().Get("force")); err != nil {
h.log.Warn("error while parsing force query parameter", "error", err)
} else {
force = val
}
}
force := parseBoolQueryParam(r, h.log, "force")

// First try to delete without normalization (as ID), then with normalization if not found
resp, err := h.manager.Delete(modelRef, force)
Expand Down
58 changes: 26 additions & 32 deletions pkg/inference/scheduling/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ import (

type contextKey bool

// readRequestBody reads up to maxSize bytes from the request body and writes
// an appropriate HTTP error if reading fails. Returns (body, true) on success
// or (nil, false) after writing the error response.
func readRequestBody(w http.ResponseWriter, r *http.Request, maxSize int64) ([]byte, bool) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
return nil, false
}
return body, true
}

const preloadOnlyKey contextKey = false

// HTTPHandler handles HTTP requests for the scheduler.
Expand Down Expand Up @@ -132,14 +149,8 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque

// Read the entire request body. We put some basic size constraints in place
// to avoid DoS attacks. We do this early to avoid client write timeouts.
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
if !ok {
return
}

Expand Down Expand Up @@ -338,14 +349,8 @@ func (h *HTTPHandler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) {
// Unload unloads the specified runners (backend, model) from the backend.
// Currently, this doesn't work for runners that are handling an OpenAI request.
func (h *HTTPHandler) Unload(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
if !ok {
return
}

Expand All @@ -371,14 +376,8 @@ type installBackendRequest struct {
// InstallBackend handles POST <inference-prefix>/install-backend requests.
// It triggers on-demand installation of a deferred backend.
func (h *HTTPHandler) InstallBackend(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
if !ok {
return
}

Expand All @@ -404,6 +403,7 @@ func (h *HTTPHandler) InstallBackend(w http.ResponseWriter, r *http.Request) {
func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
// Determine the requested backend and ensure that it's valid.
var backend inference.Backend
var err error
if b := r.PathValue("backend"); b == "" {
backend = h.scheduler.defaultBackend
} else {
Expand All @@ -414,14 +414,8 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
return
}

body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
if !ok {
return
}

Expand Down
Loading