| import msvcrt | |
| import traceback | |
| import time | |
| import requests | |
| import time | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| from src.utils.config import settings | |
| from src.utils import ( | |
| VoiceGenerator, | |
| get_ai_response, | |
| play_audio_with_interrupt, | |
| init_vad_pipeline, | |
| detect_speech_segments, | |
| record_continuous_audio, | |
| check_for_speech, | |
| transcribe_audio, | |
| ) | |
| from src.utils.audio_queue import AudioGenerationQueue | |
| from src.utils.llm import parse_stream_chunk | |
| import threading | |
| from src.utils.text_chunker import TextChunker | |
| settings.setup_directories() | |
| timing_info = { | |
| "vad_start": None, | |
| "transcription_start": None, | |
| "llm_first_token": None, | |
| "audio_queued": None, | |
| "first_audio_play": None, | |
| "playback_start": None, | |
| "end": None, | |
| "transcription_duration": None, | |
| } | |
| def process_input( | |
| session: requests.Session, | |
| user_input: str, | |
| messages: list, | |
| generator: VoiceGenerator, | |
| speed: float, | |
| ) -> tuple[bool, None]: | |
| """Processes user input, generates a response, and handles audio output. | |
| Args: | |
| session (requests.Session): The requests session to use. | |
| user_input (str): The user's input text. | |
| messages (list): The list of messages to send to the LLM. | |
| generator (VoiceGenerator): The voice generator object. | |
| speed (float): The playback speed. | |
| Returns: | |
| tuple[bool, None]: A tuple containing a boolean indicating if the process was interrupted and None. | |
| """ | |
| global timing_info | |
| timing_info = {k: None for k in timing_info} | |
| timing_info["vad_start"] = time.perf_counter() | |
| messages.append({"role": "user", "content": user_input}) | |
| print("\nThinking...") | |
| start_time = time.time() | |
| try: | |
| response_stream = get_ai_response( | |
| session=session, | |
| messages=messages, | |
| llm_model=settings.LLM_MODEL, | |
| llm_url=settings.OLLAMA_URL, | |
| max_tokens=settings.MAX_TOKENS, | |
| stream=True, | |
| ) | |
| if not response_stream: | |
| print("Failed to get AI response stream.") | |
| return False, None | |
| audio_queue = AudioGenerationQueue(generator, speed) | |
| audio_queue.start() | |
| chunker = TextChunker() | |
| complete_response = [] | |
| playback_thread = threading.Thread( | |
| target=lambda: audio_playback_worker(audio_queue) | |
| ) | |
| playback_thread.daemon = True | |
| playback_thread.start() | |
| for chunk in response_stream: | |
| data = parse_stream_chunk(chunk) | |
| if not data or "choices" not in data: | |
| continue | |
| choice = data["choices"][0] | |
| if "delta" in choice and "content" in choice["delta"]: | |
| content = choice["delta"]["content"] | |
| if content: | |
| if not timing_info["llm_first_token"]: | |
| timing_info["llm_first_token"] = time.perf_counter() | |
| print(content, end="", flush=True) | |
| chunker.current_text.append(content) | |
| text = "".join(chunker.current_text) | |
| if chunker.should_process(text): | |
| if not timing_info["audio_queued"]: | |
| timing_info["audio_queued"] = time.perf_counter() | |
| remaining = chunker.process(text, audio_queue) | |
| chunker.current_text = [remaining] | |
| complete_response.append(text[: len(text) - len(remaining)]) | |
| if choice.get("finish_reason") == "stop": | |
| final_text = "".join(chunker.current_text).strip() | |
| if final_text: | |
| chunker.process(final_text, audio_queue) | |
| complete_response.append(final_text) | |
| break | |
| messages.append({"role": "assistant", "content": " ".join(complete_response)}) | |
| print() | |
| time.sleep(0.1) | |
| audio_queue.stop() | |
| playback_thread.join() | |
| def playback_wrapper(): | |
| timing_info["playback_start"] = time.perf_counter() | |
| result = audio_playback_worker(audio_queue) | |
| return result | |
| playback_thread = threading.Thread(target=playback_wrapper) | |
| timing_info["end"] = time.perf_counter() | |
| print_timing_chart(timing_info) | |
| return False, None | |
| except Exception as e: | |
| print(f"\nError during streaming: {str(e)}") | |
| if "audio_queue" in locals(): | |
| audio_queue.stop() | |
| return False, None | |
| def audio_playback_worker(audio_queue) -> tuple[bool, None]: | |
| """Manages audio playback in a separate thread, handling interruptions. | |
| Args: | |
| audio_queue (AudioGenerationQueue): The audio queue object. | |
| Returns: | |
| tuple[bool, None]: A tuple containing a boolean indicating if the playback was interrupted and the interrupt audio data. | |
| """ | |
| global timing_info | |
| was_interrupted = False | |
| interrupt_audio = None | |
| try: | |
| while True: | |
| speech_detected, audio_data = check_for_speech() | |
| if speech_detected: | |
| was_interrupted = True | |
| interrupt_audio = audio_data | |
| break | |
| audio_data, _ = audio_queue.get_next_audio() | |
| if audio_data is not None: | |
| if not timing_info["first_audio_play"]: | |
| timing_info["first_audio_play"] = time.perf_counter() | |
| was_interrupted, interrupt_data = play_audio_with_interrupt(audio_data) | |
| if was_interrupted: | |
| interrupt_audio = interrupt_data | |
| break | |
| else: | |
| time.sleep(settings.PLAYBACK_DELAY) | |
| if ( | |
| not audio_queue.is_running | |
| and audio_queue.sentence_queue.empty() | |
| and audio_queue.audio_queue.empty() | |
| ): | |
| break | |
| except Exception as e: | |
| print(f"Error in audio playback: {str(e)}") | |
| return was_interrupted, interrupt_audio | |
| def main(): | |
| """Main function to run the voice chat bot.""" | |
| with requests.Session() as session: | |
| try: | |
| session = requests.Session() | |
| generator = VoiceGenerator(settings.MODELS_DIR, settings.VOICES_DIR) | |
| messages = [{"role": "system", "content": settings.DEFAULT_SYSTEM_PROMPT}] | |
| print("\nInitializing Whisper model...") | |
| whisper_processor = WhisperProcessor.from_pretrained(settings.WHISPER_MODEL) | |
| whisper_model = WhisperForConditionalGeneration.from_pretrained( | |
| settings.WHISPER_MODEL | |
| ) | |
| print("\nInitializing Voice Activity Detection...") | |
| vad_pipeline = init_vad_pipeline(settings.HUGGINGFACE_TOKEN) | |
| print("\n=== Voice Chat Bot Initializing ===") | |
| print("Device being used:", generator.device) | |
| print("\nInitializing voice generator...") | |
| result = generator.initialize(settings.TTS_MODEL, settings.VOICE_NAME) | |
| print(result) | |
| speed = settings.SPEED | |
| try: | |
| print("\nWarming up the LLM model...") | |
| health = session.get("http://localhost:11434", timeout=3) | |
| if health.status_code != 200: | |
| print("Ollama not running! Start it first.") | |
| return | |
| response_stream = get_ai_response( | |
| session=session, | |
| messages=[ | |
| {"role": "system", "content": settings.DEFAULT_SYSTEM_PROMPT}, | |
| {"role": "user", "content": "Hi!"}, | |
| ], | |
| llm_model=settings.LLM_MODEL, | |
| llm_url=settings.OLLAMA_URL, | |
| max_tokens=settings.MAX_TOKENS, | |
| stream=False, | |
| ) | |
| if not response_stream: | |
| print("Failed to initialized the AI model!") | |
| return | |
| except requests.RequestException as e: | |
| print(f"Warmup failed: {str(e)}") | |
| print("\n\n=== Voice Chat Bot Ready ===") | |
| print("The bot is now listening for speech.") | |
| print("Just start speaking, and I'll respond automatically!") | |
| print("You can interrupt me anytime by starting to speak.") | |
| while True: | |
| try: | |
| if msvcrt.kbhit(): | |
| user_input = input("\nYou (text): ").strip() | |
| if user_input.lower() == "quit": | |
| print("Goodbye!") | |
| break | |
| audio_data = record_continuous_audio() | |
| if audio_data is not None: | |
| speech_segments = detect_speech_segments( | |
| vad_pipeline, audio_data | |
| ) | |
| if speech_segments is not None: | |
| print("\nTranscribing detected speech...") | |
| timing_info["transcription_start"] = time.perf_counter() | |
| user_input = transcribe_audio( | |
| whisper_processor, whisper_model, speech_segments | |
| ) | |
| timing_info["transcription_duration"] = ( | |
| time.perf_counter() - timing_info["transcription_start"] | |
| ) | |
| if user_input.strip(): | |
| print(f"You (voice): {user_input}") | |
| was_interrupted, speech_data = process_input( | |
| session, user_input, messages, generator, speed | |
| ) | |
| if was_interrupted and speech_data is not None: | |
| speech_segments = detect_speech_segments( | |
| vad_pipeline, speech_data | |
| ) | |
| if speech_segments is not None: | |
| print("\nTranscribing interrupted speech...") | |
| user_input = transcribe_audio( | |
| whisper_processor, | |
| whisper_model, | |
| speech_segments, | |
| ) | |
| if user_input.strip(): | |
| print(f"You (voice): {user_input}") | |
| process_input( | |
| session, | |
| user_input, | |
| messages, | |
| generator, | |
| speed, | |
| ) | |
| else: | |
| print("No clear speech detected, please try again.") | |
| if session is not None: | |
| session.headers.update({"Connection": "keep-alive"}) | |
| if hasattr(session, "connection_pool"): | |
| session.connection_pool.clear() | |
| except KeyboardInterrupt: | |
| print("\nStopping...") | |
| break | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| continue | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| print("\nFull traceback:") | |
| traceback.print_exc() | |
| def print_timing_chart(metrics): | |
| """Prints timing chart from global metrics""" | |
| base_time = metrics["vad_start"] | |
| events = [ | |
| ("User stopped speaking", metrics["vad_start"]), | |
| ("VAD started", metrics["vad_start"]), | |
| ("Transcription started", metrics["transcription_start"]), | |
| ("LLM first token", metrics["llm_first_token"]), | |
| ("Audio queued", metrics["audio_queued"]), | |
| ("First audio played", metrics["first_audio_play"]), | |
| ("Playback started", metrics["playback_start"]), | |
| ("End-to-end response", metrics["end"]), | |
| ] | |
| print("\nTiming Chart:") | |
| print(f"{'Event':<25} | {'Time (s)':>9} | {'Δ+':>6}") | |
| print("-" * 45) | |
| prev_time = base_time | |
| for name, t in events: | |
| if t is None: | |
| continue | |
| elapsed = t - base_time | |
| delta = t - prev_time | |
| print(f"{name:<25} | {elapsed:9.2f} | {delta:6.2f}") | |
| prev_time = t | |
| if __name__ == "__main__": | |
| main() | |