Source code for cosense3d.modules.heads.det_roi_refine

import copy

import torch.nn as nn
import torch
import numpy as np
from cosense3d.ops import pointnet2_utils
from cosense3d.utils.pclib import rotate_points_along_z_torch
from cosense3d.ops.iou3d_nms_utils import boxes_iou3d_gpu
from cosense3d.utils import box_utils
from cosense3d.modules import BaseModule, plugin
from cosense3d.modules.losses.common import (weighted_smooth_l1_loss,
                                             weighted_sigmoid_binary_cross_entropy)


[docs]class KeypointRoIHead(BaseModule): def __init__(self, num_cls, in_channels, n_fc_channels, roi_grid_pool, target_assigner, dp_ratio=0.3, train_from_epoch=0, **kwargs): super().__init__(**kwargs) self.code_size = 7 self.dp_ratio = dp_ratio self.train_from_epoch = train_from_epoch self.target_assigner = plugin.build_plugin_module(target_assigner) mlps = copy.copy(roi_grid_pool['mlps']) for k in range(len(mlps)): mlps[k] = [in_channels] + mlps[k] self.roi_grid_pool_layer = pointnet2_utils.StackSAModuleMSG( radii=roi_grid_pool['pool_radius'], nsamples=roi_grid_pool['n_sample'], mlps=mlps, use_xyz=True, pool_method=roi_grid_pool['pool_method'], ) grid_size = roi_grid_pool['grid_size'] self.grid_size = grid_size c_out = sum([x[-1] for x in mlps]) pre_channel = grid_size * grid_size * grid_size * c_out fc_layers = [n_fc_channels] * 2 self.shared_fc_layers, pre_channel = self._make_fc_layers(pre_channel, fc_layers) self.cls_layers, pre_channel = self._make_fc_layers(pre_channel, fc_layers, output_channels= num_cls) self.iou_layers, _ = self._make_fc_layers(pre_channel, fc_layers, output_channels= num_cls) self.reg_layers, _ = self._make_fc_layers(pre_channel, fc_layers, output_channels=num_cls * 7) self._init_weights(weight_init='xavier') def _init_weights(self, weight_init='xavier'): if weight_init == 'kaiming': init_func = nn.init.kaiming_normal_ elif weight_init == 'xavier': init_func = nn.init.xavier_normal_ elif weight_init == 'normal': init_func = nn.init.normal_ else: raise NotImplementedError for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): if weight_init == 'normal': init_func(m.weight, mean=0, std=0.001) else: init_func(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001) def _make_fc_layers(self, input_channels, fc_list, output_channels=None): fc_layers = [] pre_channel = input_channels for k in range(len(fc_list)): fc_layers.extend([ nn.Conv1d(pre_channel, fc_list[k], kernel_size=1, bias=False), # nn.BatchNorm1d(fc_list[k]), nn.ReLU() ]) pre_channel = fc_list[k] if self.dp_ratio > 0: fc_layers.append(nn.Dropout(self.dp_ratio)) if output_channels is not None: fc_layers.append( nn.Conv1d(pre_channel, output_channels, kernel_size=1, bias=True)) fc_layers = nn.Sequential(*fc_layers) return fc_layers, pre_channel
[docs] def get_global_grid_points_of_roi(self, rois): rois = rois.view(-1, rois.shape[-1]) batch_size_rcnn = rois.shape[0] # (B, 6x6x6, 3) local_roi_grid_points = self.get_dense_grid_points(rois, batch_size_rcnn, self.grid_size) global_roi_grid_points = rotate_points_along_z_torch( local_roi_grid_points.clone(), rois[:, 6] ).squeeze(dim=1) global_center = rois[:, 0:3].clone() global_roi_grid_points += global_center.unsqueeze(dim=1) return global_roi_grid_points, local_roi_grid_points
[docs] @staticmethod def get_dense_grid_points(rois, batch_size_rcnn, grid_size): """ Get the local coordinates of each grid point of a roi in the coordinate system of the roi(origin lies in the center of this roi. """ faked_features = rois.new_ones((grid_size, grid_size, grid_size)) dense_idx = torch.stack(torch.where(faked_features), dim=1) # (N, 3) [x_idx, y_idx, z_idx] dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() # (B, 6x6x6, 3) local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] roi_grid_points = ( dense_idx + 0.5) / grid_size * local_roi_size.unsqueeze( dim=1) \ - (local_roi_size.unsqueeze( dim=1) / 2) # (B, 6x6x6, 3) return roi_grid_points
[docs] def roi_grid_pool(self, preds): B = len(preds) rois = torch.cat([p['boxes'] for p in preds], dim=0) point_features = torch.cat([p['feat'] for p in preds], dim=0) # (BxN, 6x6x6, 3) global_roi_grid_points, local_roi_grid_points = \ self.get_global_grid_points_of_roi(rois) xyz = torch.cat([p['coor'] for p in preds], dim=0) xyz_batch_cnt = xyz.new_zeros(B).int() for bs_idx in range(B): xyz_batch_cnt[bs_idx] = len(preds[bs_idx]['coor']) new_xyz = global_roi_grid_points.view(-1, 3) new_xyz_batch_cnt = xyz.new_zeros(B).int() for bs_idx in range(B): new_xyz_batch_cnt[bs_idx] = len(preds[bs_idx]['boxes']) * self.grid_size ** 3 pooled_points, pooled_features = self.roi_grid_pool_layer( xyz=xyz[:, :3].contiguous(), xyz_batch_cnt=xyz_batch_cnt, new_xyz=new_xyz[:, :3].contiguous(), new_xyz_batch_cnt=new_xyz_batch_cnt, features=point_features.contiguous(), # weighted point features ) # (M1 + M2 ..., C) # (BxN, 6x6x6, C) pooled_features = pooled_features.view(-1, self.grid_size ** 3, pooled_features.shape[-1]) return pooled_features
[docs] def forward(self, preds, **kwargs): epoch = kwargs.get('epoch', self.train_from_epoch + 1) if epoch < self.train_from_epoch: return {self.scatter_keys[0]: [None for _ in preds]} # RoI aware pooling pooled_features = self.roi_grid_pool(preds) batch_size_rcnn = pooled_features.shape[0] pooled_features = pooled_features.permute(0, 2, 1). \ contiguous().view(batch_size_rcnn, -1, self.grid_size, self.grid_size, self.grid_size) # (BxN, C, 6, 6, 6) shared_features = self.shared_fc_layers( pooled_features.view(batch_size_rcnn, -1, 1)) rcnn_cls = self.cls_layers(shared_features).transpose( 1, 2).contiguous().squeeze( dim=1) # (B, 1 or 2) rcnn_iou = self.iou_layers(shared_features).transpose( 1, 2).contiguous().squeeze( dim=1) # (B, 1) rcnn_reg = self.reg_layers(shared_features).transpose( 1, 2).contiguous().squeeze( dim=1) # (B, C) roi_preds = None if not self.training: rois = torch.cat([p['boxes'] for p in preds], dim=0) roi_preds = self.target_assigner.get_predictions( rcnn_cls, rcnn_iou, rcnn_reg, rois ) idx = 0 out_list = [] for p in preds: num = len(p['boxes']) out_dict = { 'rois': p['boxes'], 'rcnn_cls': rcnn_cls[idx:idx+num], 'rcnn_iou': rcnn_iou[idx:idx+num], 'rcnn_reg': rcnn_reg[idx:idx+num], } if roi_preds is not None: out_dict['preds'] = {k: v[idx:idx+num] for k, v in roi_preds.items()} out_list.append(out_dict) idx += num return {self.scatter_keys[0]: out_list}
[docs] def loss(self, out, gt_boxes, epoch, **kwargs): """ Parameters ---------- output_dict : dict target_dict : dict """ if epoch < self.train_from_epoch: return {} rois = [x['rois'] for x in out] label_dict = self.target_assigner.assign(rois, gt_boxes) # rcnn out rcnn_cls = self.cat_data_from_list(out, 'rcnn_cls').view(1, -1, 1) rcnn_iou = self.cat_data_from_list(out, 'rcnn_iou').view(1, -1, 1) rcnn_reg = self.cat_data_from_list(out, 'rcnn_reg').view(1, -1, 7) tgt_cls = label_dict['cls_tgt'].view(1, -1, 1) tgt_iou = label_dict['iou_tgt'].view(1, -1, 1) tgt_reg = label_dict['reg_tgt'].view(1, -1, 7) pos_norm = tgt_cls.sum() # cls loss loss_cls = weighted_sigmoid_binary_cross_entropy(rcnn_cls, tgt_cls) # iou loss # TODO: also count the negative samples loss_iou = weighted_smooth_l1_loss(rcnn_iou, tgt_iou, weights=tgt_cls).mean() # regression loss # Target resampling : Generate a weights mask to force the regressor concentrate on low iou predictions # sample 50% with iou>0.7 and 50% < 0.7 weights = torch.ones(tgt_iou.shape, device=tgt_iou.device) weights[tgt_cls == 0] = 0 neg = torch.logical_and(tgt_iou < 0.7, tgt_cls != 0) pos = torch.logical_and(tgt_iou >= 0.7, tgt_cls != 0) num_neg = int(neg.sum(dim=1)) num_pos = int(pos.sum(dim=1)) num_pos_smps = max(num_neg, 2) pos_indices = torch.where(pos)[1] not_selsected = torch.randperm(num_pos)[:num_pos - num_pos_smps] # not_selsected_indices = pos_indices[not_selsected] weights[:, pos_indices[not_selsected]] = 0 loss_reg = weighted_smooth_l1_loss(rcnn_reg, tgt_reg, weights=weights / max(weights.sum(), 1)).sum() loss_dict = { 'rcnn_cls_loss': loss_cls, 'rcnn_iou_loss': loss_iou, 'rcnn_reg_loss': loss_reg, } return loss_dict