pengc02's picture
all
ec9a6bc
import torch
import pytorch3d.ops
import smplx
from utils.posevocab_custom_ops.nearest_face import nearest_face_pytorch3d, nearest_face
from utils.knn import knn_gather
import config
def calc_blending_weight(query_pts, smpl_v, smpl_f, smpl_lbs = None, near_thres = 0.08, method = 'NN'):
"""
:param query_pts: (B, N, 3)
:param smpl_v: (B, M, 3)
:param smpl_f: (B, F, 3)
:param near_thres:
:param method: 'NN' or 'barycentric'
:return:
"""
assert (query_pts.shape[0] == smpl_v.shape[0] == smpl_f.shape[0])
batch_size = query_pts.shape[0]
if smpl_lbs is None:
smpl_lbs = smpl_skinning_weights[None].expand(batch_size, -1, -1)
if method == 'NN':
# NN
dists_to_smpl, indices, _ = pytorch3d.ops.knn_points(query_pts, smpl_v, K = 1)
near_flag = dists_to_smpl[:, :, 0] < near_thres ** 2
pts_w = pytorch3d.ops.knn_gather(smpl_lbs, indices)
pts_w = pts_w[:, :, 0]
else:
# # barycentric
# dists_to_smpl, face_indices, closest_pts = [], [], []
# for batch_idx in range(batch_size):
# dists_to_smpl_, face_indices_, closest_pts_ = nearest_face(smpl_v[batch_idx], smpl_f[batch_idx], query_pts[batch_idx])
# dists_to_smpl.append(dists_to_smpl_)
# face_indices.append(face_indices_)
# closest_pts.append(closest_pts_)
# dists_to_smpl = torch.stack(dists_to_smpl, 0)
# face_indices = torch.stack(face_indices, 0) # (B, N)
# closest_pts = torch.stack(closest_pts, 0)
# face_vertex_ids = torch.gather(smpl_f.long(), 1, face_indices[:, :, None].long().expand(-1, -1, 3)) # (B, N, 3)
# face_vertices = knn_gather(smpl_v, face_vertex_ids)
# bc_coords = barycentric_coordinate(closest_pts, face_vertices)
dists_to_smpl, face_indices, bc_coords = nearest_face_pytorch3d(query_pts, smpl_v, smpl_f[0])
face_vertex_ids = torch.gather(smpl_f.long(), 1, face_indices[:, :, None].long().expand(-1, -1, 3)) # (B, N, 3)
face_lbs = knn_gather(smpl_lbs, face_vertex_ids)
pts_w = (bc_coords[..., None] * face_lbs).sum(2)
near_flag = dists_to_smpl < near_thres
return pts_w, near_flag
def skinning(points, lbs, jnt_mats, return_pt_mats = False):
"""
forward skinning
:param points: (B, N, 3)
:param lbs: (B, N, 24)
:param jnt_mats: (B, 24, 4, 4)
:return:
"""
# lbs
pt_mats = torch.einsum('bnj,bjxy->bnxy', lbs, jnt_mats)
live_pts = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], points) + pt_mats[..., :3, 3]
if return_pt_mats:
return live_pts, pt_mats
else:
return live_pts
def skinning_normal(normals, lbs, jnt_mats):
# lbs
pt_mats = torch.einsum('bnj,bjxy->bnxy', lbs, jnt_mats)
live_normals = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], normals)
return live_normals
smpl_model = smplx.SMPLX(model_path = config.PROJ_DIR + '/smpl_files/smplx', gender = 'neutral', use_pca = False, num_pca_comps = 45, flat_hand_mean = True, batch_size = 1)
smpl_skinning_weights = smpl_model.lbs_weights.clone().to(torch.float32).to(config.device)