submission-frugalai / tasks /fast_model.py
tlemagueresse
Improve docstings and add notebooks
795052d
import os
import struct
import pickle
import numpy as np
import torch
import lightgbm as lgb
import torchaudio
from sklearn.exceptions import NotFittedError
from torchaudio.transforms import Spectrogram
import torch.nn.functional as F
from datasets.formatting import query_table
import warnings
warnings.filterwarnings("ignore")
SR = 12000
class FastModel:
"""
A class designed for training and predicting using LightGBM, incorporating spectral and cepstral features.
### Workflow:
1. Batch Loading and Decoding:
Load audio data in batches directly from a table and decode byte-encoded information.
2. Processing Audio:
- Resampling, Padding, or Truncating:
Adjust audio durations by padding, cutting, or resampling as needed.
- Spectral and Cepstral Feature Extraction:
- Compute the spectrogram for audio signals.
- Focus on a selected frequency range (~50-1500 Hz) to derive the cepstrum, calculated as the FFT of the logarithm of the spectrogram.
- Average both spectrogram and cepstral features over the time axis and combine them into a unified feature vector.
3. Model Application:
Use the extracted features as input for the LightGBM model to perform predictions.
### Options for Energy Optimization:
- Feature Selection:
Mask less significant features to reduce computation.
- Signal Truncation:
Process only a limited duration (e.g., a few seconds) of the audio signal.
- Hardware Acceleration:
Utilize CUDA to speed up feature computation when supported.
Attributes
----------
feature_params : dict
Parameters for configuring the MelSpectrogram transformation during training.
lgbm_params : dict, optional
Parameters for configuring the LightGBM model.
model_file : str
Path for saving or loading the trained LightGBM model.
padding_method : str
Padding method to apply when the waveform size is smaller than the desired size.
waveform_duration : float
Duration of the audio waveform to process, in seconds.
mask_features : bool
Whether to enable feature masking for dimensionality reduction.
mask_file : str
Path to save or load the feature mask file.
mask_ratio : float
The ratio of features to retain when feature masking is applied.
batch_size : int
Number of samples per batch during training and prediction.
apply_offset_on_fit : bool
Whether to apply the offset on fit. Useful if waveform_duration is below than 3 seconds.
device : str
Device used for computation ("cpu" or "cuda").
Methods
-------
_save_feature_mask(model, n_features, ratio):
Saves the most important features as a mask.
_load_feature_mask():
Loads the feature mask from the saved file.
fit(dataset):
Trains the LightGBM model on audio features extracted from the dataset.
predict(dataset, get_proba=False):
Predicts labels or probabilities for a dataset using the trained model.
get_features(audios, spectrogram_transformer, cepstral_transformer):
Extracts features from raw audio using spectrogram and cepstral transformations.
"""
def __init__(
self,
feature_params,
lgbm_params=None,
padding_method="zero",
waveform_duration=3,
model_file=None,
mask_features=False,
mask_file="feature_mask.pkl",
mask_ratio=0.25,
batch_size=5000,
apply_offset_on_fit=True,
device="cpu",
):
self.feature_params = feature_params
self.lgbm_params = lgbm_params
self.model_file = model_file
self.padding_method = padding_method
self.waveform_duration = waveform_duration
self.mask_features = mask_features
self.mask_file = mask_file
self.mask_ratio = mask_ratio
self.batch_size = batch_size
self.apply_offset_on_fit = apply_offset_on_fit
self.device = torch.device(
"cuda" if device == "cuda" and torch.cuda.is_available() else "cpu"
)
self.spectrogram_transformer = Spectrogram(
n_fft=self.feature_params["n_fft"],
hop_length=self.feature_params["hop_length"],
pad=self.feature_params["pad"],
window_fn=self.feature_params["win_spectrogram"],
power=self.feature_params["power"],
pad_mode=self.feature_params["pad_mode"],
onesided=True,
center=False,
).to(self.device)
self.f = torch.fft.rfftfreq(self.feature_params["n_fft"], d=1.0 / SR)
self.ind_f_filtered = torch.tensor(
(self.f > self.feature_params["f_min"]) & (self.f < self.feature_params["f_max"]),
device=self.device,
)
self.n_fft_cepstral = self.ind_f_filtered.sum()
self.cepstral_transformer = Spectrogram(
n_fft=self.n_fft_cepstral,
hop_length=self.n_fft_cepstral,
pad=0,
window_fn=self.feature_params["win_cepstral"],
power=self.feature_params["power"],
pad_mode=self.feature_params["pad_mode"],
onesided=True,
center=False,
).to(self.device)
self.cf = torch.fft.rfftfreq(self.n_fft_cepstral, d=0.5)
self.ind_cf_filtered = torch.tensor(
(self.cf > self.feature_params["fc_min"]) & (self.cf < self.feature_params["fc_max"]),
device=self.device,
)
def _save_feature_mask(self, model, n_features, ratio):
feature_importance = model.feature_importance(importance_type="gain")
sorted_indices = np.argsort(feature_importance)[::-1]
top_indices = sorted_indices[: max(1, int(n_features * ratio))]
mask = np.zeros(n_features, dtype=bool)
mask[top_indices] = True
with open(self.mask_file, "wb") as f:
pickle.dump(mask, f)
def _load_feature_mask(self):
with open(self.mask_file, "rb") as f:
return pickle.load(f)
def fit(self, dataset):
"""
Trains a LightGBM model on features extracted from the dataset.
Parameters
----------
dataset : Dataset
Dataset object containing audio samples and their corresponding labels.
Raises
------
ValueError
If the dataset is empty or invalid.
"""
features, labels = [], []
offsets = [0, 12000, 24000] if self.apply_offset_on_fit else [0]
for offset in offsets:
for audio, label in batch_audio_loader(
dataset,
waveform_duration=self.waveform_duration,
batch_size=self.batch_size,
padding_method=self.padding_method,
offset=offset,
):
feature = self.get_features(
audio, self.spectrogram_transformer, self.cepstral_transformer
)
features.append(feature)
labels.extend(label)
x_train = torch.cat(features, dim=0)
train_data = lgb.Dataset(x_train.cpu(), label=labels)
model = lgb.train(self.lgbm_params, train_data)
if self.mask_features:
self._save_feature_mask(model, x_train.shape[1], self.mask_ratio)
mask = self._load_feature_mask()
x_train = x_train[:, mask]
train_data = lgb.Dataset(x_train.cpu(), label=labels)
model = lgb.train(self.lgbm_params, train_data)
model.save_model(self.model_file)
def predict(self, dataset, get_proba=False):
"""
Predicts labels or probabilities for a dataset using the trained model.
Parameters
----------
dataset : Dataset
The dataset containing audio data for prediction.
get_proba : bool, optional
If True, returns class probabilities rather than binary predictions (default is False).
Returns
-------
numpy.ndarray
If `get_proba` is True, returns a 1D array of class probabilities.
If `get_proba` is False, returns a 1D array of binary predictions (0 or 1).
Raises
------
NotFittedError
If the model is not yet trained.
FileNotFoundError
If the model file does not exist.
"""
if not self.model_file:
raise NotFittedError("The model is not trained yet. Train using the `fit` method.")
if not os.path.isfile(self.model_file):
raise FileNotFoundError(f"Model file {self.model_file} not found.")
features = []
for audio, _ in batch_audio_loader(
dataset,
waveform_duration=self.waveform_duration,
batch_size=self.batch_size,
padding_method=self.padding_method,
):
feature = self.get_features(
audio, self.spectrogram_transformer, self.cepstral_transformer
)
features.append(feature)
features = torch.cat(features, dim=0)
torch.cuda.empty_cache()
if self.mask_features:
mask = self._load_feature_mask()
features = features[:, mask]
model = lgb.Booster(model_file=self.model_file)
y_score = model.predict(features.cpu())
return y_score if get_proba else (y_score >= 0.5).astype(int)
def get_features(self, audios, spectrogram_transformer, cepstral_transformer):
"""
Extracts features from raw audio using spectrogram and cepstrum transformations.
Parameters
----------
audios : torch.Tensor
A batch of audio waveforms as 1D tensors.
spectrogram_transformer : Spectrogram
Transformation used to compute MelSpectrogram features.
cepstral_transformer : Spectrogram
Transformation used to compute cepstral features.
Returns
-------
torch.Tensor
Extracted features for the audio batch. Includes both cepstral and log-scaled spectrogram features.
Raises
------
ValueError
If the input audio tensor is empty or invalid.
"""
audios = audios.to(self.device)
sxx = spectrogram_transformer(audios) # shape : (n_audios, n_f, n_blocks)
sxx = torch.log10(torch.clamp(sxx.permute(0, 2, 1), min=1e-10))
cepstral_mat = cepstral_transformer(sxx[:, :, self.ind_f_filtered]).squeeze(dim=3)[
:, :, self.ind_cf_filtered
]
return torch.cat(
[
cepstral_mat.mean(dim=1),
sxx.mean(dim=1),
],
dim=1,
)
def batch_audio_loader(
dataset,
waveform_duration=3,
batch_size=1,
sr=12000,
device="cpu",
padding_method=None,
offset=0,
):
"""
Loads and preprocesses audio data from a dataset for training or inference in batches.
Parameters
----------
dataset : Dataset
The dataset containing audio samples and labels.
waveform_duration : float, optional
Desired duration of the audio waveforms in seconds (default is 3).
batch_size : int, optional
Number of audio samples per batch (default is 1).
sr : int, optional
Target sampling rate for audio processing (default is 12000).
device : str, optional
Device for processing ("cpu" or "cuda") (default is "cpu").
padding_method : str, optional
Method to pad audio waveforms smaller than the desired size (e.g., "zero", "reflect").
offset : int, optional
Number of samples to skip before processing the first audio sample (default is 0).
Yields
------
tuple
A tuple (batch_audios, batch_labels), where:
- batch_audios is a tensor of processed audio waveforms.
- batch_labels is a tensor of corresponding audio labels.
Raises
------
ValueError
If an unsupported sampling rate is encountered in the dataset.
"""
def process_resampling(resample_buffer, resample_indices, batch_audios, sr, target_sr):
if resample_buffer:
resampler = torchaudio.transforms.Resample(
orig_freq=sr, new_freq=target_sr, lowpass_filter_width=6
)
resampled = resampler(torch.stack(resample_buffer))
for idx, original_idx in enumerate(resample_indices):
batch_audios[original_idx] = resampled[idx]
device = torch.device("cuda" if device == "cuda" and torch.cuda.is_available() else "cpu")
batch_audios, batch_labels = [], []
resample_24000, resample_24000_indices = [], []
for i in range(len(dataset)):
pa_subtable = query_table(dataset._data, i, indices=dataset._indices)
wav_bytes = pa_subtable[0][0][0].as_py()
sampling_rate = struct.unpack("<I", wav_bytes[24:28])[0]
if sampling_rate not in [sr, sr * 2]:
raise ValueError(
f"Unsupported sampling rate: {sampling_rate}Hz. Only {sr}Hz and {sr * 2}Hz are allowed."
)
data_size = struct.unpack("<I", wav_bytes[40:44])[0] // 2
if data_size == 0:
batch_audios.append(torch.zeros(int(waveform_duration * SR)))
else:
try:
waveform = (
torch.frombuffer(wav_bytes[44:], dtype=torch.int16, offset=offset)[
: int(waveform_duration * sampling_rate)
].float()
/ 32767
)
except Exception as e:
continue # May append during fit for small audios. offset is set to 0 during predict.
waveform = apply_padding(
waveform, int(waveform_duration * sampling_rate), padding_method
)
if sampling_rate == sr:
batch_audios.append(waveform)
elif sampling_rate == 2 * sr:
resample_24000.append(waveform)
resample_24000_indices.append(len(batch_audios))
batch_audios.append(None)
batch_labels.append(pa_subtable[1][0].as_py())
if len(batch_audios) == batch_size:
# Perform resampling once and take advantage of Torch's vectorization capabilities.
process_resampling(resample_24000, resample_24000_indices, batch_audios, sr * 2, SR)
batch_audios_on_device = torch.stack(batch_audios).to(device)
batch_labels_on_device = torch.tensor(batch_labels).to(device)
yield batch_audios_on_device, batch_labels_on_device
batch_audios, batch_labels = [], []
resample_24000, resample_24000_indices = [], []
if batch_audios:
process_resampling(resample_24000, resample_24000_indices, batch_audios, sr * 2, SR)
batch_audios_on_device = torch.stack(batch_audios).to(device)
batch_labels_on_device = torch.tensor(batch_labels).to(device)
yield batch_audios_on_device, batch_labels_on_device
def apply_padding(waveform, output_size, padding_method="zero"):
"""
Applies padding to the waveform when its size is smaller than the desired output size.
Parameters
----------
waveform : torch.Tensor
Input 1D waveform tensor.
output_size : int
Desired output size after padding or truncation.
padding_method : str, default="zero"
Padding method to apply.
Returns
-------
torch.Tensor
Padded or truncated waveform of size `output_size`.
"""
if waveform.size(0) >= output_size:
return waveform[:output_size]
total_pad = output_size - waveform.size(0)
if padding_method == "zero":
return F.pad(waveform, (0, total_pad), mode="constant", value=0)
if padding_method in ["reflect", "replicate", "circular"]:
# Pad not possible if waveform.size(0) < total_pad.
if waveform.size(0) < total_pad:
num_repeats = (total_pad // waveform.size(0)) + 1
waveform = torch.tile(waveform, (num_repeats,))
total_pad = output_size - waveform.size(0)
return F.pad(waveform.unsqueeze(0), (0, total_pad), mode=padding_method).squeeze()
raise ValueError(f"Invalid padding method: {padding_method}")