# -*- coding: utf-8 -*- """ Gradio app.py - wired to your 'inference_one' implementation - Reference voice: upload OR choose from train_ref/ - Uses phonemize_text(), compute_style(), inference_one() exactly like your snippet - NOW: adds UI sliders for alpha and beta and threads them into inference """ import os import time import yaml import numpy as np import torch import torchaudio import librosa import gradio as gr from munch import Munch # ----------------------------- # Reproducibility # ----------------------------- torch.manual_seed(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True np.random.seed(0) # ----------------------------- # Device / sample-rate # ----------------------------- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SR_OUT = 24000 # target audio rate for synthesis # ----------------------------- # External modules from the project # ----------------------------- from models import * # noqa: F401,F403 from utils import * # noqa: F401,F403 from models import build_model from text_utils import TextCleaner from Utils_extend_v1.PLBERT.util import load_plbert from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule textcleaner = TextCleaner() # ----------------------------- # Config / model loading # ----------------------------- from huggingface_hub import hf_hub_download hf_hub_download( repo_id="ltphuongunited/styletts2_vi", filename="epoch_2nd_00058.pth", local_dir="Models/multi_phoaudio_gemini", local_dir_use_symlinks=False, ) # CONFIG_PATH = os.getenv("MODEL_CONFIG", "Models/multi_phoaudio_gemini/config_phoaudio_gemini_small.yml") # CHECKPOINT_PTH = os.getenv("MODEL_CKPT", "Models/multi_phoaudio_gemini/epoch_2nd_00058.pth") CHECKPOINT_PTH = "Models/gemini_vi/gemini_2nd_00045.pth" CONFIG_PATH = "Models/gemini_vi/config_gemini_vi_en.yml" # Load config config = yaml.safe_load(open(CONFIG_PATH)) # Build components ASR_config = config.get("ASR_config", False) ASR_path = config.get("ASR_path", False) F0_path = config.get("F0_path", False) PLBERT_dir = config.get("PLBERT_dir", False) text_aligner = load_ASR_models(ASR_path, ASR_config) pitch_extractor = load_F0_models(F0_path) plbert = load_plbert(PLBERT_dir) model_params = recursive_munch(config["model_params"]) model = build_model(model_params, text_aligner, pitch_extractor, plbert) # to device & eval _ = [model[k].to(DEVICE) for k in model] _ = [model[k].eval() for k in model] # Load checkpoint if not os.path.isfile(CHECKPOINT_PTH): raise FileNotFoundError(f"Checkpoint not found at '{CHECKPOINT_PTH}'") ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu") params = ckpt["net"] for key in model: if key in params: try: model[key].load_state_dict(params[key]) except Exception: from collections import OrderedDict state_dict = params[key] new_state = OrderedDict() for k, v in state_dict.items(): name = k[7:] # strip 'module.' if present new_state[name] = v model[key].load_state_dict(new_state, strict=False) _ = [model[k].eval() for k in model] # Diffusion sampler sampler = DiffusionSampler( model.diffusion.diffusion, sampler=ADPM2Sampler(), sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0), clamp=False, ) # ----------------------------- # Audio helper: mel preprocessing # ----------------------------- _to_mel = torchaudio.transforms.MelSpectrogram( n_mels=80, n_fft=2048, win_length=1200, hop_length=300 ) _MEAN, _STD = -4.0, 4.0 def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor: mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.gt(mask + 1, lengths.unsqueeze(1)) return mask def preprocess(wave: np.ndarray) -> torch.Tensor: """Same name as your snippet: np.float -> mel (normed)""" wave_tensor = torch.from_numpy(wave).float() mel_tensor = _to_mel(wave_tensor) mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - _MEAN) / _STD return mel_tensor # ----------------------------- # Phonemizer (vi) # ----------------------------- import phonemizer vi_phonemizer = phonemizer.backend.EspeakBackend(language="vi", preserve_punctuation=True, with_stress=True) global_phonemizer = vi_phonemizer def phonemize_text(text: str) -> str: ps = global_phonemizer.phonemize([text])[0] return ps.replace("(en)", "").replace("(vi)", "").strip() # ----------------------------- # Style extractor (from file path) # ----------------------------- def compute_style(model, path, device): """Compute style/prosody reference from a wav file path""" wave, sr = librosa.load(path, sr=None, mono=True) audio, _ = librosa.effects.trim(wave, top_db=30) if sr != SR_OUT: audio = librosa.resample(audio, sr, SR_OUT) mel_tensor = preprocess(audio).to(device) with torch.no_grad(): ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) return torch.cat([ref_s, ref_p], dim=1) # [1, 256] # Style extractor (from numpy array) def compute_style_from_numpy(model, arr: np.ndarray, sr: int, device): if arr.ndim > 1: arr = librosa.to_mono(arr.T) audio, _ = librosa.effects.trim(arr, top_db=30) if sr != SR_OUT: audio = librosa.resample(audio, sr, SR_OUT) mel_tensor = preprocess(audio).to(device) with torch.no_grad(): ref_s = model.style_encoder(mel_tensor.unsqueeze(1)) ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1)) return torch.cat([ref_s, ref_p], dim=1) # ----------------------------- # Inference (your exact logic) # ----------------------------- # Tunables (still as defaults; UI will override) ALPHA = 0.3 BETA = 0.7 DIFFUSION_STEPS = 5 EMBEDDING_SCALE = 1.0 def inference_one(text, ref_feat, ipa_text=None, alpha=ALPHA, beta=BETA, diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE): # text -> phonemes -> tokens ps = ipa_text if ipa_text is not None else phonemize_text(text) tokens = textcleaner(ps) tokens.insert(0, 0) # prepend BOS tokens = torch.LongTensor(tokens).to(DEVICE).unsqueeze(0) # [1, T] with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE) text_mask = length_to_mask(input_lengths).to(DEVICE) # encoders t_en = model.text_encoder(tokens, input_lengths, text_mask) bert_d = model.bert(tokens, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_d).transpose(-1, -2) if alpha == 0 and beta == 0: print("Ignore Diffusion") ref = ref_feat[:, :128] s = ref_feat[:, 128:] simi_timbre, simi_prosody = 1,1 else: print("Have Diffusion") # diffusion for style latent s_pred = sampler( noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE), embedding=bert_d, embedding_scale=embedding_scale, features=ref_feat, # [1, 256] num_steps=diffusion_steps, ).squeeze(1) # [1, 256] s = s_pred[:, 128:] # prosody ref = s_pred[:, :128] # timbre # blend with real ref features ref = alpha * ref + (1 - alpha) * ref_feat[:, :128] s = beta * s + (1 - beta) * ref_feat[:, 128:] with torch.no_grad(): ref0 = ref_feat[:, :128] # timbre gốc s0 = ref_feat[:, 128:] # prosody gốc eps = 1e-8 def stats(name, new, base): delta = new - base l2_delta = torch.norm(delta, dim=1) # ||Δ|| l2_base = torch.norm(base, dim=1) + eps # ||x|| rel_l2 = (l2_delta / l2_base) # ||Δ|| / ||x|| mae = torch.mean(torch.abs(delta), dim=1) # MAE cos_sim = F.cosine_similarity(new, base, dim=1) # cos(new, base) snr_db = 20.0 * torch.log10(l2_base / (l2_delta + eps)) # SNR ~ 20*log10(||x||/||Δ||) # # Inference batch thường =1, nhưng vẫn in theo batch để tổng quát # for i in range(new.shape[0]): # print(f"[{name}][sample {i}] " # f"L2Δ={l2_delta[i]:.4f} | relL2={rel_l2[i]:.4f} | MAE={mae[i]:.6f} | " # f"cos={cos_sim[i]:.4f} | SNR={snr_db[i]:.2f} dB") return cos_sim simi_timbre = stats("REF(timbre)", s_pred[:, :128], ref_feat[:, :128]).detach().cpu().squeeze().item() simi_prosody = stats("S(prosody)", s_pred[:, 128:], ref_feat[:, 128:]).detach().cpu().squeeze().item() # duration prediction d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = model.predictor.lstm(d) duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1) pred_dur = torch.round(duration.squeeze()).clamp(min=1) # alignment T = int(pred_dur.sum().item()) pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE) c = 0 for i in range(input_lengths.item()): span = int(pred_dur[i].item()) pred_aln[i, c:c+span] = 1.0 c += span # prosody enc en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0)) if model_params.decoder.type == "hifigan": asr_new = torch.zeros_like(en); asr_new[:, :, 0] = en[:, :, 0]; asr_new[:, :, 1:] = en[:, :, 0:-1]; en = asr_new F0_pred, N_pred = model.predictor.F0Ntrain(en, s) # content (ASR-aligned) asr = (t_en @ pred_aln.unsqueeze(0)) if model_params.decoder.type == "hifigan": asr_new = torch.zeros_like(asr); asr_new[:, :, 0] = asr[:, :, 0]; asr_new[:, :, 1:] = asr[:, :, 0:-1]; asr = asr_new # decode out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) wav = out.squeeze().detach().cpu().numpy() if wav.shape[-1] > 50: wav = wav[..., :-50] return wav, ps, simi_timbre, simi_prosody # ----------------------------- # Gradio UI # ----------------------------- SR_OUT = 24000 ROOT_REF = "ref_voice" EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"} # -------- scan ref_voice/_/*.wav -------- def scan_ref_voice(root=ROOT_REF): """ return: speakers: list[str] # ví dụ: ["0_Fonos.vn", "1_James_A._Robinson", ...] files_by_spk: dict[str, list[str]] # speaker_dir -> [full_path,...] """ speakers, files_by_spk = [], {} if not os.path.isdir(root): return speakers, files_by_spk for spk_dir in sorted(os.listdir(root)): full_dir = os.path.join(root, spk_dir) if not os.path.isdir(full_dir) or spk_dir.startswith("."): continue lst = [] for fn in sorted(os.listdir(full_dir)): if os.path.splitext(fn)[1].lower() in EXTS: lst.append(os.path.join(full_dir, fn)) if lst: speakers.append(spk_dir) files_by_spk[spk_dir] = lst return speakers, files_by_spk SPEAKERS, FILES_BY_SPK = scan_ref_voice() with gr.Blocks(title="StyleTTS2-vi Demo ✨") as demo: gr.Markdown("# StyleTTS2-vi Demo ✨") with gr.Row(): with gr.Column(): text_inp = gr.Textbox(label="Text", lines=4, value="Thời tiết hôm nay tại Hà Nội, nhiệt độ khoảng 27 độ C, có nắng nhẹ, rất hợp lý để mình đi dạo công viên nhé.") # --- 1 ô audio duy nhất (nhận filepath) --- ref_audio = gr.Audio( label="Reference Audio", type="filepath", # nhận đường dẫn file sources=["upload","microphone"], # vẫn cho upload/mic interactive=True, ) ref_path = gr.Textbox(label="Đường dẫn reference", interactive=False) # --- chọn speaker -> hiện file tương ứng --- spk_dd = gr.Dropdown( label="Speaker", choices=["(None)"] + SPEAKERS, value="(None)", ) file_dd = gr.Dropdown( label="Voice in speaker", choices=["(None)"], value="(None)", ) # khi chọn speaker -> cập nhật danh sách file def on_pick_speaker(spk): if spk == "(None)": return gr.update(choices=["(None)"], value="(None)") files = FILES_BY_SPK.get(spk, []) # hiển thị chỉ tên file cho gọn labels = [os.path.basename(p) for p in files] # ta sẽ map label->path bằng index; set value = mục đầu tiên return gr.update(choices=labels, value=(labels[0] if labels else "(None)")) spk_dd.change(on_pick_speaker, inputs=spk_dd, outputs=file_dd) # map label (basename) -> full path theo speaker hiện tại def on_pick_file(spk, label): if spk == "(None)" or label == "(None)": return gr.update(value=None), "" files = FILES_BY_SPK.get(spk, []) # tìm đúng file theo basename for p in files: if os.path.basename(p) == label: return gr.update(value=p), p # set vào Audio + hiển thị path return gr.update(value=None), "" file_dd.change(on_pick_file, inputs=[spk_dd, file_dd], outputs=[ref_audio, ref_path]) # nếu người dùng upload/mic thì hiển thị luôn đường dẫn file tạm def on_audio_changed(fp): return fp or "" ref_audio.change(on_audio_changed, inputs=ref_audio, outputs=ref_path) # --- NEW: alpha/beta numeric inputs --- with gr.Row(): alpha_n = gr.Number(value=ALPHA, label="alpha (0-1, timbre)", precision=3) beta_n = gr.Number(value=BETA, label="beta (0-1, prosody)", precision=3) btn = gr.Button("Đọc 🔊🔥", variant="primary") with gr.Column(): out_audio = gr.Audio(label="Synthesised Audio", type="numpy") metrics = gr.JSON(label="Metrics") # ---- Inference: xử lý từ filepath ---- def _run(text, ref_fp, alpha, beta): # ref_fp là string path (do type='filepath') if isinstance(ref_fp, str) and os.path.isfile(ref_fp): wav, _ = librosa.load(ref_fp, sr=SR_OUT, mono=True) ref_feat = compute_style_from_numpy(model, wav, SR_OUT, DEVICE) ref_src = ref_fp else: ref_feat = torch.zeros(1, 256).to(DEVICE) ref_src = "(None)" t0 = time.time() wav, ps, simi_timbre, simi_prosody = inference_one(text, ref_feat, alpha=float(alpha), beta=float(beta)) wav = wav.astype(np.float32) gen_time = time.time() - t0 rtf = gen_time / max(1e-6, len(wav)/SR_OUT) info = { "simi_timbre": round(float(simi_timbre), 4) , "simi_prosody": round(float(simi_prosody), 4) , "Phonemes": ps, "Sample rate": SR_OUT, "RTF": round(float(rtf), 3), "Device": DEVICE, } return (SR_OUT, wav), info btn.click(_run, inputs=[text_inp, ref_audio, alpha_n, beta_n], outputs=[out_audio, metrics]) if __name__ == "__main__": demo.launch()