diff --git a/config/config.go b/config/config.go index f4107e42..af16502b 100644 --- a/config/config.go +++ b/config/config.go @@ -15,6 +15,10 @@ type Anthropic struct { CircuitBreaker *CircuitBreaker SendActorHeaders bool ExtraHeaders map[string]string + // BYOKBearerToken is set in BYOK mode when the user authenticates + // with an OAuth token (e.g. Claude Max/Pro subscription). When set, + // the SDK uses Authorization: Bearer instead of X-Api-Key. + BYOKBearerToken string } type AWSBedrock struct { diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 09372ec7..9d134bc0 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -23,6 +23,7 @@ import ( "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" + "github.com/coder/aibridge/utils" "github.com/coder/quartz" "github.com/tidwall/sjson" @@ -205,7 +206,19 @@ func (i *interceptionBase) isSmallFastModel() bool { } func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { - opts = append(opts, option.WithAPIKey(i.cfg.Key)) + // BYOK with OAuth token (Claude Max/Pro) uses Authorization: Bearer. + // Otherwise use X-Api-Key (centralized or BYOK with personal API key). + if i.cfg.BYOKBearerToken != "" { + i.logger.Debug(ctx, "using byok oauth bearer auth", + slog.F("bearer_hint", utils.MaskSecret(i.cfg.BYOKBearerToken)), + ) + opts = append(opts, option.WithAuthToken(i.cfg.BYOKBearerToken)) + } else { + i.logger.Debug(ctx, "using api key auth", + slog.F("api_key_hint", utils.MaskSecret(i.cfg.Key)), + ) + opts = append(opts, option.WithAPIKey(i.cfg.Key)) + } opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) // Add extra headers if configured. diff --git a/passthrough.go b/passthrough.go index c6b59edd..dc5af371 100644 --- a/passthrough.go +++ b/passthrough.go @@ -12,6 +12,7 @@ import ( "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" + "github.com/coder/aibridge/utils" "github.com/coder/quartz" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -96,6 +97,16 @@ func newPassthroughRouter(provider provider.Provider, logger slog.Logger, m *met // Inject provider auth. provider.InjectAuthHeader(&req.Header) + + if authz := req.Header.Get("Authorization"); authz != "" { + logger.Debug(ctx, "passthrough using oauth bearer auth", + slog.F("bearer_hint", utils.MaskSecret(authz)), + ) + } else { + logger.Debug(ctx, "passthrough using api key auth", + slog.F("api_key_hint", utils.MaskSecret(req.Header.Get("X-Api-Key"))), + ) + } }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) { logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path)) diff --git a/provider/anthropic.go b/provider/anthropic.go index 4a79cd42..a51bd419 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -110,11 +110,35 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr cfg := p.cfg cfg.ExtraHeaders = extractAnthropicHeaders(r) + // coder/aibridged strips all headers that may carry the Coder + // session token before passing the request here, so this code + // sees only legitimate LLM credentials. + // + // In centralized mode neither Authorization nor X-Api-Key is + // present, so cfg keeps the centralized key unchanged. + // + // In BYOK mode the user's LLM credentials survive intact. + // Either Authorization or X-Api-Key will be present, but not + // both. If Authorization is present it means the user has a + // Claude Max/Pro subscription and authenticated via OAuth; + // in this case set BYOKBearerToken so the SDK uses WithAuthToken() + // and clear the centralized key. If X-Api-Key is present it means the user + // has a personal API key; overwrite the centralized key with it. + authHeaderName := p.AuthHeader() + if bearer := r.Header.Get("Authorization"); bearer != "" { + cfg.BYOKBearerToken = strings.TrimPrefix(bearer, "Bearer ") + cfg.Key = "" + authHeaderName = "Authorization" + } else if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" { + cfg.Key = apiKey + authHeaderName = "X-Api-Key" + } + var interceptor intercept.Interceptor if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) + interceptor = messages.NewStreamingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, p.AuthHeader(), tracer) + interceptor = messages.NewBlockingInterceptor(id, &req, payload, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil @@ -137,6 +161,12 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } + // BYOK: if the request already carries user-supplied credentials, + // do not overwrite them with the centralized key. + if headers.Get("X-Api-Key") != "" || headers.Get("Authorization") != "" { + return + } + headers.Set(p.AuthHeader(), p.cfg.Key) } diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index d34fd029..484589cc 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -86,9 +86,8 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}], "stream": false}` req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) req.Header.Set("Anthropic-Beta", betaHeader) - // Simulate a client sending its own auth credential, which must be replaced - // by aibridge with the configured provider key. - req.Header.Set("Authorization", "Bearer fake-client-bearer") + // Simulate BYOK: the client sends its own bearer token for upstream auth. + req.Header.Set("Authorization", "Bearer user-oauth-token") w := httptest.NewRecorder() interceptor, err := provider.CreateInterceptor(w, req, testTracer) @@ -105,11 +104,77 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { // Verify the full Anthropic-Beta header (all betas) was forwarded unchanged. assert.Equal(t, betaHeader, receivedHeaders.Get("Anthropic-Beta"), "Anthropic-Beta header must be forwarded unchanged to upstream") - // Verify aibridge's configured key was used and the client's auth credential was not forwarded. - assert.Equal(t, "test-key", receivedHeaders.Get("X-Api-Key"), "upstream must receive configured provider key") - assert.Empty(t, receivedHeaders.Get("Authorization"), "client Authorization header must not reach upstream") + // The client sent Authorization: Bearer, so BYOK bearer mode is active. + // The SDK uses Authorization (not X-Api-Key) for bearer auth. + assert.Empty(t, receivedHeaders.Get("X-Api-Key"), "X-Api-Key must not be set in BYOK bearer mode") + assert.Equal(t, "Bearer user-oauth-token", receivedHeaders.Get("Authorization"), "upstream must receive the client's bearer token") }) + byokTests := []struct { + name string + setHeaders map[string]string + wantXApiKey string + wantAuthorization string + }{ + { + name: "Messages_BYOK_BearerToken", + setHeaders: map[string]string{"Authorization": "Bearer user-oauth-token"}, + wantAuthorization: "Bearer user-oauth-token", + }, + { + name: "Messages_BYOK_APIKey", + setHeaders: map[string]string{"X-Api-Key": "user-api-key"}, + wantXApiKey: "user-api-key", + }, + { + name: "Messages_Centralized_UsesCentralizedKey", + setHeaders: map[string]string{}, + wantXApiKey: "test-key", + }, + } + + for _, tc := range byokTests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var receivedHeaders http.Header + + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg-123","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + })) + t.Cleanup(mockUpstream.Close) + + provider := NewAnthropic(config.Anthropic{ + BaseURL: mockUpstream.URL, + Key: "test-key", + }, nil) + + body := `{"model": "claude-opus-4-5", "max_tokens": 1024, "messages": [{"role": "user", "content": "hello"}], "stream": false}` + req := httptest.NewRequest(http.MethodPost, routeMessages, bytes.NewBufferString(body)) + for k, v := range tc.setHeaders { + req.Header.Set(k, v) + } + w := httptest.NewRecorder() + + interceptor, err := provider.CreateInterceptor(w, req, testTracer) + require.NoError(t, err) + require.NotNil(t, interceptor) + + logger := slog.Make() + interceptor.Setup(logger, &testutil.MockRecorder{}, nil) + + processReq := httptest.NewRequest(http.MethodPost, routeMessages, nil) + err = interceptor.ProcessRequest(w, processReq) + require.NoError(t, err) + + assert.Equal(t, tc.wantXApiKey, receivedHeaders.Get("X-Api-Key")) + assert.Equal(t, tc.wantAuthorization, receivedHeaders.Get("Authorization")) + }) + } + t.Run("UnknownRoute", func(t *testing.T) { t.Parallel() @@ -124,6 +189,52 @@ func TestAnthropic_CreateInterceptor(t *testing.T) { }) } +func TestAnthropic_InjectAuthHeader_BYOK(t *testing.T) { + t.Parallel() + + provider := NewAnthropic(config.Anthropic{Key: "centralized-key"}, nil) + + tests := []struct { + name string + presetHeaders map[string]string + wantXApiKey string + wantAuthorization string + }{ + { + name: "no pre-existing auth headers injects centralized key", + presetHeaders: map[string]string{}, + wantXApiKey: "centralized-key", + }, + { + name: "pre-existing X-Api-Key is not overwritten", + presetHeaders: map[string]string{"X-Api-Key": "user-api-key"}, + wantXApiKey: "user-api-key", + }, + { + name: "pre-existing Authorization prevents centralized key injection", + presetHeaders: map[string]string{"Authorization": "Bearer user-oauth-token"}, + wantXApiKey: "", + wantAuthorization: "Bearer user-oauth-token", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + headers := http.Header{} + for k, v := range tc.presetHeaders { + headers.Set(k, v) + } + + provider.InjectAuthHeader(&headers) + + assert.Equal(t, tc.wantXApiKey, headers.Get("X-Api-Key")) + assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization")) + }) + } +} + func TestExtractAnthropicHeaders(t *testing.T) { t.Parallel() diff --git a/utils/mask.go b/utils/mask.go new file mode 100644 index 00000000..72de7eec --- /dev/null +++ b/utils/mask.go @@ -0,0 +1,10 @@ +package utils + +// MaskSecret returns the first 4 and last 4 characters of s +// separated by "...", or the full string if 8 characters or fewer. +func MaskSecret(s string) string { + if len(s) <= 8 { + return s + } + return s[:4] + "..." + s[len(s)-4:] +} diff --git a/utils/mask_test.go b/utils/mask_test.go new file mode 100644 index 00000000..a8cf2f01 --- /dev/null +++ b/utils/mask_test.go @@ -0,0 +1,30 @@ +package utils_test + +import ( + "testing" + + "github.com/coder/aibridge/utils" + "github.com/stretchr/testify/assert" +) + +func TestMaskSecret(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + expected string + }{ + {"", ""}, + {"short", "short"}, + {"exactly8", "exactly8"}, + {"sk-ant-api03-abcdefgh", "sk-a...efgh"}, + {"sk-ant-oat01-abcdefghijklmnop", "sk-a...mnop"}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, utils.MaskSecret(tc.input)) + }) + } +}