styletts2 / app2.py
PhuongLT
gemini version
81d41bd
# -*- coding: utf-8 -*-
"""
Gradio app.py - wired to your 'inference_one' implementation
- Reference voice: upload OR choose from train_ref/
- Uses phonemize_text(), compute_style(), inference_one() exactly like your snippet
- NOW: adds UI sliders for alpha and beta and threads them into inference
"""
import os
import time
import yaml
import numpy as np
import torch
import torchaudio
import librosa
import gradio as gr
from munch import Munch
# -----------------------------
# Reproducibility
# -----------------------------
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
np.random.seed(0)
# -----------------------------
# Device / sample-rate
# -----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR_OUT = 24000 # target audio rate for synthesis
# -----------------------------
# External modules from the project
# -----------------------------
from models import * # noqa: F401,F403
from utils import * # noqa: F401,F403
from models import build_model
from text_utils import TextCleaner
from Utils_extend_v1.PLBERT.util import load_plbert
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
textcleaner = TextCleaner()
# -----------------------------
# Config / model loading
# -----------------------------
from huggingface_hub import hf_hub_download
hf_hub_download(
repo_id="ltphuongunited/styletts2_vi",
filename="epoch_2nd_00058.pth",
local_dir="Models/multi_phoaudio_gemini",
local_dir_use_symlinks=False,
)
# CONFIG_PATH = os.getenv("MODEL_CONFIG", "Models/multi_phoaudio_gemini/config_phoaudio_gemini_small.yml")
# CHECKPOINT_PTH = os.getenv("MODEL_CKPT", "Models/multi_phoaudio_gemini/epoch_2nd_00058.pth")
CHECKPOINT_PTH = "Models/gemini_vi/gemini_2nd_00045.pth"
CONFIG_PATH = "Models/gemini_vi/config_gemini_vi_en.yml"
# Load config
config = yaml.safe_load(open(CONFIG_PATH))
# Build components
ASR_config = config.get("ASR_config", False)
ASR_path = config.get("ASR_path", False)
F0_path = config.get("F0_path", False)
PLBERT_dir = config.get("PLBERT_dir", False)
text_aligner = load_ASR_models(ASR_path, ASR_config)
pitch_extractor = load_F0_models(F0_path)
plbert = load_plbert(PLBERT_dir)
model_params = recursive_munch(config["model_params"])
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
# to device & eval
_ = [model[k].to(DEVICE) for k in model]
_ = [model[k].eval() for k in model]
# Load checkpoint
if not os.path.isfile(CHECKPOINT_PTH):
raise FileNotFoundError(f"Checkpoint not found at '{CHECKPOINT_PTH}'")
ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu")
params = ckpt["net"]
for key in model:
if key in params:
try:
model[key].load_state_dict(params[key])
except Exception:
from collections import OrderedDict
state_dict = params[key]
new_state = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # strip 'module.' if present
new_state[name] = v
model[key].load_state_dict(new_state, strict=False)
_ = [model[k].eval() for k in model]
# Diffusion sampler
sampler = DiffusionSampler(
model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
clamp=False,
)
# -----------------------------
# Audio helper: mel preprocessing
# -----------------------------
_to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
)
_MEAN, _STD = -4.0, 4.0
def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor:
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
return mask
def preprocess(wave: np.ndarray) -> torch.Tensor:
"""Same name as your snippet: np.float -> mel (normed)"""
wave_tensor = torch.from_numpy(wave).float()
mel_tensor = _to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - _MEAN) / _STD
return mel_tensor
# -----------------------------
# Phonemizer (vi)
# -----------------------------
import phonemizer
vi_phonemizer = phonemizer.backend.EspeakBackend(language="vi", preserve_punctuation=True, with_stress=True)
global_phonemizer = vi_phonemizer
def phonemize_text(text: str) -> str:
ps = global_phonemizer.phonemize([text])[0]
return ps.replace("(en)", "").replace("(vi)", "").strip()
# -----------------------------
# Style extractor (from file path)
# -----------------------------
def compute_style(model, path, device):
"""Compute style/prosody reference from a wav file path"""
wave, sr = librosa.load(path, sr=None, mono=True)
audio, _ = librosa.effects.trim(wave, top_db=30)
if sr != SR_OUT:
audio = librosa.resample(audio, sr, SR_OUT)
mel_tensor = preprocess(audio).to(device)
with torch.no_grad():
ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
return torch.cat([ref_s, ref_p], dim=1) # [1, 256]
# Style extractor (from numpy array)
def compute_style_from_numpy(model, arr: np.ndarray, sr: int, device):
if arr.ndim > 1:
arr = librosa.to_mono(arr.T)
audio, _ = librosa.effects.trim(arr, top_db=30)
if sr != SR_OUT:
audio = librosa.resample(audio, sr, SR_OUT)
mel_tensor = preprocess(audio).to(device)
with torch.no_grad():
ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
return torch.cat([ref_s, ref_p], dim=1)
# -----------------------------
# Inference (your exact logic)
# -----------------------------
# Tunables (still as defaults; UI will override)
ALPHA = 0.3
BETA = 0.7
DIFFUSION_STEPS = 5
EMBEDDING_SCALE = 1.0
def inference_one(text, ref_feat, ipa_text=None,
alpha=ALPHA, beta=BETA, diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE):
# text -> phonemes -> tokens
ps = ipa_text if ipa_text is not None else phonemize_text(text)
tokens = textcleaner(ps)
tokens.insert(0, 0) # prepend BOS
tokens = torch.LongTensor(tokens).to(DEVICE).unsqueeze(0) # [1, T]
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE)
text_mask = length_to_mask(input_lengths).to(DEVICE)
# encoders
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_d = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_d).transpose(-1, -2)
if alpha == 0 and beta == 0:
print("Ignore Diffusion")
ref = ref_feat[:, :128]
s = ref_feat[:, 128:]
simi_timbre, simi_prosody = 1,1
else:
print("Have Diffusion")
# diffusion for style latent
s_pred = sampler(
noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE),
embedding=bert_d,
embedding_scale=embedding_scale,
features=ref_feat, # [1, 256]
num_steps=diffusion_steps,
).squeeze(1) # [1, 256]
s = s_pred[:, 128:] # prosody
ref = s_pred[:, :128] # timbre
# blend with real ref features
ref = alpha * ref + (1 - alpha) * ref_feat[:, :128]
s = beta * s + (1 - beta) * ref_feat[:, 128:]
with torch.no_grad():
ref0 = ref_feat[:, :128] # timbre gốc
s0 = ref_feat[:, 128:] # prosody gốc
eps = 1e-8
def stats(name, new, base):
delta = new - base
l2_delta = torch.norm(delta, dim=1) # ||Δ||
l2_base = torch.norm(base, dim=1) + eps # ||x||
rel_l2 = (l2_delta / l2_base) # ||Δ|| / ||x||
mae = torch.mean(torch.abs(delta), dim=1) # MAE
cos_sim = F.cosine_similarity(new, base, dim=1) # cos(new, base)
snr_db = 20.0 * torch.log10(l2_base / (l2_delta + eps)) # SNR ~ 20*log10(||x||/||Δ||)
# # Inference batch thường =1, nhưng vẫn in theo batch để tổng quát
# for i in range(new.shape[0]):
# print(f"[{name}][sample {i}] "
# f"L2Δ={l2_delta[i]:.4f} | relL2={rel_l2[i]:.4f} | MAE={mae[i]:.6f} | "
# f"cos={cos_sim[i]:.4f} | SNR={snr_db[i]:.2f} dB")
return cos_sim
simi_timbre = stats("REF(timbre)", s_pred[:, :128], ref_feat[:, :128]).detach().cpu().squeeze().item()
simi_prosody = stats("S(prosody)", s_pred[:, 128:], ref_feat[:, 128:]).detach().cpu().squeeze().item()
# duration prediction
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
# alignment
T = int(pred_dur.sum().item())
pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE)
c = 0
for i in range(input_lengths.item()):
span = int(pred_dur[i].item())
pred_aln[i, c:c+span] = 1.0
c += span
# prosody enc
en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(en); asr_new[:, :, 0] = en[:, :, 0]; asr_new[:, :, 1:] = en[:, :, 0:-1]; en = asr_new
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
# content (ASR-aligned)
asr = (t_en @ pred_aln.unsqueeze(0))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(asr); asr_new[:, :, 0] = asr[:, :, 0]; asr_new[:, :, 1:] = asr[:, :, 0:-1]; asr = asr_new
# decode
out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
wav = out.squeeze().detach().cpu().numpy()
if wav.shape[-1] > 50:
wav = wav[..., :-50]
return wav, ps, simi_timbre, simi_prosody
# -----------------------------
# Gradio UI
# -----------------------------
SR_OUT = 24000
ROOT_REF = "ref_voice"
EXTS = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
# -------- scan ref_voice/<id>_<speaker>/*.wav --------
def scan_ref_voice(root=ROOT_REF):
"""
return:
speakers: list[str] # ví dụ: ["0_Fonos.vn", "1_James_A._Robinson", ...]
files_by_spk: dict[str, list[str]] # speaker_dir -> [full_path,...]
"""
speakers, files_by_spk = [], {}
if not os.path.isdir(root):
return speakers, files_by_spk
for spk_dir in sorted(os.listdir(root)):
full_dir = os.path.join(root, spk_dir)
if not os.path.isdir(full_dir) or spk_dir.startswith("."):
continue
lst = []
for fn in sorted(os.listdir(full_dir)):
if os.path.splitext(fn)[1].lower() in EXTS:
lst.append(os.path.join(full_dir, fn))
if lst:
speakers.append(spk_dir)
files_by_spk[spk_dir] = lst
return speakers, files_by_spk
SPEAKERS, FILES_BY_SPK = scan_ref_voice()
with gr.Blocks(title="StyleTTS2-vi Demo ✨") as demo:
gr.Markdown("# StyleTTS2-vi Demo ✨")
with gr.Row():
with gr.Column():
text_inp = gr.Textbox(label="Text", lines=4,
value="Thời tiết hôm nay tại Hà Nội, nhiệt độ khoảng 27 độ C, có nắng nhẹ, rất hợp lý để mình đi dạo công viên nhé.")
# --- 1 ô audio duy nhất (nhận filepath) ---
ref_audio = gr.Audio(
label="Reference Audio",
type="filepath", # nhận đường dẫn file
sources=["upload","microphone"], # vẫn cho upload/mic
interactive=True,
)
ref_path = gr.Textbox(label="Đường dẫn reference", interactive=False)
# --- chọn speaker -> hiện file tương ứng ---
spk_dd = gr.Dropdown(
label="Speaker",
choices=["(None)"] + SPEAKERS,
value="(None)",
)
file_dd = gr.Dropdown(
label="Voice in speaker",
choices=["(None)"],
value="(None)",
)
# khi chọn speaker -> cập nhật danh sách file
def on_pick_speaker(spk):
if spk == "(None)":
return gr.update(choices=["(None)"], value="(None)")
files = FILES_BY_SPK.get(spk, [])
# hiển thị chỉ tên file cho gọn
labels = [os.path.basename(p) for p in files]
# ta sẽ map label->path bằng index; set value = mục đầu tiên
return gr.update(choices=labels, value=(labels[0] if labels else "(None)"))
spk_dd.change(on_pick_speaker, inputs=spk_dd, outputs=file_dd)
# map label (basename) -> full path theo speaker hiện tại
def on_pick_file(spk, label):
if spk == "(None)" or label == "(None)":
return gr.update(value=None), ""
files = FILES_BY_SPK.get(spk, [])
# tìm đúng file theo basename
for p in files:
if os.path.basename(p) == label:
return gr.update(value=p), p # set vào Audio + hiển thị path
return gr.update(value=None), ""
file_dd.change(on_pick_file, inputs=[spk_dd, file_dd], outputs=[ref_audio, ref_path])
# nếu người dùng upload/mic thì hiển thị luôn đường dẫn file tạm
def on_audio_changed(fp):
return fp or ""
ref_audio.change(on_audio_changed, inputs=ref_audio, outputs=ref_path)
# --- NEW: alpha/beta numeric inputs ---
with gr.Row():
alpha_n = gr.Number(value=ALPHA, label="alpha (0-1, timbre)", precision=3)
beta_n = gr.Number(value=BETA, label="beta (0-1, prosody)", precision=3)
btn = gr.Button("Đọc 🔊🔥", variant="primary")
with gr.Column():
out_audio = gr.Audio(label="Synthesised Audio", type="numpy")
metrics = gr.JSON(label="Metrics")
# ---- Inference: xử lý từ filepath ----
def _run(text, ref_fp, alpha, beta):
# ref_fp là string path (do type='filepath')
if isinstance(ref_fp, str) and os.path.isfile(ref_fp):
wav, _ = librosa.load(ref_fp, sr=SR_OUT, mono=True)
ref_feat = compute_style_from_numpy(model, wav, SR_OUT, DEVICE)
ref_src = ref_fp
else:
ref_feat = torch.zeros(1, 256).to(DEVICE)
ref_src = "(None)"
t0 = time.time()
wav, ps, simi_timbre, simi_prosody = inference_one(text, ref_feat, alpha=float(alpha), beta=float(beta))
wav = wav.astype(np.float32)
gen_time = time.time() - t0
rtf = gen_time / max(1e-6, len(wav)/SR_OUT)
info = {
"simi_timbre": round(float(simi_timbre), 4) ,
"simi_prosody": round(float(simi_prosody), 4) ,
"Phonemes": ps,
"Sample rate": SR_OUT,
"RTF": round(float(rtf), 3),
"Device": DEVICE,
}
return (SR_OUT, wav), info
btn.click(_run, inputs=[text_inp, ref_audio, alpha_n, beta_n], outputs=[out_audio, metrics])
if __name__ == "__main__":
demo.launch()