| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| | import os.path as osp |
| | import numpy as np |
| | from PIL import Image |
| | from termcolor import colored |
| | import torchvision.transforms as transforms |
| |
|
| |
|
| | class NormalDataset(): |
| |
|
| | def __init__(self, cfg, split='train'): |
| |
|
| | self.split = split |
| | self.root = cfg.root |
| | self.bsize = cfg.batch_size |
| | self.overfit = cfg.overfit |
| |
|
| | self.opt = cfg.dataset |
| | self.datasets = self.opt.types |
| | self.input_size = self.opt.input_size |
| | self.scales = self.opt.scales |
| |
|
| | |
| | self.in_nml = [item[0] for item in cfg.net.in_nml] |
| | self.in_nml_dim = [item[1] for item in cfg.net.in_nml] |
| | self.in_total = self.in_nml + ['render_B', 'render_L'] |
| | self.in_total_dim = self.in_nml_dim + [3, 3] |
| |
|
| | if self.split != 'train': |
| | self.rotations = range(0, 360, 120) |
| | else: |
| | self.rotations = np.arange(0, 360, 360 // |
| | self.opt.rotation_num).astype(np.int) |
| |
|
| | self.datasets_dict = {} |
| |
|
| | for dataset_id, dataset in enumerate(self.datasets): |
| |
|
| | dataset_dir = osp.join(self.root, dataset) |
| |
|
| | self.datasets_dict[dataset] = { |
| | "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), |
| | dtype=str), |
| | "scale": self.scales[dataset_id] |
| | } |
| |
|
| | self.subject_list = self.get_subject_list(split) |
| |
|
| | |
| | self.image_to_tensor = transforms.Compose([ |
| | transforms.Resize(self.input_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| | ]) |
| |
|
| | |
| | self.mask_to_tensor = transforms.Compose([ |
| | transforms.Resize(self.input_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.0, ), (1.0, )) |
| | ]) |
| |
|
| | def get_subject_list(self, split): |
| |
|
| | subject_list = [] |
| |
|
| | for dataset in self.datasets: |
| |
|
| | split_txt = osp.join(self.root, dataset, f'{split}.txt') |
| |
|
| | if osp.exists(split_txt): |
| | print(f"load from {split_txt}") |
| | subject_list += np.loadtxt(split_txt, dtype=str).tolist() |
| | else: |
| | full_txt = osp.join(self.root, dataset, 'all.txt') |
| | print(f"split {full_txt} into train/val/test") |
| |
|
| | full_lst = np.loadtxt(full_txt, dtype=str) |
| | full_lst = [dataset + "/" + item for item in full_lst] |
| | [train_lst, test_lst, |
| | val_lst] = np.split(full_lst, [ |
| | 500, |
| | 500 + 5, |
| | ]) |
| |
|
| | np.savetxt(full_txt.replace("all", "train"), |
| | train_lst, |
| | fmt="%s") |
| | np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s") |
| | np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s") |
| |
|
| | print(f"load from {split_txt}") |
| | subject_list += np.loadtxt(split_txt, dtype=str).tolist() |
| |
|
| | if self.split != 'test': |
| | subject_list += subject_list[:self.bsize - |
| | len(subject_list) % self.bsize] |
| | print(colored(f"total: {len(subject_list)}", "yellow")) |
| | random.shuffle(subject_list) |
| |
|
| | |
| | return subject_list |
| |
|
| | def __len__(self): |
| | return len(self.subject_list) * len(self.rotations) |
| |
|
| | def __getitem__(self, index): |
| |
|
| | |
| | if self.overfit: |
| | index = 0 |
| |
|
| | rid = index % len(self.rotations) |
| | mid = index // len(self.rotations) |
| |
|
| | rotation = self.rotations[rid] |
| | subject = self.subject_list[mid].split("/")[1] |
| | dataset = self.subject_list[mid].split("/")[0] |
| | render_folder = "/".join( |
| | [dataset + f"_{self.opt.rotation_num}views", subject]) |
| |
|
| | |
| | data_dict = { |
| | 'dataset': |
| | dataset, |
| | 'subject': |
| | subject, |
| | 'rotation': |
| | rotation, |
| | 'scale': |
| | self.datasets_dict[dataset]["scale"], |
| | 'image_path': |
| | osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png') |
| | } |
| |
|
| | |
| | for name, channel in zip(self.in_total, self.in_total_dim): |
| |
|
| | if f'{name}_path' not in data_dict.keys(): |
| | data_dict.update({ |
| | f'{name}_path': |
| | osp.join(self.root, render_folder, name, |
| | f'{rotation:03d}.png') |
| | }) |
| |
|
| | |
| | data_dict.update({ |
| | name: |
| | self.imagepath2tensor(data_dict[f'{name}_path'], |
| | channel, |
| | inv=False) |
| | }) |
| |
|
| | path_keys = [ |
| | key for key in data_dict.keys() if '_path' in key or '_dir' in key |
| | ] |
| |
|
| | for key in path_keys: |
| | del data_dict[key] |
| |
|
| | return data_dict |
| |
|
| | def imagepath2tensor(self, path, channel=3, inv=False): |
| |
|
| | rgba = Image.open(path).convert('RGBA') |
| | mask = rgba.split()[-1] |
| | image = rgba.convert('RGB') |
| | image = self.image_to_tensor(image) |
| | mask = self.mask_to_tensor(mask) |
| | image = (image * mask)[:channel] |
| |
|
| | return (image * (0.5 - inv) * 2.0).float() |
| |
|