File size: 3,979 Bytes
6317c91
 
 
 
 
 
 
 
4fcb97a
 
 
 
 
6317c91
 
 
4fcb97a
 
6317c91
4fcb97a
 
 
 
 
 
 
 
 
6317c91
 
 
4fcb97a
6317c91
4fcb97a
6317c91
5159013
6317c91
 
 
4fcb97a
 
 
6317c91
 
 
4fcb97a
 
6317c91
4fcb97a
5159013
 
 
 
4fcb97a
 
6317c91
 
 
4fcb97a
6317c91
4fcb97a
 
6317c91
 
 
 
 
4fcb97a
6317c91
 
 
 
 
 
 
 
 
 
 
 
 
5159013
6317c91
4fcb97a
 
6317c91
4fcb97a
 
5159013
6317c91
 
4fcb97a
5159013
6317c91
4fcb97a
 
5159013
4fcb97a
 
5159013
 
 
 
 
 
 
 
6317c91
5159013
 
6317c91
 
5159013
 
4fcb97a
5159013
 
6317c91
5159013
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import io
import tempfile
import subprocess
import numpy as np
import torch
import librosa
import soundfile as sf
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor

# ---------------------------
# FastAPI setup
# ---------------------------
app = FastAPI()

# CORS config
origins = os.environ.get('CORS_ORIGINS', '*').split(',') if os.environ.get('CORS_ORIGINS') else ['*']
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---------------------------
# Model setup
# ---------------------------
FFMPEG_BIN = os.environ.get('FFMPEG_BIN', 'ffmpeg')
MODEL_REPO = "marshal-yash/SER_wav2vec"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and feature extractor from Hugging Face
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_REPO)
fe = AutoFeatureExtractor.from_pretrained(MODEL_REPO)
model.to(device)
model.eval()

# ---------------------------
# Utility: Convert audio to 16kHz mono
# ---------------------------
def to_wav16k_mono(data: bytes) -> np.ndarray:
    try:
        # Use ffmpeg if available
        p = subprocess.run(
            [FFMPEG_BIN, '-hide_banner', '-loglevel', 'error',
             '-i', 'pipe:0', '-ar', str(fe.sampling_rate), '-ac', '1',
             '-f', 'wav', 'pipe:1'],
            input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
        )
        audio, sr = sf.read(io.BytesIO(p.stdout), dtype='float32', always_2d=False)
        if sr != fe.sampling_rate:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate)
        return audio.astype(np.float32)
    except Exception:
        # fallback: try reading directly with soundfile / librosa
        try:
            audio, sr = sf.read(io.BytesIO(data), dtype='float32', always_2d=False)
            if audio.ndim > 1:
                audio = np.mean(audio, axis=1)
            if sr != fe.sampling_rate:
                audio = librosa.resample(audio, orig_sr=sr, target_sr=fe.sampling_rate)
            return audio.astype(np.float32)
        except Exception:
            # last fallback
            with tempfile.NamedTemporaryFile(delete=True, suffix='.audio') as tmp:
                tmp.write(data)
                tmp.flush()
                y, _ = librosa.load(tmp.name, sr=fe.sampling_rate, mono=True)
                return y.astype(np.float32)

# ---------------------------
# Routes
# ---------------------------
@app.get("/")
def root():
    return {"status": "ok"}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        # Read audio file
        data = await file.read()
        audio = to_wav16k_mono(data)

        # Extract features
        inputs = fe(audio, sampling_rate=fe.sampling_rate, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Forward pass
        with torch.no_grad():
            logits = model(**inputs).logits

        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
        label_map = model.config.id2label
        labels = [label_map.get(str(i), f"class_{i}") for i in range(len(probs))]

        pairs = sorted(
            [(labels[i], float(probs[i])) for i in range(len(probs))],
            key=lambda x: x[1],
            reverse=True
        )

        dominant = {"label": pairs[0][0], "score": pairs[0][1]} if pairs else {"label": "", "score": 0.0}

        return {
            "results": [{"label": l, "score": s} for l, s in pairs],
            "dominant": dominant
        }

    except Exception as e:
        return JSONResponse(
            status_code=400,
            content={"error": "failed to process audio", "message": f"{e.__class__.__name__}: {e}"}
        )