Source code for cosense3d.modules.plugin.pillar_encoder

import torch
from torch import nn
import torch.nn.functional as F

from cosense3d.modules.utils.conv import ConvModule
from cosense3d.modules.utils.init import xavier_init


[docs]class PFNLayer(nn.Module): def __init__(self, in_channels, out_channels, use_norm=True, last_layer=False): super().__init__() self.last_vfe = last_layer self.use_norm = use_norm if not self.last_vfe: out_channels = out_channels // 2 if self.use_norm: self.linear = nn.Linear(in_channels, out_channels, bias=False) self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01) else: self.linear = nn.Linear(in_channels, out_channels, bias=True) self.part = 50000
[docs] def forward(self, inputs): if inputs.shape[0] > self.part: # nn.Linear performs randomly when batch size is too large num_parts = inputs.shape[0] // self.part part_linear_out = [self.linear( inputs[num_part * self.part:(num_part + 1) * self.part]) for num_part in range(num_parts + 1)] x = torch.cat(part_linear_out, dim=0) else: x = self.linear(inputs) torch.backends.cudnn.enabled = False x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) if self.use_norm else x torch.backends.cudnn.enabled = True x = F.relu(x) x_max = torch.max(x, dim=1, keepdim=True)[0] if self.last_vfe: return x_max else: x_repeat = x_max.repeat(1, inputs.shape[1], 1) x_concatenated = torch.cat([x, x_repeat], dim=2) return x_concatenated
[docs]class PillarEncoder(nn.Module): def __init__(self, features, voxel_size, lidar_range, channels, use_norm=True): super(PillarEncoder, self).__init__() self.voxel_size = nn.Parameter(torch.tensor(voxel_size), requires_grad=False) self.lidar_range = nn.Parameter(torch.tensor(lidar_range), requires_grad=False) self.offset = nn.Parameter(self.voxel_size / 2 + self.lidar_range[:3], requires_grad=False) self.num_point_features = sum( [getattr(self, f"{f}_dim") for f in features]) self.features = features assert isinstance(channels, list) self.channels = [self.num_point_features] + channels self.out_channels = channels[-1] self.use_norm = use_norm self._init_layers(self.channels) def _init_layers(self, channels): pfn_layers = [] for i in range(len(channels) - 1): in_filters = channels[i] out_filters = channels[i + 1] pfn_layers.append( PFNLayer(in_filters, out_filters, self.use_norm, last_layer=(i >= len(channels) - 2)) ) self.pfn_layers = nn.ModuleList(pfn_layers)
[docs] def forward(self, voxel_features, coords, voxel_num_points): points_mean = voxel_features[..., :3].sum(dim=1, keepdim=True) / \ voxel_num_points.view(-1, 1, 1) f_cluster = voxel_features[..., :3] - points_mean coords_metric = coords[:, [3, 2, 1]].unsqueeze(1) * self.voxel_size + self.offset f_center = voxel_features[..., :3] - coords_metric features = self.compose_voxel_feature(voxel_features) + [f_cluster, f_center] features = torch.cat(features, dim=-1) voxel_count = features.shape[1] mask = self.get_paddings_indicator(voxel_num_points, voxel_count, axis=0) features *= mask.unsqueeze(-1) for pfn in self.pfn_layers: features = pfn(features) features = features.squeeze() return features
[docs] def compose_voxel_feature(self, voxel_features): features = [] if 'absolute_xyz' in self.features: features.append(voxel_features[..., :3]) if 'distance' in self.features: features.append(torch.norm(voxel_features[..., :3], 2, -1, keepdim=True)) if 'intensity' in self.features: assert voxel_features.shape[-1] >= 4 features.append(voxel_features[..., 3:4]) return features
[docs] @staticmethod def get_paddings_indicator(actual_num, max_num, axis=0): actual_num = torch.unsqueeze(actual_num, axis + 1) max_num_shape = [1] * len(actual_num.shape) max_num_shape[axis + 1] = -1 max_num = torch.arange(max_num, dtype=torch.int, device=actual_num.device).view(max_num_shape) paddings_indicator = actual_num.int() > max_num return paddings_indicator
@property def distance_dim(self): return 1 @property def absolute_xyz_dim(self): return 6 @property def xyz_dim(self): return 3 @property def intensity_dim(self): return 1