import os
from cosense3d.modules import BaseModule
from cosense3d.modules.utils.me_utils import *
from cosense3d.modules.utils.common import pad_r, linear_last, cat_coor_with_idx
from cosense3d.ops.utils import points_in_boxes_gpu
from cosense3d.modules.losses import edl, build_loss
from cosense3d.modules.plugin import build_plugin_module
from cosense3d.modules.plugin.attn import NeighborhoodAttention
[docs]class BEV(BaseModule):
def __init__(self,
data_info,
in_dim,
stride,
target_assigner,
loss_cls,
num_cls=1,
class_names_each_head=None,
down_sample_tgt=False,
generate_roi_scr=True,
**kwargs):
super(BEV, self).__init__(**kwargs)
self.in_dim = in_dim
self.class_names_each_head = class_names_each_head
self.down_sample_tgt = down_sample_tgt
self.stride = stride
self.num_cls = num_cls
self.generate_roi_scr = generate_roi_scr
for k, v in data_info.items():
setattr(self, k, v)
update_me_essentials(self, data_info, self.stride)
self.reg_layer = linear_last(in_dim, 32, num_cls, bias=True)
self.tgt_assigner = build_plugin_module(target_assigner)
self.loss_cls = build_loss(**loss_cls)
self.is_edl = True if 'edl' in self.loss_cls.name.lower() else False
[docs] def forward(self, stensor_list, **kwargs):
coor, feat, ctr = self.format_input(stensor_list)
if self.training and self.down_sample_tgt:
coor, feat = self.down_sample(coor, feat)
centers = indices2metric(coor, self.voxel_size)
reg = self.reg_layer(feat)
conf, unc = self.tgt_assigner.get_predictions(
reg, self.is_edl, getattr(self.loss_cls, 'activation'))
out = {
'ctr': centers,
'reg': reg,
'conf': conf,
'unc': unc,
}
if self.generate_roi_scr:
out['scr'] = conf.max(dim=-1).values
return self.format_output(out, len(stensor_list))
[docs] def down_sample(self, coor, feat):
keep = torch.rand_like(feat[:, 0]) > 0.5
coor = coor[keep]
feat = feat[keep]
return coor, feat
[docs] def loss(self, batch_list, gt_boxes, gt_labels, **kwargs):
tgt_pts = self.cat_data_from_list(batch_list, 'ctr', pad_idx=True)
boxes_vis = gt_boxes[0][:, :7].detach().cpu().numpy()
gt_boxes = self.cat_data_from_list(gt_boxes, pad_idx=True)
conf = self.cat_data_from_list(batch_list, 'conf')
tgt_pts, tgt_label, valid = self.tgt_assigner.assign(
tgt_pts, gt_boxes[:, :8], len(batch_list), conf, **kwargs)
epoch_num = kwargs.get('epoch', 0)
reg = self.cat_data_from_list(batch_list, 'reg')
# if kwargs['itr'] % 100 == 0:
# from cosense3d.utils.vislib import draw_points_boxes_plt, plt
# from matplotlib import colormaps
# jet = colormaps['jet']
# points = batch_list[0]['ctr'].detach().cpu().numpy()
# scores = batch_list[0]['conf'][:, self.num_cls - 1:].detach().cpu().numpy()
# ax = draw_points_boxes_plt(
# pc_range=[-144, -41.6, -3.0, 144, 41.6, 1.0],
# # points=points,
# boxes_gt=boxes_vis,
# return_ax=True
# )
# ax.scatter(points[:, 0], points[:, 1], c=scores, cmap=jet, s=3, marker='s', vmin=0, vmax=1)
# plt.savefig(f"{os.environ['HOME']}/Downloads/tmp1.jpg")
# plt.close()
if valid is None:
# targets are not down-sampled
avg_factor = max(tgt_label.sum(), 1)
loss_cls = self.loss_cls(
reg,
tgt_label,
temp=epoch_num,
avg_factor=avg_factor
)
else:
# negative targets are not down-sampled to a ratio to the positive samples
loss_cls = self.loss_cls(
reg[valid],
tgt_label,
temp=epoch_num,
)
loss_dict = {'bev_loss': loss_cls}
return loss_dict
[docs]class BEVMultiResolution(BaseModule):
def __init__(self, strides, strides_for_loss, **kwargs):
super().__init__(**kwargs)
self.strides = strides
self.strides_for_loss = strides_for_loss
for s in strides:
kwargs['stride'] = s
setattr(self, f'head_p{s}', BEV(**kwargs))
[docs] def forward(self, stensor_list, *args, **kwargs):
out_list = [{} for b in range(len(stensor_list))]
for s in self.strides:
out = getattr(self, f'head_p{s}')(stensor_list)[self.scatter_keys[0]]
for i, x in enumerate(out):
out_list[i][f'p{s}'] = x
return {self.scatter_keys[0]: out_list}
[docs] def loss(self, batch_list, gt_boxes, gt_labels, **kwargs):
loss_dict = {}
for s in self.strides_for_loss:
ldict = getattr(self, f'head_p{s}').loss(
[l[f'p{s}'] for l in batch_list], gt_boxes, gt_labels, **kwargs)
for k, v in ldict.items():
loss_dict[f'{k}_s{s}'] = v
return loss_dict
[docs]class ContinuousBEV(BaseModule):
def __init__(self,
out_channels,
data_info,
in_dim,
stride,
context_decoder,
target_assigner,
loss_cls,
class_names_each_head=None,
**kwargs):
super().__init__(**kwargs)
self.in_dim = in_dim
self.class_names_each_head = class_names_each_head
self.stride = stride
for k, v in data_info.items():
setattr(self, k, v)
update_me_essentials(self, data_info, self.stride)
self.context_decoder = build_plugin_module(context_decoder)
self.reg_layer = linear_last(in_dim, 32, out_channels, bias=True)
self.tgt_assigner = build_plugin_module(target_assigner)
self.loss_cls = build_loss(**loss_cls)
[docs] @torch.no_grad()
def sample_reference_points(self, centers, gt_boxes, gt_labels):
gt_boxes = self.cat_data_from_list(gt_boxes, pad_idx=True)
if self.training:
new_pts = centers.clone()
new_pts[:, 1:] += (torch.rand_like(centers[:, 1:]) - 0.5) * self.res[0]
ref_pts, ref_label, _ = self.tgt_assigner.assign(
new_pts, gt_boxes, len(gt_boxes))
else:
ref_pts, ref_label, _ = self.tgt_assigner.assign(
centers, gt_boxes, len(gt_boxes), down_sample=False)
return ref_pts, ref_label
[docs] def get_evidence(self, ref_pts, coor, feat):
raise NotImplementedError
[docs] def forward(self, stensor_list, gt_boxes, gt_labels, **kwargs):
coor, feat, ctr = self.format_input(stensor_list)
centers = indices2metric(coor, self.voxel_size)
ref_pts, ref_label = self.sample_reference_points(
centers, gt_boxes, gt_labels)
evidence = self.get_evidence(ref_pts, coor, feat)
conf, unc = edl.evidence_to_conf_unc(evidence)
out = {
'ref_pts': ref_pts,
'ref_lbls': ref_label,
'evi': evidence,
'conf': conf,
'unc': unc
}
return self.format_output(out, len(stensor_list))
[docs] def down_sample(self, coor, feat):
keep = torch.rand_like(feat[:, 0]) > 0.5
coor = coor[keep]
feat = feat[keep]
return coor, feat
[docs] def loss(self, batch_list, **kwargs):
tgt_lbl = self.cat_data_from_list(batch_list, 'ref_lbls')
epoch_num = kwargs.get('epoch', 0)
evidence = self.cat_data_from_list(batch_list, 'evi')
# avg_factor = max(tgt_label.sum(), 1)
loss_cls = self.loss_cls(
evidence,
tgt_lbl,
temp=epoch_num,
# avg_factor=avg_factor
)
loss_dict = {'bev_loss': loss_cls}
return loss_dict
[docs]class ContiGevBEV(ContinuousBEV):
[docs] def get_evidence(self, ref_pts, coor, feat):
reg = self.reg_layer(feat)
reg = self.context_decoder(ref_pts, coor, reg)
return reg
[docs]class ContiAttnBEV(ContinuousBEV):
[docs] def get_evidence(self, ref_pts, coor, feat):
ref_context = self.context_decoder(ref_pts, coor, feat)
reg = self.reg_layer(ref_context)
return reg.relu()