Showing
1 changed file
with
235 additions
and
0 deletions
3DCNN_VGGNet_2DResNet/model_resnet.py
0 → 100644
1 | +import torch.nn as nn | ||
2 | +import math | ||
3 | +import torch.utils.model_zoo as model_zoo | ||
4 | + | ||
5 | + | ||
6 | +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', | ||
7 | + 'resnet152'] | ||
8 | + | ||
9 | + | ||
10 | +model_urls = { | ||
11 | + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | ||
12 | + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | ||
13 | + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | ||
14 | + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | ||
15 | + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | ||
16 | +} | ||
17 | + | ||
18 | + | ||
19 | +def conv3x3(in_planes, out_planes, stride=1): | ||
20 | + """3x3 convolution with padding""" | ||
21 | + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
22 | + padding=1, bias=False) | ||
23 | + | ||
24 | + | ||
25 | +class BasicBlock(nn.Module): | ||
26 | + expansion = 1 | ||
27 | + | ||
28 | + def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
29 | + super(BasicBlock, self).__init__() | ||
30 | + self.conv1 = conv3x3(inplanes, planes, stride) | ||
31 | + #alter batchnorm2d to groupnorm | ||
32 | + self.bn1 = nn.BatchNorm2d(planes) | ||
33 | + self.relu = nn.ReLU(inplace=True) | ||
34 | + self.conv2 = conv3x3(planes, planes) | ||
35 | + self.bn2 = nn.BatchNorm2d(planes) | ||
36 | + self.downsample = downsample | ||
37 | + self.stride = stride | ||
38 | + | ||
39 | + def forward(self, x): | ||
40 | + residual = x | ||
41 | + | ||
42 | + out = self.conv1(x) | ||
43 | + out = self.bn1(out) | ||
44 | + out = self.relu(out) | ||
45 | + | ||
46 | + out = self.conv2(out) | ||
47 | + out = self.bn2(out) | ||
48 | + | ||
49 | + if self.downsample is not None: | ||
50 | + residual = self.downsample(x) | ||
51 | + | ||
52 | + out += residual | ||
53 | + out = self.relu(out) | ||
54 | + | ||
55 | + return out | ||
56 | + | ||
57 | + | ||
58 | +class Bottleneck(nn.Module): | ||
59 | + expansion = 4 | ||
60 | + | ||
61 | + def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
62 | + super(Bottleneck, self).__init__() | ||
63 | + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
64 | + self.bn1 = nn.BatchNorm2d(planes) | ||
65 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | ||
66 | + padding=1, bias=False) | ||
67 | + self.bn2 = nn.BatchNorm2d(planes) | ||
68 | + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | ||
69 | + self.bn3 = nn.BatchNorm2d(planes * 4) | ||
70 | + self.relu = nn.ReLU(inplace=True) | ||
71 | + self.downsample = downsample | ||
72 | + self.stride = stride | ||
73 | + | ||
74 | + def forward(self, x): | ||
75 | + residual = x | ||
76 | + | ||
77 | + out = self.conv1(x) | ||
78 | + out = self.bn1(out) | ||
79 | + out = self.relu(out) | ||
80 | + | ||
81 | + out = self.conv2(out) | ||
82 | + out = self.bn2(out) | ||
83 | + out = self.relu(out) | ||
84 | + | ||
85 | + out = self.conv3(out) | ||
86 | + out = self.bn3(out) | ||
87 | + | ||
88 | + if self.downsample is not None: | ||
89 | + residual = self.downsample(x) | ||
90 | + | ||
91 | + out += residual | ||
92 | + out = self.relu(out) | ||
93 | + | ||
94 | + return out | ||
95 | + | ||
96 | + | ||
97 | +class ResNet(nn.Module): | ||
98 | + #change num_classes to 1 | ||
99 | + def __init__(self, block, layers, num_classes=1): | ||
100 | + self.inplanes = 64 | ||
101 | + super(ResNet, self).__init__() | ||
102 | + #first param changed to 1 | ||
103 | + self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, | ||
104 | + bias=False) | ||
105 | + self.bn1 = nn.BatchNorm2d(64) | ||
106 | + self.relu = nn.ReLU(inplace=True) | ||
107 | + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
108 | + self.layer1 = self._make_layer(block, 64, layers[0]) | ||
109 | + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | ||
110 | + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | ||
111 | + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | ||
112 | + #change kernel size from 7 to 4 | ||
113 | + self.avgpool = nn.AvgPool2d(4, stride=1) | ||
114 | + #change from 512 to 1024 | ||
115 | + self.fc = nn.Linear(1024 * block.expansion, num_classes) | ||
116 | + | ||
117 | + for m in self.modules(): | ||
118 | + if isinstance(m, nn.Conv2d): | ||
119 | + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
120 | + m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
121 | + elif isinstance(m, nn.BatchNorm2d): | ||
122 | + m.weight.data.fill_(1) | ||
123 | + m.bias.data.zero_() | ||
124 | + | ||
125 | + def _make_layer(self, block, planes, blocks, stride=1): | ||
126 | + downsample = None | ||
127 | + if stride != 1 or self.inplanes != planes * block.expansion: | ||
128 | + downsample = nn.Sequential( | ||
129 | + nn.Conv2d(self.inplanes, planes * block.expansion, | ||
130 | + kernel_size=1, stride=stride, bias=False), | ||
131 | + nn.BatchNorm2d(planes * block.expansion), | ||
132 | + ) | ||
133 | + | ||
134 | + layers = [] | ||
135 | + layers.append(block(self.inplanes, planes, stride, downsample)) | ||
136 | + self.inplanes = planes * block.expansion | ||
137 | + for i in range(1, blocks): | ||
138 | + layers.append(block(self.inplanes, planes)) | ||
139 | + | ||
140 | + return nn.Sequential(*layers) | ||
141 | + | ||
142 | + def forward(self, x): | ||
143 | + x = self.conv1(x) | ||
144 | + x = self.bn1(x) | ||
145 | + x = self.relu(x) | ||
146 | + x = self.maxpool(x) | ||
147 | + | ||
148 | + x = self.layer1(x) | ||
149 | + x = self.layer2(x) | ||
150 | + x = self.layer3(x) | ||
151 | + x = self.layer4(x) | ||
152 | + | ||
153 | + # print(x.shape) # --> [32,512,4,5] | ||
154 | + x = self.avgpool(x) | ||
155 | + # print(x.shape) # --> [32,512,1,2], kernel = 4 | ||
156 | + x = x.view(x.size(0), -1) # --> [32x1024] | ||
157 | + x = self.fc(x) | ||
158 | + | ||
159 | + return x | ||
160 | + | ||
161 | + | ||
162 | +def resnet18(pretrained=False, **kwargs): | ||
163 | + """Constructs a ResNet-18 model. | ||
164 | + | ||
165 | + Args: | ||
166 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
167 | + """ | ||
168 | + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) | ||
169 | + if pretrained: | ||
170 | + #model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) # | ||
171 | + state_dict = model_zoo.load_url(model_urls['resnet18']) | ||
172 | + | ||
173 | + from collections import OrderedDict | ||
174 | + new_state_dict = OrderedDict() | ||
175 | + | ||
176 | + for k, v in state_dict.items(): | ||
177 | + if 'module' not in k: | ||
178 | + k = 'module.'+k | ||
179 | + else: | ||
180 | + k = k.replace('features.module.', 'module.features.') | ||
181 | + new_state_dict[k]=v | ||
182 | + model.load_state_dict(new_state_dict) | ||
183 | + return model | ||
184 | + | ||
185 | + | ||
186 | + | ||
187 | +def resnet34(pretrained=False, **kwargs): | ||
188 | + """Constructs a ResNet-34 model. | ||
189 | + | ||
190 | + Args: | ||
191 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
192 | + """ | ||
193 | + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) | ||
194 | + if pretrained: | ||
195 | + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) | ||
196 | + return model | ||
197 | + | ||
198 | + | ||
199 | + | ||
200 | +def resnet50(pretrained=False, **kwargs): | ||
201 | + """Constructs a ResNet-50 model. | ||
202 | + | ||
203 | + Args: | ||
204 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
205 | + """ | ||
206 | + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) | ||
207 | + if pretrained: | ||
208 | + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) | ||
209 | + return model | ||
210 | + | ||
211 | + | ||
212 | + | ||
213 | +def resnet101(pretrained=False, **kwargs): | ||
214 | + """Constructs a ResNet-101 model. | ||
215 | + | ||
216 | + Args: | ||
217 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
218 | + """ | ||
219 | + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) | ||
220 | + if pretrained: | ||
221 | + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) | ||
222 | + return model | ||
223 | + | ||
224 | + | ||
225 | + | ||
226 | +def resnet152(pretrained=False, **kwargs): | ||
227 | + """Constructs a ResNet-152 model. | ||
228 | + | ||
229 | + Args: | ||
230 | + pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
231 | + """ | ||
232 | + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) | ||
233 | + if pretrained: | ||
234 | + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) | ||
235 | + return model |
-
Please register or login to post a comment