Source code for cosense3d.agents.core.hooks



import os
import time

import torch
import torch.nn.functional as F
from importlib import import_module

from cosense3d.ops.utils import points_in_boxes_gpu


[docs]class Hooks: def __init__(self, cfg): self.hooks = [] if cfg is None: return for hook_cfg in cfg: self.hooks.append( globals()[hook_cfg['type']](**hook_cfg) ) def __call__(self, runner, hook_stage, **kwargs): for hook in self.hooks: getattr(hook, hook_stage)(runner, **kwargs)
[docs] def set_logger(self, logger): for hook in self.hooks: hook.set_logger(logger)
[docs]class BaseHook: def __init__(self, **kwargs): pass
[docs] def pre_iter(self, runner, **kwargs): pass
[docs] def post_iter(self, runner, **kwargs): pass
[docs] def pre_epoch(self, runner, **kwargs): pass
[docs] def post_epoch(self, runner, **kwargs): pass
[docs] def set_logger(self, logger): self.logger = logger
[docs]class MemoryUsageHook(BaseHook): def __init__(self, device='cuda:0', **kwargs): super().__init__(**kwargs) self.device = device
[docs] def post_iter(self, runner, **kwargs): memory = torch.cuda.max_memory_allocated(self.device) / 1024 / 1024 torch.cuda.empty_cache() runner.logger.update(memory=memory)
[docs]class CPMStatisticHook(BaseHook): def __init__(self, device='cuda:0', **kwargs): super().__init__(**kwargs) self.device = device
[docs] def set_logger(self, logger): super().set_logger(logger) logdir = os.path.join(logger.logdir, 'detection_eval') os.makedirs(logdir, exist_ok=True) self.logdir = logdir
[docs] def post_epoch(self, runner, **kwargs): cpm_rec = runner.controller.cav_manager.cpm_size_recorder thr = runner.controller.cav_manager.cavs[0][0].share_score_thr ss = (f"########## CPM size @ {thr} ###########\n" f"Mean: {cpm_rec.mean[0] * 4 / 1024:.2f} KB, Std: {cpm_rec.std[0] * 4 / 1024:.2f} KB") print(ss) self.logger.log(ss)
[docs]class TrainTimerHook(BaseHook): def __init__(self, **kwargs): super().__init__(**kwargs) self.elapsed_time = 0 self.start_time = None self.mean_time_per_itr = None self.observations = 0
[docs] def pre_epoch(self, runner, **kwargs): if self.start_time is None: self.start_time = time.time() self.last_time = time.time()
[docs] def post_iter(self, runner, **kwargs): cur_time = time.time() self.elapsed_time = (cur_time - self.start_time) / 3600 # total_run_iter = (runner.total_iter * (runner.epoch - runner.start_epoch)) + runner.iter # time_per_iter = self.elapsed_time / total_run_iter time_per_iter = (cur_time - self.last_time) / 3600 m = self.observations if self.mean_time_per_itr is None: self.mean_time_per_itr = time_per_iter else: self.mean_time_per_itr = m / (m + 1) * self.mean_time_per_itr + 1 / (m + 1) * time_per_iter iter_remain = runner.total_iter * (runner.total_epochs - runner.epoch + 1) - runner.iter time_remain = self.mean_time_per_itr * iter_remain runner.logger.update(t_remain=time_remain, t_used=self.elapsed_time) self.last_time = cur_time self.observations += 1
[docs]class CheckPointsHook(BaseHook): def __init__(self, max_ckpt=3, epoch_every=None, iter_every=None, **kwargs): super().__init__(**kwargs) self.max_ckpt = max_ckpt self.epoch_every = epoch_every self.iter_every = iter_every
[docs] def post_epoch(self, runner, **kwargs): if runner.gpu_id != 0: return self.save(runner, f'epoch{runner.epoch}.pth') if runner.epoch > self.max_ckpt: if (self.epoch_every is None or not (runner.epoch - self.max_ckpt) % self.epoch_every == 0): filename = os.path.join( runner.logger.logdir, f'epoch{runner.epoch - self.max_ckpt}.pth') if os.path.exists(filename): os.remove(filename)
[docs] def post_iter(self, runner, **kwargs): if runner.gpu_id != 0: return if self.iter_every is not None and runner.iter % self.iter_every == 0: self.save(runner, f'latest.pth')
[docs] @staticmethod def save(runner, name): save_path = os.path.join(runner.logger.logdir, name) print(f'Saving checkpoint to {save_path}.') torch.save({ 'epoch': runner.epoch, 'model': runner.forward_runner.state_dict(), 'optimizer': runner.optimizer.state_dict(), 'lr_scheduler': runner.lr_scheduler.state_dict(), }, save_path)
[docs]class DetectionNMSHook(BaseHook): def __init__(self, nms_thr, pre_max_size, det_key='detection', **kwargs): super().__init__(**kwargs) self.nms_thr = nms_thr self.pre_max_size = pre_max_size self.nms = import_module('cosense3d.ops.iou3d_nms_utils').nms_gpu self.det_key = det_key self.defual_pred_keys = ['box', 'scr', 'lbl', 'idx']
[docs] def post_iter(self, runner, **kwargs): detection_out = runner.controller.data_manager.gather_ego_data(self.det_key) preds = [] cav_ids = [] for cav_id, values in detection_out.items(): cav_ids.append(cav_id) boxes = values['preds']['box'] scores = values['preds']['scr'] labels = values['preds']['lbl'] indices = values['preds']['idx'] # map index for retrieving features out = {} if 'center' in values: out['ctr'] = values['center'] if 'conf' in values: out['conf'] = values['conf'] if len(values['preds']['box']) == 0: out.update({ 'box': torch.zeros((0, 7), device=boxes.device), 'scr': torch.zeros((0,), device=scores.device), 'lbl': torch.zeros((0,), device=labels.device), 'idx': torch.zeros(indices.shape[0] if isinstance(indices, torch.Tensor) else (0,), device=indices.device), }) if 'pred' in values['preds']: out['pred'] = torch.zeros((0, 2, 7), device=boxes.device) else: keep = self.nms( boxes[..., :7], scores, thresh=self.nms_thr, pre_maxsize=self.pre_max_size ) out.update({ 'box': boxes[keep], 'scr': scores[keep], 'lbl': labels[keep], 'idx': indices[keep], }) if 'pred' in values['preds'] and values['preds']['pred'] is not None: out['pred'] = values['preds']['pred'][keep] assert len(out['pred']) != len(out['box']) preds.append(out) # from cosense3d.utils.vislib import draw_points_boxes_plt # points = out['ctr'].detach().cpu().numpy() # boxes = out['box'].detach().cpu().numpy() # draw_points_boxes_plt( # pc_range=[-140.8, -38.4, -3.0, 140.8, 38.4, 1.0], # boxes_pred=boxes, # points=points, # filename="/home/yuan/Pictures/tmp.png" # ) runner.controller.data_manager.scatter(cav_ids, {self.det_key: preds})
[docs]class EvalDetectionBEVHook(BaseHook): def __init__(self, pc_range, iou_thr=[0.5, 0.7], save_result=False, det_key='detection', gt_key='global_bboxes_3d', **kwargs): super().__init__(**kwargs) self.iou_thr = iou_thr self.pc_range = pc_range self.save_result = save_result self.det_key = det_key self.gt_key = gt_key self.result = {iou: {'tp': [], 'fp': [], 'gt': 0, 'scr': []} for iou in iou_thr} self.eval_funcs = import_module('cosense3d.utils.eval_detection_utils')
[docs] def set_logger(self, logger): super().set_logger(logger) logdir = os.path.join(logger.logdir, 'detection_eval') os.makedirs(logdir, exist_ok=True) self.logdir = logdir
[docs] def post_iter(self, runner, **kwargs): detection = runner.controller.data_manager.gather_ego_data(self.det_key) gt_boxes = runner.controller.data_manager.gather_ego_data(self.gt_key) for i, (cav_id, preds) in enumerate(detection.items()): if 'preds' in preds: preds = preds['preds'] preds['box'], preds['scr'], preds['lbl'], preds['idx'], preds['time'] = \ self.filter_box_ranges(preds['box'], preds['scr'], preds['lbl'], preds['idx'], preds.get('time', None)) cur_gt_boxes = self.filter_box_ranges(gt_boxes[cav_id])[0] cur_points = runner.controller.data_manager.gather_batch(i, 'points') if self.save_result: ego_key = cav_id senario = runner.controller.data_manager.gather_ego_data('scenario')[ego_key] frame = runner.controller.data_manager.gather_ego_data('frame')[ego_key] filename = f"{senario}.{frame}.{ego_key.split('.')[1]}.pth" result = {'detection': preds, 'gt_boxes': cur_gt_boxes, 'points': cur_points} torch.save(result, os.path.join(self.logdir, filename)) for iou in self.iou_thr: self.eval_funcs.caluclate_tp_fp( preds['box'][..., :7], preds['scr'], cur_gt_boxes[..., :7], self.result, iou )
[docs] def filter_box_ranges(self, boxes, scores=None, labels=None, indices=None, times=None): mask = boxes.new_ones((len(boxes),)).bool() if boxes.ndim == 3: centers = boxes.mean(dim=1) else: centers = boxes[:, :3] for i in range(3): mask = mask & (centers[:, i] > self.pc_range[i]) & (centers[:, i] < self.pc_range[i + 3]) boxes = boxes[mask] if scores is not None: scores = scores[mask] if labels is not None: labels = labels[mask] if indices is not None: try: indices = indices[mask] except: print("Number of boxes doesn't match the number of indices") if times is not None: times = times[mask] return boxes, scores, labels, indices, times
[docs] def post_epoch(self, runner, **kwargs): fmt_str = ("################\n" "DETECTION RESULT\n" "################\n") out_dict = self.eval_funcs.eval_final_results( self.result, self.iou_thr, global_sort_detections=True ) fmt_str += "OPV2V BEV Global sorted:\n" fmt_str += self.format_final_result(out_dict) fmt_str += "----------------\n" print(fmt_str) self.logger.log(fmt_str)
[docs] def format_final_result(self, out_dict): fmt_str = "" for iou in self.iou_thr: iou_str = f"{int(iou * 100)}" fmt_str += f"AP@{iou_str}: {out_dict[f'ap_{iou_str}']:.3f}\n" return fmt_str
[docs]class EvalDetectionHook(BaseHook): def __init__(self, pc_range, iou_thr=[0.5, 0.7], metrics=['CoSense3D'], save_result=False, det_key='detection', gt_key='global_bboxes_3d', **kwargs): super().__init__(**kwargs) self.iou_thr = iou_thr self.pc_range = pc_range self.save_result = save_result self.det_key = det_key self.gt_key = gt_key for m in metrics: assert m in ['OPV2V', 'CoSense3D'] setattr(self, f'{m.lower()}_result', {iou: {'tp': [], 'fp': [], 'gt': 0, 'scr': []} for iou in iou_thr}) self.metrics = metrics self.eval_funcs = import_module('cosense3d.utils.eval_detection_utils')
[docs] def set_logger(self, logger): super().set_logger(logger) logdir = os.path.join(logger.logdir, 'detection_eval') os.makedirs(logdir, exist_ok=True) self.logdir = logdir
[docs] def post_iter(self, runner, **kwargs): detection = runner.controller.data_manager.gather_ego_data(self.det_key) gt_boxes = runner.controller.data_manager.gather_ego_data(self.gt_key) for i, (cav_id, preds) in enumerate(detection.items()): if 'preds' in preds: preds = preds['preds'] preds['box'], preds['scr'], preds['lbl'], preds['idx'], preds['time'] = \ self.filter_box_ranges(preds['box'], preds['scr'], preds['lbl'], preds['idx'], preds.get('time', None)) cur_gt_boxes = self.filter_box_ranges(gt_boxes[cav_id])[0] cur_points = runner.controller.data_manager.gather_batch(i, 'points') if self.save_result: ego_key = cav_id senario = runner.controller.data_manager.gather_ego_data('scenario')[ego_key] frame = runner.controller.data_manager.gather_ego_data('frame')[ego_key] filename = f"{senario}.{frame}.{ego_key.split('.')[1]}.pth" result = {'detection': preds, 'gt_boxes': cur_gt_boxes, 'points': cur_points} torch.save(result, os.path.join(self.logdir, filename)) for iou in self.iou_thr: if 'OPV2V' in self.metrics: result_dict = getattr(self, f'opv2v_result') self.eval_funcs.caluclate_tp_fp( preds['box'][..., :7], preds['scr'], cur_gt_boxes[..., :7], result_dict, iou ) if 'CoSense3D' in self.metrics: result_dict = getattr(self, f'cosense3d_result') tp = self.eval_funcs.ops_cal_tp( preds['box'][..., :7].detach(), cur_gt_boxes[..., :7].detach(), IoU_thr=iou ) result_dict[iou]['tp'].append(tp.cpu()) result_dict[iou]['gt'] += len(cur_gt_boxes) result_dict[iou]['scr'].append(preds['scr'].detach().cpu())
[docs] def filter_box_ranges(self, boxes, scores=None, labels=None, indices=None, times=None): mask = boxes.new_ones((len(boxes),)).bool() if boxes.ndim == 3: centers = boxes.mean(dim=1) else: centers = boxes[:, :3] for i in range(3): mask = mask & (centers[:, i] > self.pc_range[i]) & (centers[:, i] < self.pc_range[i + 3]) boxes = boxes[mask] if scores is not None: scores = scores[mask] if labels is not None: labels = labels[mask] if indices is not None: try: indices = indices[mask] except: print("Number of boxes doesn't match the number of indices") if times is not None: times = times[mask] return boxes, scores, labels, indices, times
[docs] def post_epoch(self, runner, **kwargs): fmt_str = ("################\n" "DETECTION RESULT\n" "################\n") if 'OPV2V' in self.metrics: result_dict = getattr(self, f'opv2v_result') out_dict = self.eval_funcs.eval_final_results( result_dict, self.iou_thr, global_sort_detections=True ) fmt_str += "OPV2V BEV Global sorted:\n" fmt_str += self.format_final_result(out_dict) fmt_str += "----------------\n" out_dict = self.eval_funcs.eval_final_results( result_dict, self.iou_thr, global_sort_detections=False ) fmt_str += "OPV2V BEV Local sorted:\n" fmt_str += self.format_final_result(out_dict) fmt_str += "----------------\n" if 'CoSense3D' in self.metrics: out_dict = self.eval_cosense3d_final() fmt_str += "CoSense3D Global sorted:\n" fmt_str += self.format_final_result(out_dict) fmt_str += "----------------\n" print(fmt_str) self.logger.log(fmt_str)
[docs] def format_final_result(self, out_dict): fmt_str = "" for iou in self.iou_thr: iou_str = f"{int(iou * 100)}" fmt_str += f"AP@{iou_str}: {out_dict[f'ap_{iou_str}']:.3f}\n" # fmt_str += f"Precision@{iou_str}: {out_dict[f'mpre_{iou_str}']:.3f}\n" # fmt_str += f"Recall@{iou_str}: {out_dict[f'mrec_{iou_str}']:.3f}\n" return fmt_str
[docs] def eval_cosense3d_final(self): out_dict = {} result_dict = getattr(self, f'cosense3d_result') for iou in self.iou_thr: scores = torch.cat(result_dict[iou]['scr'], dim=0) tps = torch.cat(result_dict[iou]['tp'], dim=0) n_pred = len(scores) n_gt = result_dict[iou]['gt'] ap, mpre, mrec, _ = self.eval_funcs.cal_ap_all_point(scores, tps, n_pred, n_gt) iou_str = f"{int(iou * 100)}" out_dict.update({f'ap_{iou_str}': ap, f'mpre_{iou_str}': mpre, f'mrec_{iou_str}': mrec}) return out_dict
[docs]class EvalBEVSemsegHook(BaseHook): def __init__(self, test_range, test_res=0.4, save_result=False, eval_static=True, bev_semseg_key='bev_semseg', gt_bev_key='bevmap', gt_boxes_key='global_bboxes_3d', **kwargs): super().__init__(**kwargs) self.test_range = test_range self.test_res = test_res self.save_result = save_result self.eval_static = eval_static self.bev_semseg_key = bev_semseg_key self.gt_bev_key = gt_bev_key self.gt_boxes_key = gt_boxes_key self.thrs = torch.arange(0.1, 1.1, 0.1) self.sx = int(round((self.test_range[3] - self.test_range[0]) / self.test_res)) self.sy = int(round((self.test_range[4] - self.test_range[1]) / self.test_res)) self.res_dict = { 'iou_dynamic_all': [], 'iou_dynamic_obs': [], 'iou_static_all': [], 'iou_static_obs': [], }
[docs] def set_logger(self, logger): super().set_logger(logger) logdir = os.path.join(logger.logdir, 'bev_semseg_eval') os.makedirs(logdir, exist_ok=True) self.logdir = logdir
[docs] def post_iter(self, runner, **kwargs): scene_tokens = runner.controller.data_manager.gather_ego_data('scene_tokens') frame = runner.controller.data_manager.gather_ego_data('frame') semseg = runner.controller.data_manager.gather_ego_data(self.bev_semseg_key) gt_bevmaps = runner.controller.data_manager.gather_ego_data(self.gt_bev_key) gt_boxes = runner.controller.data_manager.gather_ego_data(self.gt_boxes_key) for i, (cav_id, preds) in enumerate(semseg.items()): token = f'{scene_tokens[cav_id]}.{frame[cav_id]}' gt_dynamic_map = self.gt_dynamic_map(gt_boxes[cav_id]) self.cal_ious(preds, gt_dynamic_map, 'dynamic', token) if self.eval_static: gt_static_map = self.gt_static_map(gt_bevmaps[cav_id]) self.cal_ious(preds, gt_static_map, 'static', token)
[docs] def cal_ious(self, preds, gt_map, tag, token=None): conf = self.crop_map(preds[f'conf_map_{tag}']) unc = self.crop_map(preds[f'unc_map_{tag}']) obs_mask = self.crop_map(preds[f'obs_mask_{tag}']) self.res_dict[f'iou_{tag}_all'].append(self.iou(conf, unc, gt_map)) self.res_dict[f'iou_{tag}_obs'].append(self.iou(conf, unc, gt_map, obs_mask)) if self.save_result: img = torch.cat([gt_map, unc, conf[..., 1]], dim=0).detach().cpu().numpy() import matplotlib.pyplot as plt plt.imshow(img.T) plt.savefig(os.path.join(self.logdir, f'{token}.{tag}.jpg')) plt.close()
[docs] def iou(self, conf, unc, gt, obs_mask=None): ious = [] for thr in self.thrs: if obs_mask is None: pos_mask = torch.argmax(conf, dim=-1).bool() pos_mask = torch.logical_and(pos_mask, unc <= thr) gt_ = gt else: pos_mask = torch.argmax(conf[obs_mask], dim=-1).bool() pos_mask = torch.logical_and(pos_mask, unc[obs_mask] <= thr) gt_ = gt[obs_mask] mi = torch.logical_and(pos_mask, gt_).sum() mu = torch.logical_or(pos_mask, gt_).sum() ious.append(mi / mu) return torch.stack(ious, dim=0)
[docs] def gt_dynamic_map(self, boxes): # filter box range mask = boxes.new_ones((len(boxes),)).bool() dynamic_map = torch.ones((self.sx, self.sy), device=boxes.device) centers = boxes[:, :3] for i in range(3): mask = mask & (centers[:, i] > self.test_range[i]) & (centers[:, i] < self.test_range[i + 3]) boxes = boxes[mask] if len(boxes) > 0: indices = torch.stack(torch.where(dynamic_map), dim=1) xy = indices.float() xy = (xy + 0.5) * self.test_res xy[:, 0] += self.test_range[0] xy[:, 1] += self.test_range[1] xyz = F.pad(xy, (1, 1), 'constant', 0.0) boxes = F.pad(boxes, (1, 0), 'constant', 0.0) boxes[:, 3] = 0 boxes_decomposed, box_idx_of_pts = points_in_boxes_gpu( xyz, boxes, batch_size=1 ) inds = indices[box_idx_of_pts >= 0].T dynamic_map[inds[0], inds[1]] = 0 dynamic_map = torch.logical_not(dynamic_map) return dynamic_map
[docs] def gt_static_map(self, bevmap): # map has higher resolution, downsample 2x # bevmap = torch.flip(bevmap, [0]) return bevmap[::2, ::2]
[docs] def crop_map(self, bevmap): sx, sy = bevmap.shape[:2] sx_crop = (sx - self.sx) // 2 sy_crop = (sy - self.sy) // 2 return bevmap[sx_crop:-sx_crop, sy_crop:-sy_crop]
[docs] def post_epoch(self, runner, **kwargs): fmt_str = ("#################\n" "BEV SEMSEG RESULT\n" "#################\n") fmt_str += f"{'thr':18s} | " + " ".join([f"{v:4.1f} " for v in self.thrs]) + "\n" fmt_str += "-" * (23 + 70) + "\n" for k, vs in self.res_dict.items(): vs = torch.stack(vs, dim=0).mean(dim=0) * 100 if isinstance(vs, int): continue s1 = f"{k:18s} | " if isinstance(vs, float): s2 = f"{vs:4.1f} \n" else: s2 = " ".join([f"{v:4.1f} " for v in vs]) + "\n" fmt_str += s1 + s2 print(fmt_str) self.logger.log(fmt_str)