from importlib import import_module
import torch
from torch import nn
import numpy as np
from torch.distributions.multivariate_normal import _batch_mahalanobis
from cosense3d.modules.utils.me_utils import metric2indices
pi = 3.141592653
[docs]def clip_sigmoid(x: torch.Tensor, eps: float=1e-4) -> torch.Tensor:
"""Sigmoid function for input feature.
:param x: Input feature map with the shape of [B, N, H, W].
:param eps: Lower bound of the range to be clamped to.
Defaults to 1e-4.
:return: Feature map after sigmoid.
"""
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
return y
[docs]def cat_name_str(module_name):
"""
:param module_name: str, format in xxx_yyy_zzz
:returns: class_name: str, format in XxxYyyZzz
"""
cls_name = ''
for word in module_name.split('_'):
cls_name += word[:1].upper() + word[1:]
return cls_name
[docs]def instantiate(module_name, cls_name=None, module_cfg=None, **kwargs):
package = import_module(f"cosense3d.model.{module_name}")
cls_name = cat_name_str(module_name) if cls_name is None else cls_name
obj_cls = getattr(package, cls_name)
if module_cfg is None:
obj_inst = obj_cls(**kwargs)
else:
obj_inst = obj_cls(module_cfg)
return obj_inst
[docs]def bias_init_with_prob(prior_prob: float) -> float:
"""initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
[docs]def topk_gather(feat, topk_indexes):
if topk_indexes is not None:
feat_shape = feat.shape
topk_shape = topk_indexes.shape
view_shape = [1 for _ in range(len(feat_shape))]
view_shape[:2] = topk_shape[:2]
topk_indexes = topk_indexes.view(*view_shape)
feat = torch.gather(feat, 1, topk_indexes.repeat(1, 1, *feat_shape[2:]))
return feat
[docs]def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
:param x: (Tensor) The tensor to do the
inverse.
:param eps: (float) EPS avoid numerical
overflow. Defaults 1e-5.
:returns: Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
[docs]def xavier_init(module: nn.Module,
gain: float = 1,
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
[docs]def limit_period(val, offset=0.5, period=2 * pi):
return val - torch.floor(val / period + offset) * period
[docs]def get_conv2d_layers(conv_name, in_channels, out_channels, n_layers, kernel_size, stride,
padding, relu_last=True, sequential=True, **kwargs):
"""
Build convolutional layers. kernel_size, stride and padding should be a list with the
lengths that match n_layers
"""
seq = []
if 'bias' in kwargs:
bias = kwargs.pop('bias')
else:
bias = False
for i in range(n_layers):
seq.extend([getattr(nn, conv_name)(
in_channels, out_channels, kernel_size[i], stride=stride[i],
padding=padding[i], bias=bias, **{k: v[i] for k, v in kwargs.items()}
), nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01)])
if i < n_layers - 1 or relu_last:
seq.append(nn.ReLU())
in_channels = out_channels
if sequential:
return nn.Sequential(*seq)
else:
return seq
[docs]def get_norm_layer(channels, norm):
if norm == 'LN':
norm_layer = nn.LayerNorm(channels)
elif norm == 'BN':
norm_layer = nn.BatchNorm1d(channels)
else:
raise NotImplementedError
return norm_layer
[docs]def linear_last(in_channels, mid_channels, out_channels, bias=False, norm='BN'):
return nn.Sequential(
nn.Linear(in_channels, mid_channels, bias=bias),
get_norm_layer(mid_channels, norm),
nn.ReLU(inplace=True),
nn.Linear(mid_channels, out_channels)
)
[docs]def linear_layers(in_out, activations=None, norm='BN'):
if activations is None:
activations = ['ReLU'] * (len(in_out) - 1)
elif isinstance(activations, str):
activations = [activations] * (len(in_out) - 1)
else:
assert len(activations) == (len(in_out) - 1)
layers = []
for i in range(len(in_out) - 1):
layers.append(nn.Linear(in_out[i], in_out[i+1], bias=False))
layers.append(get_norm_layer(in_out[i+1], norm))
layers.append(getattr(nn, activations[i])())
return nn.Sequential(*layers)
[docs]def meshgrid(xmin, xmax, ymin=None, ymax=None, dim=2, n_steps=None, step=None):
assert dim <= 3, f'dim <= 3, but dim={dim} is given.'
if ymin is not None and ymax is not None:
assert dim == 2
if n_steps is not None:
x = torch.linspace(xmin, xmax, n_steps)
y = torch.linspace(ymin, ymax, n_steps)
elif step is not None:
x = torch.arange(xmin, xmax, step)
y = torch.arange(ymin, ymax, step)
else:
raise NotImplementedError
xs = (x, y)
else:
if n_steps is not None:
x = torch.linspace(xmin, xmax, n_steps)
if ymin is not None and ymax is not None:
y = torch.linspace(ymin, ymax, n_steps)
elif step is not None:
x = torch.arange(xmin, xmax, step)
else:
raise NotImplementedError
xs = (x, ) * dim
indexing = 'ijk'
indexing = indexing[:dim]
coor = torch.stack(
torch.meshgrid(*xs, indexing=indexing),
dim=-1
)
return coor
[docs]def meshgrid_cross(xmins, xmaxs, n_steps=None, steps=None):
if n_steps is not None:
assert len(xmins) == len(n_steps)
xs = [torch.linspace(xmin, xmax + 1, nstp) for xmin, xmax, nstp \
in zip(xmins, xmaxs, n_steps)]
elif steps is not None:
xs = [torch.arange(xmin, xmax + 1, stp) for xmin, xmax, stp \
in zip(xmins, xmaxs, steps)]
else:
raise NotImplementedError
dim = len(xs)
indexing = 'ijk'
indexing = indexing[:dim]
coor = torch.stack(
torch.meshgrid(*xs, indexing=indexing),
dim=-1
)
return coor
[docs]def pad_r(tensor, value=0):
tensor_pad = torch.ones_like(tensor[..., :1]) * value
return torch.cat([tensor, tensor_pad], dim=-1)
[docs]def pad_l(tensor, value=0):
tensor_pad = torch.ones_like(tensor[..., :1]) * value
return torch.cat([tensor_pad, tensor], dim=-1)
[docs]def cat_coor_with_idx(tensor_list):
out = []
for i, t in enumerate(tensor_list):
out.append(pad_l(t, i))
return torch.cat(out, dim=0)
[docs]def fuse_batch_indices(coords, num_cav):
"""
Fusing voxels of CAVs from the same frame
:param stensor: ME sparse tensor
:param num_cav: list of number of CAVs for each frame
:return: fused coordinates and features of stensor
"""
for i, c in enumerate(num_cav):
idx_start = sum(num_cav[:i])
mask = torch.logical_and(
coords[:, 0] >= idx_start,
coords[:, 0] < idx_start + c
)
coords[mask, 0] = i
return coords
[docs]def weighted_mahalanobis_dists(reg_evi, reg_var, dists, var0):
log_probs_list = []
for i in range(reg_evi.shape[1]):
vars = reg_var[:, i, :] + var0[i]
covs = torch.diag_embed(vars.squeeze(), dim1=1)
unbroadcasted_scale_tril = covs.unsqueeze(1) # N 1 2 2
# a.shape = (i, 1, n, n), b = (..., i, j, n),
M = _batch_mahalanobis(unbroadcasted_scale_tril, dists) # N M
log_probs = -0.5 * M
log_probs_list.append(log_probs)
log_probs = torch.stack(log_probs_list, dim=-1)
probs = log_probs.exp() # N M 2
cls_evi = reg_evi.view(-1, 1, 2) # N 1 2
probs_weighted = probs * cls_evi
return probs_weighted
[docs]def draw_sample_prob(centers, reg, samples, res, distr_r, det_r, batch_size, var0):
# from utils.vislib import draw_points_boxes_plt
# vis_ctrs = centers[centers[:, 0]==0, 1:].cpu().numpy()
# vis_sams = samples[samples[:, 0]==0, 1:].cpu().numpy()
#
# ax = draw_points_boxes_plt(50, vis_ctrs, points_c='det_r', return_ax=True)
# draw_points_boxes_plt(50, vis_sams, points_c='b', ax=ax)
reg_evi = reg[:, :2]
reg_var = reg[:, 2:].view(-1, 2, 2)
grid_size = int(det_r / res) * 2
centers_map = torch.ones((batch_size, grid_size, grid_size),
device=reg.device).long() * -1
ctridx = metric2indices(centers, res).T
ctridx[1:] += int(grid_size / 2)
centers_map[ctridx[0], ctridx[1], ctridx[2]] = torch.arange(ctridx.shape[1],
device=ctridx.device)
steps = int(distr_r / res)
offset = meshgrid(-steps, steps, 2, n_steps=steps * 2 + 1).to(samples.device) # s s 2
samidx = metric2indices(samples, res).view(-1, 1, 3) \
+ pad_l(offset).view(1, -1, 3) # n s*s 3
samidx = samidx.view(-1, 3).T # 3 n*s*s
samidx[1:] = (samidx[1:] + (det_r / res))
mask1 = torch.logical_and((samidx[1:] >= 0).all(dim=0),
(samidx[1:] < (det_r / res * 2)).all(dim=0))
inds = samidx[:, mask1].long()
ctr_idx_of_sam = centers_map[inds[0], inds[1], inds[2]]
mask2 = ctr_idx_of_sam >= 0
ctr_idx_of_sam = ctr_idx_of_sam[mask2]
ns = offset.shape[0]**2
new_samples = torch.tile(samples.unsqueeze(1),
(1, ns, 1)).view(-1, 3) # n*s*s 3
new_centers = centers[ctr_idx_of_sam]
dists_sam2ctr = new_samples[mask1][mask2][:, 1:] - new_centers[:, 1:]
probs_weighted = weighted_mahalanobis_dists(
reg_evi[ctr_idx_of_sam],
reg_var[ctr_idx_of_sam],
dists_sam2ctr.unsqueeze(1),
var0=var0
).squeeze()
sample_evis = torch.zeros_like(samidx[:2].T)
mask = mask1.clone()
mask[mask1] = mask2
sample_evis[mask] = probs_weighted
sample_evis = sample_evis.view(-1, ns, 2).sum(dim=1)
return sample_evis
[docs]def get_voxel_centers(voxel_coords,
downsample_times,
voxel_size,
point_cloud_range):
"""Get centers of spconv voxels.
:param voxel_coords: (N, 3)
:param downsample_times:
:param voxel_size:
:param point_cloud_range:
:return:
"""
assert voxel_coords.shape[1] == 3
voxel_centers = voxel_coords[:, [2, 1, 0]].float() # (xyz)
voxel_size = torch.tensor(voxel_size, device=voxel_centers.device).float() * downsample_times
pc_range = torch.tensor(point_cloud_range[0:3], device=voxel_centers.device).float()
voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range
return voxel_centers