조현아

classifier

1 +import os
2 +import fire
3 +import json
4 +from pprint import pprint
5 +
6 +import torch
7 +import torch.nn as nn
8 +from torch.utils.tensorboard import SummaryWriter
9 +
10 +from utils import *
11 +
12 +# command
13 +# python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/'
14 +
15 +def eval(model_path):
16 + print('\n[+] Parse arguments')
17 + kwargs_path = os.path.join(model_path, 'kwargs.json')
18 + kwargs = json.loads(open(kwargs_path).read())
19 + args, kwargs = parse_args(kwargs)
20 + pprint(args)
21 + device = torch.device('cuda' if args.use_cuda else 'cpu')
22 +
23 + print('\n[+] Create network')
24 + model = select_model(args)
25 + optimizer = select_optimizer(args, model)
26 + criterion = nn.CrossEntropyLoss()
27 + if args.use_cuda:
28 + model = model.cuda()
29 + criterion = criterion.cuda()
30 +
31 + print('\n[+] Load model')
32 + weight_path = os.path.join(model_path, 'model', 'model.pt')
33 + model.load_state_dict(torch.load(weight_path))
34 +
35 + print('\n[+] Load dataset')
36 + test_transform = get_valid_transform(args, model)
37 + #print('\nTEST Transform\n', test_transform)
38 + test_dataset = get_dataset(args, 'test')
39 +
40 +
41 +
42 + test_loader = iter(get_dataloader(args, test_dataset)) ###
43 +
44 + print('\n[+] Start testing')
45 + writer = SummaryWriter(log_dir=model_path)
46 + _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer)
47 +
48 + print('\n[+] Valid results')
49 + print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100))
50 + print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100))
51 + print(' Loss : {:.3f}'.format(_test_res[2].data))
52 + print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset)))
53 +
54 + writer.close()
55 +
56 +if __name__ == '__main__':
57 + fire.Fire(eval)
1 +import torch.nn as nn
2 +
3 +class BaseNet(nn.Module):
4 + def __init__(self, backbone, args):
5 + super(BaseNet, self).__init__()
6 +
7 + # Separate layers
8 + self.first = nn.Sequential(*list(backbone.children())[:1])
9 + self.after = nn.Sequential(*list(backbone.children())[1:-1])
10 + self.fc = list(backbone.children())[-1]
11 +
12 + self.img_size = (240, 240)
13 +
14 + def forward(self, x):
15 + f = self.first(x)
16 + x = self.after(f)
17 + x = x.reshape(x.size(0), -1)
18 + x = self.fc(x)
19 + return x, f
20 +
21 +"""
22 + print("before reshape:\n", x.size())
23 + #[128, 2048, 4, 4]
24 + # #cifar [128, 2048, 1, 1]
25 + x = x.reshape(x.size(0), -1)
26 + print("after reshape:\n", x.size())
27 + #[128, 32768]
28 + #cifar [128, 2048]
29 + #RuntimeError: size mismatch, m1: [128 x 32768], m2: [2048 x 10]
30 + print("fc :\n", self.fc)
31 + #Linear(in_features=2048, out_features=10, bias=True)
32 + #cifar Linear(in_features=2048, out_features=1000, bias=True)
33 +"""
1 +future
2 +tb-nightly
3 +hyperopt
4 +pillow==6.2.1
5 +natsort
6 +fire
7 +torchvision==0.2.2
8 +torch==1.1.0
9 +pandas
10 +sklearn
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import fire
3 +import time
4 +import json
5 +import random
6 +from pprint import pprint
7 +
8 +import torch.nn as nn
9 +import torch.backends.cudnn as cudnn
10 +from torch.utils.tensorboard import SummaryWriter
11 +
12 +from networks import *
13 +from utils import *
14 +
15 +# python train.py --use_cuda=True --network=resnet50 --dataset=BraTS --optimizer=adam
16 +# nohup python train.py --use_cuda=True --network=resnet50 --dataset=BraTS --optimizer=adam &
17 +
18 +
19 +def train(**kwargs):
20 + print('\n[+] Parse arguments')
21 + args, kwargs = parse_args(kwargs)
22 + pprint(args)
23 + device = torch.device('cuda' if args.use_cuda else 'cpu')
24 +
25 + print('\n[+] Create log dir')
26 + model_name = get_model_name(args)
27 + log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name)
28 + os.makedirs(os.path.join(log_dir, 'model'))
29 + json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w'))
30 + writer = SummaryWriter(log_dir=log_dir)
31 +
32 + if args.seed is not None:
33 + random.seed(args.seed)
34 + torch.manual_seed(args.seed)
35 + cudnn.deterministic = True
36 +
37 + print('\n[+] Create network')
38 + model = select_model(args)
39 + optimizer = select_optimizer(args, model)
40 + scheduler = select_scheduler(args, optimizer)
41 + criterion = nn.CrossEntropyLoss()
42 + if args.use_cuda:
43 + model = model.cuda()
44 + criterion = criterion.cuda()
45 + writer.add_graph(model)
46 +
47 + print('\n[+] Load dataset')
48 + transform = get_train_transform(args, model, log_dir)
49 + val_transform = get_valid_transform(args, model)
50 + train_dataset = get_dataset(args, transform, 'train')
51 + valid_dataset = get_dataset(args, val_transform, 'val')
52 + train_loader = iter(get_inf_dataloader(args, train_dataset))
53 + max_epoch = len(train_dataset) // args.batch_size
54 + best_acc = -1
55 +
56 + print('\n[+] Start training')
57 + if torch.cuda.device_count() > 1:
58 + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
59 + model = nn.DataParallel(model)
60 +
61 + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
62 + print('\n[+] Using GPU: {} '.format(torch.cuda.get_device_name(0)))
63 +
64 + start_t = time.time()
65 + for step in range(args.start_step, args.max_step):
66 + batch = next(train_loader)
67 + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer)
68 +
69 + if step % args.print_step == 0:
70 + print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format(
71 + step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr']))
72 + writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step)
73 + writer.add_scalar('train/acc1', _train_res[0], global_step=step)
74 + writer.add_scalar('train/acc5', _train_res[1], global_step=step)
75 + writer.add_scalar('train/loss', _train_res[2], global_step=step)
76 + writer.add_scalar('train/forward_time', _train_res[3], global_step=step)
77 + writer.add_scalar('train/backward_time', _train_res[4], global_step=step)
78 + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100))
79 + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100))
80 + print(' Loss : {}'.format(_train_res[2].data))
81 + print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000))
82 + print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000))
83 +
84 + if step % args.val_step == args.val_step-1:
85 + valid_loader = iter(get_dataloader(args, valid_dataset))
86 + _valid_res = validate(args, model, criterion, valid_loader, step, writer)
87 + print('\n[+] Valid results')
88 + writer.add_scalar('valid/acc1', _valid_res[0], global_step=step)
89 + writer.add_scalar('valid/acc5', _valid_res[1], global_step=step)
90 + writer.add_scalar('valid/loss', _valid_res[2], global_step=step)
91 + print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100))
92 + print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100))
93 + print(' Loss : {}'.format(_valid_res[2].data))
94 +
95 + if _valid_res[0] >= best_acc:
96 + best_acc = _valid_res[0]
97 + torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt"))
98 + print('\n[+] Model saved')
99 +
100 + writer.close()
101 +
102 +
103 +if __name__ == '__main__':
104 + fire.Fire(train)
1 +import os
2 +import time
3 +import importlib
4 +import collections
5 +import pickle as cp
6 +import glob
7 +import numpy as np
8 +import pandas as pd
9 +
10 +from natsort import natsorted
11 +from PIL import Image
12 +import torch
13 +import torchvision
14 +import torch.nn.functional as F
15 +import torchvision.models as models
16 +import torchvision.transforms as transforms
17 +from torch.utils.data import Subset
18 +from torch.utils.data import Dataset, DataLoader
19 +
20 +from sklearn.model_selection import StratifiedShuffleSplit
21 +from sklearn.model_selection import train_test_split
22 +from sklearn.model_selection import KFold
23 +
24 +from networks import *
25 +
26 +
27 +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/'
28 +TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv'
29 +# VAL_DATASET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid/'
30 +# VAL_TARGET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid_targets.csv'
31 +
32 +current_epoch = 0
33 +
34 +
35 +def split_dataset(args, dataset, k):
36 + # load dataset
37 + X = list(range(len(dataset)))
38 + Y = dataset.targets
39 +
40 + # split to k-fold
41 + assert len(X) == len(Y)
42 +
43 + def _it_to_list(_it):
44 + return list(zip(*list(_it)))
45 +
46 + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1)
47 + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y))
48 +
49 + return Dm_indexes, Da_indexes
50 +
51 +
52 +
53 +def get_model_name(args):
54 + from datetime import datetime, timedelta, timezone
55 + now = datetime.now(timezone.utc)
56 + tz = timezone(timedelta(hours=9))
57 + now = now.astimezone(tz)
58 + date_time = now.strftime("%B_%d_%H:%M:%S")
59 + model_name = '__'.join([date_time, args.network, str(args.seed)])
60 + return model_name
61 +
62 +
63 +def dict_to_namedtuple(d):
64 + Args = collections.namedtuple('Args', sorted(d.keys()))
65 +
66 + for k,v in d.items():
67 + if type(v) is dict:
68 + d[k] = dict_to_namedtuple(v)
69 +
70 + elif type(v) is str:
71 + try:
72 + d[k] = eval(v)
73 + except:
74 + d[k] = v
75 +
76 + args = Args(**d)
77 + return args
78 +
79 +
80 +def parse_args(kwargs):
81 + # combine with default args
82 + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS'
83 + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50'
84 + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam'
85 + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.0001
86 + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None
87 + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True
88 + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available()
89 + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4
90 + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500
91 + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500
92 + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp'
93 + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128
94 + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0
95 + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000
96 + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None
97 +
98 + # to named tuple
99 + args = dict_to_namedtuple(kwargs)
100 + return args, kwargs
101 +
102 +
103 +def select_model(args):
104 + # grayResNet2
105 + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(),
106 + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()}
107 +
108 + if args.network in resnet_dict:
109 + backbone = resnet_dict[args.network]
110 + model = basenet.BaseNet(backbone, args)
111 + else:
112 + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
113 + model = Net(args)
114 +
115 + #print(model) # print model architecture
116 + return model
117 +
118 +
119 +def select_optimizer(args, model):
120 + if args.optimizer == 'sgd':
121 + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001)
122 + elif args.optimizer == 'rms':
123 + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate)
124 + elif args.optimizer == 'adam':
125 + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
126 + else:
127 + raise Exception('Unknown Optimizer')
128 + return optimizer
129 +
130 +
131 +def select_scheduler(args, optimizer):
132 + if not args.scheduler or args.scheduler == 'None':
133 + return None
134 + elif args.scheduler =='clr':
135 + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False)
136 + elif args.scheduler =='exp':
137 + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1)
138 + else:
139 + raise Exception('Unknown Scheduler')
140 +
141 +
142 +class CustomDataset(Dataset):
143 + def __init__(self, data_path, csv_path):
144 + self.len = len(self.imgs)
145 + self.path = data_path
146 + self.imgs = natsorted(os.listdir(data_path))
147 +
148 + df = pd.read_csv(csv_path)
149 + targets_list = []
150 +
151 + for fname in self.imgs:
152 + row = df.loc[df['filename'] == fname]
153 + targets_list.append(row.iloc[0, 1])
154 +
155 + self.targets = targets_list
156 +
157 + def __len__(self):
158 + return self.len
159 +
160 + def __getitem__(self, idx):
161 + img_loc = os.path.join(self.path, self.imgs[idx])
162 + targets = self.targets[idx]
163 + image = Image.open(img_loc)
164 + return image, targets
165 +
166 +
167 +
168 +def get_dataset(args, transform, split='train'):
169 + assert split in ['train', 'val', 'test', 'trainval']
170 +
171 + if split in ['train']:
172 + dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform)
173 + else: #test
174 + dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform)
175 +
176 + return dataset
177 +
178 +
179 +def get_dataloader(args, dataset, shuffle=False, pin_memory=True):
180 + data_loader = torch.utils.data.DataLoader(dataset,
181 + batch_size=args.batch_size,
182 + shuffle=shuffle,
183 + num_workers=args.num_workers,
184 + pin_memory=pin_memory)
185 + return data_loader
186 +
187 +
188 +def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True):
189 + data_loader = torch.utils.data.DataLoader(dataset,
190 + batch_size=args.batch_size,
191 + shuffle=shuffle,
192 + num_workers=args.num_workers,
193 + pin_memory=pin_memory)
194 + return data_loader
195 +
196 +
197 +def get_inf_dataloader(args, dataset):
198 + global current_epoch
199 + data_loader = iter(get_dataloader(args, dataset, shuffle=True))
200 +
201 + while True:
202 + try:
203 + batch = next(data_loader)
204 +
205 + except StopIteration:
206 + current_epoch += 1
207 + data_loader = iter(get_dataloader(args, dataset, shuffle=True))
208 + batch = next(data_loader)
209 +
210 + yield batch
211 +
212 +
213 +
214 +
215 +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None):
216 + model.train()
217 + images, target = batch
218 +
219 + if device:
220 + images = images.to(device)
221 + target = target.to(device)
222 +
223 + elif args.use_cuda:
224 + images = images.cuda(non_blocking=True)
225 + target = target.cuda(non_blocking=True)
226 +
227 + # compute output
228 + start_t = time.time()
229 + output, first = model(images)
230 + forward_t = time.time() - start_t
231 + loss = criterion(output, target)
232 +
233 + # measure accuracy and record loss
234 + acc1, acc5 = accuracy(output, target, topk=(1, 5))
235 + acc1 /= images.size(0)
236 + acc5 /= images.size(0)
237 +
238 + # compute gradient and do SGD step
239 + optimizer.zero_grad()
240 + start_t = time.time()
241 + loss.backward()
242 + backward_t = time.time() - start_t
243 + optimizer.step()
244 + if scheduler: scheduler.step()
245 +
246 + if writer and step % args.print_step == 0:
247 + n_imgs = min(images.size(0), 10)
248 + tag = 'train/' + str(step)
249 + for j in range(n_imgs):
250 + writer.add_image(tag,
251 + concat_image_features(images[j], first[j]), global_step=step)
252 +
253 + return acc1, acc5, loss, forward_t, backward_t
254 +
255 +
256 +#_acc1, _acc5 = accuracy(output, target, topk=(1, 5))
257 +def accuracy(output, target, topk=(1,)):
258 + """Computes the accuracy over the k top predictions for the specified values of k"""
259 + with torch.no_grad():
260 + maxk = max(topk)
261 + batch_size = target.size(0)
262 +
263 + _, pred = output.topk(maxk, 1, True, True)
264 + pred = pred.t()
265 + correct = pred.eq(target.view(1, -1).expand_as(pred))
266 +
267 + res = []
268 + for k in topk:
269 + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
270 + res.append(correct_k)
271 + return res
272 +