Source code for ssds.modeling.nets.shufflenet

import torchvision
from torchvision.models import shufflenetv2
import torch.utils.model_zoo as model_zoo
from .rutils import register


class ShuffleNetV2(shufflenetv2.ShuffleNetV2):
    def __init__(self, stages_repeats, stages_out_channels, outputs=[4], url=None):
        super(ShuffleNetV2, self).__init__(stages_repeats, stages_out_channels)
        self.outputs = outputs
        self.url = url

    def initialize(self):
        if self.url:
            self.load_state_dict(model_zoo.load_url(self.url))

    def forward(self, x):
        x = self.maxpool(self.conv1(x))

        outputs = []
        for i, stage in enumerate([self.stage2, self.stage3, self.stage4]):
            level = i + 2
            if level > max(self.outputs):
                break
            x = stage(x)
            if level in self.outputs:
                outputs.append(x)
        return outputs


[docs]@register def ShuffleNetV2_x1(outputs, **kwargs): return ShuffleNetV2( [4, 8, 4], [24, 116, 232, 464, 1024], outputs=outputs, url=shufflenetv2.model_urls["shufflenetv2_x1.0"], )
[docs]@register def ShuffleNetV2_x2(outputs, **kwargs): return ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], outputs=outputs)