File size: 6,381 Bytes
258b448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee1f900
258b448
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import base64
import tempfile
from flask import Flask, render_template, request, jsonify, send_file
from werkzeug.utils import secure_filename
from cached_path import cached_path
from vinorm import TTSnorm
from huggingface_hub import login
import numpy as np
import soundfile as sf

from f5_tts.model import DiT
from f5_tts.infer.utils_infer import (
    preprocess_ref_audio_text,
    load_vocoder,
    load_model,
    infer_process,
    save_spectrogram,
)

app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024  # 50MB max file size
app.config['UPLOAD_FOLDER'] = tempfile.gettempdir()
app.config['ALLOWED_EXTENSIONS'] = {'wav', 'mp3', 'ogg', 'flac', 'm4a'}

# Retrieve token from secrets
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")

# Log in to Hugging Face
if hf_token:
    login(token=hf_token)

def post_process(text):
    """Post process text by cleaning up punctuation and spacing"""
    text = " " + text + " "
    text = text.replace(" . . ", " . ")
    text = " " + text + " "
    text = text.replace(" .. ", " . ")
    text = " " + text + " "
    text = text.replace(" , , ", " , ")
    text = " " + text + " "
    text = text.replace(" ,, ", " , ")
    text = " " + text + " "
    text = text.replace('"', "")
    return " ".join(text.split())

def allowed_file(filename):
    """Check if file extension is allowed"""
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']

# Load models once at startup
print("Loading models...")
vocoder = load_vocoder()
model = load_model(
    DiT,
    dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
    ckpt_path=str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/model_last.pt")),
    vocab_file=str(cached_path("hf://hynt/F5-TTS-Vietnamese-ViVoice/config.json")),
)
print("Models loaded successfully!")

@app.route('/')
def index():
    """Render the main page"""
    return render_template('index.html')

@app.route('/api/synthesize', methods=['POST'])
def synthesize():
    """

    API endpoint for text-to-speech synthesis

    

    Parameters:

    - ref_audio: audio file (multipart/form-data)

    - gen_text: text to synthesize (string)

    - speed: synthesis speed (float, default: 1.0)

    

    Returns:

    - JSON with audio data (base64) and spectrogram

    """
    try:
        # Validate request
        if 'ref_audio' not in request.files:
            return jsonify({'error': 'No audio file provided'}), 400
        
        file = request.files['ref_audio']
        if file.filename == '':
            return jsonify({'error': 'No file selected'}), 400
        
        if not allowed_file(file.filename):
            return jsonify({'error': 'Invalid file format. Allowed: wav, mp3, ogg, flac, m4a'}), 400
        
        gen_text = request.form.get('gen_text', '').strip()
        if not gen_text:
            return jsonify({'error': 'No text provided'}), 400
        
        if len(gen_text.split()) > 1000:
            return jsonify({'error': 'Text too long. Maximum 1000 words'}), 400
        
        speed = float(request.form.get('speed', 1.0))
        if speed < 0.3 or speed > 2.0:
            return jsonify({'error': 'Speed must be between 0.3 and 2.0'}), 400
        
        # Save uploaded file
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)
        
        # Process audio
        ref_audio, ref_text = preprocess_ref_audio_text(filepath, "")
        
        # Generate speech
        final_wave, final_sample_rate, spectrogram = infer_process(
            ref_audio,
            ref_text.lower(),
            post_process(TTSnorm(gen_text)).lower(),
            model,
            vocoder,
            speed=speed
        )
        
        # Save audio to temporary file
        audio_path = os.path.join(app.config['UPLOAD_FOLDER'], 'output.wav')
        sf.write(audio_path, final_wave, final_sample_rate)
        
        # Convert audio to base64
        with open(audio_path, 'rb') as f:
            audio_base64 = base64.b64encode(f.read()).decode('utf-8')
        
        # Save spectrogram
        spec_path = os.path.join(app.config['UPLOAD_FOLDER'], 'spectrogram.png')
        save_spectrogram(spectrogram, spec_path)
        
        # Convert spectrogram to base64
        with open(spec_path, 'rb') as f:
            spec_base64 = base64.b64encode(f.read()).decode('utf-8')
        
        # Cleanup
        os.remove(filepath)
        os.remove(audio_path)
        os.remove(spec_path)
        if os.path.exists(ref_audio):
            os.remove(ref_audio)
        
        return jsonify({
            'success': True,
            'audio': audio_base64,
            'spectrogram': spec_base64,
            'sample_rate': final_sample_rate,
            'message': 'Speech synthesized successfully'
        })
        
    except Exception as e:
        return jsonify({'error': f'Error generating speech: {str(e)}'}), 500

@app.route('/api/health', methods=['GET'])
def health():
    """Health check endpoint"""
    return jsonify({
        'status': 'healthy',
        'model': 'F5-TTS Vietnamese',
        'version': '1.0.0'
    })

@app.route('/api/info', methods=['GET'])
def info():
    """Get model information and limitations"""
    return jsonify({
        'model_name': 'F5-TTS Vietnamese',
        'description': 'Vietnamese Text-to-Speech synthesis model trained on ~1000 hours of data',
        'limitations': [
            'May not perform well with numerical characters, dates, special characters',
            'Rhythm of some generated audios may be inconsistent or choppy',
            'Reference audio text uses pho-whisper-medium which may not always accurately recognize Vietnamese',
            'Inference with overly long paragraphs may produce poor results'
        ],
        'max_words': 1000,
        'speed_range': [0.3, 2.0],
        'supported_audio_formats': ['wav', 'mp3', 'ogg', 'flac', 'm4a']
    })

if __name__ == '__main__':
    # Run Flask app
    port = int(os.environ.get('PORT', 7860))
    app.run(host='0.0.0.0', port=port, debug=False)