text2speech_docker / flask_app.py
manhteky123's picture
Upload flask_app.py
ee1f900 verified
import os
import base64
import tempfile
from flask import Flask, render_template, request, jsonify, send_file
from werkzeug.utils import secure_filename
from cached_path import cached_path
from vinorm import TTSnorm
from huggingface_hub import login
import numpy as np
import soundfile as sf
from f5_tts.model import DiT
from f5_tts.infer.utils_infer import (
preprocess_ref_audio_text,
load_vocoder,
load_model,
infer_process,
save_spectrogram,
)
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB max file size
app.config['UPLOAD_FOLDER'] = tempfile.gettempdir()
app.config['ALLOWED_EXTENSIONS'] = {'wav', 'mp3', 'ogg', 'flac', 'm4a'}
# Retrieve token from secrets
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Log in to Hugging Face
if hf_token:
login(token=hf_token)
def post_process(text):
"""Post process text by cleaning up punctuation and spacing"""
text = " " + text + " "
text = text.replace(" . . ", " . ")
text = " " + text + " "
text = text.replace(" .. ", " . ")
text = " " + text + " "
text = text.replace(" , , ", " , ")
text = " " + text + " "
text = text.replace(" ,, ", " , ")
text = " " + text + " "
text = text.replace('"', "")
return " ".join(text.split())
def allowed_file(filename):
"""Check if file extension is allowed"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
# Load models once at startup
print("Loading models...")
vocoder = load_vocoder()
model = load_model(
DiT,
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
ckpt_path=str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt")),
vocab_file=str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json")),
)
print("Models loaded successfully!")
@app.route('/')
def index():
"""Render the main page"""
return render_template('index.html')
@app.route('/api/synthesize', methods=['POST'])
def synthesize():
"""
API endpoint for text-to-speech synthesis
Parameters:
- ref_audio: audio file (multipart/form-data)
- gen_text: text to synthesize (string)
- speed: synthesis speed (float, default: 1.0)
Returns:
- JSON with audio data (base64) and spectrogram
"""
try:
# Validate request
if 'ref_audio' not in request.files:
return jsonify({'error': 'No audio file provided'}), 400
file = request.files['ref_audio']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
if not allowed_file(file.filename):
return jsonify({'error': 'Invalid file format. Allowed: wav, mp3, ogg, flac, m4a'}), 400
gen_text = request.form.get('gen_text', '').strip()
if not gen_text:
return jsonify({'error': 'No text provided'}), 400
if len(gen_text.split()) > 1000:
return jsonify({'error': 'Text too long. Maximum 1000 words'}), 400
speed = float(request.form.get('speed', 1.0))
if speed < 0.3 or speed > 2.0:
return jsonify({'error': 'Speed must be between 0.3 and 2.0'}), 400
# Save uploaded file
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# Process audio
ref_audio, ref_text = preprocess_ref_audio_text(filepath, "")
# Generate speech
final_wave, final_sample_rate, spectrogram = infer_process(
ref_audio,
ref_text.lower(),
post_process(TTSnorm(gen_text)).lower(),
model,
vocoder,
speed=speed
)
# Save audio to temporary file
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], 'output.wav')
sf.write(audio_path, final_wave, final_sample_rate)
# Convert audio to base64
with open(audio_path, 'rb') as f:
audio_base64 = base64.b64encode(f.read()).decode('utf-8')
# Save spectrogram
spec_path = os.path.join(app.config['UPLOAD_FOLDER'], 'spectrogram.png')
save_spectrogram(spectrogram, spec_path)
# Convert spectrogram to base64
with open(spec_path, 'rb') as f:
spec_base64 = base64.b64encode(f.read()).decode('utf-8')
# Cleanup
os.remove(filepath)
os.remove(audio_path)
os.remove(spec_path)
if os.path.exists(ref_audio):
os.remove(ref_audio)
return jsonify({
'success': True,
'audio': audio_base64,
'spectrogram': spec_base64,
'sample_rate': final_sample_rate,
'message': 'Speech synthesized successfully'
})
except Exception as e:
return jsonify({'error': f'Error generating speech: {str(e)}'}), 500
@app.route('/api/health', methods=['GET'])
def health():
"""Health check endpoint"""
return jsonify({
'status': 'healthy',
'model': 'F5-TTS Vietnamese',
'version': '1.0.0'
})
@app.route('/api/info', methods=['GET'])
def info():
"""Get model information and limitations"""
return jsonify({
'model_name': 'F5-TTS Vietnamese',
'description': 'Vietnamese Text-to-Speech synthesis model trained on ~1000 hours of data',
'limitations': [
'May not perform well with numerical characters, dates, special characters',
'Rhythm of some generated audios may be inconsistent or choppy',
'Reference audio text uses pho-whisper-medium which may not always accurately recognize Vietnamese',
'Inference with overly long paragraphs may produce poor results'
],
'max_words': 1000,
'speed_range': [0.3, 2.0],
'supported_audio_formats': ['wav', 'mp3', 'ogg', 'flac', 'm4a']
})
if __name__ == '__main__':
# Run Flask app
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)