Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import tempfile | |
| from flask import Flask, render_template, request, jsonify, send_file | |
| from werkzeug.utils import secure_filename | |
| from cached_path import cached_path | |
| from vinorm import TTSnorm | |
| from huggingface_hub import login | |
| import numpy as np | |
| import soundfile as sf | |
| from f5_tts.model import DiT | |
| from f5_tts.infer.utils_infer import ( | |
| preprocess_ref_audio_text, | |
| load_vocoder, | |
| load_model, | |
| infer_process, | |
| save_spectrogram, | |
| ) | |
| app = Flask(__name__) | |
| app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB max file size | |
| app.config['UPLOAD_FOLDER'] = tempfile.gettempdir() | |
| app.config['ALLOWED_EXTENSIONS'] = {'wav', 'mp3', 'ogg', 'flac', 'm4a'} | |
| # Retrieve token from secrets | |
| hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| # Log in to Hugging Face | |
| if hf_token: | |
| login(token=hf_token) | |
| def post_process(text): | |
| """Post process text by cleaning up punctuation and spacing""" | |
| text = " " + text + " " | |
| text = text.replace(" . . ", " . ") | |
| text = " " + text + " " | |
| text = text.replace(" .. ", " . ") | |
| text = " " + text + " " | |
| text = text.replace(" , , ", " , ") | |
| text = " " + text + " " | |
| text = text.replace(" ,, ", " , ") | |
| text = " " + text + " " | |
| text = text.replace('"', "") | |
| return " ".join(text.split()) | |
| def allowed_file(filename): | |
| """Check if file extension is allowed""" | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS'] | |
| # Load models once at startup | |
| print("Loading models...") | |
| vocoder = load_vocoder() | |
| model = load_model( | |
| DiT, | |
| dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4), | |
| ckpt_path=str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt")), | |
| vocab_file=str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json")), | |
| ) | |
| print("Models loaded successfully!") | |
| def index(): | |
| """Render the main page""" | |
| return render_template('index.html') | |
| def synthesize(): | |
| """ | |
| API endpoint for text-to-speech synthesis | |
| Parameters: | |
| - ref_audio: audio file (multipart/form-data) | |
| - gen_text: text to synthesize (string) | |
| - speed: synthesis speed (float, default: 1.0) | |
| Returns: | |
| - JSON with audio data (base64) and spectrogram | |
| """ | |
| try: | |
| # Validate request | |
| if 'ref_audio' not in request.files: | |
| return jsonify({'error': 'No audio file provided'}), 400 | |
| file = request.files['ref_audio'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| if not allowed_file(file.filename): | |
| return jsonify({'error': 'Invalid file format. Allowed: wav, mp3, ogg, flac, m4a'}), 400 | |
| gen_text = request.form.get('gen_text', '').strip() | |
| if not gen_text: | |
| return jsonify({'error': 'No text provided'}), 400 | |
| if len(gen_text.split()) > 1000: | |
| return jsonify({'error': 'Text too long. Maximum 1000 words'}), 400 | |
| speed = float(request.form.get('speed', 1.0)) | |
| if speed < 0.3 or speed > 2.0: | |
| return jsonify({'error': 'Speed must be between 0.3 and 2.0'}), 400 | |
| # Save uploaded file | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| # Process audio | |
| ref_audio, ref_text = preprocess_ref_audio_text(filepath, "") | |
| # Generate speech | |
| final_wave, final_sample_rate, spectrogram = infer_process( | |
| ref_audio, | |
| ref_text.lower(), | |
| post_process(TTSnorm(gen_text)).lower(), | |
| model, | |
| vocoder, | |
| speed=speed | |
| ) | |
| # Save audio to temporary file | |
| audio_path = os.path.join(app.config['UPLOAD_FOLDER'], 'output.wav') | |
| sf.write(audio_path, final_wave, final_sample_rate) | |
| # Convert audio to base64 | |
| with open(audio_path, 'rb') as f: | |
| audio_base64 = base64.b64encode(f.read()).decode('utf-8') | |
| # Save spectrogram | |
| spec_path = os.path.join(app.config['UPLOAD_FOLDER'], 'spectrogram.png') | |
| save_spectrogram(spectrogram, spec_path) | |
| # Convert spectrogram to base64 | |
| with open(spec_path, 'rb') as f: | |
| spec_base64 = base64.b64encode(f.read()).decode('utf-8') | |
| # Cleanup | |
| os.remove(filepath) | |
| os.remove(audio_path) | |
| os.remove(spec_path) | |
| if os.path.exists(ref_audio): | |
| os.remove(ref_audio) | |
| return jsonify({ | |
| 'success': True, | |
| 'audio': audio_base64, | |
| 'spectrogram': spec_base64, | |
| 'sample_rate': final_sample_rate, | |
| 'message': 'Speech synthesized successfully' | |
| }) | |
| except Exception as e: | |
| return jsonify({'error': f'Error generating speech: {str(e)}'}), 500 | |
| def health(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'model': 'F5-TTS Vietnamese', | |
| 'version': '1.0.0' | |
| }) | |
| def info(): | |
| """Get model information and limitations""" | |
| return jsonify({ | |
| 'model_name': 'F5-TTS Vietnamese', | |
| 'description': 'Vietnamese Text-to-Speech synthesis model trained on ~1000 hours of data', | |
| 'limitations': [ | |
| 'May not perform well with numerical characters, dates, special characters', | |
| 'Rhythm of some generated audios may be inconsistent or choppy', | |
| 'Reference audio text uses pho-whisper-medium which may not always accurately recognize Vietnamese', | |
| 'Inference with overly long paragraphs may produce poor results' | |
| ], | |
| 'max_words': 1000, | |
| 'speed_range': [0.3, 2.0], | |
| 'supported_audio_formats': ['wav', 'mp3', 'ogg', 'flac', 'm4a'] | |
| }) | |
| if __name__ == '__main__': | |
| # Run Flask app | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=False) | |