Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import math | |
| # from lib.modules.stylegan2 import Generator | |
| # from lib.ops.styleGAN import grid_sample_gradfix | |
| def conv3x3(in_channels, out_channels, stride=1, use_bn=False): | |
| assert stride == 1 or stride == 2 | |
| layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True)] | |
| if use_bn: | |
| layers.append(nn.BatchNorm2d(out_channels)) | |
| return nn.Sequential(*layers) | |
| def deconv3x3(in_channels, out_channels, use_bn=False): | |
| layers = [nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True)] | |
| if use_bn: | |
| layers.append(nn.BatchNorm2d(out_channels)) | |
| return nn.Sequential(*layers) | |
| class ConvStack(nn.Module): | |
| def __init__(self, | |
| in_dim, | |
| out_dim, | |
| hid_dim=None, | |
| kernel_size=5, | |
| layer_num=3, | |
| use_relu=False, | |
| ): | |
| super().__init__() | |
| assert kernel_size in [3, 5, 7] | |
| if hid_dim is None: | |
| hid_dim = out_dim | |
| padding = (kernel_size - 1) // 2 | |
| layers = [] | |
| layers.append(nn.Conv2d(in_dim, hid_dim, kernel_size=kernel_size, stride=1, padding=padding, bias=False)) | |
| if use_relu: | |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
| for i in range(layer_num - 2): | |
| layers.append(nn.Conv2d(hid_dim, hid_dim, kernel_size=kernel_size, stride=1, padding=padding, bias=False)) | |
| if use_relu: | |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
| layers.append(nn.Conv2d(hid_dim, out_dim, kernel_size=kernel_size, stride=1, padding=padding, bias=False)) | |
| self.layers = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.layers(x) | |
| class Unet5d(nn.Module): | |
| def __init__(self, in_c, out_c, nf): | |
| super().__init__() | |
| self.conv1 = conv3x3(in_c, nf, stride=1, use_bn=False) | |
| self.conv2 = conv3x3(nf, 2 * nf, stride=2, use_bn=True) | |
| self.conv3 = conv3x3(2 * nf, 4 * nf, stride=2, use_bn=True) | |
| self.conv4 = conv3x3(4 * nf, 8 * nf, stride=2, use_bn=True) | |
| self.conv5 = conv3x3(8 * nf, 8 * nf, stride=2, use_bn=True) | |
| self.deconv1 = deconv3x3(8 * nf, 8 * nf, use_bn=True) | |
| self.deconv2 = deconv3x3(2 * 8 * nf, 4 * nf, use_bn=True) | |
| self.deconv3 = deconv3x3(2 * 4 * nf, 2 * nf, use_bn=True) | |
| self.deconv4 = deconv3x3(2 * 2 * nf, nf, use_bn=True) | |
| self.deconv5 = conv3x3(2 * nf, nf, stride=1, use_bn=False) | |
| self.tail = nn.Conv2d(nf, out_c, kernel_size=1, stride=1, padding=0, bias=True) | |
| def forward(self, x): | |
| # x: bs x in_c x 128 x 128 | |
| x1 = self.conv1(x) # bs x nf x 128 x 128 | |
| x2 = self.conv2(x1) # bs x 2nf x 64 x 64 | |
| x3 = self.conv3(x2) # bs x 4nf x 32 x 32 | |
| x4 = self.conv4(x3) # bs x 8nf x 16 x 16 | |
| x5 = self.conv5(x4) # bs x 8nf x 8 x 8 | |
| y1 = self.deconv1(x5) # bs x 8nf x 16 x 16 | |
| y2 = self.deconv2(torch.cat([y1, x4], dim=1)) # bs x 4nf x 32 x 32 | |
| y3 = self.deconv3(torch.cat([y2, x3], dim=1)) # bs x 2nf x 64 x 64 | |
| y4 = self.deconv4(torch.cat([y3, x2], dim=1)) # bs x nf x 128 x 128 | |
| y5 = self.deconv5(torch.cat([y4, x1], dim=1)) # bs x nf x 128 x 128 | |
| out = self.tail(y5) | |
| return out | |
| def grid_sample(image, p2d): | |
| # p2d: B x ... x 2 | |
| # image: B x C x IH x IW | |
| B, C, IH, IW = image.shape | |
| image = image.view(B, C, IH * IW) | |
| assert p2d.shape[0] == B | |
| assert p2d.shape[-1] == 2 | |
| points_shape = list(p2d.shape[1:-1]) | |
| p2d = p2d.contiguous().view(B, 1, -1, 2) # B x 1 x N x 2 | |
| ix = p2d[..., 0] # B x 1 x N | |
| iy = p2d[..., 1] # B x 1 x N | |
| ix = ((ix + 1) / 2) * (IW - 1) | |
| iy = ((iy + 1) / 2) * (IH - 1) | |
| with torch.no_grad(): | |
| ix_nw = torch.floor(ix) | |
| iy_nw = torch.floor(iy) | |
| ix_ne = ix_nw + 1 | |
| iy_ne = iy_nw | |
| ix_sw = ix_nw | |
| iy_sw = iy_nw + 1 | |
| ix_se = ix_nw + 1 | |
| iy_se = iy_nw + 1 | |
| nw = (ix_se - ix) * (iy_se - iy) | |
| ne = (ix - ix_sw) * (iy_sw - iy) | |
| sw = (ix_ne - ix) * (iy - iy_ne) | |
| se = (ix - ix_nw) * (iy - iy_nw) | |
| with torch.no_grad(): | |
| torch.clamp(ix_nw, 0, IW - 1, out=ix_nw) | |
| torch.clamp(iy_nw, 0, IH - 1, out=iy_nw) | |
| torch.clamp(ix_ne, 0, IW - 1, out=ix_ne) | |
| torch.clamp(iy_ne, 0, IH - 1, out=iy_ne) | |
| torch.clamp(ix_sw, 0, IW - 1, out=ix_sw) | |
| torch.clamp(iy_sw, 0, IH - 1, out=iy_sw) | |
| torch.clamp(ix_se, 0, IW - 1, out=ix_se) | |
| torch.clamp(iy_se, 0, IH - 1, out=iy_se) | |
| nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(B, 1, -1).expand(-1, C, -1)) # B x C x N | |
| ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(B, 1, -1).expand(-1, C, -1)) # B x C x N | |
| sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(B, 1, -1).expand(-1, C, -1)) # B x C x N | |
| se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(B, 1, -1).expand(-1, C, -1)) # B x C x N | |
| out_val = nw_val * nw + ne_val * ne + sw_val * sw + se_val * se # B x C x N | |
| out_val = out_val.permute(0, 2, 1).contiguous().view([B] + points_shape + [C]) | |
| return out_val | |
| def triplane_sample(xyz, fmap): | |
| C = fmap.shape[1] // 3 | |
| assert fmap.shape[1] == 3 * C | |
| fmap_list = fmap.split(C, dim=1) | |
| output = [] | |
| for fmapIdx, axisIdx1, axisIdx2 in zip([0, 1, 2], [0, 1, 2], [1, 2, 0]): | |
| feat = grid_sample(torch.stack([xyz[..., axisIdx1], xyz[..., axisIdx2]], dim=-1), | |
| fmap_list[fmapIdx].expand(xyz.shape[0], -1, -1, -1)) | |
| output.append(feat) | |
| return torch.cat(output, dim=-1) | |
| class TriPlaneFeature(nn.Module): | |
| def __init__(self, feat_dim, feat_size): | |
| super().__init__() | |
| self.feat_dim = feat_dim | |
| self.famp = nn.Parameter(torch.randn(1, 3 * feat_dim, feat_size, feat_size).float() * 0.03) | |
| def forward(self, input): | |
| return self.famp.expand(input.shape[0], -1, -1, -1) | |
| def sample_feat(xyz, fmap): | |
| triplane_sample(xyz, fmap) | |
| class UVFeature(nn.Module): | |
| def __init__(self, feat_dim, feat_size): | |
| super().__init__() | |
| self.feat_dim = feat_dim | |
| self.famp = nn.Parameter(torch.randn(1, feat_dim, feat_size, feat_size).float() * 0.03) | |
| def forward(self, input): | |
| return self.famp.expand(input.shape[0], -1, -1, -1) | |
| def sample_feat(p2d, fmap): | |
| return grid_sample(p2d, fmap) | |
| # class TriPlaneFeature_StyleGAN(nn.Module): | |
| # def __init__(self, feat_dim, feat_size, semantic_dim=0, style_dim=512, n_mlp=8): | |
| # super().__init__() | |
| # assert 2 ** int(math.log(feat_size, 2)) == feat_size | |
| # self.semantic_dim = max(semantic_dim, 0) | |
| # self.style_dim = style_dim | |
| # self.feat_dim = feat_dim | |
| # self.fc = nn.Linear(style_dim + semantic_dim, style_dim) | |
| # self.generator = Generator(size=feat_size, dim=feat_dim * 3, style_dim=style_dim, n_mlp=n_mlp) | |
| # def forward(self, styles, semantic=None, randomize_noise=True): | |
| # if isinstance(styles, (list, tuple)): | |
| # if semantic is None: | |
| # x = styles | |
| # else: | |
| # x = [self.fc(torch.cat([s, semantic], dim=-1)) for s in styles] | |
| # elif isinstance(styles, torch.Tensor): | |
| # if semantic is None: | |
| # x = [styles] | |
| # else: | |
| # x = [torch.cat([styles, semantic], dim=-1)] | |
| # else: | |
| # raise NotImplementedError | |
| # fmap_x, fmap_y, fmap_z = self.generator(styles=x, randomize_noise=randomize_noise)[0].split(self.feat_dim, dim=1) | |
| # return [fmap_x, fmap_y, fmap_z] | |
| # @staticmethod | |
| # def sample_feat(xyz, fmap_list): | |
| # # xyz: B x N x 3 (-1 ~ 1) | |
| # # im_feat: B x C x H x W | |
| # # output: B x N x f | |
| # assert xyz.shape[-1] == 3 | |
| # output = [] | |
| # for fmapIdx, axisIdx1, axisIdx2 in zip([0, 1, 2], [1, 2, 0], [2, 0, 1]): | |
| # p2d = torch.stack([xyz[..., axisIdx1], xyz[..., axisIdx2]], dim=-1) | |
| # fmap = fmap_list[fmapIdx].expand(xyz.shape[0], -1, -1, -1) | |
| # p2d = p2d + 1.0 / fmap.shape[-1] | |
| # feat = grid_sample_gradfix.grid_sample(fmap, p2d.unsqueeze(2))[..., 0] | |
| # feat = feat.permute(0, 2, 1) | |
| # output.append(feat) | |
| # return torch.cat(output, dim=-1) | |
| # class UVFeature_StyleGAN(nn.Module): | |
| # def __init__(self, feat_dim, feat_size, semantic_dim=0, style_dim=512, n_mlp=8): | |
| # super().__init__() | |
| # assert 2 ** int(math.log(feat_size, 2)) == feat_size | |
| # self.semantic_dim = max(semantic_dim, 0) | |
| # self.style_dim = style_dim | |
| # self.feat_dim = feat_dim | |
| # self.fc = nn.Linear(style_dim + semantic_dim, style_dim) | |
| # self.generator = Generator(size=feat_size, dim=feat_dim, style_dim=style_dim, n_mlp=n_mlp) | |
| # def forward(self, styles, semantic=None, randomize_noise=True): | |
| # if isinstance(styles, (list, tuple)): | |
| # if semantic is None: | |
| # x = styles | |
| # else: | |
| # x = [self.fc(torch.cat([s, semantic], dim=-1)) for s in styles] | |
| # elif isinstance(styles, torch.Tensor): | |
| # if semantic is None: | |
| # x = [styles] | |
| # else: | |
| # x = [torch.cat([styles, semantic], dim=-1)] | |
| # else: | |
| # raise NotImplementedError | |
| # fmap = self.generator(styles=x, randomize_noise=randomize_noise)[0] | |
| # return fmap | |
| # @staticmethod | |
| # def sample_feat(p2d, fmap): | |
| # # p2d: B x N x 2 (-1 ~ 1) | |
| # # im_feat: B x C x H x W | |
| # # output: B x N x f | |
| # assert p2d.shape[-1] == 2 | |
| # fmap = fmap.expand(p2d.shape[0], -1, -1, -1) | |
| # p2d = p2d + 1.0 / fmap.shape[-1] | |
| # feat = grid_sample_gradfix.grid_sample(fmap, p2d.unsqueeze(2))[..., 0] | |
| # feat = feat.permute(0, 2, 1) | |
| # return feat | |