"""
This class is about swap fusion applications
"""
import torch
from einops import rearrange
from torch import nn, einsum
from einops.layers.torch import Rearrange, Reduce
from cosense3d.modules import BaseModule
from cosense3d.modules.plugin.cobevt import NaiveDecoder
[docs]class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
[docs] def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
[docs]class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
[docs] def forward(self, x):
return self.net(x)
# swap attention -> max_vit
[docs]class Attention(nn.Module):
"""
Unit Attention class. Todo: mask is not added yet.
Parameters
----------
dim: int
Input feature dimension.
dim_head: int
The head dimension.
dropout: float
Dropout rate
agent_size: int
The agent can be different views, timestamps or vehicles.
"""
def __init__(
self,
dim,
dim_head=32,
dropout=0.,
agent_size=6,
window_size=7
):
super().__init__()
assert (dim % dim_head) == 0, \
'dimension should be divisible by dimension per head'
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.window_size = [agent_size, window_size, window_size]
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.attend = nn.Sequential(
nn.Softmax(dim=-1)
)
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.Dropout(dropout)
)
self.relative_position_bias_table = nn.Embedding(
(2 * self.window_size[0] - 1) *
(2 * self.window_size[1] - 1) *
(2 * self.window_size[2] - 1),
self.heads) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for
# each token inside the window
coords_d = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
# 3, Wd, Wh, Ww
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing='ij'))
coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
# 3, Wd*Wh*Ww, Wd*Wh*Ww
relative_coords = \
coords_flatten[:, :, None] - coords_flatten[:, None, :]
# Wd*Wh*Ww, Wd*Wh*Ww, 3
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
# shift to start from 0
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= \
(2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
self.register_buffer("relative_position_index",
relative_position_index)
[docs] def forward(self, x, mask=None):
# x shape: b, l, h, w, w_h, w_w, c
batch, agent_size, height, width, window_height, window_width, _, device, h \
= *x.shape, x.device, self.heads
# flatten
x = rearrange(x, 'b l x y w1 w2 d -> (b x y) (l w1 w2) d')
# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# split heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
(q, k, v))
# scale
q = q * self.scale
# sim
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add positional bias
L = agent_size * window_height * window_width
bias = self.relative_position_bias_table(self.relative_position_index[:L, :L])
sim = sim + rearrange(bias, 'i j h -> h i j')
# mask shape if exist: b x y w1 w2 e l
if mask is not None:
# b x y w1 w2 e l -> (b x y) 1 (l w1 w2)
mask = rearrange(mask, 'b x y w1 w2 e l -> (b x y) e (l w1 w2)')
# (b x y) 1 1 (l w1 w2) = b h 1 n
mask = mask.unsqueeze(1)
sim = sim.masked_fill(mask == 0, -float('inf'))
# attention
attn = self.attend(sim)
# aggregate
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h (l w1 w2) d -> b l w1 w2 (h d)',
l=agent_size, w1=window_height, w2=window_width)
# combine heads out
out = self.to_out(out)
return rearrange(out, '(b x y) l w1 w2 d -> b l x y w1 w2 d',
b=batch, x=height, y=width)
[docs]class SwapFusionBlockMask(nn.Module):
"""
Swap Fusion Block contains window attention and grid attention with
mask enabled for multi-vehicle cooperation.
"""
def __init__(self,
input_dim,
mlp_dim,
dim_head,
window_size,
agent_size,
drop_out):
super(SwapFusionBlockMask, self).__init__()
self.window_size = window_size
self.window_attention = PreNormResidual(input_dim,
Attention(input_dim, dim_head,
drop_out,
agent_size,
window_size))
self.window_ffd = PreNormResidual(input_dim,
FeedForward(input_dim, mlp_dim,
drop_out))
self.grid_attention = PreNormResidual(input_dim,
Attention(input_dim, dim_head,
drop_out,
agent_size,
window_size))
self.grid_ffd = PreNormResidual(input_dim,
FeedForward(input_dim, mlp_dim,
drop_out))
[docs] def forward(self, x, mask):
# x: b l c h w
# mask: b h w 1 l
# window attention -> grid attention
mask_swap = mask
# mask b h w 1 l -> b x y w1 w2 1 L
mask_swap = rearrange(mask_swap,
'b (x w1) (y w2) e l -> b x y w1 w2 e l',
w1=self.window_size, w2=self.window_size)
x = rearrange(x, 'b m d (x w1) (y w2) -> b m x y w1 w2 d',
w1=self.window_size, w2=self.window_size)
x = self.window_attention(x, mask=mask_swap)
x = self.window_ffd(x)
x = rearrange(x, 'b m x y w1 w2 d -> b m d (x w1) (y w2)')
# grid attention
mask_swap = mask
mask_swap = rearrange(mask_swap,
'b (w1 x) (w2 y) e l -> b x y w1 w2 e l',
w1=self.window_size, w2=self.window_size)
x = rearrange(x, 'b m d (w1 x) (w2 y) -> b m x y w1 w2 d',
w1=self.window_size, w2=self.window_size)
x = self.grid_attention(x, mask=mask_swap)
x = self.grid_ffd(x)
x = rearrange(x, 'b m x y w1 w2 d -> b m d (w1 x) (w2 y)')
return x
[docs]class SwapFusionBlock(nn.Module):
"""
Swap Fusion Block contains window attention and grid attention.
"""
def __init__(self,
input_dim,
mlp_dim,
dim_head,
window_size,
agent_size,
drop_out):
super(SwapFusionBlock, self).__init__()
# b = batch * max_cav
self.block = nn.Sequential(
Rearrange('b m d (x w1) (y w2) -> b m x y w1 w2 d',
w1=window_size, w2=window_size),
PreNormResidual(input_dim, Attention(input_dim, dim_head, drop_out,
agent_size, window_size)),
PreNormResidual(input_dim,
FeedForward(input_dim, mlp_dim, drop_out)),
Rearrange('b m x y w1 w2 d -> b m d (x w1) (y w2)'),
Rearrange('b m d (w1 x) (w2 y) -> b m x y w1 w2 d',
w1=window_size, w2=window_size),
PreNormResidual(input_dim, Attention(input_dim, dim_head, drop_out,
agent_size, window_size)),
PreNormResidual(input_dim,
FeedForward(input_dim, mlp_dim, drop_out)),
Rearrange('b m x y w1 w2 d -> b m d (w1 x) (w2 y)'),
)
[docs] def forward(self, x, mask=None):
# todo: add mask operation later for mulit-agents
x = self.block(x)
return x
[docs]class SwapFusionEncoder(BaseModule):
"""
Data rearrange -> swap block -> mlp_head
"""
def __init__(self,
input_dim=128,
mlp_dim=256,
agent_size=5,
window_size=8,
dim_head=32,
drop_out=0.1,
depth=3,
mask=False,
decoder=None,
**kwargs):
super(SwapFusionEncoder, self).__init__(**kwargs)
self.layers = nn.ModuleList([])
self.depth = depth
self.mask = mask
swap_fusion_block = SwapFusionBlockMask if self.mask else SwapFusionBlock
for i in range(self.depth):
block = swap_fusion_block(input_dim,
mlp_dim,
dim_head,
window_size,
agent_size,
drop_out)
self.layers.append(block)
# mlp head
self.mlp_head = nn.Sequential(
Reduce('b m d h w -> b d h w', 'mean'),
Rearrange('b d h w -> b h w d'),
nn.LayerNorm(input_dim),
nn.Linear(input_dim, input_dim),
Rearrange('b h w d -> b d h w')
)
if decoder is not None:
self.decoder = NaiveDecoder(decoder)
[docs] def forward(self, ego_feat, coop_cpm, **kwargs):
B = len(ego_feat)
C, H, W = ego_feat[0].shape
x = []
mask = []
num_cavs = []
for xe, xc in zip(ego_feat, coop_cpm):
values = xc.values()
ego_mask = torch.ones_like(xe[:1])
x.append([xe,] + [v['bev_feat'] for v in values])
mask.append([ego_mask,] + [v['bev_mask'] for v in values])
num_cavs.append(len(values) + 1)
l = max(num_cavs)
x_pad = ego_feat[0].new_zeros(B, l, C, H, W)
mask_pad = ego_feat[0].new_zeros(B, H, W, 1, l)
for i in range(B):
x_pad[i, :len(x[i])] = torch.stack(x[i], dim=0)
mask_pad[i, :, :, :, :len(x[i])] = torch.stack(mask[i], dim=-1).permute(1, 2, 0, 3)
for stage in self.layers:
x_pad = stage(x_pad, mask=mask_pad)
out = self.mlp_head(x_pad)
if hasattr(self, 'decoder'):
out = self.decoder(out.unsqueeze(1))
out = rearrange(out, 'b l c h w -> (b l) c h w')
return {self.scatter_keys[0]: out}
if __name__ == "__main__":
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
args = {'input_dim': 512,
'mlp_dim': 512,
'agent_size': 4,
'window_size': 8,
'dim_head': 4,
'drop_out': 0.1,
'depth': 2,
'mask': True
}
block = SwapFusionEncoder(args)
block.cuda()
test_data = torch.rand(1, 4, 512, 32, 32)
test_data = test_data.cuda()
mask = torch.ones(1, 32, 32, 1, 4)
mask = mask.cuda()
output = block(test_data, mask)
print(output)