|
|
import warnings |
|
|
import os |
|
|
import math |
|
|
import numpy as np |
|
|
import torch.utils.data |
|
|
import torchvision.transforms as transforms |
|
|
import torchvision.datasets as datasets |
|
|
from .base_provider import DataProvider |
|
|
from proard.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler |
|
|
|
|
|
__all__ = ["Cifar100DataProvider"] |
|
|
|
|
|
class Cifar100DataProvider(DataProvider): |
|
|
DEFAULT_PATH = "./dataset/cifar100" |
|
|
def __init__( |
|
|
self, |
|
|
save_path=None, |
|
|
train_batch_size=256, |
|
|
test_batch_size=512, |
|
|
resize_scale=0.08, |
|
|
distort_color=None, |
|
|
valid_size=None, |
|
|
n_worker=32, |
|
|
image_size=32, |
|
|
num_replicas=None, |
|
|
rank=None, |
|
|
): |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
self._save_path = save_path |
|
|
|
|
|
self.image_size = image_size |
|
|
|
|
|
|
|
|
self._valid_transform_dict = {} |
|
|
if not isinstance(self.image_size, int): |
|
|
from proard.utils.my_dataloader.my_data_loader import MyDataLoader |
|
|
|
|
|
assert isinstance(self.image_size, list) |
|
|
self.image_size.sort() |
|
|
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() |
|
|
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) |
|
|
|
|
|
for img_size in self.image_size: |
|
|
self._valid_transform_dict[img_size] = self.build_valid_transform( |
|
|
img_size |
|
|
) |
|
|
self.active_img_size = max(self.image_size) |
|
|
valid_transforms = self._valid_transform_dict[self.active_img_size] |
|
|
train_loader_class = MyDataLoader |
|
|
else: |
|
|
self.active_img_size = self.image_size |
|
|
valid_transforms = self.build_valid_transform() |
|
|
train_loader_class = torch.utils.data.DataLoader |
|
|
|
|
|
train_dataset = self.train_dataset(self.build_train_transform()) |
|
|
|
|
|
if valid_size is not None: |
|
|
if not isinstance(valid_size, int): |
|
|
assert isinstance(valid_size, float) and 0 < valid_size < 1 |
|
|
valid_size = int(len(train_dataset) * valid_size) |
|
|
|
|
|
valid_dataset = self.train_dataset(valid_transforms) |
|
|
train_indexes, valid_indexes = self.random_sample_valid_set( |
|
|
len(train_dataset), valid_size |
|
|
) |
|
|
|
|
|
if num_replicas is not None: |
|
|
train_sampler = MyDistributedSampler( |
|
|
train_dataset, num_replicas, rank, True, np.array(train_indexes) |
|
|
) |
|
|
valid_sampler = MyDistributedSampler( |
|
|
valid_dataset, num_replicas, rank, True, np.array(valid_indexes) |
|
|
) |
|
|
else: |
|
|
train_sampler = torch.utils.data.sampler.SubsetRandomSampler( |
|
|
train_indexes |
|
|
) |
|
|
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler( |
|
|
valid_indexes |
|
|
) |
|
|
|
|
|
self.train = train_loader_class( |
|
|
train_dataset, |
|
|
batch_size=train_batch_size, |
|
|
sampler=train_sampler, |
|
|
num_workers=n_worker, |
|
|
pin_memory=False, |
|
|
) |
|
|
self.valid = torch.utils.data.DataLoader( |
|
|
valid_dataset, |
|
|
batch_size=test_batch_size, |
|
|
sampler=valid_sampler, |
|
|
num_workers=n_worker, |
|
|
pin_memory=False, |
|
|
) |
|
|
else: |
|
|
if num_replicas is not None: |
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler( |
|
|
train_dataset, num_replicas, rank |
|
|
) |
|
|
self.train = train_loader_class( |
|
|
train_dataset, |
|
|
batch_size=train_batch_size, |
|
|
sampler=train_sampler, |
|
|
num_workers=n_worker, |
|
|
pin_memory=True, |
|
|
) |
|
|
else: |
|
|
self.train = train_loader_class( |
|
|
train_dataset, |
|
|
batch_size=train_batch_size, |
|
|
shuffle=True, |
|
|
num_workers=n_worker, |
|
|
pin_memory=False, |
|
|
) |
|
|
self.valid = None |
|
|
|
|
|
test_dataset = self.test_dataset(valid_transforms) |
|
|
if num_replicas is not None: |
|
|
test_sampler = torch.utils.data.distributed.DistributedSampler( |
|
|
test_dataset, num_replicas, rank |
|
|
) |
|
|
self.test = torch.utils.data.DataLoader( |
|
|
test_dataset, |
|
|
batch_size=test_batch_size, |
|
|
sampler=test_sampler, |
|
|
num_workers=n_worker, |
|
|
pin_memory=False, |
|
|
) |
|
|
else: |
|
|
self.test = torch.utils.data.DataLoader( |
|
|
test_dataset, |
|
|
batch_size=test_batch_size, |
|
|
shuffle=True, |
|
|
num_workers=n_worker, |
|
|
pin_memory=False, |
|
|
) |
|
|
|
|
|
if self.valid is None: |
|
|
self.valid = self.test |
|
|
|
|
|
@staticmethod |
|
|
def name(): |
|
|
return "cifar100" |
|
|
|
|
|
@property |
|
|
def data_shape(self): |
|
|
return 3, self.active_img_size, self.active_img_size |
|
|
|
|
|
@property |
|
|
def n_classes(self): |
|
|
return 100 |
|
|
|
|
|
@property |
|
|
def save_path(self): |
|
|
if self._save_path is None: |
|
|
self._save_path = self.DEFAULT_PATH |
|
|
if not os.path.exists(self._save_path): |
|
|
self._save_path = os.path.expanduser("~/dataset/cifar100") |
|
|
return self._save_path |
|
|
|
|
|
@property |
|
|
def data_url(self): |
|
|
raise ValueError("unable to download %s" % self.name()) |
|
|
|
|
|
def train_dataset(self, _transforms): |
|
|
return datasets.CIFAR100(self.train_path, train=True, transform=_transforms,download=True) |
|
|
|
|
|
def test_dataset(self, _transforms): |
|
|
return datasets.CIFAR100(self.valid_path, train=False, transform=_transforms,download=True) |
|
|
@property |
|
|
def train_path(self): |
|
|
return os.path.join(self.save_path, "train") |
|
|
|
|
|
@property |
|
|
def valid_path(self): |
|
|
return os.path.join(self.save_path, "val") |
|
|
|
|
|
@property |
|
|
def normalize(self): |
|
|
return transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) |
|
|
|
|
|
def build_train_transform(self, image_size=None, print_log=True): |
|
|
if image_size is None: |
|
|
image_size = self.image_size |
|
|
|
|
|
|
|
|
train_transforms = [ |
|
|
transforms.RandomCrop(32,padding=4), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
|
|
|
] |
|
|
|
|
|
train_transforms += [ |
|
|
transforms.ToTensor(), |
|
|
|
|
|
] |
|
|
|
|
|
train_transforms = transforms.Compose(train_transforms) |
|
|
return train_transforms |
|
|
|
|
|
def build_valid_transform(self, image_size=None): |
|
|
if image_size is None: |
|
|
image_size = self.active_img_size |
|
|
return transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
|
|
|
]) |
|
|
|
|
|
def assign_active_img_size(self, new_img_size): |
|
|
self.active_img_size = new_img_size |
|
|
if self.active_img_size not in self._valid_transform_dict: |
|
|
self._valid_transform_dict[ |
|
|
self.active_img_size |
|
|
] = self.build_valid_transform() |
|
|
|
|
|
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] |
|
|
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] |
|
|
|
|
|
def build_sub_train_loader( |
|
|
self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None |
|
|
): |
|
|
|
|
|
if self.__dict__.get("sub_train_%d" % self.active_img_size, None) is None: |
|
|
if num_worker is None: |
|
|
num_worker = self.train.num_workers |
|
|
|
|
|
n_samples = len(self.train.dataset) |
|
|
g = torch.Generator() |
|
|
g.manual_seed(DataProvider.SUB_SEED) |
|
|
rand_indexes = torch.randperm(n_samples, generator=g).tolist() |
|
|
|
|
|
new_train_dataset = self.train_dataset( |
|
|
self.build_train_transform( |
|
|
image_size=self.active_img_size, print_log=False |
|
|
) |
|
|
) |
|
|
chosen_indexes = rand_indexes[:n_images] |
|
|
if num_replicas is not None: |
|
|
sub_sampler = MyDistributedSampler( |
|
|
new_train_dataset, |
|
|
num_replicas, |
|
|
rank, |
|
|
True, |
|
|
np.array(chosen_indexes), |
|
|
) |
|
|
else: |
|
|
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler( |
|
|
chosen_indexes |
|
|
) |
|
|
sub_data_loader = torch.utils.data.DataLoader( |
|
|
new_train_dataset, |
|
|
batch_size=batch_size, |
|
|
sampler=sub_sampler, |
|
|
num_workers=num_worker, |
|
|
pin_memory=False, |
|
|
) |
|
|
self.__dict__["sub_train_%d" % self.active_img_size] = [] |
|
|
for images, labels in sub_data_loader: |
|
|
self.__dict__["sub_train_%d" % self.active_img_size].append( |
|
|
(images, labels) |
|
|
) |
|
|
return self.__dict__["sub_train_%d" % self.active_img_size] |
|
|
|