Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Gradio app.py - wired to your 'inference_one' implementation | |
| - Reference voice: upload OR choose from train_ref/ | |
| - Uses phonemize_text(), compute_style(), inference_one() exactly like your snippet | |
| - NOW: adds UI sliders for alpha and beta and threads them into inference | |
| """ | |
| import os | |
| import time | |
| import yaml | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import librosa | |
| import gradio as gr | |
| from munch import Munch | |
| # ----------------------------- | |
| # Reproducibility | |
| # ----------------------------- | |
| torch.manual_seed(0) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| np.random.seed(0) | |
| # ----------------------------- | |
| # Device / sample-rate | |
| # ----------------------------- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SR_OUT = 24000 # target audio rate for synthesis | |
| # ----------------------------- | |
| # External modules from the project | |
| # ----------------------------- | |
| from models import * # noqa: F401,F403 | |
| from utils import * # noqa: F401,F403 | |
| from models import build_model | |
| from text_utils import TextCleaner | |
| from Utils_extend_v1.PLBERT.util import load_plbert | |
| from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule | |
| textcleaner = TextCleaner() | |
| # ----------------------------- | |
| # Config / model loading | |
| # ----------------------------- | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download( | |
| repo_id="ltphuongunited/styletts2_vi", | |
| filename="epoch_2nd_00058.pth", | |
| local_dir="Models/multi_phoaudio_gemini", | |
| local_dir_use_symlinks=False, | |
| ) | |
| # CONFIG_PATH = os.getenv("MODEL_CONFIG", "Models/multi_phoaudio_gemini/config_phoaudio_gemini_small.yml") | |
| # CHECKPOINT_PTH = os.getenv("MODEL_CKPT", "Models/multi_phoaudio_gemini/epoch_2nd_00058.pth") | |
| CHECKPOINT_PTH = "Models/gemini_vi/gemini_2nd_00045.pth" | |
| CONFIG_PATH = "Models/gemini_vi/config_gemini_vi_en.yml" | |
| # Load config | |
| config = yaml.safe_load(open(CONFIG_PATH)) | |
| # Build components | |
| ASR_config = config.get("ASR_config", False) | |
| ASR_path = config.get("ASR_path", False) | |
| F0_path = config.get("F0_path", False) | |
| PLBERT_dir = config.get("PLBERT_dir", False) | |
| text_aligner = load_ASR_models(ASR_path, ASR_config) | |
| pitch_extractor = load_F0_models(F0_path) | |
| plbert = load_plbert(PLBERT_dir) | |
| model_params = recursive_munch(config["model_params"]) | |
| model = build_model(model_params, text_aligner, pitch_extractor, plbert) | |
| # to device & eval | |
| _ = [model[k].to(DEVICE) for k in model] | |
| _ = [model[k].eval() for k in model] | |
| # Load checkpoint | |
| if not os.path.isfile(CHECKPOINT_PTH): | |
| raise FileNotFoundError(f"Checkpoint not found at '{CHECKPOINT_PTH}'") | |
| ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu") | |
| params = ckpt["net"] | |
| for key in model: | |
| if key in params: | |
| try: | |
| model[key].load_state_dict(params[key]) | |
| except Exception: | |
| from collections import OrderedDict | |
| state_dict = params[key] | |
| new_state = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = k[7:] # strip 'module.' if present | |
| new_state[name] = v | |
| model[key].load_state_dict(new_state, strict=False) | |
| _ = [model[k].eval() for k in model] | |
| # Diffusion sampler | |
| sampler = DiffusionSampler( | |
| model.diffusion.diffusion, | |
| sampler=ADPM2Sampler(), | |
| sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0), | |
| clamp=False, | |
| ) | |
| # ----------------------------- | |
| # Audio helper: mel preprocessing | |
| # ----------------------------- | |
| _to_mel = torchaudio.transforms.MelSpectrogram( | |
| n_mels=80, n_fft=2048, win_length=1200, hop_length=300 | |
| ) | |
| _MEAN, _STD = -4.0, 4.0 | |
| def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor: | |
| mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) | |
| mask = torch.gt(mask + 1, lengths.unsqueeze(1)) | |
| return mask | |
| def preprocess(wave: np.ndarray) -> torch.Tensor: | |
| """Same name as your snippet: np.float -> mel (normed)""" | |
| wave_tensor = torch.from_numpy(wave).float() | |
| mel_tensor = _to_mel(wave_tensor) | |
| mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - _MEAN) / _STD | |
| return mel_tensor | |
| # ----------------------------- | |
| # Phonemizer (vi) | |
| # ----------------------------- | |
| import phonemizer | |
| vi_phonemizer = phonemizer.backend.EspeakBackend(language="vi", preserve_punctuation=True, with_stress=True) | |
| global_phonemizer = vi_phonemizer | |
| def phonemize_text(text: str) -> str: | |
| ps = global_phonemizer.phonemize([text])[0] | |
| return ps.replace("(en)", "").replace("(vi)", "").strip() | |
| # ----------------------------- | |
| # Style extractor (from file path) | |
| # ----------------------------- | |
| def compute_style(model, path, device): | |
| """Compute style/prosody reference from a wav file path""" | |
| wave, sr = librosa.load(path, sr=None, mono=True) | |
| audio, _ = librosa.effects.trim(wave, top_db=30) | |
| if sr != SR_OUT: | |
| audio = librosa.resample(audio, sr, SR_OUT) | |
| mel_tensor = preprocess(audio).to(device) | |
| with torch.no_grad(): | |
| ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) | |
| ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) | |
| return torch.cat([ref_s, ref_p], dim=1) # [1, 256] | |
| # Style extractor (from numpy array) | |
| def compute_style_from_numpy(model, arr: np.ndarray, sr: int, device): | |
| if arr.ndim > 1: | |
| arr = librosa.to_mono(arr.T) | |
| audio, _ = librosa.effects.trim(arr, top_db=30) | |
| if sr != SR_OUT: | |
| audio = librosa.resample(audio, sr, SR_OUT) | |
| mel_tensor = preprocess(audio).to(device) | |
| with torch.no_grad(): | |
| ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) | |
| ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) | |
| return torch.cat([ref_s, ref_p], dim=1) | |
| # ----------------------------- | |
| # Inference (your exact logic) | |
| # ----------------------------- | |
| # Tunables (still as defaults; UI will override) | |
| ALPHA = 0.3 | |
| BETA = 0.7 | |
| DIFFUSION_STEPS = 5 | |
| EMBEDDING_SCALE = 1.0 | |
| def inference_one(text, ref_feat, ipa_text=None, | |
| alpha=ALPHA, beta=BETA, diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE): | |
| # text -> phonemes -> tokens | |
| ps = ipa_text if ipa_text is not None else phonemize_text(text) | |
| tokens = textcleaner(ps) | |
| tokens.insert(0, 0) # prepend BOS | |
| tokens = torch.LongTensor(tokens).to(DEVICE).unsqueeze(0) # [1, T] | |
| with torch.no_grad(): | |
| input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE) | |
| text_mask = length_to_mask(input_lengths).to(DEVICE) | |
| # encoders | |
| t_en = model.text_encoder(tokens, input_lengths, text_mask) | |
| bert_d = model.bert(tokens, attention_mask=(~text_mask).int()) | |
| d_en = model.bert_encoder(bert_d).transpose(-1, -2) | |
| if alpha == 0 and beta == 0: | |
| print("Ignore Diffusion") | |
| ref = ref_feat[:, :128] | |
| s = ref_feat[:, 128:] | |
| simi_timbre, simi_prosody = 1,1 | |
| else: | |
| print("Have Diffusion") | |
| # diffusion for style latent | |
| s_pred = sampler( | |
| noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE), | |
| embedding=bert_d, | |
| embedding_scale=embedding_scale, | |
| features=ref_feat, # [1, 256] | |
| num_steps=diffusion_steps, | |
| ).squeeze(1) # [1, 256] | |
| s = s_pred[:, 128:] # prosody | |
| ref = s_pred[:, :128] # timbre | |
| # blend with real ref features | |
| ref = alpha * ref + (1 - alpha) * ref_feat[:, :128] | |
| s = beta * s + (1 - beta) * ref_feat[:, 128:] | |
| with torch.no_grad(): | |
| ref0 = ref_feat[:, :128] # timbre gốc | |
| s0 = ref_feat[:, 128:] # prosody gốc | |
| eps = 1e-8 | |
| def stats(name, new, base): | |
| delta = new - base | |
| l2_delta = torch.norm(delta, dim=1) # ||Δ|| | |
| l2_base = torch.norm(base, dim=1) + eps # ||x|| | |
| rel_l2 = (l2_delta / l2_base) # ||Δ|| / ||x|| | |
| mae = torch.mean(torch.abs(delta), dim=1) # MAE | |
| cos_sim = F.cosine_similarity(new, base, dim=1) # cos(new, base) | |
| snr_db = 20.0 * torch.log10(l2_base / (l2_delta + eps)) # SNR ~ 20*log10(||x||/||Δ||) | |
| # # Inference batch thường =1, nhưng vẫn in theo batch để tổng quát | |
| # for i in range(new.shape[0]): | |
| # print(f"[{name}][sample {i}] " | |
| # f"L2Δ={l2_delta[i]:.4f} | relL2={rel_l2[i]:.4f} | MAE={mae[i]:.6f} | " | |
| # f"cos={cos_sim[i]:.4f} | SNR={snr_db[i]:.2f} dB") | |
| return cos_sim | |
| simi_timbre = stats("REF(timbre)", s_pred[:, :128], ref_feat[:, :128]).detach().cpu().squeeze().item() | |
| simi_prosody = stats("S(prosody)", s_pred[:, 128:], ref_feat[:, 128:]).detach().cpu().squeeze().item() | |
| # duration prediction | |
| d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) | |
| x, _ = model.predictor.lstm(d) | |
| duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1) | |
| pred_dur = torch.round(duration.squeeze()).clamp(min=1) | |
| # alignment | |
| T = int(pred_dur.sum().item()) | |
| pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE) | |
| c = 0 | |
| for i in range(input_lengths.item()): | |
| span = int(pred_dur[i].item()) | |
| pred_aln[i, c:c+span] = 1.0 | |
| c += span | |
| # prosody enc | |
| en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0)) | |
| if model_params.decoder.type == "hifigan": | |
| asr_new = torch.zeros_like(en); asr_new[:, :, 0] = en[:, :, 0]; asr_new[:, :, 1:] = en[:, :, 0:-1]; en = asr_new | |
| F0_pred, N_pred = model.predictor.F0Ntrain(en, s) | |
| # content (ASR-aligned) | |
| asr = (t_en @ pred_aln.unsqueeze(0)) | |
| if model_params.decoder.type == "hifigan": | |
| asr_new = torch.zeros_like(asr); asr_new[:, :, 0] = asr[:, :, 0]; asr_new[:, :, 1:] = asr[:, :, 0:-1]; asr = asr_new | |
| # decode | |
| out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) | |
| wav = out.squeeze().detach().cpu().numpy() | |
| if wav.shape[-1] > 50: | |
| wav = wav[..., :-50] | |
| return wav, ps, simi_timbre, simi_prosody | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| SR_OUT = 24000 | |
| ROOT_REF = "ref_voice" | |
| EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"} | |
| # -------- scan ref_voice/<id>_<speaker>/*.wav -------- | |
| def scan_ref_voice(root=ROOT_REF): | |
| """ | |
| return: | |
| speakers: list[str] # ví dụ: ["0_Fonos.vn", "1_James_A._Robinson", ...] | |
| files_by_spk: dict[str, list[str]] # speaker_dir -> [full_path,...] | |
| """ | |
| speakers, files_by_spk = [], {} | |
| if not os.path.isdir(root): | |
| return speakers, files_by_spk | |
| for spk_dir in sorted(os.listdir(root)): | |
| full_dir = os.path.join(root, spk_dir) | |
| if not os.path.isdir(full_dir) or spk_dir.startswith("."): | |
| continue | |
| lst = [] | |
| for fn in sorted(os.listdir(full_dir)): | |
| if os.path.splitext(fn)[1].lower() in EXTS: | |
| lst.append(os.path.join(full_dir, fn)) | |
| if lst: | |
| speakers.append(spk_dir) | |
| files_by_spk[spk_dir] = lst | |
| return speakers, files_by_spk | |
| SPEAKERS, FILES_BY_SPK = scan_ref_voice() | |
| with gr.Blocks(title="StyleTTS2-vi Demo ✨") as demo: | |
| gr.Markdown("# StyleTTS2-vi Demo ✨") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_inp = gr.Textbox(label="Text", lines=4, | |
| value="Thời tiết hôm nay tại Hà Nội, nhiệt độ khoảng 27 độ C, có nắng nhẹ, rất hợp lý để mình đi dạo công viên nhé.") | |
| # --- 1 ô audio duy nhất (nhận filepath) --- | |
| ref_audio = gr.Audio( | |
| label="Reference Audio", | |
| type="filepath", # nhận đường dẫn file | |
| sources=["upload","microphone"], # vẫn cho upload/mic | |
| interactive=True, | |
| ) | |
| ref_path = gr.Textbox(label="Đường dẫn reference", interactive=False) | |
| # --- chọn speaker -> hiện file tương ứng --- | |
| spk_dd = gr.Dropdown( | |
| label="Speaker", | |
| choices=["(None)"] + SPEAKERS, | |
| value="(None)", | |
| ) | |
| file_dd = gr.Dropdown( | |
| label="Voice in speaker", | |
| choices=["(None)"], | |
| value="(None)", | |
| ) | |
| # khi chọn speaker -> cập nhật danh sách file | |
| def on_pick_speaker(spk): | |
| if spk == "(None)": | |
| return gr.update(choices=["(None)"], value="(None)") | |
| files = FILES_BY_SPK.get(spk, []) | |
| # hiển thị chỉ tên file cho gọn | |
| labels = [os.path.basename(p) for p in files] | |
| # ta sẽ map label->path bằng index; set value = mục đầu tiên | |
| return gr.update(choices=labels, value=(labels[0] if labels else "(None)")) | |
| spk_dd.change(on_pick_speaker, inputs=spk_dd, outputs=file_dd) | |
| # map label (basename) -> full path theo speaker hiện tại | |
| def on_pick_file(spk, label): | |
| if spk == "(None)" or label == "(None)": | |
| return gr.update(value=None), "" | |
| files = FILES_BY_SPK.get(spk, []) | |
| # tìm đúng file theo basename | |
| for p in files: | |
| if os.path.basename(p) == label: | |
| return gr.update(value=p), p # set vào Audio + hiển thị path | |
| return gr.update(value=None), "" | |
| file_dd.change(on_pick_file, inputs=[spk_dd, file_dd], outputs=[ref_audio, ref_path]) | |
| # nếu người dùng upload/mic thì hiển thị luôn đường dẫn file tạm | |
| def on_audio_changed(fp): | |
| return fp or "" | |
| ref_audio.change(on_audio_changed, inputs=ref_audio, outputs=ref_path) | |
| # --- NEW: alpha/beta numeric inputs --- | |
| with gr.Row(): | |
| alpha_n = gr.Number(value=ALPHA, label="alpha (0-1, timbre)", precision=3) | |
| beta_n = gr.Number(value=BETA, label="beta (0-1, prosody)", precision=3) | |
| btn = gr.Button("Đọc 🔊🔥", variant="primary") | |
| with gr.Column(): | |
| out_audio = gr.Audio(label="Synthesised Audio", type="numpy") | |
| metrics = gr.JSON(label="Metrics") | |
| # ---- Inference: xử lý từ filepath ---- | |
| def _run(text, ref_fp, alpha, beta): | |
| # ref_fp là string path (do type='filepath') | |
| if isinstance(ref_fp, str) and os.path.isfile(ref_fp): | |
| wav, _ = librosa.load(ref_fp, sr=SR_OUT, mono=True) | |
| ref_feat = compute_style_from_numpy(model, wav, SR_OUT, DEVICE) | |
| ref_src = ref_fp | |
| else: | |
| ref_feat = torch.zeros(1, 256).to(DEVICE) | |
| ref_src = "(None)" | |
| t0 = time.time() | |
| wav, ps, simi_timbre, simi_prosody = inference_one(text, ref_feat, alpha=float(alpha), beta=float(beta)) | |
| wav = wav.astype(np.float32) | |
| gen_time = time.time() - t0 | |
| rtf = gen_time / max(1e-6, len(wav)/SR_OUT) | |
| info = { | |
| "simi_timbre": round(float(simi_timbre), 4) , | |
| "simi_prosody": round(float(simi_prosody), 4) , | |
| "Phonemes": ps, | |
| "Sample rate": SR_OUT, | |
| "RTF": round(float(rtf), 3), | |
| "Device": DEVICE, | |
| } | |
| return (SR_OUT, wav), info | |
| btn.click(_run, inputs=[text_inp, ref_audio, alpha_n, beta_n], outputs=[out_audio, metrics]) | |
| if __name__ == "__main__": | |
| demo.launch() | |