import sys
from tqdm import tqdm
import torch
import ssds.core.tools as tools
import ssds.core.visualize_funcs as vsf
from ssds.core.evaluation_metrics import MeanAveragePrecision
from ssds.modeling.layers.box import extract_targets
CURSOR_UP_ONE = "\x1b[1A"
ERASE_LINE = "\x1b[2K"
[docs]def train_anchor_based_epoch(
    model,
    data_loader,
    optimizer,
    cls_criterion,
    loc_criterion,
    anchors,
    num_classes,
    match,
    center_sampling_radius,
    writer,
    epoch,
    device,
):
    r""" the pipeline for training
    """
    model.train()
    title = "Train: "
    progress = tqdm(
        tools.IteratorTimer(data_loader),
        total=len(data_loader),
        smoothing=0.9,
        miniters=1,
        leave=True,
        desc=title,
    )
    loss_writer = {"loc_loss": tools.AverageMeter(), "cls_loss": tools.AverageMeter()}
    loss_writer.update(
        {
            "loc_loss_{}".format(j): tools.AverageMeter()
            for j, _ in enumerate(anchors.items())
        }
    )
    loss_writer.update(
        {
            "cls_loss_{}".format(j): tools.AverageMeter()
            for j, _ in enumerate(anchors.items())
        }
    )
    for batch_idx, (images, targets) in enumerate(progress):
        if images.device != device:
            images, targets = images.to(device), targets.to(device)
        if targets.dtype != torch.float:
            targets = targets.float()
        loc, conf = model(images)
        cls_losses, loc_losses, fg_targets = [], [], []
        for j, (stride, anchor) in enumerate(anchors.items()):
            size = conf[j].shape[-2:]
            conf_target, loc_target, depth = extract_targets(
                targets,
                anchors,
                num_classes,
                stride,
                size,
                match,
                center_sampling_radius,
            )
            fg_targets.append((depth > 0).sum().float().clamp(min=1))
            c = conf[j].view_as(conf_target).float()
            cls_mask = (depth >= 0).expand_as(conf_target).float()
            cls_loss = cls_criterion(c, conf_target, depth)
            cls_loss = cls_mask * cls_loss
            cls_losses.append(cls_loss.sum())
            l = loc[j].view_as(loc_target).float()
            loc_loss = loc_criterion(l, loc_target)
            loc_mask = (depth > 0).expand_as(loc_loss).float()
            loc_loss = loc_mask * loc_loss
            loc_losses.append(loc_loss.sum())
            if torch.isnan(loc_loss.sum()) or torch.isnan(cls_loss.sum()):
                continue
            loss_writer["cls_loss_{}".format(j)].update(cls_losses[-1].item())
            loss_writer["loc_loss_{}".format(j)].update(loc_losses[-1].item())
        fg_targets = torch.stack(fg_targets).sum()
        cls_loss = torch.stack(cls_losses).sum() / fg_targets
        loc_loss = torch.stack(loc_losses).sum() / fg_targets
        if torch.isnan(loc_loss) or torch.isnan(cls_loss):
            continue
        loss_writer["cls_loss"].update(cls_loss.item())
        loss_writer["loc_loss"].update(loc_loss.item())
        log = {
            "cls_loss": cls_loss.item(),
            "loc_loss": loc_loss.item(),
            "lr": optimizer.param_groups[0]["lr"],
        }
        optimizer.zero_grad()
        total_loss = cls_loss + loc_loss
        if total_loss.item() == float("Inf"):
            continue
        total_loss.backward()
        optimizer.step()
        # log per iter
        progress.set_description(title + tools.format_dict_of_loss(log))
        progress.update(1)
    progress.close()
    log = {"lr": optimizer.param_groups[0]["lr"]}
    log.update({k: v.avg for k, v in loss_writer.items()})
    print(
        CURSOR_UP_ONE + ERASE_LINE + "===>Avg Train: " + tools.format_dict_of_loss(log)
    )
    # log for tensorboard
    for key, value in log.items():
        writer.add_scalar("Train/{}".format(key), value, epoch)
    targets[:, :, 2:4] = targets[:, :, :2] + targets[:, :, 2:4]
    vsf.add_imagesWithBoxes(writer, "Train Image", images[:5], targets[:5], epoch=epoch)
    return 
[docs]def eval_anchor_based_epoch(
    model,
    data_loader,
    decoder,
    cls_criterion,
    loc_criterion,
    anchors,
    num_classes,
    writer,
    epoch,
    device,
):
    r""" the pipeline for evaluation
    """
    model.eval()
    title = "Eval: "
    progress = tqdm(
        tools.IteratorTimer(data_loader),
        total=len(data_loader),
        smoothing=0.9,
        miniters=1,
        leave=True,
        desc=title,
    )
    metric = MeanAveragePrecision(
        num_classes, decoder.conf_threshold, decoder.nms_threshold
    )
    for batch_idx, (images, targets) in enumerate(progress):
        if images.device != device:
            images, targets = images.to(device), targets.to(device)
        if targets.dtype != torch.float:
            targets = targets.float()
        loc, conf = model(images)
        # removed loss since the conf is sigmod in the evaluation stage,
        # the conf loss is not meaningful anymore
        detections = decoder(loc, conf, anchors)
        targets[:, :, 2:4] = targets[:, :, :2] + targets[:, :, 2:4]  # from xywh to ltrb
        metric(detections, targets)
        # log per iter
        progress.update(1)
    progress.close()
    mAP, (prec, rec, ap) = metric.get_results()
    log = {"mAP": mAP}
    if len(ap) < 5:
        for i, a in enumerate(ap):
            log["AP@cls{}".format(i)] = a
    print(
        CURSOR_UP_ONE + ERASE_LINE + "===>Avg Eval: " + tools.format_dict_of_loss(log)
    )
    # log for tensorboard
    for key, value in log.items():
        writer.add_scalar("Eval/{}".format(key), value, epoch)
    vsf.add_prCurve(writer, prec, rec, epoch=epoch)
    boxes = torch.cat((detections[1], detections[0][..., None]), dim=2)
    vsf.add_imagesWithMatchedBoxes(
        writer, "Eval Image", images[:5], boxes[:5], targets[:5], epoch=epoch
    )
    return