Source code for cosense3d.modules.utils.positional_encoding

# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from mmdetection (https://github.com/open-mmlab/mmdetection)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
#  Modified by Shihao Wang
#  Modified by Yunshuang Yuan
# ------------------------------------------------------------------------
import math
import torch
import torch.nn as nn 
import numpy as np


[docs]def ratio2coord(ratio, lidar_range): return ratio * (lidar_range[3:] - lidar_range[:3]) + lidar_range[:3]
[docs]def coor2ratio(coor, lidar_range): return (coor - lidar_range[:3]) / (lidar_range[3:] - lidar_range[:3])
[docs]def img_locations(img_size, feat_size=None, stride=None): H, W = img_size if feat_size is None: assert stride is not None h, w = H // stride, W // stride elif stride is None: h, w = feat_size stride = H // h shifts_x = (torch.arange( 0, stride * w, step=stride, dtype=torch.float32 ) + stride // 2) / W shifts_y = (torch.arange( 0, h * stride, step=stride, dtype=torch.float32 ) + stride // 2) / H shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing='ij') shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) coors = torch.stack((shift_x, shift_y), dim=1) coors = coors.reshape(h, w, 2) return coors
[docs]def pos2posemb3d(pos, num_pos_feats=128, temperature=10000): scale = 2 * math.pi pos = pos * scale dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats) pos_x = pos[..., 0, None] / dim_t pos_y = pos[..., 1, None] / dim_t pos_z = pos[..., 2, None] / dim_t pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2) posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1) return posemb
[docs]def pos2posemb2d(pos, num_pos_feats=128, temperature=10000): scale = 2 * math.pi pos = pos * scale dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats) pos_x = pos[..., 0, None] / dim_t pos_y = pos[..., 1, None] / dim_t pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) posemb = torch.cat((pos_y, pos_x), dim=-1) return posemb
[docs]def pos2posemb1d(pos, num_pos_feats=256, temperature=10000): scale = 2 * math.pi pos = pos * scale dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats) pos_x = pos[..., 0, None] / dim_t pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) return pos_x
[docs]def nerf_positional_encoding( tensor: torch.Tensor, num_encoding_functions: int=6, include_input: bool=False, log_sampling: bool=True ) -> torch.Tensor: r"""Apply positional encoding to the input. :param tensor: Input tensor to be positionally encoded. :param num_encoding_functions: Number of encoding functions used to compute a positional encoding (default: 6). :param include_input: Whether or not to include the input in the positional encoding (default: True). :param log_sampling: :return: Positional encoding of the input tensor. """ # TESTED # Trivially, the input tensor is added to the positional encoding. encoding = [tensor] if include_input else [] frequency_bands = None if log_sampling: frequency_bands = 2.0 ** torch.linspace( 0.0, num_encoding_functions - 1, num_encoding_functions, dtype=tensor.dtype, device=tensor.device, ) else: frequency_bands = torch.linspace( 2.0 ** 0.0, 2.0 ** (num_encoding_functions - 1), num_encoding_functions, dtype=tensor.dtype, device=tensor.device, ) for freq in frequency_bands: for func in [torch.sin, torch.cos]: encoding.append(func(tensor * freq)) # Special case, for no positional encoding if len(encoding) == 1: return encoding[0] else: return torch.cat(encoding, dim=-1)