Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/coder/aibridge/recorder"
"github.com/coder/aibridge/tracing"
"github.com/coder/quartz"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/google/uuid"
Expand Down Expand Up @@ -310,8 +311,29 @@ func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibco
return out, nil
}

const (
// defaultAdaptiveFallbackBudgetTokens is the default budget_tokens value used when
// converting adaptive thinking to enabled thinking for Bedrock models that don't
// support adaptive thinking, and the request does not contain a max_tokens value.
defaultAdaptiveFallbackBudgetTokens = 10000

// minBedrockBudgetTokens is the minimum budget_tokens value that Bedrock accepts.
minBedrockBudgetTokens = 1024
)

// bedrockModelSupportsAdaptiveThinking returns true if the given Bedrock model
// supports the adaptive thinking type. Only Claude Opus 4.6 and Sonnet 4.6 support it.
func bedrockModelSupportsAdaptiveThinking(model string) bool {
model = strings.ToLower(model)
return strings.Contains(model, "opus-4-6") ||
strings.Contains(model, "sonnet-4-6") ||
strings.Contains(model, "opus-4.6") ||
strings.Contains(model, "sonnet-4.6")
}

// augmentRequestForBedrock will change the model used for the request since AWS Bedrock doesn't support
// Anthropics' model names.
// Anthropics' model names. It also converts adaptive thinking to enabled thinking for Bedrock models
// that don't support the adaptive thinking type.
func (i *interceptionBase) augmentRequestForBedrock() {
if i.bedrockCfg == nil {
return
Expand All @@ -324,6 +346,79 @@ func (i *interceptionBase) augmentRequestForBedrock() {
if err != nil {
i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err))
}

i.convertAdaptiveThinkingForBedrock()
}

// convertAdaptiveThinkingForBedrock converts "thinking": {"type": "adaptive"} to
// "thinking": {"type": "enabled", "budget_tokens": N} when the target Bedrock model
// does not support adaptive thinking.
//
// The budget_tokens value is derived from the request's max_tokens to satisfy the
// API constraint that budget_tokens < max_tokens. If max_tokens is not set or too
// small, a safe default is used.
func (i *interceptionBase) convertAdaptiveThinkingForBedrock() {
thinkingType := gjson.GetBytes(i.payload, "thinking.type").Str
if thinkingType != "adaptive" {
return
}

if bedrockModelSupportsAdaptiveThinking(i.Model()) {
return
}

budgetTokens, ok := adaptiveFallbackBudgetTokens(i.payload)

if !ok {
// max_tokens is too small to accommodate the Bedrock minimum budget.
// Leave the payload unchanged and let the request fail with a clear
// upstream error rather than sending a known-bad budget_tokens value.
i.logger.Warn(context.Background(),
"cannot convert adaptive thinking for Bedrock: max_tokens is too small to fit minimum budget_tokens",
slog.F("model", i.Model()),
slog.F("min_budget_tokens", minBedrockBudgetTokens),
)
return
}

i.logger.Info(context.Background(), "converting adaptive thinking to enabled for Bedrock model",
slog.F("model", i.Model()),
slog.F("budget_tokens", budgetTokens),
)

var err error
i.payload, err = sjson.SetBytes(i.payload, "thinking", map[string]any{
"type": "enabled",
"budget_tokens": budgetTokens,
})
if err != nil {
i.logger.Warn(context.Background(), "failed to convert adaptive thinking for Bedrock", slog.Error(err))
}
}

// adaptiveFallbackBudgetTokens computes a safe budget_tokens value from the
// request payload. The Anthropic API requires budget_tokens < max_tokens and
// Bedrock enforces a minimum of 1024.
//
// Returns (budget, true) on success, or (0, false) when max_tokens is present
// but too small to accommodate the minimum — the caller should skip conversion
// in that case rather than write an invalid value.
func adaptiveFallbackBudgetTokens(payload []byte) (int64, bool) {
maxTokens := gjson.GetBytes(payload, "max_tokens").Int()
if maxTokens <= 0 {
return defaultAdaptiveFallbackBudgetTokens, true
}

// budget_tokens must be strictly less than max_tokens.
budget := maxTokens * 80 / 100 // 80% of max_tokens
if budget < minBedrockBudgetTokens {
budget = minBedrockBudgetTokens
}
if budget >= maxTokens {
// max_tokens is too small to fit even the minimum budget; can't convert.
return 0, false
}
return budget, true
}

// writeUpstreamError marshals and writes a given error.
Expand Down
251 changes: 251 additions & 0 deletions intercept/messages/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package messages

import (
"context"
"encoding/json"
"testing"

"cdr.dev/slog/v3"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/utils"
mcpgo "github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)

func TestScanForCorrelatingToolCallID(t *testing.T) {
Expand Down Expand Up @@ -701,6 +704,254 @@ func TestInjectTools_ParallelToolCalls(t *testing.T) {
})
}

func TestBedrockModelSupportsAdaptiveThinking(t *testing.T) {
t.Parallel()

tests := []struct {
name string
model string
expected bool
}{
{"opus 4.6 with version", "anthropic.claude-opus-4-6-v1", true},
{"sonnet 4.6", "anthropic.claude-sonnet-4-6", true},
{"us prefix opus 4.6", "us.anthropic.claude-opus-4-6-v1", true},
{"opus 4.6 dot notation", "claude-opus-4.6", true},
{"sonnet 4.6 dot notation", "claude-sonnet-4.6", true},
{"sonnet 4.5", "anthropic.claude-sonnet-4-5-20250929-v1:0", false},
{"opus 4.5", "anthropic.claude-opus-4-5-20251101-v1:0", false},
{"haiku 4.5", "anthropic.claude-haiku-4-5-20251001-v1:0", false},
{"sonnet 3.7", "anthropic.claude-3-7-sonnet-20250219-v1:0", false},
{"custom model name", "my-custom-model", false},
{"empty", "", false},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.expected, bedrockModelSupportsAdaptiveThinking(tc.model))
})
}
}

func TestConvertAdaptiveThinkingForBedrock(t *testing.T) {
t.Parallel()

newBaseWithBedrock := func(model string, payload map[string]any) *interceptionBase {
raw, err := json.Marshal(payload)
require.NoError(t, err)
return &interceptionBase{
req: &MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Model: anthropic.Model(model),
},
},
payload: raw,
bedrockCfg: &config.AWSBedrock{Model: model, SmallFastModel: "haiku"},
logger: slogtest(t),
}
}

t.Run("converts adaptive to enabled for non-4.6 model", func(t *testing.T) {
t.Parallel()

base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{
"model": "claude-sonnet-4-5",
"max_tokens": 16000,
"thinking": map[string]string{"type": "adaptive"},
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

require.Equal(t, "enabled", gjson.GetBytes(base.payload, "thinking.type").Str)
// 80% of 16000 = 12800
require.Equal(t, int64(12800), gjson.GetBytes(base.payload, "thinking.budget_tokens").Int())
})

t.Run("uses default when max_tokens is absent", func(t *testing.T) {
t.Parallel()

base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{
"model": "claude-sonnet-4-5",
"thinking": map[string]string{"type": "adaptive"},
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

require.Equal(t, "enabled", gjson.GetBytes(base.payload, "thinking.type").Str)
require.Equal(t, int64(defaultAdaptiveFallbackBudgetTokens), gjson.GetBytes(base.payload, "thinking.budget_tokens").Int())
})

t.Run("clamps to minimum when max_tokens is small", func(t *testing.T) {
t.Parallel()

// max_tokens=1200: 80% = 960, below min 1024, so clamped to 1024.
// 1024 < 1200, so budget_tokens = 1024.
base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{
"model": "claude-sonnet-4-5",
"max_tokens": 1200,
"thinking": map[string]string{"type": "adaptive"},
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

require.Equal(t, "enabled", gjson.GetBytes(base.payload, "thinking.type").Str)
require.Equal(t, int64(1024), gjson.GetBytes(base.payload, "thinking.budget_tokens").Int())
})

t.Run("no-op when max_tokens is too small for minimum budget", func(t *testing.T) {
t.Parallel()

// max_tokens=1024: 80% = 819, clamped to min 1024, but 1024 >= 1024 (max_tokens).
// No valid budget_tokens exists — conversion should be skipped entirely.
base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{
"model": "claude-sonnet-4-5",
"max_tokens": 1024,
"thinking": map[string]string{"type": "adaptive"},
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

// thinking.type must remain "adaptive" — the conversion was skipped.
require.Equal(t, "adaptive", gjson.GetBytes(base.payload, "thinking.type").Str)
require.False(t, gjson.GetBytes(base.payload, "thinking.budget_tokens").Exists())
})

t.Run("preserves adaptive for opus 4.6", func(t *testing.T) {
t.Parallel()

base := newBaseWithBedrock("anthropic.claude-opus-4-6-v1", map[string]any{
"model": "claude-opus-4-6",
"thinking": map[string]string{"type": "adaptive"},
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

require.Equal(t, "adaptive", gjson.GetBytes(base.payload, "thinking.type").Str)
require.False(t, gjson.GetBytes(base.payload, "thinking.budget_tokens").Exists())
})

t.Run("preserves adaptive for sonnet 4.6", func(t *testing.T) {
t.Parallel()

base := newBaseWithBedrock("anthropic.claude-sonnet-4-6", map[string]any{
"model": "claude-sonnet-4-6",
"thinking": map[string]string{"type": "adaptive"},
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

require.Equal(t, "adaptive", gjson.GetBytes(base.payload, "thinking.type").Str)
})

t.Run("no-op when thinking type is enabled", func(t *testing.T) {
t.Parallel()

base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{
"model": "claude-sonnet-4-5",
"thinking": map[string]any{"type": "enabled", "budget_tokens": 5000},
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

require.Equal(t, "enabled", gjson.GetBytes(base.payload, "thinking.type").Str)
require.Equal(t, int64(5000), gjson.GetBytes(base.payload, "thinking.budget_tokens").Int())
})

t.Run("no-op when thinking is absent", func(t *testing.T) {
t.Parallel()

base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{
"model": "claude-sonnet-4-5",
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

require.False(t, gjson.GetBytes(base.payload, "thinking").Exists())
})

t.Run("no-op when thinking type is disabled", func(t *testing.T) {
t.Parallel()

base := newBaseWithBedrock("anthropic.claude-sonnet-4-5-20250929-v1:0", map[string]any{
"model": "claude-sonnet-4-5",
"thinking": map[string]string{"type": "disabled"},
"messages": []any{},
})

base.convertAdaptiveThinkingForBedrock()

require.Equal(t, "disabled", gjson.GetBytes(base.payload, "thinking.type").Str)
})
}

func TestAdaptiveFallbackBudgetTokens(t *testing.T) {
t.Parallel()

tests := []struct {
name string
payload map[string]any
wantBudget int64
wantOK bool
}{
{
name: "no max_tokens uses default",
payload: map[string]any{},
wantBudget: defaultAdaptiveFallbackBudgetTokens,
wantOK: true,
},
{
name: "normal max_tokens uses 80%",
payload: map[string]any{"max_tokens": 16000},
wantBudget: 12800,
wantOK: true,
},
{
name: "small max_tokens clamped to minimum",
payload: map[string]any{"max_tokens": 1200},
wantBudget: 1024,
wantOK: true,
},
{
name: "max_tokens too small returns false",
payload: map[string]any{"max_tokens": 1024},
wantOK: false,
},
{
name: "max_tokens below minimum returns false",
payload: map[string]any{"max_tokens": 500},
wantOK: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
raw, err := json.Marshal(tc.payload)
require.NoError(t, err)

budget, ok := adaptiveFallbackBudgetTokens(raw)
require.Equal(t, tc.wantOK, ok)
if ok {
require.Equal(t, tc.wantBudget, budget)
}
})
}
}

// slogtest returns a no-op logger for tests.
func slogtest(t *testing.T) slog.Logger {
t.Helper()
return slog.Logger{}
}

// mockServerProxier is a test implementation of mcp.ServerProxier.
type mockServerProxier struct {
tools []*mcp.Tool
Expand Down
Loading