Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import pytorch3d.ops | |
| import pytorch3d.transforms | |
| import trimesh | |
| import config | |
| from network.mlp import MLPLinear, SdfMLP | |
| from network.density import LaplaceDensity | |
| from network.volume import CanoBlendWeightVolume | |
| from network.hand_avatar import HandAvatar | |
| from utils.embedder import get_embedder | |
| import utils.nerf_util as nerf_util | |
| import utils.smpl_util as smpl_util | |
| import utils.geo_util as geo_util | |
| from utils.posevocab_custom_ops.near_far_smpl import near_far_smpl | |
| from utils.posevocab_custom_ops.nearest_face import nearest_face_pytorch3d | |
| from utils.knn import knn_gather | |
| import root_finding | |
| class TemplateNet(nn.Module): | |
| def __init__(self, opt): | |
| super(TemplateNet, self).__init__() | |
| self.opt = opt | |
| self.pos_embedder, self.pos_dim = get_embedder(opt['multires'], 3) | |
| # canonical blend weight volume | |
| self.cano_weight_volume = CanoBlendWeightVolume(config.opt['train']['data']['data_dir'] + '/cano_weight_volume.npz') | |
| self.pose_feat_dim = 0 | |
| """ geometry networks """ | |
| geo_mlp_opt = { | |
| 'in_channels': self.pos_dim + self.pose_feat_dim, | |
| 'out_channels': 256 + 1, | |
| 'inter_channels': [512, 256, 256, 256, 256, 256], | |
| 'nlactv': nn.Softplus(beta = 100), | |
| 'res_layers': [4], | |
| 'geometric_init': True, | |
| 'bias': 0.7, | |
| 'weight_norm': True | |
| } | |
| self.geo_mlp = SdfMLP(**geo_mlp_opt) | |
| """ texture networks """ | |
| if self.opt['use_viewdir']: | |
| self.viewdir_embedder, self.viewdir_dim = get_embedder(self.opt['multires_viewdir'], 3) | |
| else: | |
| self.viewdir_embedder, self.viewdir_dim = None, 0 | |
| tex_mlp_opt = { | |
| 'in_channels': 256 + self.viewdir_dim, | |
| 'out_channels': 3, | |
| 'inter_channels': [256, 256, 256], | |
| 'nlactv': nn.ReLU(), | |
| 'last_op': nn.Sigmoid() | |
| } | |
| self.tex_mlp = MLPLinear(**tex_mlp_opt) | |
| print('# MLPs: ') | |
| print(self.geo_mlp) | |
| print(self.tex_mlp) | |
| # sdf2density | |
| self.density_func = LaplaceDensity(params_init = {'beta': 0.01}) | |
| # hand avatars | |
| self.with_hand = self.opt.get('with_hand', False) | |
| self.left_hand = HandAvatar() | |
| self.right_hand = HandAvatar() | |
| # for root finding | |
| from network.volume import compute_gradient_volume | |
| if self.opt.get('volume_type', 'diff') == 'diff': | |
| self.weight_volume = self.cano_weight_volume.diff_weight_volume[0].permute(1, 2, 3, 0).contiguous() | |
| else: | |
| self.weight_volume = self.cano_weight_volume.ori_weight_volume[0].permute(1, 2, 3, 0).contiguous() | |
| self.grad_volume = compute_gradient_volume(self.weight_volume.permute(3, 0, 1, 2), self.cano_weight_volume.voxel_size).permute(2, 3, 4, 0, 1)\ | |
| .reshape(self.cano_weight_volume.res_x, self.cano_weight_volume.res_y, self.cano_weight_volume.res_z, -1).contiguous() | |
| self.res = torch.tensor([self.cano_weight_volume.res_x, self.cano_weight_volume.res_y, self.cano_weight_volume.res_z], dtype = torch.int32, device = config.device) | |
| self._initialize_hands() | |
| def _initialize_hands(self): | |
| smplx_lhand_to_mano_rhand_data = np.load(config.PROJ_DIR + '/smpl_files/mano/smplx_lhand_to_mano_rhand.npz', allow_pickle = True) | |
| smplx_rhand_to_mano_rhand_data = np.load(config.PROJ_DIR + '/smpl_files/mano/smplx_rhand_to_mano_rhand.npz', allow_pickle = True) | |
| smpl_lhand_vert_id = np.copy(smplx_lhand_to_mano_rhand_data['smpl_vert_id_to_mano']) | |
| smpl_rhand_vert_id = np.copy(smplx_rhand_to_mano_rhand_data['smpl_vert_id_to_mano']) | |
| self.smpl_lhand_vert_id = torch.from_numpy(smpl_lhand_vert_id).to(config.device) | |
| self.smpl_rhand_vert_id = torch.from_numpy(smpl_rhand_vert_id).to(config.device) | |
| self.smpl_hands_vert_id = torch.cat([self.smpl_lhand_vert_id, self.smpl_rhand_vert_id], 0) | |
| mano_face_closed = np.loadtxt(config.PROJ_DIR + '/smpl_files/mano/mano_face_close.txt').astype(np.int64) | |
| self.mano_face_closed = torch.from_numpy(mano_face_closed).to(config.device) | |
| self.mano_face_closed_2hand = torch.cat([self.mano_face_closed[:, [2, 1, 0]], self.mano_face_closed + self.smpl_lhand_vert_id.shape[0]], 0) | |
| def forward_cano_body_nerf(self, xyz, viewdirs, pose, compute_grad = False): | |
| """ | |
| :param xyz: (B, N, 3) | |
| :param viewdirs: (B, N, 3) | |
| :param pose: (B, pose_dim) | |
| :param compute_grad: whether computing gradient w.r.t xyz | |
| :return: | |
| """ | |
| if compute_grad: | |
| xyz.requires_grad_() | |
| # pose_feat = self.pose_feat[None, None].expand(xyz.shape[0], xyz.shape[1], -1) | |
| # pose_feat = torch.cat([self.pos_embedder(xyz), pose_feat], -1) | |
| pose_feat = self.pos_embedder(xyz) | |
| geo_feat = self.geo_mlp(pose_feat) | |
| sdf, geo_feat = torch.split(geo_feat, [1, geo_feat.shape[-1] - 1], -1) | |
| if self.viewdir_embedder is not None: | |
| if viewdirs is None: | |
| viewdirs = torch.zeros_like(xyz) | |
| geo_feat = torch.cat([geo_feat, self.viewdir_embedder(viewdirs)], -1) | |
| color = self.tex_mlp(geo_feat) | |
| density = self.density_func(sdf) | |
| ret = { | |
| 'sdf': -sdf, # assume outside is negative, inside is positive | |
| 'density': density, | |
| 'color': color, | |
| 'cano_xyz': xyz.detach() | |
| } | |
| if compute_grad: | |
| d_output = torch.ones_like(sdf, requires_grad = False, device = sdf.device) | |
| normal = torch.autograd.grad(outputs = sdf, | |
| inputs = xyz, | |
| grad_outputs = d_output, | |
| create_graph = self.training, | |
| retain_graph = self.training, | |
| only_inputs = True)[0] | |
| ret.update({ | |
| 'normal': normal | |
| }) | |
| return ret | |
| def forward_cano_hand_nerf(self, xyz, sdf, viewdirs, hand_pose, module = 'left_hand'): | |
| net = self.__getattr__(module) | |
| return net(xyz, sdf, viewdirs, hand_pose) | |
| def fuse_hands(self, body_ret, posed_xyz, view_dirs, batch, space = 'live'): | |
| # get hand correspondences | |
| batch_size, n_pts = posed_xyz.shape[:2] | |
| def process_one_hand(side = 'left'): | |
| hand_v = batch['%s_live_mano_v' % side] if space == 'live' else batch['%s_cano_mano_v' % side] | |
| hand_n = batch['%s_live_mano_n' % side] if space == 'live' else batch['%s_cano_mano_n' % side] | |
| hand_f = self.mano_face_closed[:, [2, 1, 0]] if side == 'left' else self.mano_face_closed | |
| dists, face_indices, bc_coords = nearest_face_pytorch3d(posed_xyz, hand_v, hand_f) | |
| face_vertex_ids = torch.gather(hand_f[None].expand(batch_size, -1, -1), 1, face_indices[:, :, None].long().expand(-1, -1, 3)) # (B, N, 3) | |
| cano_hand_v = geo_util.normalize_vert_bbox(batch['%s_cano_mano_v' % side], dim = 1, per_axis = True) | |
| face_cano_mano_v = knn_gather(cano_hand_v, face_vertex_ids) | |
| pts_cano_mano_v = (bc_coords[..., None] * face_cano_mano_v).sum(2) | |
| face_live_mano_v = knn_gather(hand_v, face_vertex_ids) | |
| pts_live_mano_v = (bc_coords[..., None] * face_live_mano_v).sum(2) | |
| # face_normal = torch.cross(face_live_smpl_v[:, :, 1] - face_live_smpl_v[:, :, 0], face_live_smpl_v[:, :, 2] - face_live_smpl_v[:, :, 0]) | |
| face_live_mano_n = knn_gather(hand_n, face_vertex_ids) | |
| pts_live_mano_n = (bc_coords[..., None] * face_live_mano_n).sum(2) | |
| pts_smpl_sdf = -torch.sign(torch.einsum('bni,bni->bn', pts_live_mano_n, posed_xyz - pts_live_mano_v)) * dists | |
| return pts_cano_mano_v, pts_smpl_sdf.unsqueeze(-1) | |
| left_cano_mano_v, left_mano_sdf = process_one_hand('left') | |
| right_cano_mano_v, right_mano_sdf = process_one_hand('right') | |
| # fuse | |
| zero_hand_pose = torch.zeros((1, 15*3)).to(left_cano_mano_v) | |
| color_lhand = self.forward_cano_hand_nerf(left_cano_mano_v, left_mano_sdf, view_dirs, zero_hand_pose, module = 'left_hand') | |
| color_rhand = self.forward_cano_hand_nerf(right_cano_mano_v, right_mano_sdf, view_dirs, zero_hand_pose, module = 'right_hand') | |
| # calculate the blending weights for blending the outputs of body network and hand networks | |
| # wl = torch.sigmoid(1000 * (left_mano_sdf + 0.1)) * torch.sigmoid(25 * (left_cano_mano_v[..., 0:1] + 0.8)) | |
| # wr = torch.sigmoid(1000 * (right_mano_sdf + 0.1)) * torch.sigmoid(-25 * (right_cano_mano_v[..., 0:1] - 0.8)) | |
| cano_xyz = body_ret['cano_xyz'] | |
| wl = torch.sigmoid(25 * (geo_util.normalize_vert_bbox(batch['left_cano_mano_v'], attris = cano_xyz, dim = 1, per_axis = True)[..., 0:1] + 0.8)) | |
| wr = torch.sigmoid(-25 * (geo_util.normalize_vert_bbox(batch['right_cano_mano_v'], attris = cano_xyz, dim = 1, per_axis = True)[..., 0:1] - 0.8)) | |
| wl[cano_xyz[..., 1] < batch['cano_smpl_center'][0, 1]] = 0. | |
| wr[cano_xyz[..., 1] < batch['cano_smpl_center'][0, 1]] = 0. | |
| s = torch.maximum(wl + wr, torch.ones_like(wl)) | |
| wl, wr = wl / s, wr / s | |
| # blend the outputs of body network and hand networks | |
| w = wl + wr | |
| # factor = 10 | |
| # left_mano_sdf *= factor | |
| # right_mano_sdf *= factor | |
| body_ret['sdf'] = wl * left_mano_sdf + wr * right_mano_sdf + (1.0 - w) * body_ret['sdf'] | |
| body_ret['color'] = wl * color_lhand + wr * color_rhand + (1.0 - w) * body_ret['color'] | |
| body_ret['density'] = self.density_func(-body_ret['sdf']) | |
| def forward_cano_radiance_field(self, xyz, view_dirs, batch): | |
| body_ret = self.forward_cano_body_nerf(xyz, view_dirs, None, compute_grad = self.training) | |
| return body_ret | |
| def transform_cano2live(self, cano_pts, batch, normals = None, near_thres = 0.08): | |
| cano2live_jnt_mats = batch['cano2live_jnt_mats'].clone() | |
| if not self.with_hand: | |
| # make sure the hand transformation is totally rigid | |
| cano2live_jnt_mats[:, 25: 40] = cano2live_jnt_mats[:, 20: 21] | |
| cano2live_jnt_mats[:, 40: 55] = cano2live_jnt_mats[:, 21: 22] | |
| pts_w = self.cano_weight_volume.forward_weight(cano_pts) | |
| pt_mats = torch.einsum('bnj,bjxy->bnxy', pts_w, cano2live_jnt_mats) | |
| posed_pts = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], cano_pts) + pt_mats[..., :3, 3] | |
| if normals is None: | |
| return posed_pts | |
| else: | |
| posed_normals = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], normals) | |
| return posed_pts, posed_normals | |
| def transform_live2cano(self, posed_pts, batch, normals = None, near_thres = 0.08): | |
| cano2live_jnt_mats = batch['cano2live_jnt_mats'].clone() | |
| if not self.with_hand: | |
| cano2live_jnt_mats[:, 25: 40] = cano2live_jnt_mats[:, 20: 21] | |
| cano2live_jnt_mats[:, 40: 55] = cano2live_jnt_mats[:, 21: 22] | |
| """ live_pts -> cano_pts """ | |
| batch_size, n_pts = posed_pts.shape[:2] | |
| with torch.no_grad(): | |
| if 'live_mesh_v' in batch: | |
| # if False: | |
| tar_v = batch['live_mesh_v'] | |
| tar_f = batch['live_mesh_f'] | |
| tar_lbs = batch['live_mesh_lbs'] | |
| pts_w, near_flag = smpl_util.calc_blending_weight(posed_pts, tar_v, tar_f, tar_lbs, near_thres, method = 'NN') | |
| else: | |
| tar_v = batch['live_smpl_v'] | |
| tar_f = batch['smpl_faces'] | |
| tar_lbs = None | |
| pts_w, near_flag = smpl_util.calc_blending_weight(posed_pts, tar_v, tar_f, tar_lbs, near_thres, method = 'barycentric') | |
| pt_mats = torch.einsum('bnj,bjxy->bnxy', pts_w, cano2live_jnt_mats) | |
| pt_mats = torch.linalg.inv(pt_mats) | |
| cano_pts = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], posed_pts) + pt_mats[..., :3, 3] | |
| # cano_pts_bk = cano_pts.detach().clone() | |
| if normals is not None: | |
| cano_normals = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], normals) | |
| if self.opt['use_root_finding']: | |
| argmax_lbs = torch.argmax(pts_w, -1) | |
| nonopt_bone_ids = [7, 8, 10, 11] | |
| nonopt_pts_flag = torch.zeros((batch_size, n_pts), dtype = torch.bool).to(argmax_lbs.device) | |
| for i in nonopt_bone_ids: | |
| nonopt_pts_flag = torch.logical_or(nonopt_pts_flag, argmax_lbs == i) | |
| root_finding_flag = torch.logical_not(nonopt_pts_flag) | |
| if root_finding_flag.any(): | |
| cano_pts_ = cano_pts[root_finding_flag].unsqueeze(0) | |
| posed_pts_ = posed_pts[root_finding_flag].unsqueeze(0) | |
| if not cano_pts_.is_contiguous(): | |
| cano_pts_ = cano_pts_.contiguous() | |
| if not posed_pts_.is_contiguous(): | |
| posed_pts_ = posed_pts_.contiguous() | |
| root_finding.root_finding( | |
| self.weight_volume, | |
| self.grad_volume, | |
| posed_pts_, | |
| cano_pts_, | |
| cano2live_jnt_mats, | |
| self.cano_weight_volume.volume_bounds, | |
| self.res, | |
| cano_pts_, | |
| 0.1, | |
| 10 | |
| ) | |
| cano_pts[root_finding_flag] = cano_pts_[0] | |
| if normals is None: | |
| return cano_pts, near_flag | |
| else: | |
| return cano_pts, cano_normals, near_flag | |
| def render(self, batch, chunk_size = 2048, depth_guided_sampling = None, space = 'live', white_bkgd = False): | |
| ray_o = batch['ray_o'] | |
| ray_d = batch['ray_d'] | |
| near = batch['near'] | |
| far = batch['far'] | |
| if depth_guided_sampling['flag']: | |
| print('# depth-guided sampling') | |
| valid_dist_flag = batch['dist'] > 1e-6 | |
| dist = batch['dist'][valid_dist_flag] | |
| near_dist = depth_guided_sampling['near_sur_dist'] | |
| far_dist = depth_guided_sampling['near_sur_dist'] | |
| near[valid_dist_flag] = dist - near_dist | |
| far[valid_dist_flag] = dist + far_dist | |
| N_ray_samples = depth_guided_sampling['N_ray_samples'] | |
| else: | |
| if depth_guided_sampling.get('type', 'smpl') == 'smpl': | |
| print('# smpl-guided sampling') | |
| valid_dist_flag = torch.ones_like(near, dtype = bool) | |
| near, far, intersect_flag = near_far_smpl(batch['live_smpl_v'][0], ray_o[0], ray_d[0]) | |
| near[~intersect_flag] = batch['near'][0][~intersect_flag] | |
| far[~intersect_flag] = batch['far'][0][~intersect_flag] | |
| near = near.unsqueeze(0) | |
| far = far.unsqueeze(0) | |
| N_ray_samples = 64 | |
| elif depth_guided_sampling.get('type', 'smpl') == 'uniform': | |
| print('# uniform sampling') | |
| valid_dist_flag = torch.ones_like(near, dtype = bool) | |
| N_ray_samples = 64 | |
| if self.training: | |
| chunk_size = batch['ray_o'].shape[1] | |
| batch_size, n_pixels = ray_o.shape[:2] | |
| output_list = [] | |
| for i in range(0, n_pixels, chunk_size): | |
| near_chunk = near[:, i: i + chunk_size] | |
| far_chunk = far[:, i: i + chunk_size] | |
| ray_o_chunk = ray_o[:, i: i + chunk_size] | |
| ray_d_chunk = ray_d[:, i: i + chunk_size] | |
| valid_dist_flag_chunk = valid_dist_flag[:, i: i + chunk_size] | |
| # sample points on each ray | |
| pts, z_vals = nerf_util.sample_pts_on_rays(ray_o_chunk, ray_d_chunk, near_chunk, far_chunk, | |
| N_samples = N_ray_samples, | |
| perturb = self.training, | |
| depth_guided_mask = valid_dist_flag_chunk) | |
| # # debug: visualize pts | |
| # import trimesh | |
| # pts_trimesh = trimesh.PointCloud(pts[0].cpu().numpy().reshape(-1, 3)) | |
| # pts_trimesh.export('./debug/sampled_pts_%s.obj' % 'training' if self.training else 'testing') | |
| # exit(1) | |
| # flat | |
| _, n_pixels_chunk, n_samples = pts.shape[:3] | |
| pts = pts.view(batch_size, n_pixels_chunk * n_samples, -1) | |
| dists = z_vals[..., 1:] - z_vals[..., :-1] | |
| dists = torch.cat([dists, dists[..., -1:]], -1) | |
| # query | |
| if space == 'live': | |
| cano_pts, near_flag = self.transform_live2cano(pts, batch) | |
| elif space == 'cano': | |
| cano_pts = pts | |
| else: | |
| raise ValueError('Invalid rendering space!') | |
| viewdirs = ray_d_chunk / torch.norm(ray_d_chunk, dim = -1, keepdim = True) | |
| viewdirs = viewdirs[:, :, None, :].expand(-1, -1, n_samples, -1).reshape(batch_size, n_pixels_chunk * n_samples, -1) | |
| # apply gaussian noise to avoid overfitting | |
| if self.training: | |
| with torch.no_grad(): | |
| noise = torch.randn_like(viewdirs) * 0.1 | |
| viewdirs = viewdirs + noise | |
| viewdirs = viewdirs / torch.norm(viewdirs, dim = -1, keepdim = True) | |
| ret = self.forward_cano_radiance_field(cano_pts, viewdirs, batch) | |
| if self.with_hand: | |
| self.fuse_hands(ret, pts, viewdirs, batch, space) | |
| ret['color'] = ret['color'].view(batch_size, n_pixels_chunk, n_samples, -1) | |
| ret['density'] = ret['density'].view(batch_size, n_pixels_chunk, n_samples, -1) | |
| # integration | |
| alpha = 1. - torch.exp(-ret['density'] * dists[..., None]) | |
| raw = torch.cat([ret['color'], alpha], dim = -1) | |
| rgb_map, disp_map, acc_map, weights, depth_map = nerf_util.raw2outputs(raw, z_vals, white_bkgd = white_bkgd) | |
| output_chunk = { | |
| 'rgb_map': rgb_map, # (batch_size, n_pixel_chunk, 3) | |
| 'acc_map': acc_map | |
| } | |
| if 'normal' in ret: | |
| output_chunk.update({ | |
| 'normal': ret['normal'].view(batch_size, n_pixels_chunk, -1, 3) | |
| }) | |
| if 'tv_loss' in ret: | |
| output_chunk.update({ | |
| 'tv_loss': ret['tv_loss'].view(1, 1, -1) | |
| }) | |
| output_list.append(output_chunk) | |
| keys = output_list[0].keys() | |
| output_list = {k: torch.cat([r[k] for r in output_list], dim = 1) for k in keys} | |
| # processing for patch-based ray sampling | |
| if 'mask_within_patch' in batch: | |
| _, ray_num = batch['mask_within_patch'].shape | |
| rgb_map = torch.zeros((batch_size, ray_num, 3), dtype = torch.float32, device = config.device) | |
| acc_map = torch.zeros((batch_size, ray_num), dtype = torch.float32, device = config.device) | |
| rgb_map[batch['mask_within_patch']] = output_list['rgb_map'].reshape(-1, 3) | |
| acc_map[batch['mask_within_patch']] = output_list['acc_map'].reshape(-1) | |
| batch['color_gt'][~batch['mask_within_patch']] = 0. | |
| batch['mask_gt'][~batch['mask_within_patch']] = 0. | |
| output_list['rgb_map'] = rgb_map | |
| output_list['acc_map'] = acc_map | |
| return output_list | |