Source code for cosense3d.modules.utils.common

from importlib import import_module

import torch
from torch import nn
import numpy as np

from torch.distributions.multivariate_normal import _batch_mahalanobis
from cosense3d.modules.utils.me_utils import metric2indices

pi = 3.141592653


[docs]def clip_sigmoid(x: torch.Tensor, eps: float=1e-4) -> torch.Tensor: """Sigmoid function for input feature. :param x: Input feature map with the shape of [B, N, H, W]. :param eps: Lower bound of the range to be clamped to. Defaults to 1e-4. :return: Feature map after sigmoid. """ y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps) return y
[docs]def cat_name_str(module_name): """ :param module_name: str, format in xxx_yyy_zzz :returns: class_name: str, format in XxxYyyZzz """ cls_name = '' for word in module_name.split('_'): cls_name += word[:1].upper() + word[1:] return cls_name
[docs]def instantiate(module_name, cls_name=None, module_cfg=None, **kwargs): package = import_module(f"cosense3d.model.{module_name}") cls_name = cat_name_str(module_name) if cls_name is None else cls_name obj_cls = getattr(package, cls_name) if module_cfg is None: obj_inst = obj_cls(**kwargs) else: obj_inst = obj_cls(module_cfg) return obj_inst
[docs]def bias_init_with_prob(prior_prob: float) -> float: """initialize conv/fc bias value according to a given probability value.""" bias_init = float(-np.log((1 - prior_prob) / prior_prob)) return bias_init
[docs]def topk_gather(feat, topk_indexes): if topk_indexes is not None: feat_shape = feat.shape topk_shape = topk_indexes.shape view_shape = [1 for _ in range(len(feat_shape))] view_shape[:2] = topk_shape[:2] topk_indexes = topk_indexes.view(*view_shape) feat = torch.gather(feat, 1, topk_indexes.repeat(1, 1, *feat_shape[2:])) return feat
[docs]def inverse_sigmoid(x, eps=1e-5): """Inverse function of sigmoid. :param x: (Tensor) The tensor to do the inverse. :param eps: (float) EPS avoid numerical overflow. Defaults 1e-5. :returns: Tensor: The x has passed the inverse function of sigmoid, has same shape with input. """ x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1 / x2)
[docs]def xavier_init(module: nn.Module, gain: float = 1, bias: float = 0, distribution: str = 'normal') -> None: assert distribution in ['uniform', 'normal'] if hasattr(module, 'weight') and module.weight is not None: if distribution == 'uniform': nn.init.xavier_uniform_(module.weight, gain=gain) else: nn.init.xavier_normal_(module.weight, gain=gain) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias)
[docs]def limit_period(val, offset=0.5, period=2 * pi): return val - torch.floor(val / period + offset) * period
[docs]def get_conv2d_layers(conv_name, in_channels, out_channels, n_layers, kernel_size, stride, padding, relu_last=True, sequential=True, **kwargs): """ Build convolutional layers. kernel_size, stride and padding should be a list with the lengths that match n_layers """ seq = [] if 'bias' in kwargs: bias = kwargs.pop('bias') else: bias = False for i in range(n_layers): seq.extend([getattr(nn, conv_name)( in_channels, out_channels, kernel_size[i], stride=stride[i], padding=padding[i], bias=bias, **{k: v[i] for k, v in kwargs.items()} ), nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01)]) if i < n_layers - 1 or relu_last: seq.append(nn.ReLU()) in_channels = out_channels if sequential: return nn.Sequential(*seq) else: return seq
[docs]def get_norm_layer(channels, norm): if norm == 'LN': norm_layer = nn.LayerNorm(channels) elif norm == 'BN': norm_layer = nn.BatchNorm1d(channels) else: raise NotImplementedError return norm_layer
[docs]def linear_last(in_channels, mid_channels, out_channels, bias=False, norm='BN'): return nn.Sequential( nn.Linear(in_channels, mid_channels, bias=bias), get_norm_layer(mid_channels, norm), nn.ReLU(inplace=True), nn.Linear(mid_channels, out_channels) )
[docs]def linear_layers(in_out, activations=None, norm='BN'): if activations is None: activations = ['ReLU'] * (len(in_out) - 1) elif isinstance(activations, str): activations = [activations] * (len(in_out) - 1) else: assert len(activations) == (len(in_out) - 1) layers = [] for i in range(len(in_out) - 1): layers.append(nn.Linear(in_out[i], in_out[i+1], bias=False)) layers.append(get_norm_layer(in_out[i+1], norm)) layers.append(getattr(nn, activations[i])()) return nn.Sequential(*layers)
[docs]def meshgrid(xmin, xmax, ymin=None, ymax=None, dim=2, n_steps=None, step=None): assert dim <= 3, f'dim <= 3, but dim={dim} is given.' if ymin is not None and ymax is not None: assert dim == 2 if n_steps is not None: x = torch.linspace(xmin, xmax, n_steps) y = torch.linspace(ymin, ymax, n_steps) elif step is not None: x = torch.arange(xmin, xmax, step) y = torch.arange(ymin, ymax, step) else: raise NotImplementedError xs = (x, y) else: if n_steps is not None: x = torch.linspace(xmin, xmax, n_steps) if ymin is not None and ymax is not None: y = torch.linspace(ymin, ymax, n_steps) elif step is not None: x = torch.arange(xmin, xmax, step) else: raise NotImplementedError xs = (x, ) * dim indexing = 'ijk' indexing = indexing[:dim] coor = torch.stack( torch.meshgrid(*xs, indexing=indexing), dim=-1 ) return coor
[docs]def meshgrid_cross(xmins, xmaxs, n_steps=None, steps=None): if n_steps is not None: assert len(xmins) == len(n_steps) xs = [torch.linspace(xmin, xmax + 1, nstp) for xmin, xmax, nstp \ in zip(xmins, xmaxs, n_steps)] elif steps is not None: xs = [torch.arange(xmin, xmax + 1, stp) for xmin, xmax, stp \ in zip(xmins, xmaxs, steps)] else: raise NotImplementedError dim = len(xs) indexing = 'ijk' indexing = indexing[:dim] coor = torch.stack( torch.meshgrid(*xs, indexing=indexing), dim=-1 ) return coor
[docs]def pad_r(tensor, value=0): tensor_pad = torch.ones_like(tensor[..., :1]) * value return torch.cat([tensor, tensor_pad], dim=-1)
[docs]def pad_l(tensor, value=0): tensor_pad = torch.ones_like(tensor[..., :1]) * value return torch.cat([tensor_pad, tensor], dim=-1)
[docs]def cat_coor_with_idx(tensor_list): out = [] for i, t in enumerate(tensor_list): out.append(pad_l(t, i)) return torch.cat(out, dim=0)
[docs]def fuse_batch_indices(coords, num_cav): """ Fusing voxels of CAVs from the same frame :param stensor: ME sparse tensor :param num_cav: list of number of CAVs for each frame :return: fused coordinates and features of stensor """ for i, c in enumerate(num_cav): idx_start = sum(num_cav[:i]) mask = torch.logical_and( coords[:, 0] >= idx_start, coords[:, 0] < idx_start + c ) coords[mask, 0] = i return coords
[docs]def weighted_mahalanobis_dists(reg_evi, reg_var, dists, var0): log_probs_list = [] for i in range(reg_evi.shape[1]): vars = reg_var[:, i, :] + var0[i] covs = torch.diag_embed(vars.squeeze(), dim1=1) unbroadcasted_scale_tril = covs.unsqueeze(1) # N 1 2 2 # a.shape = (i, 1, n, n), b = (..., i, j, n), M = _batch_mahalanobis(unbroadcasted_scale_tril, dists) # N M log_probs = -0.5 * M log_probs_list.append(log_probs) log_probs = torch.stack(log_probs_list, dim=-1) probs = log_probs.exp() # N M 2 cls_evi = reg_evi.view(-1, 1, 2) # N 1 2 probs_weighted = probs * cls_evi return probs_weighted
[docs]def draw_sample_prob(centers, reg, samples, res, distr_r, det_r, batch_size, var0): # from utils.vislib import draw_points_boxes_plt # vis_ctrs = centers[centers[:, 0]==0, 1:].cpu().numpy() # vis_sams = samples[samples[:, 0]==0, 1:].cpu().numpy() # # ax = draw_points_boxes_plt(50, vis_ctrs, points_c='det_r', return_ax=True) # draw_points_boxes_plt(50, vis_sams, points_c='b', ax=ax) reg_evi = reg[:, :2] reg_var = reg[:, 2:].view(-1, 2, 2) grid_size = int(det_r / res) * 2 centers_map = torch.ones((batch_size, grid_size, grid_size), device=reg.device).long() * -1 ctridx = metric2indices(centers, res).T ctridx[1:] += int(grid_size / 2) centers_map[ctridx[0], ctridx[1], ctridx[2]] = torch.arange(ctridx.shape[1], device=ctridx.device) steps = int(distr_r / res) offset = meshgrid(-steps, steps, 2, n_steps=steps * 2 + 1).to(samples.device) # s s 2 samidx = metric2indices(samples, res).view(-1, 1, 3) \ + pad_l(offset).view(1, -1, 3) # n s*s 3 samidx = samidx.view(-1, 3).T # 3 n*s*s samidx[1:] = (samidx[1:] + (det_r / res)) mask1 = torch.logical_and((samidx[1:] >= 0).all(dim=0), (samidx[1:] < (det_r / res * 2)).all(dim=0)) inds = samidx[:, mask1].long() ctr_idx_of_sam = centers_map[inds[0], inds[1], inds[2]] mask2 = ctr_idx_of_sam >= 0 ctr_idx_of_sam = ctr_idx_of_sam[mask2] ns = offset.shape[0]**2 new_samples = torch.tile(samples.unsqueeze(1), (1, ns, 1)).view(-1, 3) # n*s*s 3 new_centers = centers[ctr_idx_of_sam] dists_sam2ctr = new_samples[mask1][mask2][:, 1:] - new_centers[:, 1:] probs_weighted = weighted_mahalanobis_dists( reg_evi[ctr_idx_of_sam], reg_var[ctr_idx_of_sam], dists_sam2ctr.unsqueeze(1), var0=var0 ).squeeze() sample_evis = torch.zeros_like(samidx[:2].T) mask = mask1.clone() mask[mask1] = mask2 sample_evis[mask] = probs_weighted sample_evis = sample_evis.view(-1, ns, 2).sum(dim=1) return sample_evis
[docs]def get_voxel_centers(voxel_coords, downsample_times, voxel_size, point_cloud_range): """Get centers of spconv voxels. :param voxel_coords: (N, 3) :param downsample_times: :param voxel_size: :param point_cloud_range: :return: """ assert voxel_coords.shape[1] == 3 voxel_centers = voxel_coords[:, [2, 1, 0]].float() # (xyz) voxel_size = torch.tensor(voxel_size, device=voxel_centers.device).float() * downsample_times pc_range = torch.tensor(point_cloud_range[0:3], device=voxel_centers.device).float() voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range return voxel_centers