조현아

aug data on tensorboard

......@@ -33,11 +33,12 @@ def eval(model_path):
print('\n[+] Load dataset')
test_transform = get_valid_transform(args, model)
test_dataset = get_dataset(args, test_transform, 'test')
test_loader = iter(get_dataloader(args, test_dataset))
print("len(dataset): ", len(test_dataset), type(test_dataset)) # 590
test_loader = iter(get_dataloader(args, test_dataset)) ###
print('\n[+] Start testing')
log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs', model_name)
writer = SummaryWriter(log_dir=log_dir)
writer = SummaryWriter(log_dir=model_path)
_test_res = validate(args, model, criterion, test_loader, step=0, writer=writer)
print('\n[+] Valid results')
......
......@@ -55,7 +55,7 @@ def split_dataset(args, dataset, k):
return Dm_indexes, Da_indexes
#(images[j], first[j]), global_step=step)
def concat_image_features(image, features, max_features=3):
_, h, w = image.shape
......@@ -93,6 +93,7 @@ def concat_image_features(image, features, max_features=3):
image_feature = torch.cat((image_feature, feature), 2) ### dim = 2
#print("\nimg feature size: ", image_feature.size()) #[1, 240, 720]
return image_feature
......@@ -149,10 +150,6 @@ def parse_args(kwargs):
def select_model(args):
# resnet_dict = {'ResNet18':grayResNet.ResNet18(), 'ResNet34':grayResNet.ResNet34(),
# 'ResNet50':grayResNet.ResNet50(), 'ResNet101':grayResNet.ResNet101(), 'ResNet152':grayResNet.ResNet152()}
# grayResNet2
resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(),
'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()}
......@@ -285,7 +282,7 @@ def get_dataset(args, transform, split='train'):
elif args.dataset == 'BraTS':
if split in ['train']:
dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform)
else:
else: #test
dataset = CustomDataset(VAL_DATASET_PATH, transform=transform)
......@@ -382,7 +379,7 @@ def get_valid_transform(args, model):
transforms.ToTensor()
])
elif args.dataset == 'BraTS':
resize_h, resize_w = 256, 256
resize_h, resize_w = 240, 240
val_transform = transforms.Compose([
transforms.Resize([resize_h, resize_w]),
transforms.ToTensor()
......@@ -426,13 +423,14 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer
if writer and step % args.print_step == 0:
n_imgs = min(images.size(0), 10)
tag = 'train/' + str(step)
for j in range(n_imgs):
writer.add_image('train/input_image',
writer.add_image(tag,
concat_image_features(images[j], first[j]), global_step=step)
return acc1, acc5, loss, forward_t, backward_t
# validate(args, model, criterion, test_loader, step=0, writer=writer)
def validate(args, model, criterion, valid_loader, step, writer, device=None):
# switch to evaluate mode
model.eval()
......@@ -441,19 +439,24 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None):
samples = 0
infer_t = 0
img_count = 0
with torch.no_grad():
for i, (images, target) in enumerate(valid_loader): ##
for i, (images, target) in enumerate(valid_loader): ## loop [0, 148]
#print("\n1 images size: ", images.size()) #[4, 1, 240, 240]
start_t = time.time()
if device:
images = images.to(device)
target = target.to(device)
elif args.use_cuda is not None:
elif args.use_cuda is not None: #
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
#print("\n2 images size: ", images.size()) #[4, 1, 240, 240]
# compute output
# first = nn.Sequential(*list(backbone.children())[:1])
output, first = model(images)
loss = criterion(output, target)
infer_t += time.time() - start_t
......@@ -464,14 +467,19 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None):
acc5 += _acc5
samples += images.size(0)
acc1 /= samples
acc5 /= samples
if writer:
# print("\n3 images.size(0): ", images.size(0))
n_imgs = min(images.size(0), 10)
for j in range(n_imgs):
writer.add_image('valid/input_image',
tag = 'valid/' + str(img_count)
writer.add_image(tag,
concat_image_features(images[j], first[j]), global_step=step)
img_count = img_count + 1
acc1 /= samples
acc5 /= samples
return acc1, acc5, loss, infer_t
......