@@ -18,16 +18,18 @@ package ollama
1818
1919import (
2020 "encoding/json"
21+ "fmt"
22+ "github.com/google-gemini/proxy-to-gemini/internal"
23+ "github.com/google/generative-ai-go/genai"
24+ "github.com/gorilla/mux"
25+ "google.golang.org/api/iterator"
2126 "io"
2227 "net/http"
2328 "strings"
2429 "time"
25-
26- "github.com/google-gemini/proxy-to-gemini/internal"
27- "github.com/google/generative-ai-go/genai"
28- "github.com/gorilla/mux"
2930)
3031
32+ // handlers provides HTTP handlers for the Ollama proxy API.
3133type handlers struct {
3234 client * genai.Client
3335}
@@ -36,6 +38,7 @@ func RegisterHandlers(r *mux.Router, client *genai.Client) {
3638 handlers := & handlers {client : client }
3739 r .HandleFunc ("/api/generate" , handlers .generateHandler )
3840 r .HandleFunc ("/api/embed" , handlers .embedHandler )
41+ r .HandleFunc ("/api/chat" , handlers .chatHandler )
3942}
4043
4144func (h * handlers ) generateHandler (w http.ResponseWriter , r * http.Request ) {
@@ -142,6 +145,245 @@ func (h *handlers) embedHandler(w http.ResponseWriter, r *http.Request) {
142145 }
143146}
144147
148+ // ChatRequest represents a chat completion request for the Ollama API.
149+ type ChatRequest struct {
150+ Model string `json:"model,omitempty"`
151+ Messages []ChatMessage `json:"messages,omitempty"`
152+ Stream bool `json:"stream,omitempty"`
153+ Format json.RawMessage `json:"format,omitempty"`
154+ Options Options `json:"options,omitempty"`
155+ }
156+
157+ // ChatMessage represents a single message in a chat.
158+ type ChatMessage struct {
159+ Role string `json:"role,omitempty"`
160+ Content string `json:"content,omitempty"`
161+ Images []string `json:"images,omitempty"`
162+ }
163+
164+ // ChatResponse represents a chat completion response for the Ollama API.
165+ type ChatResponse struct {
166+ Model string `json:"model,omitempty"`
167+ CreatedAt time.Time `json:"created_at,omitempty"`
168+ Message ChatMessage `json:"message,omitempty"`
169+ Done bool `json:"done,omitempty"`
170+ PromptEvalCount int32 `json:"prompt_eval_count"`
171+ EvalCount int32 `json:"eval_count"`
172+ TotalDuration int64 `json:"total_duration"`
173+ LoadDuration int64 `json:"load_duration,omitempty"`
174+ PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
175+ EvalDuration int64 `json:"eval_duration"`
176+ }
177+
178+ func sanitizeJson (s string ) string {
179+ s = strings .ReplaceAll (s , "\n " , "" )
180+ s = strings .TrimPrefix (s , "```json" )
181+ s = strings .TrimSuffix (s , "```" )
182+ s = strings .ReplaceAll (s , "'" , "\\ '" )
183+ return s
184+ }
185+
186+ // chatHandler handles POST /api/chat requests.
187+ func (h * handlers ) chatHandler (w http.ResponseWriter , r * http.Request ) {
188+ if r .Method != http .MethodPost {
189+ internal .ErrorHandler (w , r , http .StatusMethodNotAllowed , "method not allowed" )
190+ return
191+ }
192+ body , err := io .ReadAll (r .Body )
193+ if err != nil {
194+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to read request body: %v" , err )
195+ return
196+ }
197+ defer r .Body .Close ()
198+
199+ var req ChatRequest
200+ if err := json .Unmarshal (body , & req ); err != nil {
201+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to unmarshal chat request: %v" , err )
202+ return
203+ }
204+
205+ // Handle advanced format parameter: JSON mode or JSON schema enforcement
206+ expectJson := false
207+ if len (req .Format ) > 0 {
208+ expectJson = true
209+ var formatVal interface {}
210+ if err := json .Unmarshal (req .Format , & formatVal ); err != nil {
211+ internal .ErrorHandler (w , r , http .StatusBadRequest , "invalid format parameter: %v" , err )
212+ return
213+ }
214+ var instr string
215+ switch v := formatVal .(type ) {
216+ case string :
217+ if v == "json" {
218+ instr = "Please respond with valid JSON."
219+ } else {
220+ instr = fmt .Sprintf ("Please respond with format: %s." , v )
221+ }
222+ default :
223+ schemaBytes , err := json .MarshalIndent (v , "" , " " )
224+ if err != nil {
225+ schemaBytes = req .Format
226+ }
227+ instr = fmt .Sprintf ("Please format your response according to the following JSON schema:\n %s" , string (schemaBytes ))
228+ }
229+ // Integrate with existing system message if present
230+ found := false
231+ for i , m := range req .Messages {
232+ if m .Role == "system" {
233+ req .Messages [i ].Content = m .Content + "\n \n " + instr
234+ found = true
235+ break
236+ }
237+ }
238+ if ! found {
239+ req .Messages = append ([]ChatMessage {{Role : "system" , Content : instr }}, req .Messages ... )
240+ }
241+ }
242+
243+ model := h .client .GenerativeModel (req .Model )
244+ model .GenerationConfig = genai.GenerationConfig {
245+ Temperature : req .Options .Temperature ,
246+ MaxOutputTokens : req .Options .NumPredict ,
247+ TopK : req .Options .TopK ,
248+ TopP : req .Options .TopP ,
249+ }
250+ if req .Options .Stop != nil {
251+ model .GenerationConfig .StopSequences = []string {* req .Options .Stop }
252+ }
253+
254+ chat := model .StartChat ()
255+ var lastPart genai.Part
256+ for i , m := range req .Messages {
257+ if m .Role == "system" {
258+ model .SystemInstruction = & genai.Content {
259+ Role : m .Role ,
260+ Parts : []genai.Part {genai .Text (m .Content )},
261+ }
262+ continue
263+ }
264+ if i == len (req .Messages )- 1 {
265+ lastPart = genai .Text (m .Content )
266+ break
267+ }
268+ chat .History = append (chat .History , & genai.Content {
269+ Role : m .Role ,
270+ Parts : []genai.Part {genai .Text (m .Content )},
271+ })
272+ }
273+
274+ if req .Stream {
275+ h .streamingChatHandler (w , r , req .Model , chat , lastPart )
276+ return
277+ }
278+ // Measure time spent generating the chat response
279+ start := time .Now ()
280+
281+ gresp , err := chat .SendMessage (r .Context (), lastPart )
282+ if err != nil {
283+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to send chat message: %v" , err )
284+ return
285+ }
286+ var builder strings.Builder
287+ if len (gresp .Candidates ) > 0 {
288+ for _ , part := range gresp .Candidates [0 ].Content .Parts {
289+ if txt , ok := part .(genai.Text ); ok {
290+ builder .WriteString (string (txt ))
291+ }
292+ }
293+ }
294+
295+ var resp ChatResponse
296+ if expectJson {
297+ resp = ChatResponse {
298+ Model : req .Model ,
299+ CreatedAt : time .Now (),
300+ Message : ChatMessage {
301+ Role : gresp .Candidates [0 ].Content .Role ,
302+ Content : sanitizeJson (builder .String ()),
303+ },
304+ Done : true ,
305+ }
306+ } else {
307+ resp = ChatResponse {
308+ Model : req .Model ,
309+ CreatedAt : time .Now (),
310+ Message : ChatMessage {
311+ Role : gresp .Candidates [0 ].Content .Role ,
312+ Content : builder .String (),
313+ },
314+ Done : true ,
315+ }
316+ }
317+
318+ if gresp .UsageMetadata != nil {
319+ resp .PromptEvalCount = gresp .UsageMetadata .PromptTokenCount
320+ // Compute number of tokens in the response.
321+ if gresp .UsageMetadata .CandidatesTokenCount > 0 {
322+ resp .EvalCount = gresp .UsageMetadata .CandidatesTokenCount
323+ } else if gresp .UsageMetadata .TotalTokenCount >= gresp .UsageMetadata .PromptTokenCount {
324+ // Fallback: use total tokens minus prompt tokens
325+ resp .EvalCount = gresp .UsageMetadata .TotalTokenCount - gresp .UsageMetadata .PromptTokenCount
326+ }
327+ }
328+ // Populate duration metadata (in nanoseconds)
329+ elapsed := time .Since (start ).Nanoseconds ()
330+ resp .TotalDuration = elapsed
331+ resp .LoadDuration = 0
332+ resp .PromptEvalDuration = 0
333+ resp .EvalDuration = elapsed
334+ if err := json .NewEncoder (w ).Encode (& resp ); err != nil {
335+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to encode chat response: %v" , err )
336+ return
337+ }
338+ }
339+
340+ // streamingChatHandler handles streaming chat responses for /api/chat.
341+ func (h * handlers ) streamingChatHandler (w http.ResponseWriter , r * http.Request , modelName string , chat * genai.ChatSession , lastPart genai.Part ) {
342+ // Measure total elapsed time for streaming
343+ start := time .Now ()
344+ iter := chat .SendMessageStream (r .Context (), lastPart )
345+ for {
346+ gresp , err := iter .Next ()
347+ if err == iterator .Done {
348+ break
349+ }
350+ if err != nil {
351+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to stream chat response: %v" , err )
352+ return
353+ }
354+ var builder strings.Builder
355+ if len (gresp .Candidates ) > 0 {
356+ for _ , part := range gresp .Candidates [0 ].Content .Parts {
357+ if txt , ok := part .(genai.Text ); ok {
358+ builder .WriteString (string (txt ))
359+ }
360+ }
361+ }
362+ // Build streaming chunk with duration metadata
363+ elapsed := time .Since (start ).Nanoseconds ()
364+ chunk := ChatResponse {
365+ Model : modelName ,
366+ CreatedAt : time .Now (),
367+ Message : ChatMessage {
368+ Role : gresp .Candidates [0 ].Content .Role ,
369+ Content : builder .String (),
370+ },
371+ Done : false ,
372+ TotalDuration : elapsed ,
373+ LoadDuration : 0 ,
374+ PromptEvalDuration : 0 ,
375+ EvalDuration : elapsed ,
376+ }
377+ data , err := json .Marshal (chunk )
378+ if err != nil {
379+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to marshal chat chunk: %v" , err )
380+ return
381+ }
382+ fmt .Fprintf (w , "data: %s\n " , data )
383+ }
384+ fmt .Fprint (w , "data: [DONE]\n " )
385+ }
386+
145387type GenerateRequest struct {
146388 Model string `json:"model,omitempty"`
147389 Prompt string `json:"prompt,omitempty"`
0 commit comments