Source code for ssds.modeling.nets.effnet

""" This file is similar with the effcientnet file, but use torch hub instand of using 
"""

import torch
import torch.nn as nn
from .rutils import register


class EffNet(nn.Module):
    def __init__(self, model_name, outputs, exportable=False, **kwargs):
        super(EffNet, self).__init__()
        self.outputs = outputs

        if exportable:
            import geffnet

            geffnet.config.set_exportable(True)
            model = geffnet.create_model(model_name, **kwargs)
        else:
            model = torch.hub.load(
                "rwightman/gen-efficientnet-pytorch", model_name, **kwargs
            )

        self.conv_stem = model.conv_stem
        self.bn1 = model.bn1
        self.act1 = model.act1
        for j in range(7):
            self.add_module(
                "block{}".format(j + 1), getattr(model.blocks, "{}".format(j))
            )

    def forward(self, x):
        x = self.act1(self.bn1(self.conv_stem(x)))

        outputs = []
        for level in range(1, 8):
            # level = j + 1 # only 1 conv before
            if level > max(self.outputs):
                break
            x = getattr(self, "block{}".format(level))(x)
            if level in self.outputs:
                outputs.append(x)

        return outputs

    def initialize(self):
        pass


[docs]@register def EffNetB0(outputs, **kwargs): return EffNet("efficientnet_b0", outputs, drop_connect_rate=0.2, pretrained=True)
[docs]@register def EffNetB1(outputs, **kwargs): return EffNet("efficientnet_b1", outputs, drop_connect_rate=0.2, pretrained=True)
[docs]@register def EffNetB2(outputs, **kwargs): return EffNet("efficientnet_b2", outputs, drop_connect_rate=0.2, pretrained=True)
[docs]@register def EffNetB3(outputs, **kwargs): return EffNet("efficientnet_b3", outputs, drop_connect_rate=0.2, pretrained=True)
[docs]@register def EffNetB4(outputs, **kwargs): return EffNet("efficientnet_b4", outputs, drop_connect_rate=0.2, pretrained=True)
[docs]@register def EffNetB5(outputs, **kwargs): return EffNet("efficientnet_b5", outputs, drop_connect_rate=0.2, pretrained=True)
[docs]@register def EffNetB6(outputs, **kwargs): return EffNet("efficientnet_b6", outputs, drop_connect_rate=0.2, pretrained=True)
[docs]@register def EffNetB7(outputs, **kwargs): return EffNet("efficientnet_b7", outputs, drop_connect_rate=0.2, pretrained=True)
[docs]@register def EffNetB0Ex(outputs, **kwargs): return EffNet( "efficientnet_b0", outputs, drop_connect_rate=0.2, pretrained=True, exportable=True, )