Source code for cosense3d.agents.core.train_runner



import os, glob, logging
from datetime import datetime

from torch.nn.parallel import DistributedDataParallel as DDP

from cosense3d.utils.train_utils import *
from cosense3d.utils.lr_scheduler import build_lr_scheduler
from cosense3d.utils.logger import LogMeter
from cosense3d.utils.misc import ensure_dir
from cosense3d.agents.core.base_runner import BaseRunner
from cosense3d.agents.utils.deco import save_ckpt_on_error


[docs]class TrainRunner(BaseRunner): def __init__(self, max_epoch, optimizer, lr_scheduler, gpus=0, resume_from=None, load_from=None, run_name='default', log_dir='work_dir', use_wandb=False, debug=False, **kwargs ): super().__init__(**kwargs) self.gpus = gpus self.gpu_id = 0 self.dist = False self.debug = debug if gpus > 0: self.dist = True self.gpu_id = int(os.environ.get("LOCAL_RANK", 0)) self.forward_runner.to_gpu(self.gpu_id) self.forward_runner = DDP(self.forward_runner, device_ids=[self.gpu_id]) self.optimizer = build_optimizer(self.forward_runner, optimizer) self.lr_scheduler = build_lr_scheduler(self.optimizer, lr_scheduler, len(self.dataloader)) self.total_epochs = max_epoch self.start_epoch = 1 self.resume(resume_from, load_from) self.setup_logger(resume_from, run_name, log_dir, use_wandb)
[docs] def setup_logger(self, resume_from, run_name, log_dir, use_wandb): if resume_from is not None: if os.path.isfile(resume_from): log_path = os.path.dirname(resume_from) else: log_path = resume_from else: now = datetime.now().strftime('%m-%d-%H-%M-%S') run_name = run_name + '_' + now log_path = os.path.join(log_dir, run_name) ensure_dir(log_path) wandb_project_name = run_name if use_wandb else None self.logger = LogMeter(self.total_iter, log_path, log_every=self.log_every, wandb_project=wandb_project_name)
[docs] def resume(self, resume_from, load_from): if resume_from is not None or load_from is not None: load_path = resume_from if resume_from is not None else load_from assert os.path.exists(load_path), f'resume/load path does not exist: {resume_from}.' if os.path.isdir(load_path): ckpts = glob.glob(os.path.join(load_path, 'epoch*.pth')) if len(ckpts) > 0: epochs = [int(os.path.basename(ckpt)[5:-4]) for ckpt in ckpts] max_idx = epochs.index(max(epochs)) ckpt = ckpts[max_idx] elif os.path.exists(os.path.join(load_path, 'last.pth')): ckpt = os.path.join(load_path, 'last.pth') else: raise IOError(f'No checkpoint found in directory {load_path}.') elif os.path.isfile(load_path): ckpt = load_path else: raise IOError(f'Failed to load checkpoint from {load_path}.') logging.info(f"Resuming the model from checkpoint: {ckpt}") ckpt = torch.load(ckpt) load_model_dict(self.forward_runner, ckpt['model']) if resume_from is not None: self.start_epoch = ckpt['epoch'] + 1 self.epoch = ckpt['epoch'] + 1 if 'lr_scheduler' in ckpt: self.lr_scheduler.load_state_dict(ckpt['lr_scheduler']) try: if 'optimizer' in ckpt: self.optimizer.load_state_dict(ckpt['optimizer']) except: warnings.warn("Cannot load optimizer state_dict, " "there might be training parameter changes, " "please consider using 'load-from'.")
[docs] def run(self): with torch.autograd.set_detect_anomaly(True): for i in range(self.start_epoch, self.total_epochs + 1): self.hooks(self, 'pre_epoch') self.run_epoch() self.hooks(self, 'post_epoch') self.lr_scheduler.step_epoch(i) self.epoch += 1 self.iter = 1
[docs] def step(self): data = self.next_batch() self.run_itr(data)
[docs] def run_epoch(self): if self.dist: self.dataloader.sampler.set_epoch(self.epoch) for data in self.dataloader: # print(f'{self.gpu_id}: run_itr{self.iter}: 0') self.hooks(self, 'pre_iter') self.run_itr(data) self.hooks(self, 'post_iter')
@save_ckpt_on_error def run_itr(self, data): load_tensors_to_gpu(data, self.gpu_id) self.optimizer.zero_grad() total_loss, loss_dict = self.controller.train_forward( data, epoch=self.epoch, itr=self.iter, gpu_id=self.gpu_id) total_loss.backward() grad_norm = clip_grads(self.controller.parameters) loss_dict['grad_norm'] = grad_norm # Updating parameters self.optimizer.step() self.lr_scheduler.step_itr(self.iter + self.epoch * self.total_iter) if self.logger is not None and self.gpu_id == 0: # rec_lr = self.lr_scheduler.optimizer.param_groups[0]['lr'] rec_lr = self.lr_scheduler.get_last_lr()[0] self.logger.log(self.epoch, self.iter, rec_lr, **loss_dict) del data self.iter += 1