| | |
| |
|
| | import ast |
| | import json |
| | import re |
| | from collections.abc import Sequence |
| | from typing import Union |
| |
|
| | import partial_json_parser |
| | from partial_json_parser.core.options import Allow |
| |
|
| | from vllm.entrypoints.openai.protocol import ( |
| | ChatCompletionRequest, |
| | DeltaFunctionCall, DeltaMessage, |
| | DeltaToolCall, |
| | ExtractedToolCallInformation, |
| | FunctionCall, |
| | ToolCall, |
| | ) |
| | from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| | ToolParser, |
| | ToolParserManager, |
| | ) |
| | from vllm.logger import init_logger |
| | from vllm.transformers_utils.tokenizer import AnyTokenizer |
| | from vllm.utils import random_uuid |
| |
|
| | logger = init_logger(__name__) |
| |
|
| |
|
| | @ToolParserManager.register_module("nemotron_json") |
| | class NemotronJSONToolParser(ToolParser): |
| |
|
| | def __init__(self, tokenizer: AnyTokenizer): |
| | super().__init__(tokenizer) |
| |
|
| | self.current_tool_name_sent: bool = False |
| | self.prev_tool_call_arr: list[dict] = [] |
| | self.current_tool_id: int = -1 |
| | self.streamed_args_for_tool: list[str] = [] |
| |
|
| | self.tool_call_start_token: str = "<TOOLCALL>" |
| | self.tool_call_end_token: str = "</TOOLCALL>" |
| |
|
| | self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL) |
| |
|
| | def extract_tool_calls( |
| | self, |
| | model_output: str, |
| | request: ChatCompletionRequest, |
| | ) -> ExtractedToolCallInformation: |
| |
|
| | if self.tool_call_start_token not in model_output: |
| | return ExtractedToolCallInformation( |
| | tools_called=False, |
| | tool_calls=[], |
| | content=model_output, |
| | ) |
| |
|
| | else: |
| |
|
| | try: |
| | str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip() |
| | if not str_tool_calls.startswith("["): |
| | str_tool_calls = "[" + str_tool_calls |
| | if not str_tool_calls.endswith("]"): |
| | str_tool_calls = "]" + str_tool_calls |
| | json_tool_calls = json.loads(str_tool_calls) |
| | tool_calls = [] |
| | for tool_call in json_tool_calls: |
| | try: |
| | tool_calls.append(ToolCall( |
| | type="function", |
| | function=FunctionCall( |
| | name=tool_call["name"], |
| | arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \ |
| | if isinstance(tool_call["arguments"], dict) else tool_call["arguments"], |
| | ), |
| | )) |
| | except: |
| | continue |
| |
|
| | content = model_output[:model_output.rfind(self.tool_call_start_token)] |
| |
|
| | return ExtractedToolCallInformation( |
| | tools_called=True, |
| | tool_calls=tool_calls, |
| | content=content if content else None, |
| | ) |
| |
|
| | except Exception: |
| | logger.exception(f"Error in extracting tool call from response. Response: {model_output}") |
| | return ExtractedToolCallInformation( |
| | tools_called=False, |
| | tool_calls=[], |
| | content=model_output, |
| | ) |
| |
|
| | def extract_tool_calls_streaming( |
| | self, |
| | previous_text: str, |
| | current_text: str, |
| | delta_text: str, |
| | previous_token_ids: Sequence[int], |
| | current_token_ids: Sequence[int], |
| | delta_token_ids: Sequence[int], |
| | request: ChatCompletionRequest, |
| | ) -> Union[DeltaMessage, None]: |
| |
|
| | raise NotImplementedError("Tool calling is not supported in streaming mode!") |
| |
|