Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,22 +1,16 @@
|
|
| 1 |
import random
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
import spaces
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Optional
|
| 8 |
-
|
| 9 |
-
# Import model class as before
|
| 10 |
-
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
| 11 |
|
| 12 |
-
# --- Device detection ---
|
| 13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
print(f"🚀 Running on device: {DEVICE}")
|
| 15 |
|
| 16 |
# --- Global Model Initialization ---
|
| 17 |
-
MODEL
|
| 18 |
|
| 19 |
-
# --- Language config (kept exactly as you provided) ---
|
| 20 |
LANGUAGE_CONFIG = {
|
| 21 |
"ar": {
|
| 22 |
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
|
|
@@ -113,7 +107,7 @@ LANGUAGE_CONFIG = {
|
|
| 113 |
}
|
| 114 |
|
| 115 |
# --- UI Helpers ---
|
| 116 |
-
def default_audio_for_ui(lang: str) ->
|
| 117 |
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
|
| 118 |
|
| 119 |
|
|
@@ -127,118 +121,101 @@ def get_supported_languages_display() -> str:
|
|
| 127 |
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
|
| 128 |
language_items.append(f"**{name}** (`{code}`)")
|
| 129 |
|
| 130 |
-
# Split into 2 lines
|
| 131 |
mid = len(language_items) // 2
|
| 132 |
line1 = " • ".join(language_items[:mid])
|
| 133 |
line2 = " • ".join(language_items[mid:])
|
| 134 |
|
| 135 |
return f"""
|
| 136 |
-
### Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
|
| 137 |
{line1}
|
| 138 |
-
|
| 139 |
{line2}
|
| 140 |
"""
|
| 141 |
|
| 142 |
|
| 143 |
-
|
| 144 |
-
def _load_model_with_cpu_fallback(device_str: str) -> ChatterboxMultilingualTTS:
|
| 145 |
-
"""
|
| 146 |
-
Try to load the model normally. If a CUDA-deserialization RuntimeError occurs,
|
| 147 |
-
patch torch.load temporarily to force map_location=cpu and retry.
|
| 148 |
-
"""
|
| 149 |
-
global MODEL
|
| 150 |
-
|
| 151 |
-
try:
|
| 152 |
-
# First attempt: let the model loader handle device as asked.
|
| 153 |
-
print(f"[model loader] Attempting to load model with device='{device_str}'")
|
| 154 |
-
return ChatterboxMultilingualTTS.from_pretrained(device_str)
|
| 155 |
-
except RuntimeError as e:
|
| 156 |
-
msg = str(e)
|
| 157 |
-
# Detect the common CUDA-deserialization error
|
| 158 |
-
if "Attempting to deserialize object on a CUDA device" in msg or "cuda" in msg and "is not available" in msg:
|
| 159 |
-
print("[model loader] Caught CUDA-deserialization error; retrying with forced CPU map_location.")
|
| 160 |
-
# Backup original torch.load
|
| 161 |
-
original_torch_load = torch.load
|
| 162 |
-
|
| 163 |
-
def _torch_load_cpu_fallback(*args, **kwargs):
|
| 164 |
-
# If user did not pass a map_location, force CPU
|
| 165 |
-
if "map_location" not in kwargs:
|
| 166 |
-
kwargs["map_location"] = torch.device("cpu")
|
| 167 |
-
return original_torch_load(*args, **kwargs)
|
| 168 |
-
|
| 169 |
-
try:
|
| 170 |
-
torch.load = _torch_load_cpu_fallback # monkeypatch
|
| 171 |
-
# Try again — pass "cpu" explicitly to the loader when retrying.
|
| 172 |
-
return ChatterboxMultilingualTTS.from_pretrained("cpu")
|
| 173 |
-
finally:
|
| 174 |
-
# Restore original
|
| 175 |
-
torch.load = original_torch_load
|
| 176 |
-
else:
|
| 177 |
-
# Not the CUDA-deserialization error we expected — re-raise
|
| 178 |
-
raise
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
def get_or_load_model() -> ChatterboxMultilingualTTS:
|
| 182 |
"""Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
|
| 183 |
-
|
| 184 |
global MODEL
|
| 185 |
if MODEL is None:
|
| 186 |
print("Model not loaded, initializing...")
|
| 187 |
try:
|
| 188 |
-
MODEL =
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
MODEL.to(torch.device(DEVICE))
|
| 193 |
-
except Exception as e:
|
| 194 |
-
# If moving to CUDA fails (e.g., CPU-only), log and continue using CPU.
|
| 195 |
-
print(f"[model loader] Warning: failed to move model to {DEVICE}: {e}")
|
| 196 |
-
print(f"Model loaded successfully. Internal device attribute: {getattr(MODEL, 'device', 'N/A')}")
|
| 197 |
except Exception as e:
|
| 198 |
-
print(f"
|
| 199 |
raise
|
| 200 |
return MODEL
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
def set_seed(seed: int):
|
| 204 |
"""Sets the random seed for reproducibility across torch, numpy, and random."""
|
| 205 |
torch.manual_seed(seed)
|
| 206 |
-
if DEVICE == "cuda"
|
| 207 |
torch.cuda.manual_seed(seed)
|
| 208 |
torch.cuda.manual_seed_all(seed)
|
| 209 |
random.seed(seed)
|
| 210 |
np.random.seed(seed)
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
| 215 |
if provided_path and str(provided_path).strip():
|
| 216 |
return provided_path
|
| 217 |
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
|
| 218 |
|
| 219 |
|
| 220 |
-
# --- The TTS generation function used by Gradio ---
|
| 221 |
@spaces.GPU
|
| 222 |
def generate_tts_audio(
|
| 223 |
text_input: str,
|
| 224 |
language_id: str,
|
| 225 |
-
audio_prompt_path_input:
|
| 226 |
exaggeration_input: float = 0.5,
|
| 227 |
temperature_input: float = 0.8,
|
| 228 |
seed_num_input: int = 0,
|
| 229 |
cfgw_input: float = 0.5
|
| 230 |
) -> tuple[int, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
current_model = get_or_load_model()
|
|
|
|
| 232 |
if current_model is None:
|
| 233 |
raise RuntimeError("TTS model is not loaded.")
|
| 234 |
|
| 235 |
if seed_num_input != 0:
|
| 236 |
set_seed(int(seed_num_input))
|
| 237 |
|
| 238 |
-
print(f"Generating audio for text: '{
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
-
# Resolve prompt (user-provided or default)
|
| 241 |
-
chosen_prompt = resolve_audio_prompt(language_id, audio_prompt_path_input)
|
| 242 |
generate_kwargs = {
|
| 243 |
"exaggeration": exaggeration_input,
|
| 244 |
"temperature": temperature_input,
|
|
@@ -248,32 +225,28 @@ def generate_tts_audio(
|
|
| 248 |
generate_kwargs["audio_prompt_path"] = chosen_prompt
|
| 249 |
print(f"Using audio prompt: {chosen_prompt}")
|
| 250 |
else:
|
| 251 |
-
print("No audio prompt provided; using default
|
| 252 |
-
|
| 253 |
-
# Call model.generate (keep same call signature as original)
|
| 254 |
wav = current_model.generate(
|
| 255 |
-
|
| 256 |
language_id=language_id,
|
| 257 |
**generate_kwargs
|
| 258 |
)
|
| 259 |
-
|
| 260 |
print("Audio generation complete.")
|
| 261 |
-
|
| 262 |
-
sr = getattr(current_model, "sr", 22050)
|
| 263 |
-
return (sr, wav.squeeze(0).numpy())
|
| 264 |
|
| 265 |
-
|
| 266 |
-
# --- UI (keeps original layout, minimal changes) ---
|
| 267 |
with gr.Blocks() as demo:
|
| 268 |
gr.Markdown(
|
| 269 |
"""
|
| 270 |
# Chatterbox Multilingual Demo
|
| 271 |
-
Generate high-quality multilingual speech from text with reference audio styling, supporting
|
|
|
|
|
|
|
| 272 |
"""
|
| 273 |
)
|
| 274 |
-
|
|
|
|
| 275 |
gr.Markdown(get_supported_languages_display())
|
| 276 |
-
|
| 277 |
with gr.Row():
|
| 278 |
with gr.Column():
|
| 279 |
initial_lang = "fr"
|
|
@@ -282,26 +255,26 @@ with gr.Blocks() as demo:
|
|
| 282 |
label="Text to synthesize (max chars 300)",
|
| 283 |
max_lines=5
|
| 284 |
)
|
| 285 |
-
|
| 286 |
language_id = gr.Dropdown(
|
| 287 |
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
|
| 288 |
value=initial_lang,
|
| 289 |
label="Language",
|
| 290 |
info="Select the language for text-to-speech synthesis"
|
| 291 |
)
|
| 292 |
-
|
| 293 |
ref_wav = gr.Audio(
|
| 294 |
sources=["upload", "microphone"],
|
| 295 |
type="filepath",
|
| 296 |
label="Reference Audio File (Optional)",
|
| 297 |
value=default_audio_for_ui(initial_lang)
|
| 298 |
)
|
| 299 |
-
|
| 300 |
gr.Markdown(
|
| 301 |
-
"Note
|
| 302 |
elem_classes=["audio-note"]
|
| 303 |
)
|
| 304 |
-
|
| 305 |
exaggeration = gr.Slider(
|
| 306 |
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
|
| 307 |
)
|
|
@@ -342,12 +315,4 @@ with gr.Blocks() as demo:
|
|
| 342 |
outputs=[audio_output],
|
| 343 |
)
|
| 344 |
|
| 345 |
-
|
| 346 |
-
try:
|
| 347 |
-
get_or_load_model()
|
| 348 |
-
except Exception as e:
|
| 349 |
-
print(f"Startup warning: Model failed to warm-load. You can still try Generate; error: {e}")
|
| 350 |
-
|
| 351 |
-
# Launch Gradio
|
| 352 |
-
if __name__ == "__main__":
|
| 353 |
-
demo.launch(mcp_server=True)
|
|
|
|
| 1 |
import random
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
+
from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
| 5 |
import gradio as gr
|
| 6 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
|
|
|
| 8 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
print(f"🚀 Running on device: {DEVICE}")
|
| 10 |
|
| 11 |
# --- Global Model Initialization ---
|
| 12 |
+
MODEL = None
|
| 13 |
|
|
|
|
| 14 |
LANGUAGE_CONFIG = {
|
| 15 |
"ar": {
|
| 16 |
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
|
|
|
|
| 107 |
}
|
| 108 |
|
| 109 |
# --- UI Helpers ---
|
| 110 |
+
def default_audio_for_ui(lang: str) -> str | None:
|
| 111 |
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
|
| 112 |
|
| 113 |
|
|
|
|
| 121 |
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
|
| 122 |
language_items.append(f"**{name}** (`{code}`)")
|
| 123 |
|
| 124 |
+
# Split into 2 lines
|
| 125 |
mid = len(language_items) // 2
|
| 126 |
line1 = " • ".join(language_items[:mid])
|
| 127 |
line2 = " • ".join(language_items[mid:])
|
| 128 |
|
| 129 |
return f"""
|
| 130 |
+
### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
|
| 131 |
{line1}
|
|
|
|
| 132 |
{line2}
|
| 133 |
"""
|
| 134 |
|
| 135 |
|
| 136 |
+
def get_or_load_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
"""Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
|
| 138 |
+
and ensures it's on the correct device."""
|
| 139 |
global MODEL
|
| 140 |
if MODEL is None:
|
| 141 |
print("Model not loaded, initializing...")
|
| 142 |
try:
|
| 143 |
+
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
|
| 144 |
+
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
|
| 145 |
+
MODEL.to(DEVICE)
|
| 146 |
+
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
except Exception as e:
|
| 148 |
+
print(f"Error loading model: {e}")
|
| 149 |
raise
|
| 150 |
return MODEL
|
| 151 |
|
| 152 |
+
# Attempt to load the model at startup.
|
| 153 |
+
try:
|
| 154 |
+
get_or_load_model()
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
|
| 157 |
|
| 158 |
def set_seed(seed: int):
|
| 159 |
"""Sets the random seed for reproducibility across torch, numpy, and random."""
|
| 160 |
torch.manual_seed(seed)
|
| 161 |
+
if DEVICE == "cuda":
|
| 162 |
torch.cuda.manual_seed(seed)
|
| 163 |
torch.cuda.manual_seed_all(seed)
|
| 164 |
random.seed(seed)
|
| 165 |
np.random.seed(seed)
|
| 166 |
+
|
| 167 |
+
def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
|
| 168 |
+
"""
|
| 169 |
+
Decide which audio prompt to use:
|
| 170 |
+
- If user provided a path (upload/mic/url), use it.
|
| 171 |
+
- Else, fall back to language-specific default (if any).
|
| 172 |
+
"""
|
| 173 |
if provided_path and str(provided_path).strip():
|
| 174 |
return provided_path
|
| 175 |
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
|
| 176 |
|
| 177 |
|
|
|
|
| 178 |
@spaces.GPU
|
| 179 |
def generate_tts_audio(
|
| 180 |
text_input: str,
|
| 181 |
language_id: str,
|
| 182 |
+
audio_prompt_path_input: str = None,
|
| 183 |
exaggeration_input: float = 0.5,
|
| 184 |
temperature_input: float = 0.8,
|
| 185 |
seed_num_input: int = 0,
|
| 186 |
cfgw_input: float = 0.5
|
| 187 |
) -> tuple[int, np.ndarray]:
|
| 188 |
+
"""
|
| 189 |
+
Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling.
|
| 190 |
+
Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi.
|
| 191 |
+
|
| 192 |
+
This tool synthesizes natural-sounding speech from input text. When a reference audio file
|
| 193 |
+
is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
|
| 194 |
+
maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
|
| 195 |
+
Args:
|
| 196 |
+
text_input (str): The text to synthesize into speech (maximum 300 characters)
|
| 197 |
+
language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi)
|
| 198 |
+
audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
|
| 199 |
+
exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
|
| 200 |
+
temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
|
| 201 |
+
seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
|
| 202 |
+
cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer.
|
| 203 |
+
Returns:
|
| 204 |
+
tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
|
| 205 |
+
"""
|
| 206 |
current_model = get_or_load_model()
|
| 207 |
+
|
| 208 |
if current_model is None:
|
| 209 |
raise RuntimeError("TTS model is not loaded.")
|
| 210 |
|
| 211 |
if seed_num_input != 0:
|
| 212 |
set_seed(int(seed_num_input))
|
| 213 |
|
| 214 |
+
print(f"Generating audio for text: '{text_input[:50]}...'")
|
| 215 |
+
|
| 216 |
+
# Handle optional audio prompt
|
| 217 |
+
chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
|
| 218 |
|
|
|
|
|
|
|
| 219 |
generate_kwargs = {
|
| 220 |
"exaggeration": exaggeration_input,
|
| 221 |
"temperature": temperature_input,
|
|
|
|
| 225 |
generate_kwargs["audio_prompt_path"] = chosen_prompt
|
| 226 |
print(f"Using audio prompt: {chosen_prompt}")
|
| 227 |
else:
|
| 228 |
+
print("No audio prompt provided; using default voice.")
|
| 229 |
+
|
|
|
|
| 230 |
wav = current_model.generate(
|
| 231 |
+
text_input[:300], # Truncate text to max chars
|
| 232 |
language_id=language_id,
|
| 233 |
**generate_kwargs
|
| 234 |
)
|
|
|
|
| 235 |
print("Audio generation complete.")
|
| 236 |
+
return (current_model.sr, wav.squeeze(0).numpy())
|
|
|
|
|
|
|
| 237 |
|
|
|
|
|
|
|
| 238 |
with gr.Blocks() as demo:
|
| 239 |
gr.Markdown(
|
| 240 |
"""
|
| 241 |
# Chatterbox Multilingual Demo
|
| 242 |
+
Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
|
| 243 |
+
|
| 244 |
+
For a hosted version of Chatterbox Multilingual and for finetuning, please visit [resemble.ai](https://app.resemble.ai)
|
| 245 |
"""
|
| 246 |
)
|
| 247 |
+
|
| 248 |
+
# Display supported languages
|
| 249 |
gr.Markdown(get_supported_languages_display())
|
|
|
|
| 250 |
with gr.Row():
|
| 251 |
with gr.Column():
|
| 252 |
initial_lang = "fr"
|
|
|
|
| 255 |
label="Text to synthesize (max chars 300)",
|
| 256 |
max_lines=5
|
| 257 |
)
|
| 258 |
+
|
| 259 |
language_id = gr.Dropdown(
|
| 260 |
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
|
| 261 |
value=initial_lang,
|
| 262 |
label="Language",
|
| 263 |
info="Select the language for text-to-speech synthesis"
|
| 264 |
)
|
| 265 |
+
|
| 266 |
ref_wav = gr.Audio(
|
| 267 |
sources=["upload", "microphone"],
|
| 268 |
type="filepath",
|
| 269 |
label="Reference Audio File (Optional)",
|
| 270 |
value=default_audio_for_ui(initial_lang)
|
| 271 |
)
|
| 272 |
+
|
| 273 |
gr.Markdown(
|
| 274 |
+
"💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.",
|
| 275 |
elem_classes=["audio-note"]
|
| 276 |
)
|
| 277 |
+
|
| 278 |
exaggeration = gr.Slider(
|
| 279 |
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
|
| 280 |
)
|
|
|
|
| 315 |
outputs=[audio_output],
|
| 316 |
)
|
| 317 |
|
| 318 |
+
demo.launch(mcp_server=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|