Showing
7 changed files
with
153 additions
and
37 deletions
... | @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter | ... | @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter |
10 | from utils import * | 10 | from utils import * |
11 | 11 | ||
12 | # command | 12 | # command |
13 | -# python "eval.py" --model_path='logs/' | 13 | +# python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/' |
14 | 14 | ||
15 | def eval(model_path): | 15 | def eval(model_path): |
16 | print('\n[+] Parse arguments') | 16 | print('\n[+] Parse arguments') |
... | @@ -34,9 +34,17 @@ def eval(model_path): | ... | @@ -34,9 +34,17 @@ def eval(model_path): |
34 | 34 | ||
35 | print('\n[+] Load dataset') | 35 | print('\n[+] Load dataset') |
36 | test_transform = get_valid_transform(args, model) | 36 | test_transform = get_valid_transform(args, model) |
37 | - test_dataset = get_dataset(args, test_transform, 'test') | 37 | + #print('\nTEST Transform\n', test_transform) |
38 | - #print("len(dataset): ", len(test_dataset), type(test_dataset)) # 590 | 38 | + test_dataset = get_dataset(args, test_transform, 'test') |
39 | - | 39 | + |
40 | + """ | ||
41 | + test_transform | ||
42 | + Compose( | ||
43 | + Resize(size=[224, 224], interpolation=PIL.Image.BILINEAR) | ||
44 | + ToTensor() | ||
45 | + ) | ||
46 | + """ | ||
47 | + | ||
40 | test_loader = iter(get_dataloader(args, test_dataset)) ### | 48 | test_loader = iter(get_dataloader(args, test_dataset)) ### |
41 | 49 | ||
42 | print('\n[+] Start testing') | 50 | print('\n[+] Start testing') | ... | ... |
... | @@ -17,15 +17,15 @@ from utils import * | ... | @@ -17,15 +17,15 @@ from utils import * |
17 | DEFALUT_CANDIDATES = [ | 17 | DEFALUT_CANDIDATES = [ |
18 | ShearXY, | 18 | ShearXY, |
19 | TranslateXY, | 19 | TranslateXY, |
20 | - # Rotate, | 20 | + Rotate, |
21 | # AutoContrast, | 21 | # AutoContrast, |
22 | # Invert, | 22 | # Invert, |
23 | - Equalize, | 23 | + Equalize, # Histogram Equalize --> white tumor |
24 | - Solarize, | 24 | + #Solarize, |
25 | Posterize, | 25 | Posterize, |
26 | - Contrast, | 26 | + # Contrast, |
27 | # Color, | 27 | # Color, |
28 | - Brightness, | 28 | + # Brightness, |
29 | Sharpness, | 29 | Sharpness, |
30 | Cutout, | 30 | Cutout, |
31 | # SamplePairing, | 31 | # SamplePairing, |
... | @@ -154,8 +154,9 @@ def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset | ... | @@ -154,8 +154,9 @@ def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset |
154 | subpolicy = transforms.Compose([ | 154 | subpolicy = transforms.Compose([ |
155 | ## baseline augmentation | 155 | ## baseline augmentation |
156 | transforms.Pad(4), | 156 | transforms.Pad(4), |
157 | - transforms.RandomCrop(32), | 157 | + # transforms.RandomCrop(240), #32 ->240 |
158 | transforms.RandomHorizontalFlip(), | 158 | transforms.RandomHorizontalFlip(), |
159 | + transforms.Resize([240, 240]), | ||
159 | ## policy | 160 | ## policy |
160 | *subpolicy, | 161 | *subpolicy, |
161 | ## to tensor | 162 | ## to tensor |
... | @@ -191,8 +192,8 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat | ... | @@ -191,8 +192,8 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat |
191 | 192 | ||
192 | return _transform | 193 | return _transform |
193 | 194 | ||
194 | -#fast_auto_augment(args, model, K=4, B=1, num_process=4) | 195 | +#fast_auto_augment(args, model, K=4, B=100, num_process=4) |
195 | -def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | 196 | +def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=2, num_process=5): |
196 | args_str = json.dumps(args._asdict()) | 197 | args_str = json.dumps(args._asdict()) |
197 | dataset = get_dataset(args, None, 'trainval') | 198 | dataset = get_dataset(args, None, 'trainval') |
198 | num_process = min(torch.cuda.device_count(), num_process) | 199 | num_process = min(torch.cuda.device_count(), num_process) |
... | @@ -215,6 +216,6 @@ def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N | ... | @@ -215,6 +216,6 @@ def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N |
215 | for future in futures: | 216 | for future in futures: |
216 | transform.extend(future.result()) | 217 | transform.extend(future.result()) |
217 | 218 | ||
218 | - transform = transforms.RandomChoice(transform) | 219 | + #transform = transforms.RandomChoice(transform) |
219 | 220 | ||
220 | return transform | 221 | return transform | ... | ... |
code/FAA2_VM/getAugmented.py
0 → 100644
1 | +import os | ||
2 | +import fire | ||
3 | +import json | ||
4 | +from pprint import pprint | ||
5 | +import pickle | ||
6 | + | ||
7 | +import torch | ||
8 | +import torch.nn as nn | ||
9 | +from torch.utils.tensorboard import SummaryWriter | ||
10 | + | ||
11 | +from utils import * | ||
12 | + | ||
13 | +# command | ||
14 | +# python getAugmented.py --model_path='logs/April_16_21:50:17__resnet50__None/' | ||
15 | + | ||
16 | +def eval(model_path): | ||
17 | + print('\n[+] Parse arguments') | ||
18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
20 | + args, kwargs = parse_args(kwargs) | ||
21 | + pprint(args) | ||
22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
23 | + | ||
24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
25 | + | ||
26 | + writer = SummaryWriter(log_dir=model_path) | ||
27 | + | ||
28 | + | ||
29 | + print('\n[+] Load transform') | ||
30 | + # list | ||
31 | + with open(cp_path, 'rb') as f: | ||
32 | + aug_transform_list = pickle.load(f) | ||
33 | + | ||
34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test')) | ||
35 | + | ||
36 | + | ||
37 | + print('\n[+] Load dataset') | ||
38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
39 | + dataset = get_dataset(args, aug_transform, 'test') | ||
40 | + | ||
41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
42 | + | ||
43 | + for i, (images, target) in enumerate(loader): | ||
44 | + images = images.view(240, 240) | ||
45 | + | ||
46 | + # concat image | ||
47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
48 | + | ||
49 | + if i % 1000 == 0: | ||
50 | + print("\n images size: ", augmented_image_list[i].size()) # [240, 240] | ||
51 | + | ||
52 | + break | ||
53 | + # break | ||
54 | + | ||
55 | + | ||
56 | + # print(augmented_image_list) | ||
57 | + | ||
58 | + | ||
59 | + print('\n[+] Write on tensorboard') | ||
60 | + if writer: | ||
61 | + for i, data in enumerate(augmented_image_list): | ||
62 | + tag = 'img/' + str(i) | ||
63 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
64 | + break | ||
65 | + | ||
66 | + writer.close() | ||
67 | + | ||
68 | + | ||
69 | + # if writer: | ||
70 | + # for j in range(): | ||
71 | + # tag = 'img/' + str(img_count) + '_' + str(j) | ||
72 | + # # writer.add_image(tag, | ||
73 | + # # concat_image_features(images[j], first[j]), global_step=step) | ||
74 | + # # if j > 0: | ||
75 | + # # fore = concat_image_features(fore, images[j]) | ||
76 | + | ||
77 | + # writer.add_image(tag, fore, global_step=0) | ||
78 | + # img_count = img_count + 1 | ||
79 | + | ||
80 | + # writer.close() | ||
81 | + | ||
82 | +if __name__ == '__main__': | ||
83 | + fire.Fire(eval) |
code/FAA2_VM/read_cp.py
0 → 100644
... | @@ -4,7 +4,7 @@ hyperopt | ... | @@ -4,7 +4,7 @@ hyperopt |
4 | pillow==6.2.1 | 4 | pillow==6.2.1 |
5 | natsort | 5 | natsort |
6 | fire | 6 | fire |
7 | -torch | 7 | +torchvision==0.2.2 |
8 | -torchvision==0.4.1 | 8 | +torch==1.1.0 |
9 | pandas | 9 | pandas |
10 | -sklearn | 10 | +sklearn |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -40,12 +40,12 @@ class TranslateXY(BaseTransform): | ... | @@ -40,12 +40,12 @@ class TranslateXY(BaseTransform): |
40 | return t(img) | 40 | return t(img) |
41 | 41 | ||
42 | 42 | ||
43 | -# class Rotate(BaseTransform): | 43 | +class Rotate(BaseTransform): |
44 | 44 | ||
45 | -# def transform(self, img): | 45 | + def transform(self, img): |
46 | -# degrees = self.mag * 360 | 46 | + degrees = self.mag * 360 |
47 | -# t = transforms.RandomRotation(degrees, Image.BILINEAR) | 47 | + t = transforms.RandomRotation(degrees, Image.BILINEAR) |
48 | -# return t(img) | 48 | + return t(img) |
49 | 49 | ||
50 | 50 | ||
51 | class AutoContrast(BaseTransform): | 51 | class AutoContrast(BaseTransform): |
... | @@ -55,10 +55,10 @@ class AutoContrast(BaseTransform): | ... | @@ -55,10 +55,10 @@ class AutoContrast(BaseTransform): |
55 | return ImageOps.autocontrast(img, cutoff=cutoff) | 55 | return ImageOps.autocontrast(img, cutoff=cutoff) |
56 | 56 | ||
57 | 57 | ||
58 | -# class Invert(BaseTransform): | 58 | +class Invert(BaseTransform): |
59 | 59 | ||
60 | -# def transform(self, img): | 60 | + def transform(self, img): |
61 | -# return ImageOps.invert(img) | 61 | + return ImageOps.invert(img) |
62 | 62 | ||
63 | 63 | ||
64 | class Equalize(BaseTransform): | 64 | class Equalize(BaseTransform): |
... | @@ -88,11 +88,11 @@ class Contrast(BaseTransform): | ... | @@ -88,11 +88,11 @@ class Contrast(BaseTransform): |
88 | return ImageEnhance.Contrast(img).enhance(factor) | 88 | return ImageEnhance.Contrast(img).enhance(factor) |
89 | 89 | ||
90 | 90 | ||
91 | -# class Color(BaseTransform): | 91 | +class Color(BaseTransform): |
92 | 92 | ||
93 | -# def transform(self, img): | 93 | + def transform(self, img): |
94 | -# factor = self.mag * 10 | 94 | + factor = self.mag * 10 |
95 | -# return ImageEnhance.Color(img).enhance(factor) | 95 | + return ImageEnhance.Color(img).enhance(factor) |
96 | 96 | ||
97 | 97 | ||
98 | class Brightness(BaseTransform): | 98 | class Brightness(BaseTransform): |
... | @@ -159,7 +159,10 @@ class CutoutOp(object): | ... | @@ -159,7 +159,10 @@ class CutoutOp(object): |
159 | # print("\nnp.asarray(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #(32, 32, 32) | 159 | # print("\nnp.asarray(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #(32, 32, 32) |
160 | # img = Image.fromarray(mask*np.asarray(img)) #(32, 32, 32) | 160 | # img = Image.fromarray(mask*np.asarray(img)) #(32, 32, 32) |
161 | 161 | ||
162 | - mask = np.reshape(mask, (32,32)) | 162 | + #mask = np.reshape(mask, (32, 32)) # (32, 32) -> (240, 240) |
163 | + | ||
164 | + # getAugmented.py | ||
165 | + mask = np.reshape(mask, (240, 240)) | ||
163 | 166 | ||
164 | #print("\n(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #[0, 255] (32, 32) | 167 | #print("\n(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #[0, 255] (32, 32) |
165 | # print("\nmask: ", mask.shape) #(32, 32) | 168 | # print("\nmask: ", mask.shape) #(32, 32) | ... | ... |
... | @@ -59,11 +59,15 @@ def split_dataset(args, dataset, k): | ... | @@ -59,11 +59,15 @@ def split_dataset(args, dataset, k): |
59 | return Dm_indexes, Da_indexes | 59 | return Dm_indexes, Da_indexes |
60 | 60 | ||
61 | 61 | ||
62 | -#(images[j], first[j]), global_step=step) | 62 | +# concat_image_features(images[j], first[j]) |
63 | def concat_image_features(image, features, max_features=3): | 63 | def concat_image_features(image, features, max_features=3): |
64 | _, h, w = image.shape | 64 | _, h, w = image.shape |
65 | + #print("\nfsize: ", features.size()) # (1, 240, 240) | ||
66 | + # features.size(0) = 64 | ||
67 | + #print(features.size(0)) | ||
68 | + #max_features = min(features.size(0), max_features) | ||
65 | 69 | ||
66 | - max_features = min(features.size(0), max_features) | 70 | + max_features = features.size(0) |
67 | image_feature = image.clone() | 71 | image_feature = image.clone() |
68 | 72 | ||
69 | for i in range(max_features): | 73 | for i in range(max_features): |
... | @@ -139,12 +143,12 @@ def parse_args(kwargs): | ... | @@ -139,12 +143,12 @@ def parse_args(kwargs): |
139 | kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | 143 | kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True |
140 | kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | 144 | kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() |
141 | kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | 145 | kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 |
142 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 | 146 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 50 |
143 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 | 147 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 50 |
144 | kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | 148 | kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' |
145 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 64 | 149 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 8 |
146 | kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | 150 | kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 |
147 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 | 151 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 100 |
148 | kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False | 152 | kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False |
149 | kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | 153 | kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None |
150 | 154 | ||
... | @@ -225,6 +229,7 @@ class CustomDataset(Dataset): | ... | @@ -225,6 +229,7 @@ class CustomDataset(Dataset): |
225 | 229 | ||
226 | if self.transform is not None: | 230 | if self.transform is not None: |
227 | tensor_image = self.transform(image) ## | 231 | tensor_image = self.transform(image) ## |
232 | + | ||
228 | return tensor_image, targets | 233 | return tensor_image, targets |
229 | 234 | ||
230 | def get_dataset(args, transform, split='train'): | 235 | def get_dataset(args, transform, split='train'): |
... | @@ -276,6 +281,14 @@ def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ... | @@ -276,6 +281,14 @@ def get_dataloader(args, dataset, shuffle=False, pin_memory=True): |
276 | pin_memory=pin_memory) | 281 | pin_memory=pin_memory) |
277 | return data_loader | 282 | return data_loader |
278 | 283 | ||
284 | +def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
285 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
286 | + # batch_size=args.batch_size, | ||
287 | + shuffle=shuffle, | ||
288 | + num_workers=args.num_workers, | ||
289 | + pin_memory=pin_memory) | ||
290 | + return data_loader | ||
291 | + | ||
279 | 292 | ||
280 | def get_inf_dataloader(args, dataset): | 293 | def get_inf_dataloader(args, dataset): |
281 | global current_epoch | 294 | global current_epoch |
... | @@ -304,7 +317,7 @@ def get_train_transform(args, model, log_dir=None): | ... | @@ -304,7 +317,7 @@ def get_train_transform(args, model, log_dir=None): |
304 | os.system('cp {} {}'.format( | 317 | os.system('cp {} {}'.format( |
305 | args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | 318 | args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) |
306 | else: | 319 | else: |
307 | - transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) ## | 320 | + transform = fast_auto_augment(args, model, K=4, B=100, num_process=4) ## |
308 | if log_dir: | 321 | if log_dir: |
309 | cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | 322 | cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) |
310 | 323 | ||
... | @@ -436,7 +449,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -436,7 +449,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
436 | samples += images.size(0) | 449 | samples += images.size(0) |
437 | 450 | ||
438 | if writer: | 451 | if writer: |
439 | - # print("\n3 images.size(0): ", images.size(0)) | 452 | + # print("\n images.size(0): ", images.size(0)) # batch size (last = n(imgs)%batch_size) |
440 | n_imgs = min(images.size(0), 10) | 453 | n_imgs = min(images.size(0), 10) |
441 | for j in range(n_imgs): | 454 | for j in range(n_imgs): |
442 | tag = 'valid/' + str(img_count) | 455 | tag = 'valid/' + str(img_count) | ... | ... |
-
Please register or login to post a comment