Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Dict, List, Union | |
| import numpy as np | |
| import torch | |
| from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec | |
| from nemo.utils import logging | |
| __all__ = ['AggregateTokenizer', 'TokenizerWrapper'] | |
| class DummyTokenizer: | |
| def __init__(self, vocab): | |
| self.vocab = vocab | |
| self.vocab_size = len(vocab) | |
| # minimum compatibility | |
| # since all the monolingual tokenizers have a vocab | |
| # additional methods could be added here | |
| def get_vocab(self): | |
| return self.vocab | |
| class AggregateTokenizer(TokenizerSpec): | |
| ''' | |
| AggregateTokenizer, allowing one to combine multiple regular monolongual tokenizers into one tokenizer. | |
| The intuition is that we can use existing tokenizers "as is", without retraining, and associate each tokenizer with a language id | |
| during text processing (language id will be used to route the incoming text sample to the right tokenizer) | |
| as well as a token id range for detokenization (e.g. [0..127] for tokenizer A, [128..255] for tokenizer B) so | |
| that the orignal text could be reconstructed. Note that we assume that the incoming dict of langs / tokenizers | |
| is ordered, e.g. the first tokenizer will be assigned a lower interval of token ids | |
| Args: | |
| tokenizers: dict of tokenizers, keys are lang ids, values are actual tokenizers | |
| ''' | |
| def __init__(self, tokenizers: Dict): | |
| self.tokenizers_dict = tokenizers | |
| self.vocabulary = [] | |
| # the tokenizers should produce non-overlapping, ordered token ids | |
| # keys are language ids | |
| self.token_id_offset = {} | |
| # keys are tokenizer numbers | |
| self.token_id_offset_by_tokenizer_num = {} | |
| offset = 0 | |
| i = 0 | |
| for lang, tokenizer in self.tokenizers_dict.items(): | |
| self.token_id_offset[lang] = offset | |
| self.token_id_offset_by_tokenizer_num[i] = offset | |
| offset += len(tokenizer.vocab) | |
| i += 1 | |
| for tokenizer in self.tokenizers_dict.values(): | |
| self.vocabulary.extend(tokenizer.vocab) | |
| self.vocab_size = len(self.vocabulary) | |
| logging.info(f'Aggregate vocab size: {self.vocab_size}') | |
| # for compatibility purposes only -- right now only the get_vocab method | |
| # is supported, returning the joint vocab across all tokenizers | |
| self.tokenizer = DummyTokenizer(self.vocabulary) | |
| # lookup tables to speed up token to text operations | |
| # if there are two tokenizers, [0,1], ['en', 'es'], each with 128 tokens, the aggregate tokenizer | |
| # token range will be [0,255]. The below method provides three look up tables: | |
| # one, to convert the incoming token id -- e.g. 200 into its real id (200-127 = 73) | |
| # second, to compute the tokenizer id that should process that token (1) | |
| # third, the compute the lang id for that token ('es') | |
| offset_token_ids_by_token_id, tokenizers_by_token_id, langs_by_token_id = self._calculate_offsets() | |
| self.offset_token_ids_by_token_id = offset_token_ids_by_token_id | |
| self.tokenizers_by_token_id = tokenizers_by_token_id | |
| self.langs_by_token_id = langs_by_token_id | |
| def _calculate_offsets(self): | |
| offsets = {} | |
| tokenizers = {} | |
| langs = {} | |
| cur_num = 0 | |
| tot = len(self.tokenizers_dict) | |
| for id in range(len(self.vocabulary)): | |
| off_id = id - list(self.token_id_offset.values())[cur_num] | |
| if cur_num + 1 < tot: | |
| if id >= list(self.token_id_offset.values())[cur_num + 1]: | |
| cur_num += 1 | |
| off_id = id - list(self.token_id_offset.values())[cur_num] | |
| offsets[id] = off_id | |
| tokenizers[id] = list(self.tokenizers_dict.values())[cur_num] | |
| langs[id] = list(self.tokenizers_dict.keys())[cur_num] | |
| return offsets, tokenizers, langs | |
| def text_to_tokens(self, text, lang_id): | |
| tokenizer = self.tokenizers_dict[lang_id] | |
| return tokenizer.text_to_tokens(text) | |
| def text_to_ids(self, text, lang_id): | |
| tokenizer = self.tokenizers_dict[lang_id] | |
| token_ids = tokenizer.text_to_ids(text) | |
| token_ids[:] = [t + self.token_id_offset[lang_id] for t in token_ids] | |
| return token_ids | |
| def tokens_to_text(self, tokens, lang_id): | |
| if isinstance(tokens, np.ndarray): | |
| tokens = tokens.tolist() | |
| tokenizer = self.tokenizers_dict[lang_id] | |
| return tokenizer.decode_pieces(tokens) | |
| def ids_to_text(self, ids): | |
| if isinstance(ids, (np.ndarray, torch.Tensor)): | |
| ids = ids.tolist() | |
| tokens = [] | |
| for id in ids: | |
| offset_id = self.offset_token_ids_by_token_id[id] | |
| tokenizer = self.tokenizers_by_token_id[id] | |
| tokens.extend(tokenizer.ids_to_tokens([offset_id])) | |
| text = ''.join(tokens).replace('▁', ' ') | |
| return text | |
| def token_to_id(self, token, lang_id): | |
| tokenizer = self.tokenizers_dict[lang_id] | |
| return tokenizer.token_to_id(token) + self.token_id_offset[lang_id] | |
| def ids_to_tokens(self, ids): | |
| tokens = [] | |
| for id in ids: | |
| offset_id = self.offset_token_ids_by_token_id[id] | |
| tokenizer = self.tokenizers_by_token_id[id] | |
| token = tokenizer.ids_to_tokens([offset_id])[0] | |
| tokens.append(token) | |
| return tokens | |
| def ids_to_text_and_langs(self, ids): | |
| text_and_langs = [] | |
| for id in ids: | |
| offset_id = self.offset_token_ids_by_token_id[id] | |
| tokenizer = self.tokenizers_by_token_id[id] | |
| token = tokenizer.ids_to_tokens([offset_id])[0] | |
| text = token.replace('▁', ' ') | |
| text = text.strip() # strip for display purposes | |
| lang = self.langs_by_token_id[id] | |
| text_and_langs.append({'char': text, 'lang': lang}) | |
| return text_and_langs | |
| def ids_to_words_and_langs(self, ids): | |
| words_and_langs = [] | |
| word_ids = [] # tokens belonging to the current word | |
| for id in ids: | |
| offset_id = self.offset_token_ids_by_token_id[id] | |
| tokenizer = self.tokenizers_by_token_id[id] | |
| token = tokenizer.ids_to_tokens([offset_id])[0] | |
| if token.startswith('▁'): | |
| if len(word_ids) > 0: # if this isn't the first word | |
| word = self.ids_to_text(word_ids) | |
| word = word.strip() # strip for display purposes | |
| lang = self.ids_to_lang(word_ids) | |
| wl = {'word': word, 'lang': lang} | |
| words_and_langs.append(wl) | |
| word_ids = [] | |
| word_ids.append(id) | |
| if len(word_ids) > 0: # the last tokens | |
| word = self.ids_to_text(word_ids) | |
| word = word.strip() # strip for display purposes | |
| lang = self.ids_to_lang(word_ids) | |
| wl = {'word': word, 'lang': lang} | |
| words_and_langs.append(wl) | |
| return words_and_langs | |
| def ids_to_lang(self, ids): | |
| lang_cnts = {} | |
| for id in ids: | |
| lang = self.langs_by_token_id[id] | |
| lang_cnt = lang_cnts.get(lang) | |
| if lang_cnt is not None: | |
| lang_cnts[lang] = lang_cnt + 1 | |
| else: | |
| lang_cnts[lang] = 1 | |
| max_lang = '' | |
| max_lang_cnt = -1 | |
| for lang, lang_cnt in lang_cnts.items(): | |
| if lang_cnt > max_lang_cnt: | |
| max_lang = lang | |
| max_lang_cnt = lang_cnt | |
| return max_lang | |
| def tokens_to_ids(self, tokens: Union[str, List[str]], langs: Union[str, List[str]]) -> Union[int, List[int]]: | |
| if isinstance(tokens, str): | |
| tokens = [tokens] | |
| if isinstance(langs, str): | |
| langs = [langs] | |
| ids = [] | |
| for i, token in enumerate(tokens): | |
| lang_id = langs[i] | |
| ids.append(self.token_to_id(token, lang_id)) | |
| return ids | |
| def get_bos(self, lang_id: str) -> int: | |
| return self.tokenizers_dict[lang_id].bos + self.token_id_offset[lang_id] | |
| def get_eos(self, lang_id: str) -> int: | |
| return self.tokenizers_dict[lang_id].eos + self.token_id_offset[lang_id] | |
| def vocab(self): | |
| return self.vocabulary | |
| def langs(self): | |
| return list(self.tokenizers_dict.keys()) | |
| class TokenizerWrapper: | |
| """ | |
| Provide a unified interface for NeMo Tokenizer, AggregateTokenizer, and (char) Parser. | |
| """ | |
| def __init__(self, tokenizer): | |
| self._tokenizer = tokenizer | |
| if isinstance(tokenizer, AggregateTokenizer): | |
| self._impl = self._call_agg_tokenizer | |
| elif isinstance(tokenizer, TokenizerSpec): | |
| self._impl = self._call_tokenizer | |
| else: | |
| self._impl = self._call_parser | |
| def __call__(self, text: str, lang: str | None = None): | |
| return self._impl(text, lang) | |
| def _call_agg_tokenizer(self, text: str, lang: str | None = None): | |
| assert lang is not None, "Expected 'lang' to be set for AggregateTokenizer." | |
| return self._tokenizer.text_to_ids(text, lang) | |
| def _call_tokenizer(self, text: str, lang: str | None = None): | |
| return self._tokenizer.text_to_ids(text) | |
| def _call_parser(self, text: str, lang: str | None = None): | |
| return self._tokenizer(text) | |