getAugmented.py 2.21 KB
import os
import fire
import json
from pprint import pprint
import pickle

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from utils import *

# command
# python getAugmented.py --model_path='logs/April_16_21:50:17__resnet50__None/'

def eval(model_path):
    print('\n[+] Parse arguments')
    kwargs_path = os.path.join(model_path, 'kwargs.json')
    kwargs = json.loads(open(kwargs_path).read())
    args, kwargs = parse_args(kwargs)
    pprint(args)
    device = torch.device('cuda' if args.use_cuda else 'cpu')

    cp_path = os.path.join(model_path, 'augmentation.cp')
    
    writer = SummaryWriter(log_dir=model_path)


    print('\n[+] Load transform')
    # list
    with open(cp_path, 'rb') as f:
        aug_transform_list = pickle.load(f)

    augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test'))


    print('\n[+] Load dataset')
    for aug_idx, aug_transform in enumerate(aug_transform_list):
        dataset = get_dataset(args, aug_transform, 'test') 

        loader = iter(get_aug_dataloader(args, dataset))

        for i, (images, target) in enumerate(loader):
            images = images.view(240, 240)

            # concat image
            augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1)
            
            if i % 1000 == 0: 
                print("\n images size: ", augmented_image_list[i].size()) # [240, 240]

            break
        # break


   # print(augmented_image_list)


    print('\n[+] Write on tensorboard')   
    if writer:
        for i, data in enumerate(augmented_image_list):
            tag = 'img/' + str(i) 
            writer.add_image(tag, data.view(1, 240, -1), global_step=0)
            break

    writer.close()      


    # if writer:
    #     for j in range():
    #         tag = 'img/' + str(img_count) + '_' + str(j) 
    #         # writer.add_image(tag,
    #         #         concat_image_features(images[j], first[j]), global_step=step) 
    #         # if j > 0:
    #         #     fore = concat_image_features(fore, images[j])

    #     writer.add_image(tag, fore, global_step=0)
    #     img_count = img_count + 1

    # writer.close()

if __name__ == '__main__':
    fire.Fire(eval)