조현아

aug data on tensorboard

...@@ -33,11 +33,12 @@ def eval(model_path): ...@@ -33,11 +33,12 @@ def eval(model_path):
33 print('\n[+] Load dataset') 33 print('\n[+] Load dataset')
34 test_transform = get_valid_transform(args, model) 34 test_transform = get_valid_transform(args, model)
35 test_dataset = get_dataset(args, test_transform, 'test') 35 test_dataset = get_dataset(args, test_transform, 'test')
36 - test_loader = iter(get_dataloader(args, test_dataset)) 36 + print("len(dataset): ", len(test_dataset), type(test_dataset)) # 590
37 +
38 + test_loader = iter(get_dataloader(args, test_dataset)) ###
37 39
38 print('\n[+] Start testing') 40 print('\n[+] Start testing')
39 - log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs', model_name) 41 + writer = SummaryWriter(log_dir=model_path)
40 - writer = SummaryWriter(log_dir=log_dir)
41 _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) 42 _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer)
42 43
43 print('\n[+] Valid results') 44 print('\n[+] Valid results')
......
...@@ -55,7 +55,7 @@ def split_dataset(args, dataset, k): ...@@ -55,7 +55,7 @@ def split_dataset(args, dataset, k):
55 55
56 return Dm_indexes, Da_indexes 56 return Dm_indexes, Da_indexes
57 57
58 - 58 +#(images[j], first[j]), global_step=step)
59 def concat_image_features(image, features, max_features=3): 59 def concat_image_features(image, features, max_features=3):
60 _, h, w = image.shape 60 _, h, w = image.shape
61 61
...@@ -93,6 +93,7 @@ def concat_image_features(image, features, max_features=3): ...@@ -93,6 +93,7 @@ def concat_image_features(image, features, max_features=3):
93 93
94 94
95 image_feature = torch.cat((image_feature, feature), 2) ### dim = 2 95 image_feature = torch.cat((image_feature, feature), 2) ### dim = 2
96 + #print("\nimg feature size: ", image_feature.size()) #[1, 240, 720]
96 97
97 return image_feature 98 return image_feature
98 99
...@@ -149,10 +150,6 @@ def parse_args(kwargs): ...@@ -149,10 +150,6 @@ def parse_args(kwargs):
149 150
150 151
151 def select_model(args): 152 def select_model(args):
152 - # resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(),
153 - # 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()}
154 -
155 -
156 # grayResNet2 153 # grayResNet2
157 resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), 154 resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(),
158 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} 155 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()}
...@@ -285,7 +282,7 @@ def get_dataset(args, transform, split='train'): ...@@ -285,7 +282,7 @@ def get_dataset(args, transform, split='train'):
285 elif args.dataset == 'BraTS': 282 elif args.dataset == 'BraTS':
286 if split in ['train']: 283 if split in ['train']:
287 dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) 284 dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform)
288 - else: 285 + else: #test
289 dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) 286 dataset = CustomDataset(VAL_DATASET_PATH, transform=transform)
290 287
291 288
...@@ -382,7 +379,7 @@ def get_valid_transform(args, model): ...@@ -382,7 +379,7 @@ def get_valid_transform(args, model):
382 transforms.ToTensor() 379 transforms.ToTensor()
383 ]) 380 ])
384 elif args.dataset == 'BraTS': 381 elif args.dataset == 'BraTS':
385 - resize_h, resize_w = 256, 256 382 + resize_h, resize_w = 240, 240
386 val_transform = transforms.Compose([ 383 val_transform = transforms.Compose([
387 transforms.Resize([resize_h, resize_w]), 384 transforms.Resize([resize_h, resize_w]),
388 transforms.ToTensor() 385 transforms.ToTensor()
...@@ -426,13 +423,14 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer ...@@ -426,13 +423,14 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer
426 423
427 if writer and step % args.print_step == 0: 424 if writer and step % args.print_step == 0:
428 n_imgs = min(images.size(0), 10) 425 n_imgs = min(images.size(0), 10)
426 + tag = 'train/' + str(step)
429 for j in range(n_imgs): 427 for j in range(n_imgs):
430 - writer.add_image('train/input_image', 428 + writer.add_image(tag,
431 concat_image_features(images[j], first[j]), global_step=step) 429 concat_image_features(images[j], first[j]), global_step=step)
432 430
433 return acc1, acc5, loss, forward_t, backward_t 431 return acc1, acc5, loss, forward_t, backward_t
434 432
435 - 433 +# validate(args, model, criterion, test_loader, step=0, writer=writer)
436 def validate(args, model, criterion, valid_loader, step, writer, device=None): 434 def validate(args, model, criterion, valid_loader, step, writer, device=None):
437 # switch to evaluate mode 435 # switch to evaluate mode
438 model.eval() 436 model.eval()
...@@ -441,19 +439,24 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): ...@@ -441,19 +439,24 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None):
441 samples = 0 439 samples = 0
442 infer_t = 0 440 infer_t = 0
443 441
442 + img_count = 0
443 +
444 with torch.no_grad(): 444 with torch.no_grad():
445 - for i, (images, target) in enumerate(valid_loader): ## 445 + for i, (images, target) in enumerate(valid_loader): ## loop [0, 148]
446 446
447 + #print("\n1 images size: ", images.size()) #[4, 1, 240, 240]
447 start_t = time.time() 448 start_t = time.time()
448 if device: 449 if device:
449 images = images.to(device) 450 images = images.to(device)
450 target = target.to(device) 451 target = target.to(device)
451 452
452 - elif args.use_cuda is not None: 453 + elif args.use_cuda is not None: #
453 images = images.cuda(non_blocking=True) 454 images = images.cuda(non_blocking=True)
454 target = target.cuda(non_blocking=True) 455 target = target.cuda(non_blocking=True)
456 + #print("\n2 images size: ", images.size()) #[4, 1, 240, 240]
455 457
456 # compute output 458 # compute output
459 + # first = nn.Sequential(*list(backbone.children())[:1])
457 output, first = model(images) 460 output, first = model(images)
458 loss = criterion(output, target) 461 loss = criterion(output, target)
459 infer_t += time.time() - start_t 462 infer_t += time.time() - start_t
...@@ -464,14 +467,19 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): ...@@ -464,14 +467,19 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None):
464 acc5 += _acc5 467 acc5 += _acc5
465 samples += images.size(0) 468 samples += images.size(0)
466 469
470 + if writer:
471 + # print("\n3 images.size(0): ", images.size(0))
472 + n_imgs = min(images.size(0), 10)
473 + for j in range(n_imgs):
474 + tag = 'valid/' + str(img_count)
475 + writer.add_image(tag,
476 + concat_image_features(images[j], first[j]), global_step=step)
477 + img_count = img_count + 1
478 +
467 acc1 /= samples 479 acc1 /= samples
468 acc5 /= samples 480 acc5 /= samples
469 481
470 - if writer: 482 +
471 - n_imgs = min(images.size(0), 10)
472 - for j in range(n_imgs):
473 - writer.add_image('valid/input_image',
474 - concat_image_features(images[j], first[j]), global_step=step)
475 483
476 return acc1, acc5, loss, infer_t 484 return acc1, acc5, loss, infer_t
477 485
......