Showing
6 changed files
with
155 additions
and
39 deletions
code/FAA2/cifar_utils.py
0 → 100644
This diff is collapsed. Click to expand it.
... | @@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None): | ... | @@ -54,10 +54,13 @@ def train_child(args, model, dataset, subset_indx, device=None): |
54 | if torch.cuda.device_count() > 1: | 54 | if torch.cuda.device_count() > 1: |
55 | print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | 55 | print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) |
56 | model = nn.DataParallel(model) | 56 | model = nn.DataParallel(model) |
57 | + elif torch.cuda.device_count() == 1: | ||
58 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
57 | 59 | ||
58 | start_t = time.time() | 60 | start_t = time.time() |
59 | for step in range(args.start_step, args.max_step): | 61 | for step in range(args.start_step, args.max_step): |
60 | batch = next(data_loader) | 62 | batch = next(data_loader) |
63 | + | ||
61 | _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) | 64 | _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, None, device) |
62 | 65 | ||
63 | if step % args.print_step == 0: | 66 | if step % args.print_step == 0: |
... | @@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat | ... | @@ -173,7 +176,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat |
173 | device = torch.device('cuda:%d' % device_id) | 176 | device = torch.device('cuda:%d' % device_id) |
174 | _transform = [] | 177 | _transform = [] |
175 | 178 | ||
176 | - print('[+] Child %d training strated (GPU: %d)' % (k, device_id)) | 179 | + print('[+] Child %d training started (GPU: %d)' % (k, device_id)) |
177 | 180 | ||
178 | # train child model | 181 | # train child model |
179 | child_model = copy.deepcopy(model) | 182 | child_model = copy.deepcopy(model) |
... | @@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat | ... | @@ -188,7 +191,7 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat |
188 | 191 | ||
189 | return _transform | 192 | return _transform |
190 | 193 | ||
191 | - | 194 | +#fast_auto_augment(args, model, K=4, B=1, num_process=4) |
192 | def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | 195 | def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): |
193 | args_str = json.dumps(args._asdict()) | 196 | args_str = json.dumps(args._asdict()) |
194 | dataset = get_dataset(args, None, 'trainval') | 197 | dataset = get_dataset(args, None, 'trainval') | ... | ... |
... | @@ -4,6 +4,12 @@ class BaseNet(nn.Module): | ... | @@ -4,6 +4,12 @@ class BaseNet(nn.Module): |
4 | def __init__(self, backbone, args): | 4 | def __init__(self, backbone, args): |
5 | super(BaseNet, self).__init__() | 5 | super(BaseNet, self).__init__() |
6 | 6 | ||
7 | + #testing | ||
8 | + for layer in backbone.children(): | ||
9 | + print("\nRESNET50 LAYERS\n") | ||
10 | + print(layer) | ||
11 | + | ||
12 | + | ||
7 | # Separate layers | 13 | # Separate layers |
8 | self.first = nn.Sequential(*list(backbone.children())[:1]) | 14 | self.first = nn.Sequential(*list(backbone.children())[:1]) |
9 | self.after = nn.Sequential(*list(backbone.children())[1:-1]) | 15 | self.after = nn.Sequential(*list(backbone.children())[1:-1]) |
... | @@ -14,6 +20,20 @@ class BaseNet(nn.Module): | ... | @@ -14,6 +20,20 @@ class BaseNet(nn.Module): |
14 | def forward(self, x): | 20 | def forward(self, x): |
15 | f = self.first(x) | 21 | f = self.first(x) |
16 | x = self.after(f) | 22 | x = self.after(f) |
17 | - x = x.reshape(x.size(0), -1) | ||
18 | x = self.fc(x) | 23 | x = self.fc(x) |
19 | return x, f | 24 | return x, f |
25 | + | ||
26 | + | ||
27 | +""" | ||
28 | + print("before reshape:\n", x.size()) | ||
29 | + #[128, 2048, 4, 4] | ||
30 | + # #cifar 내장[128, 2048, 1, 1] | ||
31 | + x = x.reshape(x.size(0), -1) | ||
32 | + print("after reshape:\n", x.size()) | ||
33 | + #[128, 32768] | ||
34 | + #cifar [128, 2048] | ||
35 | + #RuntimeError: size mismatch, m1: [128 x 32768], m2: [2048 x 10] | ||
36 | + print("fc :\n", self.fc) | ||
37 | + #Linear(in_features=2048, out_features=10, bias=True) | ||
38 | + #cifar Linear(in_features=2048, out_features=1000, bias=True) | ||
39 | +""" | ... | ... |
code/FAA2/networks/grayResNet.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import torch.nn.functional as F | ||
4 | + | ||
5 | + | ||
6 | +class BasicBlock(nn.Module): | ||
7 | + expansion = 1 | ||
8 | + | ||
9 | + def __init__(self, in_planes, planes, stride=1): | ||
10 | + super(BasicBlock, self).__init__() | ||
11 | + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
12 | + self.bn1 = nn.BatchNorm2d(planes) | ||
13 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | ||
14 | + self.bn2 = nn.BatchNorm2d(planes) | ||
15 | + | ||
16 | + self.shortcut = nn.Sequential() | ||
17 | + if stride != 1 or in_planes != self.expansion*planes: | ||
18 | + self.shortcut = nn.Sequential( | ||
19 | + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | ||
20 | + nn.BatchNorm2d(self.expansion*planes) | ||
21 | + ) | ||
22 | + | ||
23 | + def forward(self, x): | ||
24 | + out = F.relu(self.bn1(self.conv1(x))) | ||
25 | + out = self.bn2(self.conv2(out)) | ||
26 | + out += self.shortcut(x) | ||
27 | + out = F.relu(out) | ||
28 | + return out | ||
29 | + | ||
30 | + | ||
31 | +class Bottleneck(nn.Module): | ||
32 | + expansion = 4 | ||
33 | + | ||
34 | + def __init__(self, in_planes, planes, stride=1): | ||
35 | + super(Bottleneck, self).__init__() | ||
36 | + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) | ||
37 | + self.bn1 = nn.BatchNorm2d(planes) | ||
38 | + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||
39 | + self.bn2 = nn.BatchNorm2d(planes) | ||
40 | + self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) | ||
41 | + self.bn3 = nn.BatchNorm2d(self.expansion*planes) | ||
42 | + | ||
43 | + self.shortcut = nn.Sequential() | ||
44 | + if stride != 1 or in_planes != self.expansion*planes: | ||
45 | + self.shortcut = nn.Sequential( | ||
46 | + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | ||
47 | + nn.BatchNorm2d(self.expansion*planes) | ||
48 | + ) | ||
49 | + | ||
50 | + def forward(self, x): | ||
51 | + out = F.relu(self.bn1(self.conv1(x))) | ||
52 | + out = F.relu(self.bn2(self.conv2(out))) | ||
53 | + out = self.bn3(self.conv3(out)) | ||
54 | + out += self.shortcut(x) | ||
55 | + out = F.relu(out) | ||
56 | + return out | ||
57 | + | ||
58 | + | ||
59 | +class ResNet(nn.Module): | ||
60 | + def __init__(self, block, num_blocks, num_classes=10): | ||
61 | + super(ResNet, self).__init__() | ||
62 | + self.in_planes = 64 | ||
63 | + | ||
64 | + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False) | ||
65 | + self.bn1 = nn.BatchNorm2d(64) | ||
66 | + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | ||
67 | + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | ||
68 | + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | ||
69 | + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | ||
70 | + self.linear = nn.Linear(512*block.expansion, num_classes) | ||
71 | + | ||
72 | + def _make_layer(self, block, planes, num_blocks, stride): | ||
73 | + strides = [stride] + [1]*(num_blocks-1) | ||
74 | + layers = [] | ||
75 | + for stride in strides: | ||
76 | + layers.append(block(self.in_planes, planes, stride)) | ||
77 | + self.in_planes = planes * block.expansion | ||
78 | + return nn.Sequential(*layers) | ||
79 | + | ||
80 | + def forward(self, x): | ||
81 | + out = F.relu(self.bn1(self.conv1(x))) | ||
82 | + out = self.layer1(out) | ||
83 | + out = self.layer2(out) | ||
84 | + out = self.layer3(out) | ||
85 | + out = self.layer4(out) | ||
86 | + out = F.avg_pool2d(out, 4) | ||
87 | + out = out.view(out.size(0), -1) | ||
88 | + out = self.linear(out) | ||
89 | + return out | ||
90 | + | ||
91 | + | ||
92 | +def ResNet18(): | ||
93 | + return ResNet(BasicBlock, [2,2,2,2]) | ||
94 | + | ||
95 | +def ResNet34(): | ||
96 | + return ResNet(BasicBlock, [3,4,6,3]) | ||
97 | + | ||
98 | +def ResNet50(): | ||
99 | + return ResNet(Bottleneck, [3,4,6,3]) | ||
100 | + | ||
101 | +def ResNet101(): | ||
102 | + return ResNet(Bottleneck, [3,4,23,3]) | ||
103 | + | ||
104 | +def ResNet152(): | ||
105 | + return ResNet(Bottleneck, [3,8,36,3]) | ||
106 | + | ||
107 | + | ||
108 | +def test(): | ||
109 | + net = ResNet18() | ||
110 | + y = net(torch.randn(1,3,32,32)) | ||
111 | + print(y.size()) |
code/FAA2/networks/grayResNet2.py
0 → 100644
This diff is collapsed. Click to expand it.
... | @@ -6,6 +6,7 @@ import pickle as cp | ... | @@ -6,6 +6,7 @@ import pickle as cp |
6 | import glob | 6 | import glob |
7 | import numpy as np | 7 | import numpy as np |
8 | import pandas as pd | 8 | import pandas as pd |
9 | + | ||
9 | from natsort import natsorted | 10 | from natsort import natsorted |
10 | from PIL import Image | 11 | from PIL import Image |
11 | import torch | 12 | import torch |
... | @@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split | ... | @@ -21,6 +22,7 @@ from sklearn.model_selection import train_test_split |
21 | from sklearn.model_selection import KFold | 22 | from sklearn.model_selection import KFold |
22 | 23 | ||
23 | from networks import basenet | 24 | from networks import basenet |
25 | +from networks import grayResNet, grayResNet2 | ||
24 | 26 | ||
25 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 27 | DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
26 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' | 28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
... | @@ -55,40 +57,6 @@ def split_dataset(args, dataset, k): | ... | @@ -55,40 +57,6 @@ def split_dataset(args, dataset, k): |
55 | 57 | ||
56 | return Dm_indexes, Da_indexes | 58 | return Dm_indexes, Da_indexes |
57 | 59 | ||
58 | -def split_dataset2222(args, dataset, k): | ||
59 | - # load dataset | ||
60 | - X = list(range(len(dataset))) | ||
61 | - | ||
62 | - # split to k-fold | ||
63 | - #assert len(X) == len(Y) | ||
64 | - | ||
65 | - def _it_to_list(_it): | ||
66 | - return list(zip(*list(_it))) | ||
67 | - | ||
68 | - x_train = () | ||
69 | - x_test = () | ||
70 | - | ||
71 | - for i in range(k): | ||
72 | - #xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) | ||
73 | - xtr, xte = train_test_split(X, random_state=None, test_size=0.1) | ||
74 | - x_train.append(np.array(xtr)) | ||
75 | - x_test.append(np.array(xte)) | ||
76 | - | ||
77 | - y_train = np.array([0]* len(x_train)) | ||
78 | - y_test = np.array([0]* len(x_test)) | ||
79 | - | ||
80 | - x_train = tuple(x_train) | ||
81 | - x_test = tuple(x_test) | ||
82 | - | ||
83 | - trainset = (zip(x_train, y_train),) | ||
84 | - testset = (zip(x_test, y_test),) | ||
85 | - | ||
86 | - Dm_indexes, Da_indexes = trainset, testset | ||
87 | - | ||
88 | - print(type(Dm_indexes), np.shape(Dm_indexes)) | ||
89 | - print("DM\n", np.shape(Dm_indexes), Dm_indexes, "\nDA\n", np.shape(Da_indexes), Da_indexes) | ||
90 | - | ||
91 | - return Dm_indexes, Da_indexes | ||
92 | 60 | ||
93 | def concat_image_features(image, features, max_features=3): | 61 | def concat_image_features(image, features, max_features=3): |
94 | _, h, w = image.shape | 62 | _, h, w = image.shape |
... | @@ -159,8 +127,22 @@ def parse_args(kwargs): | ... | @@ -159,8 +127,22 @@ def parse_args(kwargs): |
159 | 127 | ||
160 | 128 | ||
161 | def select_model(args): | 129 | def select_model(args): |
162 | - if args.network in models.__dict__: | 130 | + # resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(), |
163 | - backbone = models.__dict__[args.network]() | 131 | + # 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()} |
132 | + | ||
133 | + | ||
134 | + # grayResNet2 | ||
135 | + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
136 | + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
137 | + | ||
138 | + if args.network in resnet_dict: | ||
139 | + backbone = resnet_dict[args.network] | ||
140 | + #testing | ||
141 | + # print("\nRESNET50 LAYERS\n") | ||
142 | + # for layer in backbone.children(): | ||
143 | + # print(layer) | ||
144 | + # print("LAYER THE END\n") | ||
145 | + | ||
164 | model = basenet.BaseNet(backbone, args) | 146 | model = basenet.BaseNet(backbone, args) |
165 | else: | 147 | else: |
166 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | 148 | Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ... | ... |
-
Please register or login to post a comment