Source code for cosense3d.modules.plugin.naive_compressor

import torch.nn as nn


[docs]class NaiveCompressor(nn.Module): """ A very naive compression that only compress on the channel. """ def __init__(self, input_dim, compress_ratio): super().__init__() self.encoder = nn.Sequential( nn.Conv2d(input_dim, input_dim//compress_ratio, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(input_dim//compress_ratio, eps=1e-3, momentum=0.01), nn.ReLU() ) self.decoder = nn.Sequential( nn.Conv2d(input_dim//compress_ratio, input_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(input_dim, eps=1e-3, momentum=0.01), nn.ReLU(), nn.Conv2d(input_dim, input_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(input_dim, eps=1e-3, momentum=0.01), nn.ReLU() )
[docs] def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x