import warnings import os import math import numpy as np import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets from .base_provider import DataProvider from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler __all__ = ["Cifar10DataProvider"] class Cifar10DataProvider(DataProvider): DEFAULT_PATH = "./dataset/cifar10" def __init__( self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, resize_scale=0.08, distort_color=None, n_worker=32, image_size=32, num_replicas=None, rank=None, ): warnings.filterwarnings("ignore") self._save_path = save_path self.image_size = image_size # int or list of int self._valid_transform_dict = {} if not isinstance(self.image_size, int): from proard.utils.my_dataloader.my_data_loader import MyDataLoader assert isinstance(self.image_size, list) self.image_size.sort() # e.g., 160 -> 224 MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) for img_size in self.image_size: self._valid_transform_dict[img_size] = self.build_valid_transform( img_size ) self.active_img_size = max(self.image_size) # active resolution for test valid_transforms = self._valid_transform_dict[self.active_img_size] train_loader_class = MyDataLoader # randomly sample image size for each batch of training image else: self.active_img_size = self.image_size valid_transforms = self.build_valid_transform() train_loader_class = torch.utils.data.DataLoader train_dataset = self.train_dataset(self.build_train_transform()) if valid_size is not None: if not isinstance(valid_size, int): assert isinstance(valid_size, float) and 0 < valid_size < 1 valid_size = int(len(train_dataset) * valid_size) valid_dataset = self.train_dataset(valid_transforms) train_indexes, valid_indexes = self.random_sample_valid_set( len(train_dataset), valid_size ) if num_replicas is not None: train_sampler = MyDistributedSampler( train_dataset, num_replicas, rank, True, np.array(train_indexes) ) valid_sampler = MyDistributedSampler( valid_dataset, num_replicas, rank, True, np.array(valid_indexes) ) else: train_sampler = torch.utils.data.sampler.SubsetRandomSampler( train_indexes ) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler( valid_indexes ) self.train = train_loader_class( train_dataset, batch_size=train_batch_size, sampler=train_sampler, num_workers=n_worker, pin_memory=False, ) self.valid = torch.utils.data.DataLoader( valid_dataset, batch_size=test_batch_size, sampler=valid_sampler, num_workers=n_worker, pin_memory=False, ) else: if num_replicas is not None: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas, rank ) self.train = train_loader_class( train_dataset, batch_size=train_batch_size, sampler=train_sampler, num_workers=n_worker, pin_memory=True, ) else: self.train = train_loader_class( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=False, ) self.valid = None test_dataset = self.test_dataset(valid_transforms) if num_replicas is not None: test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas, rank ) self.test = torch.utils.data.DataLoader( test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=False, ) else: self.test = torch.utils.data.DataLoader( test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=False, ) if self.valid is None: self.valid = self.test @staticmethod def name(): return "cifar10" @property def data_shape(self): return 3, self.active_img_size, self.active_img_size # C, H, W @property def n_classes(self): return 10 @property def save_path(self): if self._save_path is None: self._save_path = self.DEFAULT_PATH if not os.path.exists(self._save_path): self._save_path = os.path.expanduser("~/dataset/cifar10") return self._save_path @property def data_url(self): raise ValueError("unable to download %s" % self.name()) def train_dataset(self, _transforms): return datasets.CIFAR10(self.train_path, train=True, transform=_transforms,download=True) def test_dataset(self, _transforms): return datasets.CIFAR10(self.valid_path, train=False, transform=_transforms,download=True) @property def train_path(self): return os.path.join(self.save_path, "train") @property def valid_path(self): return os.path.join(self.save_path, "val") @property def normalize(self): return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) def build_train_transform(self, image_size=None, print_log=True): if image_size is None: image_size = self.image_size # random_resize_crop -> random_horizontal_flip train_transforms = [ transforms.RandomCrop(32,padding=4), transforms.RandomHorizontalFlip(), # AutoAugment(), ] train_transforms += [ transforms.ToTensor(), # self.normalize, ] train_transforms = transforms.Compose(train_transforms) return train_transforms def build_valid_transform(self, image_size=None): if image_size is None: image_size = self.active_img_size return transforms.Compose([ transforms.ToTensor(), # self.normalize, ]) def assign_active_img_size(self, new_img_size): self.active_img_size = new_img_size if self.active_img_size not in self._valid_transform_dict: self._valid_transform_dict[ self.active_img_size ] = self.build_valid_transform() # change the transform of the valid and test set self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] def build_sub_train_loader( self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None ): # used for resetting BN running statistics if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None: if num_worker is None: num_worker = self.train.num_workers n_samples = len(self.train.dataset) g = torch.Generator() g.manual_seed(DataProvider.SUB_SEED) rand_indexes = torch.randperm(n_samples, generator=g).tolist() new_train_dataset = self.train_dataset( self.build_train_transform( image_size=self.active_img_size, print_log=False ) ) chosen_indexes = rand_indexes[:n_images] if num_replicas is not None: sub_sampler = MyDistributedSampler( new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes), ) else: sub_sampler = torch.utils.data.sampler.SubsetRandomSampler( chosen_indexes ) sub_data_loader = torch.utils.data.DataLoader( new_train_dataset, batch_size=batch_size, sampler=sub_sampler, num_workers=num_worker, pin_memory=False, ) self.__dict__["sub_train_%d" % self.active_img_size] = [] for images, labels in sub_data_loader: self.__dict__["sub_train_%d" % self.active_img_size].append( (images, labels) ) return self.__dict__["sub_train_%d" % self.active_img_size]