diff --git a/cmd/root/api.go b/cmd/root/api.go index 478c6ce81e..e158d37074 100644 --- a/cmd/root/api.go +++ b/cmd/root/api.go @@ -1,6 +1,7 @@ package root import ( + "cmp" "errors" "fmt" "log/slog" @@ -13,6 +14,7 @@ import ( "github.com/docker/docker-agent/pkg/cli" "github.com/docker/docker-agent/pkg/config" pathx "github.com/docker/docker-agent/pkg/path" + "github.com/docker/docker-agent/pkg/profiling" "github.com/docker/docker-agent/pkg/server" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/telemetry" @@ -25,6 +27,7 @@ type apiFlags struct { fakeResponses string recordPath string authToken string + pprofAddr string runConfig config.RuntimeConfig } @@ -44,6 +47,8 @@ func newAPICmd() *cobra.Command { cmd.PersistentFlags().StringVar(&flags.fakeResponses, "fake", "", "Replay AI responses from cassette file (for testing)") cmd.PersistentFlags().StringVar(&flags.recordPath, "record", "", "Record AI API interactions to cassette file") cmd.PersistentFlags().StringVar(&flags.authToken, "auth-token", "", "Bearer token required for API requests (empty = no authentication)") + cmd.PersistentFlags().StringVar(&flags.pprofAddr, "pprof-addr", "", "TCP host:port to expose Go pprof endpoints at /debug/pprof/ (e.g. 127.0.0.1:6060); also set via CAGENT_PPROF_ADDR") + _ = cmd.PersistentFlags().MarkHidden("pprof-addr") cmd.MarkFlagsMutuallyExclusive("fake", "record") addRuntimeConfigFlags(cmd, &flags.runConfig) @@ -89,6 +94,12 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) (commandErr return errors.New("--pull-interval flag can only be used with OCI or URL references, not local files") } + if pprofAddr := cmp.Or(f.pprofAddr, os.Getenv("CAGENT_PPROF_ADDR")); pprofAddr != "" { + if err := profiling.StartPprofServer(ctx, pprofAddr); err != nil { + return err + } + } + ln, lnCleanup, err := newListener(ctx, f.listenAddr) if err != nil { return err diff --git a/pkg/profiling/pprof_server.go b/pkg/profiling/pprof_server.go new file mode 100644 index 0000000000..1d9a40da5c --- /dev/null +++ b/pkg/profiling/pprof_server.go @@ -0,0 +1,69 @@ +package profiling + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "net/http/pprof" + "time" +) + +// StartPprofServer starts an HTTP server exposing Go runtime profiling endpoints +// at /debug/pprof/ on the given addr. It binds the listener synchronously and +// returns an error if the address is unavailable. The server runs in a background +// goroutine and shuts down when ctx is cancelled. +// addr must be a TCP host:port address (e.g. "127.0.0.1:6060"); unix://, npipe://, +// and fd:// schemes are not supported. Prefer a loopback address over a bare port +// (":6060") — the latter binds all interfaces, exposing process memory and arguments +// to the network. +func StartPprofServer(ctx context.Context, addr string) error { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + + ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", addr) + if err != nil { + return fmt.Errorf("pprof: listen on %s: %w", addr, err) + } + + // ReadHeaderTimeout guards against slow-loris connections on the debug port. + // WriteTimeout is intentionally omitted: profile/trace captures legitimately + // run for tens of seconds and would be truncated by a short write deadline. + srv := &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} + + slog.InfoContext(ctx, "pprof server listening", "addr", ln.Addr().String()) + if tcpAddr, ok := ln.Addr().(*net.TCPAddr); ok && !tcpAddr.IP.IsLoopback() { + slog.WarnContext(ctx, "pprof server is listening on a non-loopback address — "+ + "/debug/pprof/cmdline and heap profiles are network-reachable without authentication", + "addr", tcpAddr.String()) + } + + go func() { + if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + slog.Warn("pprof server error", "error", err) + } + }() + + go func() { + <-ctx.Done() + // 5s grace: favors prompt process exit over draining in-flight profile + // captures. CPU/trace profiles run up to 30s by default; callers should + // cancel their requests before the process exits rather than relying on + // this timeout to drain them. + // context.WithoutCancel preserves ctx values (e.g. trace IDs) without + // inheriting the cancellation, so the shutdown timeout is not pre-expired. + shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + slog.WarnContext(shutdownCtx, "pprof server shutdown error", "error", err) + } + }() + + return nil +}