Spaces:
Runtime error
Runtime error
| import spaces | |
| import os | |
| import codecs | |
| from huggingface_hub import login | |
| import gradio as gr | |
| from cached_path import cached_path | |
| import tempfile | |
| from vinorm import TTSnorm | |
| from importlib.resources import files | |
| from f5_tts.model import DiT | |
| from f5_tts.infer.utils_infer import ( | |
| preprocess_ref_audio_text, | |
| load_vocoder, | |
| load_model, | |
| infer_process, | |
| save_spectrogram, | |
| target_sample_rate as default_target_sample_rate, | |
| n_mel_channels as default_n_mel_channels, | |
| hop_length as default_hop_length, | |
| win_length as default_win_length, | |
| n_fft as default_n_fft, | |
| mel_spec_type as default_mel_spec_type, | |
| target_rms as default_target_rms, | |
| cross_fade_duration as default_cross_fade_duration, | |
| ode_method as default_ode_method, | |
| nfe_step as default_nfe_step, # 16, 32 | |
| cfg_strength as default_cfg_strength, | |
| sway_sampling_coef as default_sway_sampling_coef, | |
| speed as default_speed, | |
| fix_duration as default_fix_duration | |
| ) | |
| from pathlib import Path | |
| from omegaconf import OmegaConf | |
| from datetime import datetime | |
| import hashlib | |
| import unicodedata | |
| # Retrieve token from secrets | |
| hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| # Log in to Hugging Face | |
| if hf_token: | |
| login(token=hf_token) | |
| # Hàm lấy đường dẫn file cache dựa trên text, ref_audio, model | |
| def get_audio_cache_path(text, ref_audio_path, model, cache_dir="tts_cache"): | |
| os.makedirs(cache_dir, exist_ok=True) | |
| hash_input = f"{text}|{ref_audio_path}|{model}" | |
| hash_val = hashlib.sha256(hash_input.encode("utf-8")).hexdigest() | |
| return os.path.join(cache_dir, f"{hash_val}.wav") | |
| def post_process(text): | |
| text = " " + text + " " | |
| text = text.replace(" . . ", " . ") | |
| text = " " + text + " " | |
| text = text.replace(" .. ", " . ") | |
| text = " " + text + " " | |
| text = text.replace(" , , ", " , ") | |
| text = " " + text + " " | |
| text = text.replace(" ,, ", " , ") | |
| text = " " + text + " " | |
| text = text.replace('"', "") | |
| return " ".join(text.split()) | |
| # Load models | |
| def infer_tts(ref_audio_orig: str, ref_text_input: str, gen_text: str, speed: float = 1.0, request: gr.Request = None): | |
| args = { | |
| "model": "F5TTS_Base", | |
| "ckpt_file": str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt")), | |
| "vocab_file": str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json")), | |
| "ref_audio": ref_audio_orig, | |
| "ref_text": ref_text_input, | |
| "gen_text": gen_text, | |
| "speed": speed | |
| } | |
| config = {} # tomli.load(open(args.config, "rb")) | |
| # command-line interface parameters | |
| model = args["model"] or config.get("model", "F5TTS_Base") | |
| ckpt_file = args["ckpt_file"] or config.get("ckpt_file", "") | |
| vocab_file = args["vocab_file"] or config.get("vocab_file", "") | |
| ref_audio = args["ref_audio"] or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav") | |
| ref_text = args["ref_text"] if args["ref_text"] is not None else config.get("ref_text", "Some call me nature, others call me mother nature.") | |
| gen_text = args["gen_text"] or config.get("gen_text", "Here we generate something just for test.") | |
| gen_file = args.get("gen_file", "") or config.get("gen_file", "") | |
| output_dir = args.get("output_dir", "") or config.get("output_dir", "tests") | |
| output_file = args.get("output_file", "") or config.get("output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav") | |
| save_chunk = args.get("save_chunk", False) or config.get("save_chunk", False) | |
| remove_silence = args.get("remove_silence", False) or config.get("remove_silence", False) | |
| load_vocoder_from_local = args.get("load_vocoder_from_local", False) or config.get("load_vocoder_from_local", False) | |
| vocoder_name = args.get("vocoder_name", "") or config.get("vocoder_name", default_mel_spec_type) | |
| target_rms = args.get("target_rms", None) or config.get("target_rms", default_target_rms) | |
| cross_fade_duration = args.get("cross_fade_duration", None) or config.get("cross_fade_duration", default_cross_fade_duration) | |
| nfe_step = args.get("nfe_step", None) or config.get("nfe_step", default_nfe_step) | |
| cfg_strength = args.get("cfg_strength", None) or config.get("cfg_strength", default_cfg_strength) | |
| sway_sampling_coef = args.get("sway_sampling_coef", None) or config.get("sway_sampling_coef", default_sway_sampling_coef) | |
| speed = args.get("speed", None) or config.get("speed", default_speed) | |
| fix_duration = args.get("fix_duration", None) or config.get("fix_duration", default_fix_duration) | |
| if "infer/examples/" in ref_audio: | |
| ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}")) | |
| if "infer/examples/" in gen_file: | |
| gen_file = str(files("f5_tts").joinpath(f"{gen_file}")) | |
| if "voices" in config: | |
| for voice in config["voices"]: | |
| voice_ref_audio = config["voices"][voice]["ref_audio"] | |
| if "infer/examples/" in voice_ref_audio: | |
| config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}")) | |
| # ignore gen_text if gen_file provided | |
| if gen_file: | |
| gen_text = codecs.open(gen_file, "r", "utf-8").read() | |
| # output path | |
| wave_path = Path(output_dir) / output_file | |
| # spectrogram_path = Path(output_dir) / "infer_cli_out.png" | |
| if save_chunk: | |
| output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks") | |
| if not os.path.exists(output_chunk_dir): | |
| os.makedirs(output_chunk_dir) | |
| # load vocoder | |
| if vocoder_name == "vocos": | |
| vocoder_local_path = "../checkpoints/vocos-mel-24khz" | |
| elif vocoder_name == "bigvgan": | |
| vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" | |
| vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path) | |
| # load TTS model | |
| model_cfg = OmegaConf.load( | |
| config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) | |
| ).model | |
| model_cls = globals()[model_cfg.backbone] | |
| repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" | |
| if model != "F5TTS_Base": | |
| assert vocoder_name == model_cfg.mel_spec.mel_spec_type | |
| # override for previous models | |
| if model == "F5TTS_Base": | |
| if vocoder_name == "vocos": | |
| ckpt_step = 1200000 | |
| elif vocoder_name == "bigvgan": | |
| model = "F5TTS_Base_bigvgan" | |
| ckpt_type = "pt" | |
| elif model == "E2TTS_Base": | |
| repo_name = "E2-TTS" | |
| ckpt_step = 1200000 | |
| if not ckpt_file: | |
| ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) | |
| print(f"Using {model}...") | |
| ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) | |
| if not ref_audio_orig: | |
| raise gr.Error("Please upload a sample audio file.") | |
| if not gen_text.strip(): | |
| raise gr.Error("Please enter the text content to generate voice.") | |
| if len(gen_text.split()) > 1000: | |
| raise gr.Error("Please enter text content with less than 1000 words.") | |
| try: | |
| # Nếu người dùng nhập ref_text thì dùng, không thì để rỗng để tự động nhận diện | |
| ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text_input or "") | |
| ref_text = unicodedata.normalize("NFC", ref_text.strip()) | |
| gen_text_ = unicodedata.normalize("NFC", gen_text.strip()) | |
| # --- BẮT ĐẦU: Thêm logic cache --- | |
| cache_path = get_audio_cache_path(gen_text_, ref_audio_orig, model) | |
| import soundfile as sf | |
| if os.path.exists(cache_path): | |
| print(f"Using cached audio: {cache_path}") | |
| final_wave, final_sample_rate = sf.read(cache_path) | |
| spectrogram = None | |
| else: | |
| final_wave, final_sample_rate, spectrogram = infer_process( | |
| ref_audio, ref_text, gen_text_, ema_model, vocoder, speed=speed | |
| ) | |
| print(f"[CACHE] Saved new audio to: {cache_path}") | |
| sf.write(cache_path, final_wave, final_sample_rate) | |
| # --- KẾT THÚC: Thêm logic cache --- | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: | |
| spectrogram_path = tmp_spectrogram.name | |
| if spectrogram is not None: | |
| save_spectrogram(spectrogram, spectrogram_path) | |
| return (final_sample_rate, final_wave), spectrogram_path | |
| except Exception as e: | |
| raise gr.Error(f"Error generating voice: {e}") | |
| # Gradio UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎤 F5-TTS: Vietnamese Text-to-Speech Synthesis. | |
| # The model was trained with approximately 1000 hours of data on a RTX 3090 GPU. | |
| Enter text and upload a sample voice to generate natural speech. | |
| """) | |
| with gr.Row(): | |
| ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath") | |
| ref_text = gr.Textbox(label="📝 Reference Transcript (optional)", placeholder="Nhập transcript tiếng Việt cho sample voice nếu có...", lines=2) | |
| gen_text = gr.Textbox(label="📝 Text", placeholder="Enter the text to generate voice...", lines=3) | |
| speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed") | |
| btn_synthesize = gr.Button("🔥 Generate Voice") | |
| with gr.Row(): | |
| output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy") | |
| output_spectrogram = gr.Image(label="📊 Spectrogram") | |
| model_limitations = gr.Textbox( | |
| value="""1. This model may not perform well with numerical characters, dates, special characters, etc. => A text normalization module is needed. | |
| 2. The rhythm of some generated audios may be inconsistent or choppy => It is recommended to select clearly pronounced sample audios with minimal pauses for better synthesis quality. | |
| 3. Default, reference audio text uses the pho-whisper-medium model, which may not always accurately recognize Vietnamese, resulting in poor voice synthesis quality. | |
| 4. Inference with overly long paragraphs may produce poor results.""", | |
| label="❗ Model Limitations", | |
| lines=4, | |
| interactive=False | |
| ) | |
| btn_synthesize.click(infer_tts, inputs=[ref_audio, ref_text, gen_text, speed], outputs=[output_audio, output_spectrogram]) | |
| # Run Gradio with share=True to get a gradio.live link | |
| # demo.queue().launch() | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |