import torchvision
from torchvision.models import shufflenetv2
import torch.utils.model_zoo as model_zoo
from .rutils import register
class ShuffleNetV2(shufflenetv2.ShuffleNetV2):
def __init__(self, stages_repeats, stages_out_channels, outputs=[4], url=None):
super(ShuffleNetV2, self).__init__(stages_repeats, stages_out_channels)
self.outputs = outputs
self.url = url
def initialize(self):
if self.url:
self.load_state_dict(model_zoo.load_url(self.url))
def forward(self, x):
x = self.maxpool(self.conv1(x))
outputs = []
for i, stage in enumerate([self.stage2, self.stage3, self.stage4]):
level = i + 2
if level > max(self.outputs):
break
x = stage(x)
if level in self.outputs:
outputs.append(x)
return outputs
[docs]@register
def ShuffleNetV2_x1(outputs, **kwargs):
return ShuffleNetV2(
[4, 8, 4],
[24, 116, 232, 464, 1024],
outputs=outputs,
url=shufflenetv2.model_urls["shufflenetv2_x1.0"],
)
[docs]@register
def ShuffleNetV2_x2(outputs, **kwargs):
return ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], outputs=outputs)