pengc02's picture
all
ec9a6bc
import torch
import torch.nn as nn
import math
# from lib.modules.stylegan2 import Generator
# from lib.ops.styleGAN import grid_sample_gradfix
def conv3x3(in_channels, out_channels, stride=1, use_bn=False):
assert stride == 1 or stride == 2
layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True)]
if use_bn:
layers.append(nn.BatchNorm2d(out_channels))
return nn.Sequential(*layers)
def deconv3x3(in_channels, out_channels, use_bn=False):
layers = [nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True)]
if use_bn:
layers.append(nn.BatchNorm2d(out_channels))
return nn.Sequential(*layers)
class ConvStack(nn.Module):
def __init__(self,
in_dim,
out_dim,
hid_dim=None,
kernel_size=5,
layer_num=3,
use_relu=False,
):
super().__init__()
assert kernel_size in [3, 5, 7]
if hid_dim is None:
hid_dim = out_dim
padding = (kernel_size - 1) // 2
layers = []
layers.append(nn.Conv2d(in_dim, hid_dim, kernel_size=kernel_size, stride=1, padding=padding, bias=False))
if use_relu:
layers.append(nn.LeakyReLU(0.2, inplace=True))
for i in range(layer_num - 2):
layers.append(nn.Conv2d(hid_dim, hid_dim, kernel_size=kernel_size, stride=1, padding=padding, bias=False))
if use_relu:
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(hid_dim, out_dim, kernel_size=kernel_size, stride=1, padding=padding, bias=False))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class Unet5d(nn.Module):
def __init__(self, in_c, out_c, nf):
super().__init__()
self.conv1 = conv3x3(in_c, nf, stride=1, use_bn=False)
self.conv2 = conv3x3(nf, 2 * nf, stride=2, use_bn=True)
self.conv3 = conv3x3(2 * nf, 4 * nf, stride=2, use_bn=True)
self.conv4 = conv3x3(4 * nf, 8 * nf, stride=2, use_bn=True)
self.conv5 = conv3x3(8 * nf, 8 * nf, stride=2, use_bn=True)
self.deconv1 = deconv3x3(8 * nf, 8 * nf, use_bn=True)
self.deconv2 = deconv3x3(2 * 8 * nf, 4 * nf, use_bn=True)
self.deconv3 = deconv3x3(2 * 4 * nf, 2 * nf, use_bn=True)
self.deconv4 = deconv3x3(2 * 2 * nf, nf, use_bn=True)
self.deconv5 = conv3x3(2 * nf, nf, stride=1, use_bn=False)
self.tail = nn.Conv2d(nf, out_c, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x):
# x: bs x in_c x 128 x 128
x1 = self.conv1(x) # bs x nf x 128 x 128
x2 = self.conv2(x1) # bs x 2nf x 64 x 64
x3 = self.conv3(x2) # bs x 4nf x 32 x 32
x4 = self.conv4(x3) # bs x 8nf x 16 x 16
x5 = self.conv5(x4) # bs x 8nf x 8 x 8
y1 = self.deconv1(x5) # bs x 8nf x 16 x 16
y2 = self.deconv2(torch.cat([y1, x4], dim=1)) # bs x 4nf x 32 x 32
y3 = self.deconv3(torch.cat([y2, x3], dim=1)) # bs x 2nf x 64 x 64
y4 = self.deconv4(torch.cat([y3, x2], dim=1)) # bs x nf x 128 x 128
y5 = self.deconv5(torch.cat([y4, x1], dim=1)) # bs x nf x 128 x 128
out = self.tail(y5)
return out
def grid_sample(image, p2d):
# p2d: B x ... x 2
# image: B x C x IH x IW
B, C, IH, IW = image.shape
image = image.view(B, C, IH * IW)
assert p2d.shape[0] == B
assert p2d.shape[-1] == 2
points_shape = list(p2d.shape[1:-1])
p2d = p2d.contiguous().view(B, 1, -1, 2) # B x 1 x N x 2
ix = p2d[..., 0] # B x 1 x N
iy = p2d[..., 1] # B x 1 x N
ix = ((ix + 1) / 2) * (IW - 1)
iy = ((iy + 1) / 2) * (IH - 1)
with torch.no_grad():
ix_nw = torch.floor(ix)
iy_nw = torch.floor(iy)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
with torch.no_grad():
torch.clamp(ix_nw, 0, IW - 1, out=ix_nw)
torch.clamp(iy_nw, 0, IH - 1, out=iy_nw)
torch.clamp(ix_ne, 0, IW - 1, out=ix_ne)
torch.clamp(iy_ne, 0, IH - 1, out=iy_ne)
torch.clamp(ix_sw, 0, IW - 1, out=ix_sw)
torch.clamp(iy_sw, 0, IH - 1, out=iy_sw)
torch.clamp(ix_se, 0, IW - 1, out=ix_se)
torch.clamp(iy_se, 0, IH - 1, out=iy_se)
nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(B, 1, -1).expand(-1, C, -1)) # B x C x N
ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(B, 1, -1).expand(-1, C, -1)) # B x C x N
sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(B, 1, -1).expand(-1, C, -1)) # B x C x N
se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(B, 1, -1).expand(-1, C, -1)) # B x C x N
out_val = nw_val * nw + ne_val * ne + sw_val * sw + se_val * se # B x C x N
out_val = out_val.permute(0, 2, 1).contiguous().view([B] + points_shape + [C])
return out_val
def triplane_sample(xyz, fmap):
C = fmap.shape[1] // 3
assert fmap.shape[1] == 3 * C
fmap_list = fmap.split(C, dim=1)
output = []
for fmapIdx, axisIdx1, axisIdx2 in zip([0, 1, 2], [0, 1, 2], [1, 2, 0]):
feat = grid_sample(torch.stack([xyz[..., axisIdx1], xyz[..., axisIdx2]], dim=-1),
fmap_list[fmapIdx].expand(xyz.shape[0], -1, -1, -1))
output.append(feat)
return torch.cat(output, dim=-1)
class TriPlaneFeature(nn.Module):
def __init__(self, feat_dim, feat_size):
super().__init__()
self.feat_dim = feat_dim
self.famp = nn.Parameter(torch.randn(1, 3 * feat_dim, feat_size, feat_size).float() * 0.03)
def forward(self, input):
return self.famp.expand(input.shape[0], -1, -1, -1)
@staticmethod
def sample_feat(xyz, fmap):
triplane_sample(xyz, fmap)
class UVFeature(nn.Module):
def __init__(self, feat_dim, feat_size):
super().__init__()
self.feat_dim = feat_dim
self.famp = nn.Parameter(torch.randn(1, feat_dim, feat_size, feat_size).float() * 0.03)
def forward(self, input):
return self.famp.expand(input.shape[0], -1, -1, -1)
@staticmethod
def sample_feat(p2d, fmap):
return grid_sample(p2d, fmap)
# class TriPlaneFeature_StyleGAN(nn.Module):
# def __init__(self, feat_dim, feat_size, semantic_dim=0, style_dim=512, n_mlp=8):
# super().__init__()
# assert 2 ** int(math.log(feat_size, 2)) == feat_size
# self.semantic_dim = max(semantic_dim, 0)
# self.style_dim = style_dim
# self.feat_dim = feat_dim
# self.fc = nn.Linear(style_dim + semantic_dim, style_dim)
# self.generator = Generator(size=feat_size, dim=feat_dim * 3, style_dim=style_dim, n_mlp=n_mlp)
# def forward(self, styles, semantic=None, randomize_noise=True):
# if isinstance(styles, (list, tuple)):
# if semantic is None:
# x = styles
# else:
# x = [self.fc(torch.cat([s, semantic], dim=-1)) for s in styles]
# elif isinstance(styles, torch.Tensor):
# if semantic is None:
# x = [styles]
# else:
# x = [torch.cat([styles, semantic], dim=-1)]
# else:
# raise NotImplementedError
# fmap_x, fmap_y, fmap_z = self.generator(styles=x, randomize_noise=randomize_noise)[0].split(self.feat_dim, dim=1)
# return [fmap_x, fmap_y, fmap_z]
# @staticmethod
# def sample_feat(xyz, fmap_list):
# # xyz: B x N x 3 (-1 ~ 1)
# # im_feat: B x C x H x W
# # output: B x N x f
# assert xyz.shape[-1] == 3
# output = []
# for fmapIdx, axisIdx1, axisIdx2 in zip([0, 1, 2], [1, 2, 0], [2, 0, 1]):
# p2d = torch.stack([xyz[..., axisIdx1], xyz[..., axisIdx2]], dim=-1)
# fmap = fmap_list[fmapIdx].expand(xyz.shape[0], -1, -1, -1)
# p2d = p2d + 1.0 / fmap.shape[-1]
# feat = grid_sample_gradfix.grid_sample(fmap, p2d.unsqueeze(2))[..., 0]
# feat = feat.permute(0, 2, 1)
# output.append(feat)
# return torch.cat(output, dim=-1)
# class UVFeature_StyleGAN(nn.Module):
# def __init__(self, feat_dim, feat_size, semantic_dim=0, style_dim=512, n_mlp=8):
# super().__init__()
# assert 2 ** int(math.log(feat_size, 2)) == feat_size
# self.semantic_dim = max(semantic_dim, 0)
# self.style_dim = style_dim
# self.feat_dim = feat_dim
# self.fc = nn.Linear(style_dim + semantic_dim, style_dim)
# self.generator = Generator(size=feat_size, dim=feat_dim, style_dim=style_dim, n_mlp=n_mlp)
# def forward(self, styles, semantic=None, randomize_noise=True):
# if isinstance(styles, (list, tuple)):
# if semantic is None:
# x = styles
# else:
# x = [self.fc(torch.cat([s, semantic], dim=-1)) for s in styles]
# elif isinstance(styles, torch.Tensor):
# if semantic is None:
# x = [styles]
# else:
# x = [torch.cat([styles, semantic], dim=-1)]
# else:
# raise NotImplementedError
# fmap = self.generator(styles=x, randomize_noise=randomize_noise)[0]
# return fmap
# @staticmethod
# def sample_feat(p2d, fmap):
# # p2d: B x N x 2 (-1 ~ 1)
# # im_feat: B x C x H x W
# # output: B x N x f
# assert p2d.shape[-1] == 2
# fmap = fmap.expand(p2d.shape[0], -1, -1, -1)
# p2d = p2d + 1.0 / fmap.shape[-1]
# feat = grid_sample_gradfix.grid_sample(fmap, p2d.unsqueeze(2))[..., 0]
# feat = feat.permute(0, 2, 1)
# return feat