import re
import torch
import torch.nn as nn
from torchvision.models import densenet
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
from .rutils import register
class DenseNet(nn.Module):
def __init__(
self,
growth_rate=32,
block_config=(6, 12, 24, 16),
num_init_features=64,
bn_size=4,
drop_rate=0,
memory_efficient=False,
outputs=[],
url=None,
):
super(DenseNet, self).__init__()
self.url = url
self.outputs = outputs
self.block_config = block_config
# First convolution
self.conv1 = nn.Sequential(
OrderedDict(
[
(
"conv",
nn.Conv2d(
3,
num_init_features,
kernel_size=7,
stride=2,
padding=3,
bias=False,
),
),
("norm", nn.BatchNorm2d(num_init_features)),
("relu", nn.ReLU(inplace=True)),
("pool", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]
)
)
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = densenet._DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.add_module("denseblock%d" % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = densenet._Transition(
num_input_features=num_features,
num_output_features=num_features // 2,
)
self.add_module("transition%d" % (i + 1), trans)
num_features = num_features // 2
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def initialize(self):
if self.url:
checkpoint = model_zoo.load_url(self.url)
pattern = re.compile(
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
for key in list(checkpoint.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
checkpoint[new_key] = checkpoint[key]
del checkpoint[key]
change_dict = {
"features.conv0.": "conv1.conv.",
"features.norm0.": "conv1.norm.",
}
for i, num_layers in enumerate(self.block_config):
change_dict[
"features.denseblock{}.".format(i + 1)
] = "denseblock{}.".format(i + 1)
change_dict[
"features.transition{}.".format(i + 1)
] = "transition{}.".format(i + 1)
for k, v in list(checkpoint.items()):
for _k, _v in list(change_dict.items()):
if _k in k:
new_key = k.replace(_k, _v)
checkpoint[new_key] = checkpoint.pop(k)
remove_dict = ["classifier.", "features.norm5."]
for k, v in list(checkpoint.items()):
for _k in remove_dict:
if _k in k:
checkpoint.pop(k)
self.load_state_dict(checkpoint)
def forward(self, x):
x = self.conv1(x)
outputs = []
for j in range(len(self.block_config)):
level = j + 1 # only 1 conv before
if level > max(self.outputs):
break
if level > 1:
x = getattr(self, "transition{}".format(level - 1))(x)
x = getattr(self, "denseblock{}".format(level))(x)
if level in self.outputs:
outputs.append(x)
return outputs
[docs]@register
def DenseNet121(outputs, **kwargs):
return DenseNet(
32, (6, 12, 24, 16), 64, outputs=outputs, url=densenet.model_urls["densenet121"]
)
# print(DenseNet121([4]))