diff --git a/pkg/inference/runtime_flags.go b/pkg/inference/runtime_flags.go index d0712c3e..bb9e1f6b 100644 --- a/pkg/inference/runtime_flags.go +++ b/pkg/inference/runtime_flags.go @@ -5,7 +5,33 @@ import ( "strings" ) -// ValidateRuntimeFlags ensures runtime flags don't contain paths (forward slash "/" or backslash "\") +// ValidateRuntimeFlags validates runtime flags against the backend's allowlist +// and checks for path characters as defense-in-depth. +func ValidateRuntimeFlags(backendName string, flags []string) error { + // Get allowlist for this backend + allowedFlags := GetAllowedFlags(backendName) + + // Check each flag against allowlist + for _, flag := range flags { + flagKey := ParseFlagKey(flag) + if flagKey == "" { + continue // Skip values, only validate flag keys + } + if !allowedFlags[flagKey] { + return fmt.Errorf("runtime flag %q is not allowed for backend %q", flagKey, backendName) + } + } + + // Check for flags in values (e.g., --seed=--log-file=foo or --seed=-l) + if err := validateNoFlagInjection(flags); err != nil { + return err + } + + // still check for path characters in values + return validatePathSafety(flags) +} + +// validatePathSafety ensures runtime flags don't contain paths (forward slash "/" or backslash "\") // to prevent malicious users from overwriting host files via arguments like // --log-file /some/path, --output-file /etc/passwd, or --log-file C:\Windows\file. // @@ -17,7 +43,7 @@ import ( // - UNC paths: \\network\share\file // // Returns an error if any flag contains a forward slash or backslash. -func ValidateRuntimeFlags(flags []string) error { +func validatePathSafety(flags []string) error { for _, flag := range flags { if strings.Contains(flag, "/") || strings.Contains(flag, "\\") { return fmt.Errorf("invalid runtime flag %q: paths are not allowed (contains '/' or '\\\\')", flag) @@ -25,3 +51,27 @@ func ValidateRuntimeFlags(flags []string) error { } return nil } + +// validateNoFlagInjection checks for flags in values when using the = format. +// This prevents attacks like --seed=--log-file=foo or --seed=-l where disallowed flags +// are embedded as values. +// Values starting with - are only allowed if followed by a digit (negative numbers like -1, -0.5). +func validateNoFlagInjection(flags []string) error { + for _, flag := range flags { + if idx := strings.Index(flag, "="); idx != -1 { + value := flag[idx+1:] + if strings.HasPrefix(value, "-") { + // Allow negative numbers (-1, -0.5) but reject flags (-l, --flag) + if len(value) < 2 || !isDigit(value[1]) { + return fmt.Errorf("invalid flag %q: value cannot start with '-' unless followed by a digit", flag) + } + } + } + } + return nil +} + +// isDigit returns true if the byte is an ASCII digit (0-9) +func isDigit(b byte) bool { + return b >= '0' && b <= '9' +} diff --git a/pkg/inference/runtime_flags_allowlist.go b/pkg/inference/runtime_flags_allowlist.go new file mode 100644 index 00000000..fe9fa43e --- /dev/null +++ b/pkg/inference/runtime_flags_allowlist.go @@ -0,0 +1,236 @@ +package inference + +import "strings" + +// LlamaCppAllowedFlags contains safe flags for llama.cpp server. +// This list is based on llama.cpp server documentation. +// Flags involving file paths are intentionally excluded for security. +var LlamaCppAllowedFlags = map[string]bool{ + // Threading and CPU control + "-t": true, "--threads": true, + "-tb": true, "--threads-batch": true, + "-C": true, "--cpu-mask": true, + "-Cr": true, "--cpu-range": true, + "--cpu-strict": true, + "--prio": true, + "--poll": true, + "-Cb": true, "--cpu-mask-batch": true, + "-Crb": true, "--cpu-range-batch": true, + "--cpu-strict-batch": true, + "--prio-batch": true, + "--poll-batch": true, + + // Context and prediction + "-c": true, "--ctx-size": true, + "-n": true, "--predict": true, "--n-predict": true, + "--keep": true, + + // Batching and performance + "-b": true, "--batch-size": true, + "-ub": true, "--ubatch-size": true, + "--swa-full": true, + "-fa": true, "--flash-attn": true, + "--perf": true, "--no-perf": true, + + // Sampling parameters + "--samplers": true, + "-s": true, "--seed": true, + "--temp": true, "--temperature": true, + "--top-k": true, + "--top-p": true, + "--min-p": true, + "--top-nsigma": true, + "--xtc-probability": true, + "--xtc-threshold": true, + "--typical": true, + "--repeat-last-n": true, + "--repeat-penalty": true, + "--presence-penalty": true, + "--frequency-penalty": true, + "--dry-multiplier": true, + "--dry-base": true, + "--dry-allowed-length": true, + "--dry-penalty-last-n": true, + "--mirostat": true, + "--mirostat-lr": true, + "--mirostat-ent": true, + "--ignore-eos": true, + "--dynatemp-range": true, + "--dynatemp-exp": true, + + // GPU and device management + "-dev": true, "--device": true, + "-ngl": true, "--gpu-layers": true, "--n-gpu-layers": true, + "-sm": true, "--split-mode": true, + "-ts": true, "--tensor-split": true, + "-mg": true, "--main-gpu": true, + "-fit": true, "--fit": true, + "-fitt": true, "--fit-target": true, + "-fitc": true, "--fit-ctx": true, + + // Memory and caching + "-kvo": true, "--kv-offload": true, + "-nkvo": true, "--no-kv-offload": true, + "--repack": true, "-nr": true, "--no-repack": true, + "--no-host": true, + "-ctk": true, "--cache-type-k": true, + "-ctv": true, "--cache-type-v": true, + "--mlock": true, + "--mmap": true, "--no-mmap": true, + "-dio": true, "--direct-io": true, + "-ndio": true, "--no-direct-io": true, + "-cram": true, "--cache-ram": true, + "-kvu": true, "--kv-unified": true, + "--context-shift": true, "--no-context-shift": true, + + // RoPE scaling + "--rope-scaling": true, + "--rope-scale": true, + "--rope-freq-base": true, + "--rope-freq-scale": true, + "--yarn-orig-ctx": true, + "--yarn-ext-factor": true, + "--yarn-attn-factor": true, + "--yarn-beta-slow": true, + "--yarn-beta-fast": true, + + // Server configuration + "-np": true, "--parallel": true, + "-cb": true, "--cont-batching": true, + "-nocb": true, "--no-cont-batching": true, + "--warmup": true, "--no-warmup": true, + "-to": true, "--timeout": true, + "--threads-http": true, + "--cache-prompt": true, + "--no-cache-prompt": true, + "--cache-reuse": true, + "--sleep-idle-seconds": true, + + // Multimodal (safe flags only - no file paths) + "--mmproj-auto": true, "--no-mmproj": true, "--no-mmproj-auto": true, + "--mmproj-offload": true, "--no-mmproj-offload": true, + "--image-min-tokens": true, + "--image-max-tokens": true, + "--spm-infill": true, + + // Speculative decoding (safe flags only - no file paths) + "--draft": true, "--draft-n": true, "--draft-max": true, + "--draft-min": true, "--draft-n-min": true, + "--draft-p-min": true, + "-cd": true, "--ctx-size-draft": true, + "-devd": true, "--device-draft": true, + "-ngld": true, "--gpu-layers-draft": true, "--n-gpu-layers-draft": true, + "-td": true, "--threads-draft": true, + "-tbd": true, "--threads-batch-draft": true, + + // LoRA (safe flags only - no file paths) + "--lora-init-without-apply": true, + + // Control vectors (safe flags only - no file paths) + "--control-vector-layer-range": true, + + // Grammar and constraints (safe flags only - no file paths) + "--grammar": true, + "-j": true, "--json-schema": true, + "-bs": true, "--backend-sampling": true, + + // Template and format control (safe flags only - no file paths) + "--chat-template": true, + "--chat-template-kwargs": true, + "--jinja": true, "--no-jinja": true, + "--pooling": true, + "--reasoning-format": true, + "--reasoning-budget": true, + "--prefill-assistant": true, + "--no-prefill-assistant": true, + + // Web interface and API (safe flags only - no file paths) + "--api-prefix": true, + "--webui": true, "--no-webui": true, + "--webui-config": true, + "--api-key": true, + "--metrics": true, + "--no-metrics": true, + "--props": true, + "--slots": true, "--no-slots": true, + + // Embedding and specialized + "--embedding": true, "--embeddings": true, + "--rerank": true, "--reranking": true, + "-sps": true, "--slot-prompt-similarity": true, + + // Tensor and computation (safe flags only) + "-cmoe": true, "--cpu-moe": true, + "-ncmoe": true, "--n-cpu-moe": true, + "--check-tensors": true, + "--op-offload": true, "--no-op-offload": true, + + // Verbose/debug + "-v": true, "--verbose": true, +} + +// VLLMAllowedFlags contains safe flags for vLLM engine. +// Flags involving file paths are intentionally excluded for security. +var VLLMAllowedFlags = map[string]bool{ + // Parallelism + "--tensor-parallel-size": true, "-tp": true, + "--pipeline-parallel-size": true, "-pp": true, + + // Model configuration + "--max-model-len": true, + "--max-num-batched-tokens": true, + "--max-num-seqs": true, + "--block-size": true, + "--swap-space": true, + "--seed": true, + + // Data types and quantization + "--dtype": true, + "--quantization": true, + "-q": true, + "--kv-cache-dtype": true, + + // Performance flags + "--enforce-eager": true, + "--enable-prefix-caching": true, + "--enable-chunked-prefill": true, + "--disable-custom-all-reduce": true, + "--use-v2-block-manager": true, + + // Tokenizer + "--tokenizer-mode": true, + "--trust-remote-code": true, + "--max-logprobs": true, + + // Misc + "--revision": true, + "--load-format": true, + "--disable-log-stats": true, + "--served-model-name": true, + + // GPU memory + "--gpu-memory-utilization": true, +} + +// AllowedFlags maps backend names to their allowed flag keys +var AllowedFlags = map[string]map[string]bool{ + "llama.cpp": LlamaCppAllowedFlags, + "vllm": VLLMAllowedFlags, +} + +// ParseFlagKey extracts the flag key from a flag string. +// "--threads=4" -> "--threads", "-t" -> "-t", "4" -> "" +func ParseFlagKey(flag string) string { + if !strings.HasPrefix(flag, "-") { + return "" // Not a flag, it's a value + } + if idx := strings.Index(flag, "="); idx != -1 { + return flag[:idx] + } + return flag +} + +// GetAllowedFlags returns the allowlist for a backend, or nil if unknown +func GetAllowedFlags(backendName string) map[string]bool { + return AllowedFlags[backendName] +} diff --git a/pkg/inference/runtime_flags_allowlist_test.go b/pkg/inference/runtime_flags_allowlist_test.go new file mode 100644 index 00000000..d32ece67 --- /dev/null +++ b/pkg/inference/runtime_flags_allowlist_test.go @@ -0,0 +1,267 @@ +package inference + +import ( + "testing" +) + +func TestParseFlagKey(t *testing.T) { + tests := []struct { + name string + flag string + expected string + }{ + { + name: "long flag", + flag: "--threads", + expected: "--threads", + }, + { + name: "short flag", + flag: "-t", + expected: "-t", + }, + { + name: "long flag with equals", + flag: "--threads=4", + expected: "--threads", + }, + { + name: "short flag with equals", + flag: "-t=4", + expected: "-t", + }, + { + name: "value only (number)", + flag: "4", + expected: "", + }, + { + name: "value only (string)", + flag: "some-value", + expected: "", + }, + { + name: "empty string", + flag: "", + expected: "", + }, + { + name: "long flag with complex value", + flag: "--model-name=llama-3.2-1b", + expected: "--model-name", + }, + { + name: "flag with multiple equals", + flag: "--config=key=value", + expected: "--config", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseFlagKey(tt.flag) + if result != tt.expected { + t.Errorf("ParseFlagKey(%q) = %q, want %q", tt.flag, result, tt.expected) + } + }) + } +} + +func TestGetAllowedFlags(t *testing.T) { + tests := []struct { + name string + backend string + expectNil bool + checkFlags []string // flags that should be in the allowlist + }{ + { + name: "llama.cpp backend", + backend: "llama.cpp", + expectNil: false, + checkFlags: []string{"--threads", "-t", "--ctx-size", "-ngl", "--verbose", "-v", "--cache-type-k", "--cache-type-v"}, + }, + { + name: "vllm backend", + backend: "vllm", + expectNil: false, + checkFlags: []string{"--tensor-parallel-size", "-tp", "--max-model-len", "--dtype", "--gpu-memory-utilization"}, + }, + { + name: "unknown backend", + backend: "unknown", + expectNil: true, + }, + { + name: "empty backend name", + backend: "", + expectNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetAllowedFlags(tt.backend) + + if tt.expectNil { + if result != nil { + t.Errorf("GetAllowedFlags(%q) expected nil, got %v", tt.backend, result) + } + return + } + + if result == nil { + t.Fatalf("GetAllowedFlags(%q) returned nil, expected non-nil", tt.backend) + } + + for _, flag := range tt.checkFlags { + if !result[flag] { + t.Errorf("GetAllowedFlags(%q) missing expected flag %q", tt.backend, flag) + } + } + }) + } +} + +func TestLlamaCppAllowedFlags_Categories(t *testing.T) { + // Test that key flags from each category are present + categories := map[string][]string{ + "threading": {"-t", "--threads", "-tb", "--threads-batch", "-C", "--cpu-mask", "--prio"}, + "context": {"-c", "--ctx-size", "-n", "--n-predict", "--keep"}, + "batching": {"-b", "--batch-size", "-ub", "--ubatch-size", "-fa", "--flash-attn"}, + "sampling": { + "--samplers", "-s", "--seed", "--temp", "--temperature", + "--top-k", "--top-p", "--min-p", "--typical", + "--repeat-last-n", "--repeat-penalty", + "--presence-penalty", "--frequency-penalty", + "--mirostat", "--mirostat-lr", "--mirostat-ent", + "--dynatemp-range", "--dynatemp-exp", + }, + "gpu": { + "-ngl", "--gpu-layers", "--n-gpu-layers", + "-sm", "--split-mode", "-ts", "--tensor-split", + "-mg", "--main-gpu", "-dev", "--device", + }, + "memory": { + "--mlock", "--mmap", "--no-mmap", + "-ctk", "--cache-type-k", "-ctv", "--cache-type-v", + "-kvo", "--kv-offload", "-nkvo", "--no-kv-offload", + "-cram", "--cache-ram", + }, + "rope": { + "--rope-scaling", "--rope-scale", + "--rope-freq-base", "--rope-freq-scale", + "--yarn-orig-ctx", "--yarn-ext-factor", + }, + "server": { + "-np", "--parallel", "-to", "--timeout", + "-cb", "--cont-batching", "--cache-prompt", + "--threads-http", "--warmup", "--no-warmup", + }, + "mode": { + "--embeddings", "--embedding", "--reranking", "--rerank", + "--metrics", "--no-metrics", "--jinja", "--no-jinja", + }, + "speculative": { + "--draft", "--draft-max", "--draft-min", + "-cd", "--ctx-size-draft", + "-ngld", "--gpu-layers-draft", + }, + } + + for category, flags := range categories { + t.Run(category, func(t *testing.T) { + for _, flag := range flags { + if !LlamaCppAllowedFlags[flag] { + t.Errorf("LlamaCppAllowedFlags missing %s flag %q", category, flag) + } + } + }) + } +} + +func TestVLLMAllowedFlags_Categories(t *testing.T) { + categories := map[string][]string{ + "parallelism": {"--tensor-parallel-size", "-tp", "--pipeline-parallel-size", "-pp"}, + "model": {"--max-model-len", "--max-num-batched-tokens", "--max-num-seqs", "--block-size", "--swap-space", "--seed"}, + "dtype": {"--dtype", "--quantization", "-q", "--kv-cache-dtype"}, + "performance": {"--enforce-eager", "--enable-prefix-caching", "--enable-chunked-prefill"}, + "tokenizer": {"--tokenizer-mode", "--trust-remote-code", "--max-logprobs"}, + "misc": {"--revision", "--load-format", "--disable-log-stats", "--served-model-name", "--gpu-memory-utilization"}, + } + + for category, flags := range categories { + t.Run(category, func(t *testing.T) { + for _, flag := range flags { + if !VLLMAllowedFlags[flag] { + t.Errorf("VLLMAllowedFlags missing %s flag %q", category, flag) + } + } + }) + } +} + +func TestDangerousFlagsNotAllowed(t *testing.T) { + // Ensure dangerous flags involving file paths are NOT in the allowlists + dangerousFlags := []string{ + // File path flags + "--log-file", + "--output-file", + "--model-path", + "--config-file", + "--lora-path", + "--grammar-file", + "--prompt-file", + // llama.cpp specific path flags + "--slot-save-path", + "-mm", "--mmproj", + "-mmu", "--mmproj-url", + "-jf", "--json-schema-file", + "--chat-template-file", + "--path", + "--webui-config-file", + "--api-key-file", + "--ssl-key-file", + "--ssl-cert-file", + "--models-dir", + "--models-preset", + "-md", "--model-draft", + "--lora", + "--lora-scaled", + "--control-vector", + "--control-vector-scaled", + } + + for _, flag := range dangerousFlags { + if LlamaCppAllowedFlags[flag] { + t.Errorf("Dangerous flag %q should not be in LlamaCppAllowedFlags", flag) + } + if VLLMAllowedFlags[flag] { + t.Errorf("Dangerous flag %q should not be in VLLMAllowedFlags", flag) + } + } +} + +func TestIssue515Flags(t *testing.T) { + // Verify all flags from GitHub issue #515 are allowed + issue515Flags := []string{ + "--n-gpu-layers", + "--no-mmap", + "--flash-attn", + "--jinja", + "--top-p", + "--top-k", + "--temp", + "--min-p", + "--presence-penalty", + "--cache-type-k", + "--cache-type-v", + "--n-predict", + "--threads", + } + + for _, flag := range issue515Flags { + if !LlamaCppAllowedFlags[flag] { + t.Errorf("Flag %q from issue #515 should be in LlamaCppAllowedFlags", flag) + } + } +} diff --git a/pkg/inference/runtime_flags_test.go b/pkg/inference/runtime_flags_test.go index 0f2543aa..78fb84a3 100644 --- a/pkg/inference/runtime_flags_test.go +++ b/pkg/inference/runtime_flags_test.go @@ -1,165 +1,247 @@ package inference import ( + "strings" "testing" ) func TestValidateRuntimeFlags(t *testing.T) { tests := []struct { name string + backend string flags []string expectError bool description string }{ + // Tests for llama.cpp backend with allowlist { - name: "empty flags", + name: "llama.cpp: empty flags", + backend: "llama.cpp", flags: []string{}, expectError: false, description: "Empty array should pass validation", }, { - name: "nil flags", + name: "llama.cpp: nil flags", + backend: "llama.cpp", flags: nil, expectError: false, description: "Nil array should pass validation", }, { - name: "valid flags without paths", - flags: []string{"--verbose", "--debug", "--threads", "4"}, + name: "llama.cpp: valid allowed flags", + backend: "llama.cpp", + flags: []string{"--verbose", "--threads", "4"}, expectError: false, - description: "Simple flags without paths should pass", + description: "Allowed flags should pass", }, { - name: "valid single character flags", - flags: []string{"-v", "-d", "-t", "4"}, + name: "llama.cpp: valid single character flags", + backend: "llama.cpp", + flags: []string{"-v", "-t", "4"}, expectError: false, - description: "Single character flags should pass", + description: "Single character allowed flags should pass", }, { - name: "valid flags with numbers and hyphens", - flags: []string{"--gpu-memory-utilization", "0.9", "--max-tokens", "1024"}, + name: "llama.cpp: flag with equals format", + backend: "llama.cpp", + flags: []string{"--threads=4", "--ctx-size=2048"}, expectError: false, - description: "Flags with hyphens and numeric values should pass", + description: "Flags with = format should pass", }, { - name: "reject absolute path in value", - flags: []string{"--log-file", "/var/log/model.log"}, + name: "llama.cpp: reject non-allowed flag", + backend: "llama.cpp", + flags: []string{"--log-file", "test.log"}, expectError: true, - description: "Absolute paths should be rejected", + description: "Non-allowed flags should be rejected", }, { - name: "reject absolute path in flag=value format", - flags: []string{"--log-file=/var/log/model.log"}, + name: "llama.cpp: reject path in allowed flag value", + backend: "llama.cpp", + flags: []string{"--threads", "/etc/passwd"}, expectError: true, - description: "Paths in flag=value format should be rejected", + description: "Paths in values should be rejected", }, { - name: "reject relative path with parent directory", - flags: []string{"--output", "../file.txt"}, + name: "llama.cpp: reject path in flag=value format", + backend: "llama.cpp", + flags: []string{"--threads=/var/log/test"}, expectError: true, - description: "Relative paths with ../ should be rejected", + description: "Paths in flag=value format should be rejected", }, { - name: "reject relative path with current directory", - flags: []string{"--config", "./config.yaml"}, - expectError: true, - description: "Relative paths with ./ should be rejected", + name: "llama.cpp: multiple allowed flags", + backend: "llama.cpp", + flags: []string{"--threads", "4", "--ctx-size", "2048", "--verbose", "--flash-attn"}, + expectError: false, + description: "Multiple allowed flags should pass", + }, + { + name: "llama.cpp: GPU flags allowed", + backend: "llama.cpp", + flags: []string{"-ngl", "99", "--main-gpu", "0"}, + expectError: false, + description: "GPU-related flags should be allowed", + }, + { + name: "llama.cpp: sampling flags allowed", + backend: "llama.cpp", + flags: []string{"--temp", "0.7", "--top-p", "0.9", "--seed", "42"}, + expectError: false, + description: "Sampling flags should be allowed", + }, + { + name: "llama.cpp: real-world flags from issue 515", + backend: "llama.cpp", + flags: []string{ + "--n-gpu-layers", "99", + "--jinja", + "--top-p", "0.8", + "--top-k", "20", + "--temp", "0.7", + "--min-p", "0.0", + "--presence-penalty", "1.5", + "--no-mmap", + "--flash-attn", + "--cache-type-k", "q8_0", + "--cache-type-v", "q8_0", + }, + expectError: false, + description: "Real-world flags from GitHub issue 515 should be allowed", + }, + + // Tests for vLLM backend with allowlist + { + name: "vllm: valid allowed flags", + backend: "vllm", + flags: []string{"--tensor-parallel-size", "2", "--max-model-len", "4096"}, + expectError: false, + description: "Allowed vLLM flags should pass", }, { - name: "reject Windows-style path with forward slash", - flags: []string{"--file", "C:/Users/file.txt"}, + name: "vllm: reject non-allowed flag", + backend: "vllm", + flags: []string{"--output-file", "test.log"}, expectError: true, - description: "Windows-style paths with forward slash should be rejected", + description: "Non-allowed flags should be rejected for vLLM", }, { - name: "reject Windows-style path with backslash", - flags: []string{"--file", "C:\\Users\\file.txt"}, + name: "vllm: short flags allowed", + backend: "vllm", + flags: []string{"-tp", "2", "-q", "awq"}, + expectError: false, + description: "Short vLLM flags should be allowed", + }, + + // Tests for unknown backend + { + name: "unknown backend: valid flags without paths", + backend: "unknown-backend", + flags: []string{"--verbose", "--debug", "--threads", "4"}, expectError: true, - description: "Windows-style paths with backslash should be rejected", + description: "Unknown backend should reject all flags (no allowlist)", }, + + // Path safety tests (defense-in-depth) { - name: "reject Windows relative path with backslash", - flags: []string{"--config", "..\\config.yaml"}, + name: "llama.cpp: reject relative path with parent directory", + backend: "llama.cpp", + flags: []string{"--threads", "../file.txt"}, expectError: true, - description: "Windows relative paths with backslash should be rejected", + description: "Relative paths with ../ should be rejected", }, { - name: "reject Windows current directory path", - flags: []string{"--output", ".\\output.txt"}, + name: "llama.cpp: reject relative path with current directory", + backend: "llama.cpp", + flags: []string{"--threads", "./config.yaml"}, expectError: true, - description: "Windows current directory paths should be rejected", + description: "Relative paths with ./ should be rejected", }, { - name: "reject UNC network path", - flags: []string{"--share", "\\\\server\\share\\file.txt"}, + name: "llama.cpp: reject Windows-style path with backslash", + backend: "llama.cpp", + flags: []string{"--threads", "C:\\Users\\file.txt"}, expectError: true, - description: "UNC network paths should be rejected", + description: "Windows-style paths with backslash should be rejected", }, { - name: "reject Windows system path", - flags: []string{"--log", "C:\\Windows\\System32\\log.txt"}, + name: "llama.cpp: reject UNC network path", + backend: "llama.cpp", + flags: []string{"--threads", "\\\\server\\share\\file.txt"}, expectError: true, - description: "Windows system paths should be rejected", + description: "UNC network paths should be rejected", }, { - name: "reject URL with http", - flags: []string{"--endpoint", "http://example.com/api"}, + name: "llama.cpp: reject URL with http", + backend: "llama.cpp", + flags: []string{"--threads", "http://example.com/api"}, expectError: true, description: "URLs should be rejected (conservative approach)", }, { - name: "reject URL with https", - flags: []string{"--api-url", "https://api.example.com/v1"}, - expectError: true, - description: "HTTPS URLs should be rejected (conservative approach)", + name: "llama.cpp: valid flag with special characters except slash", + backend: "llama.cpp", + flags: []string{"--temp", "0.7", "--seed", "42"}, + expectError: false, + description: "Flags with dots and numbers (no slash) should pass", }, + + // Flag injection tests (smuggling flags via = separator) { - name: "reject path in middle of flag list", - flags: []string{"--verbose", "--log-file", "/tmp/log.txt", "--debug"}, + name: "llama.cpp: reject long flag injection via equals", + backend: "llama.cpp", + flags: []string{"--seed=--log-file=container-to-host.log"}, expectError: true, - description: "Path anywhere in flag list should be rejected", + description: "Smuggled long flags via = separator should be rejected", }, { - name: "reject multiple paths", - flags: []string{"--input", "/path/to/input", "--output", "/path/to/output"}, + name: "llama.cpp: reject short flag injection via equals", + backend: "llama.cpp", + flags: []string{"--seed=-l"}, expectError: true, - description: "Multiple paths should be rejected", + description: "Smuggled short flags via = separator should be rejected", }, { - name: "reject path traversal attempt", - flags: []string{"--file", "../../etc/passwd"}, + name: "llama.cpp: reject dash-only value via equals", + backend: "llama.cpp", + flags: []string{"--seed=-"}, expectError: true, - description: "Path traversal attempts should be rejected", + description: "Single dash as value should be rejected", }, { - name: "reject root directory", - flags: []string{"--root", "/"}, + name: "llama.cpp: reject dash-dot value via equals", + backend: "llama.cpp", + flags: []string{"--temp=-.5"}, expectError: true, - description: "Root directory should be rejected", + description: "Dash followed by non-digit should be rejected", }, { - name: "reject home directory path", - flags: []string{"--home", "/home/user/.config"}, - expectError: true, - description: "Home directory paths should be rejected", + name: "llama.cpp: allow negative integer via equals", + backend: "llama.cpp", + flags: []string{"--threads=-1"}, + expectError: false, + description: "Negative integer values should be allowed", }, { - name: "valid flag with special characters except slash", - flags: []string{"--model-name", "llama-3.2-1b", "--temperature", "0.7"}, + name: "llama.cpp: allow negative float via equals", + backend: "llama.cpp", + flags: []string{"--temp=-0.5"}, expectError: false, - description: "Flags with dots, hyphens, and numbers (no slash) should pass", + description: "Negative float values should be allowed", }, { - name: "valid flag with underscore", - flags: []string{"--max_tokens", "512", "--use_cache"}, + name: "llama.cpp: allow zero via equals", + backend: "llama.cpp", + flags: []string{"--seed=0"}, expectError: false, - description: "Flags with underscores should pass", + description: "Zero value should be allowed", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := ValidateRuntimeFlags(tt.flags) + err := ValidateRuntimeFlags(tt.backend, tt.flags) if tt.expectError { if err == nil { @@ -174,36 +256,85 @@ func TestValidateRuntimeFlags(t *testing.T) { } } -func TestValidateRuntimeFlags_ErrorMessage(t *testing.T) { - // Test that error messages are helpful - flags := []string{"--log-file", "/var/log/test.log"} - err := ValidateRuntimeFlags(flags) +func TestValidateRuntimeFlags_ErrorMessages(t *testing.T) { + // Test that allowlist error messages are helpful + t.Run("allowlist rejection message", func(t *testing.T) { + err := ValidateRuntimeFlags("llama.cpp", []string{"--log-file", "test.log"}) + if err == nil { + t.Fatal("Expected error but got none") + } + + errMsg := err.Error() + if !strings.Contains(errMsg, "--log-file") { + t.Errorf("Error message should contain the offending flag, got: %s", errMsg) + } + if !strings.Contains(errMsg, "not allowed") { + t.Errorf("Error message should explain rejection, got: %s", errMsg) + } + if !strings.Contains(errMsg, "llama.cpp") { + t.Errorf("Error message should mention the backend, got: %s", errMsg) + } + }) - if err == nil { - t.Fatal("Expected error but got none") - } + // Test that path safety error messages are helpful + t.Run("path rejection message", func(t *testing.T) { + err := ValidateRuntimeFlags("llama.cpp", []string{"--threads", "/var/log/test.log"}) + if err == nil { + t.Fatal("Expected error but got none") + } - errMsg := err.Error() - if !contains(errMsg, "/var/log/test.log") { - t.Errorf("Error message should contain the offending flag value, got: %s", errMsg) - } - if !contains(errMsg, "paths are not allowed") { - t.Errorf("Error message should explain why it failed, got: %s", errMsg) - } + errMsg := err.Error() + if !strings.Contains(errMsg, "/var/log/test.log") { + t.Errorf("Error message should contain the offending value, got: %s", errMsg) + } + if !strings.Contains(errMsg, "paths are not allowed") { + t.Errorf("Error message should explain why it failed, got: %s", errMsg) + } + }) } -// contains is a helper function to check if a string contains a substring -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || substr == "" || - (s != "" && indexOf(s, substr) >= 0)) -} +func TestValidatePathSafety(t *testing.T) { + tests := []struct { + name string + flags []string + expectError bool + }{ + { + name: "no paths", + flags: []string{"--verbose", "--threads", "4"}, + expectError: false, + }, + { + name: "forward slash", + flags: []string{"--file", "/etc/passwd"}, + expectError: true, + }, + { + name: "backslash", + flags: []string{"--file", "C:\\Windows\\file"}, + expectError: true, + }, + { + name: "relative path forward", + flags: []string{"../file"}, + expectError: true, + }, + { + name: "relative path backward", + flags: []string{"..\\file"}, + expectError: true, + }, + } -// indexOf returns the index of substr in s, or -1 if not found -func indexOf(s, substr string) int { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return i - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePathSafety(tt.flags) + if tt.expectError && err == nil { + t.Error("expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) } - return -1 } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 631b8218..33db3834 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -246,8 +246,8 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe } } - // Validate runtime flags to prevent path-based security issues - if err := inference.ValidateRuntimeFlags(runtimeFlags); err != nil { + // Validate runtime flags against backend allowlist and path safety + if err := inference.ValidateRuntimeFlags(backend.Name(), runtimeFlags); err != nil { return nil, err }