Source code for cosense3d.modules.heads.query_guided_petr_head

from typing import List
import os
import torch
from torch import nn

from cosense3d.modules import BaseModule
from cosense3d.modules.plugin import build_plugin_module
from cosense3d.modules.utils.common import inverse_sigmoid
from cosense3d.utils.misc import multi_apply
from cosense3d.utils.box_utils import normalize_bbox, denormalize_bbox
from cosense3d.modules.losses import build_loss
from cosense3d.modules.losses.edl import pred_to_conf_unc


[docs]class QueryGuidedPETRHead(BaseModule): def __init__(self, embed_dims, pc_range, code_weights, num_classes, cls_assigner, box_assigner, loss_cls, loss_box, num_reg_fcs=3, num_pred=3, use_logits=False, reg_channels=None, sparse=False, pred_while_training=False, **kwargs): super().__init__(**kwargs) self.embed_dims = embed_dims self.reg_channels = {} if reg_channels is None: self.code_size = 10 else: for c in reg_channels: name, channel = c.split(':') self.reg_channels[name] = int(channel) self.code_size = sum(self.reg_channels.values()) self.num_classes = num_classes self.num_reg_fcs = num_reg_fcs self.num_pred = num_pred self.use_logits = use_logits self.sparse = sparse self.pred_while_training = pred_while_training self.pc_range = nn.Parameter(torch.tensor(pc_range), requires_grad=False) self.code_weights = nn.Parameter(torch.tensor(code_weights), requires_grad=False) self.box_assigner = build_plugin_module(box_assigner) self.cls_assigner = build_plugin_module(cls_assigner) self.loss_cls = build_loss(**loss_cls) self.loss_box = build_loss(**loss_box) self.is_edl = True if 'edl' in self.loss_cls.name.lower() else False self._init_layers() self.init_weights() def _init_layers(self): cls_branch = [] for _ in range(self.num_reg_fcs): cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) cls_branch.append(nn.LayerNorm(self.embed_dims)) cls_branch.append(nn.ReLU(inplace=True)) cls_branch.append(nn.Linear(self.embed_dims, self.num_classes)) fc_cls = nn.Sequential(*cls_branch) reg_branch = [] for _ in range(self.num_reg_fcs): reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) reg_branch.append(nn.ReLU()) reg_branch.append(nn.Linear(self.embed_dims, self.code_size)) reg_branch = nn.Sequential(*reg_branch) self.cls_branches = nn.ModuleList( [fc_cls for _ in range(self.num_pred)]) self.reg_branches = nn.ModuleList( [reg_branch for _ in range(self.num_pred)])
[docs] def init_weights(self): for m in self.cls_branches: nn.init.xavier_uniform_(m[-1].weight) # follow the official DETR to init parameters for m in self.modules(): if hasattr(m, 'weight') and m.weight.dim() > 1: nn.init.xavier_uniform_(m.weight) self._is_init = True
[docs] def forward(self, feat_in, **kwargs): if self.sparse: outs_dec = self.cat_data_from_list(feat_in, 'outs_dec').permute(1, 0, 2) reference_points = self.cat_data_from_list(feat_in, 'ref_pts', pad_idx=True) reference_inds = reference_points[..., 0] reference_points = reference_points[..., 1:] else: outs_dec = self.stack_data_from_list(feat_in, 'outs_dec').permute(1, 0, 2, 3) reference_points = self.stack_data_from_list(feat_in, 'ref_pts') reference_inds = None pos_dim = reference_points.shape[-1] assert outs_dec.isnan().sum() == 0, "found nan in outs_dec." # if outs_dec.isnan().any(): # print('d') outputs_classes = [] outputs_coords = [] for lvl in range(len(outs_dec)): out_dec = outs_dec[lvl] # out_dec = torch.nan_to_num(out_dec) pred_cls = self.cls_branches[lvl](out_dec) pred_reg = self.reg_branches[lvl](out_dec) if self.use_logits: reference = inverse_sigmoid(reference_points.clone()) pred_reg[..., :pos_dim] += reference pred_reg[..., :3] = pred_reg[..., :3].sigmoid() outputs_classes.append(pred_cls) outputs_coords.append(pred_reg) all_cls_logits = torch.stack(outputs_classes) all_bbox_reg = torch.stack(outputs_coords) if self.use_logits: all_bbox_reg[..., :3] = (all_bbox_reg[..., :3] * ( self.pc_range[3:] - self.pc_range[:3]) + self.pc_range[:3]) reference_points = reference_points * (self.pc_range[3:] - self.pc_range[:3]) + self.pc_range[:3] det_boxes, pred_boxes = self.get_pred_boxes(all_bbox_reg, reference_points) cls_scores = pred_to_conf_unc(all_cls_logits, self.loss_cls.activation, self.is_edl)[0] if self.sparse: outs = [] for i in range(len(feat_in)): mask = reference_inds == i outs.append( { 'all_cls_logits': all_cls_logits[:, mask], 'all_bbox_reg': all_bbox_reg[:, mask], 'ref_pts': reference_points[mask], 'all_cls_scores': cls_scores[:, mask], 'all_bbox_preds': det_boxes[:, mask], 'all_bbox_preds_t': pred_boxes[:, mask] if pred_boxes is not None else None, } ) else: outs = [ { 'all_cls_logits': all_cls_logits[:, i], 'all_bbox_reg': all_bbox_reg[:, i], 'ref_pts': reference_points[i], 'all_cls_scores': cls_scores[:, i], 'all_bbox_preds': det_boxes[:, i], 'all_bbox_preds_t': pred_boxes[:, i] if pred_boxes is not None else None, } for i in range(len(feat_in)) ] if self.pred_while_training or not self.training: dets = self.get_predictions(cls_scores, det_boxes, pred_boxes, batch_inds=reference_inds) for i, out in enumerate(outs): out['preds'] = dets[i] return {self.scatter_keys[0]: outs}
[docs] def loss(self, petr_out, gt_boxes_global, gt_labels_global, *args, **kwargs): aux_dict = {self.gt_keys[2:][i]: x for i, x in enumerate(args)} epoch = kwargs.get('epoch', 0) if self.sparse: cls_scores = torch.cat([x for out in petr_out for x in out['all_cls_logits']], dim=0) bbox_reg = torch.cat([x for out in petr_out for x in out['all_bbox_reg']], dim=0) ref_pts = [x['ref_pts'] for x in petr_out for _ in range(self.num_pred)] else: cls_scores = self.stack_data_from_list(petr_out, 'all_cls_logits').flatten(0, 1) bbox_reg = self.stack_data_from_list(petr_out, 'all_bbox_reg').flatten(0, 1) ref_pts = self.stack_data_from_list(petr_out, 'ref_pts').unsqueeze(1).repeat( 1, self.num_pred, 1, 1).flatten(0, 1) gt_boxes_global = [x for x in gt_boxes_global for _ in range(self.num_pred)] # gt_velos = [x[:, 7:] for x in gt_boxes for _ in range(self.num_pred)] gt_labels_global = [x for x in gt_labels_global for _ in range(self.num_pred)] if 'gt_preds' in aux_dict: gt_preds = [x.transpose(1, 0) for x in aux_dict['gt_preds'] for _ in range(self.num_pred)] else: gt_preds = None # cls loss cls_tgt = multi_apply(self.cls_assigner.assign, ref_pts, gt_boxes_global, gt_labels_global, **kwargs) cls_src = cls_scores.view(-1, self.num_classes) from cosense3d.utils.vislib import draw_points_boxes_plt, plt points = ref_pts[0].detach().cpu().numpy() boxes = gt_boxes_global[0][:, :7].detach().cpu().numpy() scores = petr_out[0]['all_cls_scores'][0] scores = scores[:, self.num_classes - 1:].squeeze().detach().cpu().numpy() ax = draw_points_boxes_plt( pc_range=self.pc_range.tolist(), boxes_gt=boxes, return_ax=True ) ax.scatter(points[:, 0], points[:, 1], c=scores, cmap='jet', s=3, marker='s', vmin=0.0, vmax=1) plt.savefig(f"{os.environ['HOME']}/Downloads/tmp.jpg") plt.close() # if kwargs['itr'] % 1 == 0: # from cosense3d.utils.vislib import draw_points_boxes_plt, plt # points = ref_pts[0].detach().cpu().numpy() # boxes = gt_boxes[0][:, :7].detach().cpu().numpy() # scores = pred_to_conf_unc( # cls_scores[0], getattr(self.loss_cls, 'activation'), edl=self.is_edl)[0] # scores = scores[:, self.num_classes - 1:].squeeze().detach().cpu().numpy() # ax = draw_points_boxes_plt( # pc_range=self.pc_range.tolist(), # boxes_gt=boxes, # return_ax=True # ) # ax.scatter(points[:, 0], points[:, 1], c=scores, cmap='jet', s=3, marker='s', vmin=0.0, vmax=1.0) # # ax = draw_points_boxes_plt( # # pc_range=self.pc_range.tolist(), # # points=points[cls_tgt[0].squeeze().detach().cpu().numpy() > 0], # # points_c="green", # # ax=ax, # # return_ax=True # # ) # # ax = draw_points_boxes_plt( # # pc_range=self.pc_range.tolist(), # # points=points[scores > 0.5], # # points_c="magenta", # # ax=ax, # # return_ax=True # # ) # plt.savefig(f"{os.environ['HOME']}/Downloads/tmp.jpg") # plt.close() cls_tgt = torch.cat(cls_tgt, dim=0) cared = (cls_tgt >= 0).any(dim=-1) cls_src = cls_src[cared] cls_tgt = cls_tgt[cared] # convert one-hot to labels( cur_labels = torch.zeros_like(cls_tgt[..., 0]).long() lbl_inds, cls_inds = torch.where(cls_tgt) cur_labels[lbl_inds] = cls_inds + 1 avg_factor = max((cur_labels > 0).sum(), 1) loss_cls = self.loss_cls( cls_src, cur_labels, temp=epoch, avg_factor=avg_factor ) # box loss # pad ref pts with batch index if 'gt_preds' in aux_dict: gt_preds = self.cat_data_from_list(gt_preds) box_tgt = self.box_assigner.assign( self.cat_data_from_list(ref_pts, pad_idx=True), self.cat_data_from_list(gt_boxes_global, pad_idx=True), self.cat_data_from_list(gt_labels_global), gt_preds ) ind = box_tgt['idx'][0] # only one head loss_box = 0 bbox_reg = bbox_reg.view(-1, self.code_size) if ind.shape[1] > 0: ptr = 0 for reg_name, reg_dim in self.reg_channels.items(): pred_reg = bbox_reg[:, ptr:ptr+reg_dim].contiguous() if reg_name == 'scr': pred_reg = pred_reg.sigmoid() cur_reg_src = pred_reg[box_tgt['valid_mask'][0]] if reg_name == 'vel': cur_reg_tgt = box_tgt['vel'][0] * 0.1 elif reg_name == 'pred': cur_reg_tgt = box_tgt[reg_name][0] mask = cur_reg_tgt[..., 0].bool() cur_reg_src = cur_reg_src[mask] cur_reg_tgt = cur_reg_tgt[mask, 1:] else: cur_reg_tgt = box_tgt[reg_name][0] # N, C cur_loss = self.loss_box(cur_reg_src, cur_reg_tgt) loss_box = loss_box + cur_loss ptr += reg_dim return { 'cls_loss': loss_cls, 'box_loss': loss_box, 'cls_max': pred_to_conf_unc( cls_src, self.loss_cls.activation, self.is_edl)[0][..., self.num_classes - 1:].max() }
[docs] def get_pred_boxes(self, bbox_preds, ref_pts): reg = {} ptr = 0 for reg_name, reg_dim in self.reg_channels.items(): reg[reg_name] = bbox_preds[..., ptr:ptr + reg_dim].contiguous() ptr += reg_dim out = self.box_assigner.box_coder.decode(ref_pts[None], reg) if isinstance(out, tuple): det, pred = out else: det = out pred = None return det, pred
[docs] def get_predictions(self, cls_scores, det_boxes, pred_boxes, batch_inds=None): if self.is_edl: scores = cls_scores[-1][..., 1:].sum(dim=-1) else: scores = cls_scores[-1].sum(dim=-1) labels = cls_scores[-1].argmax(dim=-1) pos = scores > self.box_assigner.center_threshold dets = [] if batch_inds is None: inds = range(cls_scores.shape[1]) for i in inds: dets.append({ 'box': det_boxes[-1][i][pos[i]], 'scr': scores[i][pos[i]], 'lbl': labels[i][pos[i]], 'idx': torch.ones_like(labels[i][pos[i]]) * i, }) else: inds = batch_inds.unique() for i in inds: mask = batch_inds == i pos_mask = pos[mask] dets.append({ 'box': det_boxes[-1][mask][pos_mask], 'scr': scores[mask][pos_mask], 'lbl': labels[mask][pos_mask], 'pred': pred_boxes[-1][mask][pos_mask] if pred_boxes is not None else None, 'idx': batch_inds[mask][pos_mask].long() }) return dets