조현아

vm get augmented data

...@@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter ...@@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter
10 from utils import * 10 from utils import *
11 11
12 # command 12 # command
13 -# python "eval.py" --model_path='logs/' 13 +# python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/'
14 14
15 def eval(model_path): 15 def eval(model_path):
16 print('\n[+] Parse arguments') 16 print('\n[+] Parse arguments')
...@@ -34,9 +34,17 @@ def eval(model_path): ...@@ -34,9 +34,17 @@ def eval(model_path):
34 34
35 print('\n[+] Load dataset') 35 print('\n[+] Load dataset')
36 test_transform = get_valid_transform(args, model) 36 test_transform = get_valid_transform(args, model)
37 - test_dataset = get_dataset(args, test_transform, 'test') 37 + #print('\nTEST Transform\n', test_transform)
38 - #print("len(dataset): ", len(test_dataset), type(test_dataset)) # 590 38 + test_dataset = get_dataset(args, test_transform, 'test')
39 - 39 +
40 + """
41 + test_transform
42 + Compose(
43 + Resize(size=[224, 224], interpolation=PIL.Image.BILINEAR)
44 + ToTensor()
45 + )
46 + """
47 +
40 test_loader = iter(get_dataloader(args, test_dataset)) ### 48 test_loader = iter(get_dataloader(args, test_dataset)) ###
41 49
42 print('\n[+] Start testing') 50 print('\n[+] Start testing')
......
...@@ -17,15 +17,15 @@ from utils import * ...@@ -17,15 +17,15 @@ from utils import *
17 DEFALUT_CANDIDATES = [ 17 DEFALUT_CANDIDATES = [
18 ShearXY, 18 ShearXY,
19 TranslateXY, 19 TranslateXY,
20 - # Rotate, 20 + Rotate,
21 # AutoContrast, 21 # AutoContrast,
22 # Invert, 22 # Invert,
23 - Equalize, 23 + Equalize, # Histogram Equalize --> white tumor
24 - Solarize, 24 + #Solarize,
25 Posterize, 25 Posterize,
26 - Contrast, 26 + # Contrast,
27 # Color, 27 # Color,
28 - Brightness, 28 + # Brightness,
29 Sharpness, 29 Sharpness,
30 Cutout, 30 Cutout,
31 # SamplePairing, 31 # SamplePairing,
...@@ -154,8 +154,9 @@ def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset ...@@ -154,8 +154,9 @@ def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset
154 subpolicy = transforms.Compose([ 154 subpolicy = transforms.Compose([
155 ## baseline augmentation 155 ## baseline augmentation
156 transforms.Pad(4), 156 transforms.Pad(4),
157 - transforms.RandomCrop(32), 157 + # transforms.RandomCrop(240), #32 ->240
158 transforms.RandomHorizontalFlip(), 158 transforms.RandomHorizontalFlip(),
159 + transforms.Resize([240, 240]),
159 ## policy 160 ## policy
160 *subpolicy, 161 *subpolicy,
161 ## to tensor 162 ## to tensor
...@@ -191,8 +192,8 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat ...@@ -191,8 +192,8 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat
191 192
192 return _transform 193 return _transform
193 194
194 -#fast_auto_augment(args, model, K=4, B=1, num_process=4) 195 +#fast_auto_augment(args, model, K=4, B=100, num_process=4)
195 -def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): 196 +def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=2, num_process=5):
196 args_str = json.dumps(args._asdict()) 197 args_str = json.dumps(args._asdict())
197 dataset = get_dataset(args, None, 'trainval') 198 dataset = get_dataset(args, None, 'trainval')
198 num_process = min(torch.cuda.device_count(), num_process) 199 num_process = min(torch.cuda.device_count(), num_process)
...@@ -215,6 +216,6 @@ def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N ...@@ -215,6 +216,6 @@ def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N
215 for future in futures: 216 for future in futures:
216 transform.extend(future.result()) 217 transform.extend(future.result())
217 218
218 - transform = transforms.RandomChoice(transform) 219 + #transform = transforms.RandomChoice(transform)
219 220
220 return transform 221 return transform
......
1 +import os
2 +import fire
3 +import json
4 +from pprint import pprint
5 +import pickle
6 +
7 +import torch
8 +import torch.nn as nn
9 +from torch.utils.tensorboard import SummaryWriter
10 +
11 +from utils import *
12 +
13 +# command
14 +# python getAugmented.py --model_path='logs/April_16_21:50:17__resnet50__None/'
15 +
16 +def eval(model_path):
17 + print('\n[+] Parse arguments')
18 + kwargs_path = os.path.join(model_path, 'kwargs.json')
19 + kwargs = json.loads(open(kwargs_path).read())
20 + args, kwargs = parse_args(kwargs)
21 + pprint(args)
22 + device = torch.device('cuda' if args.use_cuda else 'cpu')
23 +
24 + cp_path = os.path.join(model_path, 'augmentation.cp')
25 +
26 + writer = SummaryWriter(log_dir=model_path)
27 +
28 +
29 + print('\n[+] Load transform')
30 + # list
31 + with open(cp_path, 'rb') as f:
32 + aug_transform_list = pickle.load(f)
33 +
34 + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test'))
35 +
36 +
37 + print('\n[+] Load dataset')
38 + for aug_idx, aug_transform in enumerate(aug_transform_list):
39 + dataset = get_dataset(args, aug_transform, 'test')
40 +
41 + loader = iter(get_aug_dataloader(args, dataset))
42 +
43 + for i, (images, target) in enumerate(loader):
44 + images = images.view(240, 240)
45 +
46 + # concat image
47 + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1)
48 +
49 + if i % 1000 == 0:
50 + print("\n images size: ", augmented_image_list[i].size()) # [240, 240]
51 +
52 + break
53 + # break
54 +
55 +
56 + # print(augmented_image_list)
57 +
58 +
59 + print('\n[+] Write on tensorboard')
60 + if writer:
61 + for i, data in enumerate(augmented_image_list):
62 + tag = 'img/' + str(i)
63 + writer.add_image(tag, data.view(1, 240, -1), global_step=0)
64 + break
65 +
66 + writer.close()
67 +
68 +
69 + # if writer:
70 + # for j in range():
71 + # tag = 'img/' + str(img_count) + '_' + str(j)
72 + # # writer.add_image(tag,
73 + # # concat_image_features(images[j], first[j]), global_step=step)
74 + # # if j > 0:
75 + # # fore = concat_image_features(fore, images[j])
76 +
77 + # writer.add_image(tag, fore, global_step=0)
78 + # img_count = img_count + 1
79 +
80 + # writer.close()
81 +
82 +if __name__ == '__main__':
83 + fire.Fire(eval)
1 +import pickle
2 +
3 +with open('logs/April_16_21:50:17__resnet50__None/augmentation.cp', 'rb') as f:
4 + data = pickle.load(f)
5 +
6 +
7 +print(data)
8 +print(type(data))
...\ No newline at end of file ...\ No newline at end of file
...@@ -4,7 +4,7 @@ hyperopt ...@@ -4,7 +4,7 @@ hyperopt
4 pillow==6.2.1 4 pillow==6.2.1
5 natsort 5 natsort
6 fire 6 fire
7 -torch 7 +torchvision==0.2.2
8 -torchvision==0.4.1 8 +torch==1.1.0
9 pandas 9 pandas
10 -sklearn 10 +sklearn
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -40,12 +40,12 @@ class TranslateXY(BaseTransform): ...@@ -40,12 +40,12 @@ class TranslateXY(BaseTransform):
40 return t(img) 40 return t(img)
41 41
42 42
43 -# class Rotate(BaseTransform): 43 +class Rotate(BaseTransform):
44 44
45 -# def transform(self, img): 45 + def transform(self, img):
46 -# degrees = self.mag * 360 46 + degrees = self.mag * 360
47 -# t = transforms.RandomRotation(degrees, Image.BILINEAR) 47 + t = transforms.RandomRotation(degrees, Image.BILINEAR)
48 -# return t(img) 48 + return t(img)
49 49
50 50
51 class AutoContrast(BaseTransform): 51 class AutoContrast(BaseTransform):
...@@ -55,10 +55,10 @@ class AutoContrast(BaseTransform): ...@@ -55,10 +55,10 @@ class AutoContrast(BaseTransform):
55 return ImageOps.autocontrast(img, cutoff=cutoff) 55 return ImageOps.autocontrast(img, cutoff=cutoff)
56 56
57 57
58 -# class Invert(BaseTransform): 58 +class Invert(BaseTransform):
59 59
60 -# def transform(self, img): 60 + def transform(self, img):
61 -# return ImageOps.invert(img) 61 + return ImageOps.invert(img)
62 62
63 63
64 class Equalize(BaseTransform): 64 class Equalize(BaseTransform):
...@@ -88,11 +88,11 @@ class Contrast(BaseTransform): ...@@ -88,11 +88,11 @@ class Contrast(BaseTransform):
88 return ImageEnhance.Contrast(img).enhance(factor) 88 return ImageEnhance.Contrast(img).enhance(factor)
89 89
90 90
91 -# class Color(BaseTransform): 91 +class Color(BaseTransform):
92 92
93 -# def transform(self, img): 93 + def transform(self, img):
94 -# factor = self.mag * 10 94 + factor = self.mag * 10
95 -# return ImageEnhance.Color(img).enhance(factor) 95 + return ImageEnhance.Color(img).enhance(factor)
96 96
97 97
98 class Brightness(BaseTransform): 98 class Brightness(BaseTransform):
...@@ -159,7 +159,10 @@ class CutoutOp(object): ...@@ -159,7 +159,10 @@ class CutoutOp(object):
159 # print("\nnp.asarray(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #(32, 32, 32) 159 # print("\nnp.asarray(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #(32, 32, 32)
160 # img = Image.fromarray(mask*np.asarray(img)) #(32, 32, 32) 160 # img = Image.fromarray(mask*np.asarray(img)) #(32, 32, 32)
161 161
162 - mask = np.reshape(mask, (32,32)) 162 + #mask = np.reshape(mask, (32, 32)) # (32, 32) -> (240, 240)
163 +
164 + # getAugmented.py
165 + mask = np.reshape(mask, (240, 240))
163 166
164 #print("\n(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #[0, 255] (32, 32) 167 #print("\n(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #[0, 255] (32, 32)
165 # print("\nmask: ", mask.shape) #(32, 32) 168 # print("\nmask: ", mask.shape) #(32, 32)
......
...@@ -59,11 +59,15 @@ def split_dataset(args, dataset, k): ...@@ -59,11 +59,15 @@ def split_dataset(args, dataset, k):
59 return Dm_indexes, Da_indexes 59 return Dm_indexes, Da_indexes
60 60
61 61
62 -#(images[j], first[j]), global_step=step) 62 +# concat_image_features(images[j], first[j])
63 def concat_image_features(image, features, max_features=3): 63 def concat_image_features(image, features, max_features=3):
64 _, h, w = image.shape 64 _, h, w = image.shape
65 + #print("\nfsize: ", features.size()) # (1, 240, 240)
66 + # features.size(0) = 64
67 + #print(features.size(0))
68 + #max_features = min(features.size(0), max_features)
65 69
66 - max_features = min(features.size(0), max_features) 70 + max_features = features.size(0)
67 image_feature = image.clone() 71 image_feature = image.clone()
68 72
69 for i in range(max_features): 73 for i in range(max_features):
...@@ -139,12 +143,12 @@ def parse_args(kwargs): ...@@ -139,12 +143,12 @@ def parse_args(kwargs):
139 kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True 143 kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True
140 kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() 144 kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available()
141 kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 145 kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4
142 - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 146 + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 50
143 - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 147 + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 50
144 kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' 148 kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp'
145 - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 64 149 + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 8
146 kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 150 kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0
147 - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 151 + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 100
148 kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False 152 kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False
149 kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None 153 kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None
150 154
...@@ -225,6 +229,7 @@ class CustomDataset(Dataset): ...@@ -225,6 +229,7 @@ class CustomDataset(Dataset):
225 229
226 if self.transform is not None: 230 if self.transform is not None:
227 tensor_image = self.transform(image) ## 231 tensor_image = self.transform(image) ##
232 +
228 return tensor_image, targets 233 return tensor_image, targets
229 234
230 def get_dataset(args, transform, split='train'): 235 def get_dataset(args, transform, split='train'):
...@@ -276,6 +281,14 @@ def get_dataloader(args, dataset, shuffle=False, pin_memory=True): ...@@ -276,6 +281,14 @@ def get_dataloader(args, dataset, shuffle=False, pin_memory=True):
276 pin_memory=pin_memory) 281 pin_memory=pin_memory)
277 return data_loader 282 return data_loader
278 283
284 +def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True):
285 + data_loader = torch.utils.data.DataLoader(dataset,
286 + # batch_size=args.batch_size,
287 + shuffle=shuffle,
288 + num_workers=args.num_workers,
289 + pin_memory=pin_memory)
290 + return data_loader
291 +
279 292
280 def get_inf_dataloader(args, dataset): 293 def get_inf_dataloader(args, dataset):
281 global current_epoch 294 global current_epoch
...@@ -304,7 +317,7 @@ def get_train_transform(args, model, log_dir=None): ...@@ -304,7 +317,7 @@ def get_train_transform(args, model, log_dir=None):
304 os.system('cp {} {}'.format( 317 os.system('cp {} {}'.format(
305 args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) 318 args.augment_path, os.path.join(log_dir, 'augmentation.cp')))
306 else: 319 else:
307 - transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) ## 320 + transform = fast_auto_augment(args, model, K=4, B=100, num_process=4) ##
308 if log_dir: 321 if log_dir:
309 cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) 322 cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb'))
310 323
...@@ -436,7 +449,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): ...@@ -436,7 +449,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None):
436 samples += images.size(0) 449 samples += images.size(0)
437 450
438 if writer: 451 if writer:
439 - # print("\n3 images.size(0): ", images.size(0)) 452 + # print("\n images.size(0): ", images.size(0)) # batch size (last = n(imgs)%batch_size)
440 n_imgs = min(images.size(0), 10) 453 n_imgs = min(images.size(0), 10)
441 for j in range(n_imgs): 454 for j in range(n_imgs):
442 tag = 'valid/' + str(img_count) 455 tag = 'valid/' + str(img_count)
......