getRDAugmented_saveimg.py 3.2 KB

import os
import fire
import json
from pprint import pprint
import pickle
import random
import numpy as np
import cv2

import torch
import torch.nn as nn
from torchvision.utils import save_image
import torchvision.transforms as transforms

from transforms import *
from utils import *

# command
# python getRDAugmented_saveimg.py --model_path='logs/April_26_17:36:17_NL_resnet50__None'


DEFALUT_CANDIDATES = [
    ShearXY,
    TranslateXY,
    # Rotate,
    AutoContrast,
    # Invert,
    # Equalize, # Histogram Equalize --> white tumor
    # Solarize,
    Posterize,
    Contrast,
    # Color,
    Brightness,
    Sharpness,
    Cutout
]

def get_next_subpolicy(transform_candidates = DEFALUT_CANDIDATES, op_per_subpolicy=2):
    if not transform_candidates:
        transform_candidates = DEFALUT_CANDIDATES

    n_candidates = len(transform_candidates)
    subpolicy = []

    for i in range(op_per_subpolicy):
        indx = random.randrange(n_candidates)
        prob = random.random()
        mag = random.random()
        subpolicy.append(transform_candidates[indx](prob, mag))

    subpolicy = transforms.Compose([
            transforms.Pad(4),
            transforms.RandomHorizontalFlip(),
            *subpolicy,
            transforms.Resize([240, 240]),
            transforms.ToTensor()
        ])

    return subpolicy

def eval(model_path):
    print('\n[+] Parse arguments')
    kwargs_path = os.path.join(model_path, 'kwargs.json')
    kwargs = json.loads(open(kwargs_path).read())
    args, kwargs = parse_args(kwargs)
    pprint(args)
    device = torch.device('cuda' if args.use_cuda else 'cpu')

    cp_path = os.path.join(model_path, 'augmentation.cp')


    print('\n[+] Load transform')
    # list to tensor

    aug_transform_list = []

    for i in range (16):
        aug_transform_list.append(get_next_subpolicy(DEFALUT_CANDIDATES))

    
    transform = transforms.RandomChoice(aug_transform_list)
    print(transform)
"""
    print('\n[+] Load dataset')

    dataset = get_dataset(args, transform, 'train')
    loader = iter(get_aug_dataloader(args, dataset))


    print('\n[+] Save 1 random policy')   

    save_dir = os.path.join(model_path, 'RD_aug_synthesized')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # non lesion
    normal_dir = '/root/volume/2016104167/data/MICCAI_BraTS_2019_Data_Training/NonLesion_flair_frame_all'

    for i, (image, target) in enumerate(loader):
        image = image.view(240, 240)

        # get random normal brain img
        nor_file = random.choice(os.listdir(normal_dir))
        nor_img = cv2.imread(os.path.join(normal_dir, nor_file), cv2.IMREAD_GRAYSCALE)
        # print(nor_img.shape) # (256, 224)
        nor_img = cv2.resize(nor_img, (240, 240))

        # synthesize
        image = np.asarray(image)    
        image_255 = image * 255
        image_255[image_255 < 10] = 0
        nor_img[image_255 > 10] = 0
        syn_image = nor_img + image_255

        # save synthesized img
        cv2.imwrite(os.path.join(save_dir, 'aug_'+ str(i) + '.png'), syn_image)

        if((i+1) % 1000 == 0):
            print("\n saved images: ", i)
            break

    print('\n[+] Finished to save')
    """

if __name__ == '__main__':
    fire.Fire(eval)