Skip to content

Commit b442366

Browse files
committed
refactor: extract readRequestBody and parseBoolQueryParam helpers
Deduplicate the MaxBytesReader+ReadAll error-handling block (repeated 4x in scheduling/http_handler.go) into a readRequestBody helper, and the bool query-param parsing block (repeated 2x in models/http_handler.go) into a parseBoolQueryParam helper. Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 3c1be71 commit b442366

2 files changed

Lines changed: 43 additions & 51 deletions

File tree

pkg/inference/models/http_handler.go

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@ import (
2323
"github.com/docker/model-runner/pkg/middleware"
2424
)
2525

26+
// parseBoolQueryParam parses a boolean query parameter from the request.
27+
// Returns the parsed value, or false if the parameter is absent or unparseable
28+
// (logging a warning in the latter case).
29+
func parseBoolQueryParam(r *http.Request, log logging.Logger, name string) bool {
30+
if !r.URL.Query().Has(name) {
31+
return false
32+
}
33+
val, err := strconv.ParseBool(r.URL.Query().Get(name))
34+
if err != nil {
35+
log.Warn("error while parsing query parameter", "param", name, "error", err)
36+
return false
37+
}
38+
return val
39+
}
40+
2641
// HTTPHandler manages inference model pulls and storage.
2742
type HTTPHandler struct {
2843
// log is the associated logger.
@@ -195,16 +210,7 @@ func (h *HTTPHandler) handleGetModel(w http.ResponseWriter, r *http.Request) {
195210
}
196211

197212
func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request, modelRef string) {
198-
// Parse remote query parameter
199-
remote := false
200-
if r.URL.Query().Has("remote") {
201-
val, err := strconv.ParseBool(r.URL.Query().Get("remote"))
202-
if err != nil {
203-
h.log.Warn("error while parsing remote query parameter", "error", err)
204-
} else {
205-
remote = val
206-
}
207-
}
213+
remote := parseBoolQueryParam(r, h.log, "remote")
208214

209215
var (
210216
apiModel *Model
@@ -309,14 +315,7 @@ func (h *HTTPHandler) handleDeleteModel(w http.ResponseWriter, r *http.Request)
309315

310316
modelRef := r.PathValue("name")
311317

312-
var force bool
313-
if r.URL.Query().Has("force") {
314-
if val, err := strconv.ParseBool(r.URL.Query().Get("force")); err != nil {
315-
h.log.Warn("error while parsing force query parameter", "error", err)
316-
} else {
317-
force = val
318-
}
319-
}
318+
force := parseBoolQueryParam(r, h.log, "force")
320319

321320
// First try to delete without normalization (as ID), then with normalization if not found
322321
resp, err := h.manager.Delete(modelRef, force)

pkg/inference/scheduling/http_handler.go

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,23 @@ import (
2323

2424
type contextKey bool
2525

26+
// readRequestBody reads up to maxSize bytes from the request body and writes
27+
// an appropriate HTTP error if reading fails. Returns (body, true) on success
28+
// or (nil, false) after writing the error response.
29+
func readRequestBody(w http.ResponseWriter, r *http.Request, maxSize int64) ([]byte, bool) {
30+
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxSize))
31+
if err != nil {
32+
var maxBytesError *http.MaxBytesError
33+
if errors.As(err, &maxBytesError) {
34+
http.Error(w, "request too large", http.StatusBadRequest)
35+
} else {
36+
http.Error(w, "failed to read request body", http.StatusInternalServerError)
37+
}
38+
return nil, false
39+
}
40+
return body, true
41+
}
42+
2643
const preloadOnlyKey contextKey = false
2744

2845
// HTTPHandler handles HTTP requests for the scheduler.
@@ -132,14 +149,8 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque
132149

133150
// Read the entire request body. We put some basic size constraints in place
134151
// to avoid DoS attacks. We do this early to avoid client write timeouts.
135-
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
136-
if err != nil {
137-
var maxBytesError *http.MaxBytesError
138-
if errors.As(err, &maxBytesError) {
139-
http.Error(w, "request too large", http.StatusBadRequest)
140-
} else {
141-
http.Error(w, "failed to read request body", http.StatusInternalServerError)
142-
}
152+
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
153+
if !ok {
143154
return
144155
}
145156

@@ -338,14 +349,8 @@ func (h *HTTPHandler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) {
338349
// Unload unloads the specified runners (backend, model) from the backend.
339350
// Currently, this doesn't work for runners that are handling an OpenAI request.
340351
func (h *HTTPHandler) Unload(w http.ResponseWriter, r *http.Request) {
341-
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
342-
if err != nil {
343-
var maxBytesError *http.MaxBytesError
344-
if errors.As(err, &maxBytesError) {
345-
http.Error(w, "request too large", http.StatusBadRequest)
346-
} else {
347-
http.Error(w, "failed to read request body", http.StatusInternalServerError)
348-
}
352+
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
353+
if !ok {
349354
return
350355
}
351356

@@ -371,14 +376,8 @@ type installBackendRequest struct {
371376
// InstallBackend handles POST <inference-prefix>/install-backend requests.
372377
// It triggers on-demand installation of a deferred backend.
373378
func (h *HTTPHandler) InstallBackend(w http.ResponseWriter, r *http.Request) {
374-
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
375-
if err != nil {
376-
var maxBytesError *http.MaxBytesError
377-
if errors.As(err, &maxBytesError) {
378-
http.Error(w, "request too large", http.StatusBadRequest)
379-
} else {
380-
http.Error(w, "failed to read request body", http.StatusInternalServerError)
381-
}
379+
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
380+
if !ok {
382381
return
383382
}
384383

@@ -414,14 +413,8 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
414413
return
415414
}
416415

417-
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
418-
if err != nil {
419-
var maxBytesError *http.MaxBytesError
420-
if errors.As(err, &maxBytesError) {
421-
http.Error(w, "request too large", http.StatusBadRequest)
422-
} else {
423-
http.Error(w, "failed to read request body", http.StatusInternalServerError)
424-
}
416+
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
417+
if !ok {
425418
return
426419
}
427420

@@ -433,7 +426,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
433426
return
434427
}
435428

436-
backend, err = h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent())
429+
backend, err := h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent())
437430
if err != nil {
438431
if errors.Is(err, errRunnerAlreadyActive) {
439432
http.Error(w, err.Error(), http.StatusConflict)

0 commit comments

Comments
 (0)