조현아

run resnet & FAA getBraTS_5

This diff is collapsed. Click to expand it.
...@@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None): ...@@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None):
54 if torch.cuda.device_count() > 1: 54 if torch.cuda.device_count() > 1:
55 print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) 55 print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
56 model = nn.DataParallel(model) 56 model = nn.DataParallel(model)
57 + elif torch.cuda.device_count() == 1:
58 + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
57 59
58 start_t = time.time() 60 start_t = time.time()
59 for step in range(args.start_step, args.max_step): 61 for step in range(args.start_step, args.max_step):
60 batch = next(data_loader) 62 batch = next(data_loader)
63 +
61 _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) 64 _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device)
62 65
63 if step % args.print_step == 0: 66 if step % args.print_step == 0:
...@@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat ...@@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat
173 device = torch.device('cuda:%d' % device_id) 176 device = torch.device('cuda:%d' % device_id)
174 _transform = [] 177 _transform = []
175 178
176 - print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) 179 + print('[+] Child %d training started (GPU: %d)' % (k, device_id))
177 180
178 # train child model 181 # train child model
179 child_model = copy.deepcopy(model) 182 child_model = copy.deepcopy(model)
...@@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat ...@@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat
188 191
189 return _transform 192 return _transform
190 193
191 - 194 +#fast_auto_augment(args, model, K=4, B=1, num_process=4)
192 def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): 195 def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5):
193 args_str = json.dumps(args._asdict()) 196 args_str = json.dumps(args._asdict())
194 dataset = get_dataset(args, None, 'trainval') 197 dataset = get_dataset(args, None, 'trainval')
......
...@@ -4,6 +4,12 @@ class BaseNet(nn.Module): ...@@ -4,6 +4,12 @@ class BaseNet(nn.Module):
4 def __init__(self, backbone, args): 4 def __init__(self, backbone, args):
5 super(BaseNet, self).__init__() 5 super(BaseNet, self).__init__()
6 6
7 + #testing
8 + for layer in backbone.children():
9 + print("\nRESNET50 LAYERS\n")
10 + print(layer)
11 +
12 +
7 # Separate layers 13 # Separate layers
8 self.first = nn.Sequential(*list(backbone.children())[:1]) 14 self.first = nn.Sequential(*list(backbone.children())[:1])
9 self.after = nn.Sequential(*list(backbone.children())[1:-1]) 15 self.after = nn.Sequential(*list(backbone.children())[1:-1])
...@@ -14,6 +20,20 @@ class BaseNet(nn.Module): ...@@ -14,6 +20,20 @@ class BaseNet(nn.Module):
14 def forward(self, x): 20 def forward(self, x):
15 f = self.first(x) 21 f = self.first(x)
16 x = self.after(f) 22 x = self.after(f)
17 - x = x.reshape(x.size(0), -1)
18 x = self.fc(x) 23 x = self.fc(x)
19 return x, f 24 return x, f
25 +
26 +
27 +"""
28 + print("before reshape:\n", x.size())
29 + #[128, 2048, 4, 4]
30 + # #cifar 내장[128, 2048, 1, 1]
31 + x = x.reshape(x.size(0), -1)
32 + print("after reshape:\n", x.size())
33 + #[128, 32768]
34 + #cifar [128, 2048]
35 + #RuntimeError: size mismatch, m1: [128 x 32768], m2: [2048 x 10]
36 + print("fc :\n", self.fc)
37 + #Linear(in_features=2048, out_features=10, bias=True)
38 + #cifar Linear(in_features=2048, out_features=1000, bias=True)
39 +"""
......
1 +import torch
2 +import torch.nn as nn
3 +import torch.nn.functional as F
4 +
5 +
6 +class BasicBlock(nn.Module):
7 + expansion = 1
8 +
9 + def __init__(self, in_planes, planes, stride=1):
10 + super(BasicBlock, self).__init__()
11 + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
12 + self.bn1 = nn.BatchNorm2d(planes)
13 + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
14 + self.bn2 = nn.BatchNorm2d(planes)
15 +
16 + self.shortcut = nn.Sequential()
17 + if stride != 1 or in_planes != self.expansion*planes:
18 + self.shortcut = nn.Sequential(
19 + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
20 + nn.BatchNorm2d(self.expansion*planes)
21 + )
22 +
23 + def forward(self, x):
24 + out = F.relu(self.bn1(self.conv1(x)))
25 + out = self.bn2(self.conv2(out))
26 + out += self.shortcut(x)
27 + out = F.relu(out)
28 + return out
29 +
30 +
31 +class Bottleneck(nn.Module):
32 + expansion = 4
33 +
34 + def __init__(self, in_planes, planes, stride=1):
35 + super(Bottleneck, self).__init__()
36 + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
37 + self.bn1 = nn.BatchNorm2d(planes)
38 + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
39 + self.bn2 = nn.BatchNorm2d(planes)
40 + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
41 + self.bn3 = nn.BatchNorm2d(self.expansion*planes)
42 +
43 + self.shortcut = nn.Sequential()
44 + if stride != 1 or in_planes != self.expansion*planes:
45 + self.shortcut = nn.Sequential(
46 + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
47 + nn.BatchNorm2d(self.expansion*planes)
48 + )
49 +
50 + def forward(self, x):
51 + out = F.relu(self.bn1(self.conv1(x)))
52 + out = F.relu(self.bn2(self.conv2(out)))
53 + out = self.bn3(self.conv3(out))
54 + out += self.shortcut(x)
55 + out = F.relu(out)
56 + return out
57 +
58 +
59 +class ResNet(nn.Module):
60 + def __init__(self, block, num_blocks, num_classes=10):
61 + super(ResNet, self).__init__()
62 + self.in_planes = 64
63 +
64 + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
65 + self.bn1 = nn.BatchNorm2d(64)
66 + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
67 + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
68 + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
69 + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
70 + self.linear = nn.Linear(512*block.expansion, num_classes)
71 +
72 + def _make_layer(self, block, planes, num_blocks, stride):
73 + strides = [stride] + [1]*(num_blocks-1)
74 + layers = []
75 + for stride in strides:
76 + layers.append(block(self.in_planes, planes, stride))
77 + self.in_planes = planes * block.expansion
78 + return nn.Sequential(*layers)
79 +
80 + def forward(self, x):
81 + out = F.relu(self.bn1(self.conv1(x)))
82 + out = self.layer1(out)
83 + out = self.layer2(out)
84 + out = self.layer3(out)
85 + out = self.layer4(out)
86 + out = F.avg_pool2d(out, 4)
87 + out = out.view(out.size(0), -1)
88 + out = self.linear(out)
89 + return out
90 +
91 +
92 +def ResNet18():
93 + return ResNet(BasicBlock, [2,2,2,2])
94 +
95 +def ResNet34():
96 + return ResNet(BasicBlock, [3,4,6,3])
97 +
98 +def ResNet50():
99 + return ResNet(Bottleneck, [3,4,6,3])
100 +
101 +def ResNet101():
102 + return ResNet(Bottleneck, [3,4,23,3])
103 +
104 +def ResNet152():
105 + return ResNet(Bottleneck, [3,8,36,3])
106 +
107 +
108 +def test():
109 + net = ResNet18()
110 + y = net(torch.randn(1,3,32,32))
111 + print(y.size())
This diff is collapsed. Click to expand it.
...@@ -6,6 +6,7 @@ import pickle as cp ...@@ -6,6 +6,7 @@ import pickle as cp
6 import glob 6 import glob
7 import numpy as np 7 import numpy as np
8 import pandas as pd 8 import pandas as pd
9 +
9 from natsort import natsorted 10 from natsort import natsorted
10 from PIL import Image 11 from PIL import Image
11 import torch 12 import torch
...@@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split ...@@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split
21 from sklearn.model_selection import KFold 22 from sklearn.model_selection import KFold
22 23
23 from networks import basenet 24 from networks import basenet
25 +from networks import grayResNet, grayResNet2
24 26
25 DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' 27 DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
26 TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' 28 TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
...@@ -55,40 +57,6 @@ def split_dataset(args, dataset, k): ...@@ -55,40 +57,6 @@ def split_dataset(args, dataset, k):
55 57
56 return Dm_indexes, Da_indexes 58 return Dm_indexes, Da_indexes
57 59
58 -def split_dataset2222(args, dataset, k):
59 - # load dataset
60 - X = list(range(len(dataset)))
61 -
62 - # split to k-fold
63 - #assert len(X) == len(Y)
64 -
65 - def _it_to_list(_it):
66 - return list(zip(*list(_it)))
67 -
68 - x_train = ()
69 - x_test = ()
70 -
71 - for i in range(k):
72 - #xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1)
73 - xtr, xte = train_test_split(X, random_state=None, test_size=0.1)
74 - x_train.append(np.array(xtr))
75 - x_test.append(np.array(xte))
76 -
77 - y_train = np.array([0]* len(x_train))
78 - y_test = np.array([0]* len(x_test))
79 -
80 - x_train = tuple(x_train)
81 - x_test = tuple(x_test)
82 -
83 - trainset = (zip(x_train, y_train),)
84 - testset = (zip(x_test, y_test),)
85 -
86 - Dm_indexes, Da_indexes = trainset, testset
87 -
88 - print(type(Dm_indexes), np.shape(Dm_indexes))
89 - print("DM\n", np.shape(Dm_indexes), Dm_indexes, "\nDA\n", np.shape(Da_indexes), Da_indexes)
90 -
91 - return Dm_indexes, Da_indexes
92 60
93 def concat_image_features(image, features, max_features=3): 61 def concat_image_features(image, features, max_features=3):
94 _, h, w = image.shape 62 _, h, w = image.shape
...@@ -159,8 +127,22 @@ def parse_args(kwargs): ...@@ -159,8 +127,22 @@ def parse_args(kwargs):
159 127
160 128
161 def select_model(args): 129 def select_model(args):
162 - if args.network in models.__dict__: 130 + # resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(),
163 - backbone = models.__dict__[args.network]() 131 + # 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()}
132 +
133 +
134 + # grayResNet2
135 + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(),
136 + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()}
137 +
138 + if args.network in resnet_dict:
139 + backbone = resnet_dict[args.network]
140 + #testing
141 + # print("\nRESNET50 LAYERS\n")
142 + # for layer in backbone.children():
143 + # print(layer)
144 + # print("LAYER THE END\n")
145 +
164 model = basenet.BaseNet(backbone, args) 146 model = basenet.BaseNet(backbone, args)
165 else: 147 else:
166 Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') 148 Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
......