Spaces:
Runtime error
Runtime error
| from ast import dump | |
| import os | |
| from turtle import left | |
| # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
| # os.environ['TORCH_USE_CUDA_DSA'] = '1' | |
| os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' | |
| # from inflection import camelize | |
| import yaml | |
| import shutil | |
| import collections | |
| import torch | |
| import torch.utils.data | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 as cv | |
| import glob | |
| import datetime | |
| import trimesh | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| import importlib | |
| import json | |
| import config | |
| from network.lpips import LPIPS | |
| from dataset.dataset_pose import PoseDataset | |
| import utils.net_util as net_util | |
| import utils.visualize_util as visualize_util | |
| from utils.renderer import Renderer | |
| from utils.net_util import to_cuda | |
| from utils.obj_io import save_mesh_as_ply | |
| from gaussians.obj_io import save_gaussians_as_ply | |
| def safe_exists(path): | |
| if path is None: | |
| return False | |
| return os.path.exists(path) | |
| class AvatarTrainer: | |
| def __init__(self, opt): | |
| self.opt = opt | |
| self.patch_size = 512 | |
| self.iter_idx = 0 | |
| self.iter_num = 800000 | |
| self.lr_init = float(self.opt['train'].get('lr_init', 5e-4)) | |
| avatar_module = self.opt['model'].get('module', 'network.avatar') | |
| print('Import AvatarNet from %s' % avatar_module) | |
| AvatarNet = importlib.import_module(avatar_module).AvatarNet | |
| self.avatar_net = AvatarNet(self.opt['model']).to(config.device) | |
| self.optm = torch.optim.Adam( | |
| self.avatar_net.parameters(), lr = self.lr_init | |
| ) | |
| self.random_bg_color = self.opt['train'].get('random_bg_color', True) | |
| self.bg_color = (1., 1., 1.) | |
| self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(config.device) | |
| self.loss_weight = self.opt['train']['loss_weight'] | |
| self.finetune_color = self.opt['train']['finetune_color'] | |
| print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()]))) | |
| def update_lr(self): | |
| alpha = 0.05 | |
| progress = self.iter_idx / self.iter_num | |
| learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha | |
| lr = self.lr_init * learning_factor | |
| for param_group in self.optm.param_groups: | |
| param_group['lr'] = lr | |
| return lr | |
| def requires_net_grad(net: torch.nn.Module, flag = True): | |
| for p in net.parameters(): | |
| p.requires_grad = flag | |
| def crop_image(self, gt_mask, patch_size, randomly, *args): | |
| """ | |
| :param gt_mask: (H, W) | |
| :param patch_size: resize the cropped patch to the given patch_size | |
| :param randomly: whether to randomly sample the patch | |
| :param args: input images with shape of (C, H, W) | |
| """ | |
| mask_uv = torch.argwhere(gt_mask > 0.) | |
| min_v, min_u = mask_uv.min(0)[0] | |
| max_v, max_u = mask_uv.max(0)[0] | |
| len_v = max_v - min_v | |
| len_u = max_u - min_u | |
| max_size = max(len_v, len_u) | |
| cropped_images = [] | |
| if randomly and max_size > patch_size: | |
| random_v = torch.randint(0, max_size - patch_size + 1, (1,)).to(max_size) | |
| random_u = torch.randint(0, max_size - patch_size + 1, (1,)).to(max_size) | |
| for image in args: | |
| cropped_image = self.bg_color_cuda[:, None, None] * torch.ones((3, max_size, max_size), dtype = image.dtype, device = image.device) | |
| if len_v > len_u: | |
| start_u = (max_size - len_u) // 2 | |
| cropped_image[:, :, start_u: start_u + len_u] = image[:, min_v: max_v, min_u: max_u] | |
| else: | |
| start_v = (max_size - len_v) // 2 | |
| cropped_image[:, start_v: start_v + len_v, :] = image[:, min_v: max_v, min_u: max_u] | |
| if randomly and max_size > patch_size: | |
| cropped_image = cropped_image[:, random_v: random_v + patch_size, random_u: random_u + patch_size] | |
| else: | |
| cropped_image = F.interpolate(cropped_image[None], size = (patch_size, patch_size), mode = 'bilinear')[0] | |
| cropped_images.append(cropped_image) | |
| # cv.imshow('cropped_image', cropped_image.detach().cpu().numpy().transpose(1, 2, 0)) | |
| # cv.imshow('cropped_gt_image', cropped_gt_image.detach().cpu().numpy().transpose(1, 2, 0)) | |
| # cv.waitKey(0) | |
| if len(cropped_images) > 1: | |
| return cropped_images | |
| else: | |
| return cropped_images[0] | |
| def compute_lpips_loss(self, image, gt_image): | |
| assert image.shape[1] == image.shape[2] and gt_image.shape[1] == gt_image.shape[2] | |
| lpips_loss = self.lpips.forward( | |
| image[None, [2, 1, 0]], | |
| gt_image[None, [2, 1, 0]], | |
| normalize = True | |
| ).mean() | |
| return lpips_loss | |
| def forward_one_pass_pretrain(self, items): | |
| total_loss = 0 | |
| batch_losses = {} | |
| l1_loss = torch.nn.L1Loss() | |
| items = net_util.delete_batch_idx(items) | |
| pose_map = items['smpl_pos_map'][:3] | |
| position_loss = l1_loss(self.avatar_net.get_positions(pose_map), self.avatar_net.cano_gaussian_model.get_xyz) | |
| total_loss += position_loss | |
| batch_losses.update({ | |
| 'position': position_loss.item() | |
| }) | |
| opacity, scales, rotations = self.avatar_net.get_others(pose_map) | |
| opacity_loss = l1_loss(opacity, self.avatar_net.cano_gaussian_model.get_opacity) | |
| total_loss += opacity_loss | |
| batch_losses.update({ | |
| 'opacity': opacity_loss.item() | |
| }) | |
| scale_loss = l1_loss(scales, self.avatar_net.cano_gaussian_model.get_scaling) | |
| total_loss += scale_loss | |
| batch_losses.update({ | |
| 'scale': scale_loss.item() | |
| }) | |
| rotation_loss = l1_loss(rotations, self.avatar_net.cano_gaussian_model.get_rotation) | |
| total_loss += rotation_loss | |
| batch_losses.update({ | |
| 'rotation': rotation_loss.item() | |
| }) | |
| total_loss.backward() | |
| self.optm.step() | |
| self.optm.zero_grad() | |
| return total_loss, batch_losses | |
| def forward_one_pass(self, items): | |
| # forward_start = torch.cuda.Event(enable_timing = True) | |
| # forward_end = torch.cuda.Event(enable_timing = True) | |
| # backward_start = torch.cuda.Event(enable_timing = True) | |
| # backward_end = torch.cuda.Event(enable_timing = True) | |
| # step_start = torch.cuda.Event(enable_timing = True) | |
| # step_end = torch.cuda.Event(enable_timing = True) | |
| if self.random_bg_color: | |
| self.bg_color = np.random.rand(3) | |
| self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(config.device) | |
| total_loss = 0 | |
| batch_losses = {} | |
| items = net_util.delete_batch_idx(items) | |
| """ Optimize generator """ | |
| if self.finetune_color: | |
| self.requires_net_grad(self.avatar_net.color_net, True) | |
| self.requires_net_grad(self.avatar_net.position_net, False) | |
| self.requires_net_grad(self.avatar_net.other_net, True) | |
| else: | |
| self.requires_net_grad(self.avatar_net, True) | |
| # forward_start.record() | |
| render_output = self.avatar_net.render(items, self.bg_color) | |
| image = render_output['rgb_map'].permute(2, 0, 1) | |
| offset = render_output['offset'] | |
| # mask image & set bg color | |
| items['color_img'][~items['mask_img']] = self.bg_color_cuda | |
| gt_image = items['color_img'].permute(2, 0, 1) | |
| mask_img = items['mask_img'].to(torch.float32) | |
| boundary_mask_img = 1. - items['boundary_mask_img'].to(torch.float32) | |
| image = image * boundary_mask_img[None] + (1. - boundary_mask_img[None]) * self.bg_color_cuda[:, None, None] | |
| gt_image = gt_image * boundary_mask_img[None] + (1. - boundary_mask_img[None]) * self.bg_color_cuda[:, None, None] | |
| # cv.imshow('image', image.detach().permute(1, 2, 0).cpu().numpy()) | |
| # cv.imshow('gt_image', gt_image.permute(1, 2, 0).cpu().numpy()) | |
| # cv.waitKey(0) | |
| if self.loss_weight['l1'] > 0.: | |
| l1_loss = torch.abs(image - gt_image).mean() | |
| total_loss += self.loss_weight['l1'] * l1_loss | |
| batch_losses.update({ | |
| 'l1_loss': l1_loss.item() | |
| }) | |
| if self.loss_weight.get('mask', 0.) and 'mask_map' in render_output: | |
| rendered_mask = render_output['mask_map'].squeeze(-1) * boundary_mask_img | |
| gt_mask = mask_img * boundary_mask_img | |
| # cv.imshow('rendered_mask', rendered_mask.detach().cpu().numpy()) | |
| # cv.imshow('gt_mask', gt_mask.detach().cpu().numpy()) | |
| # cv.waitKey(0) | |
| mask_loss = torch.abs(rendered_mask - gt_mask).mean() | |
| # mask_loss = torch.nn.BCELoss()(rendered_mask, gt_mask) | |
| total_loss += self.loss_weight.get('mask', 0.) * mask_loss | |
| batch_losses.update({ | |
| 'mask_loss': mask_loss.item() | |
| }) | |
| if self.loss_weight['lpips'] > 0.: | |
| # crop images | |
| random_patch_flag = False if self.iter_idx < 300000 else True | |
| image, gt_image = self.crop_image(mask_img, self.patch_size, random_patch_flag, image, gt_image) | |
| # cv.imshow('image', image.detach().permute(1, 2, 0).cpu().numpy()) | |
| # cv.imshow('gt_image', gt_image.permute(1, 2, 0).cpu().numpy()) | |
| # cv.waitKey(0) | |
| lpips_loss = self.compute_lpips_loss(image, gt_image) | |
| total_loss += self.loss_weight['lpips'] * lpips_loss | |
| batch_losses.update({ | |
| 'lpips_loss': lpips_loss.item() | |
| }) | |
| # if self.loss_weight['offset'] > 0.: | |
| if True: | |
| offset_loss = torch.linalg.norm(offset, dim = -1).mean() | |
| total_loss += self.loss_weight['offset'] * offset_loss | |
| batch_losses.update({ | |
| 'offset_loss': offset_loss.item() | |
| }) | |
| # forward_end.record() | |
| # backward_start.record() | |
| total_loss.backward() | |
| # backward_end.record() | |
| # step_start.record() | |
| self.optm.step() | |
| self.optm.zero_grad() | |
| # step_end.record() | |
| # torch.cuda.synchronize() | |
| # print(f'Forward costs: {forward_start.elapsed_time(forward_end) / 1000.}, ', | |
| # f'Backward costs: {backward_start.elapsed_time(backward_end) / 1000.}, ', | |
| # f'Step costs: {step_start.elapsed_time(step_end) / 1000.}') | |
| return total_loss, batch_losses | |
| def pretrain(self): | |
| dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') | |
| MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
| self.dataset = MvRgbDataset(**self.opt['train']['data']) | |
| batch_size = self.opt['train']['batch_size'] | |
| num_workers = self.opt['train']['num_workers'] | |
| batch_num = len(self.dataset) // batch_size | |
| dataloader = torch.utils.data.DataLoader(self.dataset, | |
| batch_size = batch_size, | |
| shuffle = True, | |
| num_workers = num_workers, | |
| drop_last = True) | |
| # tb writer | |
| log_dir = self.opt['train']['net_ckpt_dir'] + '/' + datetime.datetime.now().strftime('pretrain_%Y_%m_%d_%H_%M_%S') | |
| writer = SummaryWriter(log_dir) | |
| smooth_interval = 10 | |
| smooth_count = 0 | |
| smooth_losses = {} | |
| for epoch_idx in range(0, 9999999): | |
| self.epoch_idx = epoch_idx | |
| for batch_idx, items in enumerate(dataloader): | |
| self.iter_idx = batch_idx + epoch_idx * batch_num | |
| items = to_cuda(items) | |
| # one_step_start.record() | |
| total_loss, batch_losses = self.forward_one_pass_pretrain(items) | |
| # one_step_end.record() | |
| # torch.cuda.synchronize() | |
| # print('One step costs %f secs' % (one_step_start.elapsed_time(one_step_end) / 1000.)) | |
| # record batch loss | |
| for key, loss in batch_losses.items(): | |
| if key in smooth_losses: | |
| smooth_losses[key] += loss | |
| else: | |
| smooth_losses[key] = loss | |
| smooth_count += 1 | |
| if self.iter_idx % smooth_interval == 0: | |
| log_info = 'epoch %d, batch %d, iter %d, ' % (epoch_idx, batch_idx, self.iter_idx) | |
| for key in smooth_losses.keys(): | |
| smooth_losses[key] /= smooth_count | |
| writer.add_scalar('%s/Iter' % key, smooth_losses[key], self.iter_idx) | |
| log_info = log_info + ('%s: %f, ' % (key, smooth_losses[key])) | |
| smooth_losses[key] = 0. | |
| smooth_count = 0 | |
| print(log_info) | |
| with open(os.path.join(log_dir, 'loss.txt'), 'a') as fp: | |
| fp.write(log_info + '\n') | |
| if self.iter_idx % 200 == 0 and self.iter_idx != 0: | |
| self.mini_test(pretraining = True) | |
| if self.iter_idx == 5000: | |
| model_folder = self.opt['train']['net_ckpt_dir'] + '/pretrained' | |
| os.makedirs(model_folder, exist_ok = True) | |
| self.save_ckpt(model_folder, save_optm = True) | |
| self.iter_idx = 0 | |
| return | |
| def train(self): | |
| dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') | |
| MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
| self.dataset = MvRgbDataset(**self.opt['train']['data']) | |
| batch_size = self.opt['train']['batch_size'] | |
| num_workers = self.opt['train']['num_workers'] | |
| batch_num = len(self.dataset) // batch_size | |
| dataloader = torch.utils.data.DataLoader(self.dataset, | |
| batch_size = batch_size, | |
| shuffle = True, | |
| num_workers = num_workers, | |
| drop_last = True) | |
| if 'lpips' in self.opt['train']['loss_weight']: | |
| self.lpips = LPIPS(net = 'vgg').to(config.device) | |
| for p in self.lpips.parameters(): | |
| p.requires_grad = False | |
| if self.opt['train']['prev_ckpt'] is not None: | |
| start_epoch, self.iter_idx = self.load_ckpt(self.opt['train']['prev_ckpt'], load_optm = True) | |
| start_epoch += 1 | |
| self.iter_idx += 1 | |
| else: | |
| prev_ckpt_path = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' | |
| if safe_exists(prev_ckpt_path): | |
| start_epoch, self.iter_idx = self.load_ckpt(prev_ckpt_path, load_optm = True) | |
| start_epoch += 1 | |
| self.iter_idx += 1 | |
| else: | |
| if safe_exists(self.opt['train']['pretrained_dir']): | |
| self.load_ckpt(self.opt['train']['pretrained_dir'], load_optm = False) | |
| elif safe_exists(self.opt['train']['net_ckpt_dir'] + '/pretrained'): | |
| self.load_ckpt(self.opt['train']['net_ckpt_dir'] + '/pretrained', load_optm = False) | |
| else: | |
| raise FileNotFoundError('Cannot find pretrained checkpoint!') | |
| self.optm.state = collections.defaultdict(dict) | |
| start_epoch = 0 | |
| self.iter_idx = 0 | |
| # one_step_start = torch.cuda.Event(enable_timing = True) | |
| # one_step_end = torch.cuda.Event(enable_timing = True) | |
| # tb writer | |
| log_dir = self.opt['train']['net_ckpt_dir'] + '/' + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |
| writer = SummaryWriter(log_dir) | |
| yaml.dump(self.opt, open(log_dir + '/config_bk.yaml', 'w'), sort_keys = False) | |
| smooth_interval = 10 | |
| smooth_count = 0 | |
| smooth_losses = {} | |
| for epoch_idx in range(start_epoch, 9999999): | |
| self.epoch_idx = epoch_idx | |
| for batch_idx, items in enumerate(dataloader): | |
| lr = self.update_lr() | |
| items = to_cuda(items) | |
| # one_step_start.record() | |
| total_loss, batch_losses = self.forward_one_pass(items) | |
| # one_step_end.record() | |
| # torch.cuda.synchronize() | |
| # print('One step costs %f secs' % (one_step_start.elapsed_time(one_step_end) / 1000.)) | |
| # record batch loss | |
| for key, loss in batch_losses.items(): | |
| if key in smooth_losses: | |
| smooth_losses[key] += loss | |
| else: | |
| smooth_losses[key] = loss | |
| smooth_count += 1 | |
| if self.iter_idx % smooth_interval == 0: | |
| log_info = 'epoch %d, batch %d, iter %d, lr %e, ' % (epoch_idx, batch_idx, self.iter_idx, lr) | |
| for key in smooth_losses.keys(): | |
| smooth_losses[key] /= smooth_count | |
| writer.add_scalar('%s/Iter' % key, smooth_losses[key], self.iter_idx) | |
| log_info = log_info + ('%s: %f, ' % (key, smooth_losses[key])) | |
| smooth_losses[key] = 0. | |
| smooth_count = 0 | |
| print(log_info) | |
| with open(os.path.join(log_dir, 'loss.txt'), 'a') as fp: | |
| fp.write(log_info + '\n') | |
| torch.cuda.empty_cache() | |
| if self.iter_idx % self.opt['train']['eval_interval'] == 0 and self.iter_idx != 0: | |
| if self.iter_idx % (10 * self.opt['train']['eval_interval']) == 0: | |
| eval_cano_pts = True | |
| else: | |
| eval_cano_pts = False | |
| self.mini_test(eval_cano_pts = eval_cano_pts) | |
| if self.iter_idx % self.opt['train']['ckpt_interval']['batch'] == 0 and self.iter_idx != 0: | |
| for folder in glob.glob(self.opt['train']['net_ckpt_dir'] + '/batch_*'): | |
| shutil.rmtree(folder) | |
| model_folder = self.opt['train']['net_ckpt_dir'] + '/batch_%d' % self.iter_idx | |
| os.makedirs(model_folder, exist_ok = True) | |
| self.save_ckpt(model_folder, save_optm = True) | |
| if self.iter_idx == self.iter_num: | |
| print('# Training is done.') | |
| return | |
| self.iter_idx += 1 | |
| """ End of epoch """ | |
| if epoch_idx % self.opt['train']['ckpt_interval']['epoch'] == 0 and epoch_idx != 0: | |
| model_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_%d' % epoch_idx | |
| os.makedirs(model_folder, exist_ok = True) | |
| self.save_ckpt(model_folder) | |
| if batch_num > 50: | |
| latest_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' | |
| os.makedirs(latest_folder, exist_ok = True) | |
| self.save_ckpt(latest_folder) | |
| def mini_test(self, pretraining = False, eval_cano_pts = False): | |
| self.avatar_net.eval() | |
| img_factor = self.opt['train'].get('eval_img_factor', 1.0) | |
| # training data | |
| pose_idx, view_idx = self.opt['train'].get('eval_training_ids', (310, 19)) | |
| intr = self.dataset.intr_mats[view_idx].copy() | |
| intr[:2] *= img_factor | |
| item = self.dataset.getitem(0, | |
| pose_idx = pose_idx, | |
| view_idx = view_idx, | |
| training = False, | |
| eval = True, | |
| img_h = int(self.dataset.img_heights[view_idx] * img_factor), | |
| img_w = int(self.dataset.img_widths[view_idx] * img_factor), | |
| extr = self.dataset.extr_mats[view_idx], | |
| intr = intr, | |
| exact_hand_pose = True) | |
| items = net_util.to_cuda(item, add_batch = False) | |
| gs_render = self.avatar_net.render(items, self.bg_color) | |
| # gs_render = self.avatar_net.render_debug(items) | |
| rgb_map = gs_render['rgb_map'] | |
| rgb_map.clip_(0., 1.) | |
| rgb_map = (rgb_map.cpu().numpy() * 255).astype(np.uint8) | |
| # cv.imshow('rgb_map', rgb_map.cpu().numpy()) | |
| # cv.waitKey(0) | |
| if not pretraining: | |
| output_dir = self.opt['train']['net_ckpt_dir'] + '/eval/training' | |
| else: | |
| output_dir = self.opt['train']['net_ckpt_dir'] + '/eval_pretrain/training' | |
| gt_image, _ = self.dataset.load_color_mask_images(pose_idx, view_idx) | |
| if gt_image is not None: | |
| gt_image = cv.resize(gt_image, (0, 0), fx = img_factor, fy = img_factor) | |
| rgb_map = np.concatenate([rgb_map, gt_image], 1) | |
| os.makedirs(output_dir, exist_ok = True) | |
| cv.imwrite(output_dir + '/iter_%d.jpg' % self.iter_idx, rgb_map) | |
| if eval_cano_pts: | |
| os.makedirs(output_dir + '/cano_pts', exist_ok = True) | |
| save_mesh_as_ply(output_dir + '/cano_pts/iter_%d.ply' % self.iter_idx, (self.avatar_net.init_points + gs_render['offset']).cpu().numpy()) | |
| # training data | |
| pose_idx, view_idx = self.opt['train'].get('eval_testing_ids', (310, 19)) | |
| intr = self.dataset.intr_mats[view_idx].copy() | |
| intr[:2] *= img_factor | |
| item = self.dataset.getitem(0, | |
| pose_idx = pose_idx, | |
| view_idx = view_idx, | |
| training = False, | |
| eval = True, | |
| img_h = int(self.dataset.img_heights[view_idx] * img_factor), | |
| img_w = int(self.dataset.img_widths[view_idx] * img_factor), | |
| extr = self.dataset.extr_mats[view_idx], | |
| intr = intr, | |
| exact_hand_pose = True) | |
| items = net_util.to_cuda(item, add_batch = False) | |
| gs_render = self.avatar_net.render(items, bg_color = self.bg_color) | |
| # gs_render = self.avatar_net.render_debug(items) | |
| rgb_map = gs_render['rgb_map'] | |
| rgb_map.clip_(0., 1.) | |
| rgb_map = (rgb_map.cpu().numpy() * 255).astype(np.uint8) | |
| # cv.imshow('rgb_map', rgb_map.cpu().numpy()) | |
| # cv.waitKey(0) | |
| if not pretraining: | |
| output_dir = self.opt['train']['net_ckpt_dir'] + '/eval/testing' | |
| else: | |
| output_dir = self.opt['train']['net_ckpt_dir'] + '/eval_pretrain/testing' | |
| gt_image, _ = self.dataset.load_color_mask_images(pose_idx, view_idx) | |
| if gt_image is not None: | |
| gt_image = cv.resize(gt_image, (0, 0), fx = img_factor, fy = img_factor) | |
| rgb_map = np.concatenate([rgb_map, gt_image], 1) | |
| os.makedirs(output_dir, exist_ok = True) | |
| cv.imwrite(output_dir + '/iter_%d.jpg' % self.iter_idx, rgb_map) | |
| if eval_cano_pts: | |
| os.makedirs(output_dir + '/cano_pts', exist_ok = True) | |
| save_mesh_as_ply(output_dir + '/cano_pts/iter_%d.ply' % self.iter_idx, (self.avatar_net.init_points + gs_render['offset']).cpu().numpy()) | |
| self.avatar_net.train() | |
| def dump_renderer_info(self, dump_dir, extrs, intrs, img_heights, img_widths): | |
| with open(os.path.join(dump_dir, 'cfg_args'), 'w') as fp: | |
| outstr = "Namespace(sh_degree=%d, source_path='%s', model_path='%s', images='images', resolution=-1, " \ | |
| "white_background=False, data_device='cuda', eval=False)" % ( | |
| 3, self.opt['train']['data']['data_dir'], dump_dir) | |
| fp.write(outstr) | |
| with open(os.path.join(dump_dir, 'cameras.json'), 'w') as fp: | |
| cam_jsons = [] | |
| for ci in range(len(extrs)): | |
| extr, intr = extrs[ci], intrs[ci] | |
| img_h, img_w = img_heights[ci], img_widths[ci] | |
| w2c = extr | |
| c2w = np.linalg.inv(w2c) | |
| pos = c2w[:3, 3] | |
| rot = c2w[:3, :3] | |
| serializable_array_2d = [x.tolist() for x in rot] | |
| camera_entry = { | |
| 'id': ci, | |
| 'img_name': '%08d' % ci, | |
| 'width': int(img_w), | |
| 'height': int(img_h), | |
| 'position': pos.tolist(), | |
| 'rotation': serializable_array_2d, | |
| 'fy': float(intr[1, 1]), | |
| 'fx': float(intr[0, 0]), | |
| } | |
| cam_jsons.append(camera_entry) | |
| json.dump(cam_jsons, fp) | |
| return | |
| def test(self): | |
| self.avatar_net.eval() | |
| # ipdb.set_trace() | |
| dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') | |
| MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
| training_dataset = MvRgbDataset(**self.opt['train']['data'], training = False) | |
| if self.opt['test'].get('n_pca', -1) >= 1: | |
| training_dataset.compute_pca(n_components = self.opt['test']['n_pca']) | |
| if 'pose_data' in self.opt['test']: | |
| testing_dataset = PoseDataset(**self.opt['test']['pose_data'], smpl_shape = training_dataset.smpl_data['betas'][0]) | |
| dataset_name = testing_dataset.dataset_name | |
| seq_name = testing_dataset.seq_name | |
| else: | |
| testing_dataset = MvRgbDataset(**self.opt['test']['data'], training = False) | |
| dataset_name = 'training' | |
| seq_name = '' | |
| # print('come here') | |
| self.dataset = testing_dataset | |
| iter_idx = self.load_ckpt(self.opt['test']['prev_ckpt'], False)[1] | |
| output_dir = self.opt['test'].get('output_dir', None) | |
| if output_dir is None: | |
| view_setting = config.opt['test'].get('view_setting', 'free') | |
| if view_setting == 'camera': | |
| view_folder = 'cam_%03d' % config.opt['test']['render_view_idx'] | |
| else: | |
| view_folder = view_setting + '_view' | |
| exp_name = os.path.basename(os.path.dirname(self.opt['test']['prev_ckpt'])) | |
| output_dir = f'./test_results/{training_dataset.subject_name}/{exp_name}/{dataset_name}_{seq_name}_{view_folder}' + '/batch_%06d' % iter_idx | |
| use_pca = self.opt['test'].get('n_pca', -1) >= 1 | |
| if use_pca: | |
| output_dir += '/pca_%d_sigma_%.2f' % (self.opt['test'].get('n_pca', -1), float(self.opt['test'].get('sigma_pca', 1.))) | |
| else: | |
| output_dir += '/vanilla' | |
| print('# Output dir: \033[1;31m%s\033[0m' % output_dir) | |
| os.makedirs(output_dir + '/live_skeleton', exist_ok = True) | |
| os.makedirs(output_dir + '/rgb_map', exist_ok = True) | |
| os.makedirs(output_dir + '/rgb_map_wo_hand', exist_ok = True) | |
| os.makedirs(output_dir + '/torso_map', exist_ok = True) | |
| os.makedirs(output_dir + '/mask_map', exist_ok = True) | |
| os.makedirs(output_dir + '/posed_gaussians', exist_ok = True) | |
| os.makedirs(output_dir + '/posed_params', exist_ok = True) | |
| os.makedirs(output_dir + '/full_body_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) | |
| geo_renderer = None | |
| item_0 = self.dataset.getitem(0, training = False) | |
| object_center = item_0['live_bounds'].mean(0) | |
| global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient'] | |
| # set x and z to 0 | |
| global_orient[0] = 0 | |
| global_orient[2] = 0 | |
| global_orient = cv.Rodrigues(global_orient)[0] | |
| # print('object_center: ', object_center.tolist()) | |
| # print('global_orient: ', global_orient.tolist()) | |
| # exit(1) | |
| time_start = torch.cuda.Event(enable_timing = True) | |
| time_start_all = torch.cuda.Event(enable_timing = True) | |
| time_end = torch.cuda.Event(enable_timing = True) | |
| data_num = len(self.dataset) | |
| if self.opt['test'].get('fix_hand', False): | |
| self.avatar_net.generate_mean_hands() | |
| log_time = False | |
| # extr = visualize_util.calc_free_mv(object_center, | |
| # tar_pos = np.array([0, 0, 2.5]), | |
| # rot_Y = 0., | |
| # rot_X = 0., | |
| # global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
| # intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| # img_scale = self.opt['test'].get('img_scale', 1.0) | |
| # intr[:2] *= img_scale | |
| # img_h = int(1024 * img_scale) | |
| # img_w = int(1024 * img_scale) | |
| # self.dump_renderer_info(output_dir, [extr], [intr], [img_h], [img_w]) | |
| extr_list = [] | |
| intr_list = [] | |
| img_h_list = [] | |
| img_w_list = [] | |
| for idx in tqdm(range(data_num), desc = 'Rendering avatars...'): | |
| if log_time: | |
| time_start.record() | |
| time_start_all.record() | |
| img_scale = self.opt['test'].get('img_scale', 1.0) | |
| view_setting = config.opt['test'].get('view_setting', 'free') | |
| if view_setting == 'camera': | |
| # training view setting | |
| cam_id = config.opt['test']['render_view_idx'] | |
| intr = self.dataset.intr_mats[cam_id].copy() | |
| intr[:2] *= img_scale | |
| extr = self.dataset.extr_mats[cam_id].copy() | |
| img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale) | |
| elif view_setting.startswith('free'): | |
| # free view setting | |
| # frame_num_per_circle = 360 | |
| print(self.opt['test'].get('global_orient', False)) | |
| frame_num_per_circle = 360 | |
| rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = rot_Y, | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| extr_list.append(extr) | |
| intr_list.append(intr) | |
| img_h_list.append(img_h) | |
| img_w_list.append(img_w) | |
| elif view_setting.startswith('degree120'): | |
| print('we render 120 degree') | |
| # +- 60 degree | |
| frame_per_cycle = 480 | |
| max_degree = 60 | |
| frame_half_cycle = frame_per_cycle // 2 | |
| if idx%frame_per_cycle < frame_per_cycle/2: | |
| rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
| # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi | |
| else: | |
| rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
| # to radian | |
| rot_Y = rot_Y * np.pi / 180 | |
| if rot_Y<0: | |
| rot_Y = rot_Y + 2 * np.pi | |
| # print('rot_Y: ', rot_Y) | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = rot_Y, | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| extr_list.append(extr) | |
| intr_list.append(intr) | |
| img_h_list.append(img_h) | |
| img_w_list.append(img_w) | |
| elif view_setting.startswith('degree90'): | |
| print('we render 90 degree') | |
| # +- 60 degree | |
| frame_per_cycle = 360 | |
| max_degree = 45 | |
| frame_half_cycle = frame_per_cycle // 2 | |
| if idx%frame_per_cycle < frame_per_cycle/2: | |
| rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
| # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi | |
| else: | |
| rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
| # to radian | |
| rot_Y = rot_Y * np.pi / 180 | |
| if rot_Y<0: | |
| rot_Y = rot_Y + 2 * np.pi | |
| # print('rot_Y: ', rot_Y) | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = rot_Y, | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| extr_list.append(extr) | |
| intr_list.append(intr) | |
| img_h_list.append(img_h) | |
| img_w_list.append(img_w) | |
| elif view_setting.startswith('front'): | |
| # front view setting | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = 0., | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| extr_list.append(extr) | |
| intr_list.append(intr) | |
| img_h_list.append(img_h) | |
| img_w_list.append(img_w) | |
| # print('extr: ', extr) | |
| # print('intr: ', intr) | |
| # print('img_h: ', img_h) | |
| # print('img_w: ', img_w) | |
| # exit() | |
| elif view_setting.startswith('back'): | |
| # back view setting | |
| extr = visualize_util.calc_free_mv(object_center, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = np.pi, | |
| rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| elif view_setting.startswith('moving'): | |
| # moving camera setting | |
| extr = visualize_util.calc_free_mv(object_center, | |
| # tar_pos = np.array([0, 0, 3.0]), | |
| # rot_Y = -0.3, | |
| tar_pos = np.array([0, 0, 2.5]), | |
| rot_Y = 0., | |
| rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
| global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
| intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
| intr[:2] *= img_scale | |
| img_h = int(1024 * img_scale) | |
| img_w = int(1024 * img_scale) | |
| elif view_setting.startswith('cano'): | |
| cano_center = self.dataset.cano_bounds.mean(0) | |
| extr = np.identity(4, np.float32) | |
| extr[:3, 3] = -cano_center | |
| rot_x = np.identity(4, np.float32) | |
| rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0] | |
| extr = rot_x @ extr | |
| f_len = 5000 | |
| extr[2, 3] += f_len / 512 | |
| intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32) | |
| # item = self.dataset.getitem(idx, | |
| # training = False, | |
| # extr = extr, | |
| # intr = intr, | |
| # img_w = 1024, | |
| # img_h = 1024) | |
| img_w, img_h = 1024, 1024 | |
| # item['live_smpl_v'] = item['cano_smpl_v'] | |
| # item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1) | |
| # item['live_bounds'] = item['cano_bounds'] | |
| else: | |
| raise ValueError('Invalid view setting for animation!') | |
| self.dump_renderer_info(output_dir, extr_list, intr_list, img_h_list, img_w_list) | |
| # also save the extr and intr and img_h and img_w to json | |
| camera_info = [] | |
| for i in range(len(extr_list)): | |
| camera = {} | |
| camera['extr'] = extr_list[i].tolist() | |
| camera['intr'] = intr_list[i].tolist() | |
| camera['img_h'] = img_h_list[i] | |
| camera['img_w'] = img_w_list[i] | |
| camera_info.append(camera) | |
| with open(os.path.join(output_dir, 'camera_info.json'), 'w') as fp: | |
| json.dump(camera_info, fp) | |
| getitem_func = self.dataset.getitem_fast if hasattr(self.dataset, 'getitem_fast') else self.dataset.getitem | |
| item = getitem_func( | |
| idx, | |
| training = False, | |
| extr = extr, | |
| intr = intr, | |
| img_w = img_w, | |
| img_h = img_h | |
| ) | |
| items = to_cuda(item, add_batch = False) | |
| if view_setting.startswith('moving') or view_setting == 'free_moving': | |
| current_center = items['live_bounds'].cpu().numpy().mean(0) | |
| delta = current_center - object_center | |
| object_center[0] += delta[0] | |
| # object_center[1] += delta[1] | |
| # object_center[2] += delta[2] | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Loading data costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| time_start.record() | |
| if self.opt['test'].get('render_skeleton', False): | |
| from utils.visualize_skeletons import construct_skeletons | |
| skel_vertices, skel_faces = construct_skeletons(item['joints'].cpu().numpy(), item['kin_parent'].cpu().numpy()) | |
| skel_mesh = trimesh.Trimesh(skel_vertices, skel_faces, process = False) | |
| if geo_renderer is None: | |
| geo_renderer = Renderer(item['img_w'], item['img_h'], shader_name = 'phong_geometry', bg_color = (1, 1, 1)) | |
| extr, intr = item['extr'], item['intr'] | |
| geo_renderer.set_camera(extr, intr) | |
| geo_renderer.set_model(skel_vertices[skel_faces.reshape(-1)], skel_mesh.vertex_normals.astype(np.float32)[skel_faces.reshape(-1)]) | |
| skel_img = geo_renderer.render()[:, :, :3] | |
| skel_img = (skel_img * 255).astype(np.uint8) | |
| cv.imwrite(output_dir + '/live_skeleton/%08d.jpg' % item['data_idx'], skel_img) | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Rendering skeletons costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| time_start.record() | |
| if 'smpl_pos_map' not in items: | |
| self.avatar_net.get_pose_map(items) | |
| # pca | |
| if use_pca: | |
| mask = training_dataset.pos_map_mask | |
| live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy() | |
| front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2) | |
| pose_conds = front_live_pos_map[mask] | |
| new_pose_conds = training_dataset.transform_pca(pose_conds, sigma_pca = float(self.opt['test'].get('sigma_pca', 2.))) | |
| front_live_pos_map[mask] = new_pose_conds | |
| live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2) | |
| items.update({ | |
| 'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(config.device).permute(2, 0, 1) | |
| }) | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Rendering pose conditions costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| time_start.record() | |
| output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca) | |
| output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca) | |
| mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca) | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Rendering avatar costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| time_start.record() | |
| if 'rgb_map' in output_wo_hand: | |
| rgb_map_wo_hand = output_wo_hand['rgb_map'] | |
| if 'full_body_rgb_map' in mask_output: | |
| os.makedirs(output_dir + '/full_body_mask', exist_ok = True) | |
| full_body_mask = mask_output['full_body_rgb_map'] | |
| full_body_mask.clip_(0., 1.) | |
| full_body_mask = (full_body_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/full_body_mask/%08d.png' % item['data_idx'], full_body_mask.cpu().numpy()) | |
| if 'hand_only_rgb_map' in mask_output: | |
| os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) | |
| hand_only_mask = mask_output['hand_only_rgb_map'] | |
| hand_only_mask.clip_(0., 1.) | |
| hand_only_mask = (hand_only_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/hand_only_mask/%08d.png' % item['data_idx'], hand_only_mask.cpu().numpy()) | |
| if 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: | |
| # mask only covers hand | |
| body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device)) | |
| body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save | |
| hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device)) | |
| hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.01 | |
| if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95 | |
| if_mask_r_hand = if_mask_r_hand.cpu().numpy() | |
| body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device)) | |
| body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save | |
| hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device)) | |
| hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01 | |
| if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95 | |
| if_mask_l_hand = if_mask_l_hand.cpu().numpy() | |
| # 保存左右手被遮挡部分的mask | |
| red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask) | |
| blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask) | |
| all_mask = red_mask | blue_mask | |
| # now save 3 mask to 3 folders | |
| os.makedirs(output_dir + '/hand_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/r_hand_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/l_hand_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/hand_visual', exist_ok = True) | |
| all_mask = (all_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/hand_mask/%08d.png' % item['data_idx'], all_mask.cpu().numpy()) | |
| r_hand_mask = (body_red_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/r_hand_mask/%08d.png' % item['data_idx'], r_hand_mask.cpu().numpy()) | |
| l_hand_mask = (body_blue_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/l_hand_mask/%08d.png' % item['data_idx'], l_hand_mask.cpu().numpy()) | |
| hand_visual = [if_mask_r_hand, if_mask_l_hand] | |
| # save to npy | |
| with open(output_dir + '/hand_visual/%08d.npy' % item['data_idx'], 'wb') as f: | |
| np.save(f, hand_visual) | |
| # now build sleeve_mask | |
| if 'left_hand_rgb_map' in mask_output and 'right_hand_rgb_map' in mask_output: | |
| os.makedirs(output_dir + '/left_sleeve_mask', exist_ok = True) | |
| os.makedirs(output_dir + '/right_sleeve_mask', exist_ok = True) | |
| mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128) | |
| mask = mask.cpu().numpy().astype(np.uint8) | |
| # 定义一个结构元素,可以调整其大小以改变膨胀的程度 | |
| kernel = np.ones((5, 5), np.uint8) | |
| # 应用膨胀操作 | |
| mask = cv.dilate(mask, kernel, iterations=3) | |
| mask = torch.tensor(mask).to(config.device) | |
| left_hand_mask = mask_output['left_hand_rgb_map'] | |
| left_hand_mask.clip_(0., 1.) | |
| # non white part is mask | |
| left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask) | |
| left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01 | |
| # dele two hand mask | |
| left_hand_mask = left_hand_mask & ~mask | |
| right_hand_mask = mask_output['right_hand_rgb_map'] | |
| right_hand_mask.clip_(0., 1.) | |
| right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask) | |
| right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01 | |
| right_hand_mask = right_hand_mask & ~mask | |
| # save | |
| left_hand_mask = (left_hand_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/left_sleeve_mask/%08d.png' % item['data_idx'], left_hand_mask.cpu().numpy()) | |
| right_hand_mask = (right_hand_mask * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/right_sleeve_mask/%08d.png' % item['data_idx'], right_hand_mask.cpu().numpy()) | |
| rgb_map = output['rgb_map'] | |
| rgb_map.clip_(0., 1.) | |
| rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy() | |
| cv.imwrite(output_dir + '/rgb_map/%08d.jpg' % item['data_idx'], rgb_map) | |
| # 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map | |
| if 'rgb_map' in output_wo_hand and 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: | |
| rgb_map_wo_hand = output_wo_hand['rgb_map'] | |
| rgb_map_wo_hand.clip_(0., 1.) | |
| rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy() | |
| r_mask = (r_hand_mask>128).cpu().numpy() | |
| l_mask = (l_hand_mask>128).cpu().numpy() | |
| mask = r_mask | l_mask | |
| mask = mask.astype(np.uint8) | |
| # 定义一个结构元素,可以调整其大小以改变膨胀的程度 | |
| kernel = np.ones((5, 5), np.uint8) | |
| # 应用膨胀操作 | |
| mask = cv.dilate(mask, kernel, iterations=3) | |
| mask = mask.astype(np.bool_) | |
| mask = np.expand_dims(mask, axis=2) | |
| # print('mask shape: ', mask.shape) | |
| import ipdb | |
| # ipdb.set_trace() | |
| mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask | |
| cv.imwrite(output_dir + '/rgb_map_wo_hand/%08d.jpg' % item['data_idx'], mix) | |
| if 'torso_map' in output: | |
| os.makedirs(output_dir + '/torso_map', exist_ok = True) | |
| torso_map = output['torso_map'][:, :, 0] | |
| torso_map.clip_(0., 1.) | |
| torso_map = (torso_map * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/torso_map/%08d.png' % item['data_idx'], torso_map.cpu().numpy()) | |
| if 'mask_map' in output: | |
| os.makedirs(output_dir + '/mask_map', exist_ok = True) | |
| mask_map = output['mask_map'][:, :, 0] | |
| mask_map.clip_(0., 1.) | |
| mask_map = (mask_map * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/mask_map/%08d.png' % item['data_idx'], mask_map.cpu().numpy()) | |
| if self.opt['test'].get('save_tex_map', False): | |
| os.makedirs(output_dir + '/cano_tex_map', exist_ok = True) | |
| cano_tex_map = output['cano_tex_map'] | |
| cano_tex_map.clip_(0., 1.) | |
| cano_tex_map = (cano_tex_map * 255).to(torch.uint8) | |
| cv.imwrite(output_dir + '/cano_tex_map/%08d.jpg' % item['data_idx'], cano_tex_map.cpu().numpy()) | |
| if self.opt['test'].get('save_ply', False): | |
| if item['data_idx'] == 0: | |
| save_gaussians_as_ply(output_dir + '/posed_gaussians/%08d.ply' % item['data_idx'], output['posed_gaussians']) | |
| for k in output['posed_gaussians'].keys(): | |
| if isinstance(output['posed_gaussians'][k], torch.Tensor): | |
| output['posed_gaussians'][k] = output['posed_gaussians'][k].detach().cpu().numpy() | |
| np.savez(output_dir + '/posed_gaussians/%08d.npz' % item['data_idx'], **output['posed_gaussians']) | |
| np.savez(output_dir + ('/posed_params/%08d.npz' % item['data_idx']), | |
| betas=training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(), | |
| global_orient=item['global_orient'].reshape([-1]).detach().cpu().numpy(), | |
| transl=item['transl'].reshape([-1]).detach().cpu().numpy(), | |
| body_pose=item['body_pose'].reshape([-1]).detach().cpu().numpy()) | |
| if log_time: | |
| time_end.record() | |
| torch.cuda.synchronize() | |
| print('Saving images costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
| print('Animating one frame costs %.4f secs' % (time_start_all.elapsed_time(time_end) / 1000.)) | |
| torch.cuda.empty_cache() | |
| def save_ckpt(self, path, save_optm = True): | |
| os.makedirs(path, exist_ok = True) | |
| net_dict = { | |
| 'epoch_idx': self.epoch_idx, | |
| 'iter_idx': self.iter_idx, | |
| 'avatar_net': self.avatar_net.state_dict(), | |
| } | |
| print('Saving networks to ', path + '/net.pt') | |
| torch.save(net_dict, path + '/net.pt') | |
| if save_optm: | |
| optm_dict = { | |
| 'avatar_net': self.optm.state_dict(), | |
| } | |
| print('Saving optimizers to ', path + '/optm.pt') | |
| torch.save(optm_dict, path + '/optm.pt') | |
| def load_ckpt(self, path, load_optm = True): | |
| print('Loading networks from ', path + '/net.pt') | |
| net_dict = torch.load(path + '/net.pt') | |
| if 'avatar_net' in net_dict: | |
| self.avatar_net.load_state_dict(net_dict['avatar_net']) | |
| else: | |
| print('[WARNING] Cannot find "avatar_net" from the network checkpoint!') | |
| epoch_idx = net_dict['epoch_idx'] | |
| iter_idx = net_dict['iter_idx'] | |
| if load_optm and os.path.exists(path + '/optm.pt'): | |
| print('Loading optimizers from ', path + '/optm.pt') | |
| optm_dict = torch.load(path + '/optm.pt') | |
| if 'avatar_net' in optm_dict: | |
| self.optm.load_state_dict(optm_dict['avatar_net']) | |
| else: | |
| print('[WARNING] Cannot find "avatar_net" from the optimizer checkpoint!') | |
| return epoch_idx, iter_idx | |
| if __name__ == '__main__': | |
| torch.manual_seed(31359) | |
| np.random.seed(31359) | |
| # torch.autograd.set_detect_anomaly(True) | |
| from argparse import ArgumentParser | |
| arg_parser = ArgumentParser() | |
| arg_parser.add_argument('-c', '--config_path', type = str, help = 'Configuration file path.') | |
| arg_parser.add_argument('-m', '--mode', type = str, help = 'Running mode.', default = 'train') | |
| args = arg_parser.parse_args() | |
| config.load_global_opt(args.config_path) | |
| if args.mode is not None: | |
| config.opt['mode'] = args.mode | |
| trainer = AvatarTrainer(config.opt) | |
| if config.opt['mode'] == 'train': | |
| if not safe_exists(config.opt['train']['net_ckpt_dir'] + '/pretrained') \ | |
| and not safe_exists(config.opt['train']['pretrained_dir'])\ | |
| and not safe_exists(config.opt['train']['prev_ckpt']): | |
| trainer.pretrain() | |
| trainer.train() | |
| elif config.opt['mode'] == 'test': | |
| trainer.test() | |
| else: | |
| raise NotImplementedError('Invalid running mode!') | |