File size: 3,259 Bytes
ec9a6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import pytorch3d.ops

import smplx
from utils.posevocab_custom_ops.nearest_face import nearest_face_pytorch3d, nearest_face
from utils.knn import knn_gather
import config


def calc_blending_weight(query_pts, smpl_v, smpl_f, smpl_lbs = None, near_thres = 0.08, method = 'NN'):
    """

    :param query_pts: (B, N, 3)

    :param smpl_v: (B, M, 3)

    :param smpl_f: (B, F, 3)

    :param near_thres:

    :param method: 'NN' or 'barycentric'

    :return:

    """
    assert (query_pts.shape[0] == smpl_v.shape[0] == smpl_f.shape[0])
    batch_size = query_pts.shape[0]

    if smpl_lbs is None:
        smpl_lbs = smpl_skinning_weights[None].expand(batch_size, -1, -1)

    if method == 'NN':
        # NN
        dists_to_smpl, indices, _ = pytorch3d.ops.knn_points(query_pts, smpl_v, K = 1)
        near_flag = dists_to_smpl[:, :, 0] < near_thres ** 2
        pts_w = pytorch3d.ops.knn_gather(smpl_lbs, indices)
        pts_w = pts_w[:, :, 0]
    else:
        # # barycentric
        # dists_to_smpl, face_indices, closest_pts = [], [], []
        # for batch_idx in range(batch_size):
        #     dists_to_smpl_, face_indices_, closest_pts_ = nearest_face(smpl_v[batch_idx], smpl_f[batch_idx], query_pts[batch_idx])
        #     dists_to_smpl.append(dists_to_smpl_)
        #     face_indices.append(face_indices_)
        #     closest_pts.append(closest_pts_)
        # dists_to_smpl = torch.stack(dists_to_smpl, 0)
        # face_indices = torch.stack(face_indices, 0)  # (B, N)
        # closest_pts = torch.stack(closest_pts, 0)
        # face_vertex_ids = torch.gather(smpl_f.long(), 1, face_indices[:, :, None].long().expand(-1, -1, 3))  # (B, N, 3)
        # face_vertices = knn_gather(smpl_v, face_vertex_ids)
        # bc_coords = barycentric_coordinate(closest_pts, face_vertices)

        dists_to_smpl, face_indices, bc_coords = nearest_face_pytorch3d(query_pts, smpl_v, smpl_f[0])
        face_vertex_ids = torch.gather(smpl_f.long(), 1, face_indices[:, :, None].long().expand(-1, -1, 3))  # (B, N, 3)

        face_lbs = knn_gather(smpl_lbs, face_vertex_ids)
        pts_w = (bc_coords[..., None] * face_lbs).sum(2)
        near_flag = dists_to_smpl < near_thres

    return pts_w, near_flag


def skinning(points, lbs, jnt_mats, return_pt_mats = False):
    """

    forward skinning

    :param points: (B, N, 3)

    :param lbs: (B, N, 24)

    :param jnt_mats: (B, 24, 4, 4)

    :return:

    """
    # lbs
    pt_mats = torch.einsum('bnj,bjxy->bnxy', lbs, jnt_mats)

    live_pts = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], points) + pt_mats[..., :3, 3]

    if return_pt_mats:
        return live_pts, pt_mats
    else:
        return live_pts


def skinning_normal(normals, lbs, jnt_mats):
    # lbs
    pt_mats = torch.einsum('bnj,bjxy->bnxy', lbs, jnt_mats)

    live_normals = torch.einsum('bnxy,bny->bnx', pt_mats[..., :3, :3], normals)
    return live_normals


smpl_model = smplx.SMPLX(model_path = config.PROJ_DIR + '/smpl_files/smplx', gender = 'neutral', use_pca = False, num_pca_comps = 45, flat_hand_mean = True, batch_size = 1)
smpl_skinning_weights = smpl_model.lbs_weights.clone().to(torch.float32).to(config.device)