Showing
1 changed file
with
1 additions
and
7 deletions
... | @@ -4,12 +4,6 @@ class BaseNet(nn.Module): | ... | @@ -4,12 +4,6 @@ class BaseNet(nn.Module): |
4 | def __init__(self, backbone, args): | 4 | def __init__(self, backbone, args): |
5 | super(BaseNet, self).__init__() | 5 | super(BaseNet, self).__init__() |
6 | 6 | ||
7 | - #testing | ||
8 | - for layer in backbone.children(): | ||
9 | - print("\nRESNET50 LAYERS\n") | ||
10 | - print(layer) | ||
11 | - | ||
12 | - | ||
13 | # Separate layers | 7 | # Separate layers |
14 | self.first = nn.Sequential(*list(backbone.children())[:1]) | 8 | self.first = nn.Sequential(*list(backbone.children())[:1]) |
15 | self.after = nn.Sequential(*list(backbone.children())[1:-1]) | 9 | self.after = nn.Sequential(*list(backbone.children())[1:-1]) |
... | @@ -20,10 +14,10 @@ class BaseNet(nn.Module): | ... | @@ -20,10 +14,10 @@ class BaseNet(nn.Module): |
20 | def forward(self, x): | 14 | def forward(self, x): |
21 | f = self.first(x) | 15 | f = self.first(x) |
22 | x = self.after(f) | 16 | x = self.after(f) |
17 | + x = x.reshape(x.size(0), -1) | ||
23 | x = self.fc(x) | 18 | x = self.fc(x) |
24 | return x, f | 19 | return x, f |
25 | 20 | ||
26 | - | ||
27 | """ | 21 | """ |
28 | print("before reshape:\n", x.size()) | 22 | print("before reshape:\n", x.size()) |
29 | #[128, 2048, 4, 4] | 23 | #[128, 2048, 4, 4] | ... | ... |
-
Please register or login to post a comment