Source code for cosense3d.modules.plugin.downsample_conv

"""
Class used to downsample features by 3*3 conv
"""
import torch.nn as nn


[docs]class DoubleConv(nn.Module): """ Double convoltuion """ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: bool): super().__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) )
[docs] def forward(self, x): return self.double_conv(x)
[docs]class DownsampleConv(nn.Module): def __init__(self, in_channels, kernel_sizes=[1], dims=[256], strides=[1], paddings=[0]): super(DownsampleConv, self).__init__() self.layers = nn.ModuleList([]) for ksize, dim, stride, padding in zip( kernel_sizes, dims, strides, paddings): self.layers.append(DoubleConv(in_channels, dim, kernel_size=ksize, stride=stride, padding=padding)) in_channels = dim
[docs] def forward(self, x): for i in range(len(self.layers)): x = self.layers[i](x) return x