Showing
7 changed files
with
150 additions
and
34 deletions
| ... | @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter | ... | @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter |
| 10 | from utils import * | 10 | from utils import * |
| 11 | 11 | ||
| 12 | # command | 12 | # command |
| 13 | -# python "eval.py" --model_path='logs/' | 13 | +# python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/' |
| 14 | 14 | ||
| 15 | def eval(model_path): | 15 | def eval(model_path): |
| 16 | print('\n[+] Parse arguments') | 16 | print('\n[+] Parse arguments') |
| ... | @@ -34,8 +34,16 @@ def eval(model_path): | ... | @@ -34,8 +34,16 @@ def eval(model_path): |
| 34 | 34 | ||
| 35 | print('\n[+] Load dataset') | 35 | print('\n[+] Load dataset') |
| 36 | test_transform = get_valid_transform(args, model) | 36 | test_transform = get_valid_transform(args, model) |
| 37 | + #print('\nTEST Transform\n', test_transform) | ||
| 37 | test_dataset = get_dataset(args, test_transform, 'test') | 38 | test_dataset = get_dataset(args, test_transform, 'test') |
| 38 | - #print("len(dataset): ", len(test_dataset), type(test_dataset)) # 590 | 39 | + |
| 40 | + """ | ||
| 41 | + test_transform | ||
| 42 | + Compose( | ||
| 43 | + Resize(size=[224, 224], interpolation=PIL.Image.BILINEAR) | ||
| 44 | + ToTensor() | ||
| 45 | + ) | ||
| 46 | + """ | ||
| 39 | 47 | ||
| 40 | test_loader = iter(get_dataloader(args, test_dataset)) ### | 48 | test_loader = iter(get_dataloader(args, test_dataset)) ### |
| 41 | 49 | ... | ... |
| ... | @@ -17,15 +17,15 @@ from utils import * | ... | @@ -17,15 +17,15 @@ from utils import * |
| 17 | DEFALUT_CANDIDATES = [ | 17 | DEFALUT_CANDIDATES = [ |
| 18 | ShearXY, | 18 | ShearXY, |
| 19 | TranslateXY, | 19 | TranslateXY, |
| 20 | - # Rotate, | 20 | + Rotate, |
| 21 | # AutoContrast, | 21 | # AutoContrast, |
| 22 | # Invert, | 22 | # Invert, |
| 23 | - Equalize, | 23 | + Equalize, # Histogram Equalize --> white tumor |
| 24 | - Solarize, | 24 | + #Solarize, |
| 25 | Posterize, | 25 | Posterize, |
| 26 | - Contrast, | 26 | + # Contrast, |
| 27 | # Color, | 27 | # Color, |
| 28 | - Brightness, | 28 | + # Brightness, |
| 29 | Sharpness, | 29 | Sharpness, |
| 30 | Cutout, | 30 | Cutout, |
| 31 | # SamplePairing, | 31 | # SamplePairing, |
| ... | @@ -154,8 +154,9 @@ def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset | ... | @@ -154,8 +154,9 @@ def search_subpolicies_hyperopt(args, transform_candidates, child_model, dataset |
| 154 | subpolicy = transforms.Compose([ | 154 | subpolicy = transforms.Compose([ |
| 155 | ## baseline augmentation | 155 | ## baseline augmentation |
| 156 | transforms.Pad(4), | 156 | transforms.Pad(4), |
| 157 | - transforms.RandomCrop(32), | 157 | + # transforms.RandomCrop(240), #32 ->240 |
| 158 | transforms.RandomHorizontalFlip(), | 158 | transforms.RandomHorizontalFlip(), |
| 159 | + transforms.Resize([240, 240]), | ||
| 159 | ## policy | 160 | ## policy |
| 160 | *subpolicy, | 161 | *subpolicy, |
| 161 | ## to tensor | 162 | ## to tensor |
| ... | @@ -191,8 +192,8 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat | ... | @@ -191,8 +192,8 @@ def process_fn(args_str, model, dataset, Dm_indx, Da_indx, T, transform_candidat |
| 191 | 192 | ||
| 192 | return _transform | 193 | return _transform |
| 193 | 194 | ||
| 194 | -#fast_auto_augment(args, model, K=4, B=1, num_process=4) | 195 | +#fast_auto_augment(args, model, K=4, B=100, num_process=4) |
| 195 | -def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=10, num_process=5): | 196 | +def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N=2, num_process=5): |
| 196 | args_str = json.dumps(args._asdict()) | 197 | args_str = json.dumps(args._asdict()) |
| 197 | dataset = get_dataset(args, None, 'trainval') | 198 | dataset = get_dataset(args, None, 'trainval') |
| 198 | num_process = min(torch.cuda.device_count(), num_process) | 199 | num_process = min(torch.cuda.device_count(), num_process) |
| ... | @@ -215,6 +216,6 @@ def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N | ... | @@ -215,6 +216,6 @@ def fast_auto_augment(args, model, transform_candidates=None, K=5, B=100, T=2, N |
| 215 | for future in futures: | 216 | for future in futures: |
| 216 | transform.extend(future.result()) | 217 | transform.extend(future.result()) |
| 217 | 218 | ||
| 218 | - transform = transforms.RandomChoice(transform) | 219 | + #transform = transforms.RandomChoice(transform) |
| 219 | 220 | ||
| 220 | return transform | 221 | return transform | ... | ... |
code/FAA2_VM/getAugmented.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | +import pickle | ||
| 6 | + | ||
| 7 | +import torch | ||
| 8 | +import torch.nn as nn | ||
| 9 | +from torch.utils.tensorboard import SummaryWriter | ||
| 10 | + | ||
| 11 | +from utils import * | ||
| 12 | + | ||
| 13 | +# command | ||
| 14 | +# python getAugmented.py --model_path='logs/April_16_21:50:17__resnet50__None/' | ||
| 15 | + | ||
| 16 | +def eval(model_path): | ||
| 17 | + print('\n[+] Parse arguments') | ||
| 18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 20 | + args, kwargs = parse_args(kwargs) | ||
| 21 | + pprint(args) | ||
| 22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 23 | + | ||
| 24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
| 25 | + | ||
| 26 | + writer = SummaryWriter(log_dir=model_path) | ||
| 27 | + | ||
| 28 | + | ||
| 29 | + print('\n[+] Load transform') | ||
| 30 | + # list | ||
| 31 | + with open(cp_path, 'rb') as f: | ||
| 32 | + aug_transform_list = pickle.load(f) | ||
| 33 | + | ||
| 34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test')) | ||
| 35 | + | ||
| 36 | + | ||
| 37 | + print('\n[+] Load dataset') | ||
| 38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
| 39 | + dataset = get_dataset(args, aug_transform, 'test') | ||
| 40 | + | ||
| 41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
| 42 | + | ||
| 43 | + for i, (images, target) in enumerate(loader): | ||
| 44 | + images = images.view(240, 240) | ||
| 45 | + | ||
| 46 | + # concat image | ||
| 47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
| 48 | + | ||
| 49 | + if i % 1000 == 0: | ||
| 50 | + print("\n images size: ", augmented_image_list[i].size()) # [240, 240] | ||
| 51 | + | ||
| 52 | + break | ||
| 53 | + # break | ||
| 54 | + | ||
| 55 | + | ||
| 56 | + # print(augmented_image_list) | ||
| 57 | + | ||
| 58 | + | ||
| 59 | + print('\n[+] Write on tensorboard') | ||
| 60 | + if writer: | ||
| 61 | + for i, data in enumerate(augmented_image_list): | ||
| 62 | + tag = 'img/' + str(i) | ||
| 63 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
| 64 | + break | ||
| 65 | + | ||
| 66 | + writer.close() | ||
| 67 | + | ||
| 68 | + | ||
| 69 | + # if writer: | ||
| 70 | + # for j in range(): | ||
| 71 | + # tag = 'img/' + str(img_count) + '_' + str(j) | ||
| 72 | + # # writer.add_image(tag, | ||
| 73 | + # # concat_image_features(images[j], first[j]), global_step=step) | ||
| 74 | + # # if j > 0: | ||
| 75 | + # # fore = concat_image_features(fore, images[j]) | ||
| 76 | + | ||
| 77 | + # writer.add_image(tag, fore, global_step=0) | ||
| 78 | + # img_count = img_count + 1 | ||
| 79 | + | ||
| 80 | + # writer.close() | ||
| 81 | + | ||
| 82 | +if __name__ == '__main__': | ||
| 83 | + fire.Fire(eval) |
code/FAA2_VM/read_cp.py
0 → 100644
| ... | @@ -40,12 +40,12 @@ class TranslateXY(BaseTransform): | ... | @@ -40,12 +40,12 @@ class TranslateXY(BaseTransform): |
| 40 | return t(img) | 40 | return t(img) |
| 41 | 41 | ||
| 42 | 42 | ||
| 43 | -# class Rotate(BaseTransform): | 43 | +class Rotate(BaseTransform): |
| 44 | 44 | ||
| 45 | -# def transform(self, img): | 45 | + def transform(self, img): |
| 46 | -# degrees = self.mag * 360 | 46 | + degrees = self.mag * 360 |
| 47 | -# t = transforms.RandomRotation(degrees, Image.BILINEAR) | 47 | + t = transforms.RandomRotation(degrees, Image.BILINEAR) |
| 48 | -# return t(img) | 48 | + return t(img) |
| 49 | 49 | ||
| 50 | 50 | ||
| 51 | class AutoContrast(BaseTransform): | 51 | class AutoContrast(BaseTransform): |
| ... | @@ -55,10 +55,10 @@ class AutoContrast(BaseTransform): | ... | @@ -55,10 +55,10 @@ class AutoContrast(BaseTransform): |
| 55 | return ImageOps.autocontrast(img, cutoff=cutoff) | 55 | return ImageOps.autocontrast(img, cutoff=cutoff) |
| 56 | 56 | ||
| 57 | 57 | ||
| 58 | -# class Invert(BaseTransform): | 58 | +class Invert(BaseTransform): |
| 59 | 59 | ||
| 60 | -# def transform(self, img): | 60 | + def transform(self, img): |
| 61 | -# return ImageOps.invert(img) | 61 | + return ImageOps.invert(img) |
| 62 | 62 | ||
| 63 | 63 | ||
| 64 | class Equalize(BaseTransform): | 64 | class Equalize(BaseTransform): |
| ... | @@ -88,11 +88,11 @@ class Contrast(BaseTransform): | ... | @@ -88,11 +88,11 @@ class Contrast(BaseTransform): |
| 88 | return ImageEnhance.Contrast(img).enhance(factor) | 88 | return ImageEnhance.Contrast(img).enhance(factor) |
| 89 | 89 | ||
| 90 | 90 | ||
| 91 | -# class Color(BaseTransform): | 91 | +class Color(BaseTransform): |
| 92 | 92 | ||
| 93 | -# def transform(self, img): | 93 | + def transform(self, img): |
| 94 | -# factor = self.mag * 10 | 94 | + factor = self.mag * 10 |
| 95 | -# return ImageEnhance.Color(img).enhance(factor) | 95 | + return ImageEnhance.Color(img).enhance(factor) |
| 96 | 96 | ||
| 97 | 97 | ||
| 98 | class Brightness(BaseTransform): | 98 | class Brightness(BaseTransform): |
| ... | @@ -159,7 +159,10 @@ class CutoutOp(object): | ... | @@ -159,7 +159,10 @@ class CutoutOp(object): |
| 159 | # print("\nnp.asarray(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #(32, 32, 32) | 159 | # print("\nnp.asarray(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #(32, 32, 32) |
| 160 | # img = Image.fromarray(mask*np.asarray(img)) #(32, 32, 32) | 160 | # img = Image.fromarray(mask*np.asarray(img)) #(32, 32, 32) |
| 161 | 161 | ||
| 162 | - mask = np.reshape(mask, (32,32)) | 162 | + #mask = np.reshape(mask, (32, 32)) # (32, 32) -> (240, 240) |
| 163 | + | ||
| 164 | + # getAugmented.py | ||
| 165 | + mask = np.reshape(mask, (240, 240)) | ||
| 163 | 166 | ||
| 164 | #print("\n(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #[0, 255] (32, 32) | 167 | #print("\n(img) max: \n", np.amax(np.asarray(img)), np.asarray(img).shape) #[0, 255] (32, 32) |
| 165 | # print("\nmask: ", mask.shape) #(32, 32) | 168 | # print("\nmask: ", mask.shape) #(32, 32) | ... | ... |
| ... | @@ -59,11 +59,15 @@ def split_dataset(args, dataset, k): | ... | @@ -59,11 +59,15 @@ def split_dataset(args, dataset, k): |
| 59 | return Dm_indexes, Da_indexes | 59 | return Dm_indexes, Da_indexes |
| 60 | 60 | ||
| 61 | 61 | ||
| 62 | -#(images[j], first[j]), global_step=step) | 62 | +# concat_image_features(images[j], first[j]) |
| 63 | def concat_image_features(image, features, max_features=3): | 63 | def concat_image_features(image, features, max_features=3): |
| 64 | _, h, w = image.shape | 64 | _, h, w = image.shape |
| 65 | + #print("\nfsize: ", features.size()) # (1, 240, 240) | ||
| 66 | + # features.size(0) = 64 | ||
| 67 | + #print(features.size(0)) | ||
| 68 | + #max_features = min(features.size(0), max_features) | ||
| 65 | 69 | ||
| 66 | - max_features = min(features.size(0), max_features) | 70 | + max_features = features.size(0) |
| 67 | image_feature = image.clone() | 71 | image_feature = image.clone() |
| 68 | 72 | ||
| 69 | for i in range(max_features): | 73 | for i in range(max_features): |
| ... | @@ -139,12 +143,12 @@ def parse_args(kwargs): | ... | @@ -139,12 +143,12 @@ def parse_args(kwargs): |
| 139 | kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | 143 | kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True |
| 140 | kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | 144 | kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() |
| 141 | kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | 145 | kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 |
| 142 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 | 146 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 50 |
| 143 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 | 147 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 50 |
| 144 | kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | 148 | kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' |
| 145 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 64 | 149 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 8 |
| 146 | kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | 150 | kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 |
| 147 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 | 151 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 100 |
| 148 | kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False | 152 | kwargs['fast_auto_augment'] = kwargs['fast_auto_augment'] if 'fast_auto_augment' in kwargs else False |
| 149 | kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | 153 | kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None |
| 150 | 154 | ||
| ... | @@ -225,6 +229,7 @@ class CustomDataset(Dataset): | ... | @@ -225,6 +229,7 @@ class CustomDataset(Dataset): |
| 225 | 229 | ||
| 226 | if self.transform is not None: | 230 | if self.transform is not None: |
| 227 | tensor_image = self.transform(image) ## | 231 | tensor_image = self.transform(image) ## |
| 232 | + | ||
| 228 | return tensor_image, targets | 233 | return tensor_image, targets |
| 229 | 234 | ||
| 230 | def get_dataset(args, transform, split='train'): | 235 | def get_dataset(args, transform, split='train'): |
| ... | @@ -276,6 +281,14 @@ def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ... | @@ -276,6 +281,14 @@ def get_dataloader(args, dataset, shuffle=False, pin_memory=True): |
| 276 | pin_memory=pin_memory) | 281 | pin_memory=pin_memory) |
| 277 | return data_loader | 282 | return data_loader |
| 278 | 283 | ||
| 284 | +def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 285 | + data_loader = torch.utils.data.DataLoader(dataset, | ||
| 286 | + # batch_size=args.batch_size, | ||
| 287 | + shuffle=shuffle, | ||
| 288 | + num_workers=args.num_workers, | ||
| 289 | + pin_memory=pin_memory) | ||
| 290 | + return data_loader | ||
| 291 | + | ||
| 279 | 292 | ||
| 280 | def get_inf_dataloader(args, dataset): | 293 | def get_inf_dataloader(args, dataset): |
| 281 | global current_epoch | 294 | global current_epoch |
| ... | @@ -304,7 +317,7 @@ def get_train_transform(args, model, log_dir=None): | ... | @@ -304,7 +317,7 @@ def get_train_transform(args, model, log_dir=None): |
| 304 | os.system('cp {} {}'.format( | 317 | os.system('cp {} {}'.format( |
| 305 | args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) | 318 | args.augment_path, os.path.join(log_dir, 'augmentation.cp'))) |
| 306 | else: | 319 | else: |
| 307 | - transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) ## | 320 | + transform = fast_auto_augment(args, model, K=4, B=100, num_process=4) ## |
| 308 | if log_dir: | 321 | if log_dir: |
| 309 | cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) | 322 | cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb')) |
| 310 | 323 | ||
| ... | @@ -436,7 +449,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -436,7 +449,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
| 436 | samples += images.size(0) | 449 | samples += images.size(0) |
| 437 | 450 | ||
| 438 | if writer: | 451 | if writer: |
| 439 | - # print("\n3 images.size(0): ", images.size(0)) | 452 | + # print("\n images.size(0): ", images.size(0)) # batch size (last = n(imgs)%batch_size) |
| 440 | n_imgs = min(images.size(0), 10) | 453 | n_imgs = min(images.size(0), 10) |
| 441 | for j in range(n_imgs): | 454 | for j in range(n_imgs): |
| 442 | tag = 'valid/' + str(img_count) | 455 | tag = 'valid/' + str(img_count) | ... | ... |
-
Please register or login to post a comment