Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from pytorch3d.transforms import axis_angle_to_quaternion | |
| from network.mlp import MLPLinear | |
| from utils.embedder import get_embedder | |
| class HandAvatar(nn.Module): | |
| def __init__(self, | |
| multires = 4, | |
| view_multires = -1, | |
| pose_dim = 15*4): | |
| super(HandAvatar, self).__init__() | |
| self.pos_embedder, self.pos_dim = get_embedder(multires, 3) | |
| if view_multires == -1: | |
| self.view_embedder, self.view_dim = None, 0 | |
| else: | |
| self.view_embedder, self.view_dim = get_embedder(view_multires, 3) | |
| self.pose_dim = pose_dim | |
| self.tex_mlp = MLPLinear( | |
| in_channels = self.pos_dim + 1 + self.view_dim + pose_dim, | |
| inter_channels = [64, 64, 64, 64, 64], | |
| out_channels = 3, | |
| last_op = nn.Sigmoid() | |
| ) | |
| def forward(self, cano_xyz, sdf, view_dir, hand_pose): | |
| batch_size, n_pts = cano_xyz.shape[:2] | |
| in_feat = torch.cat([self.pos_embedder(cano_xyz), sdf], -1) | |
| hand_pose = axis_angle_to_quaternion(hand_pose.reshape(batch_size, -1, 3)).reshape(batch_size, -1) | |
| if self.view_embedder is not None: | |
| in_feat = torch.cat([in_feat, self.view_embedder(view_dir)], -1) | |
| in_feat = torch.cat([in_feat, hand_pose[:, None].expand(-1, n_pts, -1)], -1) | |
| color = self.tex_mlp(in_feat) | |
| return color | |