| import timm | |
| import json | |
| import torch | |
| from torchaudio.functional import resample | |
| import numpy as np | |
| from torchaudio.compliance import kaldi | |
| import torch.nn.functional as F | |
| import requests | |
| TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k" | |
| MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval() | |
| LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json" | |
| AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values()) | |
| SAMPLING_RATE = 16_000 | |
| MEAN = -4.2677393 | |
| STD = 4.5689974 | |
| def preprocess(x: torch.Tensor): | |
| x = x - x.mean() | |
| melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128) | |
| if melspec.shape[0] < 1024: | |
| melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0])) | |
| else: | |
| melspec = melspec[:1024] | |
| melspec = (melspec - MEAN) / (STD * 2) | |
| return melspec | |
| def predict_class(x: np.ndarray): | |
| x = torch.from_numpy(x) | |
| if x.ndim > 1: | |
| x = x.mean(-1) | |
| assert x.ndim == 1 | |
| x = preprocess(x) | |
| with torch.inference_mode(): | |
| logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0) | |
| topk_probs, topk_classes = logits.sigmoid().topk(10) | |
| preds = [[AUDIOSET_LABELS[cls], prob.item()*100] for cls, prob in zip(topk_classes, topk_probs)] | |
| return preds |