Spaces:
Runtime error
Runtime error
File size: 5,313 Bytes
2c0f55c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import logging
from typing import List, Optional
import pandas as pd
from src.datasets.base_dataset import SimpleAudioFakeDataset
from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset
from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset
from src.datasets.wavefake_dataset import WaveFakeDataset
from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal
from src.datasets.MLAADv3_dataset import MLAADv3
from src.datasets.MAILABS_dataset import MAILABS
from src.datasets.aihub_dataset import AIHUB
from src.datasets.KoAAD_dataset import KoAAD
LOGGER = logging.getLogger()
class DetectionDataset(SimpleAudioFakeDataset):
def __init__(
self,
asvspoof_path=None,
wavefake_path=None,
fakeavceleb_path=None,
asvspoof2019_path=None,
MLAADv3_path=None,
MAILABS_path=None,
AIHUB_path=None,
KoAAD_path=None,
subset: str = "val",
transform=None,
oversample: bool = True,
undersample: bool = False,
return_label: bool = True,
reduced_number: Optional[int] = None,
return_meta: bool = False,
):
super().__init__(
subset=subset,
transform=transform,
return_label=return_label,
return_meta=return_meta,
)
datasets = self._init_datasets(
asvspoof_path=asvspoof_path,
wavefake_path=wavefake_path,
fakeavceleb_path=fakeavceleb_path,
asvspoof2019_path=asvspoof2019_path,
MLAADv3_path=MLAADv3_path,
MAILABS_path=MAILABS_path,
AIHUB_path=AIHUB_path,
KoAAD_path=KoAAD_path,
subset=subset,
)
self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True)
if oversample:
self.oversample_dataset()
elif undersample:
self.undersample_dataset()
if reduced_number:
LOGGER.info(f"Using reduced number of samples - {reduced_number}!")
self.samples = self.samples.sample(
min(len(self.samples), reduced_number),
random_state=42,
)
def _init_datasets(
self,
subset: str,
asvspoof_path: Optional[str],
wavefake_path: Optional[str],
fakeavceleb_path: Optional[str],
asvspoof2019_path: Optional[str],
MLAADv3_path=Optional[str],
MAILABS_path=Optional[str],
AIHUB_path=Optional[str],
KoAAD_path=Optional[str],
) -> List[SimpleAudioFakeDataset]:
datasets = []
if asvspoof_path is not None:
asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset)
datasets.append(asvspoof_dataset)
if wavefake_path is not None:
wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset)
datasets.append(wavefake_dataset)
if fakeavceleb_path is not None:
fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset)
datasets.append(fakeavceleb_dataset)
if asvspoof2019_path is not None:
la_dataset = ASVSpoof2019DatasetOriginal(
asvspoof2019_path, fold_subset=subset
)
datasets.append(la_dataset)
if MLAADv3_path is not None:
MLAADv3_dataset = MLAADv3(MLAADv3_path, subset=subset)
datasets.append(MLAADv3_dataset)
if MAILABS_path is not None:
MAILABS_dataset = MAILABS(MAILABS_path, subset=subset)
datasets.append(MAILABS_dataset)
if AIHUB_path is not None:
aihub_dataset = AIHUB(AIHUB_path, subset=subset)
datasets.append(aihub_dataset)
if KoAAD_path is not None:
KoAAD_dataset = KoAAD(KoAAD_path, subset=subset)
datasets.append(KoAAD_dataset)
return datasets
def oversample_dataset(self):
samples = self.samples.groupby(by=["label"])
bona_length = len(samples.groups["bonafide"])
spoof_length = len(samples.groups["spoof"])
diff_length = spoof_length - bona_length
if diff_length < 0:
raise NotImplementedError
if diff_length > 0:
bonafide = samples.get_group("bonafide").sample(diff_length, replace=True)
self.samples = pd.concat([self.samples, bonafide], ignore_index=True)
def undersample_dataset(self):
samples = self.samples.groupby(by=["label"])
bona_length = len(samples.groups["bonafide"])
spoof_length = len(samples.groups["spoof"])
if spoof_length < bona_length:
raise NotImplementedError
if spoof_length > bona_length:
spoofs = samples.get_group("spoof").sample(bona_length, replace=True)
self.samples = pd.concat(
[samples.get_group("bonafide"), spoofs], ignore_index=True
)
def get_bonafide_only(self):
samples = self.samples.groupby(by=["label"])
self.samples = samples.get_group("bonafide")
return self.samples
def get_spoof_only(self):
samples = self.samples.groupby(by=["label"])
self.samples = samples.get_group("spoof")
return self.samples
|