Skip to content

Commit ceda355

Browse files
committed
Implement proxy for Ollama chat API
1 parent c82c356 commit ceda355

File tree

2 files changed

+238
-5
lines changed

2 files changed

+238
-5
lines changed

README.md

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,54 @@ $ curl http://127.0.0.1:5555/api/embed \
145145
{"model":"text-embedding-004","embeddings":[[0.04824496,0.0117766075,-0.011552069,-0.018164534,-0.0026110192,0.05092675,0.08172899,0.007869772,0.054475933,0.026131334,-0.06593486,-0.002256868,0.038781915,...]]}
146146
```
147147

148+
Create chat completions:
149+
150+
```sh
151+
$ curl http://127.0.0.1:5555/api/chat \
152+
-H "Content-Type: application/json" \
153+
-d '{
154+
"model": "gemini-1.5-pro",
155+
"messages": [
156+
{ "role": "user", "content": "Hello, world!" }
157+
]
158+
}'
159+
```
160+
161+
Chat response example:
162+
163+
```json
164+
{
165+
"model": "gemini-1.5-pro",
166+
"created_at": "2024-07-28T15:00:00Z",
167+
"message": { "role": "assistant", "content": "Hello back to you!" },
168+
"done": true,
169+
"total_duration": 123456789,
170+
"prompt_eval_count": 5,
171+
"eval_count": 10
172+
}
173+
```
174+
175+
Advanced usage:
176+
177+
- Use the `format` parameter to request JSON or enforce a JSON schema.
178+
- Use the `options` parameter to set model generation parameters: `temperature`, `num_predict`, `top_k`, and `top_p`.
179+
180+
```sh
181+
$ curl http://127.0.0.1:5555/api/chat \
182+
-H "Content-Type: application/json" \
183+
-d '{
184+
"model": "gemini-1.5-pro",
185+
"messages": [...],
186+
"format": "json",
187+
"options": { "temperature": 0.5, "num_predict": 100 }
188+
}'
189+
```
190+
148191
### Known Ollama Limitations
149192
* Streaming is not yet supported.
150193
* Images are not supported.
151-
* Response format is not supported.
194+
* Response format is supported only for chat API.
195+
* Tools are not supported for chat API.
152196
* Model parameters not supported by Gemini are ignored.
153197

154198
## Notes

ollama/ollama.go

Lines changed: 193 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@ package ollama
1818

1919
import (
2020
"encoding/json"
21+
"fmt"
22+
"github.com/google-gemini/proxy-to-gemini/internal"
23+
"github.com/google/generative-ai-go/genai"
24+
"github.com/gorilla/mux"
2125
"io"
2226
"net/http"
2327
"strings"
2428
"time"
25-
26-
"github.com/google-gemini/proxy-to-gemini/internal"
27-
"github.com/google/generative-ai-go/genai"
28-
"github.com/gorilla/mux"
2929
)
3030

31+
// handlers provides HTTP handlers for the Ollama proxy API.
3132
type handlers struct {
3233
client *genai.Client
3334
}
@@ -36,6 +37,7 @@ func RegisterHandlers(r *mux.Router, client *genai.Client) {
3637
handlers := &handlers{client: client}
3738
r.HandleFunc("/api/generate", handlers.generateHandler)
3839
r.HandleFunc("/api/embed", handlers.embedHandler)
40+
r.HandleFunc("/api/chat", handlers.chatHandler)
3941
}
4042

4143
func (h *handlers) generateHandler(w http.ResponseWriter, r *http.Request) {
@@ -142,6 +144,193 @@ func (h *handlers) embedHandler(w http.ResponseWriter, r *http.Request) {
142144
}
143145
}
144146

147+
// ChatRequest represents a chat completion request for the Ollama API.
148+
type ChatRequest struct {
149+
Model string `json:"model,omitempty"`
150+
Messages []ChatMessage `json:"messages,omitempty"`
151+
Format json.RawMessage `json:"format,omitempty"`
152+
Options Options `json:"options,omitempty"`
153+
}
154+
155+
// ChatMessage represents a single message in a chat.
156+
type ChatMessage struct {
157+
Role string `json:"role,omitempty"`
158+
Content string `json:"content,omitempty"`
159+
Images []string `json:"images,omitempty"`
160+
}
161+
162+
// ChatResponse represents a chat completion response for the Ollama API.
163+
type ChatResponse struct {
164+
Model string `json:"model,omitempty"`
165+
CreatedAt time.Time `json:"created_at,omitempty"`
166+
Message ChatMessage `json:"message,omitempty"`
167+
Done bool `json:"done,omitempty"`
168+
PromptEvalCount int32 `json:"prompt_eval_count"`
169+
EvalCount int32 `json:"eval_count"`
170+
TotalDuration int64 `json:"total_duration"`
171+
LoadDuration int64 `json:"load_duration,omitempty"`
172+
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
173+
EvalDuration int64 `json:"eval_duration"`
174+
}
175+
176+
func sanitizeJson(s string) string {
177+
s = strings.ReplaceAll(s, "\n", "")
178+
s = strings.TrimPrefix(s, "```json")
179+
s = strings.TrimSuffix(s, "```")
180+
s = strings.ReplaceAll(s, "'", "\\'")
181+
return s
182+
}
183+
184+
// chatHandler handles POST /api/chat requests.
185+
func (h *handlers) chatHandler(w http.ResponseWriter, r *http.Request) {
186+
if r.Method != http.MethodPost {
187+
internal.ErrorHandler(w, r, http.StatusMethodNotAllowed, "method not allowed")
188+
return
189+
}
190+
body, err := io.ReadAll(r.Body)
191+
if err != nil {
192+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to read request body: %v", err)
193+
return
194+
}
195+
defer r.Body.Close()
196+
197+
var req ChatRequest
198+
if err := json.Unmarshal(body, &req); err != nil {
199+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to unmarshal chat request: %v", err)
200+
return
201+
}
202+
203+
// Handle advanced format parameter: JSON mode or JSON schema enforcement
204+
expectJson := false
205+
if len(req.Format) > 0 {
206+
expectJson = true
207+
var formatVal interface{}
208+
if err := json.Unmarshal(req.Format, &formatVal); err != nil {
209+
internal.ErrorHandler(w, r, http.StatusBadRequest, "invalid format parameter: %v", err)
210+
return
211+
}
212+
var instr string
213+
switch v := formatVal.(type) {
214+
case string:
215+
if v == "json" {
216+
instr = "Please respond with valid JSON."
217+
} else {
218+
instr = fmt.Sprintf("Please respond with format: %s.", v)
219+
}
220+
default:
221+
schemaBytes, err := json.MarshalIndent(v, "", " ")
222+
if err != nil {
223+
schemaBytes = req.Format
224+
}
225+
instr = fmt.Sprintf("Please format your response according to the following JSON schema:\n%s", string(schemaBytes))
226+
}
227+
// Integrate with existing system message if present
228+
found := false
229+
for i, m := range req.Messages {
230+
if m.Role == "system" {
231+
req.Messages[i].Content = m.Content + "\n\n" + instr
232+
found = true
233+
break
234+
}
235+
}
236+
if !found {
237+
req.Messages = append([]ChatMessage{{Role: "system", Content: instr}}, req.Messages...)
238+
}
239+
}
240+
241+
model := h.client.GenerativeModel(req.Model)
242+
model.GenerationConfig = genai.GenerationConfig{
243+
Temperature: req.Options.Temperature,
244+
MaxOutputTokens: req.Options.NumPredict,
245+
TopK: req.Options.TopK,
246+
TopP: req.Options.TopP,
247+
}
248+
if req.Options.Stop != nil {
249+
model.GenerationConfig.StopSequences = []string{*req.Options.Stop}
250+
}
251+
252+
chat := model.StartChat()
253+
var lastPart genai.Part
254+
for i, m := range req.Messages {
255+
if m.Role == "system" {
256+
model.SystemInstruction = &genai.Content{
257+
Role: m.Role,
258+
Parts: []genai.Part{genai.Text(m.Content)},
259+
}
260+
continue
261+
}
262+
if i == len(req.Messages)-1 {
263+
lastPart = genai.Text(m.Content)
264+
break
265+
}
266+
chat.History = append(chat.History, &genai.Content{
267+
Role: m.Role,
268+
Parts: []genai.Part{genai.Text(m.Content)},
269+
})
270+
}
271+
272+
// Measure time spent generating the chat response
273+
start := time.Now()
274+
275+
gresp, err := chat.SendMessage(r.Context(), lastPart)
276+
if err != nil {
277+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to send chat message: %v", err)
278+
return
279+
}
280+
var builder strings.Builder
281+
if len(gresp.Candidates) > 0 {
282+
for _, part := range gresp.Candidates[0].Content.Parts {
283+
if txt, ok := part.(genai.Text); ok {
284+
builder.WriteString(string(txt))
285+
}
286+
}
287+
}
288+
289+
var resp ChatResponse
290+
if expectJson {
291+
resp = ChatResponse{
292+
Model: req.Model,
293+
CreatedAt: time.Now(),
294+
Message: ChatMessage{
295+
Role: gresp.Candidates[0].Content.Role,
296+
Content: sanitizeJson(builder.String()),
297+
},
298+
Done: true,
299+
}
300+
} else {
301+
resp = ChatResponse{
302+
Model: req.Model,
303+
CreatedAt: time.Now(),
304+
Message: ChatMessage{
305+
Role: gresp.Candidates[0].Content.Role,
306+
Content: builder.String(),
307+
},
308+
Done: true,
309+
}
310+
}
311+
312+
if gresp.UsageMetadata != nil {
313+
resp.PromptEvalCount = gresp.UsageMetadata.PromptTokenCount
314+
// Compute number of tokens in the response.
315+
if gresp.UsageMetadata.CandidatesTokenCount > 0 {
316+
resp.EvalCount = gresp.UsageMetadata.CandidatesTokenCount
317+
} else if gresp.UsageMetadata.TotalTokenCount >= gresp.UsageMetadata.PromptTokenCount {
318+
// Fallback: use total tokens minus prompt tokens
319+
resp.EvalCount = gresp.UsageMetadata.TotalTokenCount - gresp.UsageMetadata.PromptTokenCount
320+
}
321+
}
322+
// Populate duration metadata (in nanoseconds)
323+
elapsed := time.Since(start).Nanoseconds()
324+
resp.TotalDuration = elapsed
325+
resp.LoadDuration = 0
326+
resp.PromptEvalDuration = 0
327+
resp.EvalDuration = elapsed
328+
if err := json.NewEncoder(w).Encode(&resp); err != nil {
329+
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to encode chat response: %v", err)
330+
return
331+
}
332+
}
333+
145334
type GenerateRequest struct {
146335
Model string `json:"model,omitempty"`
147336
Prompt string `json:"prompt,omitempty"`

0 commit comments

Comments
 (0)