""" This file is similar with the effcientnet file, but use torch hub instand of using 
"""
import torch
import torch.nn as nn
from .rutils import register
class EffNet(nn.Module):
    def __init__(self, model_name, outputs, exportable=False, **kwargs):
        super(EffNet, self).__init__()
        self.outputs = outputs
        if exportable:
            import geffnet
            geffnet.config.set_exportable(True)
            model = geffnet.create_model(model_name, **kwargs)
        else:
            model = torch.hub.load(
                "rwightman/gen-efficientnet-pytorch", model_name, **kwargs
            )
        self.conv_stem = model.conv_stem
        self.bn1 = model.bn1
        self.act1 = model.act1
        for j in range(7):
            self.add_module(
                "block{}".format(j + 1), getattr(model.blocks, "{}".format(j))
            )
    def forward(self, x):
        x = self.act1(self.bn1(self.conv_stem(x)))
        outputs = []
        for level in range(1, 8):
            # level = j + 1 # only 1 conv before
            if level > max(self.outputs):
                break
            x = getattr(self, "block{}".format(level))(x)
            if level in self.outputs:
                outputs.append(x)
        return outputs
    def initialize(self):
        pass
[docs]@register
def EffNetB0(outputs, **kwargs):
    return EffNet("efficientnet_b0", outputs, drop_connect_rate=0.2, pretrained=True) 
[docs]@register
def EffNetB1(outputs, **kwargs):
    return EffNet("efficientnet_b1", outputs, drop_connect_rate=0.2, pretrained=True) 
[docs]@register
def EffNetB2(outputs, **kwargs):
    return EffNet("efficientnet_b2", outputs, drop_connect_rate=0.2, pretrained=True) 
[docs]@register
def EffNetB3(outputs, **kwargs):
    return EffNet("efficientnet_b3", outputs, drop_connect_rate=0.2, pretrained=True) 
[docs]@register
def EffNetB4(outputs, **kwargs):
    return EffNet("efficientnet_b4", outputs, drop_connect_rate=0.2, pretrained=True) 
[docs]@register
def EffNetB5(outputs, **kwargs):
    return EffNet("efficientnet_b5", outputs, drop_connect_rate=0.2, pretrained=True) 
[docs]@register
def EffNetB6(outputs, **kwargs):
    return EffNet("efficientnet_b6", outputs, drop_connect_rate=0.2, pretrained=True) 
[docs]@register
def EffNetB7(outputs, **kwargs):
    return EffNet("efficientnet_b7", outputs, drop_connect_rate=0.2, pretrained=True) 
[docs]@register
def EffNetB0Ex(outputs, **kwargs):
    return EffNet(
        "efficientnet_b0",
        outputs,
        drop_connect_rate=0.2,
        pretrained=True,
        exportable=True,
    )