SER_wav2vec / server.py
marshal-yash's picture
Update server.py
6317c91 verified
import os
import io
import tempfile
import subprocess
import numpy as np
import torch
import librosa
import soundfile as sf
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
# ---------------------------
# FastAPI setup
# ---------------------------
app = FastAPI()
# CORS config
origins = os.environ.get('CORS_ORIGINS', '*').split(',') if os.environ.get('CORS_ORIGINS') else ['*']
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------
# Model setup
# ---------------------------
FFMPEG_BIN = os.environ.get('FFMPEG_BIN', 'ffmpeg')
MODEL_REPO = "marshal-yash/SER_wav2vec"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and feature extractor from Hugging Face
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO)
fe = AutoFeatureExtractor.from_pretrained(MODEL_REPO)
model.to(device)
model.eval()
# ---------------------------
# Utility: Convert audio to 16kHz mono
# ---------------------------
def to_wav16k_mono(data: bytes) -> np.ndarray:
try:
# Use ffmpeg if available
p = subprocess.run(
[FFMPEG_BIN, '-hide_banner', '-loglevel', 'error',
'-i', 'pipe:0', '-ar', str(fe.sampling_rate), '-ac', '1',
'-f', 'wav', 'pipe:1'],
input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
)
audio, sr = sf.read(io.BytesIO(p.stdout), dtype='float32', always_2d=False)
if sr != fe.sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate)
return audio.astype(np.float32)
except Exception:
# fallback: try reading directly with soundfile / librosa
try:
audio, sr = sf.read(io.BytesIO(data), dtype='float32', always_2d=False)
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
if sr != fe.sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate)
return audio.astype(np.float32)
except Exception:
# last fallback
with tempfile.NamedTemporaryFile(delete=True, suffix='.audio') as tmp:
tmp.write(data)
tmp.flush()
y, _ = librosa.load(tmp.name, sr=fe.sampling_rate, mono=True)
return y.astype(np.float32)
# ---------------------------
# Routes
# ---------------------------
@app.get("/")
def root():
return {"status": "ok"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Read audio file
data = await file.read()
audio = to_wav16k_mono(data)
# Extract features
inputs = fe(audio, sampling_rate=fe.sampling_rate, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Forward pass
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
label_map = model.config.id2label
labels = [label_map.get(str(i), f"class_{i}") for i in range(len(probs))]
pairs = sorted(
[(labels[i], float(probs[i])) for i in range(len(probs))],
key=lambda x: x[1],
reverse=True
)
dominant = {"label": pairs[0][0], "score": pairs[0][1]} if pairs else {"label": "", "score": 0.0}
return {
"results": [{"label": l, "score": s} for l, s in pairs],
"dominant": dominant
}
except Exception as e:
return JSONResponse(
status_code=400,
content={"error": "failed to process audio", "message": f"{e.__class__.__name__}: {e}"}
)