Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 3 additions & 109 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -20,7 +19,6 @@ import (
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/metrics"
"github.com/coder/aibridge/recorder"
"github.com/coder/aibridge/tracing"
"github.com/coder/quartz"
Expand Down Expand Up @@ -48,9 +46,8 @@ type responsesInterceptionBase struct {
recorder recorder.Recorder
mcpProxy mcp.ServerProxier

logger slog.Logger
metrics metrics.Metrics
tracer trace.Tracer
logger slog.Logger
tracer trace.Tracer
}

func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
Expand Down Expand Up @@ -96,27 +93,7 @@ func (i *responsesInterceptionBase) Model() string {
}

func (i *responsesInterceptionBase) CorrelatingToolCallID() *string {
items := gjson.GetBytes(i.reqPayload, "input")
if !items.IsArray() {
return nil
}

arr := items.Array()
if len(arr) == 0 {
return nil
}

last := arr[len(arr)-1]
if last.Get(string(constant.ValueOf[constant.Type]())).String() != string(constant.ValueOf[constant.FunctionCallOutput]()) {
return nil
}

callID := last.Get("call_id").String()
if callID == "" {
return nil
}

return &callID
return i.reqPayload.correlatingToolCallID()
}

func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
Expand Down Expand Up @@ -178,89 +155,6 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o
return opts
}

// lastUserPrompt returns input text with "user" role from last input item
// or string input value if it is present + bool indicating if input was found or not.
// If no such input was found empty string + false is returned.
func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string, bool, error) {
if i == nil {
return "", false, errors.New("cannot get last user prompt: nil struct")
}
if i.reqPayload == nil {
return "", false, errors.New("cannot get last user prompt: nil request struct")
}

// 'input' can be either a string or an array of input items:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input
inputItems := gjson.GetBytes(i.reqPayload, "input")
if !inputItems.Exists() || inputItems.Type == gjson.Null {
return "", false, nil
}

// String variant: treat the whole input as the user prompt.
if inputItems.Type == gjson.String {
return inputItems.String(), true, nil
}

// Array variant: checking only the last input item
if !inputItems.IsArray() {
return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type)
}

inputItemsArr := inputItems.Array()
if len(inputItemsArr) == 0 {
return "", false, nil
}

lastItem := inputItemsArr[len(inputItemsArr)-1]
if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) {
// Request was likely not initiated by a prompt but is an iteration of agentic loop.
return "", false, nil
}

// Message content can be either a string or an array of typed content items:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content
content := lastItem.Get(string(constant.ValueOf[constant.Content]()))
if !content.Exists() || content.Type == gjson.Null {
return "", false, nil
}

// String variant: use it directly as the prompt.
if content.Type == gjson.String {
return content.Str, true, nil
}

if !content.IsArray() {
return "", false, fmt.Errorf("unexpected input content type: %s", content.Type)
}

var sb strings.Builder
promptExists := false
for _, c := range content.Array() {
// Ignore non-text content blocks such as images or files.
if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) {
continue
}

text := c.Get(string(constant.ValueOf[constant.Text]()))
if text.Type != gjson.String {
i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type))
continue
}

if promptExists {
sb.WriteByte('\n')
}
promptExists = true
sb.WriteString(text.Str)
}

if !promptExists {
return "", false, nil
}

return sb.String(), true, nil
}

func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) {
if responseID == "" {
i.logger.Warn(ctx, "got empty response ID, skipping prompt recording")
Expand Down
216 changes: 0 additions & 216 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,229 +6,13 @@ import (
"time"

"cdr.dev/slog/v3"
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/recorder"
"github.com/coder/aibridge/utils"
"github.com/google/uuid"
oairesponses "github.com/openai/openai-go/v3/responses"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestScanForCorrelatingToolCallID(t *testing.T) {
t.Parallel()

tests := []struct {
name string
payload []byte
wantCall *string
}{
{
name: "no input",
payload: []byte(`{"model":"gpt-4o"}`),
},
{
name: "empty input array",
payload: []byte(`{"model":"gpt-4o","input":[]}`),
},
{
name: "no function_call_output items",
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`),
},
{
name: "single function_call_output",
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`),
wantCall: utils.PtrTo("call_abc"),
},
{
name: "multiple function_call_outputs returns last",
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`),
wantCall: utils.PtrTo("call_second"),
},
{
name: "last input is not a tool result",
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`),
},
{
name: "missing call id",
payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`),
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

rp, err := NewResponsesRequestPayload(tc.payload)
require.NoError(t, err)
base := &responsesInterceptionBase{
reqPayload: rp,
}

callID := base.CorrelatingToolCallID()
assert.Equal(t, tc.wantCall, callID)
})
}
}

func TestLastUserPrompt(t *testing.T) {
t.Parallel()

tests := []struct {
name string
reqPayload []byte
expect string
}{
{
name: "input_empty_string",
reqPayload: []byte(`{"input": ""}`),
expect: "",
},
{
name: "input_array_content_empty_string",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`),
expect: "",
},
{
name: "input_array_content_array_empty_string",
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`),
},
{
name: "input_array_content_array_multiple_inputs",
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`),
expect: "a\nb",
},
{
name: "simple_string_input",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
expect: "tell me a joke",
},
{
name: "array_single_input_string",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool),
expect: "Is 3 + 5 a prime number? Use the add function to calculate the sum.",
},
{
name: "array_multiple_items_content_objects",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex),
expect: "hello",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

rp, err := NewResponsesRequestPayload(tc.reqPayload)
require.NoError(t, err)
base := &responsesInterceptionBase{
reqPayload: rp,
}

prompt, promptFound, err := base.lastUserPrompt(t.Context())
require.NoError(t, err)
require.Equal(t, tc.expect, prompt)
require.True(t, promptFound)
})
}
}

func TestLastUserPromptNotFound(t *testing.T) {
t.Parallel()

t.Run("nil_struct", func(t *testing.T) {
t.Parallel()

var base *responsesInterceptionBase
prompt, promptFound, err := base.lastUserPrompt(t.Context())
require.Error(t, err)
require.Empty(t, prompt)
require.False(t, promptFound)
require.Contains(t, "cannot get last user prompt: nil struct", err.Error())
})

t.Run("nil_request", func(t *testing.T) {
t.Parallel()

base := responsesInterceptionBase{}
prompt, promptFound, err := base.lastUserPrompt(t.Context())
require.Error(t, err)
require.Empty(t, prompt)
require.False(t, promptFound)
require.Contains(t, "cannot get last user prompt: nil request struct", err.Error())
})

// Cases where the user prompt is not found / wrong format.
tests := []struct {
name string
reqPayload []byte
expectErr string
}{
{
name: "non_existing_input",
reqPayload: []byte(`{"model": "gpt-4o"}`),
},
{
name: "input_empty_array",
reqPayload: []byte(`{"model": "gpt-4o", "input": []}`),
},
{
name: "input_integer",
reqPayload: []byte(`{"model": "gpt-4o", "input": 123}`),
expectErr: "unexpected input type",
},
{
name: "no_user_role",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`),
},
{
name: "user_with_empty_content_array",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`),
},
{
name: "input_array_integer",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": 123}]}`),
expectErr: "unexpected input content type",
},
{
name: "user_with_non_input_text_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`),
},
{
name: "user_content_not_last",
reqPayload: []byte(`{"model": "gpt-4o", "input": [ {"role": "user", "content":"input"}, {"role": "assistant", "content": "hello"} ]}`),
},
{
name: "input_array_content_array_integer",
reqPayload: []byte(`{"model": "gpt-4o", "input": [ { "role": "user", "content": [{"type": "input_text", "text": 123}] } ] }`),
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

rp, err := NewResponsesRequestPayload(tc.reqPayload)
require.NoError(t, err)

base := &responsesInterceptionBase{
reqPayload: rp,
}

prompt, promptFound, err := base.lastUserPrompt(t.Context())
if tc.expectErr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectErr)
} else {
require.NoError(t, err)
}
require.Empty(t, prompt)
require.False(t, promptFound)
})
}
}

func TestRecordPrompt(t *testing.T) {
t.Parallel()

Expand Down
4 changes: 2 additions & 2 deletions intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type BlockingResponsesInterceptor struct {

func NewBlockingInterceptor(
id uuid.UUID,
reqPayload []byte,
reqPayload ResponsesRequestPayload,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
Expand Down Expand Up @@ -74,7 +74,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
firstResponseID string
)

prompt, promptFound, err := i.lastUserPrompt(ctx)
prompt, promptFound, err := i.reqPayload.lastUserPrompt(ctx, i.logger)
if err != nil {
i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err))
}
Expand Down
Loading
Loading