Source code for ssds.modeling.nets.mobilenet

import torch
import torch.nn as nn
from torchvision.models import mobilenet
import torch.utils.model_zoo as model_zoo
from .rutils import register


class SepConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, expand_ratio=1):
        padding = (kernel_size - 1) // 2
        super(SepConvBNReLU, self).__init__(
            # dw
            nn.Conv2d(
                in_planes,
                in_planes,
                kernel_size,
                stride,
                padding,
                groups=in_planes,
                bias=False,
            ),
            nn.BatchNorm2d(in_planes),
            nn.ReLU6(inplace=True),
            # pw
            nn.Conv2d(in_planes, out_planes, 1, 1, 0, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True),
        )


class MobileNet(nn.Module):
    def __init__(self, num_classes=1000, width_mult=1.0, version="v1", round_nearest=8):
        """
        MobileNet V2 main class
        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
        """
        super(MobileNet, self).__init__()

        input_channel = 32
        if version == "v2":
            settings = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]
            last_channel = 1280
            layer = mobilenet.InvertedResidual
        elif version == "v1":
            settings = [
                # t, c, n, s
                [1, 64, 1, 1],
                [1, 128, 2, 2],
                [1, 256, 2, 2],
                [1, 512, 6, 2],
                [1, 1024, 2, 2],
            ]
            last_channel = 1024
            layer = SepConvBNReLU
        self.settings = settings
        self.version = version

        # building first layer
        input_channel = mobilenet._make_divisible(
            input_channel * width_mult, round_nearest
        )
        self.last_channel = mobilenet._make_divisible(
            last_channel * max(1.0, width_mult), round_nearest
        )
        self.conv1 = mobilenet.ConvBNReLU(3, input_channel, stride=2)
        # building inverted residual blocks
        for j, (t, c, n, s) in enumerate(settings):
            output_channel = mobilenet._make_divisible(c * width_mult, round_nearest)
            layers = []
            for i in range(n):
                stride = s if i == 0 else 1
                layers.append(
                    layer(input_channel, output_channel, stride=stride, expand_ratio=t)
                )
                input_channel = output_channel
            self.add_module("layer{}".format(j + 1), nn.Sequential(*layers))
        # building last several layers
        if self.version == "v2":
            self.head_conv = mobilenet.ConvBNReLU(
                input_channel, self.last_channel, kernel_size=1
            )

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2), nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.conv1(x)
        for j in range(len(self.settings)):
            x = getattr(self, "layer{}".format(j + 1))(x)
        if self.version == "v2":
            x = self.head_conv(x)
        x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
        x = self.classifier(x)
        return x


class MobileNetEx(MobileNet):
    def __init__(self, width_mult=1.0, version="v1", outputs=[7], url=None):
        super(MobileNetEx, self).__init__(width_mult=width_mult, version=version)
        self.url = url
        self.outputs = outputs

    def initialize(self):
        if self.url:
            checkpoint = model_zoo.load_url(self.url)
            if self.version == "v2":
                change_dict = {"features.0.": "conv1."}
                f_idx = 1
                for j, (t, c, n, s) in enumerate(self.settings):
                    for i in range(n):
                        change_dict[
                            "features.{}.".format(f_idx)
                        ] = "layer{}.{}.".format(j + 1, i)
                        f_idx += 1
                change_dict["features.{}.".format(f_idx)] = "head_conv."
                for k, v in list(checkpoint.items()):
                    for _k, _v in list(change_dict.items()):
                        if _k in k:
                            new_key = k.replace(_k, _v)
                            checkpoint[new_key] = checkpoint.pop(k)
            else:
                change_dict = {"features.Conv2d_0.conv.": "conv1."}
                f_idx = 1
                for j, (t, c, n, s) in enumerate(self.settings):
                    for i in range(n):
                        for z in range(2):
                            change_dict[
                                "features.Conv2d_{}.depthwise.{}".format(f_idx, z)
                            ] = "layer{}.{}.{}".format(j + 1, i, z)
                            change_dict[
                                "features.Conv2d_{}.pointwise.{}".format(f_idx, z)
                            ] = "layer{}.{}.{}".format(j + 1, i, z + 3)
                        f_idx += 1
                for k, v in list(checkpoint.items()):
                    for _k, _v in list(change_dict.items()):
                        if _k in k:
                            new_key = k.replace(_k, _v)
                            checkpoint[new_key] = checkpoint.pop(k)

                remove_dict = ["classifier."]
                for k, v in list(checkpoint.items()):
                    for _k in remove_dict:
                        if _k in k:
                            checkpoint.pop(k)

                org_checkpoint = self.state_dict()
                org_checkpoint.update(checkpoint)
                checkpoint = org_checkpoint

            self.load_state_dict(checkpoint)

    def forward(self, x):
        x = self.conv1(x)

        outputs = []
        for j in range(len(self.settings)):
            level = j + 1  # only 1 conv before
            if level > max(self.outputs):
                break
            x = getattr(self, "layer{}".format(level))(x)
            if level in self.outputs:
                outputs.append(x)

        return outputs


[docs]@register def MobileNetV1(outputs, **kwargs): return MobileNetEx( width_mult=1.0, version="v1", outputs=outputs, url="https://www.dropbox.com/s/kygo8l6dwah3djv/mobilenet_v1_1.0_224.pth?dl=1", )
[docs]@register def MobileNetV2(outputs, **kwargs): return MobileNetEx( width_mult=1.0, version="v2", outputs=outputs, url=mobilenet.model_urls["mobilenet_v2"], )