getAugmented_all.py 1.58 KB
import os
import fire
import json
from pprint import pprint
import pickle

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from utils import *

# command
# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/'

def eval(model_path):
    print('\n[+] Parse arguments')
    kwargs_path = os.path.join(model_path, 'kwargs.json')
    kwargs = json.loads(open(kwargs_path).read())
    args, kwargs = parse_args(kwargs)
    pprint(args)
    device = torch.device('cuda' if args.use_cuda else 'cpu')

    cp_path = os.path.join(model_path, 'augmentation.cp')
    
    writer = SummaryWriter(log_dir=model_path)


    print('\n[+] Load transform')
    # list
    with open(cp_path, 'rb') as f:
        aug_transform_list = pickle.load(f)

    augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'train'))


    print('\n[+] Load dataset')
    for aug_idx, aug_transform in enumerate(aug_transform_list):
        dataset = get_dataset(args, aug_transform, 'train') 

        loader = iter(get_aug_dataloader(args, dataset))

        for i, (images, target) in enumerate(loader):
            images = images.view(240, 240)

            # concat image
            augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1)




    print('\n[+] Write on tensorboard')   
    if writer:
        for i, data in enumerate(augmented_image_list):
            tag = 'img/' + str(i) 
            writer.add_image(tag, data.view(1, 240, -1), global_step=0)

    writer.close()
   

if __name__ == '__main__':
    fire.Fire(eval)