import torch
# from ssds._C import decode as decode_cuda
# from ssds._C import nms as nms_cuda
INF = 100000
[docs]def generate_anchors(stride, ratio_vals, scales_vals):
    "Generate anchors coordinates from scales/ratios"
    scales = torch.FloatTensor(scales_vals).repeat(len(ratio_vals), 1)
    scales = scales.transpose(0, 1).contiguous().view(-1, 1)
    ratios = torch.FloatTensor(ratio_vals * len(scales_vals))
    wh = torch.FloatTensor([stride]).repeat(len(ratios), 2)
    ws = torch.round(torch.sqrt(wh[:, 0] * wh[:, 1] / ratios))
    dwh = torch.stack([ws, torch.round(ws * ratios)], dim=1)
    xy1 = 0.5 * (wh - dwh * scales)
    xy2 = 0.5 * (wh + dwh * scales) - 1
    return torch.cat([xy1, xy2], dim=1) 
[docs]def box2delta(boxes, anchors):
    "Convert boxes to deltas from anchors"
    anchors_wh = anchors[:, 2:] - anchors[:, :2] + 1
    anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh
    boxes_wh = boxes[:, 2:] - boxes[:, :2] + 1
    boxes_ctr = boxes[:, :2] + 0.5 * boxes_wh
    return torch.cat(
        [(boxes_ctr - anchors_ctr) / anchors_wh, torch.log(boxes_wh / anchors_wh)], 1
    ) 
[docs]def delta2box(deltas, anchors, size, stride):
    "Convert deltas from anchors to boxes"
    anchors_wh = anchors[:, 2:] - anchors[:, :2] + 1
    ctr = anchors[:, :2] + 0.5 * anchors_wh
    pred_ctr = deltas[:, :2] * anchors_wh + ctr
    pred_wh = torch.exp(deltas[:, 2:]) * anchors_wh
    m = torch.zeros([2], device=deltas.device, dtype=deltas.dtype)
    M = torch.tensor([size], device=deltas.device, dtype=deltas.dtype) * stride - 1
    clamp = lambda t: torch.max(m, torch.min(t, M))
    return torch.cat(
        [clamp(pred_ctr - 0.5 * pred_wh), clamp(pred_ctr + 0.5 * pred_wh - 1)], 1
    ) 
[docs]def get_sample_region(boxes, stride, anchor_points, radius=1.5):
    """
    This code is from
    https://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/
    maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42
    """
    # get mins and maxs value for center boarder
    stride = stride * radius
    center = (boxes[:, :2] + boxes[:, 2:]) / 2
    center_boxes = torch.cat((center - stride, center + stride), dim=-1)
    # generate the difference between grid points and center boarder
    # to check whether it is located in the center areas
    lt = (
        anchor_points[:, :, None, :]
        - torch.max(center_boxes[:, :2], boxes[:, :2])[None, None, :]
    )
    rb = (
        torch.min(center_boxes[:, 2:], boxes[:, 2:])[None, None, :]
        - anchor_points[:, :, None, :]
    )
    center_boxes = torch.cat((lt, rb), -1)
    inside_boxes_mask = center_boxes.min(-1)[0] > 0
    return inside_boxes_mask 
[docs]def snap_to_anchors_by_iou(
    boxes,
    size,
    stride,
    anchors,
    num_classes,
    match,
    center_sampling_radius,
    is_centerness,
    device,
):
    "Snap target boxes (x, y, w, h) to anchors by the iou between target boxes and anchors"
    num_anchors = anchors.size()[0] if anchors is not None else 1
    width, height = (int(size[0] / stride), int(size[1] / stride))
    if boxes.nelement() == 0:
        if is_centerness:
            return (
                torch.zeros([num_anchors, num_classes, height, width], device=device),
                torch.zeros([num_anchors, 4, height, width], device=device),
                torch.zeros([num_anchors, 1, height, width], device=device),
                torch.zeros([num_anchors, 1, height, width], device=device),
            )
        else:
            return (
                torch.zeros([num_anchors, num_classes, height, width], device=device),
                torch.zeros([num_anchors, 4, height, width], device=device),
                torch.zeros([num_anchors, 1, height, width], device=device),
            )
    boxes, classes = boxes.split(4, dim=1)
    match_threshold, unmatch_threshold = match
    # Generate anchors
    x, y = torch.meshgrid(
        [
            torch.arange(0, size[i], stride, device=device, dtype=classes.dtype)
            for i in range(2)
        ]
    )
    xyxy = torch.stack((x, y, x, y), 2).unsqueeze(0)
    anchors = anchors.view(-1, 1, 1, 4).to(dtype=classes.dtype)
    anchors = (xyxy + anchors).contiguous().view(-1, 4)
    # Compute overlap between boxes and anchors
    boxes = torch.cat([boxes[:, :2], boxes[:, :2] + boxes[:, 2:] - 1], 1)
    xy1 = torch.max(anchors[:, None, :2], boxes[:, :2])
    xy2 = torch.min(anchors[:, None, 2:], boxes[:, 2:])
    inter = torch.prod((xy2 - xy1 + 1).clamp(0), 2)
    boxes_area = torch.prod(boxes[:, 2:] - boxes[:, :2] + 1, 1)
    anchors_area = torch.prod(anchors[:, 2:] - anchors[:, :2] + 1, 1)
    overlap = inter / (anchors_area[:, None] + boxes_area - inter)
    # Keep best box per anchor
    overlap, indices = overlap.max(1)
    box_target = box2delta(boxes[indices], anchors)
    box_target = box_target.view(num_anchors, 1, width, height, 4)
    box_target = box_target.transpose(1, 4).transpose(2, 3)
    box_target = box_target.squeeze().contiguous()
    depth = torch.ones_like(overlap) * -1
    depth[overlap < unmatch_threshold] = 0  # background
    depth[overlap >= match_threshold] = (
        classes[indices][overlap >= match_threshold].squeeze() + 1
    )  # objects
    depth = depth.view(num_anchors, width, height)
    # center_sampling in ATSS
    if center_sampling_radius > 0:
        anchor_points = torch.stack((x, y), dim=2) + stride // 2
        inside_boxes_mask = (
            get_sample_region(boxes, stride, anchor_points, center_sampling_radius)
            .float()
            .max(-1)[0]
        )
        depth = torch.min(depth, inside_boxes_mask[None, ...])
    depth = depth.transpose(1, 2).contiguous()
    # Generate target classes
    cls_target = torch.zeros(
        (anchors.size()[0], num_classes + 1), device=device, dtype=boxes.dtype
    )
    if classes.nelement() == 0:
        classes = torch.LongTensor([num_classes], device=device).expand_as(indices)
    else:
        classes = classes[indices].long()
    classes = classes.view(-1, 1)
    classes[overlap < unmatch_threshold] = num_classes  # background has no class
    cls_target.scatter_(1, classes, 1)
    cls_target = cls_target[:, :num_classes].view(-1, 1, width, height, num_classes)
    cls_target = cls_target.transpose(1, 4).transpose(2, 3)
    cls_target = cls_target.squeeze().contiguous()
    if is_centerness:
        lt = torch.abs(box_target[:, :2] - 0.5 * torch.exp(box_target[:, 2:]))
        rb = torch.abs(box_target[:, :2] - 0.5 * torch.exp(box_target[:, 2:]))
        centerness = torch.sqrt(
            torch.prod(torch.min(lt, rb) / torch.max(lt, rb), dim=1)
        )
        return (
            cls_target.view(num_anchors, num_classes, height, width),
            box_target.view(num_anchors, 4, height, width),
            centerness.view(num_anchors, 1, height, width),
            depth.view(num_anchors, 1, height, width),
        )
    else:
        return (
            cls_target.view(num_anchors, num_classes, height, width),
            box_target.view(num_anchors, 4, height, width),
            depth.view(num_anchors, 1, height, width),
        ) 
[docs]def snap_to_anchors_by_scale(
    boxes,
    size,
    stride,
    anchors,
    num_classes,
    match,
    center_sampling_radius,
    is_centerness,
    device,
):
    "Snap target boxes (x, y, w, h) to anchors by the scale of target boxes"
    num_anchors = anchors.size()[0] if anchors is not None else 1
    width, height = (int(size[0] / stride), int(size[1] / stride))
    if boxes.nelement() == 0:
        if is_centerness:
            return (
                torch.zeros([num_anchors, num_classes, height, width], device=device),
                torch.zeros([num_anchors, 4, height, width], device=device),
                torch.zeros([num_anchors, 1, height, width], device=device),
                torch.zeros([num_anchors, 1, height, width], device=device),
            )
        else:
            return (
                torch.zeros([num_anchors, num_classes, height, width], device=device),
                torch.zeros([num_anchors, 4, height, width], device=device),
                torch.zeros([num_anchors, 1, height, width], device=device),
            )
    boxes, classes = boxes.split(4, dim=1)
    # generate threshold for each anchor
    anchors_wh = anchors[:, 2:] - anchors[:, :2] + 1
    anchors_size = torch.sqrt(torch.prod(anchors_wh, dim=1)).unsqueeze(1).unsqueeze(2)
    lower_threshold = (match[0] * anchors_size).clamp(-1)
    upper_threshold = match[1] * anchors_size
    # Generate anchors
    x, y = torch.meshgrid(
        [
            torch.arange(0, size[i], stride, device=device, dtype=classes.dtype)
            for i in range(2)
        ]
    )
    xyxy = torch.stack((x, y, x, y), 2).unsqueeze(0)
    anchors = anchors.view(-1, 1, 1, 4).to(dtype=classes.dtype)
    anchors = (xyxy + anchors).contiguous().view(-1, 4)
    anchor_points = (torch.stack((x, y), dim=2) + stride // 2).to(dtype=classes.dtype)
    # Compute overlap between boxes and anchors
    boxes = torch.cat([boxes[:, :2], boxes[:, :2] + boxes[:, 2:] - 1], 1)
    boxes_area = torch.sqrt(torch.prod(boxes[:, 2:] - boxes[:, :2] + 1, 1))
    # limit the positive sample anchor points inside of the box or center box
    if center_sampling_radius > 0:
        # limit the box size range for each location
        is_cared_in_the_level = (boxes_area >= lower_threshold) & (
            boxes_area <= upper_threshold
        )
        # center_sampling in ATSS
        anchor_points = torch.stack((x, y), dim=2) + stride // 2
        inside_boxes_mask = get_sample_region(boxes, stride, anchor_points).view(
            -1, boxes.shape[0]
        )
    else:
        anchor_points = (torch.stack((x, y), dim=2) + stride // 2).view(-1, 2)
        lt = anchor_points[:, None, :] - boxes[:, :2]
        rb = boxes[:, 2:] - anchor_points[:, None, :]
        box_target = torch.cat([lt, rb], dim=-1)
        # limit the regression range for each location
        max_box_target = box_target.max(dim=-1)[0]
        is_cared_in_the_level = (max_box_target >= lower_threshold) & (
            max_box_target <= upper_threshold
        )
        # no center sampling, it will use all the points within a ground-truth box
        inside_boxes_mask = box_target.min(dim=-1)[0] > 0
    # if there are still more than one objects for a location,
    # we choose the one with minimal area
    mask = (is_cared_in_the_level & inside_boxes_mask).view(-1, boxes.shape[0])
    boxes_area = boxes_area.repeat(mask.shape[0], 1)
    boxes_area[mask == 0] = INF
    mask, _ = mask.max(dim=1)
    min_area, indices = boxes_area.min(dim=1)
    # Keep best box per anchor
    box_target = box2delta(boxes[indices], anchors)
    box_target = box_target.view(num_anchors, 1, width, height, 4)
    box_target = box_target.transpose(1, 4).transpose(2, 3)
    box_target = box_target.squeeze().contiguous()
    depth = torch.ones_like(mask, dtype=classes.dtype) * -1
    depth[mask == 0] = 0  # background
    depth[mask != 0] = classes[indices][mask != 0].squeeze() + 1  # objects
    depth = depth.view(num_anchors, width, height).transpose(1, 2).contiguous()
    # Generate target classes
    cls_target = torch.zeros(
        (anchors.size()[0], num_classes + 1), device=device, dtype=boxes.dtype
    )
    if classes.nelement() == 0:
        classes = torch.LongTensor([num_classes], device=device).expand_as(indices)
    else:
        classes = classes[indices].long()
    classes = classes.view(-1, 1)
    classes[mask == 0] = num_classes  # background has no class
    cls_target.scatter_(1, classes, 1)
    cls_target = cls_target[:, :num_classes].view(-1, 1, width, height, num_classes)
    cls_target = cls_target.transpose(1, 4).transpose(2, 3)
    cls_target = cls_target.squeeze().contiguous()
    if is_centerness:
        lt = torch.abs(box_target[:, :2] - 0.5 * torch.exp(deltas[:, 2:]))
        rb = torch.abs(box_target[:, :2] - 0.5 * torch.exp(deltas[:, 2:]))
        centerness = torch.sqrt(
            torch.prod(torch.min(lt, rb) / torch.max(lt, rb), dim=1)
        )
        return (
            cls_target.view(num_anchors, num_classes, height, width),
            box_target.view(num_anchors, 4, height, width),
            centerness.view(num_anchors, 1, height, width),
            depth.view(num_anchors, 1, height, width),
        )
    else:
        return (
            cls_target.view(num_anchors, num_classes, height, width),
            box_target.view(num_anchors, 4, height, width),
            depth.view(num_anchors, 1, height, width),
        ) 
[docs]def decode(
    all_cls_head,
    all_box_head,
    stride=1,
    threshold=0.05,
    top_n=1000,
    anchors=None,
    rescore=True,
):
    "Box Decoding and Filtering"
    # if torch.cuda.is_available():
    #     return decode_cuda(all_cls_head.float(), all_box_head.float(),
    #         anchors.view(-1).tolist(), stride, threshold, top_n)
    device = all_cls_head.device
    anchors = anchors.to(device).type(all_cls_head.type())
    num_anchors = anchors.size()[0] if anchors is not None else 1
    num_classes = all_cls_head.size()[1] // num_anchors
    height, width = all_cls_head.size()[-2:]
    batch_size = all_cls_head.size()[0]
    out_scores = torch.zeros((batch_size, top_n), device=device)
    out_boxes = torch.zeros((batch_size, top_n, 4), device=device)
    out_classes = torch.zeros((batch_size, top_n), device=device)
    # Per item in batch
    for batch in range(batch_size):
        cls_head = all_cls_head[batch, :, :, :].contiguous().view(-1)
        box_head = all_box_head[batch, :, :, :].contiguous().view(-1, 4)
        # Keep scores over threshold
        keep = (cls_head >= threshold).nonzero().view(-1)
        if keep.nelement() == 0:
            continue
        # Gather top elements
        scores = torch.index_select(cls_head, 0, keep)
        scores, indices = torch.topk(scores, min(top_n, keep.size()[0]), dim=0)
        indices = torch.index_select(keep, 0, indices).view(-1)
        classes = (indices / width / height) % num_classes
        classes = classes.type(all_cls_head.type())
        # Infer kept bboxes
        x = indices % width
        y = (indices / width) % height
        a = indices / num_classes / height / width
        box_head = box_head.view(num_anchors, 4, height, width)
        boxes = box_head[a, :, y, x]
        if anchors is not None:
            grid = (
                torch.stack([x, y, x, y], 1).type(all_cls_head.type()) * stride
                + anchors[a, :]
            )
            boxes = delta2box(boxes, grid, [width, height], stride)
            if rescore:
                grid_center = (grid[:, :2] + grid[:, 2:]) / 2
                lt = torch.abs(grid_center - boxes[:, :2])
                rb = torch.abs(boxes[:, 2:] - grid_center)
                centerness = torch.sqrt(
                    torch.prod(torch.min(lt, rb) / torch.max(lt, rb), dim=1)
                )
                scores = scores * centerness
        out_scores[batch, : scores.size()[0]] = scores
        out_boxes[batch, : boxes.size()[0], :] = boxes
        out_classes[batch, : classes.size()[0]] = classes
    return out_scores, out_boxes, out_classes 
[docs]def nms(all_scores, all_boxes, all_classes, nms=0.5, ndetections=100, using_diou=True):
    "Non Maximum Suppression"
    # if torch.cuda.is_available():
    #     return nms_cuda(
    #         all_scores.float(), all_boxes.float(), all_classes.float(), nms, ndetections)
    device = all_scores.device
    batch_size = all_scores.size()[0]
    out_scores = torch.zeros((batch_size, ndetections), device=device)
    out_boxes = torch.zeros((batch_size, ndetections, 4), device=device)
    out_classes = torch.zeros((batch_size, ndetections), device=device)
    # Per item in batch
    for batch in range(batch_size):
        # Discard null scores
        keep = (all_scores[batch, :].view(-1) > 0).nonzero()
        scores = all_scores[batch, keep].view(-1)
        boxes = all_boxes[batch, keep, :].view(-1, 4)
        classes = all_classes[batch, keep].view(-1)
        if scores.nelement() == 0:
            continue
        # Sort boxes
        scores, indices = torch.sort(scores, descending=True)
        boxes, classes = boxes[indices], classes[indices]
        areas = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1).view(
            -1
        )
        keep = torch.ones(scores.nelement(), device=device, dtype=torch.uint8).view(-1)
        for i in range(ndetections):
            if i >= keep.nonzero().nelement() or i >= scores.nelement():
                i -= 1
                break
            # Find overlapping boxes with lower score
            xy1 = torch.max(boxes[:, :2], boxes[i, :2])
            xy2 = torch.min(boxes[:, 2:], boxes[i, 2:])
            inter = torch.prod((xy2 - xy1 + 1).clamp(0), 1)
            iou = inter / (areas + areas[i] - inter + 1e-7)
            if using_diou:
                outer_lt = torch.min(boxes[:, :2], boxes[i, :2])
                outer_rb = torch.max(boxes[:, 2:], boxes[i, 2:])
                inter_diag = ((boxes[:, :2] - boxes[i, :2]) ** 2).sum(dim=1)
                outer_diag = ((outer_rb - outer_lt) ** 2).sum(dim=1) + 1e-7
                diou = (iou - inter_diag / outer_diag).clamp(-1.0, 1.0)
                iou = diou
            criterion = (scores > scores[i]) | (iou <= nms) | (classes != classes[i])
            criterion[i] = 1
            # Only keep relevant boxes
            scores = scores[criterion.nonzero()].view(-1)
            boxes = boxes[criterion.nonzero(), :].view(-1, 4)
            classes = classes[criterion.nonzero()].view(-1)
            areas = areas[criterion.nonzero()].view(-1)
            keep[(~criterion).nonzero()] = 0
        out_scores[batch, : i + 1] = scores[: i + 1]
        out_boxes[batch, : i + 1, :] = boxes[: i + 1, :]
        out_classes[batch, : i + 1] = classes[: i + 1]
    return out_scores, out_boxes, out_classes