Source code for cosense3d.agents.core.base_runner



from cosense3d.utils.train_utils import *
from cosense3d.agents.core.hooks import Hooks


[docs]class BaseRunner: def __init__(self, dataloader, controller, gpus=0, log_every=10, hooks=None, **kwargs ): self.dataloader = dataloader self.data_iter = iter(dataloader) self.total_iter = len(dataloader) self.iter = 1 self.epoch = 1 self.controller = controller self.forward_runner = controller.forward_runner self.hooks = Hooks(hooks) self.gpus = gpus self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.log_every = log_every self.init()
[docs] def init(self): if self.forward_runner is not None: self.forward_runner.to(self.device)
[docs] def setup_logger(self, *args, **kwargs): pass
[docs] def set_logdir(self, logdir): self.logger.log_path = logdir
@property def logdir(self): if hasattr(self, 'logger'): return self.logger.logdir else: return None
[docs] def run(self): raise NotImplementedError
[docs] def next_batch(self): if self.iter >= self.total_iter: self.iter = 1 self.epoch += 1 self.data_iter = iter(self.dataloader) batch = next(self.data_iter) return batch
[docs] def vis_data(self, keys=None, **kwargs): if keys is None: keys = ['points', 'imgs', 'bboxes2d', 'lidar2img', 'global_labels', 'local_labels'] else: keys = list(set(keys)) return self.controller.data_manager.gather_vis_data(keys=keys)