HeWhoComes's picture
Update app.py
948d66a verified
#!/usr/bin/env python3
"""
ZW Kitten - Production FastAPI interface for KittenTTS
"""
import os
import tempfile
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
import uvicorn
import numpy as np
import wave
# Import KittenTTS directly without demo fallback
from get_model import KittenTTS
from onnx_model import KittenTTS_1_Onnx
# Fixed version without arbitrary audio trimming
class FixedKittenTTS_1_Onnx(KittenTTS_1_Onnx):
"""Fixed version without arbitrary audio trimming"""
def generate(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0) -> np.ndarray:
"""Generate without the truncation bug"""
onnx_inputs = self._prepare_inputs(text, voice, speed)
outputs = self.session.run(None, onnx_inputs)
# FIXED: No arbitrary trimming - return full audio
audio = outputs[0]
return audio
class FixedKittenTTS:
"""KittenTTS with truncation bug fixed"""
def __init__(self):
self.kitten = KittenTTS()
# Apply the fix
self.kitten.model.generate = FixedKittenTTS_1_Onnx.generate.__get__(
self.kitten.model, KittenTTS_1_Onnx
)
def generate(self, text, voice="expr-voice-2-m", speed=1.0):
return self.kitten.generate(text, voice=voice, speed=speed)
app = FastAPI(title="ZW Kitten TTS")
# Initialize TTS
tts = FixedKittenTTS()
print("βœ… Fixed KittenTTS initialized (no truncation bug)")
@app.post("/generate")
async def generate_speech(request: Request):
"""Generate speech from text"""
data = await request.json()
text = data.get('text', '').strip()
character = data.get('character', 'claude')
emotion = data.get('emotion', 'neutral')
if not text:
return JSONResponse({"error": "No text provided"}, status_code=400)
# Character to voice mapping
voice_map = {
'claude': 'expr-voice-2-m',
'keen': 'expr-voice-2-f',
'tran': 'expr-voice-3-m',
'isla': 'expr-voice-4-f',
'system': 'expr-voice-5-m',
'narrator': 'expr-voice-5-f'
}
# Emotion to speed mapping
speed_map = {
'neutral': 1.0,
'cosmic_awareness': 0.8,
'determined': 1.1,
'whisper': 0.9,
'urgent': 1.2,
'calm': 0.9
}
voice = voice_map.get(character, 'expr-voice-2-m')
speed = speed_map.get(emotion, 1.0)
try:
# Generate real audio with KittenTTS
audio = tts.generate(text, voice=voice, speed=speed)
# Save to temporary WAV (mono, 16-bit PCM)
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
a = np.asarray(audio, dtype=np.float32)
# if model returns [N,1], flatten it:
if a.ndim > 1:
a = a[:, 0]
a = np.clip(a, -1.0, 1.0)
a_i16 = (a * 32767.0).astype(np.int16)
with wave.open(temp_file.name, "wb") as w:
w.setnchannels(1) # mono
w.setsampwidth(2) # 16-bit
w.setframerate(24000) # Hz (use your model's sr if different)
w.writeframes(a_i16.tobytes())
return JSONResponse({
"success": True,
"audio_url": f"/audio/{os.path.basename(temp_file.name)}",
"zw_block": f"""!zw/dialogue.intent:
character: {character}
emotion: {emotion}
line: {text}""",
"info": f"""🎭 Character: {character.title()}
😊 Emotion: {emotion.replace('_', ' ').title()} (speed: {speed:.1f}x)
🎡 Voice: {voice}
πŸ“ Text: "{text}"
βœ… Audio generated successfully!"""
})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.get("/audio/{filename}")
async def serve_audio(filename: str):
"""Serve generated audio files"""
file_path = os.path.join(tempfile.gettempdir(), filename)
if os.path.exists(file_path):
return FileResponse(file_path, media_type="audio/wav")
return JSONResponse({"error": "File not found"}, status_code=404)
@app.get("/", response_class=HTMLResponse)
def serve_interface():
"""Serve the web interface"""
return """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>🐱 ZW Kitten TTS</title>
<style>
/* ... (keep all your existing CSS styles) ... */
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🐱 ZW Kitten</h1>
<div class="subtitle">Text-to-Speech with Character Personalities</div>
</div>
<!-- ... (keep all your existing HTML structure) ... -->
<div class="repo-link">
<a href="https://github.com/SmokesBowls/zw-kitten-tts" target="_blank">
πŸ“š GitHub Repository
</a>
</div>
</div>
<script>
// ... (keep all your existing JavaScript) ...
</script>
</body>
</html>
"""
@app.get("/health")
def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)