Showing
5 changed files
with
476 additions
and
0 deletions
code/classifier/eval.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | + | ||
6 | +import torch | ||
7 | +import torch.nn as nn | ||
8 | +from torch.utils.tensorboard import SummaryWriter | ||
9 | + | ||
10 | +from utils import * | ||
11 | + | ||
12 | +# command | ||
13 | +# python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/' | ||
14 | + | ||
15 | +def eval(model_path): | ||
16 | + print('\n[+] Parse arguments') | ||
17 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
18 | + kwargs = json.loads(open(kwargs_path).read()) | ||
19 | + args, kwargs = parse_args(kwargs) | ||
20 | + pprint(args) | ||
21 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
22 | + | ||
23 | + print('\n[+] Create network') | ||
24 | + model = select_model(args) | ||
25 | + optimizer = select_optimizer(args, model) | ||
26 | + criterion = nn.CrossEntropyLoss() | ||
27 | + if args.use_cuda: | ||
28 | + model = model.cuda() | ||
29 | + criterion = criterion.cuda() | ||
30 | + | ||
31 | + print('\n[+] Load model') | ||
32 | + weight_path = os.path.join(model_path, 'model', 'model.pt') | ||
33 | + model.load_state_dict(torch.load(weight_path)) | ||
34 | + | ||
35 | + print('\n[+] Load dataset') | ||
36 | + test_transform = get_valid_transform(args, model) | ||
37 | + #print('\nTEST Transform\n', test_transform) | ||
38 | + test_dataset = get_dataset(args, 'test') | ||
39 | + | ||
40 | + | ||
41 | + | ||
42 | + test_loader = iter(get_dataloader(args, test_dataset)) ### | ||
43 | + | ||
44 | + print('\n[+] Start testing') | ||
45 | + writer = SummaryWriter(log_dir=model_path) | ||
46 | + _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) | ||
47 | + | ||
48 | + print('\n[+] Valid results') | ||
49 | + print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) | ||
50 | + print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100)) | ||
51 | + print(' Loss : {:.3f}'.format(_test_res[2].data)) | ||
52 | + print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[3]*1000 / len(test_dataset))) | ||
53 | + | ||
54 | + writer.close() | ||
55 | + | ||
56 | +if __name__ == '__main__': | ||
57 | + fire.Fire(eval) |
code/classifier/networks/basenet.py
0 → 100644
1 | +import torch.nn as nn | ||
2 | + | ||
3 | +class BaseNet(nn.Module): | ||
4 | + def __init__(self, backbone, args): | ||
5 | + super(BaseNet, self).__init__() | ||
6 | + | ||
7 | + # Separate layers | ||
8 | + self.first = nn.Sequential(*list(backbone.children())[:1]) | ||
9 | + self.after = nn.Sequential(*list(backbone.children())[1:-1]) | ||
10 | + self.fc = list(backbone.children())[-1] | ||
11 | + | ||
12 | + self.img_size = (240, 240) | ||
13 | + | ||
14 | + def forward(self, x): | ||
15 | + f = self.first(x) | ||
16 | + x = self.after(f) | ||
17 | + x = x.reshape(x.size(0), -1) | ||
18 | + x = self.fc(x) | ||
19 | + return x, f | ||
20 | + | ||
21 | +""" | ||
22 | + print("before reshape:\n", x.size()) | ||
23 | + #[128, 2048, 4, 4] | ||
24 | + # #cifar [128, 2048, 1, 1] | ||
25 | + x = x.reshape(x.size(0), -1) | ||
26 | + print("after reshape:\n", x.size()) | ||
27 | + #[128, 32768] | ||
28 | + #cifar [128, 2048] | ||
29 | + #RuntimeError: size mismatch, m1: [128 x 32768], m2: [2048 x 10] | ||
30 | + print("fc :\n", self.fc) | ||
31 | + #Linear(in_features=2048, out_features=10, bias=True) | ||
32 | + #cifar Linear(in_features=2048, out_features=1000, bias=True) | ||
33 | +""" |
code/classifier/requirements.txt
0 → 100644
code/classifier/train.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import time | ||
4 | +import json | ||
5 | +import random | ||
6 | +from pprint import pprint | ||
7 | + | ||
8 | +import torch.nn as nn | ||
9 | +import torch.backends.cudnn as cudnn | ||
10 | +from torch.utils.tensorboard import SummaryWriter | ||
11 | + | ||
12 | +from networks import * | ||
13 | +from utils import * | ||
14 | + | ||
15 | +# python train.py --use_cuda=True --network=resnet50 --dataset=BraTS --optimizer=adam | ||
16 | +# nohup python train.py --use_cuda=True --network=resnet50 --dataset=BraTS --optimizer=adam & | ||
17 | + | ||
18 | + | ||
19 | +def train(**kwargs): | ||
20 | + print('\n[+] Parse arguments') | ||
21 | + args, kwargs = parse_args(kwargs) | ||
22 | + pprint(args) | ||
23 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
24 | + | ||
25 | + print('\n[+] Create log dir') | ||
26 | + model_name = get_model_name(args) | ||
27 | + log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name) | ||
28 | + os.makedirs(os.path.join(log_dir, 'model')) | ||
29 | + json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | ||
30 | + writer = SummaryWriter(log_dir=log_dir) | ||
31 | + | ||
32 | + if args.seed is not None: | ||
33 | + random.seed(args.seed) | ||
34 | + torch.manual_seed(args.seed) | ||
35 | + cudnn.deterministic = True | ||
36 | + | ||
37 | + print('\n[+] Create network') | ||
38 | + model = select_model(args) | ||
39 | + optimizer = select_optimizer(args, model) | ||
40 | + scheduler = select_scheduler(args, optimizer) | ||
41 | + criterion = nn.CrossEntropyLoss() | ||
42 | + if args.use_cuda: | ||
43 | + model = model.cuda() | ||
44 | + criterion = criterion.cuda() | ||
45 | + writer.add_graph(model) | ||
46 | + | ||
47 | + print('\n[+] Load dataset') | ||
48 | + transform = get_train_transform(args, model, log_dir) | ||
49 | + val_transform = get_valid_transform(args, model) | ||
50 | + train_dataset = get_dataset(args, transform, 'train') | ||
51 | + valid_dataset = get_dataset(args, val_transform, 'val') | ||
52 | + train_loader = iter(get_inf_dataloader(args, train_dataset)) | ||
53 | + max_epoch = len(train_dataset) // args.batch_size | ||
54 | + best_acc = -1 | ||
55 | + | ||
56 | + print('\n[+] Start training') | ||
57 | + if torch.cuda.device_count() > 1: | ||
58 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
59 | + model = nn.DataParallel(model) | ||
60 | + | ||
61 | + print('\n[+] Use {} GPUs'.format(torch.cuda.device_count())) | ||
62 | + print('\n[+] Using GPU: {} '.format(torch.cuda.get_device_name(0))) | ||
63 | + | ||
64 | + start_t = time.time() | ||
65 | + for step in range(args.start_step, args.max_step): | ||
66 | + batch = next(train_loader) | ||
67 | + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer) | ||
68 | + | ||
69 | + if step % args.print_step == 0: | ||
70 | + print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( | ||
71 | + step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) | ||
72 | + writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) | ||
73 | + writer.add_scalar('train/acc1', _train_res[0], global_step=step) | ||
74 | + writer.add_scalar('train/acc5', _train_res[1], global_step=step) | ||
75 | + writer.add_scalar('train/loss', _train_res[2], global_step=step) | ||
76 | + writer.add_scalar('train/forward_time', _train_res[3], global_step=step) | ||
77 | + writer.add_scalar('train/backward_time', _train_res[4], global_step=step) | ||
78 | + print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | ||
79 | + print(' Acc@5 : {:.3f}%'.format(_train_res[1].data.cpu().numpy()[0]*100)) | ||
80 | + print(' Loss : {}'.format(_train_res[2].data)) | ||
81 | + print(' FW Time : {:.3f}ms'.format(_train_res[3]*1000)) | ||
82 | + print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | ||
83 | + | ||
84 | + if step % args.val_step == args.val_step-1: | ||
85 | + valid_loader = iter(get_dataloader(args, valid_dataset)) | ||
86 | + _valid_res = validate(args, model, criterion, valid_loader, step, writer) | ||
87 | + print('\n[+] Valid results') | ||
88 | + writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) | ||
89 | + writer.add_scalar('valid/acc5', _valid_res[1], global_step=step) | ||
90 | + writer.add_scalar('valid/loss', _valid_res[2], global_step=step) | ||
91 | + print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) | ||
92 | + print(' Acc@5 : {:.3f}%'.format(_valid_res[1].data.cpu().numpy()[0]*100)) | ||
93 | + print(' Loss : {}'.format(_valid_res[2].data)) | ||
94 | + | ||
95 | + if _valid_res[0] >= best_acc: | ||
96 | + best_acc = _valid_res[0] | ||
97 | + torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) | ||
98 | + print('\n[+] Model saved') | ||
99 | + | ||
100 | + writer.close() | ||
101 | + | ||
102 | + | ||
103 | +if __name__ == '__main__': | ||
104 | + fire.Fire(train) |
code/classifier/util.py
0 → 100644
1 | +import os | ||
2 | +import time | ||
3 | +import importlib | ||
4 | +import collections | ||
5 | +import pickle as cp | ||
6 | +import glob | ||
7 | +import numpy as np | ||
8 | +import pandas as pd | ||
9 | + | ||
10 | +from natsort import natsorted | ||
11 | +from PIL import Image | ||
12 | +import torch | ||
13 | +import torchvision | ||
14 | +import torch.nn.functional as F | ||
15 | +import torchvision.models as models | ||
16 | +import torchvision.transforms as transforms | ||
17 | +from torch.utils.data import Subset | ||
18 | +from torch.utils.data import Dataset, DataLoader | ||
19 | + | ||
20 | +from sklearn.model_selection import StratifiedShuffleSplit | ||
21 | +from sklearn.model_selection import train_test_split | ||
22 | +from sklearn.model_selection import KFold | ||
23 | + | ||
24 | +from networks import * | ||
25 | + | ||
26 | + | ||
27 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | ||
28 | +TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | ||
29 | +# VAL_DATASET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid/' | ||
30 | +# VAL_TARGET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid_targets.csv' | ||
31 | + | ||
32 | +current_epoch = 0 | ||
33 | + | ||
34 | + | ||
35 | +def split_dataset(args, dataset, k): | ||
36 | + # load dataset | ||
37 | + X = list(range(len(dataset))) | ||
38 | + Y = dataset.targets | ||
39 | + | ||
40 | + # split to k-fold | ||
41 | + assert len(X) == len(Y) | ||
42 | + | ||
43 | + def _it_to_list(_it): | ||
44 | + return list(zip(*list(_it))) | ||
45 | + | ||
46 | + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
47 | + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
48 | + | ||
49 | + return Dm_indexes, Da_indexes | ||
50 | + | ||
51 | + | ||
52 | + | ||
53 | +def get_model_name(args): | ||
54 | + from datetime import datetime, timedelta, timezone | ||
55 | + now = datetime.now(timezone.utc) | ||
56 | + tz = timezone(timedelta(hours=9)) | ||
57 | + now = now.astimezone(tz) | ||
58 | + date_time = now.strftime("%B_%d_%H:%M:%S") | ||
59 | + model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
60 | + return model_name | ||
61 | + | ||
62 | + | ||
63 | +def dict_to_namedtuple(d): | ||
64 | + Args = collections.namedtuple('Args', sorted(d.keys())) | ||
65 | + | ||
66 | + for k,v in d.items(): | ||
67 | + if type(v) is dict: | ||
68 | + d[k] = dict_to_namedtuple(v) | ||
69 | + | ||
70 | + elif type(v) is str: | ||
71 | + try: | ||
72 | + d[k] = eval(v) | ||
73 | + except: | ||
74 | + d[k] = v | ||
75 | + | ||
76 | + args = Args(**d) | ||
77 | + return args | ||
78 | + | ||
79 | + | ||
80 | +def parse_args(kwargs): | ||
81 | + # combine with default args | ||
82 | + kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | ||
83 | + kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | ||
84 | + kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
85 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.0001 | ||
86 | + kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
87 | + kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
88 | + kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
89 | + kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
90 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 | ||
91 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 | ||
92 | + kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
93 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
94 | + kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
95 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 | ||
96 | + kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
97 | + | ||
98 | + # to named tuple | ||
99 | + args = dict_to_namedtuple(kwargs) | ||
100 | + return args, kwargs | ||
101 | + | ||
102 | + | ||
103 | +def select_model(args): | ||
104 | + # grayResNet2 | ||
105 | + resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
106 | + 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
107 | + | ||
108 | + if args.network in resnet_dict: | ||
109 | + backbone = resnet_dict[args.network] | ||
110 | + model = basenet.BaseNet(backbone, args) | ||
111 | + else: | ||
112 | + Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
113 | + model = Net(args) | ||
114 | + | ||
115 | + #print(model) # print model architecture | ||
116 | + return model | ||
117 | + | ||
118 | + | ||
119 | +def select_optimizer(args, model): | ||
120 | + if args.optimizer == 'sgd': | ||
121 | + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
122 | + elif args.optimizer == 'rms': | ||
123 | + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
124 | + elif args.optimizer == 'adam': | ||
125 | + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
126 | + else: | ||
127 | + raise Exception('Unknown Optimizer') | ||
128 | + return optimizer | ||
129 | + | ||
130 | + | ||
131 | +def select_scheduler(args, optimizer): | ||
132 | + if not args.scheduler or args.scheduler == 'None': | ||
133 | + return None | ||
134 | + elif args.scheduler =='clr': | ||
135 | + return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
136 | + elif args.scheduler =='exp': | ||
137 | + return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
138 | + else: | ||
139 | + raise Exception('Unknown Scheduler') | ||
140 | + | ||
141 | + | ||
142 | +class CustomDataset(Dataset): | ||
143 | + def __init__(self, data_path, csv_path): | ||
144 | + self.len = len(self.imgs) | ||
145 | + self.path = data_path | ||
146 | + self.imgs = natsorted(os.listdir(data_path)) | ||
147 | + | ||
148 | + df = pd.read_csv(csv_path) | ||
149 | + targets_list = [] | ||
150 | + | ||
151 | + for fname in self.imgs: | ||
152 | + row = df.loc[df['filename'] == fname] | ||
153 | + targets_list.append(row.iloc[0, 1]) | ||
154 | + | ||
155 | + self.targets = targets_list | ||
156 | + | ||
157 | + def __len__(self): | ||
158 | + return self.len | ||
159 | + | ||
160 | + def __getitem__(self, idx): | ||
161 | + img_loc = os.path.join(self.path, self.imgs[idx]) | ||
162 | + targets = self.targets[idx] | ||
163 | + image = Image.open(img_loc) | ||
164 | + return image, targets | ||
165 | + | ||
166 | + | ||
167 | + | ||
168 | +def get_dataset(args, transform, split='train'): | ||
169 | + assert split in ['train', 'val', 'test', 'trainval'] | ||
170 | + | ||
171 | + if split in ['train']: | ||
172 | + dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) | ||
173 | + else: #test | ||
174 | + dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) | ||
175 | + | ||
176 | + return dataset | ||
177 | + | ||
178 | + | ||
179 | +def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
180 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
181 | + batch_size=args.batch_size, | ||
182 | + shuffle=shuffle, | ||
183 | + num_workers=args.num_workers, | ||
184 | + pin_memory=pin_memory) | ||
185 | + return data_loader | ||
186 | + | ||
187 | + | ||
188 | +def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
189 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
190 | + batch_size=args.batch_size, | ||
191 | + shuffle=shuffle, | ||
192 | + num_workers=args.num_workers, | ||
193 | + pin_memory=pin_memory) | ||
194 | + return data_loader | ||
195 | + | ||
196 | + | ||
197 | +def get_inf_dataloader(args, dataset): | ||
198 | + global current_epoch | ||
199 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
200 | + | ||
201 | + while True: | ||
202 | + try: | ||
203 | + batch = next(data_loader) | ||
204 | + | ||
205 | + except StopIteration: | ||
206 | + current_epoch += 1 | ||
207 | + data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
208 | + batch = next(data_loader) | ||
209 | + | ||
210 | + yield batch | ||
211 | + | ||
212 | + | ||
213 | + | ||
214 | + | ||
215 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
216 | + model.train() | ||
217 | + images, target = batch | ||
218 | + | ||
219 | + if device: | ||
220 | + images = images.to(device) | ||
221 | + target = target.to(device) | ||
222 | + | ||
223 | + elif args.use_cuda: | ||
224 | + images = images.cuda(non_blocking=True) | ||
225 | + target = target.cuda(non_blocking=True) | ||
226 | + | ||
227 | + # compute output | ||
228 | + start_t = time.time() | ||
229 | + output, first = model(images) | ||
230 | + forward_t = time.time() - start_t | ||
231 | + loss = criterion(output, target) | ||
232 | + | ||
233 | + # measure accuracy and record loss | ||
234 | + acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
235 | + acc1 /= images.size(0) | ||
236 | + acc5 /= images.size(0) | ||
237 | + | ||
238 | + # compute gradient and do SGD step | ||
239 | + optimizer.zero_grad() | ||
240 | + start_t = time.time() | ||
241 | + loss.backward() | ||
242 | + backward_t = time.time() - start_t | ||
243 | + optimizer.step() | ||
244 | + if scheduler: scheduler.step() | ||
245 | + | ||
246 | + if writer and step % args.print_step == 0: | ||
247 | + n_imgs = min(images.size(0), 10) | ||
248 | + tag = 'train/' + str(step) | ||
249 | + for j in range(n_imgs): | ||
250 | + writer.add_image(tag, | ||
251 | + concat_image_features(images[j], first[j]), global_step=step) | ||
252 | + | ||
253 | + return acc1, acc5, loss, forward_t, backward_t | ||
254 | + | ||
255 | + | ||
256 | +#_acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
257 | +def accuracy(output, target, topk=(1,)): | ||
258 | + """Computes the accuracy over the k top predictions for the specified values of k""" | ||
259 | + with torch.no_grad(): | ||
260 | + maxk = max(topk) | ||
261 | + batch_size = target.size(0) | ||
262 | + | ||
263 | + _, pred = output.topk(maxk, 1, True, True) | ||
264 | + pred = pred.t() | ||
265 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
266 | + | ||
267 | + res = [] | ||
268 | + for k in topk: | ||
269 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
270 | + res.append(correct_k) | ||
271 | + return res | ||
272 | + |
-
Please register or login to post a comment