Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| def change_num_input_channels(model, in_channels=1): | |
| """ | |
| Assumes number of input channels in model is 3. | |
| """ | |
| for i, m in enumerate(model.modules()): | |
| if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3: | |
| m.in_channels = in_channels | |
| # First, sum across channels | |
| W = m.weight.sum(1, keepdim=True) | |
| # Then, divide by number of channels | |
| W = W / in_channels | |
| # Then, repeat by number of channels | |
| size = [1] * W.ndim | |
| size[1] = in_channels | |
| W = W.repeat(size) | |
| m.weight = nn.Parameter(W) | |
| break | |
| return model | |
| def change_initial_stride(model, stride, in_channels): | |
| for i, m in enumerate(model.modules()): | |
| if isinstance(m, (nn.Conv2d, nn.Conv3d)) and m.in_channels == in_channels: | |
| m.stride = stride | |
| break | |
| return model |