import sys
import torch
import time
from datetime import timedelta
from apex import amp
import ssds.core.tools as tools
import ssds.core.visualize_funcs as vsf
from ssds.modeling.layers.box import extract_targets
CURSOR_UP_ONE = "\x1b[1A"
ERASE_LINE = "\x1b[2K"
[docs]class ModelWithLossBasic(torch.nn.Module):
    r""" Class use to help the gpu memory becomes more balance in ddp model
    """
    def __init__(
        self,
        model,
        cls_criterion,
        loc_criterion,
        num_classes,
        match,
        center_sampling_radius,
    ):
        super(ModelWithLossBasic, self).__init__()
        self.model = model
        self.cls_criterion = cls_criterion
        self.loc_criterion = loc_criterion
        self.num_classes = num_classes
        self.match = match
        self.center_radius = center_sampling_radius
    def forward(self, images, targets, anchors):
        r""" 
        :meta private:
        """
        loc, conf = self.model(images)
        cls_losses, loc_losses, fg_targets = [], [], []
        for j, (stride, anchor) in enumerate(anchors.items()):
            size = conf[j].shape[-2:]
            conf_target, loc_target, depth = extract_targets(
                targets,
                anchors,
                self.num_classes,
                stride,
                size,
                self.match,
                self.center_radius,
            )
            fg_targets.append((depth > 0).sum().float().clamp(min=1))
            c = conf[j].view_as(conf_target).float()
            cls_mask = (depth >= 0).expand_as(conf_target).float()
            cls_loss = self.cls_criterion(c, conf_target, depth)
            cls_loss = cls_mask * cls_loss
            cls_losses.append(cls_loss.sum())
            l = loc[j].view_as(loc_target).float()
            loc_loss = self.loc_criterion(l, loc_target)
            loc_mask = (depth > 0).expand_as(loc_loss).float()
            loc_loss = loc_mask * loc_loss
            loc_losses.append(loc_loss.sum())
        fg_targets = torch.stack(fg_targets).sum()
        cls_loss = torch.stack(cls_losses).sum() / fg_targets
        loc_loss = torch.stack(loc_losses).sum() / fg_targets
        return cls_loss, loc_loss, cls_losses, loc_losses 
[docs]def train_anchor_based_epoch(
    model, data_loader, optimizer, anchors, writer, epoch, device, local_rank
):
    r""" the pipeline for training
    """
    model.train()
    title = "Train: "
    if local_rank == 0:
        loss_writer = {
            "loc_loss": tools.AverageMeter(),
            "cls_loss": tools.AverageMeter(),
        }
        loss_writer.update(
            {
                "loc_loss_{}".format(j): tools.AverageMeter()
                for j, _ in enumerate(anchors.items())
            }
        )
        loss_writer.update(
            {
                "cls_loss_{}".format(j): tools.AverageMeter()
                for j, _ in enumerate(anchors.items())
            }
        )
        start_time = time.time()
        dataset_len = len(data_loader)
    for batch_idx, (images, targets) in enumerate(data_loader):
        if images.device != device:
            images, targets = images.to(device), targets.to(device)
        if targets.dtype != torch.float:
            targets = targets.float()
        cls_loss, loc_loss, cls_losses, loc_losses = model(images, targets, anchors)
        if torch.isnan(loc_loss) or torch.isnan(cls_loss):
            continue
        if local_rank == 0:
            for j, (cl, ll) in enumerate(zip(cls_losses, loc_losses)):
                loss_writer["cls_loss_{}".format(j)].update(cl.item())
                loss_writer["loc_loss_{}".format(j)].update(ll.item())
            loss_writer["cls_loss"].update(cls_loss.item())
            loss_writer["loc_loss"].update(loc_loss.item())
            log = {
                "cls_loss": cls_loss.item(),
                "loc_loss": loc_loss.item(),
                "lr": optimizer.param_groups[0]["lr"],
            }
        optimizer.zero_grad()
        total_loss = cls_loss + loc_loss
        if total_loss.item() == float("Inf") or torch.isnan(total_loss):
            continue
        with amp.scale_loss(total_loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()
        if local_rank == 0:
            elapsed_time = time.time() - start_time
            estimat_time = elapsed_time * (dataset_len) / (batch_idx + 1)
            # log per iter
            print(
                title + tools.format_dict_of_loss(log),
                "|",
                batch_idx + 1,
                "/",
                dataset_len,
                "| Time:",
                timedelta(seconds=int(elapsed_time)),
                "/",
                timedelta(seconds=int(estimat_time)),
                "\r",
                end="",
            )
            sys.stdout.flush()
    if local_rank == 0:
        log = {"lr": optimizer.param_groups[0]["lr"]}
        log.update({k: v.avg for k, v in loss_writer.items()})
        print(
            CURSOR_UP_ONE
            + ERASE_LINE
            + "===>Avg Train: "
            + tools.format_dict_of_loss(log),
            " | Time: ",
            timedelta(seconds=int(time.time() - start_time)),
        )
        # log for tensorboard
        for key, value in log.items():
            writer.add_scalar("Train/{}".format(key), value, epoch)
        targets[:, :, 2:4] = targets[:, :, :2] + targets[:, :, 2:4]
        vsf.add_imagesWithBoxes(
            writer, "Train Image", images[:5], targets[:5], epoch=epoch
        )
    return