Source code for cosense3d.modules.projection.fax

import torch
from torch import nn
from einops import rearrange, repeat, reduce
from torchvision.models.resnet import Bottleneck

from cosense3d.modules.plugin.cobevt import CrossViewSwapAttention, Attention, BEVEmbedding
from cosense3d.modules import BaseModule
ResNetBottleNeck = lambda c: Bottleneck(c, c // 4)


[docs]class FAXModule(BaseModule): def __init__( self, middle, dim, img_size, strides, feat_dims, cross_view, cross_view_swap, bev_embedding, self_attn, **kwargs ): super().__init__(**kwargs) self.img_size = img_size cross_views = list() layers = list() downsample_layers = list() for i, (stride, num_layers) in enumerate(zip(strides, middle)): feat_dim = feat_dims[i] feat_height, feat_width = img_size[0] // stride, img_size[1] // stride cva = CrossViewSwapAttention(feat_height, feat_width, feat_dim, dim[i], i, **cross_view, **cross_view_swap) cross_views.append(cva) layer = nn.Sequential(*[ResNetBottleNeck(dim[i]) for _ in range(num_layers)]) layers.append(layer) if i < len(middle) - 1: downsample_layers.append(nn.Sequential( nn.Sequential( nn.Conv2d(dim[i], dim[i] // 4, kernel_size=3, stride=1, padding=1, bias=False), nn.PixelUnshuffle(2), nn.Conv2d(dim[i+1], dim[i+1], 3, padding=1, bias=False), nn.BatchNorm2d(dim[i+1]), nn.ReLU(inplace=True), nn.Conv2d(dim[i+1], dim[i+1], 1, padding=0, bias=False), nn.BatchNorm2d(dim[i+1]) ))) self.bev_embedding = BEVEmbedding(dim[0], **bev_embedding) self.cross_views = nn.ModuleList(cross_views) self.layers = nn.ModuleList(layers) self.downsample_layers = nn.ModuleList(downsample_layers) self.self_attn = Attention(dim[-1], **self_attn)
[docs] def forward(self, img_feat, intrinsic, extrinsic, **kwargs): B = len(img_feat) N = len(intrinsic[0]) intrinsic = self.cat_list(intrinsic, recursive=True) extrinsic = self.cat_list(extrinsic, recursive=True) I_inv = torch.stack([I.inverse()[:3, :3] for I in intrinsic], dim=0 ).reshape(B, N, 3, 3) E_inv = torch.stack([E.inverse() for E in extrinsic], dim=0 ).reshape(B, N, 4, 4) x = self.bev_embedding.get_prior() # d H W x = repeat(x, '... -> b ...', b=B) # B d H W for i, (cross_view, layer) in enumerate(zip(self.cross_views, self.layers)): feature = torch.stack([feat[i] for feat in img_feat], dim=0) x = cross_view(i, x, self.bev_embedding, feature, I_inv, E_inv) x = layer(x) if i < len(img_feat[0])-1: x = self.downsample_layers[i](x) x = self.self_attn(x) return {self.scatter_keys[0]: x}