Source code for cosense3d.modules.utils.gaussian_utils

from typing import List
import torch
from torch.distributions.multivariate_normal import _batch_mahalanobis
import torch_scatter
import numpy as np


[docs]def weighted_mahalanobis_dists(vars, dists, weights=None): """Compute the squared mahalanobis distances. :param vars: (N, 2), variances of Gaussian distribution. :param dists: (N, 2), distances to gaussian center at each axis. :param weights: weights to be applied to the output probability. :return: (N), squared mahalanobis """ vars = vars.squeeze() if len(vars.shape) == 1: vars = torch.stack([vars, vars], dim=-1) covs = torch.diag_embed(vars.squeeze(), dim1=1) unbroadcasted_scale_tril = covs.unsqueeze(1) # N 1 2 2 # a.shape = (i, 1, n, n), b = (..., i, j, n), M = _batch_mahalanobis(unbroadcasted_scale_tril, dists) # N M log_probs = -0.5 * M probs = log_probs.exp() # N M 2 if weights is not None: probs = probs * weights return probs
[docs]def mahalanobis_dists_2d(sigmas, dists): """Compute the squared mahalanobis distances. :param sigmas: (N, 2), standard deviation of Gaussian distribution :param dists: (N, 2), distances to gaussian center :return: (N), squared mahalanobis """ vars = sigmas ** 2 covs = torch.diag_embed(vars, dim1=1) unbroadcasted_scale_tril = covs.unsqueeze(1) # 1 1 2 2 M = -0.5 * _batch_mahalanobis(unbroadcasted_scale_tril, dists.unsqueeze(0)) # N M return M
[docs]def center_to_img_coor(center_in, lidar_range, pixel_sz): x, y = center_in[:, 0], center_in[:, 1] coord_x = (x - lidar_range[0]) / pixel_sz coord_y = (y - lidar_range[1]) / pixel_sz map_sz_x = (lidar_range[3] - lidar_range[0]) / pixel_sz map_sz_y = (lidar_range[4] - lidar_range[1]) / pixel_sz # clamp to fit image size: 1e-6 does not work for center.int() coord_x = torch.clamp(coord_x, min=0, max=map_sz_x - 0.5) coord_y = torch.clamp(coord_y, min=0, max=map_sz_y - 0.5) center_out = torch.cat((coord_x[:, None], coord_y[:, None]), dim=-1) return center_out
[docs]def cornernet_gaussian_radius(height, width, min_overlap=0.5): a1 = 1 b1 = (height + width) c1 = width * height * (1 - min_overlap) / (1 + min_overlap) sq1 = (b1 ** 2 - 4 * a1 * c1).sqrt() r1 = (b1 + sq1) / 2 a2 = 4 b2 = 2 * (height + width) c2 = (1 - min_overlap) * width * height sq2 = (b2 ** 2 - 4 * a2 * c2).sqrt() r2 = (b2 + sq2) / 2 a3 = 4 * min_overlap b3 = -2 * min_overlap * (height + width) c3 = (min_overlap - 1) * width * height sq3 = (b3 ** 2 - 4 * a3 * c3).sqrt() r3 = (b3 + sq3) / 2 ret = torch.min(torch.min(r1, r2), r3) return ret
[docs]def gaussian_radius(box_dims, pixel_sz, overlap, min_radius=2): dx, dy = box_dims[:, 0] / pixel_sz[0], box_dims[:, 1] / pixel_sz[1] radius = cornernet_gaussian_radius(dx, dy, min_overlap=overlap) radius = torch.clamp_min(radius.int(), min=min_radius) return radius
[docs]def gaussian_2d(shape: List[int], sigma: float=1.0) -> np.ndarray: """Generate gaussian map. :param shape: Shape of the map. :param sigma: Sigma to generate gaussian map. Defaults to 1. :return: Generated gaussian map. """ m, n = [(ss - 1.) / 2. for ss in shape] y, x = np.ogrid[-m:m + 1, -n:n + 1] h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h
[docs]def draw_gaussian_map(boxes, lidar_range, pixel_sz, batch_size, radius=None, sigma=1, min_radius=2): size_x = int((lidar_range[3] - lidar_range[0]) // pixel_sz[0]) size_y = int((lidar_range[4] - lidar_range[1]) // pixel_sz[1]) if boxes.shape[0] == 0: return torch.zeros(batch_size, size_x, size_y, device=boxes.device) if radius is None: radius = torch.ones_like(boxes[:, 0]) * 2 radius_max = radius.max() center = center_to_img_coor(boxes[:, 1:3], lidar_range, pixel_sz) ctridx = center.int() # sample points for each center point steps = radius_max * 2 + 1 x = torch.linspace(- radius_max, radius_max, steps) offsets = torch.stack(torch.meshgrid(x, x, indexing='ij'), dim=-1).to(center.device) offsets = offsets[torch.norm(offsets, dim=-1) <= radius_max] samples = ctridx.unsqueeze(1) + offsets.view(1, -1, 2) ind = torch.tile(boxes[:, 0].unsqueeze(1), (1, samples.shape[1])).unsqueeze(-1) samples = torch.cat([ind, samples], dim=-1) ctr_idx_of_sam = torch.arange(len(center)).unsqueeze(1).tile(1, samples.shape[1]) mask = (samples[..., 1] >= 0) & (samples[..., 1] < size_x) & \ (samples[..., 2] >= 0) & (samples[..., 2] < size_y) new_center = center[ctr_idx_of_sam[mask]] new_vars = 1 / min_radius * radius[ctr_idx_of_sam[mask]].float() new_samples = samples[mask] dists_sam2ctr = new_samples[:, 1:].float() - new_center probs = weighted_mahalanobis_dists( new_vars, dists_sam2ctr.unsqueeze(1), ).squeeze() # probs = probs / (2 * sigma * sigma) probs[probs < torch.finfo(probs.dtype).eps * probs.max()] = 0 indices = new_samples[:, 0] * size_y * size_x + \ new_samples[:, 1] * size_x + new_samples[:, 2] center_map = torch.zeros(batch_size * size_x * size_y, device=center.device) torch_scatter.scatter(probs, indices.long(), dim=0, out=center_map, reduce='max') center_map = center_map.view(batch_size, size_x, size_y) # import matplotlib.pyplot as plt # plt.imshow(center_map[0].cpu().numpy()) # plt.show() # plt.close() return center_map