Showing
156 changed files
with
5303 additions
and
0 deletions
Code/HD-CelebA-Cropper/.gitignore
0 → 100644
Code/HD-CelebA-Cropper/align.py
0 → 100644
1 | +from __future__ import absolute_import | ||
2 | +from __future__ import division | ||
3 | +from __future__ import print_function | ||
4 | + | ||
5 | +import argparse | ||
6 | +from functools import partial | ||
7 | +from multiprocessing import Pool | ||
8 | +import os | ||
9 | +import re | ||
10 | + | ||
11 | +import cropper | ||
12 | +import numpy as np | ||
13 | +import tqdm | ||
14 | + | ||
15 | + | ||
16 | +# ============================================================================== | ||
17 | +# = param = | ||
18 | +# ============================================================================== | ||
19 | + | ||
20 | +parser = argparse.ArgumentParser() | ||
21 | +# main | ||
22 | +parser.add_argument('--img_dir', dest='img_dir', default='./data/img_celeba') | ||
23 | +parser.add_argument('--save_dir', dest='save_dir', default='./data/aligned') | ||
24 | +parser.add_argument('--landmark_file', dest='landmark_file', default='./data/landmark.txt') | ||
25 | +parser.add_argument('--standard_landmark_file', dest='standard_landmark_file', default='./data/standard_landmark_68pts.txt') | ||
26 | +parser.add_argument('--crop_size_h', dest='crop_size_h', type=int, default=572) | ||
27 | +parser.add_argument('--crop_size_w', dest='crop_size_w', type=int, default=572) | ||
28 | +parser.add_argument('--move_h', dest='move_h', type=float, default=0.25) | ||
29 | +parser.add_argument('--move_w', dest='move_w', type=float, default=0.) | ||
30 | +parser.add_argument('--save_format', dest='save_format', choices=['jpg', 'png'], default='jpg') | ||
31 | +parser.add_argument('--n_worker', dest='n_worker', type=int, default=8) | ||
32 | +# others | ||
33 | +parser.add_argument('--face_factor', dest='face_factor', type=float, help='The factor of face area relative to the output image.', default=0.45) | ||
34 | +parser.add_argument('--align_type', dest='align_type', choices=['affine', 'similarity'], default='similarity') | ||
35 | +parser.add_argument('--order', dest='order', type=int, choices=[0, 1, 2, 3, 4, 5], help='The order of interpolation.', default=3) | ||
36 | +parser.add_argument('--mode', dest='mode', choices=['constant', 'edge', 'symmetric', 'reflect', 'wrap'], default='edge') | ||
37 | +args = parser.parse_args() | ||
38 | + | ||
39 | + | ||
40 | +# ============================================================================== | ||
41 | +# = opencv first = | ||
42 | +# ============================================================================== | ||
43 | + | ||
44 | +_DEAFAULT_JPG_QUALITY = 95 | ||
45 | +try: | ||
46 | + import cv2 | ||
47 | + imread = cv2.imread | ||
48 | + imwrite = partial(cv2.imwrite, params=[int(cv2.IMWRITE_JPEG_QUALITY), _DEAFAULT_JPG_QUALITY]) | ||
49 | + align_crop = cropper.align_crop_opencv | ||
50 | + print('Use OpenCV') | ||
51 | +except: | ||
52 | + import skimage.io as io | ||
53 | + imread = io.imread | ||
54 | + imwrite = partial(io.imsave, quality=_DEAFAULT_JPG_QUALITY) | ||
55 | + align_crop = cropper.align_crop_skimage | ||
56 | + print('Importing OpenCv fails. Use scikit-image') | ||
57 | + | ||
58 | + | ||
59 | +# ============================================================================== | ||
60 | +# = run = | ||
61 | +# ============================================================================== | ||
62 | + | ||
63 | +# count landmarks | ||
64 | +with open(args.landmark_file) as f: | ||
65 | + line = f.readline() | ||
66 | +n_landmark = len(re.split('[ ]+', line)[1:]) // 2 | ||
67 | + | ||
68 | +# load standard landmark | ||
69 | +standard_landmark = np.genfromtxt(args.standard_landmark_file, dtype=np.float).reshape(n_landmark, 2) | ||
70 | +standard_landmark[:, 0] += args.move_w | ||
71 | +standard_landmark[:, 1] += args.move_h | ||
72 | + | ||
73 | +# data dir | ||
74 | +save_dir = os.path.join(args.save_dir, 'align_size(%d,%d)_move(%.3f,%.3f)_face_factor(%.3f)_%s' % (args.crop_size_h, args.crop_size_w, args.move_h, args.move_w, args.face_factor, args.save_format)) | ||
75 | +data_dir = os.path.join(save_dir, 'data') | ||
76 | +if not os.path.isdir(data_dir): | ||
77 | + os.makedirs(data_dir) | ||
78 | + | ||
79 | + | ||
80 | +def work(name, landmark) -> str: # a single work | ||
81 | + for _ in range(3): # try three times | ||
82 | + try: | ||
83 | + img = imread(os.path.join(args.img_dir, name)) | ||
84 | + img_crop, tformed_landmarks = align_crop(img, | ||
85 | + landmark, | ||
86 | + standard_landmark, | ||
87 | + crop_size=(args.crop_size_h, args.crop_size_w), | ||
88 | + face_factor=args.face_factor, | ||
89 | + align_type=args.align_type, | ||
90 | + order=args.order, | ||
91 | + mode=args.mode) | ||
92 | + | ||
93 | + name = os.path.splitext(name)[0] + '.' + args.save_format | ||
94 | + path = os.path.join(data_dir, name) | ||
95 | + if not os.path.isdir(os.path.split(path)[0]): | ||
96 | + os.makedirs(os.path.split(path)[0]) | ||
97 | + imwrite(path, img_crop) | ||
98 | + | ||
99 | + tformed_landmarks.shape = -1 | ||
100 | + name_landmark_str = ('%s' + ' %.1f' * n_landmark * 2) % ((name, ) + tuple(tformed_landmarks)) | ||
101 | + return name_landmark_str | ||
102 | + except: | ||
103 | + print('%s fails!' % name) | ||
104 | + | ||
105 | + | ||
106 | +if __name__ == "__main__": | ||
107 | + img_names = np.genfromtxt(args.landmark_file, dtype=np.str, usecols=0) | ||
108 | + landmarks = np.genfromtxt(args.landmark_file, dtype=np.float, | ||
109 | + usecols=range(1, n_landmark * 2 + 1)).reshape(-1, n_landmark, 2) | ||
110 | + | ||
111 | + n_pics = len(img_names) | ||
112 | + | ||
113 | + landmarks_path = os.path.join(save_dir, 'landmark.txt') | ||
114 | + f = open(landmarks_path, 'w') | ||
115 | + pool = Pool(args.n_worker) | ||
116 | + bar = tqdm.tqdm(total=n_pics) | ||
117 | + | ||
118 | + tasks = [] | ||
119 | + for i in range(n_pics): | ||
120 | + tasks.append(pool.apply_async(work, (img_names[i], landmarks[i]), callback=lambda _: bar.update())) | ||
121 | + | ||
122 | + try: | ||
123 | + result = tasks.pop(0).get() | ||
124 | + if result is not None and result != "": | ||
125 | + f.write(result + '\n') | ||
126 | + except: | ||
127 | + pass | ||
128 | + | ||
129 | + pool.close() | ||
130 | + pool.join() | ||
131 | + bar.close() | ||
132 | + f.close() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/HD-CelebA-Cropper/cropper.py
0 → 100644
1 | +import numpy as np | ||
2 | + | ||
3 | + | ||
4 | +def align_crop_opencv(img, | ||
5 | + src_landmarks, | ||
6 | + standard_landmarks, | ||
7 | + crop_size=512, | ||
8 | + face_factor=0.7, | ||
9 | + align_type='similarity', | ||
10 | + order=3, | ||
11 | + mode='edge'): | ||
12 | + """Align and crop a face image by landmarks. | ||
13 | + | ||
14 | + Arguments: | ||
15 | + img : Face image to be aligned and cropped. | ||
16 | + src_landmarks : [[x_1, y_1], ..., [x_n, y_n]]. | ||
17 | + standard_landmarks : Standard shape, should be normalized. | ||
18 | + crop_size : Output image size, should be 1. int for (crop_size, crop_size) | ||
19 | + or 2. (int, int) for (crop_size_h, crop_size_w). | ||
20 | + face_factor : The factor of face area relative to the output image. | ||
21 | + align_type : 'similarity' or 'affine'. | ||
22 | + order : The order of interpolation. The order has to be in the range 0-5: | ||
23 | + - 0: INTER_NEAREST | ||
24 | + - 1: INTER_LINEAR | ||
25 | + - 2: INTER_AREA | ||
26 | + - 3: INTER_CUBIC | ||
27 | + - 4: INTER_LANCZOS4 | ||
28 | + - 5: INTER_LANCZOS4 | ||
29 | + mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap']. | ||
30 | + Points outside the boundaries of the input are filled according | ||
31 | + to the given mode. | ||
32 | + """ | ||
33 | + # set OpenCV | ||
34 | + import cv2 | ||
35 | + inter = {0: cv2.INTER_NEAREST, 1: cv2.INTER_LINEAR, 2: cv2.INTER_AREA, | ||
36 | + 3: cv2.INTER_CUBIC, 4: cv2.INTER_LANCZOS4, 5: cv2.INTER_LANCZOS4} | ||
37 | + border = {'constant': cv2.BORDER_CONSTANT, 'edge': cv2.BORDER_REPLICATE, | ||
38 | + 'symmetric': cv2.BORDER_REFLECT, 'reflect': cv2.BORDER_REFLECT101, | ||
39 | + 'wrap': cv2.BORDER_WRAP} | ||
40 | + | ||
41 | + # check | ||
42 | + assert align_type in ['affine', 'similarity'], 'Invalid `align_type`! Allowed: %s!' % ['affine', 'similarity'] | ||
43 | + assert order in [0, 1, 2, 3, 4, 5], 'Invalid `order`! Allowed: %s!' % [0, 1, 2, 3, 4, 5] | ||
44 | + assert mode in ['constant', 'edge', 'symmetric', 'reflect', 'wrap'], 'Invalid `mode`! Allowed: %s!' % ['constant', 'edge', 'symmetric', 'reflect', 'wrap'] | ||
45 | + | ||
46 | + # crop size | ||
47 | + if isinstance(crop_size, (list, tuple)) and len(crop_size) == 2: | ||
48 | + crop_size_h = crop_size[0] | ||
49 | + crop_size_w = crop_size[1] | ||
50 | + elif isinstance(crop_size, int): | ||
51 | + crop_size_h = crop_size_w = crop_size | ||
52 | + else: | ||
53 | + raise Exception('Invalid `crop_size`! `crop_size` should be 1. int for (crop_size, crop_size) or 2. (int, int) for (crop_size_h, crop_size_w)!') | ||
54 | + | ||
55 | + # estimate transform matrix | ||
56 | + trg_landmarks = standard_landmarks * max(crop_size_h, crop_size_w) * face_factor + np.array([crop_size_w // 2, crop_size_h // 2]) | ||
57 | + if align_type == 'affine': | ||
58 | + tform = cv2.estimateAffine2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0] | ||
59 | + else: | ||
60 | + tform = cv2.estimateAffinePartial2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0] | ||
61 | + | ||
62 | + # warp image by given transform | ||
63 | + output_shape = (crop_size_h, crop_size_w) | ||
64 | + img_crop = cv2.warpAffine(img, tform, output_shape[::-1], flags=cv2.WARP_INVERSE_MAP + inter[order], borderMode=border[mode]) | ||
65 | + | ||
66 | + # get transformed landmarks | ||
67 | + tformed_landmarks = cv2.transform(np.expand_dims(src_landmarks, axis=0), cv2.invertAffineTransform(tform))[0] | ||
68 | + | ||
69 | + return img_crop, tformed_landmarks | ||
70 | + | ||
71 | + | ||
72 | +def align_crop_skimage(img, | ||
73 | + src_landmarks, | ||
74 | + standard_landmarks, | ||
75 | + crop_size=512, | ||
76 | + face_factor=0.7, | ||
77 | + align_type='similarity', | ||
78 | + order=3, | ||
79 | + mode='edge'): | ||
80 | + """Align and crop a face image by landmarks. | ||
81 | + | ||
82 | + Arguments: | ||
83 | + img : Face image to be aligned and cropped. | ||
84 | + src_landmarks : [[x_1, y_1], ..., [x_n, y_n]]. | ||
85 | + standard_landmarks : Standard shape, should be normalized. | ||
86 | + crop_size : Output image size, should be 1. int for (crop_size, crop_size) | ||
87 | + or 2. (int, int) for (crop_size_h, crop_size_w). | ||
88 | + face_factor : The factor of face area relative to the output image. | ||
89 | + align_type : 'similarity' or 'affine'. | ||
90 | + order : The order of interpolation. The order has to be in the range 0-5: | ||
91 | + - 0: INTER_NEAREST | ||
92 | + - 1: INTER_LINEAR | ||
93 | + - 2: INTER_AREA | ||
94 | + - 3: INTER_CUBIC | ||
95 | + - 4: INTER_LANCZOS4 | ||
96 | + - 5: INTER_LANCZOS4 | ||
97 | + mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap']. | ||
98 | + Points outside the boundaries of the input are filled according | ||
99 | + to the given mode. | ||
100 | + """ | ||
101 | + raise NotImplementedError("'align_crop_skimage' is not implemented!") |
Code/HD-CelebA-Cropper/scores.txt
0 → 100644
This diff could not be displayed because it is too large.
Code/MaskTheFace/.gitattributes
0 → 100644
Code/MaskTheFace/.gitignore
0 → 100644
1 | +*.pyc | ||
2 | +docs | ||
3 | +data | ||
4 | +lfw | ||
5 | +lfw_40 | ||
6 | +.idea | ||
7 | +loss | ||
8 | +vgg_face_dataset | ||
9 | +saved_network | ||
10 | +loss | ||
11 | +z_detect_face.py | ||
12 | +z_main.py | ||
13 | +*.npy | ||
14 | +*.Lnk | ||
15 | +data1 | ||
16 | +data1_masked | ||
17 | +scratch.py | ||
18 | +subset | ||
19 | +subset_masked | ||
20 | +vgg_face_dataset | ||
21 | +*.mp4 | ||
22 | +ML_examples | ||
23 | +*.pptx | ||
24 | +datasets | ||
25 | +*.dat | ||
26 | +*.docx | ||
27 | + |
Code/MaskTheFace/_config.yml
0 → 100644
1 | +theme: jekyll-theme-cayman | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/MaskTheFace/images/000001.jpg
0 → 100644
11.2 KB
Code/MaskTheFace/images/000001_binary.jpg
0 → 100644
2.41 KB
Code/MaskTheFace/images/000001_masked.jpg
0 → 100644
21.4 KB
Code/MaskTheFace/mask_the_face.py
0 → 100644
1 | +# Author: aqeelanwar | ||
2 | +# Created: 27 April,2020, 10:22 PM | ||
3 | +# Email: aqeel.anwar@gatech.edu | ||
4 | + | ||
5 | +import argparse | ||
6 | +import dlib | ||
7 | +from utils.aux_functions import * | ||
8 | + | ||
9 | + | ||
10 | +# Command-line input setup | ||
11 | +parser = argparse.ArgumentParser( | ||
12 | + description="MaskTheFace - Python code to mask faces dataset" | ||
13 | +) | ||
14 | +parser.add_argument( | ||
15 | + "--path", | ||
16 | + type=str, | ||
17 | + default="", | ||
18 | + help="Path to either the folder containing images or the image itself", | ||
19 | +) | ||
20 | +parser.add_argument( | ||
21 | + "--mask_type", | ||
22 | + type=str, | ||
23 | + default="surgical", | ||
24 | + choices=["surgical", "N95", "KN95", "cloth", "gas", "inpaint", "random", "all"], | ||
25 | + help="Type of the mask to be applied. Available options: all, surgical_blue, surgical_green, N95, cloth", | ||
26 | +) | ||
27 | + | ||
28 | +parser.add_argument( | ||
29 | + "--pattern", | ||
30 | + type=str, | ||
31 | + default="", | ||
32 | + help="Type of the pattern. Available options in masks/textures", | ||
33 | +) | ||
34 | + | ||
35 | +parser.add_argument( | ||
36 | + "--pattern_weight", | ||
37 | + type=float, | ||
38 | + default=0.5, | ||
39 | + help="Weight of the pattern. Must be between 0 and 1", | ||
40 | +) | ||
41 | + | ||
42 | +parser.add_argument( | ||
43 | + "--color", | ||
44 | + type=str, | ||
45 | + default="#0473e2", | ||
46 | + help="Hex color value that need to be overlayed to the mask", | ||
47 | +) | ||
48 | + | ||
49 | +parser.add_argument( | ||
50 | + "--color_weight", | ||
51 | + type=float, | ||
52 | + default=0.5, | ||
53 | + help="Weight of the color intensity. Must be between 0 and 1", | ||
54 | +) | ||
55 | + | ||
56 | +parser.add_argument( | ||
57 | + "--code", | ||
58 | + type=str, | ||
59 | + # default="cloth-masks/textures/check/check_4.jpg, cloth-#e54294, cloth-#ff0000, cloth, cloth-masks/textures/others/heart_1.png, cloth-masks/textures/fruits/pineapple.png, N95, surgical_blue, surgical_green", | ||
60 | + default="", | ||
61 | + help="Generate specific formats", | ||
62 | +) | ||
63 | + | ||
64 | + | ||
65 | +parser.add_argument( | ||
66 | + "--verbose", dest="verbose", action="store_true", help="Turn verbosity on" | ||
67 | +) | ||
68 | +parser.add_argument( | ||
69 | + "--write_original_image", | ||
70 | + dest="write_original_image", | ||
71 | + action="store_true", | ||
72 | + help="If true, original image is also stored in the masked folder", | ||
73 | +) | ||
74 | +parser.set_defaults(feature=False) | ||
75 | + | ||
76 | +args = parser.parse_args() | ||
77 | +args.write_path = args.path + "_masked" | ||
78 | + | ||
79 | +# Set up dlib face detector and predictor | ||
80 | +args.detector = dlib.get_frontal_face_detector() | ||
81 | +path_to_dlib_model = "dlib_models/shape_predictor_68_face_landmarks.dat" | ||
82 | +if not os.path.exists(path_to_dlib_model): | ||
83 | + download_dlib_model() | ||
84 | + | ||
85 | +args.predictor = dlib.shape_predictor(path_to_dlib_model) | ||
86 | + | ||
87 | +# Extract data from code | ||
88 | +mask_code = "".join(args.code.split()).split(",") | ||
89 | +args.code_count = np.zeros(len(mask_code)) | ||
90 | +args.mask_dict_of_dict = {} | ||
91 | + | ||
92 | + | ||
93 | +for i, entry in enumerate(mask_code): | ||
94 | + mask_dict = {} | ||
95 | + mask_color = "" | ||
96 | + mask_texture = "" | ||
97 | + mask_type = entry.split("-")[0] | ||
98 | + if len(entry.split("-")) == 2: | ||
99 | + mask_variation = entry.split("-")[1] | ||
100 | + if "#" in mask_variation: | ||
101 | + mask_color = mask_variation | ||
102 | + else: | ||
103 | + mask_texture = mask_variation | ||
104 | + mask_dict["type"] = mask_type | ||
105 | + mask_dict["color"] = mask_color | ||
106 | + mask_dict["texture"] = mask_texture | ||
107 | + args.mask_dict_of_dict[i] = mask_dict | ||
108 | + | ||
109 | +# Check if path is file or directory or none | ||
110 | +is_directory, is_file, is_other = check_path(args.path) | ||
111 | +display_MaskTheFace() | ||
112 | + | ||
113 | +if is_directory: | ||
114 | + path, dirs, files = os.walk(args.path).__next__() | ||
115 | + file_count = len(files) | ||
116 | + dirs_count = len(dirs) | ||
117 | + if len(files) > 0: | ||
118 | + print_orderly("Masking image files", 60) | ||
119 | + | ||
120 | + # Process files in the directory if any | ||
121 | + for f in tqdm(files): | ||
122 | + image_path = path + "/" + f | ||
123 | + | ||
124 | + write_path = path + "_masked" | ||
125 | + if not os.path.isdir(write_path): | ||
126 | + os.makedirs(write_path) | ||
127 | + | ||
128 | + if is_image(image_path): | ||
129 | + # Proceed if file is image | ||
130 | + if args.verbose: | ||
131 | + str_p = "Processing: " + image_path | ||
132 | + tqdm.write(str_p) | ||
133 | + | ||
134 | + split_path = f.rsplit(".") | ||
135 | + masked_image, mask, mask_binary_array, original_image = mask_image( | ||
136 | + image_path, args | ||
137 | + ) | ||
138 | + for i in range(len(mask)): | ||
139 | + w_path = ( | ||
140 | + write_path | ||
141 | + + "/" | ||
142 | + + split_path[0] | ||
143 | + + "_" | ||
144 | + + "masked" | ||
145 | + + "." | ||
146 | + + split_path[1] | ||
147 | + ) | ||
148 | + img = masked_image[i] | ||
149 | + binary_img = mask_binary_array[i] | ||
150 | + cv2.imwrite(w_path, img) | ||
151 | + cv2.imwrite( | ||
152 | + path + "_binary/" + split_path[0] + "_binary" + "." + split_path[1], | ||
153 | + binary_img, | ||
154 | + ) | ||
155 | + cv2.imwrite( | ||
156 | + path + "_original/" + split_path[0] + "." + split_path[1], | ||
157 | + original_image, | ||
158 | + ) | ||
159 | + | ||
160 | + print_orderly("Masking image directories", 60) | ||
161 | + | ||
162 | + # Process directories withing the path provided | ||
163 | + for d in tqdm(dirs): | ||
164 | + dir_path = args.path + "/" + d | ||
165 | + dir_write_path = args.write_path + "/" + d | ||
166 | + if not os.path.isdir(dir_write_path): | ||
167 | + os.makedirs(dir_write_path) | ||
168 | + _, _, files = os.walk(dir_path).__next__() | ||
169 | + | ||
170 | + # Process each files within subdirectory | ||
171 | + for f in files: | ||
172 | + image_path = dir_path + "/" + f | ||
173 | + if args.verbose: | ||
174 | + str_p = "Processing: " + image_path | ||
175 | + tqdm.write(str_p) | ||
176 | + write_path = dir_write_path | ||
177 | + if is_image(image_path): | ||
178 | + # Proceed if file is image | ||
179 | + split_path = f.rsplit(".") | ||
180 | + masked_image, mask, mask_binary, original_image = mask_image( | ||
181 | + image_path, args | ||
182 | + ) | ||
183 | + for i in range(len(mask)): | ||
184 | + w_path = ( | ||
185 | + write_path | ||
186 | + + "/" | ||
187 | + + split_path[0] | ||
188 | + + "_" | ||
189 | + + "masked" | ||
190 | + + "." | ||
191 | + + split_path[1] | ||
192 | + ) | ||
193 | + w_path_original = write_path + "/" + f | ||
194 | + img = masked_image[i] | ||
195 | + binary_img = mask_binary[i] | ||
196 | + cv2.imwrite( | ||
197 | + path | ||
198 | + + "_binary/" | ||
199 | + + split_path[0] | ||
200 | + + "_binary" | ||
201 | + + "." | ||
202 | + + split_path[1], | ||
203 | + binary_img, | ||
204 | + ) | ||
205 | + # Write the masked image | ||
206 | + cv2.imwrite(w_path, img) | ||
207 | + if args.write_original_image: | ||
208 | + # Write the original image | ||
209 | + cv2.imwrite(w_path_original, original_image) | ||
210 | + | ||
211 | + if args.verbose: | ||
212 | + print(args.code_count) | ||
213 | + | ||
214 | +# Process if the path was a file | ||
215 | +elif is_file: | ||
216 | + print("Masking image file") | ||
217 | + image_path = args.path | ||
218 | + write_path = args.path.rsplit(".")[0] | ||
219 | + if is_image(image_path): | ||
220 | + # Proceed if file is image | ||
221 | + # masked_images, mask, mask_binary_array, original_image | ||
222 | + masked_image, mask, mask_binary_array, original_image = mask_image( | ||
223 | + image_path, args | ||
224 | + ) | ||
225 | + for i in range(len(mask)): | ||
226 | + w_path = write_path + "_" + "masked" + "." + args.path.rsplit(".")[1] | ||
227 | + img = masked_image[i] | ||
228 | + binary_img = mask_binary_array[i] | ||
229 | + cv2.imwrite(w_path, img) | ||
230 | + cv2.imwrite(write_path + "_binary." + args.path.rsplit(".")[1], binary_img) | ||
231 | +else: | ||
232 | + print("Path is neither a valid file or a valid directory") | ||
233 | +print("Processing Done") |
Code/MaskTheFace/masks/masks.cfg
0 → 100644
1 | +[surgical] | ||
2 | +template: masks/templates/surgical.png | ||
3 | +mask_a: 21, 97 | ||
4 | +mask_b: 307, 22 | ||
5 | +mask_c: 600, 99 | ||
6 | +mask_d: 25, 322 | ||
7 | +mask_e: 295, 470 | ||
8 | +mask_f: 600, 323 | ||
9 | + | ||
10 | +[surgical_left] | ||
11 | +template: masks/templates/surgical_left.png | ||
12 | +mask_a: 39, 27 | ||
13 | +mask_b: 130, 9 | ||
14 | +mask_c: 567, 20 | ||
15 | +mask_d: 87, 207 | ||
16 | +mask_e: 168, 302 | ||
17 | +mask_f: 568, 202 | ||
18 | + | ||
19 | +[surgical_right] | ||
20 | +template: masks/templates/surgical_right.png | ||
21 | +mask_a: 3, 20 | ||
22 | +mask_b: 440, 9 | ||
23 | +mask_c: 531, 27 | ||
24 | +mask_d: 2, 202 | ||
25 | +mask_e: 402, 302 | ||
26 | +mask_f: 483, 207 | ||
27 | + | ||
28 | +[surgical_green] | ||
29 | +template: masks/templates/surgical_green.png | ||
30 | +mask_a: 21, 97 | ||
31 | +mask_b: 307, 22 | ||
32 | +mask_c: 600, 99 | ||
33 | +mask_d: 25, 322 | ||
34 | +mask_e: 295, 470 | ||
35 | +mask_f: 600, 323 | ||
36 | + | ||
37 | +[surgical_green_left] | ||
38 | +template: masks/templates/surgical_green_left.png | ||
39 | +mask_a: 39, 27 | ||
40 | +mask_b: 130, 9 | ||
41 | +mask_c: 567, 20 | ||
42 | +mask_d: 87, 207 | ||
43 | +mask_e: 168, 302 | ||
44 | +mask_f: 568, 202 | ||
45 | + | ||
46 | +[surgical_green_right] | ||
47 | +template: masks/templates/surgical_green_right.png | ||
48 | +mask_a: 3, 20 | ||
49 | +mask_b: 440, 9 | ||
50 | +mask_c: 531, 27 | ||
51 | +mask_d: 2, 202 | ||
52 | +mask_e: 402, 302 | ||
53 | +mask_f: 483, 207 | ||
54 | + | ||
55 | +[surgical_blue] | ||
56 | +template: masks/templates/surgical_blue.png | ||
57 | +mask_a: 21, 97 | ||
58 | +mask_b: 307, 22 | ||
59 | +mask_c: 600, 99 | ||
60 | +mask_d: 25, 322 | ||
61 | +mask_e: 295, 470 | ||
62 | +mask_f: 600, 323 | ||
63 | + | ||
64 | +[surgical_blue_left] | ||
65 | +template: masks/templates/surgical_blue_left.png | ||
66 | +mask_a: 39, 27 | ||
67 | +mask_b: 130, 9 | ||
68 | +mask_c: 567, 20 | ||
69 | +mask_d: 87, 207 | ||
70 | +mask_e: 168, 302 | ||
71 | +mask_f: 568, 202 | ||
72 | + | ||
73 | +[surgical_blue_right] | ||
74 | +template: masks/templates/surgical_blue_right.png | ||
75 | +mask_a: 3, 20 | ||
76 | +mask_b: 440, 9 | ||
77 | +mask_c: 531, 27 | ||
78 | +mask_d: 2, 202 | ||
79 | +mask_e: 402, 302 | ||
80 | +mask_f: 483, 207 | ||
81 | + | ||
82 | + | ||
83 | +[N95] | ||
84 | +template: masks/templates/N95.png | ||
85 | +mask_a: 15, 119 | ||
86 | +mask_b: 327, 5 | ||
87 | +mask_c: 640, 93 | ||
88 | +mask_d: 13, 285 | ||
89 | +mask_e: 351, 518 | ||
90 | +mask_f: 645, 285 | ||
91 | + | ||
92 | +;[N95_left] | ||
93 | +;template: masks/N95_left.png | ||
94 | +;mask_a: 176, 121 | ||
95 | +;mask_b: 313, 46 | ||
96 | +;mask_c: 799, 135 | ||
97 | +;mask_d: 97, 438 | ||
98 | +;mask_e: 329, 627 | ||
99 | +;mask_f: 791, 401 | ||
100 | + | ||
101 | +[N95_right] | ||
102 | +template: masks/templates/N95_right.png | ||
103 | +mask_c: 979, 331 | ||
104 | +mask_b: 806, 172 | ||
105 | +mask_a: 12, 222 | ||
106 | +mask_f: 907, 762 | ||
107 | +mask_e: 577, 875 | ||
108 | +mask_d: -4, 632 | ||
109 | + | ||
110 | +[N95_left] | ||
111 | +template: masks/templates/N95_left.png | ||
112 | +mask_a: 193, 331 | ||
113 | +mask_b: 366, 172 | ||
114 | +mask_c: 1160, 222 | ||
115 | +mask_d: 265, 762 | ||
116 | +mask_e: 595, 875 | ||
117 | +mask_f: 1176, 632 | ||
118 | + | ||
119 | + | ||
120 | +[cloth_left] | ||
121 | +template: masks/templates/cloth_left.png | ||
122 | +mask_a: 65, 93 | ||
123 | +mask_b: 162, 15 | ||
124 | +mask_c: 672, 75 | ||
125 | +mask_d: 114, 296 | ||
126 | +mask_e: 207, 443 | ||
127 | +mask_f: 671, 341 | ||
128 | + | ||
129 | +[cloth_right] | ||
130 | +template: masks/templates/cloth_right.png | ||
131 | +mask_a: 98, 93 | ||
132 | +mask_b: 608, 15 | ||
133 | +mask_c: 705, 75 | ||
134 | +mask_d: 99, 296 | ||
135 | +mask_e: 563, 443 | ||
136 | +mask_f: 656, 341 | ||
137 | + | ||
138 | +[cloth] | ||
139 | +template: masks/templates/cloth.png | ||
140 | +mask_a: 122, 90 | ||
141 | +mask_b: 405, 7 | ||
142 | +mask_c: 686, 79 | ||
143 | +mask_d: 165, 323 | ||
144 | +mask_e: 406, 509 | ||
145 | +mask_f: 653, 311 | ||
146 | + | ||
147 | +[gas] | ||
148 | +template: masks/templates/gas.png | ||
149 | +mask_a: 330, 431 | ||
150 | +mask_b: 873, 117 | ||
151 | +mask_c: 1494, 434 | ||
152 | +mask_d: 430, 754 | ||
153 | +mask_e: 869, 1100 | ||
154 | +mask_f: 1400, 710 | ||
155 | + | ||
156 | +[gas_left] | ||
157 | +template: masks/templates/gas_left.png | ||
158 | +mask_a: 239, 238 | ||
159 | +mask_b: 317, 42 | ||
160 | +mask_c: 965, 239 | ||
161 | +mask_d: 224, 404 | ||
162 | +mask_e: 337, 502 | ||
163 | +mask_f: 963, 406 | ||
164 | + | ||
165 | +[gas_right] | ||
166 | +template: masks/templates/gas_right.png | ||
167 | +mask_c: 621, 238 | ||
168 | +mask_b: 543, 60 | ||
169 | +mask_a: -105, 239 | ||
170 | +mask_f: 636, 404 | ||
171 | +mask_e: 523, 502 | ||
172 | +mask_d: -103, 406 | ||
173 | + | ||
174 | +[KN95] | ||
175 | +template: masks/templates/KN95.png | ||
176 | +mask_a: 20, 47 | ||
177 | +mask_b: 410, 5 | ||
178 | +mask_c: 760, 55 | ||
179 | +mask_d: 75, 340 | ||
180 | +mask_e: 398, 600 | ||
181 | +mask_f: 671, 320 | ||
182 | + | ||
183 | +[KN95_left] | ||
184 | +template: masks/templates/KN95_left.png | ||
185 | +mask_a: 52, 258 | ||
186 | +mask_b: 207, 100 | ||
187 | +mask_c: 730, 80 | ||
188 | +mask_d: 210, 408 | ||
189 | +mask_e: 335, 604 | ||
190 | +mask_f: 770, 270 | ||
191 | + | ||
192 | +[KN95_right] | ||
193 | +template: masks/templates/KN95_right.png | ||
194 | +mask_c: 664, 258 | ||
195 | +mask_b: 509, 100 | ||
196 | +mask_a: -14, 80 | ||
197 | +mask_f: 506, 408 | ||
198 | +mask_e: 381, 604 | ||
199 | +mask_d: -54, 270 | ||
200 | + | ||
201 | + | ||
202 | +[empty] | ||
203 | +[empty_left] | ||
204 | +[empty_right] | ||
205 | + | ||
206 | +[inpaint] | ||
207 | +[inpaint_left] | ||
208 | +[inpaint_right] | ||
209 | + |
Code/MaskTheFace/masks/templates/KN95.png
0 → 100644
236 KB
415 KB
415 KB
Code/MaskTheFace/masks/templates/N95.png
0 → 100644
326 KB
491 KB
495 KB
Code/MaskTheFace/masks/templates/cloth.png
0 → 100644
308 KB
300 KB
302 KB
Code/MaskTheFace/masks/templates/gas.png
0 → 100644
1.11 MB
463 KB
464 KB
157 KB
329 KB
242 KB
243 KB
242 KB
179 KB
181 KB
128 KB
129 KB
75.7 KB
280 KB
21.5 KB
849 KB
592 KB
881 KB
937 KB
622 KB
72.8 KB
335 KB
391 KB
242 KB
351 KB
564 KB
1.21 MB
740 KB
849 KB
320 KB
118 KB
60.7 KB
140 KB
231 KB
274 KB
155 KB
554 KB
343 KB
60 KB
Code/MaskTheFace/requirements.txt
0 → 100644
Code/MaskTheFace/utils/__init__.py
0 → 100644
Code/MaskTheFace/utils/aux_functions.py
0 → 100644
1 | +# Author: aqeelanwar | ||
2 | +# Created: 27 April,2020, 10:21 PM | ||
3 | +# Email: aqeel.anwar@gatech.edu | ||
4 | + | ||
5 | +from configparser import ConfigParser | ||
6 | +import cv2, math, os | ||
7 | +from PIL import Image, ImageDraw | ||
8 | +from tqdm import tqdm | ||
9 | +from utils.read_cfg import read_cfg | ||
10 | +from utils.fit_ellipse import * | ||
11 | +import random | ||
12 | +from utils.create_mask import texture_the_mask, color_the_mask | ||
13 | +from imutils import face_utils | ||
14 | +import requests | ||
15 | +from zipfile import ZipFile | ||
16 | +from tqdm import tqdm | ||
17 | +import bz2, shutil | ||
18 | + | ||
19 | + | ||
20 | +def download_dlib_model(): | ||
21 | + print_orderly("Get dlib model", 60) | ||
22 | + dlib_model_link = "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2" | ||
23 | + print("Downloading dlib model...") | ||
24 | + with requests.get(dlib_model_link, stream=True) as r: | ||
25 | + print("Zip file size: ", np.round(len(r.content) / 1024 / 1024, 2), "MB") | ||
26 | + destination = ( | ||
27 | + "dlib_models" + os.path.sep + "shape_predictor_68_face_landmarks.dat.bz2" | ||
28 | + ) | ||
29 | + if not os.path.exists(destination.rsplit(os.path.sep, 1)[0]): | ||
30 | + os.mkdir(destination.rsplit(os.path.sep, 1)[0]) | ||
31 | + print("Saving dlib model...") | ||
32 | + with open(destination, "wb") as fd: | ||
33 | + for chunk in r.iter_content(chunk_size=32678): | ||
34 | + fd.write(chunk) | ||
35 | + print("Extracting dlib model...") | ||
36 | + with bz2.BZ2File(destination) as fr, open( | ||
37 | + "dlib_models/shape_predictor_68_face_landmarks.dat", "wb" | ||
38 | + ) as fw: | ||
39 | + shutil.copyfileobj(fr, fw) | ||
40 | + print("Saved: ", destination) | ||
41 | + print_orderly("done", 60) | ||
42 | + | ||
43 | + os.remove(destination) | ||
44 | + | ||
45 | + | ||
46 | +def get_line(face_landmark, image, type="eye", debug=False): | ||
47 | + pil_image = Image.fromarray(image) | ||
48 | + d = ImageDraw.Draw(pil_image) | ||
49 | + left_eye = face_landmark["left_eye"] | ||
50 | + right_eye = face_landmark["right_eye"] | ||
51 | + left_eye_mid = np.mean(np.array(left_eye), axis=0) | ||
52 | + right_eye_mid = np.mean(np.array(right_eye), axis=0) | ||
53 | + eye_line_mid = (left_eye_mid + right_eye_mid) / 2 | ||
54 | + | ||
55 | + if type == "eye": | ||
56 | + left_point = left_eye_mid | ||
57 | + right_point = right_eye_mid | ||
58 | + mid_point = eye_line_mid | ||
59 | + | ||
60 | + elif type == "nose_mid": | ||
61 | + nose_length = ( | ||
62 | + face_landmark["nose_bridge"][-1][1] - face_landmark["nose_bridge"][0][1] | ||
63 | + ) | ||
64 | + left_point = [left_eye_mid[0], left_eye_mid[1] + nose_length / 2] | ||
65 | + right_point = [right_eye_mid[0], right_eye_mid[1] + nose_length / 2] | ||
66 | + # mid_point = ( | ||
67 | + # face_landmark["nose_bridge"][-1][1] + face_landmark["nose_bridge"][0][1] | ||
68 | + # ) / 2 | ||
69 | + | ||
70 | + mid_pointY = ( | ||
71 | + face_landmark["nose_bridge"][-1][1] + face_landmark["nose_bridge"][0][1] | ||
72 | + ) / 2 | ||
73 | + mid_pointX = ( | ||
74 | + face_landmark["nose_bridge"][-1][0] + face_landmark["nose_bridge"][0][0] | ||
75 | + ) / 2 | ||
76 | + mid_point = (mid_pointX, mid_pointY) | ||
77 | + | ||
78 | + elif type == "nose_tip": | ||
79 | + nose_length = ( | ||
80 | + face_landmark["nose_bridge"][-1][1] - face_landmark["nose_bridge"][0][1] | ||
81 | + ) | ||
82 | + left_point = [left_eye_mid[0], left_eye_mid[1] + nose_length] | ||
83 | + right_point = [right_eye_mid[0], right_eye_mid[1] + nose_length] | ||
84 | + mid_point = ( | ||
85 | + face_landmark["nose_bridge"][-1][1] + face_landmark["nose_bridge"][0][1] | ||
86 | + ) / 2 | ||
87 | + | ||
88 | + elif type == "bottom_lip": | ||
89 | + bottom_lip = face_landmark["bottom_lip"] | ||
90 | + bottom_lip_mid = np.max(np.array(bottom_lip), axis=0) | ||
91 | + shiftY = bottom_lip_mid[1] - eye_line_mid[1] | ||
92 | + left_point = [left_eye_mid[0], left_eye_mid[1] + shiftY] | ||
93 | + right_point = [right_eye_mid[0], right_eye_mid[1] + shiftY] | ||
94 | + mid_point = bottom_lip_mid | ||
95 | + | ||
96 | + elif type == "perp_line": | ||
97 | + bottom_lip = face_landmark["bottom_lip"] | ||
98 | + bottom_lip_mid = np.mean(np.array(bottom_lip), axis=0) | ||
99 | + | ||
100 | + left_point = eye_line_mid | ||
101 | + left_point = face_landmark["nose_bridge"][0] | ||
102 | + right_point = bottom_lip_mid | ||
103 | + | ||
104 | + mid_point = bottom_lip_mid | ||
105 | + | ||
106 | + elif type == "nose_long": | ||
107 | + nose_bridge = face_landmark["nose_bridge"] | ||
108 | + left_point = [nose_bridge[0][0], nose_bridge[0][1]] | ||
109 | + right_point = [nose_bridge[-1][0], nose_bridge[-1][1]] | ||
110 | + | ||
111 | + mid_point = left_point | ||
112 | + | ||
113 | + # d.line(eye_mid, width=5, fill='red') | ||
114 | + y = [left_point[1], right_point[1]] | ||
115 | + x = [left_point[0], right_point[0]] | ||
116 | + # cv2.imshow('h', image) | ||
117 | + # cv2.waitKey(0) | ||
118 | + eye_line = fit_line(x, y, image) | ||
119 | + d.line(eye_line, width=5, fill="blue") | ||
120 | + | ||
121 | + # Perpendicular Line | ||
122 | + # (midX, midY) and (midX - y2 + y1, midY + x2 - x1) | ||
123 | + y = [ | ||
124 | + (left_point[1] + right_point[1]) / 2, | ||
125 | + (left_point[1] + right_point[1]) / 2 + right_point[0] - left_point[0], | ||
126 | + ] | ||
127 | + x = [ | ||
128 | + (left_point[0] + right_point[0]) / 2, | ||
129 | + (left_point[0] + right_point[0]) / 2 - right_point[1] + left_point[1], | ||
130 | + ] | ||
131 | + perp_line = fit_line(x, y, image) | ||
132 | + if debug: | ||
133 | + d.line(perp_line, width=5, fill="red") | ||
134 | + pil_image.show() | ||
135 | + return eye_line, perp_line, left_point, right_point, mid_point | ||
136 | + | ||
137 | + | ||
138 | +def get_points_on_chin(line, face_landmark, chin_type="chin"): | ||
139 | + chin = face_landmark[chin_type] | ||
140 | + points_on_chin = [] | ||
141 | + for i in range(len(chin) - 1): | ||
142 | + chin_first_point = [chin[i][0], chin[i][1]] | ||
143 | + chin_second_point = [chin[i + 1][0], chin[i + 1][1]] | ||
144 | + | ||
145 | + flag, x, y = line_intersection(line, (chin_first_point, chin_second_point)) | ||
146 | + if flag: | ||
147 | + points_on_chin.append((x, y)) | ||
148 | + | ||
149 | + return points_on_chin | ||
150 | + | ||
151 | + | ||
152 | +def plot_lines(face_line, image, debug=False): | ||
153 | + pil_image = Image.fromarray(image) | ||
154 | + if debug: | ||
155 | + d = ImageDraw.Draw(pil_image) | ||
156 | + d.line(face_line, width=4, fill="white") | ||
157 | + pil_image.show() | ||
158 | + | ||
159 | + | ||
160 | +def line_intersection(line1, line2): | ||
161 | + # mid = int(len(line1) / 2) | ||
162 | + start = 0 | ||
163 | + end = -1 | ||
164 | + line1 = ([line1[start][0], line1[start][1]], [line1[end][0], line1[end][1]]) | ||
165 | + | ||
166 | + xdiff = (line1[0][0] - line1[1][0], line2[0][0] - line2[1][0]) | ||
167 | + ydiff = (line1[0][1] - line1[1][1], line2[0][1] - line2[1][1]) | ||
168 | + x = [] | ||
169 | + y = [] | ||
170 | + flag = False | ||
171 | + | ||
172 | + def det(a, b): | ||
173 | + return a[0] * b[1] - a[1] * b[0] | ||
174 | + | ||
175 | + div = det(xdiff, ydiff) | ||
176 | + if div == 0: | ||
177 | + return flag, x, y | ||
178 | + | ||
179 | + d = (det(*line1), det(*line2)) | ||
180 | + x = det(d, xdiff) / div | ||
181 | + y = det(d, ydiff) / div | ||
182 | + | ||
183 | + segment_minX = min(line2[0][0], line2[1][0]) | ||
184 | + segment_maxX = max(line2[0][0], line2[1][0]) | ||
185 | + | ||
186 | + segment_minY = min(line2[0][1], line2[1][1]) | ||
187 | + segment_maxY = max(line2[0][1], line2[1][1]) | ||
188 | + | ||
189 | + if ( | ||
190 | + segment_maxX + 1 >= x >= segment_minX - 1 | ||
191 | + and segment_maxY + 1 >= y >= segment_minY - 1 | ||
192 | + ): | ||
193 | + flag = True | ||
194 | + | ||
195 | + return flag, x, y | ||
196 | + | ||
197 | + | ||
198 | +def fit_line(x, y, image): | ||
199 | + if x[0] == x[1]: | ||
200 | + x[0] += 0.1 | ||
201 | + coefficients = np.polyfit(x, y, 1) | ||
202 | + polynomial = np.poly1d(coefficients) | ||
203 | + x_axis = np.linspace(0, image.shape[1], 50) | ||
204 | + y_axis = polynomial(x_axis) | ||
205 | + eye_line = [] | ||
206 | + for i in range(len(x_axis)): | ||
207 | + eye_line.append((x_axis[i], y_axis[i])) | ||
208 | + | ||
209 | + return eye_line | ||
210 | + | ||
211 | + | ||
212 | +def get_six_points(face_landmark, image): | ||
213 | + _, perp_line1, _, _, m = get_line(face_landmark, image, type="nose_mid") | ||
214 | + face_b = m | ||
215 | + | ||
216 | + perp_line, _, _, _, _ = get_line(face_landmark, image, type="perp_line") | ||
217 | + points1 = get_points_on_chin(perp_line1, face_landmark) | ||
218 | + points = get_points_on_chin(perp_line, face_landmark) | ||
219 | + if not points1: | ||
220 | + face_e = tuple(np.asarray(points[0])) | ||
221 | + elif not points: | ||
222 | + face_e = tuple(np.asarray(points1[0])) | ||
223 | + else: | ||
224 | + face_e = tuple((np.asarray(points[0]) + np.asarray(points1[0])) / 2) | ||
225 | + # face_e = points1[0] | ||
226 | + nose_mid_line, _, _, _, _ = get_line(face_landmark, image, type="nose_long") | ||
227 | + | ||
228 | + angle = get_angle(perp_line, nose_mid_line) | ||
229 | + # print("angle: ", angle) | ||
230 | + nose_mid_line, _, _, _, _ = get_line(face_landmark, image, type="nose_tip") | ||
231 | + points = get_points_on_chin(nose_mid_line, face_landmark) | ||
232 | + if len(points) < 2: | ||
233 | + face_landmark = get_face_ellipse(face_landmark) | ||
234 | + # print("extrapolating chin") | ||
235 | + points = get_points_on_chin( | ||
236 | + nose_mid_line, face_landmark, chin_type="chin_extrapolated" | ||
237 | + ) | ||
238 | + if len(points) < 2: | ||
239 | + points = [] | ||
240 | + points.append(face_landmark["chin"][0]) | ||
241 | + points.append(face_landmark["chin"][-1]) | ||
242 | + face_a = points[0] | ||
243 | + face_c = points[-1] | ||
244 | + # cv2.imshow('j', image) | ||
245 | + # cv2.waitKey(0) | ||
246 | + nose_mid_line, _, _, _, _ = get_line(face_landmark, image, type="bottom_lip") | ||
247 | + points = get_points_on_chin(nose_mid_line, face_landmark) | ||
248 | + face_d = points[0] | ||
249 | + face_f = points[-1] | ||
250 | + | ||
251 | + six_points = np.float32([face_a, face_b, face_c, face_f, face_e, face_d]) | ||
252 | + | ||
253 | + return six_points, angle | ||
254 | + | ||
255 | + | ||
256 | +def get_angle(line1, line2): | ||
257 | + delta_y = line1[-1][1] - line1[0][1] | ||
258 | + delta_x = line1[-1][0] - line1[0][0] | ||
259 | + perp_angle = math.degrees(math.atan2(delta_y, delta_x)) | ||
260 | + if delta_x < 0: | ||
261 | + perp_angle = perp_angle + 180 | ||
262 | + if perp_angle < 0: | ||
263 | + perp_angle += 360 | ||
264 | + if perp_angle > 180: | ||
265 | + perp_angle -= 180 | ||
266 | + | ||
267 | + # print("perp", perp_angle) | ||
268 | + delta_y = line2[-1][1] - line2[0][1] | ||
269 | + delta_x = line2[-1][0] - line2[0][0] | ||
270 | + nose_angle = math.degrees(math.atan2(delta_y, delta_x)) | ||
271 | + | ||
272 | + if delta_x < 0: | ||
273 | + nose_angle = nose_angle + 180 | ||
274 | + if nose_angle < 0: | ||
275 | + nose_angle += 360 | ||
276 | + if nose_angle > 180: | ||
277 | + nose_angle -= 180 | ||
278 | + # print("nose", nose_angle) | ||
279 | + | ||
280 | + angle = nose_angle - perp_angle | ||
281 | + return angle | ||
282 | + | ||
283 | + | ||
284 | +def mask_face(image, face_location, six_points, angle, args, type="surgical"): | ||
285 | + debug = False | ||
286 | + | ||
287 | + # Find the face angle | ||
288 | + threshold = 13 | ||
289 | + if angle < -threshold: | ||
290 | + type += "_right" | ||
291 | + elif angle > threshold: | ||
292 | + type += "_left" | ||
293 | + | ||
294 | + face_height = face_location[2] - face_location[0] | ||
295 | + face_width = face_location[1] - face_location[3] | ||
296 | + # image = image_raw[ | ||
297 | + # face_location[0]-int(face_width/2): face_location[2]+int(face_width/2), | ||
298 | + # face_location[3]-int(face_height/2): face_location[1]+int(face_height/2), | ||
299 | + # :, | ||
300 | + # ] | ||
301 | + # cv2.imshow('win', image) | ||
302 | + # cv2.waitKey(0) | ||
303 | + # Read appropriate mask image | ||
304 | + w = image.shape[0] | ||
305 | + h = image.shape[1] | ||
306 | + if not "empty" in type and not "inpaint" in type: | ||
307 | + cfg = read_cfg(config_filename="masks/masks.cfg", mask_type=type, verbose=False) | ||
308 | + else: | ||
309 | + if "left" in type: | ||
310 | + str = "surgical_blue_left" | ||
311 | + elif "right" in type: | ||
312 | + str = "surgical_blue_right" | ||
313 | + else: | ||
314 | + str = "surgical_blue" | ||
315 | + cfg = read_cfg(config_filename="masks/masks.cfg", mask_type=str, verbose=False) | ||
316 | + img = cv2.imread(cfg.template, cv2.IMREAD_UNCHANGED) | ||
317 | + | ||
318 | + # Process the mask if necessary | ||
319 | + if args.pattern: | ||
320 | + # Apply pattern to mask | ||
321 | + img = texture_the_mask(img, args.pattern, args.pattern_weight) | ||
322 | + | ||
323 | + if args.color: | ||
324 | + # Apply color to mask | ||
325 | + img = color_the_mask(img, args.color, args.color_weight) | ||
326 | + | ||
327 | + mask_line = np.float32( | ||
328 | + [cfg.mask_a, cfg.mask_b, cfg.mask_c, cfg.mask_f, cfg.mask_e, cfg.mask_d] | ||
329 | + ) | ||
330 | + # Warp the mask | ||
331 | + M, mask = cv2.findHomography(mask_line, six_points) | ||
332 | + dst_mask = cv2.warpPerspective(img, M, (h, w)) | ||
333 | + dst_mask_points = cv2.perspectiveTransform(mask_line.reshape(-1, 1, 2), M) | ||
334 | + mask = dst_mask[:, :, 3] | ||
335 | + face_height = face_location[2] - face_location[0] | ||
336 | + face_width = face_location[1] - face_location[3] | ||
337 | + image_face = image[ | ||
338 | + face_location[0] + int(face_height / 2) : face_location[2], | ||
339 | + face_location[3] : face_location[1], | ||
340 | + :, | ||
341 | + ] | ||
342 | + | ||
343 | + image_face = image | ||
344 | + | ||
345 | + # Adjust Brightness | ||
346 | + mask_brightness = get_avg_brightness(img) | ||
347 | + img_brightness = get_avg_brightness(image_face) | ||
348 | + delta_b = 1 + (img_brightness - mask_brightness) / 255 | ||
349 | + dst_mask = change_brightness(dst_mask, delta_b) | ||
350 | + | ||
351 | + # Adjust Saturation | ||
352 | + mask_saturation = get_avg_saturation(img) | ||
353 | + img_saturation = get_avg_saturation(image_face) | ||
354 | + delta_s = 1 - (img_saturation - mask_saturation) / 255 | ||
355 | + dst_mask = change_saturation(dst_mask, delta_s) | ||
356 | + | ||
357 | + # Apply mask | ||
358 | + mask_inv = cv2.bitwise_not(mask) | ||
359 | + img_bg = cv2.bitwise_and(image, image, mask=mask_inv) | ||
360 | + img_fg = cv2.bitwise_and(dst_mask, dst_mask, mask=mask) | ||
361 | + out_img = cv2.add(img_bg, img_fg[:, :, 0:3]) | ||
362 | + if "empty" in type or "inpaint" in type: | ||
363 | + out_img = img_bg | ||
364 | + # Plot key points | ||
365 | + | ||
366 | + if "inpaint" in type: | ||
367 | + out_img = cv2.inpaint(out_img, mask, 3, cv2.INPAINT_TELEA) | ||
368 | + # dst_NS = cv2.inpaint(img, mask, 3, cv2.INPAINT_NS) | ||
369 | + | ||
370 | + if debug: | ||
371 | + for i in six_points: | ||
372 | + cv2.circle(out_img, (i[0], i[1]), radius=4, color=(0, 0, 255), thickness=-1) | ||
373 | + | ||
374 | + for i in dst_mask_points: | ||
375 | + cv2.circle( | ||
376 | + out_img, (i[0][0], i[0][1]), radius=4, color=(0, 255, 0), thickness=-1 | ||
377 | + ) | ||
378 | + | ||
379 | + return out_img, mask | ||
380 | + | ||
381 | + | ||
382 | +def draw_landmarks(face_landmarks, image): | ||
383 | + pil_image = Image.fromarray(image) | ||
384 | + d = ImageDraw.Draw(pil_image) | ||
385 | + for facial_feature in face_landmarks.keys(): | ||
386 | + d.line(face_landmarks[facial_feature], width=5, fill="white") | ||
387 | + pil_image.show() | ||
388 | + | ||
389 | + | ||
390 | +def get_face_ellipse(face_landmark): | ||
391 | + chin = face_landmark["chin"] | ||
392 | + x = [] | ||
393 | + y = [] | ||
394 | + for point in chin: | ||
395 | + x.append(point[0]) | ||
396 | + y.append(point[1]) | ||
397 | + | ||
398 | + x = np.asarray(x) | ||
399 | + y = np.asarray(y) | ||
400 | + | ||
401 | + a = fitEllipse(x, y) | ||
402 | + center = ellipse_center(a) | ||
403 | + phi = ellipse_angle_of_rotation(a) | ||
404 | + axes = ellipse_axis_length(a) | ||
405 | + a, b = axes | ||
406 | + | ||
407 | + arc = 2.2 | ||
408 | + R = np.arange(0, arc * np.pi, 0.2) | ||
409 | + xx = center[0] + a * np.cos(R) * np.cos(phi) - b * np.sin(R) * np.sin(phi) | ||
410 | + yy = center[1] + a * np.cos(R) * np.sin(phi) + b * np.sin(R) * np.cos(phi) | ||
411 | + chin_extrapolated = [] | ||
412 | + for i in range(len(R)): | ||
413 | + chin_extrapolated.append((xx[i], yy[i])) | ||
414 | + face_landmark["chin_extrapolated"] = chin_extrapolated | ||
415 | + return face_landmark | ||
416 | + | ||
417 | + | ||
418 | +def get_avg_brightness(img): | ||
419 | + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | ||
420 | + h, s, v = cv2.split(img_hsv) | ||
421 | + return np.mean(v) | ||
422 | + | ||
423 | + | ||
424 | +def get_avg_saturation(img): | ||
425 | + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | ||
426 | + h, s, v = cv2.split(img_hsv) | ||
427 | + return np.mean(v) | ||
428 | + | ||
429 | + | ||
430 | +def change_brightness(img, value=1.0): | ||
431 | + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | ||
432 | + h, s, v = cv2.split(img_hsv) | ||
433 | + v = value * v | ||
434 | + v[v > 255] = 255 | ||
435 | + v = np.asarray(v, dtype=np.uint8) | ||
436 | + final_hsv = cv2.merge((h, s, v)) | ||
437 | + img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR) | ||
438 | + return img | ||
439 | + | ||
440 | + | ||
441 | +def change_saturation(img, value=1.0): | ||
442 | + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | ||
443 | + h, s, v = cv2.split(img_hsv) | ||
444 | + s = value * s | ||
445 | + s[s > 255] = 255 | ||
446 | + s = np.asarray(s, dtype=np.uint8) | ||
447 | + final_hsv = cv2.merge((h, s, v)) | ||
448 | + img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR) | ||
449 | + return img | ||
450 | + | ||
451 | + | ||
452 | +def check_path(path): | ||
453 | + is_directory = False | ||
454 | + is_file = False | ||
455 | + is_other = False | ||
456 | + if os.path.isdir(path): | ||
457 | + is_directory = True | ||
458 | + elif os.path.isfile(path): | ||
459 | + is_file = True | ||
460 | + else: | ||
461 | + is_other = True | ||
462 | + | ||
463 | + return is_directory, is_file, is_other | ||
464 | + | ||
465 | + | ||
466 | +def shape_to_landmarks(shape): | ||
467 | + face_landmarks = {} | ||
468 | + face_landmarks["left_eyebrow"] = [ | ||
469 | + tuple(shape[17]), | ||
470 | + tuple(shape[18]), | ||
471 | + tuple(shape[19]), | ||
472 | + tuple(shape[20]), | ||
473 | + tuple(shape[21]), | ||
474 | + ] | ||
475 | + face_landmarks["right_eyebrow"] = [ | ||
476 | + tuple(shape[22]), | ||
477 | + tuple(shape[23]), | ||
478 | + tuple(shape[24]), | ||
479 | + tuple(shape[25]), | ||
480 | + tuple(shape[26]), | ||
481 | + ] | ||
482 | + face_landmarks["nose_bridge"] = [ | ||
483 | + tuple(shape[27]), | ||
484 | + tuple(shape[28]), | ||
485 | + tuple(shape[29]), | ||
486 | + tuple(shape[30]), | ||
487 | + ] | ||
488 | + face_landmarks["nose_tip"] = [ | ||
489 | + tuple(shape[31]), | ||
490 | + tuple(shape[32]), | ||
491 | + tuple(shape[33]), | ||
492 | + tuple(shape[34]), | ||
493 | + tuple(shape[35]), | ||
494 | + ] | ||
495 | + face_landmarks["left_eye"] = [ | ||
496 | + tuple(shape[36]), | ||
497 | + tuple(shape[37]), | ||
498 | + tuple(shape[38]), | ||
499 | + tuple(shape[39]), | ||
500 | + tuple(shape[40]), | ||
501 | + tuple(shape[41]), | ||
502 | + ] | ||
503 | + face_landmarks["right_eye"] = [ | ||
504 | + tuple(shape[42]), | ||
505 | + tuple(shape[43]), | ||
506 | + tuple(shape[44]), | ||
507 | + tuple(shape[45]), | ||
508 | + tuple(shape[46]), | ||
509 | + tuple(shape[47]), | ||
510 | + ] | ||
511 | + face_landmarks["top_lip"] = [ | ||
512 | + tuple(shape[48]), | ||
513 | + tuple(shape[49]), | ||
514 | + tuple(shape[50]), | ||
515 | + tuple(shape[51]), | ||
516 | + tuple(shape[52]), | ||
517 | + tuple(shape[53]), | ||
518 | + tuple(shape[54]), | ||
519 | + tuple(shape[60]), | ||
520 | + tuple(shape[61]), | ||
521 | + tuple(shape[62]), | ||
522 | + tuple(shape[63]), | ||
523 | + tuple(shape[64]), | ||
524 | + ] | ||
525 | + | ||
526 | + face_landmarks["bottom_lip"] = [ | ||
527 | + tuple(shape[54]), | ||
528 | + tuple(shape[55]), | ||
529 | + tuple(shape[56]), | ||
530 | + tuple(shape[57]), | ||
531 | + tuple(shape[58]), | ||
532 | + tuple(shape[59]), | ||
533 | + tuple(shape[48]), | ||
534 | + tuple(shape[64]), | ||
535 | + tuple(shape[65]), | ||
536 | + tuple(shape[66]), | ||
537 | + tuple(shape[67]), | ||
538 | + tuple(shape[60]), | ||
539 | + ] | ||
540 | + | ||
541 | + face_landmarks["chin"] = [ | ||
542 | + tuple(shape[0]), | ||
543 | + tuple(shape[1]), | ||
544 | + tuple(shape[2]), | ||
545 | + tuple(shape[3]), | ||
546 | + tuple(shape[4]), | ||
547 | + tuple(shape[5]), | ||
548 | + tuple(shape[6]), | ||
549 | + tuple(shape[7]), | ||
550 | + tuple(shape[8]), | ||
551 | + tuple(shape[9]), | ||
552 | + tuple(shape[10]), | ||
553 | + tuple(shape[11]), | ||
554 | + tuple(shape[12]), | ||
555 | + tuple(shape[13]), | ||
556 | + tuple(shape[14]), | ||
557 | + tuple(shape[15]), | ||
558 | + tuple(shape[16]), | ||
559 | + ] | ||
560 | + return face_landmarks | ||
561 | + | ||
562 | + | ||
563 | +def rect_to_bb(rect): | ||
564 | + x1 = rect.left() | ||
565 | + x2 = rect.right() | ||
566 | + y1 = rect.top() | ||
567 | + y2 = rect.bottom() | ||
568 | + return (x1, x2, y2, x1) | ||
569 | + | ||
570 | + | ||
571 | +def mask_image(image_path, args): | ||
572 | + # Read the image | ||
573 | + image = cv2.imread(image_path) | ||
574 | + original_image = image.copy() | ||
575 | + # gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | ||
576 | + gray = image | ||
577 | + face_locations = args.detector(gray, 1) | ||
578 | + mask_type = args.mask_type | ||
579 | + verbose = args.verbose | ||
580 | + if args.code: | ||
581 | + ind = random.randint(0, len(args.code_count) - 1) | ||
582 | + mask_dict = args.mask_dict_of_dict[ind] | ||
583 | + mask_type = mask_dict["type"] | ||
584 | + args.color = mask_dict["color"] | ||
585 | + args.pattern = mask_dict["texture"] | ||
586 | + args.code_count[ind] += 1 | ||
587 | + | ||
588 | + elif mask_type == "random": | ||
589 | + available_mask_types = get_available_mask_types() | ||
590 | + mask_type = random.choice(available_mask_types) | ||
591 | + | ||
592 | + if verbose: | ||
593 | + tqdm.write("Faces found: {:2d}".format(len(face_locations))) | ||
594 | + # Process each face in the image | ||
595 | + masked_images = [] | ||
596 | + mask_binary_array = [] | ||
597 | + mask = [] | ||
598 | + for (i, face_location) in enumerate(face_locations): | ||
599 | + shape = args.predictor(gray, face_location) | ||
600 | + shape = face_utils.shape_to_np(shape) | ||
601 | + face_landmarks = shape_to_landmarks(shape) | ||
602 | + face_location = rect_to_bb(face_location) | ||
603 | + # draw_landmarks(face_landmarks, image) | ||
604 | + six_points_on_face, angle = get_six_points(face_landmarks, image) | ||
605 | + mask = [] | ||
606 | + if mask_type != "all": | ||
607 | + if len(masked_images) > 0: | ||
608 | + image = masked_images.pop(0) | ||
609 | + image, mask_binary = mask_face( | ||
610 | + image, face_location, six_points_on_face, angle, args, type=mask_type | ||
611 | + ) | ||
612 | + | ||
613 | + # compress to face tight | ||
614 | + face_height = face_location[2] - face_location[0] | ||
615 | + face_width = face_location[1] - face_location[3] | ||
616 | + masked_images.append(image) | ||
617 | + mask_binary_array.append(mask_binary) | ||
618 | + mask.append(mask_type) | ||
619 | + else: | ||
620 | + available_mask_types = get_available_mask_types() | ||
621 | + for m in range(len(available_mask_types)): | ||
622 | + if len(masked_images) == len(available_mask_types): | ||
623 | + image = masked_images.pop(m) | ||
624 | + img, mask_binary = mask_face( | ||
625 | + image, | ||
626 | + face_location, | ||
627 | + six_points_on_face, | ||
628 | + angle, | ||
629 | + args, | ||
630 | + type=available_mask_types[m], | ||
631 | + ) | ||
632 | + masked_images.insert(m, img) | ||
633 | + mask_binary_array.insert(m, mask_binary) | ||
634 | + mask = available_mask_types | ||
635 | + cc = 1 | ||
636 | + | ||
637 | + return masked_images, mask, mask_binary_array, original_image | ||
638 | + | ||
639 | + | ||
640 | +def is_image(path): | ||
641 | + try: | ||
642 | + extensions = path[-4:] | ||
643 | + image_extensions = ["png", "PNG", "jpg", "JPG"] | ||
644 | + | ||
645 | + if extensions[1:] in image_extensions: | ||
646 | + return True | ||
647 | + else: | ||
648 | + print("Please input image file. png / jpg") | ||
649 | + return False | ||
650 | + except: | ||
651 | + return False | ||
652 | + | ||
653 | + | ||
654 | +def get_available_mask_types(config_filename="masks/masks.cfg"): | ||
655 | + parser = ConfigParser() | ||
656 | + parser.optionxform = str | ||
657 | + parser.read(config_filename) | ||
658 | + available_mask_types = parser.sections() | ||
659 | + available_mask_types = [ | ||
660 | + string for string in available_mask_types if "left" not in string | ||
661 | + ] | ||
662 | + available_mask_types = [ | ||
663 | + string for string in available_mask_types if "right" not in string | ||
664 | + ] | ||
665 | + | ||
666 | + return available_mask_types | ||
667 | + | ||
668 | + | ||
669 | +def print_orderly(str, n): | ||
670 | + # print("") | ||
671 | + hyphens = "-" * int((n - len(str)) / 2) | ||
672 | + str_p = hyphens + " " + str + " " + hyphens | ||
673 | + hyphens_bar = "-" * len(str_p) | ||
674 | + print(hyphens_bar) | ||
675 | + print(str_p) | ||
676 | + print(hyphens_bar) | ||
677 | + | ||
678 | + | ||
679 | +def display_MaskTheFace(): | ||
680 | + with open("utils/display.txt", "r") as file: | ||
681 | + for line in file: | ||
682 | + cc = 1 | ||
683 | + print(line, end="") |
Code/MaskTheFace/utils/create_mask.py
0 → 100644
1 | +# Author: aqeelanwar | ||
2 | +# Created: 6 July,2020, 12:14 AM | ||
3 | +# Email: aqeel.anwar@gatech.edu | ||
4 | + | ||
5 | +from PIL import ImageColor | ||
6 | +import cv2 | ||
7 | +import numpy as np | ||
8 | + | ||
9 | +COLOR = [ | ||
10 | + "#fc1c1a", | ||
11 | + "#177ABC", | ||
12 | + "#94B6D2", | ||
13 | + "#A5AB81", | ||
14 | + "#DD8047", | ||
15 | + "#6b425e", | ||
16 | + "#e26d5a", | ||
17 | + "#c92c48", | ||
18 | + "#6a506d", | ||
19 | + "#ffc900", | ||
20 | + "#ffffff", | ||
21 | + "#000000", | ||
22 | + "#49ff00", | ||
23 | +] | ||
24 | + | ||
25 | + | ||
26 | +def color_the_mask(mask_image, color, intensity): | ||
27 | + assert 0 <= intensity <= 1, "intensity should be between 0 and 1" | ||
28 | + RGB_color = ImageColor.getcolor(color, "RGB") | ||
29 | + RGB_color = (RGB_color[2], RGB_color[1], RGB_color[0]) | ||
30 | + orig_shape = mask_image.shape | ||
31 | + bit_mask = mask_image[:, :, 3] | ||
32 | + mask_image = mask_image[:, :, 0:3] | ||
33 | + | ||
34 | + color_image = np.full(mask_image.shape, RGB_color, np.uint8) | ||
35 | + mask_color = cv2.addWeighted(mask_image, 1 - intensity, color_image, intensity, 0) | ||
36 | + mask_color = cv2.bitwise_and(mask_color, mask_color, mask=bit_mask) | ||
37 | + colored_mask = np.zeros(orig_shape, dtype=np.uint8) | ||
38 | + colored_mask[:, :, 0:3] = mask_color | ||
39 | + colored_mask[:, :, 3] = bit_mask | ||
40 | + return colored_mask | ||
41 | + | ||
42 | + | ||
43 | +def texture_the_mask(mask_image, texture_path, intensity): | ||
44 | + assert 0 <= intensity <= 1, "intensity should be between 0 and 1" | ||
45 | + orig_shape = mask_image.shape | ||
46 | + bit_mask = mask_image[:, :, 3] | ||
47 | + mask_image = mask_image[:, :, 0:3] | ||
48 | + texture_image = cv2.imread(texture_path) | ||
49 | + texture_image = cv2.resize(texture_image, (orig_shape[1], orig_shape[0])) | ||
50 | + | ||
51 | + mask_texture = cv2.addWeighted( | ||
52 | + mask_image, 1 - intensity, texture_image, intensity, 0 | ||
53 | + ) | ||
54 | + mask_texture = cv2.bitwise_and(mask_texture, mask_texture, mask=bit_mask) | ||
55 | + textured_mask = np.zeros(orig_shape, dtype=np.uint8) | ||
56 | + textured_mask[:, :, 0:3] = mask_texture | ||
57 | + textured_mask[:, :, 3] = bit_mask | ||
58 | + | ||
59 | + return textured_mask | ||
60 | + | ||
61 | + | ||
62 | + | ||
63 | +# cloth_mask = cv2.imread("masks/templates/cloth.png", cv2.IMREAD_UNCHANGED) | ||
64 | +# # cloth_mask = color_the_mask(cloth_mask, color=COLOR[0], intensity=0.5) | ||
65 | +# path = "masks/textures" | ||
66 | +# path, dir, files = os.walk(path).__next__() | ||
67 | +# first_frame = True | ||
68 | +# col_limit = 6 | ||
69 | +# i = 0 | ||
70 | +# # img_concat_row=[] | ||
71 | +# img_concat = [] | ||
72 | +# # for f in files: | ||
73 | +# # if "._" not in f: | ||
74 | +# # print(f) | ||
75 | +# # i += 1 | ||
76 | +# # texture_image = cv2.imread(os.path.join(path, f)) | ||
77 | +# # m = texture_the_mask(cloth_mask, texture_image, intensity=0.5) | ||
78 | +# # if first_frame: | ||
79 | +# # img_concat_row = m | ||
80 | +# # first_frame = False | ||
81 | +# # else: | ||
82 | +# # img_concat_row = cv2.hconcat((img_concat_row, m)) | ||
83 | +# # | ||
84 | +# # if i % col_limit == 0: | ||
85 | +# # if len(img_concat) > 0: | ||
86 | +# # img_concat = cv2.vconcat((img_concat, img_concat_row)) | ||
87 | +# # else: | ||
88 | +# # img_concat = img_concat_row | ||
89 | +# # first_frame = True | ||
90 | +# | ||
91 | +# ## COlor the mask | ||
92 | +# thresholds = np.arange(0.1,0.9,0.05) | ||
93 | +# for intensity in thresholds: | ||
94 | +# c=COLOR[2] | ||
95 | +# # intensity = 0.5 | ||
96 | +# if "._" not in c: | ||
97 | +# print(intensity) | ||
98 | +# i += 1 | ||
99 | +# # texture_image = cv2.imread(os.path.join(path, f)) | ||
100 | +# m = color_the_mask(cloth_mask, c, intensity=intensity) | ||
101 | +# if first_frame: | ||
102 | +# img_concat_row = m | ||
103 | +# first_frame = False | ||
104 | +# else: | ||
105 | +# img_concat_row = cv2.hconcat((img_concat_row, m)) | ||
106 | +# | ||
107 | +# if i % col_limit == 0: | ||
108 | +# if len(img_concat) > 0: | ||
109 | +# img_concat = cv2.vconcat((img_concat, img_concat_row)) | ||
110 | +# else: | ||
111 | +# img_concat = img_concat_row | ||
112 | +# first_frame = True | ||
113 | +# | ||
114 | +# | ||
115 | +# cv2.imshow("k", img_concat) | ||
116 | +# cv2.imwrite("combine_N95_left.png", img_concat) | ||
117 | +# cv2.waitKey(0) | ||
118 | +# cc = 1 |
Code/MaskTheFace/utils/display.txt
0 → 100644
Code/MaskTheFace/utils/fetch_dataset.py
0 → 100644
1 | +# Author: Aqeel Anwar(ICSRL) | ||
2 | +# Created: 7/30/2020, 1:44 PM | ||
3 | +# Email: aqeel.anwar@gatech.edu | ||
4 | + | ||
5 | +# Code resued from https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url | ||
6 | +# Make sure you run this from parent folder and not from utils folder i.e. | ||
7 | +# python utils/fetch_dataset.py | ||
8 | + | ||
9 | +import requests, os | ||
10 | +from zipfile import ZipFile | ||
11 | +import argparse | ||
12 | +import urllib | ||
13 | + | ||
14 | +parser = argparse.ArgumentParser( | ||
15 | + description="Download dataset - Python code to download associated datasets" | ||
16 | +) | ||
17 | +parser.add_argument( | ||
18 | + "--dataset", | ||
19 | + type=str, | ||
20 | + default="mfr2", | ||
21 | + help="Name of the dataset - Details on available datasets can be found at GitHub Page", | ||
22 | +) | ||
23 | +args = parser.parse_args() | ||
24 | + | ||
25 | + | ||
26 | +def download_file_from_google_drive(id, destination): | ||
27 | + URL = "https://docs.google.com/uc?export=download" | ||
28 | + | ||
29 | + session = requests.Session() | ||
30 | + | ||
31 | + response = session.get(URL, params={"id": id}, stream=True) | ||
32 | + token = get_confirm_token(response) | ||
33 | + | ||
34 | + if token: | ||
35 | + params = {"id": id, "confirm": token} | ||
36 | + response = session.get(URL, params=params, stream=True) | ||
37 | + | ||
38 | + save_response_content(response, destination) | ||
39 | + | ||
40 | + | ||
41 | +def get_confirm_token(response): | ||
42 | + for key, value in response.cookies.items(): | ||
43 | + if key.startswith("download_warning"): | ||
44 | + return value | ||
45 | + | ||
46 | + return None | ||
47 | + | ||
48 | + | ||
49 | +def save_response_content(response, destination): | ||
50 | + CHUNK_SIZE = 32768 | ||
51 | + print(destination) | ||
52 | + with open(destination, "wb") as f: | ||
53 | + for chunk in response.iter_content(CHUNK_SIZE): | ||
54 | + if chunk: # filter out keep-alive new chunks | ||
55 | + f.write(chunk) | ||
56 | + | ||
57 | + | ||
58 | +def download(t_url): | ||
59 | + response = urllib.request.urlopen(t_url) | ||
60 | + data = response.read() | ||
61 | + txt_str = str(data) | ||
62 | + lines = txt_str.split("\\n") | ||
63 | + return lines | ||
64 | + | ||
65 | + | ||
66 | +def Convert(lst): | ||
67 | + it = iter(lst) | ||
68 | + res_dct = dict(zip(it, it)) | ||
69 | + return res_dct | ||
70 | + | ||
71 | + | ||
72 | +if __name__ == "__main__": | ||
73 | + # Fetch the latest download_links.txt file from GitHub | ||
74 | + link = "https://raw.githubusercontent.com/aqeelanwar/MaskTheFace/master/datasets/download_links.txt" | ||
75 | + links_dict = Convert( | ||
76 | + download(link)[0] | ||
77 | + .replace(":", "\n") | ||
78 | + .replace("b'", "") | ||
79 | + .replace("'", "") | ||
80 | + .replace(" ", "") | ||
81 | + .split("\n") | ||
82 | + ) | ||
83 | + file_id = links_dict[args.dataset] | ||
84 | + destination = "datasets\_.zip" | ||
85 | + print("Downloading: ", args.dataset) | ||
86 | + download_file_from_google_drive(file_id, destination) | ||
87 | + print("Extracting: ", args.dataset) | ||
88 | + with ZipFile(destination, "r") as zipObj: | ||
89 | + # Extract all the contents of zip file in current directory | ||
90 | + zipObj.extractall(destination.rsplit(os.path.sep, 1)[0]) | ||
91 | + | ||
92 | + os.remove(destination) |
Code/MaskTheFace/utils/fit_ellipse.py
0 → 100644
1 | +# Author: aqeelanwar | ||
2 | +# Created: 4 May,2020, 1:30 AM | ||
3 | +# Email: aqeel.anwar@gatech.edu | ||
4 | + | ||
5 | +import numpy as np | ||
6 | +from numpy.linalg import eig, inv | ||
7 | + | ||
8 | +def fitEllipse(x,y): | ||
9 | + x = x[:,np.newaxis] | ||
10 | + y = y[:,np.newaxis] | ||
11 | + D = np.hstack((x*x, x*y, y*y, x, y, np.ones_like(x))) | ||
12 | + S = np.dot(D.T,D) | ||
13 | + C = np.zeros([6,6]) | ||
14 | + C[0,2] = C[2,0] = 2; C[1,1] = -1 | ||
15 | + E, V = eig(np.dot(inv(S), C)) | ||
16 | + n = np.argmax(np.abs(E)) | ||
17 | + a = V[:,n] | ||
18 | + return a | ||
19 | + | ||
20 | +def ellipse_center(a): | ||
21 | + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0] | ||
22 | + num = b*b-a*c | ||
23 | + x0=(c*d-b*f)/num | ||
24 | + y0=(a*f-b*d)/num | ||
25 | + return np.array([x0,y0]) | ||
26 | + | ||
27 | + | ||
28 | +def ellipse_angle_of_rotation( a ): | ||
29 | + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0] | ||
30 | + return 0.5*np.arctan(2*b/(a-c)) | ||
31 | + | ||
32 | + | ||
33 | +def ellipse_axis_length( a ): | ||
34 | + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0] | ||
35 | + up = 2*(a*f*f+c*d*d+g*b*b-2*b*d*f-a*c*g) | ||
36 | + down1=(b*b-a*c)*( (c-a)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a)) | ||
37 | + down2=(b*b-a*c)*( (a-c)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a)) | ||
38 | + res1=np.sqrt(up/down1) | ||
39 | + res2=np.sqrt(up/down2) | ||
40 | + return np.array([res1, res2]) | ||
41 | + | ||
42 | +def ellipse_angle_of_rotation2( a ): | ||
43 | + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0] | ||
44 | + if b == 0: | ||
45 | + if a > c: | ||
46 | + return 0 | ||
47 | + else: | ||
48 | + return np.pi/2 | ||
49 | + else: | ||
50 | + if a > c: | ||
51 | + return np.arctan(2*b/(a-c))/2 | ||
52 | + else: | ||
53 | + return np.pi/2 + np.arctan(2*b/(a-c))/2 | ||
54 | + | ||
55 | +# a = fitEllipse(x,y) | ||
56 | +# center = ellipse_center(a) | ||
57 | +# #phi = ellipse_angle_of_rotation(a) | ||
58 | +# phi = ellipse_angle_of_rotation2(a) | ||
59 | +# axes = ellipse_axis_length(a) | ||
60 | +# | ||
61 | +# print("center = ", center) | ||
62 | +# print("angle of rotation = ", phi) | ||
63 | +# print("axes = ", axes) | ||
64 | + |
Code/MaskTheFace/utils/point_the_mask.py
0 → 100644
1 | +# Author: aqeelanwar | ||
2 | +# Created: 2 May,2020, 2:49 AM | ||
3 | +# Email: aqeel.anwar@gatech.edu | ||
4 | + | ||
5 | +from tkinter import filedialog | ||
6 | +from tkinter import * | ||
7 | +import cv2, os | ||
8 | + | ||
9 | +mouse_pts = [] | ||
10 | + | ||
11 | + | ||
12 | +def get_mouse_points(event, x, y, flags, param): | ||
13 | + global mouseX, mouseY, mouse_pts | ||
14 | + if event == cv2.EVENT_LBUTTONDOWN: | ||
15 | + mouseX, mouseY = x, y | ||
16 | + cv2.circle(mask_im, (x, y), 10, (0, 255, 255), 10) | ||
17 | + if "mouse_pts" not in globals(): | ||
18 | + mouse_pts = [] | ||
19 | + mouse_pts.append((x, y)) | ||
20 | + # print("Point detected") | ||
21 | + # print((x,y)) | ||
22 | + | ||
23 | + | ||
24 | +root = Tk() | ||
25 | +filename = filedialog.askopenfilename( | ||
26 | + initialdir="/", | ||
27 | + title="Select file", | ||
28 | + filetypes=(("PNG files", "*.PNG"), ("png files", "*.png"), ("All files", "*.*")), | ||
29 | +) | ||
30 | +root.destroy() | ||
31 | +filename_split = os.path.split(filename) | ||
32 | +folder = filename_split[0] | ||
33 | +file = filename_split[1] | ||
34 | +file_split = file.split(".") | ||
35 | +new_filename = folder + "/" + file_split[0] + "_marked." + file_split[-1] | ||
36 | +mask_im = cv2.imread(filename) | ||
37 | +cv2.namedWindow("Mask") | ||
38 | +cv2.setMouseCallback("Mask", get_mouse_points) | ||
39 | + | ||
40 | +while True: | ||
41 | + cv2.imshow("Mask", mask_im) | ||
42 | + cv2.waitKey(1) | ||
43 | + if len(mouse_pts) == 6: | ||
44 | + cv2.destroyWindow("Mask") | ||
45 | + break | ||
46 | + first_frame_display = False | ||
47 | +points = mouse_pts | ||
48 | +print(points) | ||
49 | +print("----------------------------------------------------------------") | ||
50 | +print("Copy the following code and paste it in masks.cfg") | ||
51 | +print("----------------------------------------------------------------") | ||
52 | +name_points = ["a", "b", "c", "d", "e", "f"] | ||
53 | + | ||
54 | +mask_title = "[" + file_split[0] + "]" | ||
55 | +print(mask_title) | ||
56 | +print("template: ", filename) | ||
57 | +for i in range(len(mouse_pts)): | ||
58 | + name = ( | ||
59 | + "mask_" | ||
60 | + + name_points[i] | ||
61 | + + ": " | ||
62 | + + str(mouse_pts[i][0]) | ||
63 | + + "," | ||
64 | + + str(mouse_pts[i][1]) | ||
65 | + ) | ||
66 | + print(name) | ||
67 | + | ||
68 | +cv2.imwrite(new_filename, mask_im) |
Code/MaskTheFace/utils/read_cfg.py
0 → 100644
1 | +# Author: Aqeel Anwar(ICSRL) | ||
2 | +# Created: 9/20/2019, 12:43 PM | ||
3 | +# Email: aqeel.anwar@gatech.edu | ||
4 | + | ||
5 | +from configparser import ConfigParser | ||
6 | +from dotmap import DotMap | ||
7 | + | ||
8 | + | ||
9 | +def ConvertIfStringIsInt(input_string): | ||
10 | + try: | ||
11 | + float(input_string) | ||
12 | + | ||
13 | + try: | ||
14 | + if int(input_string) == float(input_string): | ||
15 | + return int(input_string) | ||
16 | + else: | ||
17 | + return float(input_string) | ||
18 | + except ValueError: | ||
19 | + return float(input_string) | ||
20 | + | ||
21 | + except ValueError: | ||
22 | + return input_string | ||
23 | + | ||
24 | + | ||
25 | +def read_cfg(config_filename="masks/masks.cfg", mask_type="surgical", verbose=False): | ||
26 | + parser = ConfigParser() | ||
27 | + parser.optionxform = str | ||
28 | + parser.read(config_filename) | ||
29 | + cfg = DotMap() | ||
30 | + section_name = mask_type | ||
31 | + | ||
32 | + if verbose: | ||
33 | + hyphens = "-" * int((80 - len(config_filename)) / 2) | ||
34 | + print(hyphens + " " + config_filename + " " + hyphens) | ||
35 | + | ||
36 | + # for section_name in parser.sections(): | ||
37 | + | ||
38 | + if verbose: | ||
39 | + print("[" + section_name + "]") | ||
40 | + for name, value in parser.items(section_name): | ||
41 | + value = ConvertIfStringIsInt(value) | ||
42 | + if name != "template": | ||
43 | + cfg[name] = tuple(int(s) for s in value.split(",")) | ||
44 | + else: | ||
45 | + cfg[name] = value | ||
46 | + spaces = " " * (30 - len(name)) | ||
47 | + if verbose: | ||
48 | + print(name + ":" + spaces + str(cfg[name])) | ||
49 | + | ||
50 | + return cfg |
1 | +{ | ||
2 | + "cells": [ | ||
3 | + { | ||
4 | + "cell_type": "code", | ||
5 | + "execution_count": 18, | ||
6 | + "metadata": {}, | ||
7 | + "outputs": [], | ||
8 | + "source": [ | ||
9 | + "import torch\n", | ||
10 | + "import torch.nn as nn\n", | ||
11 | + "from torch.nn import Parameter\n", | ||
12 | + "import torch.nn.functional as F\n", | ||
13 | + "from torchvision import transforms as tf\n", | ||
14 | + "import torch.utils.data as data\n", | ||
15 | + "\n", | ||
16 | + "import os\n", | ||
17 | + "import cv2\n", | ||
18 | + "import functools\n", | ||
19 | + "import numpy as np\n", | ||
20 | + "from PIL import Image\n", | ||
21 | + "import matplotlib.pyplot as plt" | ||
22 | + ] | ||
23 | + }, | ||
24 | + { | ||
25 | + "cell_type": "code", | ||
26 | + "execution_count": 1, | ||
27 | + "metadata": {}, | ||
28 | + "outputs": [], | ||
29 | + "source": [ | ||
30 | + "from models import vgg19" | ||
31 | + ] | ||
32 | + }, | ||
33 | + { | ||
34 | + "cell_type": "code", | ||
35 | + "execution_count": 14, | ||
36 | + "metadata": {}, | ||
37 | + "outputs": [], | ||
38 | + "source": [ | ||
39 | + "model = vgg19(pretrained=True).features[:-2]\n", | ||
40 | + "\n", | ||
41 | + "model = model.eval()" | ||
42 | + ] | ||
43 | + }, | ||
44 | + { | ||
45 | + "cell_type": "code", | ||
46 | + "execution_count": 15, | ||
47 | + "metadata": {}, | ||
48 | + "outputs": [ | ||
49 | + { | ||
50 | + "data": { | ||
51 | + "text/plain": [ | ||
52 | + "Sequential(\n", | ||
53 | + " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
54 | + " (1): ReLU(inplace=True)\n", | ||
55 | + " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
56 | + " (3): ReLU(inplace=True)\n", | ||
57 | + " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | ||
58 | + " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
59 | + " (6): ReLU(inplace=True)\n", | ||
60 | + " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
61 | + " (8): ReLU(inplace=True)\n", | ||
62 | + " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | ||
63 | + " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
64 | + " (11): ReLU(inplace=True)\n", | ||
65 | + " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
66 | + " (13): ReLU(inplace=True)\n", | ||
67 | + " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
68 | + " (15): ReLU(inplace=True)\n", | ||
69 | + " (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
70 | + " (17): ReLU(inplace=True)\n", | ||
71 | + " (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | ||
72 | + " (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
73 | + " (20): ReLU(inplace=True)\n", | ||
74 | + " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
75 | + " (22): ReLU(inplace=True)\n", | ||
76 | + " (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
77 | + " (24): ReLU(inplace=True)\n", | ||
78 | + " (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
79 | + " (26): ReLU(inplace=True)\n", | ||
80 | + " (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | ||
81 | + " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
82 | + " (29): ReLU(inplace=True)\n", | ||
83 | + " (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
84 | + " (31): ReLU(inplace=True)\n", | ||
85 | + " (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
86 | + " (33): ReLU(inplace=True)\n", | ||
87 | + " (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | ||
88 | + ")" | ||
89 | + ] | ||
90 | + }, | ||
91 | + "execution_count": 15, | ||
92 | + "metadata": {}, | ||
93 | + "output_type": "execute_result" | ||
94 | + } | ||
95 | + ], | ||
96 | + "source": [ | ||
97 | + "model" | ||
98 | + ] | ||
99 | + }, | ||
100 | + { | ||
101 | + "cell_type": "code", | ||
102 | + "execution_count": 9, | ||
103 | + "metadata": {}, | ||
104 | + "outputs": [], | ||
105 | + "source": [ | ||
106 | + "img = torch.rand(4,3,256,256)" | ||
107 | + ] | ||
108 | + }, | ||
109 | + { | ||
110 | + "cell_type": "code", | ||
111 | + "execution_count": 10, | ||
112 | + "metadata": {}, | ||
113 | + "outputs": [ | ||
114 | + { | ||
115 | + "data": { | ||
116 | + "text/plain": [ | ||
117 | + "torch.Size([4, 512, 8, 8])" | ||
118 | + ] | ||
119 | + }, | ||
120 | + "execution_count": 10, | ||
121 | + "metadata": {}, | ||
122 | + "output_type": "execute_result" | ||
123 | + } | ||
124 | + ], | ||
125 | + "source": [ | ||
126 | + "out = model(img)\n", | ||
127 | + "out.shape" | ||
128 | + ] | ||
129 | + }, | ||
130 | + { | ||
131 | + "cell_type": "code", | ||
132 | + "execution_count": 19, | ||
133 | + "metadata": {}, | ||
134 | + "outputs": [], | ||
135 | + "source": [ | ||
136 | + "class GatedConv2d(nn.Module):\n", | ||
137 | + " def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, activation = 'lrelu', norm = 'in'):\n", | ||
138 | + " super(GatedConv2d, self).__init__()\n", | ||
139 | + " self.pad = nn.ZeroPad2d(padding)\n", | ||
140 | + " if norm is not None:\n", | ||
141 | + " self.norm = nn.InstanceNorm2d(out_channels)\n", | ||
142 | + " else:\n", | ||
143 | + " self.norm = None\n", | ||
144 | + " \n", | ||
145 | + " if activation == 'tanh':\n", | ||
146 | + " self.activation = nn.Tanh()\n", | ||
147 | + " else:\n", | ||
148 | + " self.activation = nn.LeakyReLU(0.2, inplace = True)\n", | ||
149 | + " \n", | ||
150 | + " \n", | ||
151 | + " self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)\n", | ||
152 | + " self.mask_conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)\n", | ||
153 | + " self.sigmoid = torch.nn.Sigmoid()\n", | ||
154 | + " \n", | ||
155 | + " def forward(self, x):\n", | ||
156 | + " x = self.pad(x)\n", | ||
157 | + " conv = self.conv2d(x)\n", | ||
158 | + " mask = self.mask_conv2d(x)\n", | ||
159 | + " gated_mask = self.sigmoid(mask)\n", | ||
160 | + " x = conv * gated_mask\n", | ||
161 | + " if self.norm:\n", | ||
162 | + " x = self.norm(x)\n", | ||
163 | + " if self.activation:\n", | ||
164 | + " x = self.activation(x)\n", | ||
165 | + " return x\n", | ||
166 | + "\n", | ||
167 | + "class TransposeGatedConv2d(nn.Module):\n", | ||
168 | + " def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, norm=None, scale_factor = 2):\n", | ||
169 | + " super(TransposeGatedConv2d, self).__init__()\n", | ||
170 | + " # Initialize the conv scheme\n", | ||
171 | + " self.scale_factor = scale_factor\n", | ||
172 | + " self.gated_conv2d = GatedConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, norm=norm)\n", | ||
173 | + " \n", | ||
174 | + " def forward(self, x):\n", | ||
175 | + " x = F.interpolate(x, scale_factor = self.scale_factor, mode = 'nearest')\n", | ||
176 | + " x = self.gated_conv2d(x)\n", | ||
177 | + " return x" | ||
178 | + ] | ||
179 | + }, | ||
180 | + { | ||
181 | + "cell_type": "code", | ||
182 | + "execution_count": 20, | ||
183 | + "metadata": {}, | ||
184 | + "outputs": [], | ||
185 | + "source": [ | ||
186 | + "class GatedGenerator(nn.Module):\n", | ||
187 | + " def __init__(self, in_channels=4, latent_channels=64, out_channels=3):\n", | ||
188 | + " super(GatedGenerator, self).__init__()\n", | ||
189 | + " self.coarse = nn.Sequential(\n", | ||
190 | + " # encoder\n", | ||
191 | + " GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None),\n", | ||
192 | + " GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1),\n", | ||
193 | + " GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1),\n", | ||
194 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1),\n", | ||
195 | + " # Bottleneck\n", | ||
196 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n", | ||
197 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n", | ||
198 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2),\n", | ||
199 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4),\n", | ||
200 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8),\n", | ||
201 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16),\n", | ||
202 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n", | ||
203 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n", | ||
204 | + " # decoder\n", | ||
205 | + " TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1),\n", | ||
206 | + " GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1),\n", | ||
207 | + " TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1),\n", | ||
208 | + " GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None)\n", | ||
209 | + " )\n", | ||
210 | + " self.refinement = nn.Sequential(\n", | ||
211 | + " # encoder\n", | ||
212 | + " GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None),\n", | ||
213 | + " GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1),\n", | ||
214 | + " GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1),\n", | ||
215 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1),\n", | ||
216 | + " # Bottleneck\n", | ||
217 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n", | ||
218 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n", | ||
219 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2),\n", | ||
220 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4),\n", | ||
221 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8),\n", | ||
222 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16),\n", | ||
223 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n", | ||
224 | + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n", | ||
225 | + " # decoder\n", | ||
226 | + " TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1),\n", | ||
227 | + " GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1),\n", | ||
228 | + " TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1),\n", | ||
229 | + " GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None)\n", | ||
230 | + " )\n", | ||
231 | + " \n", | ||
232 | + " def forward(self, img, mask):\n", | ||
233 | + " # img: entire img\n", | ||
234 | + " # mask: 1 for mask region; 0 for unmask region\n", | ||
235 | + " # 1 - mask: unmask\n", | ||
236 | + " # img * (1 - mask): ground truth unmask region\n", | ||
237 | + " # Coarse\n", | ||
238 | + " \n", | ||
239 | + " first_masked_img = img * (1 - mask) + mask\n", | ||
240 | + " first_in = torch.cat((first_masked_img, mask), 1) # in: [B, 4, H, W]\n", | ||
241 | + " first_out = self.coarse(first_in) # out: [B, 3, H, W]\n", | ||
242 | + " # Refinement\n", | ||
243 | + " second_masked_img = img * (1 - mask) + first_out * mask\n", | ||
244 | + " second_in = torch.cat((second_masked_img, mask), 1) # in: [B, 4, H, W]\n", | ||
245 | + " second_out = self.refinement(second_in) # out: [B, 3, H, W]\n", | ||
246 | + " return first_out, second_out" | ||
247 | + ] | ||
248 | + }, | ||
249 | + { | ||
250 | + "cell_type": "code", | ||
251 | + "execution_count": 21, | ||
252 | + "metadata": {}, | ||
253 | + "outputs": [], | ||
254 | + "source": [ | ||
255 | + "class NLayerDiscriminator(nn.Module):\n", | ||
256 | + " def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):\n", | ||
257 | + " super(NLayerDiscriminator, self).__init__()\n", | ||
258 | + " if type(norm_layer) == functools.partial:\n", | ||
259 | + " use_bias = norm_layer.func == nn.InstanceNorm2d\n", | ||
260 | + " else:\n", | ||
261 | + " use_bias = norm_layer == nn.InstanceNorm2d\n", | ||
262 | + "\n", | ||
263 | + " kw = 4\n", | ||
264 | + " padw = 1\n", | ||
265 | + " sequence = [\n", | ||
266 | + " nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),\n", | ||
267 | + " nn.LeakyReLU(0.2, True)\n", | ||
268 | + " ]\n", | ||
269 | + "\n", | ||
270 | + " nf_mult = 1\n", | ||
271 | + " nf_mult_prev = 1\n", | ||
272 | + " for n in range(1, n_layers):\n", | ||
273 | + " nf_mult_prev = nf_mult\n", | ||
274 | + " nf_mult = min(2**n, 8)\n", | ||
275 | + " sequence += [\n", | ||
276 | + " nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n", | ||
277 | + " kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n", | ||
278 | + " norm_layer(ndf * nf_mult),\n", | ||
279 | + " nn.LeakyReLU(0.2, True)\n", | ||
280 | + " ]\n", | ||
281 | + "\n", | ||
282 | + " nf_mult_prev = nf_mult\n", | ||
283 | + " nf_mult = min(2**n_layers, 8)\n", | ||
284 | + " sequence += [\n", | ||
285 | + " nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n", | ||
286 | + " kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n", | ||
287 | + " norm_layer(ndf * nf_mult),\n", | ||
288 | + " nn.LeakyReLU(0.2, True)\n", | ||
289 | + " ]\n", | ||
290 | + "\n", | ||
291 | + " sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]\n", | ||
292 | + "\n", | ||
293 | + " if use_sigmoid:\n", | ||
294 | + " sequence += [nn.Sigmoid()]\n", | ||
295 | + "\n", | ||
296 | + " self.model = nn.Sequential(*sequence)\n", | ||
297 | + "\n", | ||
298 | + " def forward(self, input):\n", | ||
299 | + " return self.model(input)" | ||
300 | + ] | ||
301 | + }, | ||
302 | + { | ||
303 | + "cell_type": "code", | ||
304 | + "execution_count": 22, | ||
305 | + "metadata": {}, | ||
306 | + "outputs": [], | ||
307 | + "source": [ | ||
308 | + "class PerceptualNet(nn.Module):\n", | ||
309 | + " def __init__(self):\n", | ||
310 | + " super(PerceptualNet, self).__init__()\n", | ||
311 | + " self.features = nn.Sequential(\n", | ||
312 | + " nn.Conv2d(3, 64, 3, 1, 1),\n", | ||
313 | + " nn.ReLU(inplace = True),\n", | ||
314 | + " nn.Conv2d(64, 64, 3, 1, 1),\n", | ||
315 | + " nn.ReLU(inplace = True),\n", | ||
316 | + " nn.MaxPool2d(2, 2),\n", | ||
317 | + " nn.Conv2d(64, 128, 3, 1, 1),\n", | ||
318 | + " nn.ReLU(inplace = True),\n", | ||
319 | + " nn.Conv2d(128, 128, 3, 1, 1),\n", | ||
320 | + " nn.ReLU(inplace = True),\n", | ||
321 | + " nn.MaxPool2d(2, 2),\n", | ||
322 | + " nn.Conv2d(128, 256, 3, 1, 1),\n", | ||
323 | + " nn.ReLU(inplace = True),\n", | ||
324 | + " nn.Conv2d(256, 256, 3, 1, 1),\n", | ||
325 | + " nn.ReLU(inplace = True),\n", | ||
326 | + " nn.Conv2d(256, 256, 3, 1, 1),\n", | ||
327 | + " nn.MaxPool2d(2, 2),\n", | ||
328 | + " nn.Conv2d(256, 512, 3, 1, 1),\n", | ||
329 | + " nn.ReLU(inplace = True),\n", | ||
330 | + " nn.Conv2d(512, 512, 3, 1, 1),\n", | ||
331 | + " nn.ReLU(inplace = True),\n", | ||
332 | + " nn.Conv2d(512, 512, 3, 1, 1)\n", | ||
333 | + " )\n", | ||
334 | + "\n", | ||
335 | + " def forward(self, x):\n", | ||
336 | + " x = self.features(x)\n", | ||
337 | + " return x" | ||
338 | + ] | ||
339 | + }, | ||
340 | + { | ||
341 | + "cell_type": "code", | ||
342 | + "execution_count": 6, | ||
343 | + "metadata": {}, | ||
344 | + "outputs": [], | ||
345 | + "source": [ | ||
346 | + "class GANLoss(nn.Module):\n", | ||
347 | + " def __init__(self, target_real_label=1.0, target_fake_label=0.0):\n", | ||
348 | + " super(GANLoss, self).__init__()\n", | ||
349 | + " self.register_buffer('real_label', torch.tensor(target_real_label))\n", | ||
350 | + " self.register_buffer('fake_label', torch.tensor(target_fake_label))\n", | ||
351 | + " self.loss = nn.BCELoss()\n", | ||
352 | + "\n", | ||
353 | + " def get_target_tensor(self, input, target_is_real):\n", | ||
354 | + " if target_is_real:\n", | ||
355 | + " target_tensor = self.real_label\n", | ||
356 | + " else:\n", | ||
357 | + " target_tensor = self.fake_label\n", | ||
358 | + " return target_tensor.expand_as(input)\n", | ||
359 | + "\n", | ||
360 | + " def __call__(self, input, target_is_real):\n", | ||
361 | + " target_tensor = self.get_target_tensor(input, target_is_real)\n", | ||
362 | + " return self.loss(input, target_tensor)" | ||
363 | + ] | ||
364 | + }, | ||
365 | + { | ||
366 | + "cell_type": "code", | ||
367 | + "execution_count": 7, | ||
368 | + "metadata": {}, | ||
369 | + "outputs": [], | ||
370 | + "source": [ | ||
371 | + "class InpaintDataset(data.Dataset):\n", | ||
372 | + " def __init__(self, img_dir):\n", | ||
373 | + " self.img_dir = img_dir\n", | ||
374 | + " self.load_images()\n", | ||
375 | + " \n", | ||
376 | + " def load_images(self):\n", | ||
377 | + " self.fns =[]\n", | ||
378 | + " img_paths = sorted(os.listdir(self.img_dir))\n", | ||
379 | + " for path in img_paths:\n", | ||
380 | + " self.fns.append(os.path.join(self.img_dir, path))\n", | ||
381 | + " \n", | ||
382 | + " def __getitem__(self, index):\n", | ||
383 | + " img_path = self.fns[index]\n", | ||
384 | + " img = cv2.imread(img_path)\n", | ||
385 | + " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", | ||
386 | + " img = cv2.resize(img, (256,256))\n", | ||
387 | + " \n", | ||
388 | + " mask = self.random_ff_mask()\n", | ||
389 | + " img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()\n", | ||
390 | + " mask = torch.from_numpy(mask.astype(np.float32)).contiguous()\n", | ||
391 | + " return img, mask\n", | ||
392 | + " \n", | ||
393 | + " def collate_fn(self, batch):\n", | ||
394 | + " imgs = torch.stack([i[0] for i in batch])\n", | ||
395 | + " masks = torch.stack([i[1] for i in batch])\n", | ||
396 | + " return {\n", | ||
397 | + " 'imgs': imgs,\n", | ||
398 | + " 'masks': masks\n", | ||
399 | + " }\n", | ||
400 | + " \n", | ||
401 | + " def __len__(self):\n", | ||
402 | + " return len(self.fns)\n", | ||
403 | + " \n", | ||
404 | + " def random_ff_mask(self, shape =256 , max_angle = 4, max_len = 40, max_width = 10, times = 15):\n", | ||
405 | + " \"\"\"Generate a random free form mask with configuration.\n", | ||
406 | + " Args:\n", | ||
407 | + " config: Config should have configuration including IMG_SHAPES,\n", | ||
408 | + " VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.\n", | ||
409 | + " Returns:\n", | ||
410 | + " tuple: (top, left, height, width)\n", | ||
411 | + " \"\"\"\n", | ||
412 | + " height = shape\n", | ||
413 | + " width = shape\n", | ||
414 | + " mask = np.zeros((height, width), np.float32)\n", | ||
415 | + " times = np.random.randint(times)\n", | ||
416 | + " for i in range(times):\n", | ||
417 | + " start_x = np.random.randint(width)\n", | ||
418 | + " start_y = np.random.randint(height)\n", | ||
419 | + " for j in range(1 + np.random.randint(5)):\n", | ||
420 | + " angle = 0.01 + np.random.randint(max_angle)\n", | ||
421 | + " if i % 2 == 0:\n", | ||
422 | + " angle = 2 * 3.1415926 - angle\n", | ||
423 | + " length = 10 + np.random.randint(max_len)\n", | ||
424 | + " brush_w = 5 + np.random.randint(max_width)\n", | ||
425 | + " end_x = (start_x + length * np.sin(angle)).astype(np.int32)\n", | ||
426 | + " end_y = (start_y + length * np.cos(angle)).astype(np.int32)\n", | ||
427 | + " cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)\n", | ||
428 | + " start_x, start_y = end_x, end_y\n", | ||
429 | + " return mask.reshape((1, ) + mask.shape).astype(np.float32)" | ||
430 | + ] | ||
431 | + }, | ||
432 | + { | ||
433 | + "cell_type": "code", | ||
434 | + "execution_count": null, | ||
435 | + "metadata": {}, | ||
436 | + "outputs": [], | ||
437 | + "source": [ | ||
438 | + "dataset = InpaintDataset(img_dir='datasets/places365standard_easyformat/places365_standard/train/waterfall')\n", | ||
439 | + "dataloader = data.DataLoader(dataset, batch_size=4, collate_fn = dataset.collate_fn)" | ||
440 | + ] | ||
441 | + }, | ||
442 | + { | ||
443 | + "cell_type": "code", | ||
444 | + "execution_count": null, | ||
445 | + "metadata": {}, | ||
446 | + "outputs": [], | ||
447 | + "source": [ | ||
448 | + "for batch in dataloader:\n", | ||
449 | + " imgs = batch['imgs']\n", | ||
450 | + " masks = batch['masks']\n", | ||
451 | + " \n", | ||
452 | + " break" | ||
453 | + ] | ||
454 | + }, | ||
455 | + { | ||
456 | + "cell_type": "code", | ||
457 | + "execution_count": 8, | ||
458 | + "metadata": {}, | ||
459 | + "outputs": [], | ||
460 | + "source": [ | ||
461 | + "device = torch.device('cuda')" | ||
462 | + ] | ||
463 | + }, | ||
464 | + { | ||
465 | + "cell_type": "code", | ||
466 | + "execution_count": 23, | ||
467 | + "metadata": {}, | ||
468 | + "outputs": [ | ||
469 | + { | ||
470 | + "ename": "NameError", | ||
471 | + "evalue": "name 'GANLoss' is not defined", | ||
472 | + "output_type": "error", | ||
473 | + "traceback": [ | ||
474 | + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | ||
475 | + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", | ||
476 | + "\u001b[1;32m<ipython-input-23-bdcc75eef256>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mmodel_D\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mNLayerDiscriminator\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0muse_sigmoid\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mmodel_P\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPerceptualNet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mcriterion_adv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mGANLoss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[0mcriterion_rec\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMSELoss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mcriterion_per\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mL1Loss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | ||
477 | + "\u001b[1;31mNameError\u001b[0m: name 'GANLoss' is not defined" | ||
478 | + ] | ||
479 | + } | ||
480 | + ], | ||
481 | + "source": [ | ||
482 | + "model_G = GatedGenerator()\n", | ||
483 | + "model_D = NLayerDiscriminator(3, use_sigmoid=True)\n", | ||
484 | + "model_P = PerceptualNet()\n", | ||
485 | + "criterion_adv = GANLoss()\n", | ||
486 | + "criterion_rec = nn.MSELoss()\n", | ||
487 | + "criterion_per = nn.L1Loss()\n", | ||
488 | + "optimizer_D = torch.optim.Adam(model_D.parameters(), lr=1e-4)\n", | ||
489 | + "optimizer_G = torch.optim.Adam(model_G.parameters(), lr=1e-4)" | ||
490 | + ] | ||
491 | + }, | ||
492 | + { | ||
493 | + "cell_type": "code", | ||
494 | + "execution_count": 24, | ||
495 | + "metadata": {}, | ||
496 | + "outputs": [], | ||
497 | + "source": [ | ||
498 | + "torch.save({\n", | ||
499 | + " 'D': model_D.state_dict(),\n", | ||
500 | + " 'G': model_G.state_dict()\n", | ||
501 | + "}, 's.pth')" | ||
502 | + ] | ||
503 | + }, | ||
504 | + { | ||
505 | + "cell_type": "code", | ||
506 | + "execution_count": null, | ||
507 | + "metadata": {}, | ||
508 | + "outputs": [], | ||
509 | + "source": [ | ||
510 | + "def count_params(model):\n", | ||
511 | + " return sum(p.numel() for p in model.parameters() if p.requires_grad)" | ||
512 | + ] | ||
513 | + }, | ||
514 | + { | ||
515 | + "cell_type": "code", | ||
516 | + "execution_count": null, | ||
517 | + "metadata": {}, | ||
518 | + "outputs": [], | ||
519 | + "source": [ | ||
520 | + "print(count_params(model_G))\n", | ||
521 | + "print(count_params(model_D))\n", | ||
522 | + "print(count_params(model_P))" | ||
523 | + ] | ||
524 | + }, | ||
525 | + { | ||
526 | + "cell_type": "code", | ||
527 | + "execution_count": 10, | ||
528 | + "metadata": {}, | ||
529 | + "outputs": [], | ||
530 | + "source": [ | ||
531 | + "def random_ff_mask(shape =256 , max_angle = 4, max_len = 40, max_width = 10, times = 15):\n", | ||
532 | + " \"\"\"Generate a random free form mask with configuration.\n", | ||
533 | + " Args:\n", | ||
534 | + " config: Config should have configuration including IMG_SHAPES,\n", | ||
535 | + " VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.\n", | ||
536 | + " Returns:\n", | ||
537 | + " tuple: (top, left, height, width)\n", | ||
538 | + " \"\"\"\n", | ||
539 | + " height = shape\n", | ||
540 | + " width = shape\n", | ||
541 | + " mask = np.zeros((height, width), np.float32)\n", | ||
542 | + " times = np.random.randint(times)\n", | ||
543 | + " for i in range(times):\n", | ||
544 | + " start_x = np.random.randint(width)\n", | ||
545 | + " start_y = np.random.randint(height)\n", | ||
546 | + " for j in range(1 + np.random.randint(5)):\n", | ||
547 | + " angle = 0.01 + np.random.randint(max_angle)\n", | ||
548 | + " if i % 2 == 0:\n", | ||
549 | + " angle = 2 * 3.1415926 - angle\n", | ||
550 | + " length = 10 + np.random.randint(max_len)\n", | ||
551 | + " brush_w = 5 + np.random.randint(max_width)\n", | ||
552 | + " end_x = (start_x + length * np.sin(angle)).astype(np.int32)\n", | ||
553 | + " end_y = (start_y + length * np.cos(angle)).astype(np.int32)\n", | ||
554 | + " cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)\n", | ||
555 | + " start_x, start_y = end_x, end_y\n", | ||
556 | + " return mask.reshape((1, ) + mask.shape).astype(np.float32)" | ||
557 | + ] | ||
558 | + }, | ||
559 | + { | ||
560 | + "cell_type": "code", | ||
561 | + "execution_count": 11, | ||
562 | + "metadata": {}, | ||
563 | + "outputs": [], | ||
564 | + "source": [ | ||
565 | + "img = cv2.imread('datasets/places365standard_easyformat/places365_standard/train/waterfall/00000003.jpg')\n", | ||
566 | + "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", | ||
567 | + "img = cv2.resize(img, (256, 256))\n", | ||
568 | + "img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()\n", | ||
569 | + "img_tensor = img.unsqueeze(0)\n", | ||
570 | + "mask = random_ff_mask()\n", | ||
571 | + "mask = torch.from_numpy(mask).contiguous().unsqueeze(0)" | ||
572 | + ] | ||
573 | + }, | ||
574 | + { | ||
575 | + "cell_type": "code", | ||
576 | + "execution_count": null, | ||
577 | + "metadata": {}, | ||
578 | + "outputs": [], | ||
579 | + "source": [ | ||
580 | + "def visualize(img):\n", | ||
581 | + " np_img = img.squeeze(0).detach().cpu().numpy()\n", | ||
582 | + " return np_img.transpose(1, 2, 0)" | ||
583 | + ] | ||
584 | + }, | ||
585 | + { | ||
586 | + "cell_type": "code", | ||
587 | + "execution_count": null, | ||
588 | + "metadata": {}, | ||
589 | + "outputs": [], | ||
590 | + "source": [ | ||
591 | + "plt.imshow(visualize(first_out_wholeimg))" | ||
592 | + ] | ||
593 | + }, | ||
594 | + { | ||
595 | + "cell_type": "code", | ||
596 | + "execution_count": 12, | ||
597 | + "metadata": {}, | ||
598 | + "outputs": [], | ||
599 | + "source": [ | ||
600 | + "first_out, second_out = model_G(img_tensor, mask)\n", | ||
601 | + "\n", | ||
602 | + "first_out_wholeimg = img_tensor * (1 - mask) + first_out * mask \n", | ||
603 | + "second_out_wholeimg = img_tensor * (1 - mask) + second_out * mask" | ||
604 | + ] | ||
605 | + }, | ||
606 | + { | ||
607 | + "cell_type": "code", | ||
608 | + "execution_count": 13, | ||
609 | + "metadata": {}, | ||
610 | + "outputs": [], | ||
611 | + "source": [ | ||
612 | + "# Train discriminator\n", | ||
613 | + "optimizer_D.zero_grad()\n", | ||
614 | + "\n", | ||
615 | + "fake_D = model_D(second_out_wholeimg.detach())\n", | ||
616 | + "real_D = model_D(img_tensor)\n", | ||
617 | + "\n", | ||
618 | + "loss_fake_D = criterion_adv(fake_D, target_is_real=False)\n", | ||
619 | + "loss_real_D = criterion_adv(real_D, target_is_real=True)\n", | ||
620 | + "\n", | ||
621 | + "loss_D = (loss_fake_D + loss_real_D) *0.5\n", | ||
622 | + "\n", | ||
623 | + "loss_D.backward()\n", | ||
624 | + "optimizer_D.step()" | ||
625 | + ] | ||
626 | + }, | ||
627 | + { | ||
628 | + "cell_type": "code", | ||
629 | + "execution_count": 15, | ||
630 | + "metadata": {}, | ||
631 | + "outputs": [], | ||
632 | + "source": [ | ||
633 | + "# Train Generator\n", | ||
634 | + "\n", | ||
635 | + "optimizer_G.zero_grad()\n", | ||
636 | + "\n", | ||
637 | + "fake_D = model_D(second_out_wholeimg)\n", | ||
638 | + "G_loss = criterion_adv(fake_D, target_is_real=True)" | ||
639 | + ] | ||
640 | + }, | ||
641 | + { | ||
642 | + "cell_type": "code", | ||
643 | + "execution_count": 16, | ||
644 | + "metadata": {}, | ||
645 | + "outputs": [], | ||
646 | + "source": [ | ||
647 | + "# Reconstruction loss\n", | ||
648 | + "\n", | ||
649 | + "loss_rec_1 = criterion_rec(first_out_wholeimg, img_tensor)\n", | ||
650 | + "loss_rec_2 = criterion_rec(second_out_wholeimg, img_tensor)" | ||
651 | + ] | ||
652 | + }, | ||
653 | + { | ||
654 | + "cell_type": "code", | ||
655 | + "execution_count": 17, | ||
656 | + "metadata": {}, | ||
657 | + "outputs": [], | ||
658 | + "source": [ | ||
659 | + "# Perceptual loss\n", | ||
660 | + "\n", | ||
661 | + "img_featuremaps = model_P(img_tensor) # feature maps\n", | ||
662 | + "second_out_wholeimg_featuremaps = model_P(second_out_wholeimg)\n", | ||
663 | + "\n", | ||
664 | + "loss_P = criterion_per(second_out_wholeimg_featuremaps, img_featuremaps)" | ||
665 | + ] | ||
666 | + }, | ||
667 | + { | ||
668 | + "cell_type": "code", | ||
669 | + "execution_count": 18, | ||
670 | + "metadata": {}, | ||
671 | + "outputs": [ | ||
672 | + { | ||
673 | + "ename": "NameError", | ||
674 | + "evalue": "name 'lambda_G' is not defined", | ||
675 | + "output_type": "error", | ||
676 | + "traceback": [ | ||
677 | + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | ||
678 | + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", | ||
679 | + "\u001b[1;32m<ipython-input-18-dbfe5f51e2fc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlambda_G\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mG_loss\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mlambda_rec_1\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mloss_rec_1\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mlambda_rec_2\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mloss_rec_2\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mlambda_per\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mloss_P\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0moptimizer_G\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | ||
680 | + "\u001b[1;31mNameError\u001b[0m: name 'lambda_G' is not defined" | ||
681 | + ] | ||
682 | + } | ||
683 | + ], | ||
684 | + "source": [ | ||
685 | + "loss = lambda_G * G_loss + lambda_rec_1 * loss_rec_1 + lambda_rec_2 * loss_rec_2 + lambda_per * loss_P\n", | ||
686 | + "loss.backward()\n", | ||
687 | + "optimizer_G.step()" | ||
688 | + ] | ||
689 | + } | ||
690 | + ], | ||
691 | + "metadata": { | ||
692 | + "kernelspec": { | ||
693 | + "display_name": "Python 3", | ||
694 | + "language": "python", | ||
695 | + "name": "python3" | ||
696 | + }, | ||
697 | + "language_info": { | ||
698 | + "codemirror_mode": { | ||
699 | + "name": "ipython", | ||
700 | + "version": 3 | ||
701 | + }, | ||
702 | + "file_extension": ".py", | ||
703 | + "mimetype": "text/x-python", | ||
704 | + "name": "python", | ||
705 | + "nbconvert_exporter": "python", | ||
706 | + "pygments_lexer": "ipython3", | ||
707 | + "version": "3.7.6" | ||
708 | + } | ||
709 | + }, | ||
710 | + "nbformat": 4, | ||
711 | + "nbformat_minor": 4 | ||
712 | +} |
Code/image-inpainting/configs/__init__.py
0 → 100644
1 | +from .config import Config | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
Code/image-inpainting/configs/config.py
0 → 100644
Code/image-inpainting/configs/facemask.yaml
0 → 100644
1 | +settings: | ||
2 | + root_dir: "./datasets/celeba/images/" | ||
3 | + checkpoint_path: "weights" | ||
4 | + sample_folder: "sample" | ||
5 | + | ||
6 | + cuda: True | ||
7 | + lr: 0.001 | ||
8 | + batch_size: 2 | ||
9 | + num_workers: 4 | ||
10 | + | ||
11 | + step_iters: [10000, 15000, 20000] | ||
12 | + gamma: 0.1 | ||
13 | + | ||
14 | + d_num_layers: 3 | ||
15 | + | ||
16 | + visualize_per_iter: 500 | ||
17 | + save_per_iter: 500 | ||
18 | + print_per_iter: 10 | ||
19 | + num_epochs: 100 | ||
20 | + | ||
21 | + lambda_G: 1.0 | ||
22 | + lambda_rec_1: 100.0 | ||
23 | + lambda_rec_2: 100.0 | ||
24 | + lambda_per: 10.0 | ||
25 | + | ||
26 | + img_size: 512 |
Code/image-inpainting/configs/places365.yaml
0 → 100644
1 | +settings: | ||
2 | + root_dir: "./datasets/places365_10classes" | ||
3 | + checkpoint_path: "/content/drive/MyDrive/weights/Places365 Inpainting/phase 3" | ||
4 | + sample_folder: "/content/drive/MyDrive/results/Places365 Inpainting/phase 3" | ||
5 | + | ||
6 | + cuda: True | ||
7 | + lr: 0.0001 | ||
8 | + batch_size: 8 | ||
9 | + num_workers: 4 | ||
10 | + | ||
11 | + step_iters: [50000, 75000, 100000] | ||
12 | + gamma: 0.1 | ||
13 | + | ||
14 | + d_num_layers: 3 | ||
15 | + | ||
16 | + visualize_per_iter: 500 | ||
17 | + save_per_iter: 500 | ||
18 | + print_per_iter: 10 | ||
19 | + num_epochs: 100 | ||
20 | + | ||
21 | + lambda_G: 0.3 | ||
22 | + lambda_rec_1: 10.0 | ||
23 | + lambda_rec_2: 10.0 | ||
24 | + lambda_per: 1.0 | ||
25 | + | ||
26 | + img_size: 256 | ||
27 | + max_angle: 4 | ||
28 | + max_len: 50 | ||
29 | + max_width: 30 | ||
30 | + times: 15 |
Code/image-inpainting/configs/segm.yaml
0 → 100644
1 | +settings: | ||
2 | + root_dir: "./datasets/celeba/images/" | ||
3 | + train_anns: "./datasets/celeba/annotations/train.csv" | ||
4 | + val_anns: "./datasets/celeba/annotations/val.csv" | ||
5 | + | ||
6 | + checkpoint_path: "weights" #"/content/drive/MyDrive/weights/Places365 Inpainting/unet/phase 1" | ||
7 | + sample_folder: "sample" #"/content/drive/MyDrive/results/Places365 Inpainting/unet/phase 1" | ||
8 | + | ||
9 | + cuda: True | ||
10 | + lr: 0.001 | ||
11 | + batch_size: 4 | ||
12 | + num_workers: 4 | ||
13 | + | ||
14 | + step_iters: [50000, 75000, 100000] | ||
15 | + gamma: 0.1 | ||
16 | + | ||
17 | + visualize_per_iter: 1000 | ||
18 | + save_per_iter: 1000 | ||
19 | + print_per_iter: 10 | ||
20 | + num_epochs: 100 | ||
21 | + | ||
22 | + img_size: 512 |
Code/image-inpainting/datasets/__init__.py
0 → 100644
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
1 | +import os | ||
2 | +import csv | ||
3 | + | ||
4 | +f = open("./datasets/celeba/annotations/train.csv", "w", newline="") | ||
5 | +wr = csv.writer(f) | ||
6 | +wr.writerow(["_", "img_name", "mask_name"]) | ||
7 | + | ||
8 | +for i in range(23304): | ||
9 | + wr.writerow( | ||
10 | + [ | ||
11 | + i, | ||
12 | + "celeba512_30k_masked/" | ||
13 | + + os.listdir("./datasets/celeba/images/celeba512_30k_masked")[i], | ||
14 | + "celeba512_30k_binary/" | ||
15 | + + os.listdir("./datasets/celeba/images/celeba512_30k_binary")[i], | ||
16 | + ] | ||
17 | + ) | ||
18 | + | ||
19 | +f.close() | ||
20 | + | ||
21 | +f = open("./datasets/celeba/annotations/val.csv", "w", newline="") | ||
22 | +wr = csv.writer(f) | ||
23 | +wr.writerow(["_", "img_name", "mask_name"]) | ||
24 | + | ||
25 | +for i in range(23304, 29131): | ||
26 | + wr.writerow( | ||
27 | + [ | ||
28 | + i, | ||
29 | + "celeba512_30k_masked/" | ||
30 | + + os.listdir("./datasets/celeba/images/celeba512_30k_masked")[i], | ||
31 | + "celeba512_30k_binary/" | ||
32 | + + os.listdir("./datasets/celeba/images/celeba512_30k_binary")[i], | ||
33 | + ] | ||
34 | + ) | ||
35 | + | ||
36 | +f.close() |
Code/image-inpainting/datasets/dataset.py
0 → 100644
1 | +import os | ||
2 | +import torch | ||
3 | +import torch.nn as nn | ||
4 | +import torch.utils.data as data | ||
5 | +import cv2 | ||
6 | +import numpy as np | ||
7 | +from tqdm import tqdm | ||
8 | + | ||
9 | +class Places365Dataset(data.Dataset): | ||
10 | + def __init__(self, cfg): | ||
11 | + self.root_dir = cfg.root_dir | ||
12 | + self.cfg = cfg | ||
13 | + self.load_images() | ||
14 | + | ||
15 | + def load_images(self): | ||
16 | + self.fns =[] | ||
17 | + idx = 0 | ||
18 | + img_paths = os.listdir(self.root_dir) | ||
19 | + for cls_id in img_paths: | ||
20 | + paths = os.path.join(self.root_dir, cls_id) | ||
21 | + file_paths = os.listdir(paths) | ||
22 | + for img_name in file_paths: | ||
23 | + filename = os.path.join(paths, img_name) | ||
24 | + self.fns.append(filename) | ||
25 | + | ||
26 | + def __getitem__(self, index): | ||
27 | + img_path = self.fns[index] | ||
28 | + img = cv2.imread(img_path) | ||
29 | + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
30 | + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size)) | ||
31 | + | ||
32 | + mask = self.random_ff_mask( | ||
33 | + shape = self.cfg.img_size, | ||
34 | + max_angle = self.cfg.max_angle, | ||
35 | + max_len = self.cfg.max_len, | ||
36 | + max_width = self.cfg.max_width, | ||
37 | + times = self.cfg.times) | ||
38 | + | ||
39 | + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous() | ||
40 | + mask = torch.from_numpy(mask.astype(np.float32)).contiguous() | ||
41 | + | ||
42 | + return img, mask | ||
43 | + | ||
44 | + def collate_fn(self, batch): | ||
45 | + imgs = torch.stack([i[0] for i in batch]) | ||
46 | + masks = torch.stack([i[1] for i in batch]) | ||
47 | + return { | ||
48 | + 'imgs': imgs, | ||
49 | + 'masks': masks | ||
50 | + } | ||
51 | + | ||
52 | + def __len__(self): | ||
53 | + return len(self.fns) | ||
54 | + | ||
55 | + def random_ff_mask(self, shape = 256 , max_angle = 4, max_len = 50, max_width = 20, times = 15): | ||
56 | + """Generate a random free form mask with configuration. | ||
57 | + Args: | ||
58 | + config: Config should have configuration including IMG_SHAPES, | ||
59 | + VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. | ||
60 | + Returns: | ||
61 | + tuple: (top, left, height, width) | ||
62 | + """ | ||
63 | + height = shape | ||
64 | + width = shape | ||
65 | + mask = np.zeros((height, width), np.float32) | ||
66 | + times = np.random.randint(10, times) | ||
67 | + for i in range(times): | ||
68 | + start_x = np.random.randint(width) | ||
69 | + start_y = np.random.randint(height) | ||
70 | + for j in range(1 + np.random.randint(5)): | ||
71 | + angle = 0.01 + np.random.randint(max_angle) | ||
72 | + if i % 2 == 0: | ||
73 | + angle = 2 * 3.1415926 - angle | ||
74 | + length = 10 + np.random.randint(max_len) | ||
75 | + brush_w = 5 + np.random.randint(max_width) | ||
76 | + end_x = (start_x + length * np.sin(angle)).astype(np.int32) | ||
77 | + end_y = (start_y + length * np.cos(angle)).astype(np.int32) | ||
78 | + cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w) | ||
79 | + start_x, start_y = end_x, end_y | ||
80 | + return mask.reshape((1, ) + mask.shape).astype(np.float32) | ||
81 | + | ||
82 | + | ||
83 | +class FacemaskDataset(data.Dataset): | ||
84 | + def __init__(self, cfg): | ||
85 | + self.root_dir = cfg.root_dir | ||
86 | + self.cfg = cfg | ||
87 | + | ||
88 | + self.mask_folder = os.path.join(self.root_dir, 'celeba512_30k_binary') | ||
89 | + self.img_folder = os.path.join(self.root_dir, 'celeba512_30k') | ||
90 | + self.load_images() | ||
91 | + | ||
92 | + def load_images(self): | ||
93 | + self.fns = [] | ||
94 | + idx = 0 | ||
95 | + img_paths = sorted(os.listdir(self.img_folder)) | ||
96 | + for img_name in img_paths: | ||
97 | + mask_name = img_name.split('.')[0]+'_binary.jpg' | ||
98 | + img_path = os.path.join(self.img_folder, img_name) | ||
99 | + mask_path = os.path.join(self.mask_folder, mask_name) | ||
100 | + if os.path.isfile(mask_path): | ||
101 | + self.fns.append([img_path, mask_path]) | ||
102 | + | ||
103 | + def __getitem__(self, index): | ||
104 | + img_path, mask_path = self.fns[index] | ||
105 | + img = cv2.imread(img_path) | ||
106 | + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
107 | + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size)) | ||
108 | + | ||
109 | + | ||
110 | + mask = cv2.imread(mask_path, 0) | ||
111 | + | ||
112 | + mask[mask>0]=1.0 | ||
113 | + mask = np.expand_dims(mask, axis=0) | ||
114 | + | ||
115 | + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous() | ||
116 | + mask = torch.from_numpy(mask.astype(np.float32)).contiguous() | ||
117 | + return img, mask | ||
118 | + | ||
119 | + def collate_fn(self, batch): | ||
120 | + imgs = torch.stack([i[0] for i in batch]) | ||
121 | + masks = torch.stack([i[1] for i in batch]) | ||
122 | + return { | ||
123 | + 'imgs': imgs, | ||
124 | + 'masks': masks | ||
125 | + } | ||
126 | + | ||
127 | + def __len__(self): | ||
128 | + return len(self.fns) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | +import os | ||
2 | +import torch | ||
3 | +import torch.nn as nn | ||
4 | +import torch.utils.data as data | ||
5 | +import cv2 | ||
6 | +import numpy as np | ||
7 | +from tqdm import tqdm | ||
8 | +import pandas as pd | ||
9 | +from PIL import Image | ||
10 | + | ||
11 | + | ||
12 | +class FacemaskSegDataset(data.Dataset): | ||
13 | + def __init__(self, cfg, train=True): | ||
14 | + self.root_dir = cfg.root_dir | ||
15 | + self.cfg = cfg | ||
16 | + self.train = train | ||
17 | + | ||
18 | + if self.train: | ||
19 | + self.df = pd.read_csv(cfg.train_anns) | ||
20 | + else: | ||
21 | + self.df = pd.read_csv(cfg.val_anns) | ||
22 | + | ||
23 | + self.load_images() | ||
24 | + | ||
25 | + def load_images(self): | ||
26 | + self.fns = [] | ||
27 | + for idx, rows in self.df.iterrows(): | ||
28 | + _, img_name, mask_name = rows | ||
29 | + img_path = os.path.join(self.root_dir, img_name) | ||
30 | + mask_path = os.path.join(self.root_dir, mask_name) | ||
31 | + img_path = img_path.replace("\\", "/") | ||
32 | + mask_path = mask_path.replace("\\", "/") | ||
33 | + if os.path.isfile(mask_path): | ||
34 | + self.fns.append([img_path, mask_path]) | ||
35 | + | ||
36 | + def __getitem__(self, index): | ||
37 | + img_path, mask_path = self.fns[index] | ||
38 | + img = cv2.imread(img_path) | ||
39 | + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
40 | + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size)) | ||
41 | + mask = cv2.imread(mask_path, 0) | ||
42 | + mask[mask > 0] = 1.0 | ||
43 | + mask = np.expand_dims(mask, axis=0) | ||
44 | + | ||
45 | + img = ( | ||
46 | + torch.from_numpy(img.astype(np.float32) / 255.0) | ||
47 | + .permute(2, 0, 1) | ||
48 | + .contiguous() | ||
49 | + ) | ||
50 | + mask = torch.from_numpy(mask.astype(np.float32)).contiguous() | ||
51 | + return img, mask | ||
52 | + | ||
53 | + def collate_fn(self, batch): | ||
54 | + imgs = torch.stack([i[0] for i in batch]) | ||
55 | + masks = torch.stack([i[1] for i in batch]) | ||
56 | + return {"imgs": imgs, "masks": masks} | ||
57 | + | ||
58 | + def __len__(self): | ||
59 | + return len(self.fns) |
Code/image-inpainting/infer.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +from torchvision.utils import save_image | ||
4 | + | ||
5 | +import numpy as np | ||
6 | +from PIL import Image | ||
7 | +import cv2 | ||
8 | +from models import UNetSemantic, GatedGenerator | ||
9 | +import argparse | ||
10 | +from configs import Config | ||
11 | + | ||
12 | +class Predictor(): | ||
13 | + def __init__(self, cfg): | ||
14 | + self.cfg = cfg | ||
15 | + self.device = torch.device('cuda:0' if cfg.cuda else 'cpu') | ||
16 | + self.masking = UNetSemantic().to(self.device) | ||
17 | + self.masking.load_state_dict(torch.load('weights\model_segm_19_135000.pth', map_location='cpu')) | ||
18 | + #self.masking.eval() | ||
19 | + | ||
20 | + self.inpaint = GatedGenerator().to(self.device) | ||
21 | + self.inpaint.load_state_dict(torch.load('weights/model_6_100000.pth', map_location='cpu')['G']) | ||
22 | + self.inpaint.eval() | ||
23 | + | ||
24 | + def save_image(self, img_list, save_img_path, nrow): | ||
25 | + img_list = [i.clone().cpu() for i in img_list] | ||
26 | + imgs = torch.stack(img_list, dim=1) | ||
27 | + imgs = imgs.view(-1, *list(imgs.size())[2:]) | ||
28 | + save_image(imgs, save_img_path, nrow = nrow) | ||
29 | + print(f"Save image to {save_img_path}") | ||
30 | + | ||
31 | + def predict(self, image, outpath='sample/results.png'): | ||
32 | + outpath=f'sample/results_{image}.png' | ||
33 | + image = 'sample/'+image | ||
34 | + img = cv2.imread(image+'_masked.jpg') | ||
35 | + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
36 | + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size)) | ||
37 | + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous() | ||
38 | + img = img.unsqueeze(0).to(self.device) | ||
39 | + | ||
40 | + img_ori = cv2.imread(image+'.jpg') | ||
41 | + img_ori = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB) | ||
42 | + img_ori = cv2.resize(img_ori, (self.cfg.img_size, self.cfg.img_size)) | ||
43 | + img_ori = torch.from_numpy(img_ori.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous() | ||
44 | + img_ori = img_ori.unsqueeze(0) | ||
45 | + with torch.no_grad(): | ||
46 | + outputs = self.masking(img) | ||
47 | + _, out = self.inpaint(img, outputs) | ||
48 | + inpaint = img * (1 - outputs) + out * outputs | ||
49 | + masks = img * (1 - outputs) + outputs #torch.cat([outputs, outputs, outputs], dim=1) | ||
50 | + | ||
51 | + | ||
52 | + | ||
53 | + self.save_image([img, masks, inpaint, img_ori], outpath, nrow=4) | ||
54 | + | ||
55 | + | ||
56 | + | ||
57 | + | ||
58 | +if __name__ == '__main__': | ||
59 | + parser = argparse.ArgumentParser(description='Training custom model') | ||
60 | + parser.add_argument('--image', default=None, type=str, help='resume training') | ||
61 | + parser.add_argument('config', default='config', type=str, help='config training') | ||
62 | + args = parser.parse_args() | ||
63 | + | ||
64 | + config = Config(f'./configs/{args.config}.yaml') | ||
65 | + | ||
66 | + | ||
67 | + model = Predictor(config) | ||
68 | + model.predict(args.image) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/image-inpainting/loggers/__init__.py
0 → 100644
1 | +from .loggers import * | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
No preview for this file type
No preview for this file type
Code/image-inpainting/loggers/loggers.py
0 → 100644
1 | +import os | ||
2 | +import numpy as np | ||
3 | +from torch.utils.tensorboard import SummaryWriter | ||
4 | +from datetime import datetime | ||
5 | + | ||
6 | +class Logger(): | ||
7 | + """ | ||
8 | + Logger for Tensorboard visualization | ||
9 | + :param log_dir: Path to save checkpoint | ||
10 | + """ | ||
11 | + def __init__(self, log_dir=None): | ||
12 | + self.log_dir = log_dir | ||
13 | + if self.log_dir is None: | ||
14 | + self.log_dir = os.path.join('loggers/runs',datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) | ||
15 | + if not os.path.exists(self.log_dir): | ||
16 | + os.mkdir(self.log_dir) | ||
17 | + self.writer = SummaryWriter(log_dir=self.log_dir) | ||
18 | + self.iters = {} | ||
19 | + | ||
20 | + def write(self, tags, values): | ||
21 | + """ | ||
22 | + Write a log to specified directory | ||
23 | + :param tags: (str) tag for log | ||
24 | + :param values: (number) value for corresponding tag | ||
25 | + """ | ||
26 | + if not isinstance(tags, list): | ||
27 | + tags = list(tags) | ||
28 | + if not isinstance(values, list): | ||
29 | + values = list(values) | ||
30 | + | ||
31 | + for i, (tag, value) in enumerate(zip(tags,values)): | ||
32 | + if tag not in self.iters.keys(): | ||
33 | + self.iters[tag] = 0 | ||
34 | + self.writer.add_scalar(tag, value, self.iters[tag]) | ||
35 | + self.iters[tag] += 1 | ||
36 | + | ||
37 | + |
Code/image-inpainting/loggers/runs/your training logs are here, use tensorboard to view .txt
0 → 100644
File mode changed
Code/image-inpainting/losses/__init__.py
0 → 100644
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
Code/image-inpainting/losses/adversarial.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | + | ||
4 | +class GANLoss(nn.Module): | ||
5 | + def __init__(self, target_real_label=1.0, target_fake_label=0.0): | ||
6 | + super(GANLoss, self).__init__() | ||
7 | + self.register_buffer('real_label', torch.tensor(target_real_label)) | ||
8 | + self.register_buffer('fake_label', torch.tensor(target_fake_label)) | ||
9 | + self.loss = nn.MSELoss() | ||
10 | + | ||
11 | + def get_target_tensor(self, input, target_is_real): | ||
12 | + if target_is_real: | ||
13 | + target_tensor = self.real_label | ||
14 | + else: | ||
15 | + target_tensor = self.fake_label | ||
16 | + return target_tensor.expand_as(input) | ||
17 | + | ||
18 | + def __call__(self, input, target_is_real): | ||
19 | + target_tensor = self.get_target_tensor(input, target_is_real).to(input.device) | ||
20 | + return self.loss(input, target_tensor) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/image-inpainting/losses/dice.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | + | ||
4 | + | ||
5 | +class DiceLoss(nn.Module): | ||
6 | + """ | ||
7 | + Dice loss of binary class | ||
8 | + Args: | ||
9 | + smooth: A float number to smooth loss, and avoid NaN error, default: 1 | ||
10 | + p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 | ||
11 | + predict: A tensor of shape [N, *] | ||
12 | + target: A tensor of shape same with predict | ||
13 | + reduction: Reduction method to apply, return mean over batch if 'mean', | ||
14 | + return sum if 'sum', return a tensor of shape [N,] if 'none' | ||
15 | + Returns: | ||
16 | + Loss tensor according to arg reduction | ||
17 | + Raise: | ||
18 | + Exception if unexpected reduction | ||
19 | + """ | ||
20 | + def __init__(self, smooth=1, p=2, reduction='mean'): | ||
21 | + super(DiceLoss, self).__init__() | ||
22 | + self.smooth = smooth | ||
23 | + self.p = p | ||
24 | + self.reduction = reduction | ||
25 | + | ||
26 | + def forward(self, predict, target): | ||
27 | + assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" | ||
28 | + predict = predict.contiguous().view(predict.shape[0], -1) | ||
29 | + target = target.contiguous().view(target.shape[0], -1) | ||
30 | + | ||
31 | + num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth | ||
32 | + den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth | ||
33 | + | ||
34 | + loss = 1 - num / den | ||
35 | + | ||
36 | + if self.reduction == 'mean': | ||
37 | + return loss.mean() | ||
38 | + elif self.reduction == 'sum': | ||
39 | + return loss.sum() | ||
40 | + elif self.reduction == 'none': | ||
41 | + return loss | ||
42 | + else: | ||
43 | + raise Exception('Unexpected reduction {}'.format(self.reduction)) |
Code/image-inpainting/losses/ssim.py
0 → 100644
1 | +#Source: https://github.com/Po-Hsun-Su/pytorch-ssim.git | ||
2 | + | ||
3 | +import torch | ||
4 | +import torch.nn.functional as F | ||
5 | +from torch.autograd import Variable | ||
6 | +import numpy as np | ||
7 | +from math import exp | ||
8 | + | ||
9 | +def gaussian(window_size, sigma): | ||
10 | + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) | ||
11 | + return gauss/gauss.sum() | ||
12 | + | ||
13 | +def create_window(window_size, channel): | ||
14 | + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) | ||
15 | + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | ||
16 | + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | ||
17 | + return window | ||
18 | + | ||
19 | +def _ssim(img1, img2, window, window_size, channel, size_average = True): | ||
20 | + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) | ||
21 | + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) | ||
22 | + | ||
23 | + mu1_sq = mu1.pow(2) | ||
24 | + mu2_sq = mu2.pow(2) | ||
25 | + mu1_mu2 = mu1*mu2 | ||
26 | + | ||
27 | + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq | ||
28 | + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq | ||
29 | + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 | ||
30 | + | ||
31 | + C1 = 0.01**2 | ||
32 | + C2 = 0.03**2 | ||
33 | + | ||
34 | + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) | ||
35 | + | ||
36 | + if size_average: | ||
37 | + return ssim_map.mean() | ||
38 | + else: | ||
39 | + return ssim_map.mean(1).mean(1).mean(1) | ||
40 | + | ||
41 | +class SSIM(torch.nn.Module): | ||
42 | + def __init__(self, window_size = 11, size_average = True): | ||
43 | + super(SSIM, self).__init__() | ||
44 | + self.window_size = window_size | ||
45 | + self.size_average = size_average | ||
46 | + self.channel = 1 | ||
47 | + self.window = create_window(window_size, self.channel) | ||
48 | + | ||
49 | + def forward(self, img1, img2): | ||
50 | + (_, channel, _, _) = img1.size() | ||
51 | + | ||
52 | + if channel == self.channel and self.window.data.type() == img1.data.type(): | ||
53 | + window = self.window | ||
54 | + else: | ||
55 | + window = create_window(self.window_size, channel) | ||
56 | + | ||
57 | + if img1.is_cuda: | ||
58 | + window = window.cuda(img1.get_device()) | ||
59 | + window = window.type_as(img1) | ||
60 | + | ||
61 | + self.window = window | ||
62 | + self.channel = channel | ||
63 | + | ||
64 | + | ||
65 | + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) | ||
66 | + | ||
67 | +def ssim(img1, img2, window_size = 11, size_average = True): | ||
68 | + (_, channel, _, _) = img1.size() | ||
69 | + window = create_window(window_size, channel) | ||
70 | + | ||
71 | + if img1.is_cuda: | ||
72 | + window = window.cuda(img1.get_device()) | ||
73 | + window = window.type_as(img1) | ||
74 | + | ||
75 | + return _ssim(img1, img2, window, window_size, channel, size_average) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/image-inpainting/metrics/__init__.py
0 → 100644
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
Code/image-inpainting/metrics/dicecoeff.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import numpy as np | ||
4 | + | ||
5 | +class DiceScore(): | ||
6 | + def __init__(self, num_classes, ignore_index = None, eps=1e-6, thresh = 0.5): | ||
7 | + self.thresh = thresh | ||
8 | + self.num_classes = num_classes | ||
9 | + self.pred_type = "multi" if num_classes > 1 else "binary" | ||
10 | + | ||
11 | + if num_classes == 1: | ||
12 | + self.num_classes+=1 | ||
13 | + | ||
14 | + self.ignore_index = ignore_index | ||
15 | + self.eps = eps | ||
16 | + | ||
17 | + self.scores_list = np.zeros(self.num_classes) | ||
18 | + self.reset() | ||
19 | + | ||
20 | + def compute(self, outputs, targets): | ||
21 | + # outputs: (batch, num_classes, W, H) | ||
22 | + # targets: (batch, num_classes, W, H) | ||
23 | + | ||
24 | + batch_size, _ , w, h = outputs.shape | ||
25 | + if len(targets.shape) == 3: | ||
26 | + targets = targets.unsqueeze(1) | ||
27 | + | ||
28 | + one_hot_targets = torch.zeros(batch_size, self.num_classes, h, w) | ||
29 | + one_hot_predicts = torch.zeros(batch_size, self.num_classes, h, w) | ||
30 | + | ||
31 | + if self.pred_type == 'binary': | ||
32 | + predicts = (outputs > self.thresh).float() | ||
33 | + elif self.pred_type =='multi': | ||
34 | + predicts = torch.argmax(outputs, dim=1).unsqueeze(1) | ||
35 | + | ||
36 | + one_hot_targets.scatter_(1, targets.long(), 1) | ||
37 | + one_hot_predicts.scatter_(1, predicts.long(), 1) | ||
38 | + | ||
39 | + for cl in range(self.num_classes): | ||
40 | + cl_pred = one_hot_predicts[:,cl,:,:] | ||
41 | + cl_target = one_hot_targets[:,cl,:,:] | ||
42 | + score = self.binary_compute(cl_pred, cl_target) | ||
43 | + self.scores_list[cl] += sum(score) | ||
44 | + | ||
45 | + | ||
46 | + def binary_compute(self, predict, target): | ||
47 | + # outputs: (batch, 1, W, H) | ||
48 | + # targets: (batch, 1, W, H) | ||
49 | + | ||
50 | + intersect = (predict * target).sum((-2,-1)) | ||
51 | + union = (predict + target).sum((-2,-1)) | ||
52 | + return (2. * intersect + self.eps) / (union +self.eps) | ||
53 | + | ||
54 | + def reset(self): | ||
55 | + self.scores_list = np.zeros(self.num_classes) | ||
56 | + self.sample_size = 0 | ||
57 | + | ||
58 | + def update(self, outputs, targets): | ||
59 | + self.sample_size += outputs.shape[0] | ||
60 | + self.compute(outputs, targets) | ||
61 | + | ||
62 | + def value(self): | ||
63 | + scores_each_class = self.scores_list / self.sample_size #mean over number of samples | ||
64 | + if self.pred_type == 'binary': | ||
65 | + scores = scores_each_class[1] # ignore background which is label 0 | ||
66 | + else: | ||
67 | + scores = sum(scores_each_class) / self.num_classes | ||
68 | + return np.round(scores, decimals=4) | ||
69 | + | ||
70 | + def summary(self): | ||
71 | + class_iou = self.scores_list / self.sample_size #mean | ||
72 | + | ||
73 | + print(f'{self.value()}') | ||
74 | + for i, x in enumerate(class_iou): | ||
75 | + print(f'\tClass {i}: {x:.4f}') | ||
76 | + | ||
77 | + def __str__(self): | ||
78 | + return f'Dice Score: {self.value()}' | ||
79 | + | ||
80 | + def __len__(self): | ||
81 | + return len(self.sample_size) | ||
82 | + | ||
83 | + | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/image-inpainting/metrics/pixelacc.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import numpy as np | ||
4 | + | ||
5 | +class PixelAccuracy(): | ||
6 | + def __init__(self, num_classes, ignore_index=None, eps=1e-6, thresh = 0.5): | ||
7 | + self.thresh = thresh | ||
8 | + self.num_classes = num_classes | ||
9 | + self.pred_type = "multi" if num_classes > 1 else "binary" | ||
10 | + | ||
11 | + if num_classes == 1: | ||
12 | + self.num_classes+=1 | ||
13 | + | ||
14 | + self.ignore_index = ignore_index | ||
15 | + self.eps = eps | ||
16 | + | ||
17 | + self.scores_list = np.zeros(self.num_classes) | ||
18 | + self.reset() | ||
19 | + | ||
20 | + def compute(self, outputs, targets): | ||
21 | + # outputs: (batch, num_classes, W, H) | ||
22 | + # targets: (batch, num_classes, W, H) | ||
23 | + | ||
24 | + batch_size, _ , w, h = outputs.shape | ||
25 | + if len(targets.shape) == 3: | ||
26 | + targets = targets.unsqueeze(1) | ||
27 | + | ||
28 | + one_hot_targets = torch.zeros(batch_size, self.num_classes, h, w) | ||
29 | + one_hot_predicts = torch.zeros(batch_size, self.num_classes, h, w) | ||
30 | + | ||
31 | + if self.pred_type == 'binary': | ||
32 | + predicts = (outputs > self.thresh).float() | ||
33 | + elif self.pred_type =='multi': | ||
34 | + predicts = torch.argmax(outputs, dim=1).unsqueeze(1) | ||
35 | + | ||
36 | + one_hot_targets.scatter_(1, targets.long(), 1) | ||
37 | + one_hot_predicts.scatter_(1, predicts.long(), 1) | ||
38 | + | ||
39 | + for cl in range(self.num_classes): | ||
40 | + cl_pred = one_hot_predicts[:,cl,:,:] | ||
41 | + cl_target = one_hot_targets[:,cl,:,:] | ||
42 | + score = self.binary_compute(cl_pred, cl_target) | ||
43 | + self.scores_list[cl] += sum(score) | ||
44 | + | ||
45 | + def binary_compute(self, predict, target): | ||
46 | + # predict: (batch, 1, W, H) | ||
47 | + # targets: (batch, 1, W, H) | ||
48 | + | ||
49 | + correct = (predict == target).sum((-2,-1)) | ||
50 | + total = target.shape[-1] * target.shape[-2] | ||
51 | + return (correct + self.eps) *1.0 / (total +self.eps) | ||
52 | + | ||
53 | + def reset(self): | ||
54 | + self.scores_list = np.zeros(self.num_classes) | ||
55 | + self.sample_size = 0 | ||
56 | + | ||
57 | + def update(self, outputs, targets): | ||
58 | + self.sample_size += outputs.shape[0] | ||
59 | + self.compute(outputs, targets) | ||
60 | + | ||
61 | + def value(self): | ||
62 | + scores_each_class = self.scores_list / self.sample_size #mean over number of samples | ||
63 | + if self.pred_type == 'binary': | ||
64 | + scores = scores_each_class[1] # ignore background which is label 0 | ||
65 | + else: | ||
66 | + scores = sum(scores_each_class) / self.num_classes | ||
67 | + return np.round(scores, decimals=4) | ||
68 | + | ||
69 | + def summary(self): | ||
70 | + class_iou = self.scores_list / self.sample_size #mean | ||
71 | + | ||
72 | + print(f'{self.value()}') | ||
73 | + for i, x in enumerate(class_iou): | ||
74 | + print(f'\tClass {i}: {x:.4f}') | ||
75 | + | ||
76 | + def __str__(self): | ||
77 | + return f'Pixel Accuracy: {self.value()}' | ||
78 | + | ||
79 | + def __len__(self): | ||
80 | + return len(self.sample_size) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/image-inpainting/models/__init__.py
0 → 100644
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
Code/image-inpainting/models/networks.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +from torch.nn import Parameter | ||
4 | +import torch.nn.functional as F | ||
5 | +import torch.utils.data as data | ||
6 | +import functools | ||
7 | +from torchvision.models import vgg19, vgg16 | ||
8 | + | ||
9 | +class GatedConv2d(nn.Module): | ||
10 | + def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, activation = 'lrelu', norm = 'in'): | ||
11 | + super(GatedConv2d, self).__init__() | ||
12 | + self.pad = nn.ZeroPad2d(padding) | ||
13 | + if norm is not None: | ||
14 | + self.norm = nn.InstanceNorm2d(out_channels) | ||
15 | + else: | ||
16 | + self.norm = None | ||
17 | + | ||
18 | + if activation == 'tanh': | ||
19 | + self.activation = nn.Tanh() | ||
20 | + else: | ||
21 | + self.activation = nn.LeakyReLU(0.2, inplace = True) | ||
22 | + | ||
23 | + | ||
24 | + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation) | ||
25 | + self.mask_conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation) | ||
26 | + self.sigmoid = torch.nn.Sigmoid() | ||
27 | + | ||
28 | + def forward(self, x): | ||
29 | + x = self.pad(x) | ||
30 | + conv = self.conv2d(x) | ||
31 | + mask = self.mask_conv2d(x) | ||
32 | + gated_mask = self.sigmoid(mask) | ||
33 | + x = conv * gated_mask | ||
34 | + if self.norm: | ||
35 | + x = self.norm(x) | ||
36 | + if self.activation: | ||
37 | + x = self.activation(x) | ||
38 | + return x | ||
39 | + | ||
40 | +class TransposeGatedConv2d(nn.Module): | ||
41 | + def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, norm=None, scale_factor = 2): | ||
42 | + super(TransposeGatedConv2d, self).__init__() | ||
43 | + # Initialize the conv scheme | ||
44 | + self.scale_factor = scale_factor | ||
45 | + self.gated_conv2d = GatedConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, norm=norm) | ||
46 | + | ||
47 | + def forward(self, x): | ||
48 | + x = F.interpolate(x, scale_factor = self.scale_factor, mode = 'nearest') | ||
49 | + x = self.gated_conv2d(x) | ||
50 | + return x | ||
51 | + | ||
52 | + | ||
53 | +class GatedGenerator(nn.Module): | ||
54 | + def __init__(self, in_channels=4, latent_channels=64, out_channels=3): | ||
55 | + super(GatedGenerator, self).__init__() | ||
56 | + self.coarse = nn.Sequential( | ||
57 | + # encoder | ||
58 | + GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None), | ||
59 | + GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1), | ||
60 | + GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1), | ||
61 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1), | ||
62 | + # Bottleneck | ||
63 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1), | ||
64 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1), | ||
65 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2), | ||
66 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4), | ||
67 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8), | ||
68 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16), | ||
69 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1), | ||
70 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1), | ||
71 | + # decoder | ||
72 | + TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1), | ||
73 | + GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1), | ||
74 | + TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1), | ||
75 | + GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None) | ||
76 | + ) | ||
77 | + self.refinement = nn.Sequential( | ||
78 | + # encoder | ||
79 | + GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None), | ||
80 | + GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1), | ||
81 | + GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1), | ||
82 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1), | ||
83 | + # Bottleneck | ||
84 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1), | ||
85 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1), | ||
86 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2), | ||
87 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4), | ||
88 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8), | ||
89 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16), | ||
90 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1), | ||
91 | + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1), | ||
92 | + # decoder | ||
93 | + TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1), | ||
94 | + GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1), | ||
95 | + TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1), | ||
96 | + GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None) | ||
97 | + ) | ||
98 | + | ||
99 | + def forward(self, img, mask): | ||
100 | + # img: entire img | ||
101 | + # mask: 1 for mask region; 0 for unmask region | ||
102 | + # 1 - mask: unmask | ||
103 | + # img * (1 - mask): ground truth unmask region | ||
104 | + # Coarse | ||
105 | + | ||
106 | + first_masked_img = img * (1 - mask) + mask | ||
107 | + first_in = torch.cat((first_masked_img, mask), 1) # in: [B, 4, H, W] | ||
108 | + first_out = self.coarse(first_in) # out: [B, 3, H, W] | ||
109 | + # Refinement | ||
110 | + second_masked_img = img * (1 - mask) + first_out * mask | ||
111 | + second_in = torch.cat((second_masked_img, mask), 1) # in: [B, 4, H, W] | ||
112 | + second_out = self.refinement(second_in) # out: [B, 3, H, W] | ||
113 | + return first_out, second_out | ||
114 | + | ||
115 | + | ||
116 | +class NLayerDiscriminator(nn.Module): | ||
117 | + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): | ||
118 | + super(NLayerDiscriminator, self).__init__() | ||
119 | + if type(norm_layer) == functools.partial: | ||
120 | + use_bias = norm_layer.func == nn.InstanceNorm2d | ||
121 | + else: | ||
122 | + use_bias = norm_layer == nn.InstanceNorm2d | ||
123 | + | ||
124 | + kw = 4 | ||
125 | + padw = 1 | ||
126 | + sequence = [ | ||
127 | + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), | ||
128 | + nn.LeakyReLU(0.2, True) | ||
129 | + ] | ||
130 | + | ||
131 | + nf_mult = 1 | ||
132 | + nf_mult_prev = 1 | ||
133 | + for n in range(1, n_layers): | ||
134 | + nf_mult_prev = nf_mult | ||
135 | + nf_mult = min(2**n, 8) | ||
136 | + sequence += [ | ||
137 | + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, | ||
138 | + kernel_size=kw, stride=2, padding=padw, bias=use_bias), | ||
139 | + norm_layer(ndf * nf_mult), | ||
140 | + nn.LeakyReLU(0.2, True) | ||
141 | + ] | ||
142 | + | ||
143 | + nf_mult_prev = nf_mult | ||
144 | + nf_mult = min(2**n_layers, 8) | ||
145 | + sequence += [ | ||
146 | + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, | ||
147 | + kernel_size=kw, stride=1, padding=padw, bias=use_bias), | ||
148 | + norm_layer(ndf * nf_mult), | ||
149 | + nn.LeakyReLU(0.2, True) | ||
150 | + ] | ||
151 | + | ||
152 | + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] | ||
153 | + | ||
154 | + if use_sigmoid: | ||
155 | + sequence += [nn.Sigmoid()] | ||
156 | + | ||
157 | + self.model = nn.Sequential(*sequence) | ||
158 | + | ||
159 | + def forward(self, input): | ||
160 | + return self.model(input) | ||
161 | + | ||
162 | +class PerceptualNet(nn.Module): | ||
163 | + # https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49 | ||
164 | + def __init__(self, name = "vgg19", resize=True): | ||
165 | + super(PerceptualNet, self).__init__() | ||
166 | + blocks = [] | ||
167 | + if name == "vgg19": | ||
168 | + blocks.append(vgg19(pretrained=True).features[:4].eval()) | ||
169 | + blocks.append(vgg19(pretrained=True).features[4:9].eval()) | ||
170 | + blocks.append(vgg19(pretrained=True).features[9:16].eval()) | ||
171 | + blocks.append(vgg19(pretrained=True).features[16:23].eval()) | ||
172 | + elif name == "vgg16": | ||
173 | + blocks.append(vgg16(pretrained=True).features[:4].eval()) | ||
174 | + blocks.append(vgg16(pretrained=True).features[4:9].eval()) | ||
175 | + blocks.append(vgg16(pretrained=True).features[9:16].eval()) | ||
176 | + blocks.append(vgg16(pretrained=True).features[16:23].eval()) | ||
177 | + else: | ||
178 | + assert "wrong model name" | ||
179 | + | ||
180 | + for bl in blocks: | ||
181 | + for p in bl: | ||
182 | + p.requires_grad = False | ||
183 | + self.blocks = torch.nn.ModuleList(blocks) | ||
184 | + self.transform = torch.nn.functional.interpolate | ||
185 | + self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) | ||
186 | + self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) | ||
187 | + self.resize = resize | ||
188 | + | ||
189 | + def forward(self, inputs, targets): | ||
190 | + if inputs.shape[1] != 3: | ||
191 | + inputs = inputs.repeat(1, 3, 1, 1) | ||
192 | + targets = targets.repeat(1, 3, 1, 1) | ||
193 | + inputs = (inputs-self.mean) / self.std | ||
194 | + targets = (targets-self.mean) / self.std | ||
195 | + if self.resize: | ||
196 | + inputs = self.transform(inputs, mode='bilinear', size=(512, 512), align_corners=False) | ||
197 | + targets = self.transform(targets, mode='bilinear', size=(512, 512), align_corners=False) | ||
198 | + loss = 0.0 | ||
199 | + x = inputs | ||
200 | + y = targets | ||
201 | + for block in self.blocks: | ||
202 | + x = block(x) | ||
203 | + y = block(y) | ||
204 | + loss += torch.nn.functional.l1_loss(x, y) | ||
205 | + return loss | ||
206 | + | ||
207 | + | ||
208 | + |
Code/image-inpainting/models/unet.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +from torch.nn import Parameter | ||
4 | +import torch.nn.functional as F | ||
5 | +import torch.utils.data as data | ||
6 | +import functools | ||
7 | + | ||
8 | + | ||
9 | +class conv_block(nn.Module): | ||
10 | + """ | ||
11 | + Convolution Block | ||
12 | + """ | ||
13 | + def __init__(self, in_ch, out_ch): | ||
14 | + super(conv_block, self).__init__() | ||
15 | + | ||
16 | + self.conv = nn.Sequential( | ||
17 | + nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), | ||
18 | + nn.BatchNorm2d(out_ch), | ||
19 | + nn.ReLU(inplace=True), | ||
20 | + nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), | ||
21 | + nn.BatchNorm2d(out_ch), | ||
22 | + nn.ReLU(inplace=True)) | ||
23 | + | ||
24 | + def forward(self, x): | ||
25 | + | ||
26 | + x = self.conv(x) | ||
27 | + return x | ||
28 | + | ||
29 | + | ||
30 | +class up_conv(nn.Module): | ||
31 | + """ | ||
32 | + Up Convolution Block | ||
33 | + """ | ||
34 | + def __init__(self, in_ch, out_ch): | ||
35 | + super(up_conv, self).__init__() | ||
36 | + self.up = nn.Sequential( | ||
37 | + nn.Upsample(scale_factor=2), | ||
38 | + nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), | ||
39 | + nn.BatchNorm2d(out_ch), | ||
40 | + nn.ReLU(inplace=True) | ||
41 | + ) | ||
42 | + | ||
43 | + def forward(self, x): | ||
44 | + x = self.up(x) | ||
45 | + return x | ||
46 | + | ||
47 | + | ||
48 | + | ||
49 | +class Recurrent_block(nn.Module): | ||
50 | + """ | ||
51 | + Recurrent Block for R2Unet_CNN | ||
52 | + """ | ||
53 | + def __init__(self, out_ch, t=2): | ||
54 | + super(Recurrent_block, self).__init__() | ||
55 | + | ||
56 | + self.t = t | ||
57 | + self.out_ch = out_ch | ||
58 | + self.conv = nn.Sequential( | ||
59 | + nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), | ||
60 | + nn.BatchNorm2d(out_ch), | ||
61 | + nn.ReLU(inplace=True) | ||
62 | + ) | ||
63 | + | ||
64 | + def forward(self, x): | ||
65 | + for i in range(self.t): | ||
66 | + if i == 0: | ||
67 | + x = self.conv(x) | ||
68 | + out = self.conv(x + x) | ||
69 | + return out | ||
70 | + | ||
71 | + | ||
72 | +class RRCNN_block(nn.Module): | ||
73 | + """ | ||
74 | + Recurrent Residual Convolutional Neural Network Block | ||
75 | + """ | ||
76 | + def __init__(self, in_ch, out_ch, t=2): | ||
77 | + super(RRCNN_block, self).__init__() | ||
78 | + | ||
79 | + self.RCNN = nn.Sequential( | ||
80 | + Recurrent_block(out_ch, t=t), | ||
81 | + Recurrent_block(out_ch, t=t) | ||
82 | + ) | ||
83 | + self.Conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0) | ||
84 | + | ||
85 | + def forward(self, x): | ||
86 | + x1 = self.Conv(x) | ||
87 | + x2 = self.RCNN(x1) | ||
88 | + out = x1 + x2 | ||
89 | + return out | ||
90 | + | ||
91 | +class Attention_block(nn.Module): | ||
92 | + """ | ||
93 | + Attention Block | ||
94 | + """ | ||
95 | + | ||
96 | + def __init__(self, F_g, F_l, F_int): | ||
97 | + super(Attention_block, self).__init__() | ||
98 | + | ||
99 | + self.W_g = nn.Sequential( | ||
100 | + nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | ||
101 | + nn.BatchNorm2d(F_int) | ||
102 | + ) | ||
103 | + | ||
104 | + self.W_x = nn.Sequential( | ||
105 | + nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | ||
106 | + nn.BatchNorm2d(F_int) | ||
107 | + ) | ||
108 | + | ||
109 | + self.psi = nn.Sequential( | ||
110 | + nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | ||
111 | + nn.BatchNorm2d(1), | ||
112 | + nn.Sigmoid() | ||
113 | + ) | ||
114 | + | ||
115 | + self.relu = nn.ReLU(inplace=True) | ||
116 | + | ||
117 | + def forward(self, g, x): | ||
118 | + g1 = self.W_g(g) | ||
119 | + x1 = self.W_x(x) | ||
120 | + psi = self.relu(g1 + x1) | ||
121 | + psi = self.psi(psi) | ||
122 | + out = x * psi | ||
123 | + return out | ||
124 | + | ||
125 | +class SE_Block(nn.Module): | ||
126 | + "credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4" | ||
127 | + def __init__(self, c, r=16): | ||
128 | + super().__init__() | ||
129 | + self.squeeze = nn.AdaptiveAvgPool2d(1) | ||
130 | + self.excitation = nn.Sequential( | ||
131 | + nn.Linear(c, c // r, bias=False), | ||
132 | + nn.ReLU(inplace=True), | ||
133 | + nn.Linear(c // r, c, bias=False), | ||
134 | + nn.Sigmoid() | ||
135 | + ) | ||
136 | + | ||
137 | + def forward(self, x): | ||
138 | + bs, c, _, _ = x.shape | ||
139 | + y = self.squeeze(x).view(bs, c) | ||
140 | + y = self.excitation(y).view(bs, c, 1, 1) | ||
141 | + return x * y.expand_as(x) | ||
142 | + | ||
143 | +class AtrousConv(nn.Module): | ||
144 | + def __init__(self, in_ch): | ||
145 | + super().__init__() | ||
146 | + self.atrous_conv = nn.Sequential( | ||
147 | + nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, dilation=2, padding=2), | ||
148 | + nn.BatchNorm2d(in_ch), | ||
149 | + nn.ReLU(), | ||
150 | + | ||
151 | + nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, dilation=4, padding=4), | ||
152 | + nn.BatchNorm2d(in_ch), | ||
153 | + nn.ReLU(), | ||
154 | + | ||
155 | + nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, dilation=8, padding=8), | ||
156 | + nn.BatchNorm2d(in_ch), | ||
157 | + nn.ReLU(), | ||
158 | + | ||
159 | + nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, dilation=16, padding=16), | ||
160 | + nn.BatchNorm2d(in_ch), | ||
161 | + nn.ReLU(), | ||
162 | + ) | ||
163 | + | ||
164 | + def forward(self, x): | ||
165 | + return self.atrous_conv(x) | ||
166 | + | ||
167 | + | ||
168 | +class UNetSemantic(nn.Module): | ||
169 | + """ | ||
170 | + UNet - Basic Implementation | ||
171 | + Paper : https://arxiv.org/abs/1505.04597 | ||
172 | + """ | ||
173 | + def __init__(self, in_ch=3, out_ch=1): | ||
174 | + super(UNetSemantic, self).__init__() | ||
175 | + | ||
176 | + n1 = 32 | ||
177 | + filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] | ||
178 | + | ||
179 | + self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
180 | + self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
181 | + self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
182 | + self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) | ||
183 | + | ||
184 | + self.Conv1 = conv_block(in_ch, filters[0]) | ||
185 | + self.Conv2 = conv_block(filters[0], filters[1]) | ||
186 | + self.Conv3 = conv_block(filters[1], filters[2]) | ||
187 | + self.Conv4 = conv_block(filters[2], filters[3]) | ||
188 | + self.Conv5 = conv_block(filters[3], filters[4]) | ||
189 | + | ||
190 | + self.Up5 = up_conv(filters[4], filters[3]) | ||
191 | + self.Up_conv5 = conv_block(filters[4], filters[3]) | ||
192 | + | ||
193 | + self.Up4 = up_conv(filters[3], filters[2]) | ||
194 | + self.Up_conv4 = conv_block(filters[3], filters[2]) | ||
195 | + | ||
196 | + self.Up3 = up_conv(filters[2], filters[1]) | ||
197 | + self.Up_conv3 = conv_block(filters[2], filters[1]) | ||
198 | + | ||
199 | + self.Up2 = up_conv(filters[1], filters[0]) | ||
200 | + self.Up_conv2 = conv_block(filters[1], filters[0]) | ||
201 | + | ||
202 | + self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) | ||
203 | + self.se1 = SE_Block(filters[0]) | ||
204 | + self.se2 = SE_Block(filters[1]) | ||
205 | + self.se3 = SE_Block(filters[2]) | ||
206 | + self.active = torch.nn.Sigmoid() | ||
207 | + | ||
208 | + def forward(self, x): | ||
209 | + | ||
210 | + e1 = self.Conv1(x) | ||
211 | + e1 = self.se1(e1) | ||
212 | + | ||
213 | + e2 = self.Maxpool1(e1) | ||
214 | + e2 = self.Conv2(e2) | ||
215 | + e2 = self.se2(e2) | ||
216 | + | ||
217 | + e3 = self.Maxpool2(e2) | ||
218 | + e3 = self.Conv3(e3) | ||
219 | + e3 = self.se3(e3) | ||
220 | + | ||
221 | + e4 = self.Maxpool3(e3) | ||
222 | + e4 = self.Conv4(e4) | ||
223 | + | ||
224 | + e5 = self.Maxpool4(e4) | ||
225 | + e5 = self.Conv5(e5) | ||
226 | + | ||
227 | + d5 = self.Up5(e5) | ||
228 | + d5 = torch.cat((e4, d5), dim=1) | ||
229 | + | ||
230 | + d5 = self.Up_conv5(d5) | ||
231 | + | ||
232 | + d4 = self.Up4(d5) | ||
233 | + d4 = torch.cat((e3, d4), dim=1) | ||
234 | + d4 = self.Up_conv4(d4) | ||
235 | + | ||
236 | + d3 = self.Up3(d4) | ||
237 | + d3 = torch.cat((e2, d3), dim=1) | ||
238 | + d3 = self.Up_conv3(d3) | ||
239 | + | ||
240 | + d2 = self.Up2(d3) | ||
241 | + d2 = torch.cat((e1, d2), dim=1) | ||
242 | + d2 = self.Up_conv2(d2) | ||
243 | + | ||
244 | + out = self.Conv(d2) | ||
245 | + | ||
246 | + out = self.active(out) | ||
247 | + | ||
248 | + return out | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/image-inpainting/train.py
0 → 100644
1 | +import argparse | ||
2 | +from configs import Config | ||
3 | +from trainer import Trainer | ||
4 | +from unet_trainer import UNetTrainer | ||
5 | + | ||
6 | + | ||
7 | +def main(args, cfg): | ||
8 | + if args.config == "segm": | ||
9 | + trainer = UNetTrainer(args, cfg) | ||
10 | + else: | ||
11 | + trainer = Trainer(args, cfg) | ||
12 | + trainer.fit() | ||
13 | + | ||
14 | + | ||
15 | +if __name__ == "__main__": | ||
16 | + parser = argparse.ArgumentParser(description="Training custom model") | ||
17 | + parser.add_argument("--resume", default=None, type=str, help="resume training") | ||
18 | + parser.add_argument("config", default="config", type=str, help="config training") | ||
19 | + args = parser.parse_args() | ||
20 | + | ||
21 | + config = Config(f"./configs/{args.config}.yaml") | ||
22 | + main(args, config) |
Code/image-inpainting/trainer.py
0 → 100644
1 | +import os | ||
2 | +import cv2 | ||
3 | +import time | ||
4 | +import numpy as np | ||
5 | +from PIL import Image | ||
6 | +import matplotlib.pyplot as plt | ||
7 | + | ||
8 | +import torch | ||
9 | +import torch.nn as nn | ||
10 | +import torch.utils.data as data | ||
11 | +from torch.optim.lr_scheduler import StepLR | ||
12 | +from torchvision.utils import save_image | ||
13 | + | ||
14 | + | ||
15 | +from models import * | ||
16 | +from losses import * | ||
17 | +from datasets import Places365Dataset, FacemaskDataset | ||
18 | + | ||
19 | + | ||
20 | +def adjust_learning_rate(optimizer, gamma, num_steps=1): | ||
21 | + for i in range(num_steps): | ||
22 | + for param_group in optimizer.param_groups: | ||
23 | + param_group['lr'] *= gamma | ||
24 | + | ||
25 | +def get_epoch_iters(path): | ||
26 | + path = os.path.basename(path) | ||
27 | + tokens = path[:-4].split('_') | ||
28 | + try: | ||
29 | + if tokens[-1] == 'interrupted': | ||
30 | + epoch_idx = int(tokens[-3]) | ||
31 | + iter_idx = int(tokens[-2]) | ||
32 | + else: | ||
33 | + epoch_idx = int(tokens[-2]) | ||
34 | + iter_idx = int(tokens[-1]) | ||
35 | + except: | ||
36 | + return 0, 0 | ||
37 | + | ||
38 | + return epoch_idx, iter_idx | ||
39 | + | ||
40 | +def load_checkpoint(model_G, model_D, path): | ||
41 | + state = torch.load(path,map_location='cpu') | ||
42 | + model_G.load_state_dict(state['G']) | ||
43 | + model_D.load_state_dict(state['D']) | ||
44 | + print('Loaded checkpoint successfully') | ||
45 | + | ||
46 | +class Trainer(): | ||
47 | + def __init__(self, args, cfg): | ||
48 | + | ||
49 | + if args.resume is not None: | ||
50 | + epoch, iters = get_epoch_iters(args.resume) | ||
51 | + else: | ||
52 | + epoch = 0 | ||
53 | + iters = 0 | ||
54 | + | ||
55 | + if not os.path.exists(cfg.checkpoint_path): | ||
56 | + os.makedirs(cfg.checkpoint_path) | ||
57 | + if not os.path.exists(cfg.sample_folder): | ||
58 | + os.makedirs(cfg.sample_folder) | ||
59 | + | ||
60 | + self.cfg = cfg | ||
61 | + self.step_iters = cfg.step_iters | ||
62 | + self.gamma = cfg.gamma | ||
63 | + self.visualize_per_iter = cfg.visualize_per_iter | ||
64 | + self.print_per_iter = cfg.print_per_iter | ||
65 | + self.save_per_iter = cfg.save_per_iter | ||
66 | + | ||
67 | + self.start_iter = iters | ||
68 | + self.iters = 0 | ||
69 | + self.num_epochs = cfg.num_epochs | ||
70 | + self.device = torch.device('cuda' if cfg.cuda else 'cpu') | ||
71 | + | ||
72 | + trainset = FacemaskDataset(cfg) # Places365Dataset(cfg) # | ||
73 | + | ||
74 | + self.trainloader = data.DataLoader( | ||
75 | + trainset, | ||
76 | + batch_size=cfg.batch_size, | ||
77 | + num_workers = cfg.num_workers, | ||
78 | + pin_memory = True, | ||
79 | + shuffle=True, | ||
80 | + collate_fn = trainset.collate_fn) | ||
81 | + | ||
82 | + self.epoch = int(self.start_iter / len(self.trainloader)) | ||
83 | + self.iters = self.start_iter | ||
84 | + self.num_iters = (self.num_epochs+1) * len(self.trainloader) | ||
85 | + | ||
86 | + self.model_G = GatedGenerator().to(self.device) | ||
87 | + self.model_D = NLayerDiscriminator(cfg.d_num_layers, use_sigmoid=False).to(self.device) | ||
88 | + self.model_P = PerceptualNet(name = "vgg16", resize=False).to(self.device) | ||
89 | + | ||
90 | + if args.resume is not None: | ||
91 | + load_checkpoint(self.model_G, self.model_D, args.resume) | ||
92 | + | ||
93 | + self.criterion_adv = GANLoss(target_real_label=0.9, target_fake_label=0.1) | ||
94 | + self.criterion_rec = nn.SmoothL1Loss() | ||
95 | + self.criterion_ssim = SSIM(window_size = 11) | ||
96 | + self.criterion_per = nn.SmoothL1Loss() | ||
97 | + | ||
98 | + self.optimizer_D = torch.optim.Adam(self.model_D.parameters(), lr=cfg.lr) | ||
99 | + self.optimizer_G = torch.optim.Adam(self.model_G.parameters(), lr=cfg.lr) | ||
100 | + | ||
101 | + def validate(self, sample_folder, sample_name, img_list): | ||
102 | + save_img_path = os.path.join(sample_folder, sample_name+'.png') | ||
103 | + img_list = [i.clone().cpu() for i in img_list] | ||
104 | + imgs = torch.stack(img_list, dim=1) | ||
105 | + | ||
106 | + # imgs shape: Bx5xCxWxH | ||
107 | + | ||
108 | + imgs = imgs.view(-1, *list(imgs.size())[2:]) | ||
109 | + save_image(imgs, save_img_path, nrow= 5) | ||
110 | + print(f"Save image to {save_img_path}") | ||
111 | + | ||
112 | + def fit(self): | ||
113 | + self.model_G.train() | ||
114 | + self.model_D.train() | ||
115 | + | ||
116 | + running_loss = { | ||
117 | + 'D': 0, | ||
118 | + 'G': 0, | ||
119 | + 'P': 0, | ||
120 | + 'R_1': 0, | ||
121 | + 'R_2': 0, | ||
122 | + 'T': 0, | ||
123 | + } | ||
124 | + | ||
125 | + running_time = 0 | ||
126 | + step = 0 | ||
127 | + try: | ||
128 | + for epoch in range(self.epoch, self.num_epochs): | ||
129 | + self.epoch = epoch | ||
130 | + for i, batch in enumerate(self.trainloader): | ||
131 | + start_time = time.time() | ||
132 | + imgs = batch['imgs'].to(self.device) | ||
133 | + masks = batch['masks'].to(self.device) | ||
134 | + | ||
135 | + # Train discriminator | ||
136 | + self.optimizer_D.zero_grad() | ||
137 | + self.optimizer_G.zero_grad() | ||
138 | + | ||
139 | + first_out, second_out = self.model_G(imgs, masks) | ||
140 | + | ||
141 | + first_out_wholeimg = imgs * (1 - masks) + first_out * masks | ||
142 | + second_out_wholeimg = imgs * (1 - masks) + second_out * masks | ||
143 | + | ||
144 | + masks = masks.cpu() | ||
145 | + | ||
146 | + fake_D = self.model_D(second_out_wholeimg.detach()) | ||
147 | + real_D = self.model_D(imgs) | ||
148 | + | ||
149 | + loss_fake_D = self.criterion_adv(fake_D, target_is_real=False) | ||
150 | + loss_real_D = self.criterion_adv(real_D, target_is_real=True) | ||
151 | + | ||
152 | + loss_D = (loss_fake_D + loss_real_D) * 0.5 | ||
153 | + | ||
154 | + loss_D.backward() | ||
155 | + self.optimizer_D.step() | ||
156 | + | ||
157 | + real_D = None | ||
158 | + | ||
159 | + # Train Generator | ||
160 | + self.optimizer_D.zero_grad() | ||
161 | + self.optimizer_G.zero_grad() | ||
162 | + | ||
163 | + fake_D = self.model_D(second_out_wholeimg) | ||
164 | + loss_G = self.criterion_adv(fake_D, target_is_real=True) | ||
165 | + | ||
166 | + fake_D = None | ||
167 | + | ||
168 | + # Reconstruction loss | ||
169 | + loss_l1_1 = self.criterion_rec(first_out_wholeimg, imgs) | ||
170 | + loss_l1_2 = self.criterion_rec(second_out_wholeimg, imgs) | ||
171 | + loss_ssim_1 = self.criterion_ssim(first_out_wholeimg, imgs) | ||
172 | + loss_ssim_2 = self.criterion_ssim(second_out_wholeimg, imgs) | ||
173 | + | ||
174 | + loss_rec_1 = 0.5 * loss_l1_1 + 0.5 * (1 - loss_ssim_1) | ||
175 | + loss_rec_2 = 0.5 * loss_l1_2 + 0.5 * (1 - loss_ssim_2) | ||
176 | + | ||
177 | + # Perceptual loss | ||
178 | + loss_P = self.model_P(second_out_wholeimg, imgs) | ||
179 | + | ||
180 | + loss = self.cfg.lambda_G * loss_G + self.cfg.lambda_rec_1 * loss_rec_1 + self.cfg.lambda_rec_2 * loss_rec_2 + self.cfg.lambda_per * loss_P | ||
181 | + loss.backward() | ||
182 | + self.optimizer_G.step() | ||
183 | + | ||
184 | + end_time = time.time() | ||
185 | + | ||
186 | + imgs = imgs.cpu() | ||
187 | + # Visualize number | ||
188 | + running_time += (end_time - start_time) | ||
189 | + running_loss['D'] += loss_D.item() | ||
190 | + running_loss['G'] += (self.cfg.lambda_G * loss_G.item()) | ||
191 | + running_loss['P'] += (self.cfg.lambda_per * loss_P.item()) | ||
192 | + running_loss['R_1'] += (self.cfg.lambda_rec_1 * loss_rec_1.item()) | ||
193 | + running_loss['R_2'] += (self.cfg.lambda_rec_2 * loss_rec_2.item()) | ||
194 | + running_loss['T'] += loss.item() | ||
195 | + | ||
196 | + | ||
197 | + if self.iters % self.print_per_iter == 0: | ||
198 | + for key in running_loss.keys(): | ||
199 | + running_loss[key] /= self.print_per_iter | ||
200 | + running_loss[key] = np.round(running_loss[key], 5) | ||
201 | + loss_string = '{}'.format(running_loss)[1:-1].replace("'",'').replace(",",' ||') | ||
202 | + print("[{}|{}] [{}|{}] || {} || Time: {:10.4f}s".format(self.epoch, self.num_epochs, self.iters, self.num_iters, loss_string, running_time)) | ||
203 | + | ||
204 | + running_loss = { | ||
205 | + 'D': 0, | ||
206 | + 'G': 0, | ||
207 | + 'P': 0, | ||
208 | + 'R_1': 0, | ||
209 | + 'R_2': 0, | ||
210 | + 'T': 0, | ||
211 | + } | ||
212 | + running_time = 0 | ||
213 | + | ||
214 | + if self.iters % self.save_per_iter == 0: | ||
215 | + torch.save({ | ||
216 | + 'D': self.model_D.state_dict(), | ||
217 | + 'G': self.model_G.state_dict(), | ||
218 | + }, os.path.join(self.cfg.checkpoint_path, f"model_{self.epoch}_{self.iters}.pth")) | ||
219 | + | ||
220 | + # Step learning rate | ||
221 | + if self.iters == self.step_iters[step]: | ||
222 | + adjust_learning_rate(self.optimizer_D, self.gamma) | ||
223 | + adjust_learning_rate(self.optimizer_G, self.gamma) | ||
224 | + step+=1 | ||
225 | + | ||
226 | + # Visualize sample | ||
227 | + if self.iters % self.visualize_per_iter == 0: | ||
228 | + masked_imgs = imgs * (1 - masks) + masks | ||
229 | + | ||
230 | + img_list = [imgs, masked_imgs, first_out, second_out, second_out_wholeimg] | ||
231 | + #name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out'] | ||
232 | + filename = f"{self.epoch}_{str(self.iters)}" | ||
233 | + self.validate(self.cfg.sample_folder, filename , img_list) | ||
234 | + | ||
235 | + self.iters += 1 | ||
236 | + | ||
237 | + except KeyboardInterrupt: | ||
238 | + torch.save({ | ||
239 | + 'D': self.model_D.state_dict(), | ||
240 | + 'G': self.model_G.state_dict(), | ||
241 | + }, os.path.join(self.cfg.checkpoint_path, f"model_{self.epoch}_{self.iters}.pth")) | ||
242 | + | ||
243 | + | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/image-inpainting/unet_trainer.py
0 → 100644
1 | +import os | ||
2 | +import cv2 | ||
3 | +import time | ||
4 | +import numpy as np | ||
5 | +from PIL import Image | ||
6 | +import matplotlib.pyplot as plt | ||
7 | +from tqdm import tqdm | ||
8 | + | ||
9 | +import torch | ||
10 | +import torch.nn as nn | ||
11 | +import torch.utils.data as data | ||
12 | +from torch.optim.lr_scheduler import StepLR | ||
13 | +from torchvision.utils import save_image | ||
14 | + | ||
15 | + | ||
16 | +from models import UNetSemantic | ||
17 | +from losses import DiceLoss | ||
18 | +from datasets import FacemaskSegDataset | ||
19 | +from metrics import * | ||
20 | + | ||
21 | + | ||
22 | +def adjust_learning_rate(optimizer, gamma, num_steps=1): | ||
23 | + for i in range(num_steps): | ||
24 | + for param_group in optimizer.param_groups: | ||
25 | + param_group["lr"] *= gamma | ||
26 | + | ||
27 | + | ||
28 | +def get_epoch_iters(path): | ||
29 | + path = os.path.basename(path) | ||
30 | + tokens = path[:-4].split("_") | ||
31 | + try: | ||
32 | + if tokens[-1] == "interrupted": | ||
33 | + epoch_idx = int(tokens[-3]) | ||
34 | + iter_idx = int(tokens[-2]) | ||
35 | + else: | ||
36 | + epoch_idx = int(tokens[-2]) | ||
37 | + iter_idx = int(tokens[-1]) | ||
38 | + except: | ||
39 | + return 0, 0 | ||
40 | + | ||
41 | + return epoch_idx, iter_idx | ||
42 | + | ||
43 | + | ||
44 | +def load_checkpoint(model, path): | ||
45 | + state = torch.load(path, map_location="cpu") | ||
46 | + model.load_state_dict(state) | ||
47 | + print("Loaded checkpoint successfully") | ||
48 | + | ||
49 | + | ||
50 | +class UNetTrainer: | ||
51 | + def __init__(self, args, cfg): | ||
52 | + | ||
53 | + if args.resume is not None: | ||
54 | + epoch, iters = get_epoch_iters(args.resume) | ||
55 | + else: | ||
56 | + epoch = 0 | ||
57 | + iters = 0 | ||
58 | + | ||
59 | + self.cfg = cfg | ||
60 | + self.step_iters = cfg.step_iters | ||
61 | + self.gamma = cfg.gamma | ||
62 | + self.visualize_per_iter = cfg.visualize_per_iter | ||
63 | + self.print_per_iter = cfg.print_per_iter | ||
64 | + self.save_per_iter = cfg.save_per_iter | ||
65 | + | ||
66 | + self.start_iter = iters | ||
67 | + self.iters = 0 | ||
68 | + self.num_epochs = cfg.num_epochs | ||
69 | + self.device = torch.device("cuda:0" if cfg.cuda else "cpu") | ||
70 | + | ||
71 | + trainset = FacemaskSegDataset(cfg) | ||
72 | + valset = FacemaskSegDataset(cfg, train=False) | ||
73 | + | ||
74 | + self.trainloader = data.DataLoader( | ||
75 | + trainset, | ||
76 | + batch_size=cfg.batch_size, | ||
77 | + num_workers=cfg.num_workers, | ||
78 | + pin_memory=True, | ||
79 | + shuffle=True, | ||
80 | + collate_fn=trainset.collate_fn, | ||
81 | + ) | ||
82 | + | ||
83 | + self.valloader = data.DataLoader( | ||
84 | + valset, | ||
85 | + batch_size=cfg.batch_size, | ||
86 | + num_workers=cfg.num_workers, | ||
87 | + pin_memory=True, | ||
88 | + shuffle=True, | ||
89 | + collate_fn=valset.collate_fn, | ||
90 | + ) | ||
91 | + | ||
92 | + self.epoch = int(self.start_iter / len(self.trainloader)) | ||
93 | + self.iters = self.start_iter | ||
94 | + self.num_iters = (self.num_epochs + 1) * len(self.trainloader) | ||
95 | + | ||
96 | + self.model = UNetSemantic().to(self.device) | ||
97 | + self.criterion_dice = DiceLoss() | ||
98 | + self.criterion_bce = nn.BCELoss() | ||
99 | + | ||
100 | + if args.resume is not None: | ||
101 | + load_checkpoint(self.model, args.resume) | ||
102 | + | ||
103 | + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr) | ||
104 | + | ||
105 | + def validate(self, sample_folder, sample_name, img_list): | ||
106 | + save_img_path = os.path.join(sample_folder, sample_name + ".png") | ||
107 | + img_list = [i.clone().cpu() for i in img_list] | ||
108 | + imgs = torch.stack(img_list, dim=1) | ||
109 | + | ||
110 | + # imgs shape: Bx5xCxWxH | ||
111 | + | ||
112 | + imgs = imgs.view(-1, *list(imgs.size())[2:]) | ||
113 | + save_image(imgs, save_img_path, nrow=3) | ||
114 | + print(f"Save image to {save_img_path}") | ||
115 | + | ||
116 | + def train_epoch(self): | ||
117 | + self.model.train() | ||
118 | + running_loss = { | ||
119 | + "DICE": 0, | ||
120 | + "BCE": 0, | ||
121 | + "T": 0, | ||
122 | + } | ||
123 | + running_time = 0 | ||
124 | + | ||
125 | + for idx, batch in enumerate(self.trainloader): | ||
126 | + self.optimizer.zero_grad() | ||
127 | + inputs = batch["imgs"].to(self.device) | ||
128 | + targets = batch["masks"].to(self.device) | ||
129 | + | ||
130 | + start_time = time.time() | ||
131 | + | ||
132 | + outputs = self.model(inputs) | ||
133 | + | ||
134 | + loss_bce = self.criterion_bce(outputs, targets) | ||
135 | + loss_dice = self.criterion_dice(outputs, targets) | ||
136 | + loss = loss_bce + loss_dice | ||
137 | + loss.backward() | ||
138 | + self.optimizer.step() | ||
139 | + | ||
140 | + end_time = time.time() | ||
141 | + | ||
142 | + running_loss["T"] += loss.item() | ||
143 | + running_loss["DICE"] += loss_dice.item() | ||
144 | + running_loss["BCE"] += loss_bce.item() | ||
145 | + running_time += end_time - start_time | ||
146 | + | ||
147 | + if self.iters % self.print_per_iter == 0: | ||
148 | + for key in running_loss.keys(): | ||
149 | + running_loss[key] /= self.print_per_iter | ||
150 | + running_loss[key] = np.round(running_loss[key], 5) | ||
151 | + loss_string = ( | ||
152 | + "{}".format(running_loss)[1:-1].replace("'", "").replace(",", " ||") | ||
153 | + ) | ||
154 | + running_time = np.round(running_time, 5) | ||
155 | + print( | ||
156 | + "[{}/{}][{}/{}] || {} || Time: {}s".format( | ||
157 | + self.epoch, | ||
158 | + self.num_epochs, | ||
159 | + self.iters, | ||
160 | + self.num_iters, | ||
161 | + loss_string, | ||
162 | + running_time, | ||
163 | + ) | ||
164 | + ) | ||
165 | + running_time = 0 | ||
166 | + running_loss = { | ||
167 | + "DICE": 0, | ||
168 | + "BCE": 0, | ||
169 | + "T": 0, | ||
170 | + } | ||
171 | + | ||
172 | + if self.iters % self.save_per_iter == 0: | ||
173 | + save_path = os.path.join( | ||
174 | + self.cfg.checkpoint_path, | ||
175 | + f"model_segm_{self.epoch}_{self.iters}.pth", | ||
176 | + ) | ||
177 | + torch.save(self.model.state_dict(), save_path) | ||
178 | + print(f"Save model at {save_path}") | ||
179 | + self.iters += 1 | ||
180 | + | ||
181 | + def validate_epoch(self): | ||
182 | + # Validate | ||
183 | + | ||
184 | + self.model.eval() | ||
185 | + metrics = [DiceScore(1), PixelAccuracy(1)] | ||
186 | + running_loss = { | ||
187 | + "DICE": 0, | ||
188 | + "BCE": 0, | ||
189 | + "T": 0, | ||
190 | + } | ||
191 | + | ||
192 | + running_time = 0 | ||
193 | + print( | ||
194 | + "=============================EVALUATION===================================" | ||
195 | + ) | ||
196 | + with torch.no_grad(): | ||
197 | + start_time = time.time() | ||
198 | + for idx, batch in enumerate(tqdm(self.valloader)): | ||
199 | + | ||
200 | + inputs = batch["imgs"].to(self.device) | ||
201 | + targets = batch["masks"].to(self.device) | ||
202 | + outputs = self.model(inputs) | ||
203 | + loss_bce = self.criterion_bce(outputs, targets) | ||
204 | + loss_dice = self.criterion_dice(outputs, targets) | ||
205 | + loss = loss_bce + loss_dice | ||
206 | + running_loss["T"] += loss.item() | ||
207 | + running_loss["DICE"] += loss_dice.item() | ||
208 | + running_loss["BCE"] += loss_bce.item() | ||
209 | + for metric in metrics: | ||
210 | + metric.update(outputs.cpu(), targets.cpu()) | ||
211 | + | ||
212 | + end_time = time.time() | ||
213 | + running_time += end_time - start_time | ||
214 | + running_time = np.round(running_time, 5) | ||
215 | + for key in running_loss.keys(): | ||
216 | + running_loss[key] /= len(self.valloader) | ||
217 | + running_loss[key] = np.round(running_loss[key], 5) | ||
218 | + | ||
219 | + loss_string = ( | ||
220 | + "{}".format(running_loss)[1:-1].replace("'", "").replace(",", " ||") | ||
221 | + ) | ||
222 | + | ||
223 | + print( | ||
224 | + "[{}/{}] || Validation || {} || Time: {}s".format( | ||
225 | + self.epoch, self.num_epochs, loss_string, running_time | ||
226 | + ) | ||
227 | + ) | ||
228 | + for metric in metrics: | ||
229 | + print(metric) | ||
230 | + print( | ||
231 | + "==========================================================================" | ||
232 | + ) | ||
233 | + | ||
234 | + def fit(self): | ||
235 | + try: | ||
236 | + for epoch in range(self.epoch, self.num_epochs + 1): | ||
237 | + self.epoch = epoch | ||
238 | + self.train_epoch() | ||
239 | + self.validate_epoch() | ||
240 | + except KeyboardInterrupt: | ||
241 | + torch.save( | ||
242 | + self.model.state_dict(), | ||
243 | + os.path.join( | ||
244 | + self.cfg.checkpoint_path, | ||
245 | + f"model_segm_{self.epoch}_{self.iters}.pth", | ||
246 | + ), | ||
247 | + ) | ||
248 | + print("Model saved!") | ||
249 | + |
Code/pConv-Keras/.gitignore
0 → 100644
1 | +# Repo-specific | ||
2 | +data/masks/* | ||
3 | +.vscode* | ||
4 | + | ||
5 | +# Byte-compiled / optimized / DLL files | ||
6 | +__pycache__/ | ||
7 | +*.py[cod] | ||
8 | +*$py.class | ||
9 | + | ||
10 | +# C extensions | ||
11 | +*.so | ||
12 | + | ||
13 | +# Distribution / packaging | ||
14 | +.Python | ||
15 | +build/ | ||
16 | +develop-eggs/ | ||
17 | +dist/ | ||
18 | +downloads/ | ||
19 | +eggs/ | ||
20 | +.eggs/ | ||
21 | +lib/ | ||
22 | +lib64/ | ||
23 | +parts/ | ||
24 | +sdist/ | ||
25 | +var/ | ||
26 | +wheels/ | ||
27 | +*.egg-info/ | ||
28 | +.installed.cfg | ||
29 | +*.egg | ||
30 | +MANIFEST | ||
31 | + | ||
32 | +# PyInstaller | ||
33 | +# Usually these files are written by a python script from a template | ||
34 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
35 | +*.manifest | ||
36 | +*.spec | ||
37 | + | ||
38 | +# Installer logs | ||
39 | +pip-log.txt | ||
40 | +pip-delete-this-directory.txt | ||
41 | + | ||
42 | +# Unit test / coverage reports | ||
43 | +htmlcov/ | ||
44 | +.tox/ | ||
45 | +.coverage | ||
46 | +.coverage.* | ||
47 | +.cache | ||
48 | +nosetests.xml | ||
49 | +coverage.xml | ||
50 | +*.cover | ||
51 | +.hypothesis/ | ||
52 | +.pytest_cache/ | ||
53 | + | ||
54 | +# Translations | ||
55 | +*.mo | ||
56 | +*.pot | ||
57 | + | ||
58 | +# Django stuff: | ||
59 | +*.log | ||
60 | +local_settings.py | ||
61 | +db.sqlite3 | ||
62 | + | ||
63 | +# Flask stuff: | ||
64 | +instance/ | ||
65 | +.webassets-cache | ||
66 | + | ||
67 | +# Scrapy stuff: | ||
68 | +.scrapy | ||
69 | + | ||
70 | +# Sphinx documentation | ||
71 | +docs/_build/ | ||
72 | + | ||
73 | +# PyBuilder | ||
74 | +target/ | ||
75 | + | ||
76 | +# Jupyter Notebook | ||
77 | +.ipynb_checkpoints | ||
78 | + | ||
79 | +# pyenv | ||
80 | +.python-version | ||
81 | + | ||
82 | +# celery beat schedule file | ||
83 | +celerybeat-schedule | ||
84 | + | ||
85 | +# SageMath parsed files | ||
86 | +*.sage.py | ||
87 | + | ||
88 | +# Environments | ||
89 | +.env | ||
90 | +.venv | ||
91 | +env/ | ||
92 | +venv/ | ||
93 | +ENV/ | ||
94 | +env.bak/ | ||
95 | +venv.bak/ | ||
96 | + | ||
97 | +# Spyder project settings | ||
98 | +.spyderproject | ||
99 | +.spyproject | ||
100 | + | ||
101 | +# Rope project settings | ||
102 | +.ropeproject | ||
103 | + | ||
104 | +# mkdocs documentation | ||
105 | +/site | ||
106 | + | ||
107 | +# mypy | ||
108 | +.mypy_cache/ | ||
109 | + | ||
110 | +backup* | ||
111 | +pexels_royalty_free_photos* | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/pConv-Keras/libs/__init__.py
0 → 100644
File mode changed
Code/pConv-Keras/libs/pconv_layer.py
0 → 100644
1 | + | ||
2 | +from keras.utils import conv_utils | ||
3 | +from keras import backend as K | ||
4 | +from keras.engine import InputSpec | ||
5 | +from keras.layers import Conv2D | ||
6 | + | ||
7 | + | ||
8 | +class PConv2D(Conv2D): | ||
9 | + def __init__(self, *args, n_channels=3, mono=False, **kwargs): | ||
10 | + super().__init__(*args, **kwargs) | ||
11 | + self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)] | ||
12 | + | ||
13 | + def build(self, input_shape): | ||
14 | + """Adapted from original _Conv() layer of Keras | ||
15 | + param input_shape: list of dimensions for [img, mask] | ||
16 | + """ | ||
17 | + | ||
18 | + if self.data_format == 'channels_first': | ||
19 | + channel_axis = 1 | ||
20 | + else: | ||
21 | + channel_axis = -1 | ||
22 | + | ||
23 | + if input_shape[0][channel_axis] is None: | ||
24 | + raise ValueError('The channel dimension of the inputs should be defined. Found `None`.') | ||
25 | + | ||
26 | + self.input_dim = input_shape[0][channel_axis] | ||
27 | + | ||
28 | + # Image kernel | ||
29 | + kernel_shape = self.kernel_size + (self.input_dim, self.filters) | ||
30 | + self.kernel = self.add_weight(shape=kernel_shape, | ||
31 | + initializer=self.kernel_initializer, | ||
32 | + name='img_kernel', | ||
33 | + regularizer=self.kernel_regularizer, | ||
34 | + constraint=self.kernel_constraint) | ||
35 | + # Mask kernel | ||
36 | + self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters)) | ||
37 | + | ||
38 | + # Calculate padding size to achieve zero-padding | ||
39 | + self.pconv_padding = ( | ||
40 | + (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)), | ||
41 | + (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)), | ||
42 | + ) | ||
43 | + | ||
44 | + # Window size - used for normalization | ||
45 | + self.window_size = self.kernel_size[0] * self.kernel_size[1] | ||
46 | + | ||
47 | + if self.use_bias: | ||
48 | + self.bias = self.add_weight(shape=(self.filters,), | ||
49 | + initializer=self.bias_initializer, | ||
50 | + name='bias', | ||
51 | + regularizer=self.bias_regularizer, | ||
52 | + constraint=self.bias_constraint) | ||
53 | + else: | ||
54 | + self.bias = None | ||
55 | + self.built = True | ||
56 | + | ||
57 | + def call(self, inputs, mask=None): | ||
58 | + ''' | ||
59 | + We will be using the Keras conv2d method, and essentially we have | ||
60 | + to do here is multiply the mask with the input X, before we apply the | ||
61 | + convolutions. For the mask itself, we apply convolutions with all weights | ||
62 | + set to 1. | ||
63 | + Subsequently, we clip mask values to between 0 and 1 | ||
64 | + ''' | ||
65 | + | ||
66 | + # Both image and mask must be supplied | ||
67 | + if type(inputs) is not list or len(inputs) != 2: | ||
68 | + raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs)) | ||
69 | + | ||
70 | + # Padding done explicitly so that padding becomes part of the masked partial convolution | ||
71 | + images = K.spatial_2d_padding(inputs[0], self.pconv_padding, self.data_format) | ||
72 | + masks = K.spatial_2d_padding(inputs[1], self.pconv_padding, self.data_format) | ||
73 | + | ||
74 | + # Apply convolutions to mask | ||
75 | + mask_output = K.conv2d( | ||
76 | + masks, self.kernel_mask, | ||
77 | + strides=self.strides, | ||
78 | + padding='valid', | ||
79 | + data_format=self.data_format, | ||
80 | + dilation_rate=self.dilation_rate | ||
81 | + ) | ||
82 | + | ||
83 | + # Apply convolutions to image | ||
84 | + img_output = K.conv2d( | ||
85 | + (images*masks), self.kernel, | ||
86 | + strides=self.strides, | ||
87 | + padding='valid', | ||
88 | + data_format=self.data_format, | ||
89 | + dilation_rate=self.dilation_rate | ||
90 | + ) | ||
91 | + | ||
92 | + # Calculate the mask ratio on each pixel in the output mask | ||
93 | + mask_ratio = self.window_size / (mask_output + 1e-8) | ||
94 | + | ||
95 | + # Clip output to be between 0 and 1 | ||
96 | + mask_output = K.clip(mask_output, 0, 1) | ||
97 | + | ||
98 | + # Remove ratio values where there are holes | ||
99 | + mask_ratio = mask_ratio * mask_output | ||
100 | + | ||
101 | + # Normalize iamge output | ||
102 | + img_output = img_output * mask_ratio | ||
103 | + | ||
104 | + # Apply bias only to the image (if chosen to do so) | ||
105 | + if self.use_bias: | ||
106 | + img_output = K.bias_add( | ||
107 | + img_output, | ||
108 | + self.bias, | ||
109 | + data_format=self.data_format) | ||
110 | + | ||
111 | + # Apply activations on the image | ||
112 | + if self.activation is not None: | ||
113 | + img_output = self.activation(img_output) | ||
114 | + | ||
115 | + return [img_output, mask_output] | ||
116 | + | ||
117 | + def compute_output_shape(self, input_shape): | ||
118 | + if self.data_format == 'channels_last': | ||
119 | + space = input_shape[0][1:-1] | ||
120 | + new_space = [] | ||
121 | + for i in range(len(space)): | ||
122 | + new_dim = conv_utils.conv_output_length( | ||
123 | + space[i], | ||
124 | + self.kernel_size[i], | ||
125 | + padding='same', | ||
126 | + stride=self.strides[i], | ||
127 | + dilation=self.dilation_rate[i]) | ||
128 | + new_space.append(new_dim) | ||
129 | + new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,) | ||
130 | + return [new_shape, new_shape] | ||
131 | + if self.data_format == 'channels_first': | ||
132 | + space = input_shape[2:] | ||
133 | + new_space = [] | ||
134 | + for i in range(len(space)): | ||
135 | + new_dim = conv_utils.conv_output_length( | ||
136 | + space[i], | ||
137 | + self.kernel_size[i], | ||
138 | + padding='same', | ||
139 | + stride=self.strides[i], | ||
140 | + dilation=self.dilation_rate[i]) | ||
141 | + new_space.append(new_dim) | ||
142 | + new_shape = (input_shape[0], self.filters) + tuple(new_space) | ||
143 | + return [new_shape, new_shape] |
Code/pConv-Keras/libs/pconv_model.py
0 → 100644
1 | +import os | ||
2 | +import sys | ||
3 | +import numpy as np | ||
4 | +from datetime import datetime | ||
5 | + | ||
6 | +import tensorflow as tf | ||
7 | +from keras.models import Model | ||
8 | +from keras.models import load_model | ||
9 | +from keras.optimizers import Adam | ||
10 | +from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation, Lambda | ||
11 | +from keras.layers.merge import Concatenate | ||
12 | +from keras.applications import VGG16 | ||
13 | +from keras import backend as K | ||
14 | +from keras.utils.multi_gpu_utils import multi_gpu_model | ||
15 | + | ||
16 | +from libs.pconv_layer import PConv2D | ||
17 | + | ||
18 | + | ||
19 | +class PConvUnet(object): | ||
20 | + | ||
21 | + def __init__(self, img_rows=512, img_cols=512, vgg_weights="imagenet", inference_only=False, net_name='default', gpus=1, vgg_device=None): | ||
22 | + """Create the PConvUnet. If variable image size, set img_rows and img_cols to None | ||
23 | + | ||
24 | + Args: | ||
25 | + img_rows (int): image height. | ||
26 | + img_cols (int): image width. | ||
27 | + vgg_weights (str): which weights to pass to the vgg network. | ||
28 | + inference_only (bool): initialize BN layers for inference. | ||
29 | + net_name (str): Name of this network (used in logging). | ||
30 | + gpus (int): How many GPUs to use for training. | ||
31 | + vgg_device (str): In case of training with multiple GPUs, specify which device to run VGG inference on. | ||
32 | + e.g. if training on 8 GPUs, vgg inference could be off-loaded exclusively to one GPU, instead of | ||
33 | + running on one of the GPUs which is also training the UNet. | ||
34 | + """ | ||
35 | + | ||
36 | + # Settings | ||
37 | + self.img_rows = img_rows | ||
38 | + self.img_cols = img_cols | ||
39 | + self.img_overlap = 30 | ||
40 | + self.inference_only = inference_only | ||
41 | + self.net_name = net_name | ||
42 | + self.gpus = gpus | ||
43 | + self.vgg_device = vgg_device | ||
44 | + | ||
45 | + # Scaling for VGG input | ||
46 | + self.mean = [0.485, 0.456, 0.406] | ||
47 | + self.std = [0.229, 0.224, 0.225] | ||
48 | + | ||
49 | + # Assertions | ||
50 | + assert self.img_rows >= 256, 'Height must be >256 pixels' | ||
51 | + assert self.img_cols >= 256, 'Width must be >256 pixels' | ||
52 | + | ||
53 | + # Set current epoch | ||
54 | + self.current_epoch = 0 | ||
55 | + | ||
56 | + # VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper) | ||
57 | + self.vgg_layers = [3, 6, 10] | ||
58 | + | ||
59 | + # Instantiate the vgg network | ||
60 | + if self.vgg_device: | ||
61 | + with tf.device(self.vgg_device): | ||
62 | + self.vgg = self.build_vgg(vgg_weights) | ||
63 | + else: | ||
64 | + self.vgg = self.build_vgg(vgg_weights) | ||
65 | + | ||
66 | + # Create UNet-like model | ||
67 | + if self.gpus <= 1: | ||
68 | + self.model, inputs_mask = self.build_pconv_unet() | ||
69 | + self.compile_pconv_unet(self.model, inputs_mask) | ||
70 | + else: | ||
71 | + with tf.device("/cpu:0"): | ||
72 | + self.model, inputs_mask = self.build_pconv_unet() | ||
73 | + self.model = multi_gpu_model(self.model, gpus=self.gpus) | ||
74 | + self.compile_pconv_unet(self.model, inputs_mask) | ||
75 | + | ||
76 | + def build_vgg(self, weights="imagenet"): | ||
77 | + """ | ||
78 | + Load pre-trained VGG16 from keras applications | ||
79 | + Extract features to be used in loss function from last conv layer, see architecture at: | ||
80 | + https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py | ||
81 | + """ | ||
82 | + | ||
83 | + # Input image to extract features from | ||
84 | + img = Input(shape=(self.img_rows, self.img_cols, 3)) | ||
85 | + | ||
86 | + # Mean center and rescale by variance as in PyTorch | ||
87 | + processed = Lambda(lambda x: (x-self.mean) / self.std)(img) | ||
88 | + | ||
89 | + # If inference only, just return empty model | ||
90 | + if self.inference_only: | ||
91 | + model = Model(inputs=img, outputs=[img for _ in range(len(self.vgg_layers))]) | ||
92 | + model.trainable = False | ||
93 | + model.compile(loss='mse', optimizer='adam') | ||
94 | + return model | ||
95 | + | ||
96 | + # Get the vgg network from Keras applications | ||
97 | + if weights in ['imagenet', None]: | ||
98 | + vgg = VGG16(weights=weights, include_top=False) | ||
99 | + else: | ||
100 | + vgg = VGG16(weights=None, include_top=False) | ||
101 | + vgg.load_weights(weights, by_name=True) | ||
102 | + | ||
103 | + # Output the first three pooling layers | ||
104 | + vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers] | ||
105 | + | ||
106 | + # Create model and compile | ||
107 | + model = Model(inputs=img, outputs=vgg(processed)) | ||
108 | + model.trainable = False | ||
109 | + model.compile(loss='mse', optimizer='adam') | ||
110 | + | ||
111 | + return model | ||
112 | + | ||
113 | + def build_pconv_unet(self, train_bn=True): | ||
114 | + | ||
115 | + # INPUTS | ||
116 | + inputs_img = Input((self.img_rows, self.img_cols, 3), name='inputs_img') | ||
117 | + inputs_mask = Input((self.img_rows, self.img_cols, 3), name='inputs_mask') | ||
118 | + | ||
119 | + # ENCODER | ||
120 | + def encoder_layer(img_in, mask_in, filters, kernel_size, bn=True): | ||
121 | + conv, mask = PConv2D(filters, kernel_size, strides=2, padding='same')([img_in, mask_in]) | ||
122 | + if bn: | ||
123 | + conv = BatchNormalization(name='EncBN'+str(encoder_layer.counter))(conv, training=train_bn) | ||
124 | + conv = Activation('relu')(conv) | ||
125 | + encoder_layer.counter += 1 | ||
126 | + return conv, mask | ||
127 | + encoder_layer.counter = 0 | ||
128 | + | ||
129 | + e_conv1, e_mask1 = encoder_layer(inputs_img, inputs_mask, 64, 7, bn=False) | ||
130 | + e_conv2, e_mask2 = encoder_layer(e_conv1, e_mask1, 128, 5) | ||
131 | + e_conv3, e_mask3 = encoder_layer(e_conv2, e_mask2, 256, 5) | ||
132 | + e_conv4, e_mask4 = encoder_layer(e_conv3, e_mask3, 512, 3) | ||
133 | + e_conv5, e_mask5 = encoder_layer(e_conv4, e_mask4, 512, 3) | ||
134 | + e_conv6, e_mask6 = encoder_layer(e_conv5, e_mask5, 512, 3) | ||
135 | + e_conv7, e_mask7 = encoder_layer(e_conv6, e_mask6, 512, 3) | ||
136 | + e_conv8, e_mask8 = encoder_layer(e_conv7, e_mask7, 512, 3) | ||
137 | + | ||
138 | + # DECODER | ||
139 | + def decoder_layer(img_in, mask_in, e_conv, e_mask, filters, kernel_size, bn=True): | ||
140 | + up_img = UpSampling2D(size=(2,2))(img_in) | ||
141 | + up_mask = UpSampling2D(size=(2,2))(mask_in) | ||
142 | + concat_img = Concatenate(axis=3)([e_conv,up_img]) | ||
143 | + concat_mask = Concatenate(axis=3)([e_mask,up_mask]) | ||
144 | + conv, mask = PConv2D(filters, kernel_size, padding='same')([concat_img, concat_mask]) | ||
145 | + if bn: | ||
146 | + conv = BatchNormalization()(conv) | ||
147 | + conv = LeakyReLU(alpha=0.2)(conv) | ||
148 | + return conv, mask | ||
149 | + | ||
150 | + d_conv9, d_mask9 = decoder_layer(e_conv8, e_mask8, e_conv7, e_mask7, 512, 3) | ||
151 | + d_conv10, d_mask10 = decoder_layer(d_conv9, d_mask9, e_conv6, e_mask6, 512, 3) | ||
152 | + d_conv11, d_mask11 = decoder_layer(d_conv10, d_mask10, e_conv5, e_mask5, 512, 3) | ||
153 | + d_conv12, d_mask12 = decoder_layer(d_conv11, d_mask11, e_conv4, e_mask4, 512, 3) | ||
154 | + d_conv13, d_mask13 = decoder_layer(d_conv12, d_mask12, e_conv3, e_mask3, 256, 3) | ||
155 | + d_conv14, d_mask14 = decoder_layer(d_conv13, d_mask13, e_conv2, e_mask2, 128, 3) | ||
156 | + d_conv15, d_mask15 = decoder_layer(d_conv14, d_mask14, e_conv1, e_mask1, 64, 3) | ||
157 | + d_conv16, d_mask16 = decoder_layer(d_conv15, d_mask15, inputs_img, inputs_mask, 3, 3, bn=False) | ||
158 | + outputs = Conv2D(3, 1, activation = 'sigmoid', name='outputs_img')(d_conv16) | ||
159 | + | ||
160 | + # Setup the model inputs / outputs | ||
161 | + model = Model(inputs=[inputs_img, inputs_mask], outputs=outputs) | ||
162 | + | ||
163 | + return model, inputs_mask | ||
164 | + | ||
165 | + def compile_pconv_unet(self, model, inputs_mask, lr=0.0002): | ||
166 | + model.compile( | ||
167 | + optimizer = Adam(lr=lr), | ||
168 | + loss=self.loss_total(inputs_mask), | ||
169 | + metrics=[self.PSNR] | ||
170 | + ) | ||
171 | + | ||
172 | + def loss_total(self, mask): | ||
173 | + """ | ||
174 | + Creates a loss function which sums all the loss components | ||
175 | + and multiplies by their weights. See paper eq. 7. | ||
176 | + """ | ||
177 | + def loss(y_true, y_pred): | ||
178 | + | ||
179 | + # Compute predicted image with non-hole pixels set to ground truth | ||
180 | + y_comp = mask * y_true + (1-mask) * y_pred | ||
181 | + | ||
182 | + # Compute the vgg features. | ||
183 | + if self.vgg_device: | ||
184 | + with tf.device(self.vgg_device): | ||
185 | + vgg_out = self.vgg(y_pred) | ||
186 | + vgg_gt = self.vgg(y_true) | ||
187 | + vgg_comp = self.vgg(y_comp) | ||
188 | + else: | ||
189 | + vgg_out = self.vgg(y_pred) | ||
190 | + vgg_gt = self.vgg(y_true) | ||
191 | + vgg_comp = self.vgg(y_comp) | ||
192 | + | ||
193 | + # Compute loss components | ||
194 | + l1 = self.loss_valid(mask, y_true, y_pred) | ||
195 | + l2 = self.loss_hole(mask, y_true, y_pred) | ||
196 | + l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp) | ||
197 | + l4 = self.loss_style(vgg_out, vgg_gt) | ||
198 | + l5 = self.loss_style(vgg_comp, vgg_gt) | ||
199 | + l6 = self.loss_tv(mask, y_comp) | ||
200 | + | ||
201 | + # Return loss function | ||
202 | + return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6 | ||
203 | + | ||
204 | + return loss | ||
205 | + | ||
206 | + def loss_hole(self, mask, y_true, y_pred): | ||
207 | + """Pixel L1 loss within the hole / mask""" | ||
208 | + return self.l1((1-mask) * y_true, (1-mask) * y_pred) | ||
209 | + | ||
210 | + def loss_valid(self, mask, y_true, y_pred): | ||
211 | + """Pixel L1 loss outside the hole / mask""" | ||
212 | + return self.l1(mask * y_true, mask * y_pred) | ||
213 | + | ||
214 | + def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp): | ||
215 | + """Perceptual loss based on VGG16, see. eq. 3 in paper""" | ||
216 | + loss = 0 | ||
217 | + for o, c, g in zip(vgg_out, vgg_comp, vgg_gt): | ||
218 | + loss += self.l1(o, g) + self.l1(c, g) | ||
219 | + return loss | ||
220 | + | ||
221 | + def loss_style(self, output, vgg_gt): | ||
222 | + """Style loss based on output/computation, used for both eq. 4 & 5 in paper""" | ||
223 | + loss = 0 | ||
224 | + for o, g in zip(output, vgg_gt): | ||
225 | + loss += self.l1(self.gram_matrix(o), self.gram_matrix(g)) | ||
226 | + return loss | ||
227 | + | ||
228 | + def loss_tv(self, mask, y_comp): | ||
229 | + """Total variation loss, used for smoothing the hole region, see. eq. 6""" | ||
230 | + | ||
231 | + # Create dilated hole region using a 3x3 kernel of all 1s. | ||
232 | + kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3])) | ||
233 | + dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same') | ||
234 | + | ||
235 | + # Cast values to be [0., 1.], and compute dilated hole region of y_comp | ||
236 | + dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32') | ||
237 | + P = dilated_mask * y_comp | ||
238 | + | ||
239 | + # Calculate total variation loss | ||
240 | + a = self.l1(P[:,1:,:,:], P[:,:-1,:,:]) | ||
241 | + b = self.l1(P[:,:,1:,:], P[:,:,:-1,:]) | ||
242 | + return a+b | ||
243 | + | ||
244 | + def fit_generator(self, generator, *args, **kwargs): | ||
245 | + """Fit the U-Net to a (images, targets) generator | ||
246 | + | ||
247 | + Args: | ||
248 | + generator (generator): generator supplying input image & mask, as well as targets. | ||
249 | + *args: arguments to be passed to fit_generator | ||
250 | + **kwargs: keyword arguments to be passed to fit_generator | ||
251 | + """ | ||
252 | + self.model.fit_generator( | ||
253 | + generator, | ||
254 | + *args, **kwargs | ||
255 | + ) | ||
256 | + | ||
257 | + def summary(self): | ||
258 | + """Get summary of the UNet model""" | ||
259 | + print(self.model.summary()) | ||
260 | + | ||
261 | + def load(self, filepath, train_bn=True, lr=0.0002): | ||
262 | + | ||
263 | + # Create UNet-like model | ||
264 | + self.model, inputs_mask = self.build_pconv_unet(train_bn) | ||
265 | + self.compile_pconv_unet(self.model, inputs_mask, lr) | ||
266 | + | ||
267 | + # Load weights into model | ||
268 | + epoch = int(os.path.basename(filepath).split('.')[1].split('-')[0]) | ||
269 | + assert epoch > 0, "Could not parse weight file. Should include the epoch" | ||
270 | + self.current_epoch = epoch | ||
271 | + self.model.load_weights(filepath) | ||
272 | + | ||
273 | + @staticmethod | ||
274 | + def PSNR(y_true, y_pred): | ||
275 | + """ | ||
276 | + PSNR is Peek Signal to Noise Ratio, see https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio | ||
277 | + The equation is: | ||
278 | + PSNR = 20 * log10(MAX_I) - 10 * log10(MSE) | ||
279 | + | ||
280 | + Our input is scaled with be within the range -2.11 to 2.64 (imagenet value scaling). We use the difference between these | ||
281 | + two values (4.75) as MAX_I | ||
282 | + """ | ||
283 | + #return 20 * K.log(4.75) / K.log(10.0) - 10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0) | ||
284 | + return - 10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0) | ||
285 | + | ||
286 | + @staticmethod | ||
287 | + def current_timestamp(): | ||
288 | + return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') | ||
289 | + | ||
290 | + @staticmethod | ||
291 | + def l1(y_true, y_pred): | ||
292 | + """Calculate the L1 loss used in all loss calculations""" | ||
293 | + if K.ndim(y_true) == 4: | ||
294 | + return K.mean(K.abs(y_pred - y_true), axis=[1,2,3]) | ||
295 | + elif K.ndim(y_true) == 3: | ||
296 | + return K.mean(K.abs(y_pred - y_true), axis=[1,2]) | ||
297 | + else: | ||
298 | + raise NotImplementedError("Calculating L1 loss on 1D tensors? should not occur for this network") | ||
299 | + | ||
300 | + @staticmethod | ||
301 | + def gram_matrix(x, norm_by_channels=False): | ||
302 | + """Calculate gram matrix used in style loss""" | ||
303 | + | ||
304 | + # Assertions on input | ||
305 | + assert K.ndim(x) == 4, 'Input tensor should be a 4d (B, H, W, C) tensor' | ||
306 | + assert K.image_data_format() == 'channels_last', "Please use channels-last format" | ||
307 | + | ||
308 | + # Permute channels and get resulting shape | ||
309 | + x = K.permute_dimensions(x, (0, 3, 1, 2)) | ||
310 | + shape = K.shape(x) | ||
311 | + B, C, H, W = shape[0], shape[1], shape[2], shape[3] | ||
312 | + | ||
313 | + # Reshape x and do batch dot product | ||
314 | + features = K.reshape(x, K.stack([B, C, H*W])) | ||
315 | + gram = K.batch_dot(features, features, axes=2) | ||
316 | + | ||
317 | + # Normalize with channels, height and width | ||
318 | + gram = gram / K.cast(C * H * W, x.dtype) | ||
319 | + | ||
320 | + return gram | ||
321 | + | ||
322 | + # Prediction functions | ||
323 | + ###################### | ||
324 | + def predict(self, sample, **kwargs): | ||
325 | + """Run prediction using this model""" | ||
326 | + return self.model.predict(sample, **kwargs) |
Code/pConv-Keras/libs/util.py
0 → 100644
1 | +import os | ||
2 | +from random import randint, seed | ||
3 | +import itertools | ||
4 | +import numpy as np | ||
5 | +import cv2 | ||
6 | + | ||
7 | + | ||
8 | +class MaskGenerator(): | ||
9 | + | ||
10 | + def __init__(self, height, width, channels=3, rand_seed=None, filepath=None): | ||
11 | + """Convenience functions for generating masks to be used for inpainting training | ||
12 | + | ||
13 | + Arguments: | ||
14 | + height {int} -- Mask height | ||
15 | + width {width} -- Mask width | ||
16 | + | ||
17 | + Keyword Arguments: | ||
18 | + channels {int} -- Channels to output (default: {3}) | ||
19 | + rand_seed {[type]} -- Random seed (default: {None}) | ||
20 | + filepath {[type]} -- Load masks from filepath. If None, generate masks with OpenCV (default: {None}) | ||
21 | + """ | ||
22 | + | ||
23 | + self.height = height | ||
24 | + self.width = width | ||
25 | + self.channels = channels | ||
26 | + self.filepath = filepath | ||
27 | + | ||
28 | + # If filepath supplied, load the list of masks within the directory | ||
29 | + self.mask_files = [] | ||
30 | + if self.filepath: | ||
31 | + filenames = [f for f in os.listdir(self.filepath)] | ||
32 | + self.mask_files = [f for f in filenames if any(filetype in f.lower() for filetype in ['.jpeg', '.png', '.jpg'])] | ||
33 | + print(">> Found {} masks in {}".format(len(self.mask_files), self.filepath)) | ||
34 | + | ||
35 | + # Seed for reproducibility | ||
36 | + if rand_seed: | ||
37 | + seed(rand_seed) | ||
38 | + | ||
39 | + def _generate_mask(self): | ||
40 | + """Generates a random irregular mask with lines, circles and elipses""" | ||
41 | + | ||
42 | + img = np.zeros((self.height, self.width, self.channels), np.uint8) | ||
43 | + | ||
44 | + # Set size scale | ||
45 | + size = int((self.width + self.height) * 0.03) | ||
46 | + if self.width < 64 or self.height < 64: | ||
47 | + raise Exception("Width and Height of mask must be at least 64!") | ||
48 | + | ||
49 | + # Draw random lines | ||
50 | + for _ in range(randint(1, 20)): | ||
51 | + x1, x2 = randint(1, self.width), randint(1, self.width) | ||
52 | + y1, y2 = randint(1, self.height), randint(1, self.height) | ||
53 | + thickness = randint(3, size) | ||
54 | + cv2.line(img,(x1,y1),(x2,y2),(1,1,1),thickness) | ||
55 | + | ||
56 | + # Draw random circles | ||
57 | + for _ in range(randint(1, 20)): | ||
58 | + x1, y1 = randint(1, self.width), randint(1, self.height) | ||
59 | + radius = randint(3, size) | ||
60 | + cv2.circle(img,(x1,y1),radius,(1,1,1), -1) | ||
61 | + | ||
62 | + # Draw random ellipses | ||
63 | + for _ in range(randint(1, 20)): | ||
64 | + x1, y1 = randint(1, self.width), randint(1, self.height) | ||
65 | + s1, s2 = randint(1, self.width), randint(1, self.height) | ||
66 | + a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180) | ||
67 | + thickness = randint(3, size) | ||
68 | + cv2.ellipse(img, (x1,y1), (s1,s2), a1, a2, a3,(1,1,1), thickness) | ||
69 | + | ||
70 | + return 1-img | ||
71 | + | ||
72 | + def _load_mask(self, rotation=True, dilation=True, cropping=True): | ||
73 | + """Loads a mask from disk, and optionally augments it""" | ||
74 | + | ||
75 | + # Read image | ||
76 | + mask = cv2.imread(os.path.join(self.filepath, np.random.choice(self.mask_files, 1, replace=False)[0])) | ||
77 | + | ||
78 | + # Random rotation | ||
79 | + if rotation: | ||
80 | + rand = np.random.randint(-180, 180) | ||
81 | + M = cv2.getRotationMatrix2D((mask.shape[1]/2, mask.shape[0]/2), rand, 1.5) | ||
82 | + mask = cv2.warpAffine(mask, M, (mask.shape[1], mask.shape[0])) | ||
83 | + | ||
84 | + # Random dilation | ||
85 | + if dilation: | ||
86 | + rand = np.random.randint(5, 47) | ||
87 | + kernel = np.ones((rand, rand), np.uint8) | ||
88 | + mask = cv2.erode(mask, kernel, iterations=1) | ||
89 | + | ||
90 | + # Random cropping | ||
91 | + if cropping: | ||
92 | + x = np.random.randint(0, mask.shape[1] - self.width) | ||
93 | + y = np.random.randint(0, mask.shape[0] - self.height) | ||
94 | + mask = mask[y:y+self.height, x:x+self.width] | ||
95 | + | ||
96 | + return (mask > 1).astype(np.uint8) | ||
97 | + | ||
98 | + def sample(self, random_seed=None): | ||
99 | + """Retrieve a random mask""" | ||
100 | + if random_seed: | ||
101 | + seed(random_seed) | ||
102 | + if self.filepath and len(self.mask_files) > 0: | ||
103 | + return self._load_mask() | ||
104 | + else: | ||
105 | + return self._generate_mask() | ||
106 | + | ||
107 | + | ||
108 | +class ImageChunker(object): | ||
109 | + | ||
110 | + def __init__(self, rows, cols, overlap): | ||
111 | + self.rows = rows | ||
112 | + self.cols = cols | ||
113 | + self.overlap = overlap | ||
114 | + | ||
115 | + def perform_chunking(self, img_size, chunk_size): | ||
116 | + """ | ||
117 | + Given an image dimension img_size, return list of (start, stop) | ||
118 | + tuples to perform chunking of chunk_size | ||
119 | + """ | ||
120 | + chunks, i = [], 0 | ||
121 | + while True: | ||
122 | + chunks.append((i*(chunk_size - self.overlap/2), i*(chunk_size - self.overlap/2)+chunk_size)) | ||
123 | + i+=1 | ||
124 | + if chunks[-1][1] > img_size: | ||
125 | + break | ||
126 | + n_count = len(chunks) | ||
127 | + chunks[-1] = tuple(x - (n_count*chunk_size - img_size - (n_count-1)*self.overlap/2) for x in chunks[-1]) | ||
128 | + chunks = [(int(x), int(y)) for x, y in chunks] | ||
129 | + return chunks | ||
130 | + | ||
131 | + def get_chunks(self, img, scale=1): | ||
132 | + """ | ||
133 | + Get width and height lists of (start, stop) tuples for chunking of img. | ||
134 | + """ | ||
135 | + x_chunks, y_chunks = [(0, self.rows)], [(0, self.cols)] | ||
136 | + if img.shape[0] > self.rows: | ||
137 | + x_chunks = self.perform_chunking(img.shape[0], self.rows) | ||
138 | + else: | ||
139 | + x_chunks = [(0, img.shape[0])] | ||
140 | + if img.shape[1] > self.cols: | ||
141 | + y_chunks = self.perform_chunking(img.shape[1], self.cols) | ||
142 | + else: | ||
143 | + y_chunks = [(0, img.shape[1])] | ||
144 | + return x_chunks, y_chunks | ||
145 | + | ||
146 | + def dimension_preprocess(self, img, padding=True): | ||
147 | + """ | ||
148 | + In case of prediction on image of different size than 512x512, | ||
149 | + this function is used to add padding and chunk up the image into pieces | ||
150 | + of 512x512, which can then later be reconstructed into the original image | ||
151 | + using the dimension_postprocess() function. | ||
152 | + """ | ||
153 | + | ||
154 | + # Assert single image input | ||
155 | + assert len(img.shape) == 3, "Image dimension expected to be (H, W, C)" | ||
156 | + | ||
157 | + # Check if we are adding padding for too small images | ||
158 | + if padding: | ||
159 | + | ||
160 | + # Check if height is too small | ||
161 | + if img.shape[0] < self.rows: | ||
162 | + padding = np.ones((self.rows - img.shape[0], img.shape[1], img.shape[2])) | ||
163 | + img = np.concatenate((img, padding), axis=0) | ||
164 | + | ||
165 | + # Check if width is too small | ||
166 | + if img.shape[1] < self.cols: | ||
167 | + padding = np.ones((img.shape[0], self.cols - img.shape[1], img.shape[2])) | ||
168 | + img = np.concatenate((img, padding), axis=1) | ||
169 | + | ||
170 | + # Get chunking of the image | ||
171 | + x_chunks, y_chunks = self.get_chunks(img) | ||
172 | + | ||
173 | + # Chunk up the image | ||
174 | + images = [] | ||
175 | + for x in x_chunks: | ||
176 | + for y in y_chunks: | ||
177 | + images.append( | ||
178 | + img[x[0]:x[1], y[0]:y[1], :] | ||
179 | + ) | ||
180 | + images = np.array(images) | ||
181 | + return images | ||
182 | + | ||
183 | + def dimension_postprocess(self, chunked_images, original_image, scale=1, padding=True): | ||
184 | + """ | ||
185 | + In case of prediction on image of different size than 512x512, | ||
186 | + the dimension_preprocess function is used to add padding and chunk | ||
187 | + up the image into pieces of 512x512, and this function is used to | ||
188 | + reconstruct these pieces into the original image. | ||
189 | + """ | ||
190 | + | ||
191 | + # Assert input dimensions | ||
192 | + assert len(original_image.shape) == 3, "Image dimension expected to be (H, W, C)" | ||
193 | + assert len(chunked_images.shape) == 4, "Chunked images dimension expected to be (B, H, W, C)" | ||
194 | + | ||
195 | + # Check if we are adding padding for too small images | ||
196 | + if padding: | ||
197 | + | ||
198 | + # Check if height is too small | ||
199 | + if original_image.shape[0] < self.rows: | ||
200 | + new_images = [] | ||
201 | + for img in chunked_images: | ||
202 | + new_images.append(img[0:scale*original_image.shape[0], :, :]) | ||
203 | + chunked_images = np.array(new_images) | ||
204 | + | ||
205 | + # Check if width is too small | ||
206 | + if original_image.shape[1] < self.cols: | ||
207 | + new_images = [] | ||
208 | + for img in chunked_images: | ||
209 | + new_images.append(img[:, 0:scale*original_image.shape[1], :]) | ||
210 | + chunked_images = np.array(new_images) | ||
211 | + | ||
212 | + # Put reconstruction into this array | ||
213 | + new_shape = ( | ||
214 | + original_image.shape[0]*scale, | ||
215 | + original_image.shape[1]*scale, | ||
216 | + original_image.shape[2] | ||
217 | + ) | ||
218 | + reconstruction = np.zeros(new_shape) | ||
219 | + | ||
220 | + # Get the chunks for this image | ||
221 | + x_chunks, y_chunks = self.get_chunks(original_image) | ||
222 | + | ||
223 | + i = 0 | ||
224 | + s = scale | ||
225 | + for x in x_chunks: | ||
226 | + for y in y_chunks: | ||
227 | + | ||
228 | + prior_fill = reconstruction != 0 | ||
229 | + chunk = np.zeros(new_shape) | ||
230 | + chunk[x[0]*s:x[1]*s, y[0]*s:y[1]*s, :] += chunked_images[i] | ||
231 | + chunk_fill = chunk != 0 | ||
232 | + | ||
233 | + reconstruction += chunk | ||
234 | + reconstruction[prior_fill & chunk_fill] = reconstruction[prior_fill & chunk_fill] / 2 | ||
235 | + | ||
236 | + i += 1 | ||
237 | + | ||
238 | + return reconstruction | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Code/pConv-Keras/main.py
0 → 100644
1 | +import os | ||
2 | +import gc | ||
3 | +import datetime | ||
4 | +import numpy as np | ||
5 | +import pandas as pd | ||
6 | +import cv2 | ||
7 | + | ||
8 | +from argparse import ArgumentParser | ||
9 | +from copy import deepcopy | ||
10 | +from tqdm import tqdm | ||
11 | + | ||
12 | +from keras.preprocessing.image import ImageDataGenerator | ||
13 | +from keras.callbacks import TensorBoard, ModelCheckpoint, LambdaCallback | ||
14 | +from keras import backend as K | ||
15 | +from keras.utils import Sequence | ||
16 | +from keras_tqdm import TQDMCallback | ||
17 | + | ||
18 | +import matplotlib.pyplot as plt | ||
19 | +from matplotlib.ticker import NullFormatter | ||
20 | + | ||
21 | +from libs.pconv_model import PConvUnet | ||
22 | +from libs.util import MaskGenerator | ||
23 | + | ||
24 | + | ||
25 | +# Sample call | ||
26 | +r""" | ||
27 | +# Train on CelebaHQ | ||
28 | +python main.py --name CelebHQ --train C:\Documents\Kaggle\celebaHQ-512\train\ --validation C:\Documents\Kaggle\celebaHQ-512\val\ --test C:\Documents\Kaggle\celebaHQ-512\test\ --checkpoint "C:\Users\Mathias Felix Gruber\Documents\GitHub\PConv-Keras\data\logs\imagenet_phase1_paperMasks\weights.35-0.70.h5" | ||
29 | +""" | ||
30 | + | ||
31 | + | ||
32 | +def parse_args(): | ||
33 | + parser = ArgumentParser(description="Training script for PConv inpainting") | ||
34 | + | ||
35 | + parser.add_argument( | ||
36 | + "-stage", | ||
37 | + "--stage", | ||
38 | + type=str, | ||
39 | + default="train", | ||
40 | + help="Which stage of training to run", | ||
41 | + choices=["train", "finetune"], | ||
42 | + ) | ||
43 | + | ||
44 | + parser.add_argument( | ||
45 | + "-train", "--train", type=str, help="Folder with training images" | ||
46 | + ) | ||
47 | + | ||
48 | + parser.add_argument( | ||
49 | + "-validation", "--validation", type=str, help="Folder with validation images" | ||
50 | + ) | ||
51 | + | ||
52 | + parser.add_argument("-test", "--test", type=str, help="Folder with testing images") | ||
53 | + | ||
54 | + parser.add_argument( | ||
55 | + "-name", | ||
56 | + "--name", | ||
57 | + type=str, | ||
58 | + default="myDataset", | ||
59 | + help="Dataset name, e.g. 'imagenet'", | ||
60 | + ) | ||
61 | + | ||
62 | + parser.add_argument( | ||
63 | + "-batch_size", | ||
64 | + "--batch_size", | ||
65 | + type=int, | ||
66 | + default=4, | ||
67 | + help="What batch-size should we use", | ||
68 | + ) | ||
69 | + | ||
70 | + parser.add_argument( | ||
71 | + "-test_path", | ||
72 | + "--test_path", | ||
73 | + type=str, | ||
74 | + default="./data/test_samples/", | ||
75 | + help="Where to output test images during training", | ||
76 | + ) | ||
77 | + | ||
78 | + parser.add_argument( | ||
79 | + "-weight_path", | ||
80 | + "--weight_path", | ||
81 | + type=str, | ||
82 | + default="./data/logs/", | ||
83 | + help="Where to output weights during training", | ||
84 | + ) | ||
85 | + | ||
86 | + parser.add_argument( | ||
87 | + "-log_path", | ||
88 | + "--log_path", | ||
89 | + type=str, | ||
90 | + default="./data/logs/", | ||
91 | + help="Where to output tensorboard logs during training", | ||
92 | + ) | ||
93 | + | ||
94 | + parser.add_argument( | ||
95 | + "-vgg_path", | ||
96 | + "--vgg_path", | ||
97 | + type=str, | ||
98 | + default="./data/logs/pytorch_to_keras_vgg16.h5", | ||
99 | + help="VGG16 weights trained on PyTorch with pixel scaling 1/255.", | ||
100 | + ) | ||
101 | + | ||
102 | + parser.add_argument( | ||
103 | + "-checkpoint", | ||
104 | + "--checkpoint", | ||
105 | + type=str, | ||
106 | + help="Previous weights to be loaded onto model", | ||
107 | + ) | ||
108 | + | ||
109 | + return parser.parse_args() | ||
110 | + | ||
111 | + | ||
112 | +class AugmentingDataGenerator(ImageDataGenerator): | ||
113 | + """Wrapper for ImageDataGenerator to return mask & image""" | ||
114 | + | ||
115 | + def flow_from_directory(self, directory, mask_generator, *args, **kwargs): | ||
116 | + generator = super().flow_from_directory( | ||
117 | + directory, class_mode=None, *args, **kwargs | ||
118 | + ) | ||
119 | + seed = None if "seed" not in kwargs else kwargs["seed"] | ||
120 | + while True: | ||
121 | + | ||
122 | + # Get augmentend image samples | ||
123 | + ori = next(generator) | ||
124 | + | ||
125 | + # Get masks for each image sample | ||
126 | + mask = np.stack( | ||
127 | + [mask_generator.sample(seed) for _ in range(ori.shape[0])], axis=0 | ||
128 | + ) | ||
129 | + | ||
130 | + # Apply masks to all image sample | ||
131 | + masked = deepcopy(ori) | ||
132 | + masked[mask == 0] = 1 | ||
133 | + | ||
134 | + # Yield ([ori, masl], ori) training batches | ||
135 | + # print(masked.shape, ori.shape) | ||
136 | + gc.collect() | ||
137 | + yield [masked, mask], ori | ||
138 | + | ||
139 | + | ||
140 | +# Run script | ||
141 | +if __name__ == "__main__": | ||
142 | + | ||
143 | + # Parse command-line arguments | ||
144 | + args = parse_args() | ||
145 | + | ||
146 | + if args.stage == "finetune" and not args.checkpoint: | ||
147 | + raise AttributeError( | ||
148 | + "If you are finetuning your model, you must supply a checkpoint file" | ||
149 | + ) | ||
150 | + | ||
151 | + # Create training generator | ||
152 | + train_datagen = AugmentingDataGenerator( | ||
153 | + rotation_range=10, | ||
154 | + width_shift_range=0.1, | ||
155 | + height_shift_range=0.1, | ||
156 | + rescale=1.0 / 255, | ||
157 | + horizontal_flip=True, | ||
158 | + ) | ||
159 | + train_generator = train_datagen.flow_from_directory( | ||
160 | + args.train, | ||
161 | + MaskGenerator(512, 512, 3), | ||
162 | + target_size=(512, 512), | ||
163 | + batch_size=args.batch_size, | ||
164 | + ) | ||
165 | + | ||
166 | + # Create validation generator | ||
167 | + val_datagen = AugmentingDataGenerator(rescale=1.0 / 255) | ||
168 | + val_generator = val_datagen.flow_from_directory( | ||
169 | + args.validation, | ||
170 | + MaskGenerator(512, 512, 3), | ||
171 | + target_size=(512, 512), | ||
172 | + batch_size=args.batch_size, | ||
173 | + classes=["val"], | ||
174 | + seed=42, | ||
175 | + ) | ||
176 | + | ||
177 | + # Create testing generator | ||
178 | + test_datagen = AugmentingDataGenerator(rescale=1.0 / 255) | ||
179 | + test_generator = test_datagen.flow_from_directory( | ||
180 | + args.test, | ||
181 | + MaskGenerator(512, 512, 3), | ||
182 | + target_size=(512, 512), | ||
183 | + batch_size=args.batch_size, | ||
184 | + seed=42, | ||
185 | + ) | ||
186 | + | ||
187 | + # Pick out an example to be send to test samples folder | ||
188 | + test_data = next(test_generator) | ||
189 | + (masked, mask), ori = test_data | ||
190 | + | ||
191 | + def plot_callback(model, path): | ||
192 | + """Called at the end of each epoch, displaying our previous test images, | ||
193 | + as well as their masked predictions and saving them to disk""" | ||
194 | + | ||
195 | + # Get samples & Display them | ||
196 | + pred_img = model.predict([masked, mask]) | ||
197 | + pred_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | ||
198 | + | ||
199 | + # Clear current output and display test images | ||
200 | + for i in range(len(ori)): | ||
201 | + _, axes = plt.subplots(1, 3, figsize=(20, 5)) | ||
202 | + axes[0].imshow(masked[i, :, :, :]) | ||
203 | + axes[1].imshow(pred_img[i, :, :, :] * 1.0) | ||
204 | + axes[2].imshow(ori[i, :, :, :]) | ||
205 | + axes[0].set_title("Masked Image") | ||
206 | + axes[1].set_title("Predicted Image") | ||
207 | + axes[2].set_title("Original Image") | ||
208 | + | ||
209 | + plt.savefig(os.path.join(path, "/img_{}_{}.png".format(i, pred_time))) | ||
210 | + plt.close() | ||
211 | + | ||
212 | + # Load the model | ||
213 | + if args.vgg_path: | ||
214 | + model = PConvUnet(vgg_weights=args.vgg_path) | ||
215 | + else: | ||
216 | + model = PConvUnet() | ||
217 | + | ||
218 | + # Loading of checkpoint | ||
219 | + if args.checkpoint: | ||
220 | + if args.stage == "train": | ||
221 | + model.load(args.checkpoint) | ||
222 | + elif args.stage == "finetune": | ||
223 | + model.load(args.checkpoint, train_bn=False, lr=0.00005) | ||
224 | + | ||
225 | + # Fit model | ||
226 | + model.fit_generator( | ||
227 | + train_generator, | ||
228 | + steps_per_epoch=10000, | ||
229 | + validation_data=val_generator, | ||
230 | + validation_steps=1000, | ||
231 | + epochs=100, | ||
232 | + verbose=0, | ||
233 | + callbacks=[ | ||
234 | + TensorBoard( | ||
235 | + log_dir=os.path.join(args.log_path, args.name + "_phase1"), | ||
236 | + write_graph=False, | ||
237 | + ), | ||
238 | + ModelCheckpoint( | ||
239 | + os.path.join( | ||
240 | + args.log_path, | ||
241 | + args.name + "_phase1", | ||
242 | + "weights.{epoch:02d}-{loss:.2f}.h5", | ||
243 | + ), | ||
244 | + monitor="val_loss", | ||
245 | + save_best_only=True, | ||
246 | + save_weights_only=True, | ||
247 | + ), | ||
248 | + LambdaCallback( | ||
249 | + on_epoch_end=lambda epoch, logs: plot_callback(model, args.test_path) | ||
250 | + ), | ||
251 | + TQDMCallback(), | ||
252 | + ], | ||
253 | + ) | ||
254 | + |
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
Code/pConv-Keras/requirements.txt
0 → 100644
1 | +h5py==2.8.0 | ||
2 | +Keras==2.2.4 | ||
3 | +Keras-Applications==1.0.6 | ||
4 | +Keras-Preprocessing==1.0.5 | ||
5 | +keras-tqdm==2.0.1 | ||
6 | +matplotlib==3.0.2 | ||
7 | +numpy==1.15.4 | ||
8 | +pandas==0.23.4 | ||
9 | +scipy==1.1.0 | ||
10 | +seaborn==0.9.0 | ||
11 | +tables==3.4.4 | ||
12 | +tensorboard==1.12.2 | ||
13 | +tensorflow==1.12.0 | ||
14 | +tqdm==4.28.1 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Docs/주간보고서/10주차_주간보고서_0507.docx
0 → 100644
No preview for this file type
Docs/주간보고서/11주차_주간보고서_0514.docx
0 → 100644
No preview for this file type
Docs/주간보고서/12주차_주간보고서_0521.docx
0 → 100644
No preview for this file type
Docs/주간보고서/14주차_주간보고서_0604.docx
0 → 100644
No preview for this file type
Docs/주간보고서/8주차_주간보고서_0423.docx
0 → 100644
No preview for this file type
Docs/최종보고서/최종보고서_2015104198_이민호.pdf
0 → 100644
No preview for this file type
Docs/최종보고서/캡스톤디자인2_최종발표_이민호.pptx
0 → 100644
No preview for this file type
-
Please register or login to post a comment