Source code for ssds.pipeline.pipeline_anchor_basic

import sys
from tqdm import tqdm
import torch

import ssds.core.tools as tools
import ssds.core.visualize_funcs as vsf
from ssds.core.evaluation_metrics import MeanAveragePrecision
from ssds.modeling.layers.box import extract_targets

CURSOR_UP_ONE = "\x1b[1A"
ERASE_LINE = "\x1b[2K"


[docs]def train_anchor_based_epoch( model, data_loader, optimizer, cls_criterion, loc_criterion, anchors, num_classes, match, center_sampling_radius, writer, epoch, device, ): r""" the pipeline for training """ model.train() title = "Train: " progress = tqdm( tools.IteratorTimer(data_loader), total=len(data_loader), smoothing=0.9, miniters=1, leave=True, desc=title, ) loss_writer = {"loc_loss": tools.AverageMeter(), "cls_loss": tools.AverageMeter()} loss_writer.update( { "loc_loss_{}".format(j): tools.AverageMeter() for j, _ in enumerate(anchors.items()) } ) loss_writer.update( { "cls_loss_{}".format(j): tools.AverageMeter() for j, _ in enumerate(anchors.items()) } ) for batch_idx, (images, targets) in enumerate(progress): if images.device != device: images, targets = images.to(device), targets.to(device) if targets.dtype != torch.float: targets = targets.float() loc, conf = model(images) cls_losses, loc_losses, fg_targets = [], [], [] for j, (stride, anchor) in enumerate(anchors.items()): size = conf[j].shape[-2:] conf_target, loc_target, depth = extract_targets( targets, anchors, num_classes, stride, size, match, center_sampling_radius, ) fg_targets.append((depth > 0).sum().float().clamp(min=1)) c = conf[j].view_as(conf_target).float() cls_mask = (depth >= 0).expand_as(conf_target).float() cls_loss = cls_criterion(c, conf_target, depth) cls_loss = cls_mask * cls_loss cls_losses.append(cls_loss.sum()) l = loc[j].view_as(loc_target).float() loc_loss = loc_criterion(l, loc_target) loc_mask = (depth > 0).expand_as(loc_loss).float() loc_loss = loc_mask * loc_loss loc_losses.append(loc_loss.sum()) if torch.isnan(loc_loss.sum()) or torch.isnan(cls_loss.sum()): continue loss_writer["cls_loss_{}".format(j)].update(cls_losses[-1].item()) loss_writer["loc_loss_{}".format(j)].update(loc_losses[-1].item()) fg_targets = torch.stack(fg_targets).sum() cls_loss = torch.stack(cls_losses).sum() / fg_targets loc_loss = torch.stack(loc_losses).sum() / fg_targets if torch.isnan(loc_loss) or torch.isnan(cls_loss): continue loss_writer["cls_loss"].update(cls_loss.item()) loss_writer["loc_loss"].update(loc_loss.item()) log = { "cls_loss": cls_loss.item(), "loc_loss": loc_loss.item(), "lr": optimizer.param_groups[0]["lr"], } optimizer.zero_grad() total_loss = cls_loss + loc_loss if total_loss.item() == float("Inf"): continue total_loss.backward() optimizer.step() # log per iter progress.set_description(title + tools.format_dict_of_loss(log)) progress.update(1) progress.close() log = {"lr": optimizer.param_groups[0]["lr"]} log.update({k: v.avg for k, v in loss_writer.items()}) print( CURSOR_UP_ONE + ERASE_LINE + "===>Avg Train: " + tools.format_dict_of_loss(log) ) # log for tensorboard for key, value in log.items(): writer.add_scalar("Train/{}".format(key), value, epoch) targets[:, :, 2:4] = targets[:, :, :2] + targets[:, :, 2:4] vsf.add_imagesWithBoxes(writer, "Train Image", images[:5], targets[:5], epoch=epoch) return
[docs]def eval_anchor_based_epoch( model, data_loader, decoder, cls_criterion, loc_criterion, anchors, num_classes, writer, epoch, device, ): r""" the pipeline for evaluation """ model.eval() title = "Eval: " progress = tqdm( tools.IteratorTimer(data_loader), total=len(data_loader), smoothing=0.9, miniters=1, leave=True, desc=title, ) metric = MeanAveragePrecision( num_classes, decoder.conf_threshold, decoder.nms_threshold ) for batch_idx, (images, targets) in enumerate(progress): if images.device != device: images, targets = images.to(device), targets.to(device) if targets.dtype != torch.float: targets = targets.float() loc, conf = model(images) # removed loss since the conf is sigmod in the evaluation stage, # the conf loss is not meaningful anymore detections = decoder(loc, conf, anchors) targets[:, :, 2:4] = targets[:, :, :2] + targets[:, :, 2:4] # from xywh to ltrb metric(detections, targets) # log per iter progress.update(1) progress.close() mAP, (prec, rec, ap) = metric.get_results() log = {"mAP": mAP} if len(ap) < 5: for i, a in enumerate(ap): log["AP@cls{}".format(i)] = a print( CURSOR_UP_ONE + ERASE_LINE + "===>Avg Eval: " + tools.format_dict_of_loss(log) ) # log for tensorboard for key, value in log.items(): writer.add_scalar("Eval/{}".format(key), value, epoch) vsf.add_prCurve(writer, prec, rec, epoch=epoch) boxes = torch.cat((detections[1], detections[0][..., None]), dim=2) vsf.add_imagesWithMatchedBoxes( writer, "Eval Image", images[:5], boxes[:5], targets[:5], epoch=epoch ) return