zw-kitten-tts-working / get_model.py
HeWhoComes's picture
Update get_model.py
07b35cf verified
import json
import os
from huggingface_hub import hf_hub_download
from onnx_model import KittenTTS_1_Onnx
class KittenTTS:
"""Main KittenTTS class for text-to-speech synthesis."""
def __init__(self, model_name="KittenML/kitten-tts-nano-0.1", cache_dir=None):
"""Initialize KittenTTS with a model from Hugging Face.
Args:
model_name: Hugging Face repository ID or model name
cache_dir: Directory to cache downloaded files
"""
# Handle different model name formats
if "/" not in model_name:
# If just model name provided, assume it's from KittenML
repo_id = f"KittenML/{model_name}"
else:
repo_id = model_name
self.model = download_from_huggingface(repo_id=repo_id, cache_dir=cache_dir)
def generate(self, text, voice="expr-voice-5-m", speed=1.0):
"""Generate audio from text.
Args:
text: Input text to synthesize
voice: Voice to use for synthesis
speed: Speech speed (1.0 = normal)
Returns:
Audio data as numpy array
"""
return self.model.generate(text, voice=voice, speed=speed)
def generate_to_file(self, text, output_path, voice="expr-voice-5-m", speed=1.0, sample_rate=24000):
"""Generate audio from text and save to file.
Args:
text: Input text to synthesize
output_path: Path to save the audio file
voice: Voice to use for synthesis
speed: Speech speed (1.0 = normal)
sample_rate: Audio sample rate
"""
return self.model.generate_to_file(text, output_path, voice=voice, speed=speed, sample_rate=sample_rate)
@property
def available_voices(self):
"""Get list of available voices."""
return self.model.available_voices
def download_from_huggingface(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir=None):
"""Download model files from Hugging Face repository.
Args:
repo_id: Hugging Face repository ID
cache_dir: Directory to cache downloaded files
Returns:
KittenTTS_1_Onnx: Instantiated model ready for use
"""
# Download config file first
config_path = hf_hub_download(
repo_id=repo_id,
filename="config.json",
cache_dir=cache_dir
)
# Load config
with open(config_path, 'r') as f:
config = json.load(f)
if config.get("type") != "ONNX1":
raise ValueError("Unsupported model type.")
# Download model and voices files based on config
model_path = hf_hub_download(
repo_id=repo_id,
filename=config["model_file"],
cache_dir=cache_dir
)
voices_path = hf_hub_download(
repo_id=repo_id,
filename=config["voices"],
cache_dir=cache_dir
)
# Instantiate and return model
model = KittenTTS_1_Onnx(model_path=model_path, voices_path=voices_path)
return model
def get_model(repo_id="KittenML/kitten-tts-nano-0.1", cache_dir=None):
"""Get a KittenTTS model (legacy function for backward compatibility)."""
return KittenTTS(repo_id, cache_dir)