Spaces:
Sleeping
Sleeping
| # inference/pipeline.py | |
| import os | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Optional | |
| from utils.hparams import set_hparams, hparams | |
| from inference.ds_variance import DiffSingerVarianceInfer | |
| from inference.ds_acoustic import DiffSingerAcousticInfer | |
| from utils.infer_utils import parse_commandline_spk_mix, trans_key | |
| from webapp.services.parsing.ds_validator import validate_ds | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| HF_CHECKPOINTS_DIR = "/tmp/cantussvs_v1/checkpoints" | |
| def run_inference( | |
| ds_path: Path, | |
| output_dir: Path, | |
| title: str, | |
| *, | |
| variance_exp: str = "regular_variance_v1", | |
| acoustic_exp: str = "debug_test", | |
| seed: int = 42, | |
| num_runs: int = 1, | |
| key_shift: int = 0, | |
| gender: Optional[float] = None | |
| ) -> Path: | |
| """ | |
| Runs the full pipeline: variance model => acoustic model; | |
| returns the path to the generated WAV. | |
| """ | |
| sys.argv = [ | |
| "", | |
| "--config", str(PROJECT_ROOT / "checkpoints" / variance_exp / "config.yaml"), | |
| "--exp_name", variance_exp, | |
| "--infer" | |
| ] | |
| set_hparams(print_hparams=False) | |
| # 1) Check input DS exists | |
| if not ds_path.exists(): | |
| raise FileNotFoundError(f"Input DS file not found: {ds_path}") | |
| # 2) Load DS params | |
| with open(ds_path, "r", encoding="utf-8") as f: | |
| params = json.load(f) | |
| if not isinstance(params, list): | |
| params = [params] | |
| # Validate loaded DS files | |
| for p in params: | |
| try: | |
| validate_ds(p) | |
| except Exception as e: | |
| raise ValueError(f"Invalid input DS file: {e}") | |
| # Ensure ph_seq present | |
| for p in params: | |
| if "ph_seq" not in p: | |
| text = p.get("text", "") | |
| p["ph_seq"] = " ".join(list(text.replace(" ", ""))) | |
| # Transpose | |
| if key_shift != 0: | |
| params = trans_key(params, key_shift) | |
| # Speaker mix | |
| spk_mix = parse_commandline_spk_mix(None) if hparams.get("use_spk_id") else None | |
| for p in params: | |
| if gender is not None and hparams.get("use_key_shift_embed"): | |
| p["gender"] = gender | |
| if spk_mix is not None: | |
| p["spk_mix"] = spk_mix | |
| # ==== Variance Inference ==== # | |
| print(f"[pipeline] Loading variance exp: {variance_exp}") | |
| variance_config_path = os.path.join(HF_CHECKPOINTS_DIR, variance_exp, "config.yaml") | |
| sys.argv = [ | |
| "", | |
| "--config", variance_config_path, | |
| "--exp_name", variance_exp, | |
| "--infer" | |
| ] | |
| set_hparams(print_hparams=False) | |
| print("[pipeline] Variance hparams keys:", sorted(hparams.keys())) | |
| var_infer = DiffSingerVarianceInfer(ckpt_steps=None, predictions={"dur", "pitch"}) | |
| ds_out = output_dir / f"{title}.ds" | |
| var_infer.run_inference(params, out_dir=output_dir, title=title, num_runs=1, seed=seed) | |
| if not ds_out.exists(): | |
| raise RuntimeError(f"Variance inference failed; missing {ds_out}") | |
| # Reload params from variance output | |
| with open(ds_out, "r", encoding="utf-8") as f: | |
| params = json.load(f) | |
| if not isinstance(params, list): | |
| params = [params] | |
| # Validate variance output DS | |
| for p in params: | |
| try: | |
| validate_ds(p) | |
| except Exception as e: | |
| raise ValueError(f"Invalid DS after variance inference: {e}") | |
| # ==== Acoustic Inference ==== # | |
| print(f"[pipeline] Loading acoustic exp: {acoustic_exp}") | |
| acoustic_config_path = os.path.join(HF_CHECKPOINTS_DIR, acoustic_exp, "config.yaml") | |
| sys.argv = [ | |
| "", | |
| "--config", acoustic_config_path, | |
| "--exp_name", acoustic_exp, | |
| "--infer" | |
| ] | |
| set_hparams(print_hparams=False) | |
| print("[pipeline] Acoustic hparams keys:", sorted(hparams.keys())) | |
| ac_infer = DiffSingerAcousticInfer(load_vocoder=True, ckpt_steps=None) | |
| ac_infer.run_inference(params, out_dir=output_dir, title=title, num_runs=num_runs, seed=seed) | |
| wav_out = output_dir / f"{title}.wav" | |
| if not wav_out.exists(): | |
| raise RuntimeError(f"Acoustic inference failed; missing {wav_out}") | |
| return wav_out | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Run full DiffSinger inference pipeline") | |
| parser.add_argument("ds_path", type=Path) | |
| parser.add_argument("output_dir", type=Path) | |
| parser.add_argument("--title", type=str, default=None) | |
| parser.add_argument("--variance_exp", type=str, default="regular_variance_v1") | |
| parser.add_argument("--acoustic_exp", type=str, default="debug_test") | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--num_runs", type=int, default=1) | |
| parser.add_argument("--key_shift", type=int, default=0) | |
| parser.add_argument("--gender", type=float, default=None) | |
| args = parser.parse_args() | |
| title = args.title or args.ds_path.stem | |
| run_inference( | |
| ds_path=args.ds_path, | |
| output_dir=args.output_dir, | |
| title=title, | |
| variance_exp=args.variance_exp, | |
| acoustic_exp=args.acoustic_exp, | |
| seed=args.seed, | |
| num_runs=args.num_runs, | |
| key_shift=args.key_shift, | |
| gender=args.gender, | |
| ) | |