Spaces:
Sleeping
Sleeping
File size: 3,979 Bytes
6317c91 4fcb97a 6317c91 4fcb97a 6317c91 4fcb97a 6317c91 4fcb97a 6317c91 4fcb97a 6317c91 5159013 6317c91 4fcb97a 6317c91 4fcb97a 6317c91 4fcb97a 5159013 4fcb97a 6317c91 4fcb97a 6317c91 4fcb97a 6317c91 4fcb97a 6317c91 5159013 6317c91 4fcb97a 6317c91 4fcb97a 5159013 6317c91 4fcb97a 5159013 6317c91 4fcb97a 5159013 4fcb97a 5159013 6317c91 5159013 6317c91 5159013 4fcb97a 5159013 6317c91 5159013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import io
import tempfile
import subprocess
import numpy as np
import torch
import librosa
import soundfile as sf
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
# ---------------------------
# FastAPI setup
# ---------------------------
app = FastAPI()
# CORS config
origins = os.environ.get('CORS_ORIGINS', '*').split(',') if os.environ.get('CORS_ORIGINS') else ['*']
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------
# Model setup
# ---------------------------
FFMPEG_BIN = os.environ.get('FFMPEG_BIN', 'ffmpeg')
MODEL_REPO = "marshal-yash/SER_wav2vec"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and feature extractor from Hugging Face
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO)
fe = AutoFeatureExtractor.from_pretrained(MODEL_REPO)
model.to(device)
model.eval()
# ---------------------------
# Utility: Convert audio to 16kHz mono
# ---------------------------
def to_wav16k_mono(data: bytes) -> np.ndarray:
try:
# Use ffmpeg if available
p = subprocess.run(
[FFMPEG_BIN, '-hide_banner', '-loglevel', 'error',
'-i', 'pipe:0', '-ar', str(fe.sampling_rate), '-ac', '1',
'-f', 'wav', 'pipe:1'],
input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
)
audio, sr = sf.read(io.BytesIO(p.stdout), dtype='float32', always_2d=False)
if sr != fe.sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate)
return audio.astype(np.float32)
except Exception:
# fallback: try reading directly with soundfile / librosa
try:
audio, sr = sf.read(io.BytesIO(data), dtype='float32', always_2d=False)
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
if sr != fe.sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate)
return audio.astype(np.float32)
except Exception:
# last fallback
with tempfile.NamedTemporaryFile(delete=True, suffix='.audio') as tmp:
tmp.write(data)
tmp.flush()
y, _ = librosa.load(tmp.name, sr=fe.sampling_rate, mono=True)
return y.astype(np.float32)
# ---------------------------
# Routes
# ---------------------------
@app.get("/")
def root():
return {"status": "ok"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
# Read audio file
data = await file.read()
audio = to_wav16k_mono(data)
# Extract features
inputs = fe(audio, sampling_rate=fe.sampling_rate, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Forward pass
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
label_map = model.config.id2label
labels = [label_map.get(str(i), f"class_{i}") for i in range(len(probs))]
pairs = sorted(
[(labels[i], float(probs[i])) for i in range(len(probs))],
key=lambda x: x[1],
reverse=True
)
dominant = {"label": pairs[0][0], "score": pairs[0][1]} if pairs else {"label": "", "score": 0.0}
return {
"results": [{"label": l, "score": s} for l, s in pairs],
"dominant": dominant
}
except Exception as e:
return JSONResponse(
status_code=400,
content={"error": "failed to process audio", "message": f"{e.__class__.__name__}: {e}"}
)
|