diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 09372ec7..924c6f70 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -24,6 +24,7 @@ import ( "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/google/uuid" @@ -310,8 +311,29 @@ func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibco return out, nil } +const ( + // defaultAdaptiveFallbackBudgetTokens is the default budget_tokens value used when + // converting adaptive thinking to enabled thinking for Bedrock models that don't + // support adaptive thinking, and the request does not contain a max_tokens value. + defaultAdaptiveFallbackBudgetTokens = 10000 + + // minBedrockBudgetTokens is the minimum budget_tokens value that Bedrock accepts. + minBedrockBudgetTokens = 1024 +) + +// bedrockModelSupportsAdaptiveThinking returns true if the given Bedrock model +// supports the adaptive thinking type. Only Claude Opus 4.6 and Sonnet 4.6 support it. +func bedrockModelSupportsAdaptiveThinking(model string) bool { + model = strings.ToLower(model) + return strings.Contains(model, "opus-4-6") || + strings.Contains(model, "sonnet-4-6") || + strings.Contains(model, "opus-4.6") || + strings.Contains(model, "sonnet-4.6") +} + // augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support -// Anthropics' model names. +// Anthropics' model names. It also converts adaptive thinking to enabled thinking for Bedrock models +// that don't support the adaptive thinking type. func (i *interceptionBase) augmentRequestForBedrock() { if i.bedrockCfg == nil { return @@ -324,6 +346,79 @@ func (i *interceptionBase) augmentRequestForBedrock() { if err != nil { i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err)) } + + i.convertAdaptiveThinkingForBedrock() +} + +// convertAdaptiveThinkingForBedrock converts "thinking": {"type": "adaptive"} to +// "thinking": {"type": "enabled", "budget_tokens": N} when the target Bedrock model +// does not support adaptive thinking. +// +// The budget_tokens value is derived from the request's max_tokens to satisfy the +// API constraint that budget_tokens < max_tokens. If max_tokens is not set or too +// small, a safe default is used. +func (i *interceptionBase) convertAdaptiveThinkingForBedrock() { + thinkingType := gjson.GetBytes(i.payload, "thinking.type").Str + if thinkingType != "adaptive" { + return + } + + if bedrockModelSupportsAdaptiveThinking(i.Model()) { + return + } + + budgetTokens, ok := adaptiveFallbackBudgetTokens(i.payload) + + if !ok { + // max_tokens is too small to accommodate the Bedrock minimum budget. + // Leave the payload unchanged and let the request fail with a clear + // upstream error rather than sending a known-bad budget_tokens value. + i.logger.Warn(context.Background(), + "cannot convert adaptive thinking for Bedrock: max_tokens is too small to fit minimum budget_tokens", + slog.F("model", i.Model()), + slog.F("min_budget_tokens", minBedrockBudgetTokens), + ) + return + } + + i.logger.Info(context.Background(), "converting adaptive thinking to enabled for Bedrock model", + slog.F("model", i.Model()), + slog.F("budget_tokens", budgetTokens), + ) + + var err error + i.payload, err = sjson.SetBytes(i.payload, "thinking", map[string]any{ + "type": "enabled", + "budget_tokens": budgetTokens, + }) + if err != nil { + i.logger.Warn(context.Background(), "failed to convert adaptive thinking for Bedrock", slog.Error(err)) + } +} + +// adaptiveFallbackBudgetTokens computes a safe budget_tokens value from the +// request payload. The Anthropic API requires budget_tokens < max_tokens and +// Bedrock enforces a minimum of 1024. +// +// Returns (budget, true) on success, or (0, false) when max_tokens is present +// but too small to accommodate the minimum — the caller should skip conversion +// in that case rather than write an invalid value. +func adaptiveFallbackBudgetTokens(payload []byte) (int64, bool) { + maxTokens := gjson.GetBytes(payload, "max_tokens").Int() + if maxTokens <= 0 { + return defaultAdaptiveFallbackBudgetTokens, true + } + + // budget_tokens must be strictly less than max_tokens. + budget := maxTokens * 80 / 100 // 80% of max_tokens + if budget < minBedrockBudgetTokens { + budget = minBedrockBudgetTokens + } + if budget >= maxTokens { + // max_tokens is too small to fit even the minimum budget; can't convert. + return 0, false + } + return budget, true } // writeUpstreamError marshals and writes a given error. diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index cca890e0..f491a365 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -2,8 +2,10 @@ package messages import ( "context" + "encoding/json" "testing" + "cdr.dev/slog/v3" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/config" @@ -11,6 +13,7 @@ import ( "github.com/coder/aibridge/utils" mcpgo "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestScanForCorrelatingToolCallID(t *testing.T) { @@ -701,6 +704,254 @@ func TestInjectTools_ParallelToolCalls(t *testing.T) { }) } +func TestBedrockModelSupportsAdaptiveThinking(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + model string + expected bool + }{ + {"opus 4.6 with version", "anthropic.claude-opus-4-6-v1", true}, + {"sonnet 4.6", "anthropic.claude-sonnet-4-6", true}, + {"us prefix opus 4.6", "us.anthropic.claude-opus-4-6-v1", true}, + {"opus 4.6 dot notation", "claude-opus-4.6", true}, + {"sonnet 4.6 dot notation", "claude-sonnet-4.6", true}, + {"sonnet 4.5", "anthropic.claude-sonnet-4-5-20250929-v1:0", false}, + {"opus 4.5", "anthropic.claude-opus-4-5-20251101-v1:0", false}, + {"haiku 4.5", "anthropic.claude-haiku-4-5-20251001-v1:0", false}, + {"sonnet 3.7", "anthropic.claude-3-7-sonnet-20250219-v1:0", false}, + {"custom model name", "my-custom-model", false}, + {"empty", "", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.expected, bedrockModelSupportsAdaptiveThinking(tc.model)) + }) + } +} + +func TestConvertAdaptiveThinkingForBedrock(t *testing.T) { + t.Parallel() + + newBaseWithBedrock := func(model string, payload map[string]any) *interceptionBase { + raw, err := json.Marshal(payload) + require.NoError(t, err) + return &interceptionBase{ + req: &MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Model: anthropic.Model(model), + }, + }, + payload: raw, + bedrockCfg: &config.AWSBedrock{Model: model, SmallFastModel: "haiku"}, + logger: slogtest(t), + } + } + + t.Run("converts adaptive to enabled for non-4.6 model", func(t *testing.T) { + t.Parallel() + + base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{ + "model": "claude-sonnet-4-5", + "max_tokens": 16000, + "thinking": map[string]string{"type": "adaptive"}, + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + require.Equal(t, "enabled", gjson.GetBytes(base.payload, "thinking.type").Str) + // 80% of 16000 = 12800 + require.Equal(t, int64(12800), gjson.GetBytes(base.payload, "thinking.budget_tokens").Int()) + }) + + t.Run("uses default when max_tokens is absent", func(t *testing.T) { + t.Parallel() + + base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{ + "model": "claude-sonnet-4-5", + "thinking": map[string]string{"type": "adaptive"}, + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + require.Equal(t, "enabled", gjson.GetBytes(base.payload, "thinking.type").Str) + require.Equal(t, int64(defaultAdaptiveFallbackBudgetTokens), gjson.GetBytes(base.payload, "thinking.budget_tokens").Int()) + }) + + t.Run("clamps to minimum when max_tokens is small", func(t *testing.T) { + t.Parallel() + + // max_tokens=1200: 80% = 960, below min 1024, so clamped to 1024. + // 1024 < 1200, so budget_tokens = 1024. + base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{ + "model": "claude-sonnet-4-5", + "max_tokens": 1200, + "thinking": map[string]string{"type": "adaptive"}, + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + require.Equal(t, "enabled", gjson.GetBytes(base.payload, "thinking.type").Str) + require.Equal(t, int64(1024), gjson.GetBytes(base.payload, "thinking.budget_tokens").Int()) + }) + + t.Run("no-op when max_tokens is too small for minimum budget", func(t *testing.T) { + t.Parallel() + + // max_tokens=1024: 80% = 819, clamped to min 1024, but 1024 >= 1024 (max_tokens). + // No valid budget_tokens exists — conversion should be skipped entirely. + base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{ + "model": "claude-sonnet-4-5", + "max_tokens": 1024, + "thinking": map[string]string{"type": "adaptive"}, + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + // thinking.type must remain "adaptive" — the conversion was skipped. + require.Equal(t, "adaptive", gjson.GetBytes(base.payload, "thinking.type").Str) + require.False(t, gjson.GetBytes(base.payload, "thinking.budget_tokens").Exists()) + }) + + t.Run("preserves adaptive for opus 4.6", func(t *testing.T) { + t.Parallel() + + base := newBaseWithBedrock("anthropic.claude-opus-4-6-v1", map[string]any{ + "model": "claude-opus-4-6", + "thinking": map[string]string{"type": "adaptive"}, + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + require.Equal(t, "adaptive", gjson.GetBytes(base.payload, "thinking.type").Str) + require.False(t, gjson.GetBytes(base.payload, "thinking.budget_tokens").Exists()) + }) + + t.Run("preserves adaptive for sonnet 4.6", func(t *testing.T) { + t.Parallel() + + base := newBaseWithBedrock("anthropic.claude-sonnet-4-6", map[string]any{ + "model": "claude-sonnet-4-6", + "thinking": map[string]string{"type": "adaptive"}, + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + require.Equal(t, "adaptive", gjson.GetBytes(base.payload, "thinking.type").Str) + }) + + t.Run("no-op when thinking type is enabled", func(t *testing.T) { + t.Parallel() + + base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{ + "model": "claude-sonnet-4-5", + "thinking": map[string]any{"type": "enabled", "budget_tokens": 5000}, + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + require.Equal(t, "enabled", gjson.GetBytes(base.payload, "thinking.type").Str) + require.Equal(t, int64(5000), gjson.GetBytes(base.payload, "thinking.budget_tokens").Int()) + }) + + t.Run("no-op when thinking is absent", func(t *testing.T) { + t.Parallel() + + base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + require.False(t, gjson.GetBytes(base.payload, "thinking").Exists()) + }) + + t.Run("no-op when thinking type is disabled", func(t *testing.T) { + t.Parallel() + + base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{ + "model": "claude-sonnet-4-5", + "thinking": map[string]string{"type": "disabled"}, + "messages": []any{}, + }) + + base.convertAdaptiveThinkingForBedrock() + + require.Equal(t, "disabled", gjson.GetBytes(base.payload, "thinking.type").Str) + }) +} + +func TestAdaptiveFallbackBudgetTokens(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload map[string]any + wantBudget int64 + wantOK bool + }{ + { + name: "no max_tokens uses default", + payload: map[string]any{}, + wantBudget: defaultAdaptiveFallbackBudgetTokens, + wantOK: true, + }, + { + name: "normal max_tokens uses 80%", + payload: map[string]any{"max_tokens": 16000}, + wantBudget: 12800, + wantOK: true, + }, + { + name: "small max_tokens clamped to minimum", + payload: map[string]any{"max_tokens": 1200}, + wantBudget: 1024, + wantOK: true, + }, + { + name: "max_tokens too small returns false", + payload: map[string]any{"max_tokens": 1024}, + wantOK: false, + }, + { + name: "max_tokens below minimum returns false", + payload: map[string]any{"max_tokens": 500}, + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + raw, err := json.Marshal(tc.payload) + require.NoError(t, err) + + budget, ok := adaptiveFallbackBudgetTokens(raw) + require.Equal(t, tc.wantOK, ok) + if ok { + require.Equal(t, tc.wantBudget, budget) + } + }) + } +} + +// slogtest returns a no-op logger for tests. +func slogtest(t *testing.T) slog.Logger { + t.Helper() + return slog.Logger{} +} + // mockServerProxier is a test implementation of mcp.ServerProxier. type mockServerProxier struct { tools []*mcp.Tool