|
|
from typing import Dict, List, Any |
|
|
from kokoro import KPipeline |
|
|
from IPython.display import display, Audio |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import io |
|
|
import os |
|
|
import base64 |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, model_dir: str): |
|
|
self.pipeline = KPipeline(lang_code='a') |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
inputs = data.get("inputs", {}) |
|
|
text = inputs.get("text") |
|
|
voice = inputs.get("voice") |
|
|
|
|
|
audio_segments = [] |
|
|
generator = self.pipeline(text, voice) |
|
|
|
|
|
|
|
|
for i, (gs, ps, audio) in enumerate(generator): |
|
|
audio_segments.append(audio) |
|
|
|
|
|
|
|
|
full_audio = torch.cat([torch.tensor(a) for a in audio_segments]) |
|
|
|
|
|
sample_rate = 24000 |
|
|
audio_length_seconds = len(full_audio) / sample_rate |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
sf.write(buffer, full_audio.numpy(), 24000, format='WAV') |
|
|
buffer.seek(0) |
|
|
audio_bytes = buffer.read() |
|
|
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") |
|
|
return { |
|
|
"headers": { |
|
|
"Content-Disposition": "attachment; filename=output.wav", |
|
|
"Content-Type": "audio/wav" |
|
|
}, |
|
|
"body": audio_b64, |
|
|
"statusCode": 200, |
|
|
"isBase64Encoded": True, |
|
|
"audio_length_seconds": audio_length_seconds |
|
|
} |