조현아

reshape fc input

......@@ -4,12 +4,6 @@ class BaseNet(nn.Module):
def __init__(self, backbone, args):
super(BaseNet, self).__init__()
#testing
for layer in backbone.children():
print("\nRESNET50 LAYERS\n")
print(layer)
# Separate layers
self.first = nn.Sequential(*list(backbone.children())[:1])
self.after = nn.Sequential(*list(backbone.children())[1:-1])
......@@ -20,10 +14,10 @@ class BaseNet(nn.Module):
def forward(self, x):
f = self.first(x)
x = self.after(f)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x, f
"""
print("before reshape:\n", x.size())
#[128, 2048, 4, 4]
......