from torch import nn from torch.nn.utils.parametrizations import spectral_norm as sn from ..utils import activation, normalization class ResBlock3D(nn.Module): def __init__( self, in_channels, out_channels, resample=None, resample_factor=(1,1,1), kernel_size=(3,3,3), act='swish', norm='group', norm_kwargs=None, spectral_norm=False, **kwargs ): super().__init__(**kwargs) if in_channels != out_channels: self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1) else: self.proj = nn.Identity() padding = tuple(k//2 for k in kernel_size) if resample == "down": self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True) self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=resample_factor, padding=padding) self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding) elif resample == "up": self.resample = nn.Upsample( scale_factor=resample_factor, mode='trilinear') self.conv1 = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) output_padding = tuple( 2*p+s-k for (p,s,k) in zip(padding,resample_factor,kernel_size) ) self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=kernel_size, stride=resample_factor, padding=padding, output_padding=output_padding) else: self.resample = nn.Identity() self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding) if isinstance(act, str): act = (act, act) self.act1 = activation(act_type=act[0]) self.act2 = activation(act_type=act[1]) if norm_kwargs is None: norm_kwargs = {} self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs) self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs) if spectral_norm: self.conv1 = sn(self.conv1) self.conv2 = sn(self.conv2) if not isinstance(self.proj, nn.Identity): self.proj = sn(self.proj) def forward(self, x): x_in = self.resample(self.proj(x)) x = self.norm1(x) x = self.act1(x) x = self.conv1(x) x = self.norm2(x) x = self.act2(x) x = self.conv2(x) return x + x_in