Showing
2 changed files
with
27 additions
and
18 deletions
... | @@ -33,11 +33,12 @@ def eval(model_path): | ... | @@ -33,11 +33,12 @@ def eval(model_path): |
33 | print('\n[+] Load dataset') | 33 | print('\n[+] Load dataset') |
34 | test_transform = get_valid_transform(args, model) | 34 | test_transform = get_valid_transform(args, model) |
35 | test_dataset = get_dataset(args, test_transform, 'test') | 35 | test_dataset = get_dataset(args, test_transform, 'test') |
36 | - test_loader = iter(get_dataloader(args, test_dataset)) | 36 | + print("len(dataset): ", len(test_dataset), type(test_dataset)) # 590 |
37 | + | ||
38 | + test_loader = iter(get_dataloader(args, test_dataset)) ### | ||
37 | 39 | ||
38 | print('\n[+] Start testing') | 40 | print('\n[+] Start testing') |
39 | - log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs', model_name) | 41 | + writer = SummaryWriter(log_dir=model_path) |
40 | - writer = SummaryWriter(log_dir=log_dir) | ||
41 | _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) | 42 | _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) |
42 | 43 | ||
43 | print('\n[+] Valid results') | 44 | print('\n[+] Valid results') | ... | ... |
... | @@ -55,7 +55,7 @@ def split_dataset(args, dataset, k): | ... | @@ -55,7 +55,7 @@ def split_dataset(args, dataset, k): |
55 | 55 | ||
56 | return Dm_indexes, Da_indexes | 56 | return Dm_indexes, Da_indexes |
57 | 57 | ||
58 | - | 58 | +#(images[j], first[j]), global_step=step) |
59 | def concat_image_features(image, features, max_features=3): | 59 | def concat_image_features(image, features, max_features=3): |
60 | _, h, w = image.shape | 60 | _, h, w = image.shape |
61 | 61 | ||
... | @@ -93,6 +93,7 @@ def concat_image_features(image, features, max_features=3): | ... | @@ -93,6 +93,7 @@ def concat_image_features(image, features, max_features=3): |
93 | 93 | ||
94 | 94 | ||
95 | image_feature = torch.cat((image_feature, feature), 2) ### dim = 2 | 95 | image_feature = torch.cat((image_feature, feature), 2) ### dim = 2 |
96 | + #print("\nimg feature size: ", image_feature.size()) #[1, 240, 720] | ||
96 | 97 | ||
97 | return image_feature | 98 | return image_feature |
98 | 99 | ||
... | @@ -149,10 +150,6 @@ def parse_args(kwargs): | ... | @@ -149,10 +150,6 @@ def parse_args(kwargs): |
149 | 150 | ||
150 | 151 | ||
151 | def select_model(args): | 152 | def select_model(args): |
152 | - # resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(), | ||
153 | - # 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()} | ||
154 | - | ||
155 | - | ||
156 | # grayResNet2 | 153 | # grayResNet2 |
157 | resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | 154 | resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), |
158 | 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | 155 | 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} |
... | @@ -285,7 +282,7 @@ def get_dataset(args, transform, split='train'): | ... | @@ -285,7 +282,7 @@ def get_dataset(args, transform, split='train'): |
285 | elif args.dataset == 'BraTS': | 282 | elif args.dataset == 'BraTS': |
286 | if split in ['train']: | 283 | if split in ['train']: |
287 | dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) | 284 | dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) |
288 | - else: | 285 | + else: #test |
289 | dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) | 286 | dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) |
290 | 287 | ||
291 | 288 | ||
... | @@ -382,7 +379,7 @@ def get_valid_transform(args, model): | ... | @@ -382,7 +379,7 @@ def get_valid_transform(args, model): |
382 | transforms.ToTensor() | 379 | transforms.ToTensor() |
383 | ]) | 380 | ]) |
384 | elif args.dataset == 'BraTS': | 381 | elif args.dataset == 'BraTS': |
385 | - resize_h, resize_w = 256, 256 | 382 | + resize_h, resize_w = 240, 240 |
386 | val_transform = transforms.Compose([ | 383 | val_transform = transforms.Compose([ |
387 | transforms.Resize([resize_h, resize_w]), | 384 | transforms.Resize([resize_h, resize_w]), |
388 | transforms.ToTensor() | 385 | transforms.ToTensor() |
... | @@ -426,13 +423,14 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer | ... | @@ -426,13 +423,14 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer |
426 | 423 | ||
427 | if writer and step % args.print_step == 0: | 424 | if writer and step % args.print_step == 0: |
428 | n_imgs = min(images.size(0), 10) | 425 | n_imgs = min(images.size(0), 10) |
426 | + tag = 'train/' + str(step) | ||
429 | for j in range(n_imgs): | 427 | for j in range(n_imgs): |
430 | - writer.add_image('train/input_image', | 428 | + writer.add_image(tag, |
431 | concat_image_features(images[j], first[j]), global_step=step) | 429 | concat_image_features(images[j], first[j]), global_step=step) |
432 | 430 | ||
433 | return acc1, acc5, loss, forward_t, backward_t | 431 | return acc1, acc5, loss, forward_t, backward_t |
434 | 432 | ||
435 | - | 433 | +# validate(args, model, criterion, test_loader, step=0, writer=writer) |
436 | def validate(args, model, criterion, valid_loader, step, writer, device=None): | 434 | def validate(args, model, criterion, valid_loader, step, writer, device=None): |
437 | # switch to evaluate mode | 435 | # switch to evaluate mode |
438 | model.eval() | 436 | model.eval() |
... | @@ -441,19 +439,24 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -441,19 +439,24 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
441 | samples = 0 | 439 | samples = 0 |
442 | infer_t = 0 | 440 | infer_t = 0 |
443 | 441 | ||
442 | + img_count = 0 | ||
443 | + | ||
444 | with torch.no_grad(): | 444 | with torch.no_grad(): |
445 | - for i, (images, target) in enumerate(valid_loader): ## | 445 | + for i, (images, target) in enumerate(valid_loader): ## loop [0, 148] |
446 | 446 | ||
447 | + #print("\n1 images size: ", images.size()) #[4, 1, 240, 240] | ||
447 | start_t = time.time() | 448 | start_t = time.time() |
448 | if device: | 449 | if device: |
449 | images = images.to(device) | 450 | images = images.to(device) |
450 | target = target.to(device) | 451 | target = target.to(device) |
451 | 452 | ||
452 | - elif args.use_cuda is not None: | 453 | + elif args.use_cuda is not None: # |
453 | images = images.cuda(non_blocking=True) | 454 | images = images.cuda(non_blocking=True) |
454 | target = target.cuda(non_blocking=True) | 455 | target = target.cuda(non_blocking=True) |
456 | + #print("\n2 images size: ", images.size()) #[4, 1, 240, 240] | ||
455 | 457 | ||
456 | # compute output | 458 | # compute output |
459 | + # first = nn.Sequential(*list(backbone.children())[:1]) | ||
457 | output, first = model(images) | 460 | output, first = model(images) |
458 | loss = criterion(output, target) | 461 | loss = criterion(output, target) |
459 | infer_t += time.time() - start_t | 462 | infer_t += time.time() - start_t |
... | @@ -464,14 +467,19 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -464,14 +467,19 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
464 | acc5 += _acc5 | 467 | acc5 += _acc5 |
465 | samples += images.size(0) | 468 | samples += images.size(0) |
466 | 469 | ||
467 | - acc1 /= samples | ||
468 | - acc5 /= samples | ||
469 | - | ||
470 | if writer: | 470 | if writer: |
471 | + # print("\n3 images.size(0): ", images.size(0)) | ||
471 | n_imgs = min(images.size(0), 10) | 472 | n_imgs = min(images.size(0), 10) |
472 | for j in range(n_imgs): | 473 | for j in range(n_imgs): |
473 | - writer.add_image('valid/input_image', | 474 | + tag = 'valid/' + str(img_count) |
475 | + writer.add_image(tag, | ||
474 | concat_image_features(images[j], first[j]), global_step=step) | 476 | concat_image_features(images[j], first[j]), global_step=step) |
477 | + img_count = img_count + 1 | ||
478 | + | ||
479 | + acc1 /= samples | ||
480 | + acc5 /= samples | ||
481 | + | ||
482 | + | ||
475 | 483 | ||
476 | return acc1, acc5, loss, infer_t | 484 | return acc1, acc5, loss, infer_t |
477 | 485 | ... | ... |
-
Please register or login to post a comment