Skip to content

Commit 6a61085

Browse files
juliendenizekhluu
authored andcommitted
[BUGFIX] Fix regex pattern for Mistral Tool Call (#29918)
Signed-off-by: juliendenize <[email protected]> (cherry picked from commit 1b1e35a)
1 parent 9057fc2 commit 6a61085

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

tests/models/language/generation/test_mistral.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,38 @@ def get_vocab():
315315
assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
316316
# No additional content outside the tool call should be returned.
317317
assert parsed.content is None
318+
319+
# multiple calls
320+
multiple_args_dict = [
321+
{
322+
"city": "Dallas",
323+
"state": "TX",
324+
"unit": "fahrenheit",
325+
"sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}},
326+
},
327+
{},
328+
{"a": 0},
329+
{"a": 1, "b": "c"},
330+
]
331+
names = ["get_current_weather", "get_current_weather_2", "random", "random_2"]
332+
333+
model_output = "".join(
334+
[
335+
f"{parser.bot_token}{name}{json.dumps(args)}"
336+
for name, args in zip(names, multiple_args_dict)
337+
]
338+
)
339+
340+
parsed = parser.extract_tool_calls(model_output, None)
341+
342+
# Assertions: the tool call is detected and the full nested JSON is parsed
343+
# without truncation.
344+
assert parsed.tools_called
345+
assert len(parsed.tool_calls) == len(multiple_args_dict)
346+
347+
for i, tool_call in enumerate(parsed.tool_calls):
348+
assert MistralToolCall.is_valid_id(tool_call.id)
349+
assert tool_call.function.name == names[i]
350+
assert json.loads(tool_call.function.arguments) == multiple_args_dict[i]
351+
# No additional content outside the tool call should be returned.
352+
assert parsed.content is None

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, tokenizer: TokenizerLike):
8080
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
8181
if _is_fn_name_regex_support(self.model_tokenizer):
8282
self.fn_name_regex = re.compile(
83-
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)?", re.DOTALL
83+
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
8484
)
8585
else:
8686
self.fn_name_regex = None

0 commit comments

Comments
 (0)