# -*- coding: utf-8 -*- """ Gradio app.py - StyleTTS2-vi with precomputed style embeddings (.pth) - UI có alpha/beta/metrics - Style Mixer: 4 slot cố định (Kore, Puck, Algenib, Leda), chỉ chỉnh weight; auto-normalize - Luôn hiển thị 4 reference samples (accordion) - Không còn dropdown speaker & reference sample auto """ import os, re, glob, time, yaml, torch, librosa, numpy as np, gradio as gr from munch import Munch from soe_vinorm import SoeNormalizer # ============================================================== # Cấu hình cơ bản # ============================================================== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SR_OUT = 24000 ALPHA, BETA, DIFFUSION_STEPS, EMBEDDING_SCALE = 0.0, 0.0, 5, 1.0 REF_DIR = "ref_voice" # thư mục chứa audio mẫu (.wav) # ============================================================== # Import module StyleTTS2 # ============================================================== from models import * from utils import * from models import build_model from text_utils import TextCleaner from Utils_extend_v1.PLBERT.util import load_plbert from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule textcleaner = TextCleaner() # ============================================================== # Load model và checkpoint # ============================================================== from huggingface_hub import hf_hub_download hf_hub_download( repo_id="ltphuongunited/styletts2_vi", filename="gemini_2nd_00045.pth", local_dir="Models/gemini_vi", local_dir_use_symlinks=False, ) CHECKPOINT_PTH = "Models/gemini_vi/gemini_2nd_00045.pth" CONFIG_PATH = "Models/gemini_vi/config_gemini_vi_en.yml" config = yaml.safe_load(open(CONFIG_PATH)) ASR_config = config.get("ASR_config", False) ASR_path = config.get("ASR_path", False) F0_path = config.get("F0_path", False) PLBERT_dir = config.get("PLBERT_dir", False) text_aligner = load_ASR_models(ASR_path, ASR_config) pitch_extractor = load_F0_models(F0_path) plbert = load_plbert(PLBERT_dir) model_params = recursive_munch(config["model_params"]) model = build_model(model_params, text_aligner, pitch_extractor, plbert) _ = [model[k].to(DEVICE) for k in model] _ = [model[k].eval() for k in model] ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu")["net"] for key in model: if key in ckpt: try: model[key].load_state_dict(ckpt[key]) except Exception: from collections import OrderedDict new_state = OrderedDict() for k, v in ckpt[key].items(): new_state[k[7:]] = v model[key].load_state_dict(new_state, strict=False) sampler = DiffusionSampler( model.diffusion.diffusion, sampler=ADPM2Sampler(), sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0), clamp=False, ) # ============================================================== # Phonemizer # ============================================================== import phonemizer vi_phonemizer = phonemizer.backend.EspeakBackend( language="vi", preserve_punctuation=True, with_stress=True ) def phonemize_text(text: str) -> str: ps = vi_phonemizer.phonemize([text])[0] return ps.replace("(en)", "").replace("(vi)", "").strip() def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor: mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.gt(mask + 1, lengths.unsqueeze(1)) return mask # ============================================================== # Load style embeddings đã tính sẵn # ============================================================== STYLE_PTH = "Models/styles_speaker_parallel.pth" print(f"Loading precomputed styles: {STYLE_PTH}") styles_dict = torch.load(STYLE_PTH, map_location=DEVICE) # fallback speaker nếu mixer rỗng SPEAKER_ORDER_PREF = ["Kore", "Puck", "Algenib", "Leda"] DEFAULT_SPK = next((s for s in SPEAKER_ORDER_PREF if s in styles_dict), list(styles_dict.keys())[0]) def get_style_by_length(speaker: str, phoneme_len: int): spk_tensor = styles_dict[speaker] # [510, 1, 256] hoặc [510, 256] idx = min(max(phoneme_len, 1), spk_tensor.shape[0]) - 1 feat = spk_tensor[idx] # ép về [1,256] if feat.ndim == 3: # [1,1,256] feat = feat.squeeze(0) if feat.ndim == 2: # [1,256] feat = feat.squeeze(0) return feat.unsqueeze(0).to(DEVICE) # [1,256] # ============================================================== # Style mixing utils # ============================================================== def parse_mix_spec(spec: str) -> dict: """Parse 'Kore:0.75,Puck:0.25' -> {'Kore':0.75,'Puck':0.25} (lọc lỗi, gộp trùng).""" mix = {} if not spec or not isinstance(spec, str): return mix for part in spec.split(","): if ":" not in part: continue k, v = part.split(":", 1) k = (k or "").strip() if not k: continue try: w = float((v or "").strip()) except Exception: continue if not np.isfinite(w) or w <= 0: continue mix[k] = mix.get(k, 0.0) + w return mix def get_style_mixed_by_length(mix_dict: dict, phoneme_len: int): """Trộn style của nhiều speaker theo trọng số. Trả về [1,256] trên DEVICE.""" if not mix_dict: return get_style_by_length(DEFAULT_SPK, phoneme_len) total = sum(max(0.0, float(w)) for w in mix_dict.values()) if total <= 0: return get_style_by_length(DEFAULT_SPK, phoneme_len) mix_feat = None for spk, w in mix_dict.items(): if spk not in styles_dict: print(f"[WARN] Speaker '{spk}' không có trong styles_dict, bỏ qua.") continue feat_i = get_style_by_length(spk, phoneme_len) # [1,256] wi = float(w) / total mix_feat = feat_i * wi if mix_feat is None else mix_feat + feat_i * wi if mix_feat is None: return get_style_by_length(DEFAULT_SPK, phoneme_len) return mix_feat # [1,256] # ============================================================== # Audio postprocess (librosa): trim + denoise + remove internal silence # ============================================================== def _simple_spectral_denoise(y, sr, n_fft=1024, hop=256, prop_decrease=0.8): if y.size == 0: return y D = librosa.stft(y, n_fft=n_fft, hop_length=hop, win_length=n_fft) S = np.abs(D) noise = np.median(S, axis=1, keepdims=True) S_clean = S - prop_decrease * noise S_clean = np.maximum(S_clean, 0.0) gain = S_clean / (S + 1e-8) D_denoised = D * gain y_out = librosa.istft(D_denoised, hop_length=hop, win_length=n_fft, length=len(y)) return y_out def _concat_with_crossfade(segments, crossfade_samples=0): if not segments: return np.array([], dtype=np.float32) out = segments[0].astype(np.float32, copy=True) for seg in segments[1:]: seg = seg.astype(np.float32, copy=False) if crossfade_samples > 0 and out.size > 0 and seg.size > 0: cf = min(crossfade_samples, out.size, seg.size) fade_out = np.linspace(1.0, 0.0, cf, dtype=np.float32) fade_in = 1.0 - fade_out tail = out[-cf:] * fade_out + seg[:cf] * fade_in out = np.concatenate([out[:-cf], tail, seg[cf:]], axis=0) else: out = np.concatenate([out, seg], axis=0) return out def _reduce_internal_silence(y, sr, top_db=30, min_keep_ms=40, crossfade_ms=8): if y.size == 0: return y intervals = librosa.effects.split(y, top_db=top_db) if intervals.size == 0: return y min_keep = int(sr * (min_keep_ms / 1000.0)) segs = [] for s, e in intervals: if e - s >= min_keep: segs.append(y[s:e]) if not segs: return y crossfade = int(sr * (crossfade_ms / 1000.0)) y_out = _concat_with_crossfade(segs, crossfade_samples=crossfade) return y_out def postprocess_audio(y, sr, trim_top_db=30, denoise=True, denoise_n_fft=1024, denoise_hop=256, denoise_strength=0.8, remove_internal_silence=True, split_top_db=30, min_keep_ms=40, crossfade_ms=8): if y.size == 0: return y.astype(np.float32) y_trim, _ = librosa.effects.trim(y, top_db=trim_top_db) if denoise: y_trim = _simple_spectral_denoise( y_trim, sr, n_fft=denoise_n_fft, hop=denoise_hop, prop_decrease=denoise_strength ) if remove_internal_silence: y_trim = _reduce_internal_silence( y_trim, sr, top_db=split_top_db, min_keep_ms=min_keep_ms, crossfade_ms=crossfade_ms ) y_trim = np.nan_to_num(y_trim, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) m = np.max(np.abs(y_trim)) + 1e-8 if m > 1.0: y_trim = y_trim / m return y_trim # ============================================================== # Inference core # ============================================================== def inference_one(text, ref_feat, alpha=ALPHA, beta=BETA, diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE): ps = phonemize_text(text) tokens = textcleaner(ps) tokens.insert(0, 0) tokens = torch.LongTensor(tokens).unsqueeze(0).to(DEVICE) input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE) text_mask = length_to_mask(input_lengths).to(DEVICE) with torch.no_grad(): t_en = model.text_encoder(tokens, input_lengths, text_mask) bert_d = model.bert(tokens, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_d).transpose(-1, -2) if alpha == 0 and beta == 0: s_pred = ref_feat.clone() # [1,256] else: s_pred = sampler( noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE), embedding=bert_d, embedding_scale=embedding_scale, features=ref_feat, # [1,256] num_steps=diffusion_steps, ).squeeze(1) # [1,256] s, ref = s_pred[:, 128:], s_pred[:, :128] ref = alpha * ref + (1 - alpha) * ref_feat[:, :128] s = beta * s + (1 - beta) * ref_feat[:, 128:] # --- Metrics (cosine) --- def cosine_sim(a, b): return torch.nn.functional.cosine_similarity(a, b, dim=1).mean().item() simi_timbre = cosine_sim(s_pred[:, :128], ref_feat[:, :128]) simi_prosody = cosine_sim(s_pred[:, 128:], ref_feat[:, 128:]) # --- Duration / Alignment --- d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = model.predictor.lstm(d) duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1) pred_dur = torch.round(duration.squeeze()).clamp(min=1) T = int(pred_dur.sum().item()) pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE) c = 0 for i in range(input_lengths.item()): span = int(pred_dur[i].item()) pred_aln[i, c:c+span] = 1.0 c += span en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0)) if model_params.decoder.type == "hifigan": en = torch.cat([en[:, :, :1], en[:, :, :-1]], dim=2) F0_pred, N_pred = model.predictor.F0Ntrain(en, s) asr = (t_en @ pred_aln.unsqueeze(0)) if model_params.decoder.type == "hifigan": asr = torch.cat([asr[:, :, :1], asr[:, :, :-1]], dim=2) out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) wav = out.squeeze().detach().cpu().numpy() if wav.shape[-1] > 50: wav = wav[:-50] # Hậu xử lý: trim + denoise + bỏ silence nội bộ # wav = postprocess_audio( # wav, SR_OUT, # trim_top_db=30, # denoise=True, # denoise_n_fft=1024, denoise_hop=256, denoise_strength=0.8, # remove_internal_silence=True, # split_top_db=30, min_keep_ms=40, crossfade_ms=8 # ) return wav, ps, simi_timbre, simi_prosody # ============================================================== # Ref-audio mapping (quét ./ref_voice để tìm file mẫu theo speaker) # ============================================================== def _norm(s: str) -> str: import unicodedata s = unicodedata.normalize("NFKD", s) s = "".join([c for c in s if not unicodedata.combining(c)]) s = s.lower() s = re.sub(r"[^a-z0-9_\-\.]+", "", s) return s def build_ref_map(ref_dir: str) -> dict: paths = glob.glob(os.path.join(ref_dir, "**", "*.wav"), recursive=True) by_name = {} for p in paths: fname = os.path.basename(p) by_name[_norm(fname)] = p spk_map = {} speakers = list(styles_dict.keys()) if isinstance(styles_dict, dict) else ["Kore","Algenib","Puck","Leda"] for spk in speakers: spk_n = _norm(spk) hit = None for k, p in by_name.items(): if f"_{spk_n}_" in k: hit = p break if not hit: for k, p in by_name.items(): if spk_n in k: hit = p break if hit: spk_map[spk] = hit return spk_map REF_MAP = build_ref_map(REF_DIR) def get_ref_path_for_speaker(spk: str): return REF_MAP.get(spk) # ============================================================== # Wrapper cho Gradio (nhận speaker_mix_spec là string ẩn) # ============================================================== def run_inference(text, alpha, beta, speaker_mix_spec): normalizer = SoeNormalizer() text = normalizer.normalize(text).replace(" ,", ",").replace(" .", ".") ps = phonemize_text(text) phoneme_len = len(ps.replace(" ", "")) mix_dict = parse_mix_spec(speaker_mix_spec) if len(mix_dict) > 0: ref_feat = get_style_mixed_by_length(mix_dict, phoneme_len) ref_idx = min(phoneme_len, 510) total = sum(mix_dict.values()) mix_info = {k: round(float(v / total), 3) for k, v in mix_dict.items()} chosen_speakers = list(mix_dict.keys()) else: ref_feat = get_style_by_length(DEFAULT_SPK, phoneme_len) ref_idx = min(phoneme_len, 510) mix_info = {DEFAULT_SPK: 1.0} chosen_speakers = [DEFAULT_SPK] t0 = time.time() wav, ps_out, simi_timbre, simi_prosody = inference_one( text, ref_feat, alpha=float(alpha), beta=float(beta) ) gen_time = time.time() - t0 rtf = gen_time / max(1e-6, len(wav) / SR_OUT) info = { "Text after soe_vinorms:": text, "Speakers": chosen_speakers, "Mix weights (normalized)": mix_info, "Phonemes": ps_out, "Phoneme length": phoneme_len, "Ref index": ref_idx, "simi_timbre": round(float(simi_timbre), 4), "simi_prosody": round(float(simi_prosody), 4), "alpha": float(alpha), "beta": float(beta), "RTF": round(float(rtf), 3), "Device": DEVICE, } return (SR_OUT, wav.astype(np.float32)), info # ============================================================== # UI helper: build mix-spec CỐ ĐỊNH theo 4 speaker # ============================================================== def _build_mix_spec_ui_fixed(normalize, w1, w2, w3, w4, order): pairs = [(order[0], float(w1 or 0.0)), (order[1], float(w2 or 0.0)), (order[2], float(w3 or 0.0)), (order[3], float(w4 or 0.0))] pairs = [(s, w) for s, w in pairs if w > 0] if not pairs: return "", {}, "**Sum:** 0.000" total = sum(w for _, w in pairs) if normalize and total > 0: pairs = [(s, w/total) for s, w in pairs] acc = {} for s, w in pairs: acc[s] = acc.get(s, 0.0) + w mix_spec = ",".join([f"{s}:{w:.4f}" for s, w in acc.items()]) mix_view = {"weights": {s: round(w, 3) for s, w in acc.items()}, "normalized": bool(normalize)} sum_md = f"**Sum:** {round(sum(acc.values()), 3)}" return mix_spec, mix_view, sum_md # ============================================================== # Gradio UI # ============================================================== with gr.Blocks(title="StyleTTS2-vi Demo") as demo: gr.Markdown("# StyleTTS2-vi Demo") with gr.Row(): with gr.Column(): text_inp = gr.Textbox( label="Text", lines=4, value="Trăng treo lơ lửng trên đỉnh núi chơ vơ, ánh sáng bàng bạc phủ lên bãi đá ngổn ngang. Con dế thổn thức trong khe cỏ, tiếng gió hun hút lùa qua hốc núi trập trùng. Dưới thung lũng, đàn trâu gặm cỏ ung dung, hơi sương vẩn đục, lảng bảng giữa đồng khuya tĩnh mịch." ) # Danh sách speaker có trong styles_dict spk_choices = list(styles_dict.keys()) if isinstance(styles_dict, dict) else ["Kore","Algenib","Puck","Leda"] # Thứ tự CỐ ĐỊNH cho mixer fixed_order = [s for s in ["Kore", "Puck", "Algenib", "Leda"] if s in spk_choices] if len(fixed_order) < 4: for s in spk_choices: if s not in fixed_order: fixed_order.append(s) if len(fixed_order) == 4: break # === Luôn hiển thị 4 voice sample === with gr.Accordion("Reference samples", open=True): with gr.Row(): spk0 = fixed_order[0] if len(fixed_order) > 0 else "Kore" spk1 = fixed_order[1] if len(fixed_order) > 1 else "Puck" with gr.Column(): gr.Markdown(f"**{spk0}**") gr.Audio(value=get_ref_path_for_speaker(spk0), label=f"{spk0} sample", type="filepath", interactive=False) with gr.Column(): gr.Markdown(f"**{spk1}**") gr.Audio(value=get_ref_path_for_speaker(spk1), label=f"{spk1} sample", type="filepath", interactive=False) with gr.Row(): spk2 = fixed_order[2] if len(fixed_order) > 2 else "Algenib" spk3 = fixed_order[3] if len(fixed_order) > 3 else "Leda" with gr.Column(): gr.Markdown(f"**{spk2}**") gr.Audio(value=get_ref_path_for_speaker(spk2), label=f"{spk2} sample", type="filepath", interactive=False) with gr.Column(): gr.Markdown(f"**{spk3}**") gr.Audio(value=get_ref_path_for_speaker(spk3), label=f"{spk3} sample", type="filepath", interactive=False) # ---- Style Mixer cố định 4 slot ---- with gr.Accordion("Style Mixer", open=True): normalize_ck = gr.Checkbox(value=True, label="Normalize weights to 1") # Hàng 1: Kore & Puck with gr.Row(equal_height=True): with gr.Column(): gr.Markdown(f"**{fixed_order[0]}**") w1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Weight 1", container=False) with gr.Column(): gr.Markdown(f"**{fixed_order[1]}**") w2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Weight 2", container=False) # Hàng 2: Algenib & Leda with gr.Row(equal_height=True): with gr.Column(): gr.Markdown(f"**{fixed_order[2]}**") w3 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Weight 3", container=False) with gr.Column(): gr.Markdown(f"**{fixed_order[3]}**") w4 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Weight 4", container=False) mix_sum_md = gr.Markdown("**Sum:** 0.000") mix_view_json = gr.JSON(label="Mixer weights (view)") mix_spec_state = gr.State("") # string mix-spec cho backend order_state = gr.State(fixed_order) # giữ thứ tự cố định with gr.Row(): alpha_n = gr.Number(value=ALPHA, label="alpha diffusion (0-1, timbre)", precision=3) beta_n = gr.Number(value=BETA, label="beta diffusion (0-1, prosody)", precision=3) btn = gr.Button("Đọc 🔊🔥", variant="primary") with gr.Column(): out_audio = gr.Audio(label="Synthesised Audio", type="numpy") metrics = gr.JSON(label="Metrics") # Bất kỳ thay đổi weight/normalize -> build spec cố định + update tổng/json def _ui_build_wrapper_fixed(normalize, w1, w2, w3, w4, order): spec, view, summ = _build_mix_spec_ui_fixed(normalize, w1, w2, w3, w4, order) return spec, view, summ for comp in [normalize_ck, w1, w2, w3, w4]: comp.change( _ui_build_wrapper_fixed, inputs=[normalize_ck, w1, w2, w3, w4, order_state], outputs=[mix_spec_state, mix_view_json, mix_sum_md] ) # Nút đọc: dùng mix_spec_state; nếu rỗng => fallback DEFAULT_SPK btn.click( run_inference, inputs=[text_inp, alpha_n, beta_n, mix_spec_state], outputs=[out_audio, metrics] ) if __name__ == "__main__": demo.launch()