weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
from torch import nn
from torch.nn.utils.parametrizations import spectral_norm as sn
from ..utils import activation, normalization
class ResBlock3D(nn.Module):
def __init__(
self, in_channels, out_channels, resample=None,
resample_factor=(1,1,1), kernel_size=(3,3,3),
act='swish', norm='group', norm_kwargs=None,
spectral_norm=False,
**kwargs
):
super().__init__(**kwargs)
if in_channels != out_channels:
self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.proj = nn.Identity()
padding = tuple(k//2 for k in kernel_size)
if resample == "down":
self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True)
self.conv1 = nn.Conv3d(in_channels, out_channels,
kernel_size=kernel_size, stride=resample_factor, padding=padding)
self.conv2 = nn.Conv3d(out_channels, out_channels,
kernel_size=kernel_size, padding=padding)
elif resample == "up":
self.resample = nn.Upsample(
scale_factor=resample_factor, mode='trilinear')
self.conv1 = nn.ConvTranspose3d(in_channels, out_channels,
kernel_size=kernel_size, padding=padding)
output_padding = tuple(
2*p+s-k for (p,s,k) in zip(padding,resample_factor,kernel_size)
)
self.conv2 = nn.ConvTranspose3d(out_channels, out_channels,
kernel_size=kernel_size, stride=resample_factor,
padding=padding, output_padding=output_padding)
else:
self.resample = nn.Identity()
self.conv1 = nn.Conv3d(in_channels, out_channels,
kernel_size=kernel_size, padding=padding)
self.conv2 = nn.Conv3d(out_channels, out_channels,
kernel_size=kernel_size, padding=padding)
if isinstance(act, str):
act = (act, act)
self.act1 = activation(act_type=act[0])
self.act2 = activation(act_type=act[1])
if norm_kwargs is None:
norm_kwargs = {}
self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs)
self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs)
if spectral_norm:
self.conv1 = sn(self.conv1)
self.conv2 = sn(self.conv2)
if not isinstance(self.proj, nn.Identity):
self.proj = sn(self.proj)
def forward(self, x):
x_in = self.resample(self.proj(x))
x = self.norm1(x)
x = self.act1(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.act2(x)
x = self.conv2(x)
return x + x_in