import spaces import os import codecs from huggingface_hub import login import gradio as gr from cached_path import cached_path import tempfile from vinorm import TTSnorm from importlib.resources import files from f5_tts.model import DiT from f5_tts.infer.utils_infer import ( preprocess_ref_audio_text, load_vocoder, load_model, infer_process, save_spectrogram, target_sample_rate as default_target_sample_rate, n_mel_channels as default_n_mel_channels, hop_length as default_hop_length, win_length as default_win_length, n_fft as default_n_fft, mel_spec_type as default_mel_spec_type, target_rms as default_target_rms, cross_fade_duration as default_cross_fade_duration, ode_method as default_ode_method, nfe_step as default_nfe_step, # 16, 32 cfg_strength as default_cfg_strength, sway_sampling_coef as default_sway_sampling_coef, speed as default_speed, fix_duration as default_fix_duration ) from pathlib import Path from omegaconf import OmegaConf from datetime import datetime import hashlib import unicodedata # Retrieve token from secrets hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") # Log in to Hugging Face if hf_token: login(token=hf_token) # Hàm lấy đường dẫn file cache dựa trên text, ref_audio, model def get_audio_cache_path(text, ref_audio_path, model, cache_dir="tts_cache"): os.makedirs(cache_dir, exist_ok=True) hash_input = f"{text}|{ref_audio_path}|{model}" hash_val = hashlib.sha256(hash_input.encode("utf-8")).hexdigest() return os.path.join(cache_dir, f"{hash_val}.wav") def post_process(text): text = " " + text + " " text = text.replace(" . . ", " . ") text = " " + text + " " text = text.replace(" .. ", " . ") text = " " + text + " " text = text.replace(" , , ", " , ") text = " " + text + " " text = text.replace(" ,, ", " , ") text = " " + text + " " text = text.replace('"', "") return " ".join(text.split()) # Load models @spaces.GPU def infer_tts(ref_audio_orig: str, ref_text_input: str, gen_text: str, speed: float = 1.0, request: gr.Request = None): args = { "model": "F5TTS_Base", "ckpt_file": str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt")), "vocab_file": str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json")), "ref_audio": ref_audio_orig, "ref_text": ref_text_input, "gen_text": gen_text, "speed": speed } config = {} # tomli.load(open(args.config, "rb")) # command-line interface parameters model = args["model"] or config.get("model", "F5TTS_Base") ckpt_file = args["ckpt_file"] or config.get("ckpt_file", "") vocab_file = args["vocab_file"] or config.get("vocab_file", "") ref_audio = args["ref_audio"] or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav") ref_text = args["ref_text"] if args["ref_text"] is not None else config.get("ref_text", "Some call me nature, others call me mother nature.") gen_text = args["gen_text"] or config.get("gen_text", "Here we generate something just for test.") gen_file = args.get("gen_file", "") or config.get("gen_file", "") output_dir = args.get("output_dir", "") or config.get("output_dir", "tests") output_file = args.get("output_file", "") or config.get("output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav") save_chunk = args.get("save_chunk", False) or config.get("save_chunk", False) remove_silence = args.get("remove_silence", False) or config.get("remove_silence", False) load_vocoder_from_local = args.get("load_vocoder_from_local", False) or config.get("load_vocoder_from_local", False) vocoder_name = args.get("vocoder_name", "") or config.get("vocoder_name", default_mel_spec_type) target_rms = args.get("target_rms", None) or config.get("target_rms", default_target_rms) cross_fade_duration = args.get("cross_fade_duration", None) or config.get("cross_fade_duration", default_cross_fade_duration) nfe_step = args.get("nfe_step", None) or config.get("nfe_step", default_nfe_step) cfg_strength = args.get("cfg_strength", None) or config.get("cfg_strength", default_cfg_strength) sway_sampling_coef = args.get("sway_sampling_coef", None) or config.get("sway_sampling_coef", default_sway_sampling_coef) speed = args.get("speed", None) or config.get("speed", default_speed) fix_duration = args.get("fix_duration", None) or config.get("fix_duration", default_fix_duration) if "infer/examples/" in ref_audio: ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}")) if "infer/examples/" in gen_file: gen_file = str(files("f5_tts").joinpath(f"{gen_file}")) if "voices" in config: for voice in config["voices"]: voice_ref_audio = config["voices"][voice]["ref_audio"] if "infer/examples/" in voice_ref_audio: config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}")) # ignore gen_text if gen_file provided if gen_file: gen_text = codecs.open(gen_file, "r", "utf-8").read() # output path wave_path = Path(output_dir) / output_file # spectrogram_path = Path(output_dir) / "infer_cli_out.png" if save_chunk: output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks") if not os.path.exists(output_chunk_dir): os.makedirs(output_chunk_dir) # load vocoder if vocoder_name == "vocos": vocoder_local_path = "../checkpoints/vocos-mel-24khz" elif vocoder_name == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path) # load TTS model model_cfg = OmegaConf.load( config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) ).model model_cls = globals()[model_cfg.backbone] repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" if model != "F5TTS_Base": assert vocoder_name == model_cfg.mel_spec.mel_spec_type # override for previous models if model == "F5TTS_Base": if vocoder_name == "vocos": ckpt_step = 1200000 elif vocoder_name == "bigvgan": model = "F5TTS_Base_bigvgan" ckpt_type = "pt" elif model == "E2TTS_Base": repo_name = "E2-TTS" ckpt_step = 1200000 if not ckpt_file: ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) print(f"Using {model}...") ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) if not ref_audio_orig: raise gr.Error("Please upload a sample audio file.") if not gen_text.strip(): raise gr.Error("Please enter the text content to generate voice.") if len(gen_text.split()) > 1000: raise gr.Error("Please enter text content with less than 1000 words.") try: # Nếu người dùng nhập ref_text thì dùng, không thì để rỗng để tự động nhận diện ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text_input or "") ref_text = unicodedata.normalize("NFC", ref_text.strip()) gen_text_ = unicodedata.normalize("NFC", gen_text.strip()) # --- BẮT ĐẦU: Thêm logic cache --- cache_path = get_audio_cache_path(gen_text_, ref_audio_orig, model) import soundfile as sf if os.path.exists(cache_path): print(f"Using cached audio: {cache_path}") final_wave, final_sample_rate = sf.read(cache_path) spectrogram = None else: final_wave, final_sample_rate, spectrogram = infer_process( ref_audio, ref_text, gen_text_, ema_model, vocoder, speed=speed ) print(f"[CACHE] Saved new audio to: {cache_path}") sf.write(cache_path, final_wave, final_sample_rate) # --- KẾT THÚC: Thêm logic cache --- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: spectrogram_path = tmp_spectrogram.name if spectrogram is not None: save_spectrogram(spectrogram, spectrogram_path) return (final_sample_rate, final_wave), spectrogram_path except Exception as e: raise gr.Error(f"Error generating voice: {e}") # Gradio UI with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎤 F5-TTS: Vietnamese Text-to-Speech Synthesis. # The model was trained with approximately 1000 hours of data on a RTX 3090 GPU. Enter text and upload a sample voice to generate natural speech. """) with gr.Row(): ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath") ref_text = gr.Textbox(label="📝 Reference Transcript (optional)", placeholder="Nhập transcript tiếng Việt cho sample voice nếu có...", lines=2) gen_text = gr.Textbox(label="📝 Text", placeholder="Enter the text to generate voice...", lines=3) speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed") btn_synthesize = gr.Button("🔥 Generate Voice") with gr.Row(): output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy") output_spectrogram = gr.Image(label="📊 Spectrogram") model_limitations = gr.Textbox( value="""1. This model may not perform well with numerical characters, dates, special characters, etc. => A text normalization module is needed. 2. The rhythm of some generated audios may be inconsistent or choppy => It is recommended to select clearly pronounced sample audios with minimal pauses for better synthesis quality. 3. Default, reference audio text uses the pho-whisper-medium model, which may not always accurately recognize Vietnamese, resulting in poor voice synthesis quality. 4. Inference with overly long paragraphs may produce poor results.""", label="❗ Model Limitations", lines=4, interactive=False ) btn_synthesize.click(infer_tts, inputs=[ref_audio, ref_text, gen_text, speed], outputs=[output_audio, output_spectrogram]) # Run Gradio with share=True to get a gradio.live link # demo.queue().launch() if __name__ == "__main__": demo.queue().launch()