66from ._internals import LlamaContext , LlamaBatch
77
88import ctypes
9- from typing import Union , List
9+ from typing import Union , List , Optional , Any , Tuple
10+
11+ import llama_cpp .llama_types as llama_types
12+ import llama_cpp .llama as llama
13+ import jinja2
14+ from jinja2 .sandbox import ImmutableSandboxedEnvironment
15+ import copy
16+ import numpy as np
17+ import numpy .typing as npt
18+ import os
19+
20+ from .llama_chat_format import ChatFormatter , ChatFormatterResponse
1021
1122class TextChunk :
1223 def __init__ (self , tokens : List [int ]):
@@ -77,6 +88,242 @@ def close(self):
7788 def __del__ (self ):
7889 self .close ()
7990
91+ DEFAULT_MEDIA_MARKER = mtmd .mtmd_default_marker ().decode ('utf-8' )
92+
93+ class Jinja2MultimodalChatFormatter (ChatFormatter ):
94+ def __init__ (
95+ self ,
96+ template : str ,
97+ eos_token : str ,
98+ bos_token : str ,
99+ add_generation_prompt : bool = True ,
100+ stop_token_ids : Optional [List [int ]] = None ,
101+ placeholders : List [str ] = None
102+ ):
103+ """A chat formatter that uses jinja2 templates to format the prompt."""
104+ self .template = template
105+ self .eos_token = eos_token
106+ self .bos_token = bos_token
107+ self .add_generation_prompt = add_generation_prompt
108+ self .stop_token_ids = (
109+ set (stop_token_ids ) if stop_token_ids is not None else None
110+ )
111+
112+ self .chat_template = ImmutableSandboxedEnvironment (
113+ loader = jinja2 .BaseLoader (),
114+ trim_blocks = True ,
115+ lstrip_blocks = True
116+ ).from_string (template )
117+
118+ # Placeholder mapping, mtmd_tokenize requires <__media__>
119+ self .placeholders = placeholders if placeholders else [
120+ "<|vision_start|><|image_pad|><|vision_end|>" , # Qwen3-VL
121+ "<image>" , # LLaVA / Yi
122+ "<image_placeholder>" ,# DeepSeek
123+ ]
124+
125+ def __call__ (
126+ self ,
127+ messages : List [llama_types .ChatCompletionRequestMessage ],
128+ functions : Optional [List [llama_types .ChatCompletionFunction ]] = None ,
129+ function_call : Optional [llama_types .ChatCompletionRequestFunctionCall ] = None ,
130+ tools : Optional [List [llama_types .ChatCompletionTool ]] = None ,
131+ tool_choice : Optional [llama_types .ChatCompletionToolChoiceOption ] = None ,
132+ ** kwargs : Any ,
133+ ) -> Tuple [str , List [Union [str , bytes , bytearray ]], List [str ]]:
134+ def raise_exception (message : str ):
135+ raise ValueError (message )
136+
137+ def strftime_now (format_string = "%Y-%m-%d %H:%M:%S" ) -> str :
138+ """
139+ Returns the current time formatted as a string.
140+ """
141+ return datetime .datetime .now ().strftime (format_string )
142+
143+ messages = copy .deepcopy (messages )
144+ media_urls , media_types = self .split_media (messages )
145+ medias = []
146+
147+ for url , m_type in zip (media_urls , media_types ):
148+ if m_type == "video" :
149+ raise ValueError ("Video input is not supported yet." )
150+
151+ data = self ._fetch_media (url , m_type )
152+
153+ #if m_type == "image" and isinstance(data, bytes):
154+ # data = self._compress_image(data)
155+
156+ medias .append (data )
157+
158+ prompt = self .chat_template .render (
159+ messages = messages ,
160+ eos_token = self .eos_token ,
161+ bos_token = self .bos_token ,
162+ raise_exception = raise_exception ,
163+ strftime_now = strftime_now ,
164+ add_generation_prompt = self .add_generation_prompt ,
165+ functions = functions ,
166+ function_call = function_call ,
167+ tools = tools ,
168+ tool_choice = tool_choice ,
169+ )
170+
171+ for p in self .placeholders :
172+ prompt = prompt .replace (p , DEFAULT_MEDIA_MARKER )
173+
174+ stopping_criteria = None
175+ if self .stop_token_ids is not None :
176+
177+ def stop_on_last_token (
178+ tokens : npt .NDArray [np .intc ], logits : npt .NDArray [np .single ]
179+ ) -> bool :
180+ return tokens [- 1 ] in self .stop_token_ids
181+
182+ stopping_criteria = llama .StoppingCriteriaList ([stop_on_last_token ])
183+
184+ return ChatFormatterResponse (
185+ prompt = prompt ,
186+ stop = [self .eos_token ],
187+ stopping_criteria = stopping_criteria ,
188+ added_special = True ,
189+ medias = medias ,
190+ media_types = media_types
191+ )
192+
193+ @staticmethod
194+ def split_media (messages : List [llama_types .ChatCompletionRequestMessage ]):
195+ media_urls : List [Union [str , bytes , bytearray ]] = []
196+ media_types : List [str ] = []
197+
198+ for message in messages :
199+ if message .get ("role" ) != "user" or not isinstance (message .get ("content" ), list ):
200+ continue
201+
202+ for content in message ["content" ]:
203+ if not (isinstance (content , dict ) and "type" in content ):
204+ continue
205+
206+ c_type = content ["type" ]
207+ if c_type == "text" :
208+ continue
209+
210+ value = content [c_type ]
211+
212+ if isinstance (value , dict ) and "url" in value :
213+ media_urls .append (value ["url" ])
214+ value ["url" ] = DEFAULT_MEDIA_MARKER
215+ else :
216+ media_urls .append (value )
217+ content [c_type ] = DEFAULT_MEDIA_MARKER
218+
219+ if c_type == "image" or c_type == "image_url" :
220+ media_types .append ("image" )
221+
222+ elif c_type == "audio" or c_type == "audio_url" :
223+ media_types .append ("audio" )
224+
225+ elif c_type == "video" or c_type == "video_url" :
226+ media_types .append ("video" )
227+
228+ else :
229+ raise ValueError (f"Unsupported content type { c_type } " )
230+
231+ return media_urls , media_types
232+
233+ @staticmethod
234+ def _fetch_media (media_input : Union [str , bytes ], media_type : str ) -> Union [str , bytes , bytearray ]:
235+ """
236+ Fetch media (audio, image, video...) from local disk, memory, or internet
237+ """
238+
239+ # --- from_buffer fast path ---
240+ if isinstance (media_input , bytes ) or isinstance (media_input , bytearray ):
241+ return media_input
242+
243+ if not isinstance (media_input , str ):
244+ raise ValueError (f"Unsupported media input type: { type (media_input )} " )
245+
246+ # --- from_file fast path ---
247+ if media_input .startswith ("file://" ):
248+ parsed_path = urllib .parse .urlparse (media_input ).path
249+ # unquote 处理 URL 编码的字符
250+ abs_path = os .path .abspath (urllib .parse .unquote (parsed_path ))
251+ if os .path .exists (abs_path ):
252+ return abs_path
253+ else :
254+ raise FileNotFoundError (f"Local file not found: { abs_path } " )
255+
256+ # --- base64 or remote url ---
257+ raw_bytes = b""
258+ if media_input .startswith ("data:" ):
259+ import base64
260+ # Split only once from the right to correctly handle mime types containing commas
261+ comma_pos = media_input .find ("," )
262+ if comma_pos == - 1 :
263+ raise ValueError ("Invalid data URI: missing comma separator" )
264+
265+ raw_bytes = base64 .b64decode (media_input [comma_pos + 1 :])
266+ elif "://" in media_input :
267+ import urllib .request
268+ from urllib .error import URLError , HTTPError
269+
270+ headers = {"User-Agent" : "Mozilla/5.0" }
271+ req = urllib .request .Request (media_input , headers = headers )
272+
273+ try :
274+ with urllib .request .urlopen (req , timeout = 15 ) as f :
275+ raw_bytes = f .read ()
276+ except (URLError , HTTPError ) as e :
277+ raise ConnectionError (f"Failed to fetch media from { media_input } : { e } " )
278+
279+ else :
280+ # try direct path
281+ if os .path .exists (media_input ):
282+ return os .path .abspath (media_input )
283+ raise ValueError ("Unrecognized media string format" )
284+
285+ if not raw_bytes :
286+ raise ValueError ("Empty data received" )
287+
288+ return raw_bytes
289+
290+ @staticmethod
291+ def _compress_image (image_bytes : bytes ) -> bytes :
292+ try :
293+ from PIL import Image , ImageStat
294+ except ImportError :
295+ raise ImportError ("Pillow is required for image processing. Install with: pip install pillow" )
296+
297+ import io
298+ image = Image .open (io .BytesIO (image_bytes ))
299+
300+ # 4. Handle transparency (RGBA, LA, P with transparency, etc.)
301+ if image .mode in ("RGBA" , "LA" , "PA" ) or (image .mode == "P" and "transparency" in image .info ):
302+ # Use alpha channel as mask
303+ if image .mode == "P" :
304+ image = image .convert ("RGBA" )
305+
306+ alpha = image .split ()[- 1 ] # Last channel is alpha
307+ # Compute average brightness of visible (non-transparent) pixels
308+ stat = ImageStat .Stat (image .convert ("L" ), mask = alpha )
309+
310+ # Choose background: white for dark content, black for bright content
311+ bg_color = (255 , 255 , 255 ) # white
312+ if stat .count [0 ] > 0 and stat .mean [0 ] > 127 :
313+ bg_color = (0 , 0 , 0 ) # black
314+
315+ background = Image .new ("RGB" , image .size , bg_color )
316+ background .paste (image , mask = alpha )
317+ image = background
318+
319+ # 5. Ensure RGB mode for formats like CMYK, palette, etc.
320+ elif image .mode != "RGB" :
321+ image = image .convert ("RGB" )
322+
323+ # 6. Save as high-quality JPEG, suitable for most vision models.
324+ output = io .BytesIO ()
325+ image .save (output , format = "JPEG" , quality = 95 , optimize = True , progressive = True )
326+ return output .getvalue ()
80327
81328# Simple FNV-1a hash implementation to match fnv_hash in C++
82329def fnv_hash (data : bytes ) -> str :
@@ -89,12 +336,12 @@ def fnv_hash(data: bytes) -> str:
89336def mtmd_tokenize (
90337 mctx : mtmd .mtmd_context_p ,
91338 prompt : str ,
92- files_data : list [bytes | str ]) -> MultimodalTokenList :
339+ medias_data : list [Union [ str , bytes , bytearray ] ]) -> MultimodalTokenList :
93340
94341 bitmaps = []
95342 do_hash = False
96343
97- for data in files_data :
344+ for data in medias_data :
98345
99346 bmp = None
100347 if isinstance (data , str ):
@@ -200,3 +447,5 @@ def mtmd_prefill(
200447 raise RuntimeError (f"MTMD eval error: { result } " )
201448
202449 n_past = new_n_past .value
450+
451+ return n_past
0 commit comments