Source code for cosense3d.modules.plugin.vsa

import copy
import random

import torch
import torch.nn as nn

from cosense3d.ops import pointnet2_utils
from cosense3d.ops.utils import points_in_boxes_gpu
from cosense3d.modules.utils.common import get_voxel_centers, cat_coor_with_idx

sa_layer_default=dict(
    raw_points=dict(
    mlps=[[16, 16], [16, 16]],
    pool_radius=[0.4, 0.8],
    n_sample=[16, 16],
    ),
    x_conv1=dict(
        downsample_factor=1,
        mlps=[[16, 16], [16, 16]],
        pool_radius=[0.4, 0.8],
        n_sample=[16, 16],
    ),
    x_conv2=dict(
        downsample_factor=2,
        mlps=[[32, 32], [32, 32]],
        pool_radius=[0.8, 1.2],
        n_sample=[16, 32],
    ),
    x_conv3=dict(
        downsample_factor=4,
        mlps=[[64, 64], [64, 64]],
        pool_radius=[1.2, 2.4],
        n_sample=[16, 32],
    ),
    x_conv4=dict(
        downsample_factor=8,
        mlps=[[64, 64], [64, 64]],
        pool_radius=[2.4, 4.8],
        n_sample=[16, 32],
    )
)

default_feature_source = ['bev', 'x_conv1', 'x_conv2', 'x_conv3', 'x_conv4', 'raw_points']

[docs]def bilinear_interpolate_torch(im, x, y): """ Args: im: (H, W, C) [y, x] x: (N) y: (N) Returns: """ x0 = torch.floor(x).long() x1 = x0 + 1 y0 = torch.floor(y).long() y1 = y0 + 1 x0 = torch.clamp(x0, 0, im.shape[1] - 1) x1 = torch.clamp(x1, 0, im.shape[1] - 1) y0 = torch.clamp(y0, 0, im.shape[0] - 1) y1 = torch.clamp(y1, 0, im.shape[0] - 1) Ia = im[y0, x0] Ib = im[y1, x0] Ic = im[y0, x1] Id = im[y1, x1] wa = (x1.type_as(x) - x) * (y1.type_as(y) - y) wb = (x1.type_as(x) - x) * (y - y0.type_as(y)) wc = (x - x0.type_as(x)) * (y1.type_as(y) - y) wd = (x - x0.type_as(x)) * (y - y0.type_as(y)) ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd) return ans
[docs]class VoxelSetAbstraction(nn.Module): def __init__(self, voxel_size, point_cloud_range, num_keypoints=4096, num_out_features=32, point_source='raw_points', features_source=None, num_bev_features=128, bev_stride=8, num_rawpoint_features=3, enlarge_selection_boxes=True, sa_layer=None, min_selected_kpts=128, **kwargs): super().__init__() self.voxel_size = voxel_size self.point_cloud_range = point_cloud_range self.features_source = default_feature_source \ if features_source is None \ else features_source self.num_keypoints = num_keypoints self.num_out_features = num_out_features self.point_source = point_source self.num_bev_features = num_bev_features self.bev_stride = bev_stride self.num_rawpoint_features = num_rawpoint_features self.enlarge_selection_boxes = enlarge_selection_boxes self.min_selected_kpts = min_selected_kpts self.SA_layers = nn.ModuleList() self.SA_layer_names = [] self.downsample_times_map = {} c_in = 0 sa_layer = sa_layer_default if sa_layer is None else sa_layer for src_name in self.features_source : if src_name in ['bev', 'raw_points']: continue self.downsample_times_map[src_name] = sa_layer[src_name]['downsample_factor'] mlps = copy.copy(sa_layer[src_name]['mlps']) for k in range(len(mlps)): mlps[k] = [mlps[k][0]] + mlps[k] cur_layer = pointnet2_utils.StackSAModuleMSG( radii=sa_layer[src_name]['pool_radius'], nsamples=sa_layer[src_name]['n_sample'], mlps=mlps, use_xyz=True, pool_method='max_pool', ) self.SA_layers.append(cur_layer) self.SA_layer_names.append(src_name) c_in += sum([x[-1] for x in mlps]) if 'bev' in self.features_source: c_bev = num_bev_features c_in += c_bev if 'raw_points' in self.features_source: mlps = copy.copy(sa_layer['raw_points']['mlps']) for k in range(len(mlps)): mlps[k] = [num_rawpoint_features - 3] + mlps[k] self.SA_rawpoints = pointnet2_utils.StackSAModuleMSG( radii=sa_layer['raw_points']['pool_radius'], nsamples=sa_layer['raw_points']['n_sample'], mlps=mlps, use_xyz=True, pool_method='max_pool' ) c_in += sum([x[-1] for x in mlps]) self.vsa_point_feature_fusion = nn.Sequential( nn.Linear(c_in, self.num_out_features, bias=False), nn.BatchNorm1d(self.num_out_features), nn.ReLU(), ) self.num_point_features = self.num_out_features self.num_point_features_before_fusion = c_in
[docs] def interpolate_from_bev_features(self, keypoints_list, bev_features): B = len(bev_features) point_bev_features_list = [] for i in range(B): keypoints = keypoints_list[i][:, :3] x_idxs = (keypoints[..., 0] - self.point_cloud_range[0]) / self.voxel_size[0] y_idxs = (keypoints[..., 1] - self.point_cloud_range[1]) / self.voxel_size[1] x_idxs = x_idxs / self.bev_stride y_idxs = y_idxs / self.bev_stride cur_bev_features = bev_features[i].permute(1, 2, 0) # (H, W, C) point_bev_features = bilinear_interpolate_torch(cur_bev_features, x_idxs, y_idxs) point_bev_features_list.append(point_bev_features) point_bev_features = torch.cat(point_bev_features_list, dim=0) # (B, N, C0) return point_bev_features
[docs] def get_sampled_points(self, points, voxel_coords): B = len(points) keypoints_list = [] for i in range(B): if self.point_source == 'raw_points': src_points = points[i] else: raise NotImplementedError # # generate random keypoints in the perception view field # keypoints = torch.randn((self.num_keypoints, 4), device=src_points.device) # keypoints[..., 0] = keypoints[..., 0] * 140 # keypoints[..., 1] = keypoints[..., 1] * 40 # # points with height flag 10 are padding/invalid, for later filtering # keypoints[..., 2] = 10.0 sampled_points = src_points.unsqueeze(dim=0) # (1, N, 3) # sample points with FPS # some cropped pcd may have very few points, select various number # of points to ensure similar sample density # 50000 is approximately the number of points in one full pcd num_kpts = int(self.num_keypoints * sampled_points.shape[1] / 50000) + 1 num_kpts = min(num_kpts, self.num_keypoints) cur_pt_idxs = pointnet2_utils.furthest_point_sample( sampled_points[..., :3].contiguous(), num_kpts ).long() if sampled_points.shape[1] < num_kpts: empty_num = num_kpts - sampled_points.shape[1] cur_pt_idxs[0, -empty_num:] = cur_pt_idxs[0, :empty_num] keypoints = sampled_points[0][cur_pt_idxs[0]] # keypoints[:len(kpts[0]), :] = kpts keypoints_list.append(keypoints) # keypoints = torch.cat(keypoints_list, dim=0) # (B, M, 3) return keypoints_list
[docs] def forward(self, det_out, bev_feat, voxel_feat, points): B = len(points) preds = [x['preds'] for x in det_out] keypoints_list = self.get_sampled_points(points, voxel_feat) # BxNx4 # Only select the points that are in the predicted bounding boxes boxes = cat_coor_with_idx([x['box'] for x in preds]) scores = torch.cat([x['scr'] for x in preds]) # At the early training stage, there might be too many boxes, # we select limited number of boxes for the second stage. if boxes.shape[0] > B * 100: topk = scores.topk(k=100 * B).indices scores = scores[topk] boxes = boxes[topk] boxes_tmp = boxes.clone() if self.enlarge_selection_boxes: boxes_tmp[:, 4:7] += 0.5 keypoints = cat_coor_with_idx(keypoints_list) if len(boxes_tmp) > 0: pts_idx_of_box = points_in_boxes_gpu(keypoints[:, :4], boxes_tmp, batch_size=B)[1] else: pts_idx_of_box = torch.full((len(keypoints),), fill_value=-1, device=keypoints.device) kpt_mask = pts_idx_of_box >= 0 # Ensure enough points are selected to satisfy the # condition of batch norm in the FC layers of feature fusion module for i in range(B): batch_mask = keypoints[:, 0] == i if kpt_mask[batch_mask].sum().item() < self.min_selected_kpts: tmp = kpt_mask[batch_mask].clone() tmp[torch.randint(0, batch_mask.sum().item(), (self.min_selected_kpts,))] = True kpt_mask[batch_mask] = tmp point_features_list = [] if 'bev' in self.features_source: point_bev_features = self.interpolate_from_bev_features( keypoints_list, bev_feat ) point_features_list.append(point_bev_features[kpt_mask]) new_xyz = keypoints[kpt_mask] new_xyz_scrs = torch.zeros((kpt_mask.sum().item(),), device=keypoints.device) valid = pts_idx_of_box[kpt_mask] >= 0 new_xyz_scrs[valid] = scores[pts_idx_of_box[kpt_mask][valid]] new_xyz_batch_cnt = torch.tensor([(new_xyz[:, 0] == b).sum() for b in range(B)], device=new_xyz.device).int() if 'raw_points' in self.features_source: xyz_batch_cnt = torch.tensor([len(pts) for pts in points], device=points[0].device).int() raw_points = cat_coor_with_idx(points) xyz = raw_points[:, 1:4] point_features = None pooled_points, pooled_features = self.SA_rawpoints( xyz=xyz.contiguous(), xyz_batch_cnt=xyz_batch_cnt, new_xyz=new_xyz[:, :3].contiguous(), new_xyz_batch_cnt=new_xyz_batch_cnt, features=point_features, ) point_features_list.append(pooled_features) for k, src_name in enumerate(self.SA_layer_names): cur_stride = 2 ** (int(src_name[-1]) - 1) cur_coords = [feat[f"p{cur_stride}"]['coor'] for feat in voxel_feat] cur_feats = [feat[f"p{cur_stride}"]['feat'] for feat in voxel_feat] xyz = get_voxel_centers( torch.cat(cur_coords), downsample_times=self.downsample_times_map[src_name], voxel_size=self.voxel_size, point_cloud_range=self.point_cloud_range ) xyz_batch_cnt = torch.tensor([len(coor) for coor in cur_coords], device=cur_coords[0].device).int() pooled_points, pooled_features = self.SA_layers[k]( xyz=xyz.contiguous(), xyz_batch_cnt=xyz_batch_cnt, new_xyz=new_xyz[:, :3].contiguous(), new_xyz_batch_cnt=new_xyz_batch_cnt, features=torch.cat(cur_feats, dim=0), ) point_features_list.append(pooled_features) point_features = torch.cat(point_features_list, dim=1) out_dict = {} # out_dict['point_features_before_fusion'] = point_features point_features = self.vsa_point_feature_fusion(point_features) cur_idx = 0 out_dict['point_features'] = [] out_dict['point_coords'] = [] out_dict['point_scores'] = [] out_dict['boxes'] = [] out_dict['scores'] = [] for i, num in enumerate(new_xyz_batch_cnt): out_dict['point_features'].append(point_features[cur_idx:cur_idx + num]) out_dict['point_coords'].append(new_xyz[cur_idx:cur_idx + num]) out_dict['point_scores'].append(new_xyz_scrs[cur_idx:cur_idx + num]) mask = boxes[:, 0] == i out_dict['boxes'].append(boxes[mask, 1:]) out_dict['scores'].append(scores[mask]) cur_idx += num return out_dict