getAugmented_saveimg.py 2.72 KB
import os
import fire
import json
from pprint import pprint
import pickle
import random
import numpy as np
import cv2

import torch
import torch.nn as nn
from torchvision.utils import save_image

from utils import *

# command
# python getAugmented_saveimg.py --model_path='logs/April_26_00:55:16__resnet50__None/'

def eval(model_path):
    print('\n[+] Parse arguments')
    kwargs_path = os.path.join(model_path, 'kwargs.json')
    kwargs = json.loads(open(kwargs_path).read())
    args, kwargs = parse_args(kwargs)
    pprint(args)
    device = torch.device('cuda' if args.use_cuda else 'cpu')

    cp_path = os.path.join(model_path, 'augmentation.cp')


    print('\n[+] Load transform')
    # list to tensor
    with open(cp_path, 'rb') as f:
        aug_transform_list = pickle.load(f)

    transform = transforms.RandomChoice(aug_transform_list)


    print('\n[+] Load dataset')

    dataset = get_dataset(args, transform, 'train')
    loader = iter(get_aug_dataloader(args, dataset))


    print('\n[+] Save 1 random policy')   

    # save segmented lesion images
    # os.makedirs(os.path.join(model_path, 'aug_seg'))
    # save_dir = os.path.join(model_path, 'aug_seg')

    # for i, (image, target) in enumerate(loader):
    #     image = image.view(240, 240)
    #     # save img
    #     save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png'))

    #     if(i % 100 == 0):
    #         print("\n saved images: ", i)
                

    # save synthesized images
    save_dir = os.path.join(model_path, 'aug_synthesized')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    normal_dir = '/root/volume/2016104167/data/MICCAI_BraTS_2019_Data_Training/Normal_frames_all'

    for i, (image, target) in enumerate(loader):
        image = image.view(240, 240)

        # get random normal brain img
        nor_file = random.choice(os.listdir(normal_dir))
        nor_img = cv2.imread(os.path.join(normal_dir, nor_file), cv2.IMREAD_GRAYSCALE)
        # print(nor_img.shape) # (256, 224)
        nor_img = cv2.resize(nor_img, (240, 240))

        # save normal, lesion image
        # save_image(image, os.path.join(save_dir, 'lesion_'+ str(i) + '.png'))
        # cv2.imwrite(os.path.join(save_dir, 'nor_'+ str(i) + '.png'), nor_img)
  

        # synthesize
        image = np.asarray(image)    
        image_255 = image * 255
        image_255[image_255 < 5] = 0
        nor_img[image_255 != 0] = 0
        syn_image = nor_img + image_255

        # save synthesized img
        cv2.imwrite(os.path.join(save_dir, 'aug_'+ str(i) + '.png'), syn_image)

        if(i % 100 == 0):
            print("\n saved images: ", i)
            break

    print('\n[+] Finished to save')

if __name__ == '__main__':
    fire.Fire(eval)