yash184 commited on
Commit
237e0e0
·
verified ·
1 Parent(s): 8574f55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -99
app.py CHANGED
@@ -1,22 +1,16 @@
1
  import random
2
  import numpy as np
3
  import torch
 
4
  import gradio as gr
5
  import spaces
6
- from pathlib import Path
7
- from typing import Optional
8
-
9
- # Import model class as before
10
- from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
11
 
12
- # --- Device detection ---
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  print(f"🚀 Running on device: {DEVICE}")
15
 
16
  # --- Global Model Initialization ---
17
- MODEL: Optional[ChatterboxMultilingualTTS] = None
18
 
19
- # --- Language config (kept exactly as you provided) ---
20
  LANGUAGE_CONFIG = {
21
  "ar": {
22
  "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
@@ -113,7 +107,7 @@ LANGUAGE_CONFIG = {
113
  }
114
 
115
  # --- UI Helpers ---
116
- def default_audio_for_ui(lang: str) -> Optional[str]:
117
  return LANGUAGE_CONFIG.get(lang, {}).get("audio")
118
 
119
 
@@ -127,118 +121,101 @@ def get_supported_languages_display() -> str:
127
  for code, name in sorted(SUPPORTED_LANGUAGES.items()):
128
  language_items.append(f"**{name}** (`{code}`)")
129
 
130
- # Split into 2 lines for readability
131
  mid = len(language_items) // 2
132
  line1 = " • ".join(language_items[:mid])
133
  line2 = " • ".join(language_items[mid:])
134
 
135
  return f"""
136
- ### Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
137
  {line1}
138
-
139
  {line2}
140
  """
141
 
142
 
143
- # --- Safe model loader with CPU fallback wrapper for torch.load ---
144
- def _load_model_with_cpu_fallback(device_str: str) -> ChatterboxMultilingualTTS:
145
- """
146
- Try to load the model normally. If a CUDA-deserialization RuntimeError occurs,
147
- patch torch.load temporarily to force map_location=cpu and retry.
148
- """
149
- global MODEL
150
-
151
- try:
152
- # First attempt: let the model loader handle device as asked.
153
- print(f"[model loader] Attempting to load model with device='{device_str}'")
154
- return ChatterboxMultilingualTTS.from_pretrained(device_str)
155
- except RuntimeError as e:
156
- msg = str(e)
157
- # Detect the common CUDA-deserialization error
158
- if "Attempting to deserialize object on a CUDA device" in msg or "cuda" in msg and "is not available" in msg:
159
- print("[model loader] Caught CUDA-deserialization error; retrying with forced CPU map_location.")
160
- # Backup original torch.load
161
- original_torch_load = torch.load
162
-
163
- def _torch_load_cpu_fallback(*args, **kwargs):
164
- # If user did not pass a map_location, force CPU
165
- if "map_location" not in kwargs:
166
- kwargs["map_location"] = torch.device("cpu")
167
- return original_torch_load(*args, **kwargs)
168
-
169
- try:
170
- torch.load = _torch_load_cpu_fallback # monkeypatch
171
- # Try again — pass "cpu" explicitly to the loader when retrying.
172
- return ChatterboxMultilingualTTS.from_pretrained("cpu")
173
- finally:
174
- # Restore original
175
- torch.load = original_torch_load
176
- else:
177
- # Not the CUDA-deserialization error we expected — re-raise
178
- raise
179
-
180
-
181
- def get_or_load_model() -> ChatterboxMultilingualTTS:
182
  """Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
183
- using safe CPU fallback handling for CUDA-saved checkpoints."""
184
  global MODEL
185
  if MODEL is None:
186
  print("Model not loaded, initializing...")
187
  try:
188
- MODEL = _load_model_with_cpu_fallback(DEVICE)
189
- # Move model to desired device if required and supported.
190
- if hasattr(MODEL, "to"):
191
- try:
192
- MODEL.to(torch.device(DEVICE))
193
- except Exception as e:
194
- # If moving to CUDA fails (e.g., CPU-only), log and continue using CPU.
195
- print(f"[model loader] Warning: failed to move model to {DEVICE}: {e}")
196
- print(f"Model loaded successfully. Internal device attribute: {getattr(MODEL, 'device', 'N/A')}")
197
  except Exception as e:
198
- print(f"CRITICAL: Failed to load model. Error: {e}")
199
  raise
200
  return MODEL
201
 
 
 
 
 
 
202
 
203
  def set_seed(seed: int):
204
  """Sets the random seed for reproducibility across torch, numpy, and random."""
205
  torch.manual_seed(seed)
206
- if DEVICE == "cuda" and torch.cuda.is_available():
207
  torch.cuda.manual_seed(seed)
208
  torch.cuda.manual_seed_all(seed)
209
  random.seed(seed)
210
  np.random.seed(seed)
211
-
212
-
213
- def resolve_audio_prompt(language_id: str, provided_path: Optional[str]) -> Optional[str]:
214
- """Choose provided prompt or default language prompt."""
 
 
 
215
  if provided_path and str(provided_path).strip():
216
  return provided_path
217
  return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
218
 
219
 
220
- # --- The TTS generation function used by Gradio ---
221
  @spaces.GPU
222
  def generate_tts_audio(
223
  text_input: str,
224
  language_id: str,
225
- audio_prompt_path_input: Optional[str] = None,
226
  exaggeration_input: float = 0.5,
227
  temperature_input: float = 0.8,
228
  seed_num_input: int = 0,
229
  cfgw_input: float = 0.5
230
  ) -> tuple[int, np.ndarray]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  current_model = get_or_load_model()
 
232
  if current_model is None:
233
  raise RuntimeError("TTS model is not loaded.")
234
 
235
  if seed_num_input != 0:
236
  set_seed(int(seed_num_input))
237
 
238
- print(f"Generating audio for text: '{(text_input or '')[:60]}...' language={language_id}")
 
 
 
239
 
240
- # Resolve prompt (user-provided or default)
241
- chosen_prompt = resolve_audio_prompt(language_id, audio_prompt_path_input)
242
  generate_kwargs = {
243
  "exaggeration": exaggeration_input,
244
  "temperature": temperature_input,
@@ -248,32 +225,28 @@ def generate_tts_audio(
248
  generate_kwargs["audio_prompt_path"] = chosen_prompt
249
  print(f"Using audio prompt: {chosen_prompt}")
250
  else:
251
- print("No audio prompt provided; using default/model voice.")
252
-
253
- # Call model.generate (keep same call signature as original)
254
  wav = current_model.generate(
255
- (text_input or "")[:300],
256
  language_id=language_id,
257
  **generate_kwargs
258
  )
259
-
260
  print("Audio generation complete.")
261
- # Ensure returned shape -> (sr, numpy array)
262
- sr = getattr(current_model, "sr", 22050)
263
- return (sr, wav.squeeze(0).numpy())
264
 
265
-
266
- # --- UI (keeps original layout, minimal changes) ---
267
  with gr.Blocks() as demo:
268
  gr.Markdown(
269
  """
270
  # Chatterbox Multilingual Demo
271
- Generate high-quality multilingual speech from text with reference audio styling, supporting many languages.
 
 
272
  """
273
  )
274
-
 
275
  gr.Markdown(get_supported_languages_display())
276
-
277
  with gr.Row():
278
  with gr.Column():
279
  initial_lang = "fr"
@@ -282,26 +255,26 @@ with gr.Blocks() as demo:
282
  label="Text to synthesize (max chars 300)",
283
  max_lines=5
284
  )
285
-
286
  language_id = gr.Dropdown(
287
  choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
288
  value=initial_lang,
289
  label="Language",
290
  info="Select the language for text-to-speech synthesis"
291
  )
292
-
293
  ref_wav = gr.Audio(
294
  sources=["upload", "microphone"],
295
  type="filepath",
296
  label="Reference Audio File (Optional)",
297
  value=default_audio_for_ui(initial_lang)
298
  )
299
-
300
  gr.Markdown(
301
- "Note: Ensure reference clip language matches the selected language. To avoid language transfer, set CFG weight to 0.",
302
  elem_classes=["audio-note"]
303
  )
304
-
305
  exaggeration = gr.Slider(
306
  0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
307
  )
@@ -342,12 +315,4 @@ with gr.Blocks() as demo:
342
  outputs=[audio_output],
343
  )
344
 
345
- # Attempt to warm-load model on startup (optional; will raise if fails)
346
- try:
347
- get_or_load_model()
348
- except Exception as e:
349
- print(f"Startup warning: Model failed to warm-load. You can still try Generate; error: {e}")
350
-
351
- # Launch Gradio
352
- if __name__ == "__main__":
353
- demo.launch(mcp_server=True)
 
1
  import random
2
  import numpy as np
3
  import torch
4
+ from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
5
  import gradio as gr
6
  import spaces
 
 
 
 
 
7
 
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
  print(f"🚀 Running on device: {DEVICE}")
10
 
11
  # --- Global Model Initialization ---
12
+ MODEL = None
13
 
 
14
  LANGUAGE_CONFIG = {
15
  "ar": {
16
  "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
 
107
  }
108
 
109
  # --- UI Helpers ---
110
+ def default_audio_for_ui(lang: str) -> str | None:
111
  return LANGUAGE_CONFIG.get(lang, {}).get("audio")
112
 
113
 
 
121
  for code, name in sorted(SUPPORTED_LANGUAGES.items()):
122
  language_items.append(f"**{name}** (`{code}`)")
123
 
124
+ # Split into 2 lines
125
  mid = len(language_items) // 2
126
  line1 = " • ".join(language_items[:mid])
127
  line2 = " • ".join(language_items[mid:])
128
 
129
  return f"""
130
+ ### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
131
  {line1}
 
132
  {line2}
133
  """
134
 
135
 
136
+ def get_or_load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  """Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
138
+ and ensures it's on the correct device."""
139
  global MODEL
140
  if MODEL is None:
141
  print("Model not loaded, initializing...")
142
  try:
143
+ MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
144
+ if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
145
+ MODEL.to(DEVICE)
146
+ print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
 
 
 
 
 
147
  except Exception as e:
148
+ print(f"Error loading model: {e}")
149
  raise
150
  return MODEL
151
 
152
+ # Attempt to load the model at startup.
153
+ try:
154
+ get_or_load_model()
155
+ except Exception as e:
156
+ print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
157
 
158
  def set_seed(seed: int):
159
  """Sets the random seed for reproducibility across torch, numpy, and random."""
160
  torch.manual_seed(seed)
161
+ if DEVICE == "cuda":
162
  torch.cuda.manual_seed(seed)
163
  torch.cuda.manual_seed_all(seed)
164
  random.seed(seed)
165
  np.random.seed(seed)
166
+
167
+ def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
168
+ """
169
+ Decide which audio prompt to use:
170
+ - If user provided a path (upload/mic/url), use it.
171
+ - Else, fall back to language-specific default (if any).
172
+ """
173
  if provided_path and str(provided_path).strip():
174
  return provided_path
175
  return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
176
 
177
 
 
178
  @spaces.GPU
179
  def generate_tts_audio(
180
  text_input: str,
181
  language_id: str,
182
+ audio_prompt_path_input: str = None,
183
  exaggeration_input: float = 0.5,
184
  temperature_input: float = 0.8,
185
  seed_num_input: int = 0,
186
  cfgw_input: float = 0.5
187
  ) -> tuple[int, np.ndarray]:
188
+ """
189
+ Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling.
190
+ Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi.
191
+
192
+ This tool synthesizes natural-sounding speech from input text. When a reference audio file
193
+ is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
194
+ maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
195
+ Args:
196
+ text_input (str): The text to synthesize into speech (maximum 300 characters)
197
+ language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi)
198
+ audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
199
+ exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
200
+ temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
201
+ seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
202
+ cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer.
203
+ Returns:
204
+ tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
205
+ """
206
  current_model = get_or_load_model()
207
+
208
  if current_model is None:
209
  raise RuntimeError("TTS model is not loaded.")
210
 
211
  if seed_num_input != 0:
212
  set_seed(int(seed_num_input))
213
 
214
+ print(f"Generating audio for text: '{text_input[:50]}...'")
215
+
216
+ # Handle optional audio prompt
217
+ chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
218
 
 
 
219
  generate_kwargs = {
220
  "exaggeration": exaggeration_input,
221
  "temperature": temperature_input,
 
225
  generate_kwargs["audio_prompt_path"] = chosen_prompt
226
  print(f"Using audio prompt: {chosen_prompt}")
227
  else:
228
+ print("No audio prompt provided; using default voice.")
229
+
 
230
  wav = current_model.generate(
231
+ text_input[:300], # Truncate text to max chars
232
  language_id=language_id,
233
  **generate_kwargs
234
  )
 
235
  print("Audio generation complete.")
236
+ return (current_model.sr, wav.squeeze(0).numpy())
 
 
237
 
 
 
238
  with gr.Blocks() as demo:
239
  gr.Markdown(
240
  """
241
  # Chatterbox Multilingual Demo
242
+ Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
243
+
244
+ For a hosted version of Chatterbox Multilingual and for finetuning, please visit [resemble.ai](https://app.resemble.ai)
245
  """
246
  )
247
+
248
+ # Display supported languages
249
  gr.Markdown(get_supported_languages_display())
 
250
  with gr.Row():
251
  with gr.Column():
252
  initial_lang = "fr"
 
255
  label="Text to synthesize (max chars 300)",
256
  max_lines=5
257
  )
258
+
259
  language_id = gr.Dropdown(
260
  choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
261
  value=initial_lang,
262
  label="Language",
263
  info="Select the language for text-to-speech synthesis"
264
  )
265
+
266
  ref_wav = gr.Audio(
267
  sources=["upload", "microphone"],
268
  type="filepath",
269
  label="Reference Audio File (Optional)",
270
  value=default_audio_for_ui(initial_lang)
271
  )
272
+
273
  gr.Markdown(
274
+ "💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.",
275
  elem_classes=["audio-note"]
276
  )
277
+
278
  exaggeration = gr.Slider(
279
  0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
280
  )
 
315
  outputs=[audio_output],
316
  )
317
 
318
+ demo.launch(mcp_server=True)