Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 61 additions & 55 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,19 @@ const (
)

type responsesInterceptionBase struct {
id uuid.UUID
req *ResponsesNewParamsWrapper
reqPayload []byte
cfg config.OpenAI
model string

id uuid.UUID
// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
authHeaderName string
reqPayload ResponsesRequestPayload

cfg config.OpenAI
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
logger slog.Logger
metrics metrics.Metrics
tracer trace.Tracer

logger slog.Logger
metrics metrics.Metrics
tracer trace.Tracer
}

func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
Expand Down Expand Up @@ -88,26 +86,37 @@ func (i *responsesInterceptionBase) ID() uuid.UUID {
}

func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger.With(slog.F("model", i.model))
i.logger = logger.With(slog.F("model", i.Model()))
i.recorder = recorder
i.mcpProxy = mcpProxy
}

func (i *responsesInterceptionBase) Model() string {
return i.model
return i.reqPayload.model()
}

func (i *responsesInterceptionBase) CorrelatingToolCallID() *string {
if len(i.req.Input.OfInputItemList) == 0 {
items := gjson.GetBytes(i.reqPayload, "input")
if !items.IsArray() {
return nil
}

arr := items.Array()
if len(arr) == 0 {
return nil
}

last := arr[len(arr)-1]
if last.Get(string(constant.ValueOf[constant.Type]())).String() != string(constant.ValueOf[constant.FunctionCallOutput]()) {
return nil
}

// The tool result should be the last input message.
item := i.req.Input.OfInputItemList[len(i.req.Input.OfInputItemList)-1]
if item.OfFunctionCallOutput == nil {
callID := last.Get("call_id").String()
if callID == "" {
return nil
}
return &item.OfFunctionCallOutput.CallID

return &callID
}

func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
Expand All @@ -122,13 +131,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami
}

func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.ResponseWriter) error {
if i.req == nil {
err := errors.New("developer error: req is nil")
i.sendCustomErr(ctx, w, http.StatusInternalServerError, err)
return err
}

if i.req.Background.Value {
if i.reqPayload.background() {
err := fmt.Errorf("background requests are currently not supported by AI Bridge")
i.sendCustomErr(ctx, w, http.StatusNotImplemented, err)
return err
Expand Down Expand Up @@ -161,15 +164,15 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o
// eg. Codex CLI produces requests without ID set in reasoning items: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-item-reasoning-id
// when re-encoded, ID field is set to empty string which results
// in bad request while not sending ID field at all somehow works.
option.WithRequestBody("application/json", i.reqPayload),
option.WithRequestBody("application/json", []byte(i.reqPayload)),

// copyMiddleware copies body of original response body to the buffer in responseCopier,
// also reference to headers and status code is kept responseCopier.
// responseCopier is used by interceptors to forward response as it was received,
// eliminating any possibility of JSON re-encoding issues.
option.WithMiddleware(respCopy.copyMiddleware),
}
if !i.req.Stream {
if !i.reqPayload.Stream() {
opts = append(opts, option.WithRequestTimeout(requestTimeout))
}
return opts
Expand All @@ -182,77 +185,80 @@ func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string,
if i == nil {
return "", false, errors.New("cannot get last user prompt: nil struct")
}
if i.req == nil {
if i.reqPayload == nil {
return "", false, errors.New("cannot get last user prompt: nil request struct")
}

// 'input' field can be a string or array of objects:
// 'input' can be either a string or an array of input items:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input

// Check string variant
if i.req.Input.OfString.Valid() {
return i.req.Input.OfString.Value, true, nil
inputItems := gjson.GetBytes(i.reqPayload, "input")
if !inputItems.Exists() || inputItems.Type == gjson.Null {
return "", false, nil
}

// Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field.
// If 'type' field of input item is not set it will be omitted from 'Input.OfInputItemList'
// It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message
// example: fixtures/openai/responses/blocking/builtin_tool.txtar
inputItems := gjson.GetBytes(i.reqPayload, "input")
// String variant: treat the whole input as the user prompt.
if inputItems.Type == gjson.String {
return inputItems.String(), true, nil
}

// Array variant: checking only the last input item
if !inputItems.IsArray() {
if inputItems.Type == gjson.Null {
return "", false, nil
}
return "", false, fmt.Errorf("unexpected input type: %v", inputItems.Type.String())
return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type)
}

inputItemsArr := inputItems.Array()
if len(inputItemsArr) == 0 {
return "", false, nil
}
lastItem := inputItemsArr[len(inputItemsArr)-1]

// Request was likely not human-initiated.
lastItem := inputItemsArr[len(inputItemsArr)-1]
if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) {
// Request was likely not initiated by a prompt but is an iteration of agentic loop.
return "", false, nil
}

// content can be a string or array of objects:
// Message content can be either a string or an array of typed content items:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content
content := lastItem.Get(string(constant.ValueOf[constant.Content]()))
if !content.Exists() || content.Type == gjson.Null {
return "", false, nil
}

// String variant: use it directly as the prompt.
if content.Type == gjson.String {
return content.Str, true, nil
}

// non array case, should be string
if !content.IsArray() {
if content.Type == gjson.String {
return content.Str, true, nil
}
return "", false, fmt.Errorf("unexpected input content type: %v", content.Type.String())
return "", false, fmt.Errorf("unexpected input content type: %s", content.Type)
}

var sb strings.Builder
promptExists := false
for _, c := range content.Array() {
// ignore inputs of not `input_text` type
// Ignore non-text content blocks such as images or files.
if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) {
continue
}

text := c.Get(string(constant.ValueOf[constant.Text]()))
if text.Type == gjson.String {
promptExists = true
sb.WriteString(text.Str + "\n")
} else {
if text.Type != gjson.String {
i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type))
continue
}

if promptExists {
sb.WriteByte('\n')
}
promptExists = true
sb.WriteString(text.Str)
}

if !promptExists {
return "", false, nil
}

prompt := strings.TrimSuffix(sb.String(), "\n")
return prompt, true, nil
return sb.String(), true, nil
}

func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) {
Expand Down
110 changes: 32 additions & 78 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/coder/aibridge/utils"
"github.com/google/uuid"
oairesponses "github.com/openai/openai-go/v3/responses"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -20,95 +21,53 @@ func TestScanForCorrelatingToolCallID(t *testing.T) {

tests := []struct {
name string
input []oairesponses.ResponseInputItemUnionParam
expected *string
payload []byte
wantCall *string
}{
{
name: "no input items",
input: nil,
expected: nil,
name: "no input",
payload: []byte(`{"model":"gpt-4o"}`),
},
{
name: "no function_call_output items",
input: []oairesponses.ResponseInputItemUnionParam{
{
OfMessage: &oairesponses.EasyInputMessageParam{
Role: "user",
},
},
},
expected: nil,
name: "empty input array",
payload: []byte(`{"model":"gpt-4o","input":[]}`),
},
{
name: "single function_call_output",
input: []oairesponses.ResponseInputItemUnionParam{
{
OfMessage: &oairesponses.EasyInputMessageParam{
Role: "user",
},
},
{
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
CallID: "call_abc",
},
},
},
expected: utils.PtrTo("call_abc"),
name: "no function_call_output items",
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`),
},
{
name: "multiple function_call_outputs returns last",
input: []oairesponses.ResponseInputItemUnionParam{
{
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
CallID: "call_first",
},
},
{
OfMessage: &oairesponses.EasyInputMessageParam{
Role: "user",
},
},
{
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
CallID: "call_second",
},
},
},
expected: utils.PtrTo("call_second"),
name: "single function_call_output",
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`),
wantCall: utils.PtrTo("call_abc"),
},
{
name: "last input is not a tool result",
input: []oairesponses.ResponseInputItemUnionParam{
{
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
CallID: "call_first",
},
},
{
OfMessage: &oairesponses.EasyInputMessageParam{
Role: "user",
},
},
},
expected: nil,
name: "multiple function_call_outputs returns last",
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`),
wantCall: utils.PtrTo("call_second"),
},
{
name: "last input is not a tool result",
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`),
},
{
name: "missing call id",
payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`),
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

rp, err := NewResponsesRequestPayload(tc.payload)
require.NoError(t, err)
base := &responsesInterceptionBase{
req: &ResponsesNewParamsWrapper{
ResponseNewParams: oairesponses.ResponseNewParams{
Input: oairesponses.ResponseNewParamsInputUnion{
OfInputItemList: tc.input,
},
},
},
reqPayload: rp,
}

require.Equal(t, tc.expected, base.CorrelatingToolCallID())
callID := base.CorrelatingToolCallID()
assert.Equal(t, tc.wantCall, callID)
})
}
}
Expand Down Expand Up @@ -161,13 +120,10 @@ func TestLastUserPrompt(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

req := &ResponsesNewParamsWrapper{}
err := req.UnmarshalJSON(tc.reqPayload)
rp, err := NewResponsesRequestPayload(tc.reqPayload)
require.NoError(t, err)

base := &responsesInterceptionBase{
req: req,
reqPayload: tc.reqPayload,
reqPayload: rp,
}

prompt, promptFound, err := base.lastUserPrompt(t.Context())
Expand Down Expand Up @@ -253,13 +209,11 @@ func TestLastUserPromptNotFound(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

req := &ResponsesNewParamsWrapper{}
err := req.UnmarshalJSON(tc.reqPayload)
rp, err := NewResponsesRequestPayload(tc.reqPayload)
require.NoError(t, err)

base := &responsesInterceptionBase{
req: req,
reqPayload: tc.reqPayload,
reqPayload: rp,
}

prompt, promptFound, err := base.lastUserPrompt(t.Context())
Expand Down
Loading
Loading