Hyunji

3D-CNN, VGGNet model

1 +import torch.nn.functional as F
2 +import torch.nn as nn
3 +from torch.autograd import Variable
4 +
5 +class Model(nn.Module):
6 + def __init__(self):
7 + super(Model, self).__init__()
8 + self.features = nn.Sequential(
9 + nn.Conv2d(121, 64, 3, padding=1),
10 + nn.BatchNorm2d(64),
11 + nn.ReLU(),
12 + nn.MaxPool2d(2),
13 + # nn.Dropout2d(),
14 + nn.Conv2d(64, 32, 3, padding=1),
15 + nn.BatchNorm2d(32),
16 + nn.ReLU(),
17 + nn.MaxPool2d(2),
18 + nn.Conv2d(32, 16, 3, padding=1),
19 + nn.BatchNorm2d(16),
20 + nn.ReLU(),
21 + nn.MaxPool2d(2)
22 + )
23 +
24 + self.classifier = nn.Sequential(
25 + nn.Linear(4320, 1024),
26 + # nn.BatchNorm1d(1024),
27 + nn.ReLU(),
28 + nn.Dropout(),
29 + nn.Linear(1024, 512),
30 + # nn.BatchNorm1d(512),
31 + nn.ReLU(),
32 + nn.Dropout(),
33 + nn.Linear(512, 1),
34 + # nn.BatchNorm1d(),
35 + nn.ReLU()
36 + )
37 +
38 +
39 + # self.linear1 = nn.Linear()
40 +
41 +
42 + def forward(self, x):
43 + x = self.features(x)
44 + x = self.classifier(x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]))
45 + # print(x.shape)
46 +
47 + return x
48 +
49 +
50 +class VGGBasedModel(nn.Module):
51 + def __init__(self):
52 + super(VGGBasedModel, self).__init__()
53 + self.features = nn.Sequential(
54 + nn.Conv2d(121, 128, 3, padding=1),
55 + nn.BatchNorm2d(128),
56 + nn.ReLU(),
57 + nn.Conv2d(128, 128, 3, padding=1),
58 + nn.BatchNorm2d(128),
59 + nn.ReLU(),
60 + nn.MaxPool2d(2),
61 +
62 + nn.Conv2d(128, 256, 3, padding=1),
63 + nn.BatchNorm2d(256),
64 + nn.ReLU(),
65 + nn.Conv2d(256, 256, 3, padding=1),
66 + nn.BatchNorm2d(256),
67 + nn.ReLU(),
68 + nn.MaxPool2d(2),
69 +
70 + nn.Conv2d(256, 512, 3, padding=1),
71 + nn.BatchNorm2d(512),
72 + nn.ReLU(),
73 + nn.Conv2d(512, 512, 3, padding=1),
74 + nn.BatchNorm2d(512),
75 + nn.ReLU(),
76 + nn.MaxPool2d(2),
77 +
78 + nn.Conv2d(512, 512, 3, padding=1),
79 + nn.BatchNorm2d(512),
80 + nn.ReLU(),
81 + nn.Conv2d(512, 512, 3, padding=1),
82 + nn.BatchNorm2d(512),
83 + nn.ReLU(),
84 + nn.MaxPool2d(2),
85 +
86 + nn.Conv2d(512, 256, 3, padding=1),
87 + nn.BatchNorm2d(256),
88 + nn.ReLU(),
89 + nn.Conv2d(256, 256, 3, padding=1),
90 + nn.BatchNorm2d(256),
91 + nn.ReLU(),
92 + nn.MaxPool2d(2),
93 +
94 + )
95 +
96 + self.classifier = nn.Sequential(
97 + nn.Linear(3072, 2048),
98 + # nn.BatchNorm1d(1024),
99 + nn.ReLU(),
100 + nn.Dropout(),
101 + nn.Linear(2048, 1024),
102 + nn.ReLU(),
103 + nn.Dropout(),
104 + nn.Linear(1024, 512),
105 + nn.ReLU(),
106 + nn.Dropout(),
107 + nn.Linear(512, 1),
108 + # nn.BatchNorm1d(),
109 + nn.ReLU()
110 + )
111 +
112 +
113 + def forward(self, x):
114 + x = self.features(x)
115 + x = self.classifier(x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]))
116 + # print(x.shape)
117 +
118 + return x
119 +
120 +
121 +class VGGBasedModel2D(nn.Module):
122 + def __init__(self):
123 + super(VGGBasedModel2D, self).__init__()
124 + self.features = nn.Sequential(
125 + nn.Conv2d(1, 128, 3, padding=1),
126 + nn.BatchNorm2d(128),
127 + nn.ReLU(),
128 + nn.Conv2d(128, 128, 3, padding=1),
129 + nn.BatchNorm2d(128),
130 + nn.ReLU(),
131 + nn.MaxPool2d(2),
132 +
133 + nn.Conv2d(128, 256, 3, padding=1),
134 + nn.BatchNorm2d(256),
135 + nn.ReLU(),
136 + nn.Conv2d(256, 256, 3, padding=1),
137 + nn.BatchNorm2d(256),
138 + nn.ReLU(),
139 + nn.MaxPool2d(2),
140 +
141 + nn.Conv2d(256, 512, 3, padding=1),
142 + nn.BatchNorm2d(512),
143 + nn.ReLU(),
144 + nn.Conv2d(512, 512, 3, padding=1),
145 + nn.BatchNorm2d(512),
146 + nn.ReLU(),
147 + nn.MaxPool2d(2),
148 +
149 + nn.Conv2d(512, 512, 3, padding=1),
150 + nn.BatchNorm2d(512),
151 + nn.ReLU(),
152 + nn.Conv2d(512, 512, 3, padding=1),
153 + nn.BatchNorm2d(512),
154 + nn.ReLU(),
155 + nn.MaxPool2d(2),
156 +
157 + nn.Conv2d(512, 256, 3, padding=1),
158 + nn.BatchNorm2d(256),
159 + nn.ReLU(),
160 + nn.Conv2d(256, 256, 3, padding=1),
161 + nn.BatchNorm2d(256),
162 + nn.ReLU(),
163 + nn.MaxPool2d(2),
164 +
165 + )
166 +
167 + self.classifier = nn.Sequential(
168 + nn.Linear(3072, 2048),
169 + # nn.BatchNorm1d(1024),
170 + nn.ReLU(),
171 + nn.Dropout(),
172 + nn.Linear(2048, 1024),
173 + nn.ReLU(),
174 + nn.Dropout(),
175 + nn.Linear(1024, 512),
176 + nn.ReLU(),
177 + nn.Dropout(),
178 + nn.Linear(512, 1),
179 + # nn.BatchNorm1d(),
180 + nn.ReLU()
181 + )
182 +
183 +
184 + def forward(self, x):
185 + x = self.features(x)
186 + x = self.classifier(x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]))
187 + # print(x.shape)
188 +
189 + return x
190 +
191 +class Model3D(nn.Module):
192 + def __init__(self):
193 + super(Model3D, self).__init__()
194 + self.block1 = nn.Sequential(
195 + nn.Conv3d(1, 8, 3, stride=1),
196 + nn.ReLU(),
197 + nn.Conv3d(8, 8, 3, stride=1),
198 + nn.BatchNorm3d(8),
199 + nn.ReLU(),
200 + nn.MaxPool3d(2, stride=2))
201 +
202 + self.block2 = nn.Sequential(
203 + nn.Conv3d(8, 16, 3, stride=1),
204 + nn.ReLU(),
205 + nn.Conv3d(16, 16, 3, stride=1),
206 + nn.BatchNorm3d(16),
207 + nn.ReLU(),
208 + nn.MaxPool3d(2, stride=2))
209 +
210 + self.block3 = nn.Sequential(
211 + nn.Conv3d(16, 32, 3, stride=1),
212 + nn.ReLU(),
213 + nn.Conv3d(32, 32, 3, stride=1),
214 + nn.BatchNorm3d(32),
215 + nn.ReLU(),
216 + nn.MaxPool3d(2, stride=2))
217 +
218 + self.block4 = nn.Sequential(
219 + nn.Conv3d(32, 64, 3, stride=1, padding=1),
220 + nn.ReLU(),
221 + nn.Conv3d(64, 64, 3, stride=1, padding=1),
222 + nn.BatchNorm3d(64),
223 + nn.ReLU(),
224 + nn.MaxPool3d(2, stride=2))
225 +
226 + self.block5 = nn.Sequential(
227 + nn.Conv3d(64, 128, 3, stride=1, padding=1),
228 + nn.ReLU(),
229 + nn.Conv3d(128, 128, 3, stride=1, padding=1),
230 + nn.BatchNorm3d(128),
231 + nn.ReLU(),
232 + nn.MaxPool3d(2, stride=2)
233 + )
234 +
235 + self.classifier = nn.Linear(1536, 1)
236 +
237 +
238 + def forward(self, x):
239 + x = self.block1(x)
240 + # print(x.shape)
241 + x = self.block2(x)
242 + # print(x.shape)
243 + x = self.block3(x)
244 + # print(x.shape)
245 + x = self.block4(x)
246 + # print(x.shape)
247 + x = self.block5(x)
248 + # print(x.shape)
249 + x = self.classifier(x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]*x.shape[4]))
250 + # print(x.shape)
251 +
252 + return x
...\ No newline at end of file ...\ No newline at end of file