Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 22 additions & 69 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/coder/aibridge/recorder"
"github.com/coder/aibridge/tracing"
"github.com/coder/quartz"
"github.com/tidwall/sjson"

"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
Expand All @@ -34,9 +33,8 @@ import (
)

type interceptionBase struct {
id uuid.UUID
req *MessageNewParamsWrapper
payload []byte
id uuid.UUID
reqPayload MessagesRequestPayload

cfg aibconfig.Anthropic
bedrockCfg *aibconfig.AWSBedrock
Expand All @@ -63,22 +61,11 @@ func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder,
}

func (i *interceptionBase) CorrelatingToolCallID() *string {
if len(i.req.Messages) == 0 {
return nil
}
content := i.req.Messages[len(i.req.Messages)-1].Content
for idx := len(content) - 1; idx >= 0; idx-- {
block := content[idx]
if block.OfToolResult == nil {
continue
}
return &block.OfToolResult.ToolUseID
}
return nil
return i.reqPayload.correlatingToolCallID()
}

func (i *interceptionBase) Model() string {
if i.req == nil {
if len(i.reqPayload) == 0 {
return "coder-aibridge-unknown"
}

Expand All @@ -90,7 +77,7 @@ func (i *interceptionBase) Model() string {
return model
}

return string(i.req.Model)
return i.reqPayload.model()
}

func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
Expand All @@ -106,7 +93,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
}

func (i *interceptionBase) injectTools() {
if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() {
if i.mcpProxy == nil || !i.hasInjectableTools() {
return
}

Expand All @@ -131,46 +118,23 @@ func (i *interceptionBase) injectTools() {
// Prepend the injected tools in order to maintain any configured cache breakpoints.
// The order of injected tools is expected to be stable, and therefore will not cause
// any cache invalidation when prepended.
i.req.Tools = append(injectedTools, i.req.Tools...)

var err error
i.payload, err = sjson.SetBytes(i.payload, "tools", i.req.Tools)
updated, err := i.reqPayload.injectTools(injectedTools)
if err != nil {
i.logger.Warn(context.Background(), "failed to set inject tools in request payload", slog.Error(err))
return
}
i.reqPayload = updated
}

func (i *interceptionBase) disableParallelToolCalls() {
// Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches.
// https://github.com/coder/aibridge/issues/2
toolChoiceType := i.req.ToolChoice.GetType()
var toolChoiceTypeStr string
if toolChoiceType != nil {
toolChoiceTypeStr = *toolChoiceType
}

switch toolChoiceTypeStr {
// If no tool_choice was defined, assume auto.
// See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use.
case "", string(constant.ValueOf[constant.Auto]()):
// We only set OfAuto if no tool_choice was provided (the default).
// "auto" is the default when a zero value is provided, so we can safely disable parallel checks on it.
if i.req.ToolChoice.OfAuto == nil {
i.req.ToolChoice.OfAuto = &anthropic.ToolChoiceAutoParam{}
}
i.req.ToolChoice.OfAuto.DisableParallelToolUse = anthropic.Bool(true)
case string(constant.ValueOf[constant.Any]()):
i.req.ToolChoice.OfAny.DisableParallelToolUse = anthropic.Bool(true)
case string(constant.ValueOf[constant.Tool]()):
i.req.ToolChoice.OfTool.DisableParallelToolUse = anthropic.Bool(true)
case string(constant.ValueOf[constant.None]()):
// No-op; if tool_choice=none then tools are not used at all.
}
var err error
i.payload, err = sjson.SetBytes(i.payload, "tool_choice", i.req.ToolChoice)
updated, err := i.reqPayload.disableParallelToolCalls()
if err != nil {
i.logger.Warn(context.Background(), "failed to set tool_choice in request payload", slog.Error(err))
return
}
i.reqPayload = updated
}

// extractModelThoughts returns any thinking blocks that were returned in the response.
Expand Down Expand Up @@ -201,7 +165,7 @@ func (i *interceptionBase) extractModelThoughts(msg *anthropic.Message) []*recor
// See `ANTHROPIC_SMALL_FAST_MODEL`: https://docs.anthropic.com/en/docs/claude-code/settings#environment-variables
// https://docs.claude.com/en/docs/claude-code/costs#background-token-usage
func (i *interceptionBase) isSmallFastModel() bool {
return strings.Contains(string(i.req.Model), "haiku")
return strings.Contains(i.reqPayload.model(), "haiku")
}

func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) {
Expand Down Expand Up @@ -244,23 +208,12 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio
return anthropic.NewMessageService(opts...), nil
}

// withBody returns a per-request option that sends the current i.payload as the
// request body. This is called for each API request so that the latest payload (including
// any messages appended during the agentic tool loop) is always sent.
// withBody returns a per-request option that sends the current raw request
// payload as the request body. This is called for each API request so that the
// latest payload (including any messages appended during the agentic tool loop)
// is always sent.
func (i *interceptionBase) withBody() option.RequestOption {
return option.WithRequestBody("application/json", i.payload)
}

// syncPayloadMessages updates the raw payload's "messages" field to match the given messages.
// This must be called before the next API request in the agentic loop so that
// withBody() picks up the updated messages.
func (i *interceptionBase) syncPayloadMessages(messages []anthropic.MessageParam) error {
var err error
i.payload, err = sjson.SetBytes(i.payload, "messages", messages)
if err != nil {
return fmt.Errorf("sync payload messages: %w", err)
}
return nil
return option.WithRequestBody("application/json", []byte(i.reqPayload))
}

func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
Expand Down Expand Up @@ -317,13 +270,13 @@ func (i *interceptionBase) augmentRequestForBedrock() {
return
}

i.req.MessageNewParams.Model = anthropic.Model(i.Model())

var err error
i.payload, err = sjson.SetBytes(i.payload, "model", i.Model())
updated, err := i.reqPayload.withModel(i.Model())
if err != nil {
i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err))
return
}

i.reqPayload = updated
}

// writeUpstreamError marshals and writes a given error.
Expand Down
Loading
Loading