Spaces:
Runtime error
Runtime error
| # styletts_plugin.py | |
| import os | |
| import sys | |
| import numpy as np | |
| import yaml | |
| import torch | |
| import phonemizer | |
| from phonemizer.backend.espeak.wrapper import EspeakWrapper | |
| import soundfile as sf | |
| import httpx | |
| import nltk | |
| import subprocess | |
| from libs.inference import StyleTTS2 | |
| try: | |
| nltk.data.find('tokenizers/punkt_tab') | |
| except nltk.downloader.DownloadError: | |
| print("Đang tải NLTK tokenizer 'punkt_tab'...") | |
| nltk.download('punkt_tab') | |
| print("Tải thành công.") | |
| class StyleTTModel(): | |
| def __init__(self, **kwargs): | |
| self.model_weights_path = "models/base_model.pth" | |
| self.model_config_path = "models/config.yaml" | |
| self.speaker_wav = kwargs.get("speaker_wav", "speakers/example_female.wav") | |
| self.language = kwargs.get("language", "en-us") | |
| self.speed = kwargs.get("speed", 1.0) | |
| self.denoise = kwargs.get("denoise", 0.2) | |
| self.avg_style = kwargs.get("avg_style", True) | |
| self.stabilize = kwargs.get("stabilize", True) | |
| self.device = self._get_device() | |
| self.sample_rate = 24000 | |
| self.model = None | |
| def _get_device(self): | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| def _download_file(self, url: str, destination: str): | |
| print(f"Đang tải file từ {url}...") | |
| try: | |
| os.makedirs(os.path.dirname(destination), exist_ok=True) | |
| with httpx.stream("GET", url, follow_redirects=True, timeout=30) as r: | |
| r.raise_for_status() | |
| with open(destination, 'wb') as f: | |
| for chunk in r.iter_bytes(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Tải thành công và lưu tại: {destination}") | |
| except Exception as e: | |
| print(f"Lỗi khi tải file bằng httpx: {e}") | |
| raise | |
| def _phonemize(self, text: str, lang: str) -> str: | |
| # Tạo mới instance phonemizer mỗi lần gọi để đảm bảo an toàn luồng | |
| if sys.platform == 'darwin': | |
| try: | |
| # Dùng lệnh brew để tìm đường dẫn cài đặt của espeak-ng một cách an toàn | |
| result = subprocess.run(['brew', '--prefix', 'espeak-ng'], capture_output=True, text=True, check=True) | |
| espeak_ng_prefix = result.stdout.strip() | |
| # Xây dựng đường dẫn đến file thư viện động (.dylib) | |
| # Đây là cách làm ổn định hơn nhiều so với việc mã hóa cứng phiên bản | |
| espeak_lib_path = os.path.join(espeak_ng_prefix, 'lib', 'libespeak-ng.dylib') | |
| if os.path.exists(espeak_lib_path): | |
| EspeakWrapper.set_library(espeak_lib_path) | |
| print(f"✅ Đã tự động tìm và cấu hình eSpeak NG cho macOS tại: {espeak_lib_path}") | |
| else: | |
| print(f"⚠️ Không tìm thấy file thư viện tại {espeak_lib_path}. Hãy chắc chắn bạn đã cài espeak-ng qua Homebrew.") | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| print("🛑 Lỗi: Không thể chạy lệnh 'brew'. Hãy chắc chắn Homebrew và espeak-ng đã được cài đặt đúng cách.") | |
| print(" Chạy lệnh 'brew install espeak-ng' trong terminal.") | |
| elif sys.platform == 'win32': | |
| try: | |
| import espeakng_loader | |
| EspeakWrapper.set_library(espeakng_loader.get_library_path()) | |
| EspeakWrapper.data_path = espeakng_loader.get_data_path() | |
| except ImportError: | |
| print("Cảnh báo: Không tìm thấy espeakng_loader.") | |
| phonemizer_instance = phonemizer.backend.EspeakBackend( | |
| language=lang, preserve_punctuation=True, with_stress=True | |
| ) | |
| return phonemizer_instance.phonemize([text])[0] | |
| def cache_speaker_style(self, speaker_wav: str): | |
| """ | |
| Tính toán và cache style của một giọng nói để tái sử dụng. | |
| Hàm này nên được gọi một lần khi bắt đầu cuộc hội thoại. | |
| """ | |
| if self.model is None: | |
| self.load() | |
| print(f"-> Đang tính toán và cache style cho giọng nói: {speaker_wav}") | |
| speaker_info = {"path": speaker_wav, "speed": self.speed} # Tốc độ có thể không cần ở đây | |
| # Sử dụng các tham số mặc định của plugin để cache | |
| with torch.no_grad(): | |
| self.cached_style = self.model.get_styles( | |
| speaker_info, | |
| denoise=self.denoise, | |
| avg_style=self.avg_style | |
| ) | |
| print("-> Cache style thành công.") | |
| def load(self): | |
| print("Đang khởi tạo StyleTTS PyTorch plugin...") | |
| if not os.path.exists(self.model_config_path): | |
| config_url = "https://huggingface.co/dangtr0408/StyleTTS2-lite/resolve/main/Models/config.yaml" | |
| self._download_file(config_url, self.model_config_path) | |
| if not os.path.exists(self.model_weights_path): | |
| weights_url = "https://huggingface.co/dangtr0408/StyleTTS2-lite/resolve/main/Models/base_model.pth" | |
| self._download_file(weights_url, self.model_weights_path) | |
| print("\nBắt đầu tải model PyTorch vào bộ nhớ...") | |
| self.model = StyleTTS2(self.model_config_path, self.model_weights_path) | |
| self.model.eval() | |
| self.model.to(self.device) | |
| print(f"StyleTTS PyTorch plugin đã tải thành công trên thiết bị {self.device}.") | |
| # Tự động cache style cho giọng nói mặc định | |
| print(f"-> Tự động tính toán và cache style cho giọng nói: {self.speaker_wav}") | |
| try: | |
| speaker_info = {"path": self.speaker_wav, "speed": self.speed} | |
| with torch.no_grad(): | |
| self.cached_style = self.model.get_styles( | |
| speaker_info, | |
| denoise=self.denoise, | |
| avg_style=self.avg_style | |
| ) | |
| print("-> Cache style thành công.") | |
| except Exception as e: | |
| print(f"-> CẢNH BÁO: Không thể cache style. Lỗi: {e}") | |
| self.cached_style = None | |
| # "Warm-up" cho phonemizer | |
| print("-> Đang thực hiện warm-up cho phonemizer...") | |
| try: | |
| self._phonemize("warm-up", self.language) | |
| print("-> Phonemizer warm-up thành công.") | |
| except Exception as e: | |
| print(f"-> Cảnh báo: Phonemizer warm-up thất bại: {e}") | |
| return self | |
| def synthesize(self, text: str, **kwargs) -> np.ndarray: | |
| if self.model is None: | |
| self.load() | |
| language = kwargs.get("language", self.language) | |
| speed = kwargs.get("speed", self.speed) | |
| stabilize = kwargs.get("stabilize", self.stabilize) | |
| if not hasattr(self, 'cached_style') or self.cached_style is None: | |
| print("Cảnh báo: Style chưa được cache. Đang tính toán lại...") | |
| speaker_wav = kwargs.get("speaker_wav", self.speaker_wav) | |
| speaker_info = {"path": speaker_wav, "speed": speed} | |
| styles = self.model.get_styles(speaker_info, denoise=kwargs.get("denoise", self.denoise), avg_style=kwargs.get("avg_style", self.avg_style)) | |
| else: | |
| styles = self.cached_style | |
| styles['speed'] = speed | |
| with torch.no_grad(): | |
| phonemes = self._phonemize(text, language) | |
| wav = self.model.generate(phonemes, styles, stabilize=stabilize) | |
| wav = wav / np.max(np.abs(wav)) | |
| return wav.astype(np.float32) | |
| if __name__ == "__main__": | |
| SPEAKER_WAV_PATH = "speakers/example_female.wav" | |
| if not os.path.exists(SPEAKER_WAV_PATH): | |
| print(f"Lỗi: Không tìm thấy file âm thanh mẫu tại '{SPEAKER_WAV_PATH}'.") | |
| else: | |
| # Khởi tạo plugin | |
| styletts_utils = StyleTTModel(speaker_wav=SPEAKER_WAV_PATH) | |
| styletts_utils.load() # Load model trước | |
| print("\n" + "="*50) | |
| print("🔍 KIỂM TRA THIẾT BỊ (DEVICE) RUNTIME") | |
| # 1. PyTorch có "nhìn thấy" GPU không? | |
| cuda_available = torch.cuda.is_available() | |
| print(f" - PyTorch có tìm thấy CUDA không? : {cuda_available}") | |
| if styletts_utils.model: | |
| model_device = next(styletts_utils.model.parameters()).device | |
| print(f" - Model thực sự đang nằm trên? : {model_device}") | |
| if "cuda" in str(model_device): | |
| print("\n>>> KẾT LUẬN: ✅ Model đang chạy trên GPU.") | |
| else: | |
| print("\n>>> KẾT LUẬN: ❌ Model đang chạy trên CPU.") | |
| else: | |
| print(" - Model chưa được load.") | |
| print("="*50) | |
| print("\n--- Thử nghiệm tổng hợp âm thanh ---") | |
| long_text = "StyleTTS 2 is a text-to-speech model that offers zero-shot speaker adaptation." | |
| audio = styletts_utils.synthesize(long_text) | |
| output_path = "plugin_pytorch_output.wav" | |
| styletts_utils.save_audio(audio, output_path) | |
| print(f"✅ Âm thanh đã được lưu thành công tại: {output_path}") |