import os import io import tempfile import subprocess import numpy as np import torch import librosa import soundfile as sf from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor # --------------------------- # FastAPI setup # --------------------------- app = FastAPI() # CORS config origins = os.environ.get('CORS_ORIGINS', '*').split(',') if os.environ.get('CORS_ORIGINS') else ['*'] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------- # Model setup # --------------------------- FFMPEG_BIN = os.environ.get('FFMPEG_BIN', 'ffmpeg') MODEL_REPO = "marshal-yash/SER_wav2vec" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model and feature extractor from Hugging Face model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO) fe = AutoFeatureExtractor.from_pretrained(MODEL_REPO) model.to(device) model.eval() # --------------------------- # Utility: Convert audio to 16kHz mono # --------------------------- def to_wav16k_mono(data: bytes) -> np.ndarray: try: # Use ffmpeg if available p = subprocess.run( [FFMPEG_BIN, '-hide_banner', '-loglevel', 'error', '-i', 'pipe:0', '-ar', str(fe.sampling_rate), '-ac', '1', '-f', 'wav', 'pipe:1'], input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True ) audio, sr = sf.read(io.BytesIO(p.stdout), dtype='float32', always_2d=False) if sr != fe.sampling_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate) return audio.astype(np.float32) except Exception: # fallback: try reading directly with soundfile / librosa try: audio, sr = sf.read(io.BytesIO(data), dtype='float32', always_2d=False) if audio.ndim > 1: audio = np.mean(audio, axis=1) if sr != fe.sampling_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate) return audio.astype(np.float32) except Exception: # last fallback with tempfile.NamedTemporaryFile(delete=True, suffix='.audio') as tmp: tmp.write(data) tmp.flush() y, _ = librosa.load(tmp.name, sr=fe.sampling_rate, mono=True) return y.astype(np.float32) # --------------------------- # Routes # --------------------------- @app.get("/") def root(): return {"status": "ok"} @app.post("/predict") async def predict(file: UploadFile = File(...)): try: # Read audio file data = await file.read() audio = to_wav16k_mono(data) # Extract features inputs = fe(audio, sampling_rate=fe.sampling_rate, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Forward pass with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1)[0].cpu().numpy() label_map = model.config.id2label labels = [label_map.get(str(i), f"class_{i}") for i in range(len(probs))] pairs = sorted( [(labels[i], float(probs[i])) for i in range(len(probs))], key=lambda x: x[1], reverse=True ) dominant = {"label": pairs[0][0], "score": pairs[0][1]} if pairs else {"label": "", "score": 0.0} return { "results": [{"label": l, "score": s} for l, s in pairs], "dominant": dominant } except Exception as e: return JSONResponse( status_code=400, content={"error": "failed to process audio", "message": f"{e.__class__.__name__}: {e}"} )