import torch
import torch.nn as nn
import torchvision.models as models
from einops import rearrange
from cosense3d.modules import BaseModule
from cosense3d.modules.plugin import build_plugin_module
from cosense3d.modules.utils.positional_encoding import img_locations
[docs]class ResnetEncoder(BaseModule):
"""Resnet family to encode image."""
def __init__(self, num_layers, feat_indices, out_index, img_size,
neck=None, **kwargs):
super(ResnetEncoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.feat_indices = sorted(feat_indices)
self.out_index = out_index
self.img_size = img_size
indices = (out_index, ) if isinstance(out_index, int) else out_index
self.strides = [2 ** (idx + 1) for idx in indices]
self.feat_sizes = [(img_size[0] // stride, img_size[1] // stride)
for stride in self.strides]
if 'img_coor' in self.scatter_keys:
self.img_locations = [nn.Parameter(
img_locations(img_size, feat_size), requires_grad=False)
for feat_size in self.feat_sizes]
self.img_locations = nn.ParameterList(self.img_locations)
resnet = getattr(models, f'resnet{self.num_layers}', None)
if resnet is None:
raise ValueError(f"{self.num_layers} is not a valid number of resnet ""layers")
resnet_weights = getattr(models, f"ResNet{self.num_layers}_Weights")
self.encoder = resnet(weights=resnet_weights.DEFAULT)
self.neck = build_plugin_module(neck) if neck is not None else None
[docs] def forward(self, input_images, **kwargs):
num_imgs = [len(x) for x in input_images]
imgs = self.compose_imgs(input_images)
b, h, w, c = imgs.shape
# b, h, w, c -> b, c, h, w
imgs = imgs.permute(0, 3, 1, 2).contiguous()
x = self.encoder.conv1(imgs)
x = self.encoder.bn1(x)
x = self.encoder.relu(x)
x = self.encoder.maxpool(x)
out = []
for i in range(1, 5):
x = getattr(self.encoder, f'layer{i}')(x)
if i in self.feat_indices:
out.append(x)
if self.neck is not None:
out = self.neck(out)
if isinstance(self.out_index, tuple):
out = [out[self.feat_indices.index(i)] for i in self.out_index]
else:
out = out[self.feat_indices.index(self.out_index)]
return self.format_output(out, num_imgs)