import torch
import torch.nn as nn
from cosense3d.modules import BaseModule, plugin
from cosense3d.modules.utils.misc import SELayer_Linear, MLN
from cosense3d.modules.utils.positional_encoding import pos2posemb2d
[docs]class LidarPETRHead(BaseModule):
def __init__(self,
in_channels,
transformer,
feature_stride,
lidar_range,
topk=2048,
memory_len=256,
num_query=644,
**kwargs):
super().__init__(**kwargs)
self.transformer = plugin.build_plugin_module(transformer)
self.embed_dims = self.transformer.embed_dims
self.num_pose_feat = 64
self.pos_dim = 2
self.in_channels = in_channels
self.feature_stride = feature_stride
self.topk = topk
self.num_query = num_query
self.lidar_range = nn.Parameter(torch.tensor(lidar_range), requires_grad=False)
self.reference_points = nn.Embedding(self.num_query, self.pos_dim)
self._init_layers()
def _init_layers(self):
self.position_embeding = nn.Sequential(
nn.Linear(self.num_pose_feat * self.pos_dim, self.embed_dims * 4),
nn.ReLU(),
nn.Linear(self.embed_dims * 4, self.embed_dims),
)
self.memory_embed = nn.Sequential(
nn.Linear(self.in_channels, self.embed_dims),
nn.ReLU(),
nn.Linear(self.embed_dims, self.embed_dims),
)
self.query_embedding = nn.Sequential(
nn.Linear(self.num_pose_feat * self.pos_dim, self.embed_dims),
nn.ReLU(),
nn.Linear(self.embed_dims, self.embed_dims),
)
# can be replaced with MLN
self.featurized_pe = SELayer_Linear(self.embed_dims)
[docs] def init_weights(self):
# follow the official DETR to init parameters
for m in self.modules():
if hasattr(m, 'weight') and m.weight.dim() > 1:
nn.utils.init.xavier_uniform_(m)
self._is_init = True
[docs] def forward(self, rois, bev_feat, memory, **kwargs):
feat, ctr = self.gather_topk(rois, bev_feat)
pos = ((ctr - self.lidar_range[:2]) /
(self.lidar_range[3:5] - self.lidar_range[:2]))
pos_emb = self.position_embeding(pos2posemb2d(pos, self.num_pose_feat))
memory = self.memory_embed(feat)
pos_emb = self.featurized_pe(pos_emb, memory)
reference_points = (self.reference_points.weight).unsqueeze(0).repeat(memory.shape[0], 1, 1)
query_pos = self.query_embedding(pos2posemb2d(reference_points, self.num_pose_feat))
tgt = torch.zeros_like(query_pos)
outs_dec, _ = self.transformer(memory, tgt, query_pos, pos_emb)
outs = [
{
'outs_dec': outs_dec[:, i],
'ref_pts': reference_points[i],
} for i in range(len(rois))
]
return {self.scatter_keys[0]: outs}
[docs] def gather_topk(self, rois, bev_feats):
topk_feat, topk_ctr = [], []
for roi, bev_feat in zip(rois, bev_feats):
ctr = bev_feat[f'p{self.feature_stride}']['ctr']
feat = bev_feat[f'p{self.feature_stride}']['feat']
scores = roi['scr']
if scores.shape[0] < self.topk:
raise NotImplementedError
else:
topk_inds = torch.topk(scores, k=self.topk).indices
topk_ctr.append(ctr[topk_inds])
topk_feat.append(feat[topk_inds])
topk_ctr = torch.stack(topk_ctr, dim=0)
topk_feat = torch.stack(topk_feat, dim=0)
return topk_feat, topk_ctr