Source code for ssds.pipeline.pipeline_anchor_apex

import sys
import torch
import time
from datetime import timedelta

from apex import amp

import ssds.core.tools as tools
import ssds.core.visualize_funcs as vsf
from ssds.modeling.layers.box import extract_targets

CURSOR_UP_ONE = "\x1b[1A"
ERASE_LINE = "\x1b[2K"


[docs]class ModelWithLossBasic(torch.nn.Module): r""" Class use to help the gpu memory becomes more balance in ddp model """ def __init__( self, model, cls_criterion, loc_criterion, num_classes, match, center_sampling_radius, ): super(ModelWithLossBasic, self).__init__() self.model = model self.cls_criterion = cls_criterion self.loc_criterion = loc_criterion self.num_classes = num_classes self.match = match self.center_radius = center_sampling_radius def forward(self, images, targets, anchors): r""" :meta private: """ loc, conf = self.model(images) cls_losses, loc_losses, fg_targets = [], [], [] for j, (stride, anchor) in enumerate(anchors.items()): size = conf[j].shape[-2:] conf_target, loc_target, depth = extract_targets( targets, anchors, self.num_classes, stride, size, self.match, self.center_radius, ) fg_targets.append((depth > 0).sum().float().clamp(min=1)) c = conf[j].view_as(conf_target).float() cls_mask = (depth >= 0).expand_as(conf_target).float() cls_loss = self.cls_criterion(c, conf_target, depth) cls_loss = cls_mask * cls_loss cls_losses.append(cls_loss.sum()) l = loc[j].view_as(loc_target).float() loc_loss = self.loc_criterion(l, loc_target) loc_mask = (depth > 0).expand_as(loc_loss).float() loc_loss = loc_mask * loc_loss loc_losses.append(loc_loss.sum()) fg_targets = torch.stack(fg_targets).sum() cls_loss = torch.stack(cls_losses).sum() / fg_targets loc_loss = torch.stack(loc_losses).sum() / fg_targets return cls_loss, loc_loss, cls_losses, loc_losses
[docs]def train_anchor_based_epoch( model, data_loader, optimizer, anchors, writer, epoch, device, local_rank ): r""" the pipeline for training """ model.train() title = "Train: " if local_rank == 0: loss_writer = { "loc_loss": tools.AverageMeter(), "cls_loss": tools.AverageMeter(), } loss_writer.update( { "loc_loss_{}".format(j): tools.AverageMeter() for j, _ in enumerate(anchors.items()) } ) loss_writer.update( { "cls_loss_{}".format(j): tools.AverageMeter() for j, _ in enumerate(anchors.items()) } ) start_time = time.time() dataset_len = len(data_loader) for batch_idx, (images, targets) in enumerate(data_loader): if images.device != device: images, targets = images.to(device), targets.to(device) if targets.dtype != torch.float: targets = targets.float() cls_loss, loc_loss, cls_losses, loc_losses = model(images, targets, anchors) if torch.isnan(loc_loss) or torch.isnan(cls_loss): continue if local_rank == 0: for j, (cl, ll) in enumerate(zip(cls_losses, loc_losses)): loss_writer["cls_loss_{}".format(j)].update(cl.item()) loss_writer["loc_loss_{}".format(j)].update(ll.item()) loss_writer["cls_loss"].update(cls_loss.item()) loss_writer["loc_loss"].update(loc_loss.item()) log = { "cls_loss": cls_loss.item(), "loc_loss": loc_loss.item(), "lr": optimizer.param_groups[0]["lr"], } optimizer.zero_grad() total_loss = cls_loss + loc_loss if total_loss.item() == float("Inf") or torch.isnan(total_loss): continue with amp.scale_loss(total_loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if local_rank == 0: elapsed_time = time.time() - start_time estimat_time = elapsed_time * (dataset_len) / (batch_idx + 1) # log per iter print( title + tools.format_dict_of_loss(log), "|", batch_idx + 1, "/", dataset_len, "| Time:", timedelta(seconds=int(elapsed_time)), "/", timedelta(seconds=int(estimat_time)), "\r", end="", ) sys.stdout.flush() if local_rank == 0: log = {"lr": optimizer.param_groups[0]["lr"]} log.update({k: v.avg for k, v in loss_writer.items()}) print( CURSOR_UP_ONE + ERASE_LINE + "===>Avg Train: " + tools.format_dict_of_loss(log), " | Time: ", timedelta(seconds=int(time.time() - start_time)), ) # log for tensorboard for key, value in log.items(): writer.add_scalar("Train/{}".format(key), value, epoch) targets[:, :, 2:4] = targets[:, :, :2] + targets[:, :, 2:4] vsf.add_imagesWithBoxes( writer, "Train Image", images[:5], targets[:5], epoch=epoch ) return