조현아

reshape fc input

...@@ -4,12 +4,6 @@ class BaseNet(nn.Module): ...@@ -4,12 +4,6 @@ class BaseNet(nn.Module):
4 def __init__(self, backbone, args): 4 def __init__(self, backbone, args):
5 super(BaseNet, self).__init__() 5 super(BaseNet, self).__init__()
6 6
7 - #testing
8 - for layer in backbone.children():
9 - print("\nRESNET50 LAYERS\n")
10 - print(layer)
11 -
12 -
13 # Separate layers 7 # Separate layers
14 self.first = nn.Sequential(*list(backbone.children())[:1]) 8 self.first = nn.Sequential(*list(backbone.children())[:1])
15 self.after = nn.Sequential(*list(backbone.children())[1:-1]) 9 self.after = nn.Sequential(*list(backbone.children())[1:-1])
...@@ -20,10 +14,10 @@ class BaseNet(nn.Module): ...@@ -20,10 +14,10 @@ class BaseNet(nn.Module):
20 def forward(self, x): 14 def forward(self, x):
21 f = self.first(x) 15 f = self.first(x)
22 x = self.after(f) 16 x = self.after(f)
17 + x = x.reshape(x.size(0), -1)
23 x = self.fc(x) 18 x = self.fc(x)
24 return x, f 19 return x, f
25 20
26 -
27 """ 21 """
28 print("before reshape:\n", x.size()) 22 print("before reshape:\n", x.size())
29 #[128, 2048, 4, 4] 23 #[128, 2048, 4, 4]
......