import copy
import math
import torch
from cosense3d.ops.utils import points_in_boxes_gpu
[docs]def build_box_coder(type, **kwargs):
return globals()[type](**kwargs)
[docs]class ResidualBoxCoder(object):
def __init__(self, mode: str='simple_dist'):
"""
:param mode: str, simple_dist | sin_cos_dist | compass_rose
"""
self.mode = mode
if mode == 'simple_dist':
self.code_size = 7
elif mode == 'sin_cos_dist':
self.code_size = 8
elif mode == 'compass_rose':
self.code_size = 10
self.cls_code_size = 2
else:
raise NotImplementedError
[docs] def encode_direction(self, ra, rg):
if self.mode == 'simple_dist':
reg = (rg - ra).view(-1, 1)
return reg, None
elif self.mode == 'sin_cos_dist':
rgx = torch.cos(rg)
rgy = torch.sin(rg)
rax = torch.cos(ra)
ray = torch.sin(ra)
rtx = rgx - rax
rty = rgy - ray
ret = [rtx, rty]
reg = torch.stack(ret, dim=-1) # N 2
return reg, None
elif self.mode == 'compass_rose':
# encode box directions
rgx = torch.cos(rg).view(-1, 1) # N 1
rgy = torch.sin(rg).view(-1, 1) # N 1
ra_ext = torch.cat([ra, ra + math.pi], dim=-1) # N 2, invert
rax = torch.cos(ra_ext) # N 2
ray = torch.sin(ra_ext) # N 2
# cos(a - b) = cos(a)cos(b) + sin(a)sin(b)
# we use arccos instead of a-b to control the difference in 0-pi
diff_angle = torch.arccos(rax * rgx + ray * rgy) # N 2
dir_score = 1 - diff_angle / math.pi # N 2
rtx = rgx - rax # N 2
rty = rgy - ray # N 2
dir_score = dir_score # N 2
ret = [rtx, rty]
reg = torch.cat(ret, dim=-1) # N 4
return reg, dir_score
else:
raise NotImplementedError
[docs] def decode_direction(self, ra, vt, dir_scores=None):
if self.mode == 'simple_dist':
rg = vt + ra
return rg
elif self.mode == 'sin_cos_dist':
rax = torch.cos(ra)
ray = torch.sin(ra)
va = torch.cat([rax, ray], dim=-1)
vg = vt + va
rg = torch.atan2(vg[..., 1], vg[..., 0])
return rg
elif self.mode == 'compass_rose':
ra_ext = torch.cat([ra, ra + math.pi], dim=-1) # N 2, invert
rax = torch.cos(ra_ext) # N 2
ray = torch.sin(ra_ext) # N 2
va = torch.cat([rax, ray], dim=-1)
vg = vt + va
rg = torch.atan2(vg[..., 2:], vg[..., :2]).view(-1, 2)
dirs = torch.argmax(dir_scores, dim=-1).view(-1)
rg = rg[torch.arange(len(rg)), dirs].view(len(vt), -1, 1)
return rg
else:
raise NotImplementedError
[docs] def encode(self, anchors, boxes):
xa, ya, za, la, wa, ha, ra = torch.split(anchors, 1, dim=-1)
xg, yg, zg, lg, wg, hg, rg = torch.split(boxes, 1, dim=-1)
diagonal = torch.sqrt(la ** 2 + wa ** 2)
xt = (xg - xa) / diagonal
yt = (yg - ya) / diagonal
zt = (zg - za) / ha
lt = torch.log(lg / la)
wt = torch.log(wg / wa)
ht = torch.log(hg / ha)
reg_dir, dir_score = self.encode_direction(ra, rg)
ret = [xt, yt, zt, lt, wt, ht, reg_dir]
reg = torch.cat(ret, dim=1) # N 6+4
return reg, dir_score
[docs] def decode(self, anchors, boxes_enc, dir_scores=None):
xa, ya, za, la, wa, ha, ra = torch.split(anchors, 1, dim=-1)
xt, yt, zt, lt, wt, ht = torch.split(boxes_enc[..., :6], 1, dim=-1)
vt = boxes_enc[..., 6:]
diagonal = torch.sqrt(la ** 2 + wa ** 2)
xg = xt * diagonal + xa
yg = yt * diagonal + ya
zg = zt * ha + za
lg = torch.exp(lt) * la
wg = torch.exp(wt) * wa
hg = torch.exp(ht) * ha
rg = self.decode_direction(ra, vt, dir_scores)
return torch.cat([xg, yg, zg, lg, wg, hg, rg], dim=-1)
[docs]class CenterBoxCoder(object):
def __init__(self, with_velo=False, with_pred=False, reg_radius=1.6, z_offset=1.0):
self.with_velo = with_velo
self.with_pred = with_pred
self.reg_radius = reg_radius
self.z_offset = z_offset
self.pred_max_offset = 2.0 + reg_radius
[docs] def encode(self, centers, gt_boxes, meter_per_pixel, gt_preds=None):
"""
:param centers: (N, 3)
:param gt_boxes: (N, 8) [batch_idx, x, y, z, l, w, h, r]
:param meter_per_pixel: tuple with 2 elements
:param gt_preds:
:return:
"""
if isinstance(meter_per_pixel, list):
assert meter_per_pixel[0] == meter_per_pixel[1], 'only support unified pixel size for x and y'
# TODO: adapt meter per pixel
meter_per_pixel = meter_per_pixel[0]
if len(gt_boxes) == 0:
valid = torch.zeros_like(centers[:, 0]).bool()
res = None, None, None, valid
if self.with_velo:
res = res + (None,)
return res
# match centers and gt_boxes
dist_ctr_to_box = torch.norm(centers[:, 1:3].unsqueeze(1)
- gt_boxes[:, 1:3].unsqueeze(0), dim=-1)
cc, bb = torch.meshgrid(centers[:, 0], gt_boxes[:, 0], indexing='ij')
dist_ctr_to_box[cc != bb] = 1000
min_dists, box_idx_of_pts = dist_ctr_to_box.min(dim=1)
diagnal = torch.norm(gt_boxes[:, 4:6].mean(dim=0) / 2)
valid = min_dists < max(diagnal, meter_per_pixel[0])
# valid = min_dists < self.reg_radius
valid_center, valid_box = centers[valid], gt_boxes[box_idx_of_pts[valid]]
valid_pred = None
if self.with_pred and gt_preds is not None:
valid_pred = gt_preds[box_idx_of_pts[valid]]
xc, yc = torch.split(valid_center[:, 1:3], 1, dim=-1)
xg, yg, zg, lg, wg, hg, rg = torch.split(valid_box[:, 1:8], 1, dim=-1)
xt = xg - xc
yt = yg - yc
zt = zg # + self.z_offset
lt = torch.log(lg)
wt = torch.log(wg)
ht = torch.log(hg)
# encode box directions
rgx = torch.cos(rg).view(-1, 1) # N 1
rgy = torch.sin(rg).view(-1, 1) # N 1
ra = torch.arange(0, 2, 0.5).to(xc.device) * math.pi
ra_ext = torch.ones_like(valid_box[:, :4]) * ra.view(-1, 4) # N 4
rax = torch.cos(ra_ext) # N 4
ray = torch.sin(ra_ext) # N 4
# cos(a - b) = cos(a)cos(b) + sin(a)sin(b)
# we use arccos instead of a-b to control the difference in 0-pi
diff_angle = torch.arccos(rax * rgx + ray * rgy) # N 4
dir_score = 1 - diff_angle / math.pi # N 4
rtx = rgx - rax # N 4
rty = rgy - ray # N 4
reg_box = torch.cat([xt, yt, zt, lt, wt, ht], dim=1) # N 6
reg_dir = torch.cat([rtx, rty], dim=1) # N 8
# reg_box[..., :3] /= self.reg_radius
res = (reg_box, reg_dir, dir_score, valid)
if self.with_velo:
res = res + (valid_box[:, 8:10],)
elif valid_box.shape[-1] > 8:
res = res + (valid_box[:, 8:10],)
if self.with_pred and valid_pred is not None:
prev_angles = valid_box[:, 7:8]
preds_tgt = []
mask = []
for i, boxes in enumerate(valid_pred.transpose(1, 0)):
# some gt_boxes do not have gt successors, zero padded virtual successors are used to align the number
# of boxes between gt_boxes and gt_preds, when calculate preds loss, these boxes should be ignored.
mask.append(boxes.any(dim=-1, keepdim=True).float())
diff_xy = (boxes[:, :2] - valid_center[:, 1:3]) # / self.pred_max_offset
diff_z = boxes[:, 2:3] # + self.z_offset
diff_cos = torch.cos(boxes[:, 3:]) - torch.cos(prev_angles)
diff_sin = torch.sin(boxes[:, 3:]) - torch.sin(prev_angles)
preds_tgt.append(torch.cat([diff_xy, diff_z, diff_cos, diff_sin], dim=-1) / (i + 2))
preds_tgt = torch.cat(preds_tgt, dim=-1)
mask = torch.cat(mask, dim=-1).all(dim=-1, keepdim=True)
res = res + (torch.cat([mask, preds_tgt], dim=-1),)
return res
[docs] def decode(self, centers, reg):
"""
:param centers: Tensor (N, 3) or (B, N, 2+).
:param reg: dict,
box - (N, 6) or (B, N, 6)
dir - (N, 8) or (B, N, 8)
scr - (N, 4) or (B, N, 4)
vel - (N, 2) or (B, N, 2), optional
pred - (N, 5) or (B, N, 5), optional
:return: decoded bboxes.
"""
if centers.ndim > 2:
xc, yc = torch.split(centers[..., 0:2], 1, dim=-1)
else:
xc, yc = torch.split(centers[..., 1:3], 1, dim=-1)
# reg['box'][..., :3] *= self.reg_radius
xt, yt, zt, lt, wt, ht = torch.split(reg['box'], 1, dim=-1)
xo = xt + xc
yo = yt + yc
zo = zt #- self.z_offset
lo = torch.exp(lt)
wo = torch.exp(wt)
ho = torch.exp(ht)
# decode box directions
scr_max, max_idx = reg['scr'].max(dim=-1)
shape = max_idx.shape
max_idx = max_idx.view(-1)
ii = torch.arange(len(max_idx))
ra = max_idx.float() * 0.5 * math.pi
ct = reg['dir'][..., :4].view(-1, 4)[ii, max_idx] + torch.cos(ra)
st = reg['dir'][..., 4:].view(-1, 4)[ii, max_idx] + torch.sin(ra)
ro = torch.atan2(st.view(*shape), ct.view(*shape)).unsqueeze(-1)
if centers.ndim > 2:
# dense tensor
ret = torch.cat([xo, yo, zo, lo, wo, ho, ro], dim=-1)
else:
# sparse tensor with batch indices
ret = torch.cat([centers[..., :1], xo, yo, zo, lo, wo, ho, ro], dim=-1)
if self.with_velo:
ret = torch.cat([ret, reg['vel']], dim=-1)
if self.with_pred:
pred = reg['pred'].clone()
b, n, c = pred.shape
pred_len = c // 5
mul = torch.arange(1, pred_len + 1, device=pred.device, dtype=pred.dtype)
pred = pred.view(b, n, -1, 5) * mul.view(1, 1, -1, 1)
xy = pred[..., :2] + centers[..., :2].unsqueeze(-2)
z = pred[..., 2:3]
r = torch.atan2(pred[..., 4] + st.view(*shape, 1), pred[..., 3] + ct.view(*shape, 1)).unsqueeze(-1)
lwh = torch.cat([lo, wo, ho], dim=-1).unsqueeze(-2).repeat(1, 1, pred_len, 1)
pred = torch.cat([xy, z, lwh, r], dim=-1)
ret = (ret, pred)
return ret
[docs]class BoxPredCoder(object):
def __init__(self, with_velo=False):
self.with_velo = with_velo
[docs] def encode(self, centers, gt_boxes, meter_per_pixel, gt_preds):
"""
:param centers: (N, 3)
:param gt_boxes: (N, 8) [batch_idx, x, y, z, l, w, h, r]
:param meter_per_pixel: tuple with 2 elements
:param gt_preds: (N, 8) [batch_idx, x, y, z, l, w, h, r], gt boxes to be predicted
:return: encoded bbox targets.
"""
if isinstance(meter_per_pixel, list):
assert meter_per_pixel[0] == meter_per_pixel[1], 'only support unified pixel size for x and y'
# TODO: adapt meter per pixel
meter_per_pixel = meter_per_pixel[0]
if len(gt_boxes) == 0:
valid = torch.zeros_like(centers[:, 0]).bool()
res = None, None, None, valid
if self.with_velo:
res = res + (None,)
return res
# match centers and gt_boxes
dist_ctr_to_box = torch.norm(centers[:, 1:3].unsqueeze(1)
- gt_boxes[:, 1:3].unsqueeze(0), dim=-1)
cc, bb = torch.meshgrid(centers[:, 0], gt_boxes[:, 0], indexing='ij')
dist_ctr_to_box[cc != bb] = 1000
min_dists, box_idx_of_pts = dist_ctr_to_box.min(dim=1)
diagnal = torch.norm(gt_boxes[:, 4:6].mean(dim=0) / 2)
valid = min_dists < max(diagnal, meter_per_pixel[0])
# valid = min_dists < self.reg_radius
valid_center = centers[valid]
valid_box = gt_preds[box_idx_of_pts[valid]]
xc, yc = torch.split(valid_center[:, 1:3], 1, dim=-1)
xg, yg, zg, lg, wg, hg, rg = torch.split(valid_box[:, 1:8], 1, dim=-1)
xt = xg - xc
yt = yg - yc
zt = zg # + self.z_offset
lt = torch.log(lg)
wt = torch.log(wg)
ht = torch.log(hg)
# encode box directions
rgx = torch.cos(rg).view(-1, 1) # N 1
rgy = torch.sin(rg).view(-1, 1) # N 1
ra = torch.arange(0, 2, 0.5).to(xc.device) * math.pi
ra_ext = torch.ones_like(valid_box[:, :4]) * ra.view(-1, 4) # N 4
rax = torch.cos(ra_ext) # N 4
ray = torch.sin(ra_ext) # N 4
# cos(a - b) = cos(a)cos(b) + sin(a)sin(b)
# we use arccos instead of a-b to control the difference in 0-pi
diff_angle = torch.arccos(rax * rgx + ray * rgy) # N 4
dir_score = 1 - diff_angle / math.pi # N 4
rtx = rgx - rax # N 4
rty = rgy - ray # N 4
reg_box = torch.cat([xt, yt, zt, lt, wt, ht], dim=1) # N 6
reg_dir = torch.cat([rtx, rty], dim=1) # N 8
# reg_box[..., :3] /= self.reg_radius
res = (reg_box, reg_dir, dir_score, valid)
if self.with_velo:
res = res + (valid_box[:, 8:10],)
elif valid_box.shape[-1] > 8:
res = res + (valid_box[:, 8:10],)
return res
[docs] def decode(self, centers, reg):
"""
:param centers: Tensor (N, 3) or (B, N, 2+).
:param reg: dict,
box - (N, 6) or (B, N, 6)
dir - (N, 8) or (B, N, 8)
scr - (N, 4) or (B, N, 4)
vel - (N, 2) or (B, N, 2), optional
pred - (N, 5) or (B, N, 5), optional
:return: decoded bboxes.
"""
if centers.ndim > 2:
xc, yc = torch.split(centers[..., 0:2], 1, dim=-1)
else:
xc, yc = torch.split(centers[..., 1:3], 1, dim=-1)
# reg['box'][..., :3] *= self.reg_radius
xt, yt, zt, lt, wt, ht = torch.split(reg['box'], 1, dim=-1)
xo = xt + xc
yo = yt + yc
zo = zt #- self.z_offset
lo = torch.exp(lt)
wo = torch.exp(wt)
ho = torch.exp(ht)
# decode box directions
scr_max, max_idx = reg['scr'].max(dim=-1)
shape = max_idx.shape
max_idx = max_idx.view(-1)
ii = torch.arange(len(max_idx))
ra = max_idx.float() * 0.5 * math.pi
ct = reg['dir'][..., :4].view(-1, 4)[ii, max_idx] + torch.cos(ra)
st = reg['dir'][..., 4:].view(-1, 4)[ii, max_idx] + torch.sin(ra)
ro = torch.atan2(st.view(*shape), ct.view(*shape)).unsqueeze(-1)
if centers.ndim > 2:
# dense tensor
ret = torch.cat([xo, yo, zo, lo, wo, ho, ro], dim=-1)
else:
# sparse tensor with batch indices
ret = torch.cat([centers[..., :1], xo, yo, zo, lo, wo, ho, ro], dim=-1)
if self.with_velo:
ret = torch.cat([ret, reg['vel']], dim=-1)
return ret