import torch
# from ssds._C import decode as decode_cuda
# from ssds._C import nms as nms_cuda
INF = 100000
[docs]def generate_anchors(stride, ratio_vals, scales_vals):
"Generate anchors coordinates from scales/ratios"
scales = torch.FloatTensor(scales_vals).repeat(len(ratio_vals), 1)
scales = scales.transpose(0, 1).contiguous().view(-1, 1)
ratios = torch.FloatTensor(ratio_vals * len(scales_vals))
wh = torch.FloatTensor([stride]).repeat(len(ratios), 2)
ws = torch.round(torch.sqrt(wh[:, 0] * wh[:, 1] / ratios))
dwh = torch.stack([ws, torch.round(ws * ratios)], dim=1)
xy1 = 0.5 * (wh - dwh * scales)
xy2 = 0.5 * (wh + dwh * scales) - 1
return torch.cat([xy1, xy2], dim=1)
[docs]def box2delta(boxes, anchors):
"Convert boxes to deltas from anchors"
anchors_wh = anchors[:, 2:] - anchors[:, :2] + 1
anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh
boxes_wh = boxes[:, 2:] - boxes[:, :2] + 1
boxes_ctr = boxes[:, :2] + 0.5 * boxes_wh
return torch.cat(
[(boxes_ctr - anchors_ctr) / anchors_wh, torch.log(boxes_wh / anchors_wh)], 1
)
[docs]def delta2box(deltas, anchors, size, stride):
"Convert deltas from anchors to boxes"
anchors_wh = anchors[:, 2:] - anchors[:, :2] + 1
ctr = anchors[:, :2] + 0.5 * anchors_wh
pred_ctr = deltas[:, :2] * anchors_wh + ctr
pred_wh = torch.exp(deltas[:, 2:]) * anchors_wh
m = torch.zeros([2], device=deltas.device, dtype=deltas.dtype)
M = torch.tensor([size], device=deltas.device, dtype=deltas.dtype) * stride - 1
clamp = lambda t: torch.max(m, torch.min(t, M))
return torch.cat(
[clamp(pred_ctr - 0.5 * pred_wh), clamp(pred_ctr + 0.5 * pred_wh - 1)], 1
)
[docs]def get_sample_region(boxes, stride, anchor_points, radius=1.5):
"""
This code is from
https://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/
maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42
"""
# get mins and maxs value for center boarder
stride = stride * radius
center = (boxes[:, :2] + boxes[:, 2:]) / 2
center_boxes = torch.cat((center - stride, center + stride), dim=-1)
# generate the difference between grid points and center boarder
# to check whether it is located in the center areas
lt = (
anchor_points[:, :, None, :]
- torch.max(center_boxes[:, :2], boxes[:, :2])[None, None, :]
)
rb = (
torch.min(center_boxes[:, 2:], boxes[:, 2:])[None, None, :]
- anchor_points[:, :, None, :]
)
center_boxes = torch.cat((lt, rb), -1)
inside_boxes_mask = center_boxes.min(-1)[0] > 0
return inside_boxes_mask
[docs]def snap_to_anchors_by_iou(
boxes,
size,
stride,
anchors,
num_classes,
match,
center_sampling_radius,
is_centerness,
device,
):
"Snap target boxes (x, y, w, h) to anchors by the iou between target boxes and anchors"
num_anchors = anchors.size()[0] if anchors is not None else 1
width, height = (int(size[0] / stride), int(size[1] / stride))
if boxes.nelement() == 0:
if is_centerness:
return (
torch.zeros([num_anchors, num_classes, height, width], device=device),
torch.zeros([num_anchors, 4, height, width], device=device),
torch.zeros([num_anchors, 1, height, width], device=device),
torch.zeros([num_anchors, 1, height, width], device=device),
)
else:
return (
torch.zeros([num_anchors, num_classes, height, width], device=device),
torch.zeros([num_anchors, 4, height, width], device=device),
torch.zeros([num_anchors, 1, height, width], device=device),
)
boxes, classes = boxes.split(4, dim=1)
match_threshold, unmatch_threshold = match
# Generate anchors
x, y = torch.meshgrid(
[
torch.arange(0, size[i], stride, device=device, dtype=classes.dtype)
for i in range(2)
]
)
xyxy = torch.stack((x, y, x, y), 2).unsqueeze(0)
anchors = anchors.view(-1, 1, 1, 4).to(dtype=classes.dtype)
anchors = (xyxy + anchors).contiguous().view(-1, 4)
# Compute overlap between boxes and anchors
boxes = torch.cat([boxes[:, :2], boxes[:, :2] + boxes[:, 2:] - 1], 1)
xy1 = torch.max(anchors[:, None, :2], boxes[:, :2])
xy2 = torch.min(anchors[:, None, 2:], boxes[:, 2:])
inter = torch.prod((xy2 - xy1 + 1).clamp(0), 2)
boxes_area = torch.prod(boxes[:, 2:] - boxes[:, :2] + 1, 1)
anchors_area = torch.prod(anchors[:, 2:] - anchors[:, :2] + 1, 1)
overlap = inter / (anchors_area[:, None] + boxes_area - inter)
# Keep best box per anchor
overlap, indices = overlap.max(1)
box_target = box2delta(boxes[indices], anchors)
box_target = box_target.view(num_anchors, 1, width, height, 4)
box_target = box_target.transpose(1, 4).transpose(2, 3)
box_target = box_target.squeeze().contiguous()
depth = torch.ones_like(overlap) * -1
depth[overlap < unmatch_threshold] = 0 # background
depth[overlap >= match_threshold] = (
classes[indices][overlap >= match_threshold].squeeze() + 1
) # objects
depth = depth.view(num_anchors, width, height)
# center_sampling in ATSS
if center_sampling_radius > 0:
anchor_points = torch.stack((x, y), dim=2) + stride // 2
inside_boxes_mask = (
get_sample_region(boxes, stride, anchor_points, center_sampling_radius)
.float()
.max(-1)[0]
)
depth = torch.min(depth, inside_boxes_mask[None, ...])
depth = depth.transpose(1, 2).contiguous()
# Generate target classes
cls_target = torch.zeros(
(anchors.size()[0], num_classes + 1), device=device, dtype=boxes.dtype
)
if classes.nelement() == 0:
classes = torch.LongTensor([num_classes], device=device).expand_as(indices)
else:
classes = classes[indices].long()
classes = classes.view(-1, 1)
classes[overlap < unmatch_threshold] = num_classes # background has no class
cls_target.scatter_(1, classes, 1)
cls_target = cls_target[:, :num_classes].view(-1, 1, width, height, num_classes)
cls_target = cls_target.transpose(1, 4).transpose(2, 3)
cls_target = cls_target.squeeze().contiguous()
if is_centerness:
lt = torch.abs(box_target[:, :2] - 0.5 * torch.exp(box_target[:, 2:]))
rb = torch.abs(box_target[:, :2] - 0.5 * torch.exp(box_target[:, 2:]))
centerness = torch.sqrt(
torch.prod(torch.min(lt, rb) / torch.max(lt, rb), dim=1)
)
return (
cls_target.view(num_anchors, num_classes, height, width),
box_target.view(num_anchors, 4, height, width),
centerness.view(num_anchors, 1, height, width),
depth.view(num_anchors, 1, height, width),
)
else:
return (
cls_target.view(num_anchors, num_classes, height, width),
box_target.view(num_anchors, 4, height, width),
depth.view(num_anchors, 1, height, width),
)
[docs]def snap_to_anchors_by_scale(
boxes,
size,
stride,
anchors,
num_classes,
match,
center_sampling_radius,
is_centerness,
device,
):
"Snap target boxes (x, y, w, h) to anchors by the scale of target boxes"
num_anchors = anchors.size()[0] if anchors is not None else 1
width, height = (int(size[0] / stride), int(size[1] / stride))
if boxes.nelement() == 0:
if is_centerness:
return (
torch.zeros([num_anchors, num_classes, height, width], device=device),
torch.zeros([num_anchors, 4, height, width], device=device),
torch.zeros([num_anchors, 1, height, width], device=device),
torch.zeros([num_anchors, 1, height, width], device=device),
)
else:
return (
torch.zeros([num_anchors, num_classes, height, width], device=device),
torch.zeros([num_anchors, 4, height, width], device=device),
torch.zeros([num_anchors, 1, height, width], device=device),
)
boxes, classes = boxes.split(4, dim=1)
# generate threshold for each anchor
anchors_wh = anchors[:, 2:] - anchors[:, :2] + 1
anchors_size = torch.sqrt(torch.prod(anchors_wh, dim=1)).unsqueeze(1).unsqueeze(2)
lower_threshold = (match[0] * anchors_size).clamp(-1)
upper_threshold = match[1] * anchors_size
# Generate anchors
x, y = torch.meshgrid(
[
torch.arange(0, size[i], stride, device=device, dtype=classes.dtype)
for i in range(2)
]
)
xyxy = torch.stack((x, y, x, y), 2).unsqueeze(0)
anchors = anchors.view(-1, 1, 1, 4).to(dtype=classes.dtype)
anchors = (xyxy + anchors).contiguous().view(-1, 4)
anchor_points = (torch.stack((x, y), dim=2) + stride // 2).to(dtype=classes.dtype)
# Compute overlap between boxes and anchors
boxes = torch.cat([boxes[:, :2], boxes[:, :2] + boxes[:, 2:] - 1], 1)
boxes_area = torch.sqrt(torch.prod(boxes[:, 2:] - boxes[:, :2] + 1, 1))
# limit the positive sample anchor points inside of the box or center box
if center_sampling_radius > 0:
# limit the box size range for each location
is_cared_in_the_level = (boxes_area >= lower_threshold) & (
boxes_area <= upper_threshold
)
# center_sampling in ATSS
anchor_points = torch.stack((x, y), dim=2) + stride // 2
inside_boxes_mask = get_sample_region(boxes, stride, anchor_points).view(
-1, boxes.shape[0]
)
else:
anchor_points = (torch.stack((x, y), dim=2) + stride // 2).view(-1, 2)
lt = anchor_points[:, None, :] - boxes[:, :2]
rb = boxes[:, 2:] - anchor_points[:, None, :]
box_target = torch.cat([lt, rb], dim=-1)
# limit the regression range for each location
max_box_target = box_target.max(dim=-1)[0]
is_cared_in_the_level = (max_box_target >= lower_threshold) & (
max_box_target <= upper_threshold
)
# no center sampling, it will use all the points within a ground-truth box
inside_boxes_mask = box_target.min(dim=-1)[0] > 0
# if there are still more than one objects for a location,
# we choose the one with minimal area
mask = (is_cared_in_the_level & inside_boxes_mask).view(-1, boxes.shape[0])
boxes_area = boxes_area.repeat(mask.shape[0], 1)
boxes_area[mask == 0] = INF
mask, _ = mask.max(dim=1)
min_area, indices = boxes_area.min(dim=1)
# Keep best box per anchor
box_target = box2delta(boxes[indices], anchors)
box_target = box_target.view(num_anchors, 1, width, height, 4)
box_target = box_target.transpose(1, 4).transpose(2, 3)
box_target = box_target.squeeze().contiguous()
depth = torch.ones_like(mask, dtype=classes.dtype) * -1
depth[mask == 0] = 0 # background
depth[mask != 0] = classes[indices][mask != 0].squeeze() + 1 # objects
depth = depth.view(num_anchors, width, height).transpose(1, 2).contiguous()
# Generate target classes
cls_target = torch.zeros(
(anchors.size()[0], num_classes + 1), device=device, dtype=boxes.dtype
)
if classes.nelement() == 0:
classes = torch.LongTensor([num_classes], device=device).expand_as(indices)
else:
classes = classes[indices].long()
classes = classes.view(-1, 1)
classes[mask == 0] = num_classes # background has no class
cls_target.scatter_(1, classes, 1)
cls_target = cls_target[:, :num_classes].view(-1, 1, width, height, num_classes)
cls_target = cls_target.transpose(1, 4).transpose(2, 3)
cls_target = cls_target.squeeze().contiguous()
if is_centerness:
lt = torch.abs(box_target[:, :2] - 0.5 * torch.exp(deltas[:, 2:]))
rb = torch.abs(box_target[:, :2] - 0.5 * torch.exp(deltas[:, 2:]))
centerness = torch.sqrt(
torch.prod(torch.min(lt, rb) / torch.max(lt, rb), dim=1)
)
return (
cls_target.view(num_anchors, num_classes, height, width),
box_target.view(num_anchors, 4, height, width),
centerness.view(num_anchors, 1, height, width),
depth.view(num_anchors, 1, height, width),
)
else:
return (
cls_target.view(num_anchors, num_classes, height, width),
box_target.view(num_anchors, 4, height, width),
depth.view(num_anchors, 1, height, width),
)
[docs]def decode(
all_cls_head,
all_box_head,
stride=1,
threshold=0.05,
top_n=1000,
anchors=None,
rescore=True,
):
"Box Decoding and Filtering"
# if torch.cuda.is_available():
# return decode_cuda(all_cls_head.float(), all_box_head.float(),
# anchors.view(-1).tolist(), stride, threshold, top_n)
device = all_cls_head.device
anchors = anchors.to(device).type(all_cls_head.type())
num_anchors = anchors.size()[0] if anchors is not None else 1
num_classes = all_cls_head.size()[1] // num_anchors
height, width = all_cls_head.size()[-2:]
batch_size = all_cls_head.size()[0]
out_scores = torch.zeros((batch_size, top_n), device=device)
out_boxes = torch.zeros((batch_size, top_n, 4), device=device)
out_classes = torch.zeros((batch_size, top_n), device=device)
# Per item in batch
for batch in range(batch_size):
cls_head = all_cls_head[batch, :, :, :].contiguous().view(-1)
box_head = all_box_head[batch, :, :, :].contiguous().view(-1, 4)
# Keep scores over threshold
keep = (cls_head >= threshold).nonzero().view(-1)
if keep.nelement() == 0:
continue
# Gather top elements
scores = torch.index_select(cls_head, 0, keep)
scores, indices = torch.topk(scores, min(top_n, keep.size()[0]), dim=0)
indices = torch.index_select(keep, 0, indices).view(-1)
classes = (indices / width / height) % num_classes
classes = classes.type(all_cls_head.type())
# Infer kept bboxes
x = indices % width
y = (indices / width) % height
a = indices / num_classes / height / width
box_head = box_head.view(num_anchors, 4, height, width)
boxes = box_head[a, :, y, x]
if anchors is not None:
grid = (
torch.stack([x, y, x, y], 1).type(all_cls_head.type()) * stride
+ anchors[a, :]
)
boxes = delta2box(boxes, grid, [width, height], stride)
if rescore:
grid_center = (grid[:, :2] + grid[:, 2:]) / 2
lt = torch.abs(grid_center - boxes[:, :2])
rb = torch.abs(boxes[:, 2:] - grid_center)
centerness = torch.sqrt(
torch.prod(torch.min(lt, rb) / torch.max(lt, rb), dim=1)
)
scores = scores * centerness
out_scores[batch, : scores.size()[0]] = scores
out_boxes[batch, : boxes.size()[0], :] = boxes
out_classes[batch, : classes.size()[0]] = classes
return out_scores, out_boxes, out_classes
[docs]def nms(all_scores, all_boxes, all_classes, nms=0.5, ndetections=100, using_diou=True):
"Non Maximum Suppression"
# if torch.cuda.is_available():
# return nms_cuda(
# all_scores.float(), all_boxes.float(), all_classes.float(), nms, ndetections)
device = all_scores.device
batch_size = all_scores.size()[0]
out_scores = torch.zeros((batch_size, ndetections), device=device)
out_boxes = torch.zeros((batch_size, ndetections, 4), device=device)
out_classes = torch.zeros((batch_size, ndetections), device=device)
# Per item in batch
for batch in range(batch_size):
# Discard null scores
keep = (all_scores[batch, :].view(-1) > 0).nonzero()
scores = all_scores[batch, keep].view(-1)
boxes = all_boxes[batch, keep, :].view(-1, 4)
classes = all_classes[batch, keep].view(-1)
if scores.nelement() == 0:
continue
# Sort boxes
scores, indices = torch.sort(scores, descending=True)
boxes, classes = boxes[indices], classes[indices]
areas = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1).view(
-1
)
keep = torch.ones(scores.nelement(), device=device, dtype=torch.uint8).view(-1)
for i in range(ndetections):
if i >= keep.nonzero().nelement() or i >= scores.nelement():
i -= 1
break
# Find overlapping boxes with lower score
xy1 = torch.max(boxes[:, :2], boxes[i, :2])
xy2 = torch.min(boxes[:, 2:], boxes[i, 2:])
inter = torch.prod((xy2 - xy1 + 1).clamp(0), 1)
iou = inter / (areas + areas[i] - inter + 1e-7)
if using_diou:
outer_lt = torch.min(boxes[:, :2], boxes[i, :2])
outer_rb = torch.max(boxes[:, 2:], boxes[i, 2:])
inter_diag = ((boxes[:, :2] - boxes[i, :2]) ** 2).sum(dim=1)
outer_diag = ((outer_rb - outer_lt) ** 2).sum(dim=1) + 1e-7
diou = (iou - inter_diag / outer_diag).clamp(-1.0, 1.0)
iou = diou
criterion = (scores > scores[i]) | (iou <= nms) | (classes != classes[i])
criterion[i] = 1
# Only keep relevant boxes
scores = scores[criterion.nonzero()].view(-1)
boxes = boxes[criterion.nonzero(), :].view(-1, 4)
classes = classes[criterion.nonzero()].view(-1)
areas = areas[criterion.nonzero()].view(-1)
keep[(~criterion).nonzero()] = 0
out_scores[batch, : i + 1] = scores[: i + 1]
out_boxes[batch, : i + 1, :] = boxes[: i + 1, :]
out_classes[batch, : i + 1] = classes[: i + 1]
return out_scores, out_boxes, out_classes