import torch
from .box import decode, nms
[docs]class Decoder(object):
    r"""
    class Decoder contains the decoder func and nms func
    
    * decoder
        decoder is used to decode the boxes from loc and conf feature map, check :meth:`ssds.modeling.layers.box.decode` for more details.
    * nms
        nms is used to filter the decoded boxes by its confidence and box location, check :meth:`ssds.modeling.layers.box.nms` for more details.
    """
    def __init__(
        self, conf_threshold, nms_threshold, top_n, top_n_per_level, rescore, use_diou
    ):
        self.conf_threshold = conf_threshold
        self.nms_threshold = nms_threshold
        self.top_n = top_n
        self.top_n_per_level = top_n_per_level
        self.rescore = rescore
        self.use_diou = use_diou
[docs]    def __call__(self, loc, conf, anchors):
        r"""
        Decode and filter boxes
        Returns:
            out_scores,  (batch, top_n)
            out_boxes,   (batch, top_n, 4) with ltrb format
            out_classes, (batch, top_n)
        """
        decoded = [
            decode(
                c,
                l,
                stride,
                self.conf_threshold,
                self.top_n_per_level,
                anchor,
                rescore=self.rescore,
            )
            for l, c, (stride, anchor) in zip(loc, conf, anchors.items())
        ]
        decoded = [torch.cat(tensors, 1) for tensors in zip(*decoded)]
        return nms(*decoded, self.nms_threshold, self.top_n, using_diou=self.use_diou)