Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import sys | |
| import wave | |
| import api.nmt_pb2 as nmt | |
| import api.nmt_pb2_grpc as nmtsrv | |
| import grpc | |
| import pyaudio | |
| import riva_api.audio_pb2 as riva | |
| import riva_api.riva_asr_pb2 as rivaasr | |
| import riva_api.riva_asr_pb2_grpc as rivaasr_srv | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description="Streaming transcription via Riva AI Speech Services") | |
| parser.add_argument("--riva-server", default="localhost:50051", type=str, help="URI to GRPC server endpoint") | |
| parser.add_argument("--audio-file", required=True, help="path to local file to stream") | |
| parser.add_argument("--output-device", type=int, default=None, help="output device to use") | |
| parser.add_argument("--list-devices", action="store_true", help="list output devices indices") | |
| parser.add_argument("--nmt-server", default="localhost:50052", help="port on which NMT server runs") | |
| parser.add_argument("--asr_only", action="store_true", help="Whether to skip MT and just display") | |
| parser.add_argument("--target_language", default="es", help="Target language to translate into.") | |
| parser.add_argument( | |
| "--asr_punctuation", | |
| action="store_true", | |
| help="Whether to use Riva's punctuation model for ASR transcript postprocessing.", | |
| ) | |
| return parser.parse_args() | |
| def listen_print_loop(responses, nmt_stub, target_language, asr_only=False): | |
| num_chars_printed = 0 | |
| prev_utterances = [] | |
| for response in responses: | |
| if not response.results: | |
| continue | |
| result = response.results[0] | |
| if not result.alternatives: | |
| continue | |
| transcript = result.alternatives[0].transcript | |
| original_transcript = transcript | |
| if not asr_only: | |
| req = nmt.TranslateTextRequest(texts=[transcript], source_language='en', target_language=target_language) | |
| translation = nmt_stub.TranslateText(req).translations[0].translation | |
| transcript = translation | |
| overwrite_chars = ' ' * (num_chars_printed - len(transcript)) | |
| if not result.is_final: | |
| sys.stdout.write(">> " + transcript + overwrite_chars + '\r') | |
| sys.stdout.flush() | |
| num_chars_printed = len(transcript) + 3 | |
| else: | |
| print("## " + transcript + overwrite_chars + "\n") | |
| num_chars_printed = 0 | |
| prev_utterances.append(original_transcript) | |
| CHUNK = 1024 | |
| args = get_args() | |
| wf = wave.open(args.audio_file, 'rb') | |
| channel = grpc.insecure_channel(args.riva_server) | |
| client = rivaasr_srv.RivaSpeechRecognitionStub(channel) | |
| nmt_channel = grpc.insecure_channel(args.nmt_server) | |
| nmt_stub = nmtsrv.RivaTranslateStub(nmt_channel) | |
| config = rivaasr.RecognitionConfig( | |
| encoding=riva.AudioEncoding.LINEAR_PCM, | |
| sample_rate_hertz=wf.getframerate(), | |
| language_code="en-US", | |
| max_alternatives=1, | |
| enable_automatic_punctuation=args.asr_punctuation, | |
| ) | |
| streaming_config = rivaasr.StreamingRecognitionConfig(config=config, interim_results=True) | |
| # instantiate PyAudio (1) | |
| p = pyaudio.PyAudio() | |
| if args.list_devices: | |
| for i in range(p.get_device_count()): | |
| info = p.get_device_info_by_index(i) | |
| if info['maxOutputChannels'] < 1: | |
| continue | |
| print(f"{info['index']}: {info['name']}") | |
| sys.exit(0) | |
| # open stream (2) | |
| stream = p.open( | |
| output_device_index=args.output_device, | |
| format=p.get_format_from_width(wf.getsampwidth()), | |
| channels=wf.getnchannels(), | |
| rate=wf.getframerate(), | |
| output=True, | |
| ) | |
| # read data | |
| def generator(w, s): | |
| d = w.readframes(CHUNK) | |
| yield rivaasr.StreamingRecognizeRequest(streaming_config=s) | |
| while len(d) > 0: | |
| yield rivaasr.StreamingRecognizeRequest(audio_content=d) | |
| stream.write(d) | |
| d = w.readframes(CHUNK) | |
| return | |
| responses = client.StreamingRecognize(generator(wf, streaming_config)) | |
| listen_print_loop(responses, nmt_stub, target_language=args.target_language, asr_only=args.asr_only) | |
| # stop stream (4) | |
| stream.stop_stream() | |
| stream.close() | |
| # close PyAudio (5) | |
| p.terminate() | |