Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions derp/derphttp/derphttp_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,31 @@ import (
// Send/Recv will completely re-establish the connection (unless Close
// has been called).
type Client struct {
Header http.Header
Header http.Header

// GetHeaders, if non-nil, returns a fresh set of HTTP headers to send
// on every (re)connect to the DERP server. When non-nil it takes
// precedence over Header. This is useful when a caller needs to inject
// short-lived authentication tokens (e.g. for an authenticating
// reverse proxy in front of DERP) that must be refreshed on each
// reconnect, rather than captured once at startup. The same pattern
// is already used by netcheck.Client.GetDERPHeaders.
//
// Implementations must be cheap and non-blocking: GetHeaders is invoked
// from connect() while the Client's internal mutex is held, so it must
// not call back into this Client (Send, Close, etc.) or acquire the
// lock of any caller that may, in turn, call into this Client (notably
// magicsock.Conn.mu). It should also avoid blocking I/O on the hot
// reconnect path; cache and refresh in the background where possible.
//
// Returning a nil http.Header is treated the same as a missing static
// Header: the connect request goes out without those caller-supplied
// headers (i.e. without auth). Implementations that fail to obtain
// fresh credentials should generally return the most recent known-good
// value rather than nil, so the server can return a clear 401 instead
// of a generic auth failure.
GetHeaders func() http.Header

TLSConfig *tls.Config // optional; nil means default
DNSCache *dnscache.Resolver // optional; nil means no caching
MeshKey string // optional; for trusted clients
Expand Down Expand Up @@ -113,6 +137,17 @@ func (c *Client) String() string {
return fmt.Sprintf("<derphttp_client.Client %s url=%s>", c.serverPubKey.ShortString(), c.url)
}

// headers returns the HTTP headers to send on the next DERP connection
// attempt. If GetHeaders is set, it is invoked on every call (so callers can
// refresh short-lived tokens). Otherwise the static Header field is used.
// Either may be nil.
func (c *Client) headers() http.Header {
if c.GetHeaders != nil {
return c.GetHeaders()
}
return c.Header
}

// NewRegionClient returns a new DERP-over-HTTP client. It connects lazily.
// To trigger a connection, use Connect.
// The netMon parameter is optional; if non-nil it's used to do faster interface lookups.
Expand Down Expand Up @@ -430,7 +465,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
tlsConfig = c.tlsConfig(nil)
}
c.logf("%s: connecting websocket to %v", caller, urlStr)
conn, err := dialWebsocketFunc(ctx, urlStr, tlsConfig, c.Header)
conn, err := dialWebsocketFunc(ctx, urlStr, tlsConfig, c.headers())
if err != nil {
c.logf("%s: websocket to %v error: %v", caller, urlStr, err)
return nil, 0, err
Expand Down Expand Up @@ -533,8 +568,8 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
if err != nil {
return nil, 0, err
}
if c.Header != nil {
req.Header = c.Header.Clone()
if h := c.headers(); h != nil {
req.Header = h.Clone()
}
req.Header.Set("Upgrade", "DERP")
req.Header.Set("Connection", "Upgrade")
Expand Down
33 changes: 33 additions & 0 deletions derp/derphttp/derphttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,36 @@ func TestForceWebsockets(t *testing.T) {

c.Close()
}

func TestClientHeaders(t *testing.T) {
t.Run("nil when neither set", func(t *testing.T) {
c := &Client{}
if got := c.headers(); got != nil {
t.Fatalf("expected nil headers, got %v", got)
}
})
t.Run("returns Header when GetHeaders is nil", func(t *testing.T) {
want := http.Header{"X-Test": []string{"static"}}
c := &Client{Header: want}
if got := c.headers().Get("X-Test"); got != "static" {
t.Fatalf("expected static header, got %q", got)
}
})
t.Run("GetHeaders takes precedence and is invoked on every call", func(t *testing.T) {
var calls int
c := &Client{
Header: http.Header{"X-Test": []string{"static"}},
GetHeaders: func() http.Header {
calls++
return http.Header{"X-Test": []string{"dynamic"}}
},
}
if got := c.headers().Get("X-Test"); got != "dynamic" {
t.Fatalf("expected dynamic header, got %q", got)
}
_ = c.headers()
if calls != 2 {
t.Fatalf("expected GetHeaders to be invoked on every call, got %d calls", calls)
}
})
}
3 changes: 3 additions & 0 deletions wgengine/magicsock/derp.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ func (c *Conn) derpWriteChanOfAddr(addr netip.AddrPort, peer key.NodePublic) cha
if header != nil {
dc.Header = header.Clone()
}
if getHeaders := c.derpGetHeaders.Load(); getHeaders != nil {
dc.GetHeaders = *getHeaders
}
dc.ForceWebsockets = c.derpForceWebsockets.Load()
dialer := c.derpRegionDialer.Load()
if dialer != nil {
Expand Down
19 changes: 19 additions & 0 deletions wgengine/magicsock/magicsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ type Conn struct {
// headers that are passed to the DERP HTTP client
derpHeader atomic.Pointer[http.Header]

// derpGetHeaders, if non-nil, is called by the DERP HTTP client on every
// (re)connect to obtain a fresh set of HTTP headers. When non-nil it
// takes precedence over derpHeader. See derphttp.Client.GetHeaders.
derpGetHeaders atomic.Pointer[func() http.Header]

// whether websocket is always used by the DERP HTTP client
derpForceWebsockets atomic.Bool

Expand Down Expand Up @@ -462,6 +467,9 @@ func NewConn(opts Options) (*Conn, error) {
PortMapper: c.portMapper,
UseDNSCache: true,
GetDERPHeaders: func() http.Header {
if getHeaders := c.derpGetHeaders.Load(); getHeaders != nil {
return (*getHeaders)()
}
h := c.derpHeader.Load()
if h == nil {
return nil
Expand Down Expand Up @@ -1759,6 +1767,17 @@ func (c *Conn) SetDERPHeader(header http.Header) {
c.derpHeader.Store(&header)
}

// SetDERPGetHeaders sets a callback invoked by the DERP HTTP client on every
// (re)connect to obtain a fresh set of HTTP headers. When non-nil it takes
// precedence over the value set with SetDERPHeader. Pass nil to clear.
func (c *Conn) SetDERPGetHeaders(getHeaders func() http.Header) {
if getHeaders == nil {
c.derpGetHeaders.Store(nil)
return
}
c.derpGetHeaders.Store(&getHeaders)
}

func (c *Conn) SetDERPForceWebsockets(v bool) {
c.derpForceWebsockets.Store(v)
}
Expand Down
Loading