# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from mmdetection (https://github.com/open-mmlab/mmdetection)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
# Modified by Shihao Wang
# Modified by Yunshuang Yuan
# ------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import numpy as np
[docs]def ratio2coord(ratio, lidar_range):
return ratio * (lidar_range[3:] - lidar_range[:3]) + lidar_range[:3]
[docs]def coor2ratio(coor, lidar_range):
return (coor - lidar_range[:3]) / (lidar_range[3:] - lidar_range[:3])
[docs]def img_locations(img_size, feat_size=None, stride=None):
H, W = img_size
if feat_size is None:
assert stride is not None
h, w = H // stride, W // stride
elif stride is None:
h, w = feat_size
stride = H // h
shifts_x = (torch.arange(
0, stride * w, step=stride,
) + stride // 2) / W
shifts_y = (torch.arange(
0, h * stride, step=stride,
) + stride // 2) / H
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing='ij')
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
coors = torch.stack((shift_x, shift_y), dim=1)
coors = coors.reshape(h, w, 2)
return coors
[docs]def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):
scale = 2 * math.pi
pos = pos * scale
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
pos_x = pos[..., 0, None] / dim_t
pos_y = pos[..., 1, None] / dim_t
pos_z = pos[..., 2, None] / dim_t
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)
posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)
return posemb
[docs]def pos2posemb2d(pos, num_pos_feats=128, temperature=10000):
scale = 2 * math.pi
pos = pos * scale
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
pos_x = pos[..., 0, None] / dim_t
pos_y = pos[..., 1, None] / dim_t
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
posemb = torch.cat((pos_y, pos_x), dim=-1)
return posemb
[docs]def pos2posemb1d(pos, num_pos_feats=256, temperature=10000):
scale = 2 * math.pi
pos = pos * scale
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
pos_x = pos[..., 0, None] / dim_t
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
return pos_x
[docs]def nerf_positional_encoding(
tensor: torch.Tensor,
num_encoding_functions: int=6,
include_input: bool=False,
log_sampling: bool=True
) -> torch.Tensor:
r"""Apply positional encoding to the input.
:param tensor: Input tensor to be positionally encoded.
:param num_encoding_functions: Number of encoding functions used to compute
a positional encoding (default: 6).
:param include_input: Whether or not to include the input in the
positional encoding (default: True).
:param log_sampling:
:return: Positional encoding of the input tensor.
# Trivially, the input tensor is added to the positional encoding.
encoding = [tensor] if include_input else []
frequency_bands = None
if log_sampling:
frequency_bands = 2.0 ** torch.linspace(
num_encoding_functions - 1,
frequency_bands = torch.linspace(
2.0 ** 0.0,
2.0 ** (num_encoding_functions - 1),
for freq in frequency_bands:
for func in [torch.sin, torch.cos]:
encoding.append(func(tensor * freq))
# Special case, for no positional encoding
if len(encoding) == 1:
return encoding[0]
return torch.cat(encoding, dim=-1)