File size: 7,202 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from functools import cache

TEMPLATE_VAR_VALIDATION_PAT = re.compile(r'^\{_[A-Za-z][A-Za-z0-9_]*_\}$')
TEMPLATE_VAR_SEARCH_PAT = re.compile('({_[^}]+_})')


class ChatTemplateMixin:
    def apply_chat_template(self, messages):
        assert self.chat_template is not None
        return tokenize_with_chat_template(self, messages, self.chat_template)

    @property
    def has_chat_template(self):
        return self.chat_template is not None


@cache
def is_template_var(s):
    # It should start with {_ and end with _}, be non-empty and not contain { or } within.
    return re.match(TEMPLATE_VAR_VALIDATION_PAT, s)


def extract_template_parts(template, skip_empty=True):
    for part in re.split(TEMPLATE_VAR_SEARCH_PAT, template):
        # skip empty parts
        if skip_empty and part == '':
            continue
        yield part


def strip_template_wrap(s):
    if not is_template_var(s):
        return s
    # Strip the "{_" prefix and the "_}" suffix
    return s[2:-2]


def render_chat_turn(message, template):
    """Renders a chat turn based on template

    Args:
        message (Dict)
        e.g. {'role': ['user'], 'content': ['What is your favourite fruit?']},
        template (Str):
            "[INST] {_content_} [/INST]",

    Returns:
        (str, token_id/None): the template formatted message
        e.g.
            "[INST] What is your favourite fruit? [/INST]", None
    """
    ans = []
    for i, template_part in enumerate(extract_template_parts(template)):
        if is_template_var(template_part):
            template_part = strip_template_wrap(template_part)
            if template_part == 'content':
                ans.append(message['content'])
            else:
                # assert i == len(template_parts) - 1, "unsupported"
                yield ''.join(ans), template_part
                ans = []
        else:
            # Otherwise it is literal string
            ans.append(template_part)
    yield ''.join(ans), None


def encode_string_with_special_token(tokenizer, inputs, special_token):
    """
    Tokenizes a string or a list of string into their corresponding token_ids
    and appends (at the end) a special_token if present.

    Args:
        tokenizer: (SPM)
        inputs: (Str, List[Str])
        e.g. "Alex" or ["Alex", "nvidia"]
        special_token: (Str):
        e.g. "eos"

        Returns:
         (list[int]): list of token_ids
         e.g.
            input="Alex", special_token="eos"
            Alex->[3413]
            eos->[2]

            Will return the following:
            [3413, 2]
    """
    ans = []
    if isinstance(inputs, str) and inputs != '':
        ans += tokenizer.text_to_ids(inputs)
    elif isinstance(inputs, list) and len(inputs) > 0:
        ans += tokenizer.text_to_ids(''.join(inputs))
    if special_token is not None:
        # TODO(@akoumparouli): limit which attributes user-defined string can query.
        assert hasattr(tokenizer, special_token), f"Special_token {special_token} is not part of tokenizer"
        ans += [getattr(tokenizer, special_token)]
    return ans


def tokenize_with_chat_template(tokenizer, messages, template):
    assert is_chat_input(messages), "Expected input to be chat-template"
    assert len(messages) > 0, "Expected non-empty messages"
    assert 'roles' in template, "Expected template to have key `roles`."
    ans = []
    encode = lambda x, y: encode_string_with_special_token(tokenizer, x, y)
    if 'prefix' in template:
        for part, special_token in render_chat_turn('', template['prefix']):
            ans += encode(part, special_token)
    buffer = []
    for message in messages:
        assert message['role'] in template['roles'], (message['role'], template['roles'])
        msg_template = template['roles'][message['role']]
        for templated_messages, special_token in render_chat_turn(message, msg_template):
            buffer += [templated_messages]
            if special_token is not None:
                ans += encode(buffer, special_token)
                buffer = []
    # handle tail
    ans += encode(buffer, None)
    assert len(ans) > 0, 'Expected non-empty output'
    return ans


def extract_turns(messages, axis):
    """
    a collated messages can have multiple chat messages in each dict,
    this extracts (vertically) one of them, for example:

    messages = [
        {'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']},
        {'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]},
        {'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']}
    ]
    ans = extract_turns(messages, axis=1)

    ans = [
        {'role': ['user'], 'content': ['What is your favourite fruit?']},
        {'role': ['assistant'], 'content': ["good squeeze of fresh lemon"]},
        {'role': ['user'], 'content': ['Do you have tomato salad recipes?']}
    ]
    """
    ans = []
    for turn in messages:
        ans.append({k: v[axis] for k, v in turn.items()})
    return ans


def explode_chat_template_input(messages):
    """
    Example input
    [
       {'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']},
       {'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]},
       {'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']}
    ]

    Notice the 2D axis system of the messages variable, one for the list and one for each item in the list (i.e.
    the 'content' contains multiple messages).
    """
    assert isinstance(messages, list), "Expected messages to be a list"
    assert len(messages) > 0, "Expected non empty messages"
    assert all(map(lambda x: isinstance(x, dict), messages)), "Expected messages to contain dicts"
    assert all(
        map(lambda x: 'role' in x and 'content' in x, messages)
    ), "Expected messages each dict to contain 'role' and 'content' fields"
    n = len(messages[0]['role'])
    assert all(
        map(lambda x: len(x['role']) == n, messages)
    ), "Expected all batch messages to contain equal number of roles in all turns"
    for i in range(n):
        yield extract_turns(messages, axis=i)


def is_chat_input(messages):
    # TOOD(@akoumparouli): improve validation.
    return isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict)