Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import os | |
| from concurrent import futures | |
| import api.nmt_pb2 as nmt | |
| import api.nmt_pb2_grpc as nmtsrv | |
| import grpc | |
| import torch | |
| import nemo.collections.nlp as nemo_nlp | |
| from nemo.utils import logging | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model_dir", required=True, type=str, help="Path to a folder containing .nemo translation model files.", | |
| ) | |
| parser.add_argument( | |
| "--punctuation_model", | |
| default="", | |
| type=str, | |
| help="Optionally provide a path a .nemo file for punctation and capitalization (recommend if working with Riva speech recognition outputs)", | |
| ) | |
| parser.add_argument("--port", default=50052, type=int, required=False) | |
| parser.add_argument("--batch_size", type=int, default=256, help="Maximum number of batches to process") | |
| parser.add_argument("--beam_size", type=int, default=1, help="Beam Size") | |
| parser.add_argument("--len_pen", type=float, default=0.6, help="Length Penalty") | |
| parser.add_argument("--max_delta_length", type=int, default=5, help="Max Delta Generation Length.") | |
| args = parser.parse_args() | |
| return args | |
| def batches(lst, n): | |
| """Yield successive n-sized chunks from lst.""" | |
| for i in range(0, len(lst), n): | |
| yield lst[i : i + n] | |
| class RivaTranslateServicer(nmtsrv.RivaTranslateServicer): | |
| """Provides methods that implement functionality of route guide server.""" | |
| def __init__( | |
| self, model_dir, punctuation_model_path, beam_size=1, len_pen=0.6, max_delta_length=5, batch_size=256, | |
| ): | |
| self._models = {} | |
| self._beam_size = beam_size | |
| self._len_pen = len_pen | |
| self._max_delta_length = max_delta_length | |
| self._batch_size = batch_size | |
| self._punctuation_model_path = punctuation_model_path | |
| self._model_dir = model_dir | |
| model_paths = [os.path.join(model_dir, fname) for fname in os.listdir(model_dir) if fname.endswith('.nemo')] | |
| for idx, model_path in enumerate(model_paths): | |
| assert os.path.exists(model_path) | |
| logging.info(f"Loading model {model_path}") | |
| self._load_model(model_path) | |
| if self._punctuation_model_path != "": | |
| assert os.path.exists(punctuation_model_path) | |
| logging.info(f"Loading punctuation model {model_path}") | |
| self._load_puncutation_model(punctuation_model_path) | |
| logging.info("Models loaded. Ready for inference requests.") | |
| def _load_puncutation_model(self, punctuation_model_path): | |
| if punctuation_model_path.endswith(".nemo"): | |
| self.punctuation_model = nemo_nlp.models.PunctuationCapitalizationModel.restore_from( | |
| restore_path=punctuation_model_path | |
| ) | |
| self.punctuation_model.eval() | |
| else: | |
| raise NotImplemented(f"Only support .nemo files, but got: {punctuation_model_path}") | |
| if torch.cuda.is_available(): | |
| self.punctuation_model = self.punctuation_model.cuda() | |
| def _load_model(self, model_path): | |
| if model_path.endswith(".nemo"): | |
| logging.info("Attempting to initialize from .nemo file") | |
| model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path) | |
| model = model.eval() | |
| model.beam_search.beam_size = self._beam_size | |
| model.beam_search.len_pen = self._len_pen | |
| model.beam_search.max_delta_length = self._max_delta_length | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| else: | |
| raise NotImplemented(f"Only support .nemo files, but got: {model_path}") | |
| if not hasattr(model, "src_language") or not hasattr(model, "tgt_language"): | |
| raise ValueError( | |
| f"Could not find src_language and tgt_language in model attributes. If using NeMo rc1 checkpoints, please edit the config files to add model.src_language and model.tgt_language" | |
| ) | |
| src_language = model.src_language | |
| tgt_language = model.tgt_language | |
| if src_language not in self._models: | |
| self._models[src_language] = {} | |
| if tgt_language not in self._models[src_language]: | |
| self._models[src_language][tgt_language] = model | |
| if torch.cuda.is_available(): | |
| self._models[src_language][tgt_language] = self._models[src_language][tgt_language].cuda() | |
| else: | |
| raise ValueError(f"Already found model for language pair {src_language}-{tgt_language}") | |
| def TranslateText(self, request, context): | |
| logging.info(f"Request received w/ {len(request.texts)} utterances") | |
| results = [] | |
| if request.source_language not in self._models: | |
| context.set_code(grpc.StatusCode.INVALID_ARGUMENT) | |
| context.set_details( | |
| f"Could not find source-target language pair {request.source_language}-{request.target_language} in list of models." | |
| ) | |
| return nmt.TranslateTextResponse() | |
| if request.target_language not in self._models[request.source_language]: | |
| context.set_code(grpc.StatusCode.INVALID_ARGUMENT) | |
| context.set_details( | |
| f"Could not find source-target language pair {request.source_language}-{request.target_language} in list of models." | |
| ) | |
| return nmt.TranslateTextResponse() | |
| request_strings = [x for x in request.texts] | |
| for batch in batches(request_strings, self._batch_size): | |
| if self._punctuation_model_path != "": | |
| batch = self.punctuation_model.add_punctuation_capitalization(batch) | |
| batch_results = self._models[request.source_language][request.target_language].translate(text=batch) | |
| translations = [nmt.Translation(translation=x) for x in batch_results] | |
| results.extend(translations) | |
| return nmt.TranslateTextResponse(translations=results) | |
| def serve(): | |
| args = get_args() | |
| server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |
| servicer = RivaTranslateServicer( | |
| model_dir=args.model_dir, | |
| punctuation_model_path=args.punctuation_model, | |
| beam_size=args.beam_size, | |
| len_pen=args.len_pen, | |
| batch_size=args.batch_size, | |
| max_delta_length=args.max_delta_length, | |
| ) | |
| nmtsrv.add_RivaTranslateServicer_to_server(servicer, server) | |
| server.add_insecure_port('[::]:' + str(args.port)) | |
| server.start() | |
| server.wait_for_termination() | |
| if __name__ == '__main__': | |
| serve() | |