@@ -18,16 +18,17 @@ package ollama
1818
1919import (
2020 "encoding/json"
21+ "fmt"
22+ "github.com/google-gemini/proxy-to-gemini/internal"
23+ "github.com/google/generative-ai-go/genai"
24+ "github.com/gorilla/mux"
2125 "io"
2226 "net/http"
2327 "strings"
2428 "time"
25-
26- "github.com/google-gemini/proxy-to-gemini/internal"
27- "github.com/google/generative-ai-go/genai"
28- "github.com/gorilla/mux"
2929)
3030
31+ // handlers provides HTTP handlers for the Ollama proxy API.
3132type handlers struct {
3233 client * genai.Client
3334}
@@ -36,6 +37,7 @@ func RegisterHandlers(r *mux.Router, client *genai.Client) {
3637 handlers := & handlers {client : client }
3738 r .HandleFunc ("/api/generate" , handlers .generateHandler )
3839 r .HandleFunc ("/api/embed" , handlers .embedHandler )
40+ r .HandleFunc ("/api/chat" , handlers .chatHandler )
3941}
4042
4143func (h * handlers ) generateHandler (w http.ResponseWriter , r * http.Request ) {
@@ -142,6 +144,193 @@ func (h *handlers) embedHandler(w http.ResponseWriter, r *http.Request) {
142144 }
143145}
144146
147+ // ChatRequest represents a chat completion request for the Ollama API.
148+ type ChatRequest struct {
149+ Model string `json:"model,omitempty"`
150+ Messages []ChatMessage `json:"messages,omitempty"`
151+ Format json.RawMessage `json:"format,omitempty"`
152+ Options Options `json:"options,omitempty"`
153+ }
154+
155+ // ChatMessage represents a single message in a chat.
156+ type ChatMessage struct {
157+ Role string `json:"role,omitempty"`
158+ Content string `json:"content,omitempty"`
159+ Images []string `json:"images,omitempty"`
160+ }
161+
162+ // ChatResponse represents a chat completion response for the Ollama API.
163+ type ChatResponse struct {
164+ Model string `json:"model,omitempty"`
165+ CreatedAt time.Time `json:"created_at,omitempty"`
166+ Message ChatMessage `json:"message,omitempty"`
167+ Done bool `json:"done,omitempty"`
168+ PromptEvalCount int32 `json:"prompt_eval_count"`
169+ EvalCount int32 `json:"eval_count"`
170+ TotalDuration int64 `json:"total_duration"`
171+ LoadDuration int64 `json:"load_duration,omitempty"`
172+ PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
173+ EvalDuration int64 `json:"eval_duration"`
174+ }
175+
176+ func sanitizeJson (s string ) string {
177+ s = strings .ReplaceAll (s , "\n " , "" )
178+ s = strings .TrimPrefix (s , "```json" )
179+ s = strings .TrimSuffix (s , "```" )
180+ s = strings .ReplaceAll (s , "'" , "\\ '" )
181+ return s
182+ }
183+
184+ // chatHandler handles POST /api/chat requests.
185+ func (h * handlers ) chatHandler (w http.ResponseWriter , r * http.Request ) {
186+ if r .Method != http .MethodPost {
187+ internal .ErrorHandler (w , r , http .StatusMethodNotAllowed , "method not allowed" )
188+ return
189+ }
190+ body , err := io .ReadAll (r .Body )
191+ if err != nil {
192+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to read request body: %v" , err )
193+ return
194+ }
195+ defer r .Body .Close ()
196+
197+ var req ChatRequest
198+ if err := json .Unmarshal (body , & req ); err != nil {
199+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to unmarshal chat request: %v" , err )
200+ return
201+ }
202+
203+ // Handle advanced format parameter: JSON mode or JSON schema enforcement
204+ expectJson := false
205+ if len (req .Format ) > 0 {
206+ expectJson = true
207+ var formatVal interface {}
208+ if err := json .Unmarshal (req .Format , & formatVal ); err != nil {
209+ internal .ErrorHandler (w , r , http .StatusBadRequest , "invalid format parameter: %v" , err )
210+ return
211+ }
212+ var instr string
213+ switch v := formatVal .(type ) {
214+ case string :
215+ if v == "json" {
216+ instr = "Please respond with valid JSON."
217+ } else {
218+ instr = fmt .Sprintf ("Please respond with format: %s." , v )
219+ }
220+ default :
221+ schemaBytes , err := json .MarshalIndent (v , "" , " " )
222+ if err != nil {
223+ schemaBytes = req .Format
224+ }
225+ instr = fmt .Sprintf ("Please format your response according to the following JSON schema:\n %s" , string (schemaBytes ))
226+ }
227+ // Integrate with existing system message if present
228+ found := false
229+ for i , m := range req .Messages {
230+ if m .Role == "system" {
231+ req .Messages [i ].Content = m .Content + "\n \n " + instr
232+ found = true
233+ break
234+ }
235+ }
236+ if ! found {
237+ req .Messages = append ([]ChatMessage {{Role : "system" , Content : instr }}, req .Messages ... )
238+ }
239+ }
240+
241+ model := h .client .GenerativeModel (req .Model )
242+ model .GenerationConfig = genai.GenerationConfig {
243+ Temperature : req .Options .Temperature ,
244+ MaxOutputTokens : req .Options .NumPredict ,
245+ TopK : req .Options .TopK ,
246+ TopP : req .Options .TopP ,
247+ }
248+ if req .Options .Stop != nil {
249+ model .GenerationConfig .StopSequences = []string {* req .Options .Stop }
250+ }
251+
252+ chat := model .StartChat ()
253+ var lastPart genai.Part
254+ for i , m := range req .Messages {
255+ if m .Role == "system" {
256+ model .SystemInstruction = & genai.Content {
257+ Role : m .Role ,
258+ Parts : []genai.Part {genai .Text (m .Content )},
259+ }
260+ continue
261+ }
262+ if i == len (req .Messages )- 1 {
263+ lastPart = genai .Text (m .Content )
264+ break
265+ }
266+ chat .History = append (chat .History , & genai.Content {
267+ Role : m .Role ,
268+ Parts : []genai.Part {genai .Text (m .Content )},
269+ })
270+ }
271+
272+ // Measure time spent generating the chat response
273+ start := time .Now ()
274+
275+ gresp , err := chat .SendMessage (r .Context (), lastPart )
276+ if err != nil {
277+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to send chat message: %v" , err )
278+ return
279+ }
280+ var builder strings.Builder
281+ if len (gresp .Candidates ) > 0 {
282+ for _ , part := range gresp .Candidates [0 ].Content .Parts {
283+ if txt , ok := part .(genai.Text ); ok {
284+ builder .WriteString (string (txt ))
285+ }
286+ }
287+ }
288+
289+ var resp ChatResponse
290+ if expectJson {
291+ resp = ChatResponse {
292+ Model : req .Model ,
293+ CreatedAt : time .Now (),
294+ Message : ChatMessage {
295+ Role : gresp .Candidates [0 ].Content .Role ,
296+ Content : sanitizeJson (builder .String ()),
297+ },
298+ Done : true ,
299+ }
300+ } else {
301+ resp = ChatResponse {
302+ Model : req .Model ,
303+ CreatedAt : time .Now (),
304+ Message : ChatMessage {
305+ Role : gresp .Candidates [0 ].Content .Role ,
306+ Content : builder .String (),
307+ },
308+ Done : true ,
309+ }
310+ }
311+
312+ if gresp .UsageMetadata != nil {
313+ resp .PromptEvalCount = gresp .UsageMetadata .PromptTokenCount
314+ // Compute number of tokens in the response.
315+ if gresp .UsageMetadata .CandidatesTokenCount > 0 {
316+ resp .EvalCount = gresp .UsageMetadata .CandidatesTokenCount
317+ } else if gresp .UsageMetadata .TotalTokenCount >= gresp .UsageMetadata .PromptTokenCount {
318+ // Fallback: use total tokens minus prompt tokens
319+ resp .EvalCount = gresp .UsageMetadata .TotalTokenCount - gresp .UsageMetadata .PromptTokenCount
320+ }
321+ }
322+ // Populate duration metadata (in nanoseconds)
323+ elapsed := time .Since (start ).Nanoseconds ()
324+ resp .TotalDuration = elapsed
325+ resp .LoadDuration = 0
326+ resp .PromptEvalDuration = 0
327+ resp .EvalDuration = elapsed
328+ if err := json .NewEncoder (w ).Encode (& resp ); err != nil {
329+ internal .ErrorHandler (w , r , http .StatusInternalServerError , "failed to encode chat response: %v" , err )
330+ return
331+ }
332+ }
333+
145334type GenerateRequest struct {
146335 Model string `json:"model,omitempty"`
147336 Prompt string `json:"prompt,omitempty"`
0 commit comments