Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import cv2 as cv | |
| import trimesh | |
| import yaml | |
| import tqdm | |
| import smplx | |
| from network.volume import CanoBlendWeightVolume | |
| from utils.renderer import Renderer | |
| import config | |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" | |
| def save_pos_map(pos_map, path): | |
| mask = np.linalg.norm(pos_map, axis = -1) > 0. | |
| positions = pos_map[mask] | |
| print('Point nums %d' % positions.shape[0]) | |
| pc = trimesh.PointCloud(positions) | |
| pc.export(path) | |
| def interpolate_lbs(pts, vertices, faces, vertex_lbs): | |
| from utils.posevocab_custom_ops.nearest_face import nearest_face_pytorch3d | |
| from utils.geo_util import barycentric_interpolate | |
| dists, indices, bc_coords = nearest_face_pytorch3d( | |
| torch.from_numpy(pts).to(torch.float32).cuda()[None], | |
| torch.from_numpy(vertices).to(torch.float32).cuda()[None], | |
| torch.from_numpy(faces).to(torch.int64).cuda() | |
| ) | |
| # print(dists.mean()) | |
| lbs = barycentric_interpolate( | |
| vert_attris = vertex_lbs[None].to(torch.float32).cuda(), | |
| faces = torch.from_numpy(faces).to(torch.int64).cuda()[None], | |
| face_ids = indices, | |
| bc_coords = bc_coords | |
| ) | |
| return lbs[0].cpu().numpy() | |
| map_size = 1024 | |
| if __name__ == '__main__': | |
| from argparse import ArgumentParser | |
| import importlib | |
| arg_parser = ArgumentParser() | |
| arg_parser.add_argument('-c', '--config_path', type = str, help = 'Configuration file path.') | |
| args = arg_parser.parse_args() | |
| opt = yaml.load(open(args.config_path, encoding = 'UTF-8'), Loader = yaml.FullLoader) | |
| dataset_module = opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') | |
| MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
| dataset = MvRgbDataset(**opt['train']['data']) | |
| data_dir, frame_list = dataset.data_dir, dataset.pose_list | |
| os.makedirs(data_dir + '/smpl_pos_map', exist_ok = True) | |
| cano_renderer = Renderer(map_size, map_size, shader_name = 'vertex_attribute') | |
| smpl_model = smplx.SMPLX(config.PROJ_DIR + '/smpl_files/smplx', gender = 'neutral', use_pca = False, num_pca_comps = 45, flat_hand_mean = True, batch_size = 1) | |
| smpl_data = np.load(data_dir + '/smpl_params.npz') | |
| smpl_data = {k: torch.from_numpy(v.astype(np.float32)) for k, v in smpl_data.items()} | |
| with torch.no_grad(): | |
| cano_smpl = smpl_model.forward( | |
| betas = smpl_data['betas'], | |
| global_orient = config.cano_smpl_global_orient[None], | |
| transl = config.cano_smpl_transl[None], | |
| body_pose = config.cano_smpl_body_pose[None] | |
| ) | |
| cano_smpl_v = cano_smpl.vertices[0].cpu().numpy() | |
| cano_center = 0.5 * (cano_smpl_v.min(0) + cano_smpl_v.max(0)) | |
| cano_smpl_v_min = cano_smpl_v.min() | |
| smpl_faces = smpl_model.faces.astype(np.int64) | |
| if os.path.exists(data_dir + '/template.ply'): | |
| print('# Loading template from %s' % (data_dir + '/template.ply')) | |
| template = trimesh.load(data_dir + '/template.ply', process = False) | |
| using_template = True | |
| else: | |
| print(f'# Cannot find template.ply from {data_dir}, using SMPL-X as template') | |
| template = trimesh.Trimesh(cano_smpl_v, smpl_faces, process = False) | |
| using_template = False | |
| cano_smpl_v = template.vertices.astype(np.float32) | |
| smpl_faces = template.faces.astype(np.int64) | |
| cano_smpl_v_dup = cano_smpl_v[smpl_faces.reshape(-1)] | |
| cano_smpl_n_dup = template.vertex_normals.astype(np.float32)[smpl_faces.reshape(-1)] | |
| # define front & back view matrices | |
| front_mv = np.identity(4, np.float32) | |
| front_mv[:3, 3] = -cano_center + np.array([0, 0, -10], np.float32) | |
| front_mv[1:3] *= -1 | |
| back_mv = np.identity(4, np.float32) | |
| rot_y = cv.Rodrigues(np.array([0, np.pi, 0], np.float32))[0] | |
| back_mv[:3, :3] = rot_y | |
| back_mv[:3, 3] = -rot_y @ cano_center + np.array([0, 0, -10], np.float32) | |
| back_mv[1:3] *= -1 | |
| # render canonical smpl position maps | |
| cano_renderer.set_model(cano_smpl_v_dup, cano_smpl_v_dup) | |
| cano_renderer.set_camera(front_mv) | |
| front_cano_pos_map = cano_renderer.render()[:, :, :3] | |
| cano_renderer.set_camera(back_mv) | |
| back_cano_pos_map = cano_renderer.render()[:, :, :3] | |
| back_cano_pos_map = cv.flip(back_cano_pos_map, 1) | |
| cano_pos_map = np.concatenate([front_cano_pos_map, back_cano_pos_map], 1) | |
| cv.imwrite(data_dir + '/smpl_pos_map/cano_smpl_pos_map.exr', cano_pos_map) | |
| # render canonical smpl normal maps | |
| cano_renderer.set_model(cano_smpl_v_dup, cano_smpl_n_dup) | |
| cano_renderer.set_camera(front_mv) | |
| front_cano_nml_map = cano_renderer.render()[:, :, :3] | |
| cano_renderer.set_camera(back_mv) | |
| back_cano_nml_map = cano_renderer.render()[:, :, :3] | |
| back_cano_nml_map = cv.flip(back_cano_nml_map, 1) | |
| cano_nml_map = np.concatenate([front_cano_nml_map, back_cano_nml_map], 1) | |
| cv.imwrite(data_dir + '/smpl_pos_map/cano_smpl_nml_map.exr', cano_nml_map) | |
| body_mask = np.linalg.norm(cano_pos_map, axis = -1) > 0. | |
| cano_pts = cano_pos_map[body_mask] | |
| if using_template: | |
| weight_volume = CanoBlendWeightVolume(data_dir + '/cano_weight_volume.npz') | |
| pts_lbs = weight_volume.forward_weight(torch.from_numpy(cano_pts)[None].cuda())[0] | |
| else: | |
| pts_lbs = interpolate_lbs(cano_pts, cano_smpl_v, smpl_faces, smpl_model.lbs_weights) | |
| pts_lbs = torch.from_numpy(pts_lbs).cuda() | |
| np.save(data_dir + '/smpl_pos_map/init_pts_lbs.npy', pts_lbs.cpu().numpy()) | |
| inv_cano_smpl_A = torch.linalg.inv(cano_smpl.A).cuda() | |
| body_mask = torch.from_numpy(body_mask).cuda() | |
| cano_pts = torch.from_numpy(cano_pts).cuda() | |
| pts_lbs = pts_lbs.cuda() | |
| for pose_idx in tqdm.tqdm(frame_list, desc = 'Generating positional maps...'): | |
| with torch.no_grad(): | |
| live_smpl_woRoot = smpl_model.forward( | |
| betas = smpl_data['betas'], | |
| # global_orient = smpl_data['global_orient'][pose_idx][None], | |
| # transl = smpl_data['transl'][pose_idx][None], | |
| body_pose = smpl_data['body_pose'][pose_idx][None], | |
| jaw_pose = smpl_data['jaw_pose'][pose_idx][None], | |
| expression = smpl_data['expression'][pose_idx][None], | |
| # left_hand_pose = smpl_data['left_hand_pose'][pose_idx][None], | |
| # right_hand_pose = smpl_data['right_hand_pose'][pose_idx][None] | |
| ) | |
| cano2live_jnt_mats_woRoot = torch.matmul(live_smpl_woRoot.A.cuda(), inv_cano_smpl_A)[0] | |
| pt_mats = torch.einsum('nj,jxy->nxy', pts_lbs, cano2live_jnt_mats_woRoot) | |
| live_pts = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], cano_pts) + pt_mats[..., :3, 3] | |
| live_pos_map = torch.zeros((map_size, 2 * map_size, 3)).to(live_pts) | |
| live_pos_map[body_mask] = live_pts | |
| live_pos_map = F.interpolate(live_pos_map.permute(2, 0, 1)[None], None, [0.5, 0.5], mode = 'nearest')[0] | |
| live_pos_map = live_pos_map.permute(1, 2, 0).cpu().numpy() | |
| cv.imwrite(data_dir + '/smpl_pos_map/%08d.exr' % pose_idx, live_pos_map) | |