subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Union
import numpy as np
import torch
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.utils import logging
__all__ = ['AggregateTokenizer', 'TokenizerWrapper']
class DummyTokenizer:
def __init__(self, vocab):
self.vocab = vocab
self.vocab_size = len(vocab)
# minimum compatibility
# since all the monolingual tokenizers have a vocab
# additional methods could be added here
def get_vocab(self):
return self.vocab
class AggregateTokenizer(TokenizerSpec):
'''
AggregateTokenizer, allowing one to combine multiple regular monolongual tokenizers into one tokenizer.
The intuition is that we can use existing tokenizers "as is", without retraining, and associate each tokenizer with a language id
during text processing (language id will be used to route the incoming text sample to the right tokenizer)
as well as a token id range for detokenization (e.g. [0..127] for tokenizer A, [128..255] for tokenizer B) so
that the orignal text could be reconstructed. Note that we assume that the incoming dict of langs / tokenizers
is ordered, e.g. the first tokenizer will be assigned a lower interval of token ids
Args:
tokenizers: dict of tokenizers, keys are lang ids, values are actual tokenizers
'''
def __init__(self, tokenizers: Dict):
self.tokenizers_dict = tokenizers
self.vocabulary = []
# the tokenizers should produce non-overlapping, ordered token ids
# keys are language ids
self.token_id_offset = {}
# keys are tokenizer numbers
self.token_id_offset_by_tokenizer_num = {}
offset = 0
i = 0
for lang, tokenizer in self.tokenizers_dict.items():
self.token_id_offset[lang] = offset
self.token_id_offset_by_tokenizer_num[i] = offset
offset += len(tokenizer.vocab)
i += 1
for tokenizer in self.tokenizers_dict.values():
self.vocabulary.extend(tokenizer.vocab)
self.vocab_size = len(self.vocabulary)
logging.info(f'Aggregate vocab size: {self.vocab_size}')
# for compatibility purposes only -- right now only the get_vocab method
# is supported, returning the joint vocab across all tokenizers
self.tokenizer = DummyTokenizer(self.vocabulary)
# lookup tables to speed up token to text operations
# if there are two tokenizers, [0,1], ['en', 'es'], each with 128 tokens, the aggregate tokenizer
# token range will be [0,255]. The below method provides three look up tables:
# one, to convert the incoming token id -- e.g. 200 into its real id (200-127 = 73)
# second, to compute the tokenizer id that should process that token (1)
# third, the compute the lang id for that token ('es')
offset_token_ids_by_token_id, tokenizers_by_token_id, langs_by_token_id = self._calculate_offsets()
self.offset_token_ids_by_token_id = offset_token_ids_by_token_id
self.tokenizers_by_token_id = tokenizers_by_token_id
self.langs_by_token_id = langs_by_token_id
def _calculate_offsets(self):
offsets = {}
tokenizers = {}
langs = {}
cur_num = 0
tot = len(self.tokenizers_dict)
for id in range(len(self.vocabulary)):
off_id = id - list(self.token_id_offset.values())[cur_num]
if cur_num + 1 < tot:
if id >= list(self.token_id_offset.values())[cur_num + 1]:
cur_num += 1
off_id = id - list(self.token_id_offset.values())[cur_num]
offsets[id] = off_id
tokenizers[id] = list(self.tokenizers_dict.values())[cur_num]
langs[id] = list(self.tokenizers_dict.keys())[cur_num]
return offsets, tokenizers, langs
def text_to_tokens(self, text, lang_id):
tokenizer = self.tokenizers_dict[lang_id]
return tokenizer.text_to_tokens(text)
def text_to_ids(self, text, lang_id):
tokenizer = self.tokenizers_dict[lang_id]
token_ids = tokenizer.text_to_ids(text)
token_ids[:] = [t + self.token_id_offset[lang_id] for t in token_ids]
return token_ids
def tokens_to_text(self, tokens, lang_id):
if isinstance(tokens, np.ndarray):
tokens = tokens.tolist()
tokenizer = self.tokenizers_dict[lang_id]
return tokenizer.decode_pieces(tokens)
def ids_to_text(self, ids):
if isinstance(ids, (np.ndarray, torch.Tensor)):
ids = ids.tolist()
tokens = []
for id in ids:
offset_id = self.offset_token_ids_by_token_id[id]
tokenizer = self.tokenizers_by_token_id[id]
tokens.extend(tokenizer.ids_to_tokens([offset_id]))
text = ''.join(tokens).replace('▁', ' ')
return text
def token_to_id(self, token, lang_id):
tokenizer = self.tokenizers_dict[lang_id]
return tokenizer.token_to_id(token) + self.token_id_offset[lang_id]
def ids_to_tokens(self, ids):
tokens = []
for id in ids:
offset_id = self.offset_token_ids_by_token_id[id]
tokenizer = self.tokenizers_by_token_id[id]
token = tokenizer.ids_to_tokens([offset_id])[0]
tokens.append(token)
return tokens
def ids_to_text_and_langs(self, ids):
text_and_langs = []
for id in ids:
offset_id = self.offset_token_ids_by_token_id[id]
tokenizer = self.tokenizers_by_token_id[id]
token = tokenizer.ids_to_tokens([offset_id])[0]
text = token.replace('▁', ' ')
text = text.strip() # strip for display purposes
lang = self.langs_by_token_id[id]
text_and_langs.append({'char': text, 'lang': lang})
return text_and_langs
def ids_to_words_and_langs(self, ids):
words_and_langs = []
word_ids = [] # tokens belonging to the current word
for id in ids:
offset_id = self.offset_token_ids_by_token_id[id]
tokenizer = self.tokenizers_by_token_id[id]
token = tokenizer.ids_to_tokens([offset_id])[0]
if token.startswith('▁'):
if len(word_ids) > 0: # if this isn't the first word
word = self.ids_to_text(word_ids)
word = word.strip() # strip for display purposes
lang = self.ids_to_lang(word_ids)
wl = {'word': word, 'lang': lang}
words_and_langs.append(wl)
word_ids = []
word_ids.append(id)
if len(word_ids) > 0: # the last tokens
word = self.ids_to_text(word_ids)
word = word.strip() # strip for display purposes
lang = self.ids_to_lang(word_ids)
wl = {'word': word, 'lang': lang}
words_and_langs.append(wl)
return words_and_langs
def ids_to_lang(self, ids):
lang_cnts = {}
for id in ids:
lang = self.langs_by_token_id[id]
lang_cnt = lang_cnts.get(lang)
if lang_cnt is not None:
lang_cnts[lang] = lang_cnt + 1
else:
lang_cnts[lang] = 1
max_lang = ''
max_lang_cnt = -1
for lang, lang_cnt in lang_cnts.items():
if lang_cnt > max_lang_cnt:
max_lang = lang
max_lang_cnt = lang_cnt
return max_lang
def tokens_to_ids(self, tokens: Union[str, List[str]], langs: Union[str, List[str]]) -> Union[int, List[int]]:
if isinstance(tokens, str):
tokens = [tokens]
if isinstance(langs, str):
langs = [langs]
ids = []
for i, token in enumerate(tokens):
lang_id = langs[i]
ids.append(self.token_to_id(token, lang_id))
return ids
def get_bos(self, lang_id: str) -> int:
return self.tokenizers_dict[lang_id].bos + self.token_id_offset[lang_id]
def get_eos(self, lang_id: str) -> int:
return self.tokenizers_dict[lang_id].eos + self.token_id_offset[lang_id]
@property
def vocab(self):
return self.vocabulary
@property
def langs(self):
return list(self.tokenizers_dict.keys())
class TokenizerWrapper:
"""
Provide a unified interface for NeMo Tokenizer, AggregateTokenizer, and (char) Parser.
"""
def __init__(self, tokenizer):
self._tokenizer = tokenizer
if isinstance(tokenizer, AggregateTokenizer):
self._impl = self._call_agg_tokenizer
elif isinstance(tokenizer, TokenizerSpec):
self._impl = self._call_tokenizer
else:
self._impl = self._call_parser
def __call__(self, text: str, lang: str | None = None):
return self._impl(text, lang)
def _call_agg_tokenizer(self, text: str, lang: str | None = None):
assert lang is not None, "Expected 'lang' to be set for AggregateTokenizer."
return self._tokenizer.text_to_ids(text, lang)
def _call_tokenizer(self, text: str, lang: str | None = None):
return self._tokenizer.text_to_ids(text)
def _call_parser(self, text: str, lang: str | None = None):
return self._tokenizer(text)