TTS-Talker / app_tts.py
Quang Long
update docker-compose, dockerfile, setup cronjob, fix generate audio
1dce2dd
import spaces
import os
import codecs
from huggingface_hub import login
import gradio as gr
from cached_path import cached_path
import tempfile
from vinorm import TTSnorm
from importlib.resources import files
from f5_tts.model import DiT
from f5_tts.infer.utils_infer import (
preprocess_ref_audio_text,
load_vocoder,
load_model,
infer_process,
save_spectrogram,
target_sample_rate as default_target_sample_rate,
n_mel_channels as default_n_mel_channels,
hop_length as default_hop_length,
win_length as default_win_length,
n_fft as default_n_fft,
mel_spec_type as default_mel_spec_type,
target_rms as default_target_rms,
cross_fade_duration as default_cross_fade_duration,
ode_method as default_ode_method,
nfe_step as default_nfe_step, # 16, 32
cfg_strength as default_cfg_strength,
sway_sampling_coef as default_sway_sampling_coef,
speed as default_speed,
fix_duration as default_fix_duration
)
from pathlib import Path
from omegaconf import OmegaConf
from datetime import datetime
import hashlib
import unicodedata
# Retrieve token from secrets
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Log in to Hugging Face
if hf_token:
login(token=hf_token)
# Hàm lấy đường dẫn file cache dựa trên text, ref_audio, model
def get_audio_cache_path(text, ref_audio_path, model, cache_dir="tts_cache"):
os.makedirs(cache_dir, exist_ok=True)
hash_input = f"{text}|{ref_audio_path}|{model}"
hash_val = hashlib.sha256(hash_input.encode("utf-8")).hexdigest()
return os.path.join(cache_dir, f"{hash_val}.wav")
def post_process(text):
text = " " + text + " "
text = text.replace(" . . ", " . ")
text = " " + text + " "
text = text.replace(" .. ", " . ")
text = " " + text + " "
text = text.replace(" , , ", " , ")
text = " " + text + " "
text = text.replace(" ,, ", " , ")
text = " " + text + " "
text = text.replace('"', "")
return " ".join(text.split())
# Load models
@spaces.GPU
def infer_tts(ref_audio_orig: str, ref_text_input: str, gen_text: str, speed: float = 1.0, request: gr.Request = None):
args = {
"model": "F5TTS_Base",
"ckpt_file": str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt")),
"vocab_file": str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json")),
"ref_audio": ref_audio_orig,
"ref_text": ref_text_input,
"gen_text": gen_text,
"speed": speed
}
config = {} # tomli.load(open(args.config, "rb"))
# command-line interface parameters
model = args["model"] or config.get("model", "F5TTS_Base")
ckpt_file = args["ckpt_file"] or config.get("ckpt_file", "")
vocab_file = args["vocab_file"] or config.get("vocab_file", "")
ref_audio = args["ref_audio"] or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
ref_text = args["ref_text"] if args["ref_text"] is not None else config.get("ref_text", "Some call me nature, others call me mother nature.")
gen_text = args["gen_text"] or config.get("gen_text", "Here we generate something just for test.")
gen_file = args.get("gen_file", "") or config.get("gen_file", "")
output_dir = args.get("output_dir", "") or config.get("output_dir", "tests")
output_file = args.get("output_file", "") or config.get("output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav")
save_chunk = args.get("save_chunk", False) or config.get("save_chunk", False)
remove_silence = args.get("remove_silence", False) or config.get("remove_silence", False)
load_vocoder_from_local = args.get("load_vocoder_from_local", False) or config.get("load_vocoder_from_local", False)
vocoder_name = args.get("vocoder_name", "") or config.get("vocoder_name", default_mel_spec_type)
target_rms = args.get("target_rms", None) or config.get("target_rms", default_target_rms)
cross_fade_duration = args.get("cross_fade_duration", None) or config.get("cross_fade_duration", default_cross_fade_duration)
nfe_step = args.get("nfe_step", None) or config.get("nfe_step", default_nfe_step)
cfg_strength = args.get("cfg_strength", None) or config.get("cfg_strength", default_cfg_strength)
sway_sampling_coef = args.get("sway_sampling_coef", None) or config.get("sway_sampling_coef", default_sway_sampling_coef)
speed = args.get("speed", None) or config.get("speed", default_speed)
fix_duration = args.get("fix_duration", None) or config.get("fix_duration", default_fix_duration)
if "infer/examples/" in ref_audio:
ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
if "infer/examples/" in gen_file:
gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
if "voices" in config:
for voice in config["voices"]:
voice_ref_audio = config["voices"][voice]["ref_audio"]
if "infer/examples/" in voice_ref_audio:
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
# ignore gen_text if gen_file provided
if gen_file:
gen_text = codecs.open(gen_file, "r", "utf-8").read()
# output path
wave_path = Path(output_dir) / output_file
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
if save_chunk:
output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
if not os.path.exists(output_chunk_dir):
os.makedirs(output_chunk_dir)
# load vocoder
if vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
# load TTS model
model_cfg = OmegaConf.load(
config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
).model
model_cls = globals()[model_cfg.backbone]
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model != "F5TTS_Base":
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
# override for previous models
if model == "F5TTS_Base":
if vocoder_name == "vocos":
ckpt_step = 1200000
elif vocoder_name == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
if not ref_audio_orig:
raise gr.Error("Please upload a sample audio file.")
if not gen_text.strip():
raise gr.Error("Please enter the text content to generate voice.")
if len(gen_text.split()) > 1000:
raise gr.Error("Please enter text content with less than 1000 words.")
try:
# Nếu người dùng nhập ref_text thì dùng, không thì để rỗng để tự động nhận diện
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text_input or "")
ref_text = unicodedata.normalize("NFC", ref_text.strip())
gen_text_ = unicodedata.normalize("NFC", gen_text.strip())
# --- BẮT ĐẦU: Thêm logic cache ---
cache_path = get_audio_cache_path(gen_text_, ref_audio_orig, model)
import soundfile as sf
if os.path.exists(cache_path):
print(f"Using cached audio: {cache_path}")
final_wave, final_sample_rate = sf.read(cache_path)
spectrogram = None
else:
final_wave, final_sample_rate, spectrogram = infer_process(
ref_audio, ref_text, gen_text_, ema_model, vocoder, speed=speed
)
print(f"[CACHE] Saved new audio to: {cache_path}")
sf.write(cache_path, final_wave, final_sample_rate)
# --- KẾT THÚC: Thêm logic cache ---
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
if spectrogram is not None:
save_spectrogram(spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path
except Exception as e:
raise gr.Error(f"Error generating voice: {e}")
# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎤 F5-TTS: Vietnamese Text-to-Speech Synthesis.
# The model was trained with approximately 1000 hours of data on a RTX 3090 GPU.
Enter text and upload a sample voice to generate natural speech.
""")
with gr.Row():
ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath")
ref_text = gr.Textbox(label="📝 Reference Transcript (optional)", placeholder="Nhập transcript tiếng Việt cho sample voice nếu có...", lines=2)
gen_text = gr.Textbox(label="📝 Text", placeholder="Enter the text to generate voice...", lines=3)
speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed")
btn_synthesize = gr.Button("🔥 Generate Voice")
with gr.Row():
output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy")
output_spectrogram = gr.Image(label="📊 Spectrogram")
model_limitations = gr.Textbox(
value="""1. This model may not perform well with numerical characters, dates, special characters, etc. => A text normalization module is needed.
2. The rhythm of some generated audios may be inconsistent or choppy => It is recommended to select clearly pronounced sample audios with minimal pauses for better synthesis quality.
3. Default, reference audio text uses the pho-whisper-medium model, which may not always accurately recognize Vietnamese, resulting in poor voice synthesis quality.
4. Inference with overly long paragraphs may produce poor results.""",
label="❗ Model Limitations",
lines=4,
interactive=False
)
btn_synthesize.click(infer_tts, inputs=[ref_audio, ref_text, gen_text, speed], outputs=[output_audio, output_spectrogram])
# Run Gradio with share=True to get a gradio.live link
# demo.queue().launch()
if __name__ == "__main__":
demo.queue().launch()