이민호

upload 보고서, codes

Showing 156 changed files with 5303 additions and 0 deletions
1 +*.pyc
2 +__pycache__/
3 +/data/
...\ No newline at end of file ...\ No newline at end of file
1 +from __future__ import absolute_import
2 +from __future__ import division
3 +from __future__ import print_function
4 +
5 +import argparse
6 +from functools import partial
7 +from multiprocessing import Pool
8 +import os
9 +import re
10 +
11 +import cropper
12 +import numpy as np
13 +import tqdm
14 +
15 +
16 +# ==============================================================================
17 +# = param =
18 +# ==============================================================================
19 +
20 +parser = argparse.ArgumentParser()
21 +# main
22 +parser.add_argument('--img_dir', dest='img_dir', default='./data/img_celeba')
23 +parser.add_argument('--save_dir', dest='save_dir', default='./data/aligned')
24 +parser.add_argument('--landmark_file', dest='landmark_file', default='./data/landmark.txt')
25 +parser.add_argument('--standard_landmark_file', dest='standard_landmark_file', default='./data/standard_landmark_68pts.txt')
26 +parser.add_argument('--crop_size_h', dest='crop_size_h', type=int, default=572)
27 +parser.add_argument('--crop_size_w', dest='crop_size_w', type=int, default=572)
28 +parser.add_argument('--move_h', dest='move_h', type=float, default=0.25)
29 +parser.add_argument('--move_w', dest='move_w', type=float, default=0.)
30 +parser.add_argument('--save_format', dest='save_format', choices=['jpg', 'png'], default='jpg')
31 +parser.add_argument('--n_worker', dest='n_worker', type=int, default=8)
32 +# others
33 +parser.add_argument('--face_factor', dest='face_factor', type=float, help='The factor of face area relative to the output image.', default=0.45)
34 +parser.add_argument('--align_type', dest='align_type', choices=['affine', 'similarity'], default='similarity')
35 +parser.add_argument('--order', dest='order', type=int, choices=[0, 1, 2, 3, 4, 5], help='The order of interpolation.', default=3)
36 +parser.add_argument('--mode', dest='mode', choices=['constant', 'edge', 'symmetric', 'reflect', 'wrap'], default='edge')
37 +args = parser.parse_args()
38 +
39 +
40 +# ==============================================================================
41 +# = opencv first =
42 +# ==============================================================================
43 +
44 +_DEAFAULT_JPG_QUALITY = 95
45 +try:
46 + import cv2
47 + imread = cv2.imread
48 + imwrite = partial(cv2.imwrite, params=[int(cv2.IMWRITE_JPEG_QUALITY), _DEAFAULT_JPG_QUALITY])
49 + align_crop = cropper.align_crop_opencv
50 + print('Use OpenCV')
51 +except:
52 + import skimage.io as io
53 + imread = io.imread
54 + imwrite = partial(io.imsave, quality=_DEAFAULT_JPG_QUALITY)
55 + align_crop = cropper.align_crop_skimage
56 + print('Importing OpenCv fails. Use scikit-image')
57 +
58 +
59 +# ==============================================================================
60 +# = run =
61 +# ==============================================================================
62 +
63 +# count landmarks
64 +with open(args.landmark_file) as f:
65 + line = f.readline()
66 +n_landmark = len(re.split('[ ]+', line)[1:]) // 2
67 +
68 +# load standard landmark
69 +standard_landmark = np.genfromtxt(args.standard_landmark_file, dtype=np.float).reshape(n_landmark, 2)
70 +standard_landmark[:, 0] += args.move_w
71 +standard_landmark[:, 1] += args.move_h
72 +
73 +# data dir
74 +save_dir = os.path.join(args.save_dir, 'align_size(%d,%d)_move(%.3f,%.3f)_face_factor(%.3f)_%s' % (args.crop_size_h, args.crop_size_w, args.move_h, args.move_w, args.face_factor, args.save_format))
75 +data_dir = os.path.join(save_dir, 'data')
76 +if not os.path.isdir(data_dir):
77 + os.makedirs(data_dir)
78 +
79 +
80 +def work(name, landmark) -> str: # a single work
81 + for _ in range(3): # try three times
82 + try:
83 + img = imread(os.path.join(args.img_dir, name))
84 + img_crop, tformed_landmarks = align_crop(img,
85 + landmark,
86 + standard_landmark,
87 + crop_size=(args.crop_size_h, args.crop_size_w),
88 + face_factor=args.face_factor,
89 + align_type=args.align_type,
90 + order=args.order,
91 + mode=args.mode)
92 +
93 + name = os.path.splitext(name)[0] + '.' + args.save_format
94 + path = os.path.join(data_dir, name)
95 + if not os.path.isdir(os.path.split(path)[0]):
96 + os.makedirs(os.path.split(path)[0])
97 + imwrite(path, img_crop)
98 +
99 + tformed_landmarks.shape = -1
100 + name_landmark_str = ('%s' + ' %.1f' * n_landmark * 2) % ((name, ) + tuple(tformed_landmarks))
101 + return name_landmark_str
102 + except:
103 + print('%s fails!' % name)
104 +
105 +
106 +if __name__ == "__main__":
107 + img_names = np.genfromtxt(args.landmark_file, dtype=np.str, usecols=0)
108 + landmarks = np.genfromtxt(args.landmark_file, dtype=np.float,
109 + usecols=range(1, n_landmark * 2 + 1)).reshape(-1, n_landmark, 2)
110 +
111 + n_pics = len(img_names)
112 +
113 + landmarks_path = os.path.join(save_dir, 'landmark.txt')
114 + f = open(landmarks_path, 'w')
115 + pool = Pool(args.n_worker)
116 + bar = tqdm.tqdm(total=n_pics)
117 +
118 + tasks = []
119 + for i in range(n_pics):
120 + tasks.append(pool.apply_async(work, (img_names[i], landmarks[i]), callback=lambda _: bar.update()))
121 +
122 + try:
123 + result = tasks.pop(0).get()
124 + if result is not None and result != "":
125 + f.write(result + '\n')
126 + except:
127 + pass
128 +
129 + pool.close()
130 + pool.join()
131 + bar.close()
132 + f.close()
...\ No newline at end of file ...\ No newline at end of file
1 +import numpy as np
2 +
3 +
4 +def align_crop_opencv(img,
5 + src_landmarks,
6 + standard_landmarks,
7 + crop_size=512,
8 + face_factor=0.7,
9 + align_type='similarity',
10 + order=3,
11 + mode='edge'):
12 + """Align and crop a face image by landmarks.
13 +
14 + Arguments:
15 + img : Face image to be aligned and cropped.
16 + src_landmarks : [[x_1, y_1], ..., [x_n, y_n]].
17 + standard_landmarks : Standard shape, should be normalized.
18 + crop_size : Output image size, should be 1. int for (crop_size, crop_size)
19 + or 2. (int, int) for (crop_size_h, crop_size_w).
20 + face_factor : The factor of face area relative to the output image.
21 + align_type : 'similarity' or 'affine'.
22 + order : The order of interpolation. The order has to be in the range 0-5:
23 + - 0: INTER_NEAREST
24 + - 1: INTER_LINEAR
25 + - 2: INTER_AREA
26 + - 3: INTER_CUBIC
27 + - 4: INTER_LANCZOS4
28 + - 5: INTER_LANCZOS4
29 + mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap'].
30 + Points outside the boundaries of the input are filled according
31 + to the given mode.
32 + """
33 + # set OpenCV
34 + import cv2
35 + inter = {0: cv2.INTER_NEAREST, 1: cv2.INTER_LINEAR, 2: cv2.INTER_AREA,
36 + 3: cv2.INTER_CUBIC, 4: cv2.INTER_LANCZOS4, 5: cv2.INTER_LANCZOS4}
37 + border = {'constant': cv2.BORDER_CONSTANT, 'edge': cv2.BORDER_REPLICATE,
38 + 'symmetric': cv2.BORDER_REFLECT, 'reflect': cv2.BORDER_REFLECT101,
39 + 'wrap': cv2.BORDER_WRAP}
40 +
41 + # check
42 + assert align_type in ['affine', 'similarity'], 'Invalid `align_type`! Allowed: %s!' % ['affine', 'similarity']
43 + assert order in [0, 1, 2, 3, 4, 5], 'Invalid `order`! Allowed: %s!' % [0, 1, 2, 3, 4, 5]
44 + assert mode in ['constant', 'edge', 'symmetric', 'reflect', 'wrap'], 'Invalid `mode`! Allowed: %s!' % ['constant', 'edge', 'symmetric', 'reflect', 'wrap']
45 +
46 + # crop size
47 + if isinstance(crop_size, (list, tuple)) and len(crop_size) == 2:
48 + crop_size_h = crop_size[0]
49 + crop_size_w = crop_size[1]
50 + elif isinstance(crop_size, int):
51 + crop_size_h = crop_size_w = crop_size
52 + else:
53 + raise Exception('Invalid `crop_size`! `crop_size` should be 1. int for (crop_size, crop_size) or 2. (int, int) for (crop_size_h, crop_size_w)!')
54 +
55 + # estimate transform matrix
56 + trg_landmarks = standard_landmarks * max(crop_size_h, crop_size_w) * face_factor + np.array([crop_size_w // 2, crop_size_h // 2])
57 + if align_type == 'affine':
58 + tform = cv2.estimateAffine2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0]
59 + else:
60 + tform = cv2.estimateAffinePartial2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0]
61 +
62 + # warp image by given transform
63 + output_shape = (crop_size_h, crop_size_w)
64 + img_crop = cv2.warpAffine(img, tform, output_shape[::-1], flags=cv2.WARP_INVERSE_MAP + inter[order], borderMode=border[mode])
65 +
66 + # get transformed landmarks
67 + tformed_landmarks = cv2.transform(np.expand_dims(src_landmarks, axis=0), cv2.invertAffineTransform(tform))[0]
68 +
69 + return img_crop, tformed_landmarks
70 +
71 +
72 +def align_crop_skimage(img,
73 + src_landmarks,
74 + standard_landmarks,
75 + crop_size=512,
76 + face_factor=0.7,
77 + align_type='similarity',
78 + order=3,
79 + mode='edge'):
80 + """Align and crop a face image by landmarks.
81 +
82 + Arguments:
83 + img : Face image to be aligned and cropped.
84 + src_landmarks : [[x_1, y_1], ..., [x_n, y_n]].
85 + standard_landmarks : Standard shape, should be normalized.
86 + crop_size : Output image size, should be 1. int for (crop_size, crop_size)
87 + or 2. (int, int) for (crop_size_h, crop_size_w).
88 + face_factor : The factor of face area relative to the output image.
89 + align_type : 'similarity' or 'affine'.
90 + order : The order of interpolation. The order has to be in the range 0-5:
91 + - 0: INTER_NEAREST
92 + - 1: INTER_LINEAR
93 + - 2: INTER_AREA
94 + - 3: INTER_CUBIC
95 + - 4: INTER_LANCZOS4
96 + - 5: INTER_LANCZOS4
97 + mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap'].
98 + Points outside the boundaries of the input are filled according
99 + to the given mode.
100 + """
101 + raise NotImplementedError("'align_crop_skimage' is not implemented!")
This diff could not be displayed because it is too large.
1 +# Auto detect text files and perform LF normalization
2 +* text=auto
1 +*.pyc
2 +docs
3 +data
4 +lfw
5 +lfw_40
6 +.idea
7 +loss
8 +vgg_face_dataset
9 +saved_network
10 +loss
11 +z_detect_face.py
12 +z_main.py
13 +*.npy
14 +*.Lnk
15 +data1
16 +data1_masked
17 +scratch.py
18 +subset
19 +subset_masked
20 +vgg_face_dataset
21 +*.mp4
22 +ML_examples
23 +*.pptx
24 +datasets
25 +*.dat
26 +*.docx
27 +
1 +theme: jekyll-theme-cayman
...\ No newline at end of file ...\ No newline at end of file
1 +# Author: aqeelanwar
2 +# Created: 27 April,2020, 10:22 PM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +import argparse
6 +import dlib
7 +from utils.aux_functions import *
8 +
9 +
10 +# Command-line input setup
11 +parser = argparse.ArgumentParser(
12 + description="MaskTheFace - Python code to mask faces dataset"
13 +)
14 +parser.add_argument(
15 + "--path",
16 + type=str,
17 + default="",
18 + help="Path to either the folder containing images or the image itself",
19 +)
20 +parser.add_argument(
21 + "--mask_type",
22 + type=str,
23 + default="surgical",
24 + choices=["surgical", "N95", "KN95", "cloth", "gas", "inpaint", "random", "all"],
25 + help="Type of the mask to be applied. Available options: all, surgical_blue, surgical_green, N95, cloth",
26 +)
27 +
28 +parser.add_argument(
29 + "--pattern",
30 + type=str,
31 + default="",
32 + help="Type of the pattern. Available options in masks/textures",
33 +)
34 +
35 +parser.add_argument(
36 + "--pattern_weight",
37 + type=float,
38 + default=0.5,
39 + help="Weight of the pattern. Must be between 0 and 1",
40 +)
41 +
42 +parser.add_argument(
43 + "--color",
44 + type=str,
45 + default="#0473e2",
46 + help="Hex color value that need to be overlayed to the mask",
47 +)
48 +
49 +parser.add_argument(
50 + "--color_weight",
51 + type=float,
52 + default=0.5,
53 + help="Weight of the color intensity. Must be between 0 and 1",
54 +)
55 +
56 +parser.add_argument(
57 + "--code",
58 + type=str,
59 + # default="cloth-masks/textures/check/check_4.jpg, cloth-#e54294, cloth-#ff0000, cloth, cloth-masks/textures/others/heart_1.png, cloth-masks/textures/fruits/pineapple.png, N95, surgical_blue, surgical_green",
60 + default="",
61 + help="Generate specific formats",
62 +)
63 +
64 +
65 +parser.add_argument(
66 + "--verbose", dest="verbose", action="store_true", help="Turn verbosity on"
67 +)
68 +parser.add_argument(
69 + "--write_original_image",
70 + dest="write_original_image",
71 + action="store_true",
72 + help="If true, original image is also stored in the masked folder",
73 +)
74 +parser.set_defaults(feature=False)
75 +
76 +args = parser.parse_args()
77 +args.write_path = args.path + "_masked"
78 +
79 +# Set up dlib face detector and predictor
80 +args.detector = dlib.get_frontal_face_detector()
81 +path_to_dlib_model = "dlib_models/shape_predictor_68_face_landmarks.dat"
82 +if not os.path.exists(path_to_dlib_model):
83 + download_dlib_model()
84 +
85 +args.predictor = dlib.shape_predictor(path_to_dlib_model)
86 +
87 +# Extract data from code
88 +mask_code = "".join(args.code.split()).split(",")
89 +args.code_count = np.zeros(len(mask_code))
90 +args.mask_dict_of_dict = {}
91 +
92 +
93 +for i, entry in enumerate(mask_code):
94 + mask_dict = {}
95 + mask_color = ""
96 + mask_texture = ""
97 + mask_type = entry.split("-")[0]
98 + if len(entry.split("-")) == 2:
99 + mask_variation = entry.split("-")[1]
100 + if "#" in mask_variation:
101 + mask_color = mask_variation
102 + else:
103 + mask_texture = mask_variation
104 + mask_dict["type"] = mask_type
105 + mask_dict["color"] = mask_color
106 + mask_dict["texture"] = mask_texture
107 + args.mask_dict_of_dict[i] = mask_dict
108 +
109 +# Check if path is file or directory or none
110 +is_directory, is_file, is_other = check_path(args.path)
111 +display_MaskTheFace()
112 +
113 +if is_directory:
114 + path, dirs, files = os.walk(args.path).__next__()
115 + file_count = len(files)
116 + dirs_count = len(dirs)
117 + if len(files) > 0:
118 + print_orderly("Masking image files", 60)
119 +
120 + # Process files in the directory if any
121 + for f in tqdm(files):
122 + image_path = path + "/" + f
123 +
124 + write_path = path + "_masked"
125 + if not os.path.isdir(write_path):
126 + os.makedirs(write_path)
127 +
128 + if is_image(image_path):
129 + # Proceed if file is image
130 + if args.verbose:
131 + str_p = "Processing: " + image_path
132 + tqdm.write(str_p)
133 +
134 + split_path = f.rsplit(".")
135 + masked_image, mask, mask_binary_array, original_image = mask_image(
136 + image_path, args
137 + )
138 + for i in range(len(mask)):
139 + w_path = (
140 + write_path
141 + + "/"
142 + + split_path[0]
143 + + "_"
144 + + "masked"
145 + + "."
146 + + split_path[1]
147 + )
148 + img = masked_image[i]
149 + binary_img = mask_binary_array[i]
150 + cv2.imwrite(w_path, img)
151 + cv2.imwrite(
152 + path + "_binary/" + split_path[0] + "_binary" + "." + split_path[1],
153 + binary_img,
154 + )
155 + cv2.imwrite(
156 + path + "_original/" + split_path[0] + "." + split_path[1],
157 + original_image,
158 + )
159 +
160 + print_orderly("Masking image directories", 60)
161 +
162 + # Process directories withing the path provided
163 + for d in tqdm(dirs):
164 + dir_path = args.path + "/" + d
165 + dir_write_path = args.write_path + "/" + d
166 + if not os.path.isdir(dir_write_path):
167 + os.makedirs(dir_write_path)
168 + _, _, files = os.walk(dir_path).__next__()
169 +
170 + # Process each files within subdirectory
171 + for f in files:
172 + image_path = dir_path + "/" + f
173 + if args.verbose:
174 + str_p = "Processing: " + image_path
175 + tqdm.write(str_p)
176 + write_path = dir_write_path
177 + if is_image(image_path):
178 + # Proceed if file is image
179 + split_path = f.rsplit(".")
180 + masked_image, mask, mask_binary, original_image = mask_image(
181 + image_path, args
182 + )
183 + for i in range(len(mask)):
184 + w_path = (
185 + write_path
186 + + "/"
187 + + split_path[0]
188 + + "_"
189 + + "masked"
190 + + "."
191 + + split_path[1]
192 + )
193 + w_path_original = write_path + "/" + f
194 + img = masked_image[i]
195 + binary_img = mask_binary[i]
196 + cv2.imwrite(
197 + path
198 + + "_binary/"
199 + + split_path[0]
200 + + "_binary"
201 + + "."
202 + + split_path[1],
203 + binary_img,
204 + )
205 + # Write the masked image
206 + cv2.imwrite(w_path, img)
207 + if args.write_original_image:
208 + # Write the original image
209 + cv2.imwrite(w_path_original, original_image)
210 +
211 + if args.verbose:
212 + print(args.code_count)
213 +
214 +# Process if the path was a file
215 +elif is_file:
216 + print("Masking image file")
217 + image_path = args.path
218 + write_path = args.path.rsplit(".")[0]
219 + if is_image(image_path):
220 + # Proceed if file is image
221 + # masked_images, mask, mask_binary_array, original_image
222 + masked_image, mask, mask_binary_array, original_image = mask_image(
223 + image_path, args
224 + )
225 + for i in range(len(mask)):
226 + w_path = write_path + "_" + "masked" + "." + args.path.rsplit(".")[1]
227 + img = masked_image[i]
228 + binary_img = mask_binary_array[i]
229 + cv2.imwrite(w_path, img)
230 + cv2.imwrite(write_path + "_binary." + args.path.rsplit(".")[1], binary_img)
231 +else:
232 + print("Path is neither a valid file or a valid directory")
233 +print("Processing Done")
1 +[surgical]
2 +template: masks/templates/surgical.png
3 +mask_a: 21, 97
4 +mask_b: 307, 22
5 +mask_c: 600, 99
6 +mask_d: 25, 322
7 +mask_e: 295, 470
8 +mask_f: 600, 323
9 +
10 +[surgical_left]
11 +template: masks/templates/surgical_left.png
12 +mask_a: 39, 27
13 +mask_b: 130, 9
14 +mask_c: 567, 20
15 +mask_d: 87, 207
16 +mask_e: 168, 302
17 +mask_f: 568, 202
18 +
19 +[surgical_right]
20 +template: masks/templates/surgical_right.png
21 +mask_a: 3, 20
22 +mask_b: 440, 9
23 +mask_c: 531, 27
24 +mask_d: 2, 202
25 +mask_e: 402, 302
26 +mask_f: 483, 207
27 +
28 +[surgical_green]
29 +template: masks/templates/surgical_green.png
30 +mask_a: 21, 97
31 +mask_b: 307, 22
32 +mask_c: 600, 99
33 +mask_d: 25, 322
34 +mask_e: 295, 470
35 +mask_f: 600, 323
36 +
37 +[surgical_green_left]
38 +template: masks/templates/surgical_green_left.png
39 +mask_a: 39, 27
40 +mask_b: 130, 9
41 +mask_c: 567, 20
42 +mask_d: 87, 207
43 +mask_e: 168, 302
44 +mask_f: 568, 202
45 +
46 +[surgical_green_right]
47 +template: masks/templates/surgical_green_right.png
48 +mask_a: 3, 20
49 +mask_b: 440, 9
50 +mask_c: 531, 27
51 +mask_d: 2, 202
52 +mask_e: 402, 302
53 +mask_f: 483, 207
54 +
55 +[surgical_blue]
56 +template: masks/templates/surgical_blue.png
57 +mask_a: 21, 97
58 +mask_b: 307, 22
59 +mask_c: 600, 99
60 +mask_d: 25, 322
61 +mask_e: 295, 470
62 +mask_f: 600, 323
63 +
64 +[surgical_blue_left]
65 +template: masks/templates/surgical_blue_left.png
66 +mask_a: 39, 27
67 +mask_b: 130, 9
68 +mask_c: 567, 20
69 +mask_d: 87, 207
70 +mask_e: 168, 302
71 +mask_f: 568, 202
72 +
73 +[surgical_blue_right]
74 +template: masks/templates/surgical_blue_right.png
75 +mask_a: 3, 20
76 +mask_b: 440, 9
77 +mask_c: 531, 27
78 +mask_d: 2, 202
79 +mask_e: 402, 302
80 +mask_f: 483, 207
81 +
82 +
83 +[N95]
84 +template: masks/templates/N95.png
85 +mask_a: 15, 119
86 +mask_b: 327, 5
87 +mask_c: 640, 93
88 +mask_d: 13, 285
89 +mask_e: 351, 518
90 +mask_f: 645, 285
91 +
92 +;[N95_left]
93 +;template: masks/N95_left.png
94 +;mask_a: 176, 121
95 +;mask_b: 313, 46
96 +;mask_c: 799, 135
97 +;mask_d: 97, 438
98 +;mask_e: 329, 627
99 +;mask_f: 791, 401
100 +
101 +[N95_right]
102 +template: masks/templates/N95_right.png
103 +mask_c: 979, 331
104 +mask_b: 806, 172
105 +mask_a: 12, 222
106 +mask_f: 907, 762
107 +mask_e: 577, 875
108 +mask_d: -4, 632
109 +
110 +[N95_left]
111 +template: masks/templates/N95_left.png
112 +mask_a: 193, 331
113 +mask_b: 366, 172
114 +mask_c: 1160, 222
115 +mask_d: 265, 762
116 +mask_e: 595, 875
117 +mask_f: 1176, 632
118 +
119 +
120 +[cloth_left]
121 +template: masks/templates/cloth_left.png
122 +mask_a: 65, 93
123 +mask_b: 162, 15
124 +mask_c: 672, 75
125 +mask_d: 114, 296
126 +mask_e: 207, 443
127 +mask_f: 671, 341
128 +
129 +[cloth_right]
130 +template: masks/templates/cloth_right.png
131 +mask_a: 98, 93
132 +mask_b: 608, 15
133 +mask_c: 705, 75
134 +mask_d: 99, 296
135 +mask_e: 563, 443
136 +mask_f: 656, 341
137 +
138 +[cloth]
139 +template: masks/templates/cloth.png
140 +mask_a: 122, 90
141 +mask_b: 405, 7
142 +mask_c: 686, 79
143 +mask_d: 165, 323
144 +mask_e: 406, 509
145 +mask_f: 653, 311
146 +
147 +[gas]
148 +template: masks/templates/gas.png
149 +mask_a: 330, 431
150 +mask_b: 873, 117
151 +mask_c: 1494, 434
152 +mask_d: 430, 754
153 +mask_e: 869, 1100
154 +mask_f: 1400, 710
155 +
156 +[gas_left]
157 +template: masks/templates/gas_left.png
158 +mask_a: 239, 238
159 +mask_b: 317, 42
160 +mask_c: 965, 239
161 +mask_d: 224, 404
162 +mask_e: 337, 502
163 +mask_f: 963, 406
164 +
165 +[gas_right]
166 +template: masks/templates/gas_right.png
167 +mask_c: 621, 238
168 +mask_b: 543, 60
169 +mask_a: -105, 239
170 +mask_f: 636, 404
171 +mask_e: 523, 502
172 +mask_d: -103, 406
173 +
174 +[KN95]
175 +template: masks/templates/KN95.png
176 +mask_a: 20, 47
177 +mask_b: 410, 5
178 +mask_c: 760, 55
179 +mask_d: 75, 340
180 +mask_e: 398, 600
181 +mask_f: 671, 320
182 +
183 +[KN95_left]
184 +template: masks/templates/KN95_left.png
185 +mask_a: 52, 258
186 +mask_b: 207, 100
187 +mask_c: 730, 80
188 +mask_d: 210, 408
189 +mask_e: 335, 604
190 +mask_f: 770, 270
191 +
192 +[KN95_right]
193 +template: masks/templates/KN95_right.png
194 +mask_c: 664, 258
195 +mask_b: 509, 100
196 +mask_a: -14, 80
197 +mask_f: 506, 408
198 +mask_e: 381, 604
199 +mask_d: -54, 270
200 +
201 +
202 +[empty]
203 +[empty_left]
204 +[empty_right]
205 +
206 +[inpaint]
207 +[inpaint_left]
208 +[inpaint_right]
209 +
1 +certifi==2020.4.5.1
2 +click==7.1.2
3 +dlib==19.19.0
4 +dotmap==1.3.14
5 +face-recognition==1.3.0
6 +face-recognition-models==0.3.0
7 +numpy==1.18.4
8 +opencv-python==4.2.0.34
9 +Pillow==7.1.2
10 +tqdm==4.46.0
11 +wincertstore==0.2
12 +imutils==0.5.3
13 +requests==2.24.0
1 +# Author: Aqeel Anwar(ICSRL)
2 +# Created: 7/30/2020, 7:43 AM
3 +# Email: aqeel.anwar@gatech.edu
...\ No newline at end of file ...\ No newline at end of file
1 +# Author: aqeelanwar
2 +# Created: 27 April,2020, 10:21 PM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +from configparser import ConfigParser
6 +import cv2, math, os
7 +from PIL import Image, ImageDraw
8 +from tqdm import tqdm
9 +from utils.read_cfg import read_cfg
10 +from utils.fit_ellipse import *
11 +import random
12 +from utils.create_mask import texture_the_mask, color_the_mask
13 +from imutils import face_utils
14 +import requests
15 +from zipfile import ZipFile
16 +from tqdm import tqdm
17 +import bz2, shutil
18 +
19 +
20 +def download_dlib_model():
21 + print_orderly("Get dlib model", 60)
22 + dlib_model_link = "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2"
23 + print("Downloading dlib model...")
24 + with requests.get(dlib_model_link, stream=True) as r:
25 + print("Zip file size: ", np.round(len(r.content) / 1024 / 1024, 2), "MB")
26 + destination = (
27 + "dlib_models" + os.path.sep + "shape_predictor_68_face_landmarks.dat.bz2"
28 + )
29 + if not os.path.exists(destination.rsplit(os.path.sep, 1)[0]):
30 + os.mkdir(destination.rsplit(os.path.sep, 1)[0])
31 + print("Saving dlib model...")
32 + with open(destination, "wb") as fd:
33 + for chunk in r.iter_content(chunk_size=32678):
34 + fd.write(chunk)
35 + print("Extracting dlib model...")
36 + with bz2.BZ2File(destination) as fr, open(
37 + "dlib_models/shape_predictor_68_face_landmarks.dat", "wb"
38 + ) as fw:
39 + shutil.copyfileobj(fr, fw)
40 + print("Saved: ", destination)
41 + print_orderly("done", 60)
42 +
43 + os.remove(destination)
44 +
45 +
46 +def get_line(face_landmark, image, type="eye", debug=False):
47 + pil_image = Image.fromarray(image)
48 + d = ImageDraw.Draw(pil_image)
49 + left_eye = face_landmark["left_eye"]
50 + right_eye = face_landmark["right_eye"]
51 + left_eye_mid = np.mean(np.array(left_eye), axis=0)
52 + right_eye_mid = np.mean(np.array(right_eye), axis=0)
53 + eye_line_mid = (left_eye_mid + right_eye_mid) / 2
54 +
55 + if type == "eye":
56 + left_point = left_eye_mid
57 + right_point = right_eye_mid
58 + mid_point = eye_line_mid
59 +
60 + elif type == "nose_mid":
61 + nose_length = (
62 + face_landmark["nose_bridge"][-1][1] - face_landmark["nose_bridge"][0][1]
63 + )
64 + left_point = [left_eye_mid[0], left_eye_mid[1] + nose_length / 2]
65 + right_point = [right_eye_mid[0], right_eye_mid[1] + nose_length / 2]
66 + # mid_point = (
67 + # face_landmark["nose_bridge"][-1][1] + face_landmark["nose_bridge"][0][1]
68 + # ) / 2
69 +
70 + mid_pointY = (
71 + face_landmark["nose_bridge"][-1][1] + face_landmark["nose_bridge"][0][1]
72 + ) / 2
73 + mid_pointX = (
74 + face_landmark["nose_bridge"][-1][0] + face_landmark["nose_bridge"][0][0]
75 + ) / 2
76 + mid_point = (mid_pointX, mid_pointY)
77 +
78 + elif type == "nose_tip":
79 + nose_length = (
80 + face_landmark["nose_bridge"][-1][1] - face_landmark["nose_bridge"][0][1]
81 + )
82 + left_point = [left_eye_mid[0], left_eye_mid[1] + nose_length]
83 + right_point = [right_eye_mid[0], right_eye_mid[1] + nose_length]
84 + mid_point = (
85 + face_landmark["nose_bridge"][-1][1] + face_landmark["nose_bridge"][0][1]
86 + ) / 2
87 +
88 + elif type == "bottom_lip":
89 + bottom_lip = face_landmark["bottom_lip"]
90 + bottom_lip_mid = np.max(np.array(bottom_lip), axis=0)
91 + shiftY = bottom_lip_mid[1] - eye_line_mid[1]
92 + left_point = [left_eye_mid[0], left_eye_mid[1] + shiftY]
93 + right_point = [right_eye_mid[0], right_eye_mid[1] + shiftY]
94 + mid_point = bottom_lip_mid
95 +
96 + elif type == "perp_line":
97 + bottom_lip = face_landmark["bottom_lip"]
98 + bottom_lip_mid = np.mean(np.array(bottom_lip), axis=0)
99 +
100 + left_point = eye_line_mid
101 + left_point = face_landmark["nose_bridge"][0]
102 + right_point = bottom_lip_mid
103 +
104 + mid_point = bottom_lip_mid
105 +
106 + elif type == "nose_long":
107 + nose_bridge = face_landmark["nose_bridge"]
108 + left_point = [nose_bridge[0][0], nose_bridge[0][1]]
109 + right_point = [nose_bridge[-1][0], nose_bridge[-1][1]]
110 +
111 + mid_point = left_point
112 +
113 + # d.line(eye_mid, width=5, fill='red')
114 + y = [left_point[1], right_point[1]]
115 + x = [left_point[0], right_point[0]]
116 + # cv2.imshow('h', image)
117 + # cv2.waitKey(0)
118 + eye_line = fit_line(x, y, image)
119 + d.line(eye_line, width=5, fill="blue")
120 +
121 + # Perpendicular Line
122 + # (midX, midY) and (midX - y2 + y1, midY + x2 - x1)
123 + y = [
124 + (left_point[1] + right_point[1]) / 2,
125 + (left_point[1] + right_point[1]) / 2 + right_point[0] - left_point[0],
126 + ]
127 + x = [
128 + (left_point[0] + right_point[0]) / 2,
129 + (left_point[0] + right_point[0]) / 2 - right_point[1] + left_point[1],
130 + ]
131 + perp_line = fit_line(x, y, image)
132 + if debug:
133 + d.line(perp_line, width=5, fill="red")
134 + pil_image.show()
135 + return eye_line, perp_line, left_point, right_point, mid_point
136 +
137 +
138 +def get_points_on_chin(line, face_landmark, chin_type="chin"):
139 + chin = face_landmark[chin_type]
140 + points_on_chin = []
141 + for i in range(len(chin) - 1):
142 + chin_first_point = [chin[i][0], chin[i][1]]
143 + chin_second_point = [chin[i + 1][0], chin[i + 1][1]]
144 +
145 + flag, x, y = line_intersection(line, (chin_first_point, chin_second_point))
146 + if flag:
147 + points_on_chin.append((x, y))
148 +
149 + return points_on_chin
150 +
151 +
152 +def plot_lines(face_line, image, debug=False):
153 + pil_image = Image.fromarray(image)
154 + if debug:
155 + d = ImageDraw.Draw(pil_image)
156 + d.line(face_line, width=4, fill="white")
157 + pil_image.show()
158 +
159 +
160 +def line_intersection(line1, line2):
161 + # mid = int(len(line1) / 2)
162 + start = 0
163 + end = -1
164 + line1 = ([line1[start][0], line1[start][1]], [line1[end][0], line1[end][1]])
165 +
166 + xdiff = (line1[0][0] - line1[1][0], line2[0][0] - line2[1][0])
167 + ydiff = (line1[0][1] - line1[1][1], line2[0][1] - line2[1][1])
168 + x = []
169 + y = []
170 + flag = False
171 +
172 + def det(a, b):
173 + return a[0] * b[1] - a[1] * b[0]
174 +
175 + div = det(xdiff, ydiff)
176 + if div == 0:
177 + return flag, x, y
178 +
179 + d = (det(*line1), det(*line2))
180 + x = det(d, xdiff) / div
181 + y = det(d, ydiff) / div
182 +
183 + segment_minX = min(line2[0][0], line2[1][0])
184 + segment_maxX = max(line2[0][0], line2[1][0])
185 +
186 + segment_minY = min(line2[0][1], line2[1][1])
187 + segment_maxY = max(line2[0][1], line2[1][1])
188 +
189 + if (
190 + segment_maxX + 1 >= x >= segment_minX - 1
191 + and segment_maxY + 1 >= y >= segment_minY - 1
192 + ):
193 + flag = True
194 +
195 + return flag, x, y
196 +
197 +
198 +def fit_line(x, y, image):
199 + if x[0] == x[1]:
200 + x[0] += 0.1
201 + coefficients = np.polyfit(x, y, 1)
202 + polynomial = np.poly1d(coefficients)
203 + x_axis = np.linspace(0, image.shape[1], 50)
204 + y_axis = polynomial(x_axis)
205 + eye_line = []
206 + for i in range(len(x_axis)):
207 + eye_line.append((x_axis[i], y_axis[i]))
208 +
209 + return eye_line
210 +
211 +
212 +def get_six_points(face_landmark, image):
213 + _, perp_line1, _, _, m = get_line(face_landmark, image, type="nose_mid")
214 + face_b = m
215 +
216 + perp_line, _, _, _, _ = get_line(face_landmark, image, type="perp_line")
217 + points1 = get_points_on_chin(perp_line1, face_landmark)
218 + points = get_points_on_chin(perp_line, face_landmark)
219 + if not points1:
220 + face_e = tuple(np.asarray(points[0]))
221 + elif not points:
222 + face_e = tuple(np.asarray(points1[0]))
223 + else:
224 + face_e = tuple((np.asarray(points[0]) + np.asarray(points1[0])) / 2)
225 + # face_e = points1[0]
226 + nose_mid_line, _, _, _, _ = get_line(face_landmark, image, type="nose_long")
227 +
228 + angle = get_angle(perp_line, nose_mid_line)
229 + # print("angle: ", angle)
230 + nose_mid_line, _, _, _, _ = get_line(face_landmark, image, type="nose_tip")
231 + points = get_points_on_chin(nose_mid_line, face_landmark)
232 + if len(points) < 2:
233 + face_landmark = get_face_ellipse(face_landmark)
234 + # print("extrapolating chin")
235 + points = get_points_on_chin(
236 + nose_mid_line, face_landmark, chin_type="chin_extrapolated"
237 + )
238 + if len(points) < 2:
239 + points = []
240 + points.append(face_landmark["chin"][0])
241 + points.append(face_landmark["chin"][-1])
242 + face_a = points[0]
243 + face_c = points[-1]
244 + # cv2.imshow('j', image)
245 + # cv2.waitKey(0)
246 + nose_mid_line, _, _, _, _ = get_line(face_landmark, image, type="bottom_lip")
247 + points = get_points_on_chin(nose_mid_line, face_landmark)
248 + face_d = points[0]
249 + face_f = points[-1]
250 +
251 + six_points = np.float32([face_a, face_b, face_c, face_f, face_e, face_d])
252 +
253 + return six_points, angle
254 +
255 +
256 +def get_angle(line1, line2):
257 + delta_y = line1[-1][1] - line1[0][1]
258 + delta_x = line1[-1][0] - line1[0][0]
259 + perp_angle = math.degrees(math.atan2(delta_y, delta_x))
260 + if delta_x < 0:
261 + perp_angle = perp_angle + 180
262 + if perp_angle < 0:
263 + perp_angle += 360
264 + if perp_angle > 180:
265 + perp_angle -= 180
266 +
267 + # print("perp", perp_angle)
268 + delta_y = line2[-1][1] - line2[0][1]
269 + delta_x = line2[-1][0] - line2[0][0]
270 + nose_angle = math.degrees(math.atan2(delta_y, delta_x))
271 +
272 + if delta_x < 0:
273 + nose_angle = nose_angle + 180
274 + if nose_angle < 0:
275 + nose_angle += 360
276 + if nose_angle > 180:
277 + nose_angle -= 180
278 + # print("nose", nose_angle)
279 +
280 + angle = nose_angle - perp_angle
281 + return angle
282 +
283 +
284 +def mask_face(image, face_location, six_points, angle, args, type="surgical"):
285 + debug = False
286 +
287 + # Find the face angle
288 + threshold = 13
289 + if angle < -threshold:
290 + type += "_right"
291 + elif angle > threshold:
292 + type += "_left"
293 +
294 + face_height = face_location[2] - face_location[0]
295 + face_width = face_location[1] - face_location[3]
296 + # image = image_raw[
297 + # face_location[0]-int(face_width/2): face_location[2]+int(face_width/2),
298 + # face_location[3]-int(face_height/2): face_location[1]+int(face_height/2),
299 + # :,
300 + # ]
301 + # cv2.imshow('win', image)
302 + # cv2.waitKey(0)
303 + # Read appropriate mask image
304 + w = image.shape[0]
305 + h = image.shape[1]
306 + if not "empty" in type and not "inpaint" in type:
307 + cfg = read_cfg(config_filename="masks/masks.cfg", mask_type=type, verbose=False)
308 + else:
309 + if "left" in type:
310 + str = "surgical_blue_left"
311 + elif "right" in type:
312 + str = "surgical_blue_right"
313 + else:
314 + str = "surgical_blue"
315 + cfg = read_cfg(config_filename="masks/masks.cfg", mask_type=str, verbose=False)
316 + img = cv2.imread(cfg.template, cv2.IMREAD_UNCHANGED)
317 +
318 + # Process the mask if necessary
319 + if args.pattern:
320 + # Apply pattern to mask
321 + img = texture_the_mask(img, args.pattern, args.pattern_weight)
322 +
323 + if args.color:
324 + # Apply color to mask
325 + img = color_the_mask(img, args.color, args.color_weight)
326 +
327 + mask_line = np.float32(
328 + [cfg.mask_a, cfg.mask_b, cfg.mask_c, cfg.mask_f, cfg.mask_e, cfg.mask_d]
329 + )
330 + # Warp the mask
331 + M, mask = cv2.findHomography(mask_line, six_points)
332 + dst_mask = cv2.warpPerspective(img, M, (h, w))
333 + dst_mask_points = cv2.perspectiveTransform(mask_line.reshape(-1, 1, 2), M)
334 + mask = dst_mask[:, :, 3]
335 + face_height = face_location[2] - face_location[0]
336 + face_width = face_location[1] - face_location[3]
337 + image_face = image[
338 + face_location[0] + int(face_height / 2) : face_location[2],
339 + face_location[3] : face_location[1],
340 + :,
341 + ]
342 +
343 + image_face = image
344 +
345 + # Adjust Brightness
346 + mask_brightness = get_avg_brightness(img)
347 + img_brightness = get_avg_brightness(image_face)
348 + delta_b = 1 + (img_brightness - mask_brightness) / 255
349 + dst_mask = change_brightness(dst_mask, delta_b)
350 +
351 + # Adjust Saturation
352 + mask_saturation = get_avg_saturation(img)
353 + img_saturation = get_avg_saturation(image_face)
354 + delta_s = 1 - (img_saturation - mask_saturation) / 255
355 + dst_mask = change_saturation(dst_mask, delta_s)
356 +
357 + # Apply mask
358 + mask_inv = cv2.bitwise_not(mask)
359 + img_bg = cv2.bitwise_and(image, image, mask=mask_inv)
360 + img_fg = cv2.bitwise_and(dst_mask, dst_mask, mask=mask)
361 + out_img = cv2.add(img_bg, img_fg[:, :, 0:3])
362 + if "empty" in type or "inpaint" in type:
363 + out_img = img_bg
364 + # Plot key points
365 +
366 + if "inpaint" in type:
367 + out_img = cv2.inpaint(out_img, mask, 3, cv2.INPAINT_TELEA)
368 + # dst_NS = cv2.inpaint(img, mask, 3, cv2.INPAINT_NS)
369 +
370 + if debug:
371 + for i in six_points:
372 + cv2.circle(out_img, (i[0], i[1]), radius=4, color=(0, 0, 255), thickness=-1)
373 +
374 + for i in dst_mask_points:
375 + cv2.circle(
376 + out_img, (i[0][0], i[0][1]), radius=4, color=(0, 255, 0), thickness=-1
377 + )
378 +
379 + return out_img, mask
380 +
381 +
382 +def draw_landmarks(face_landmarks, image):
383 + pil_image = Image.fromarray(image)
384 + d = ImageDraw.Draw(pil_image)
385 + for facial_feature in face_landmarks.keys():
386 + d.line(face_landmarks[facial_feature], width=5, fill="white")
387 + pil_image.show()
388 +
389 +
390 +def get_face_ellipse(face_landmark):
391 + chin = face_landmark["chin"]
392 + x = []
393 + y = []
394 + for point in chin:
395 + x.append(point[0])
396 + y.append(point[1])
397 +
398 + x = np.asarray(x)
399 + y = np.asarray(y)
400 +
401 + a = fitEllipse(x, y)
402 + center = ellipse_center(a)
403 + phi = ellipse_angle_of_rotation(a)
404 + axes = ellipse_axis_length(a)
405 + a, b = axes
406 +
407 + arc = 2.2
408 + R = np.arange(0, arc * np.pi, 0.2)
409 + xx = center[0] + a * np.cos(R) * np.cos(phi) - b * np.sin(R) * np.sin(phi)
410 + yy = center[1] + a * np.cos(R) * np.sin(phi) + b * np.sin(R) * np.cos(phi)
411 + chin_extrapolated = []
412 + for i in range(len(R)):
413 + chin_extrapolated.append((xx[i], yy[i]))
414 + face_landmark["chin_extrapolated"] = chin_extrapolated
415 + return face_landmark
416 +
417 +
418 +def get_avg_brightness(img):
419 + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
420 + h, s, v = cv2.split(img_hsv)
421 + return np.mean(v)
422 +
423 +
424 +def get_avg_saturation(img):
425 + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
426 + h, s, v = cv2.split(img_hsv)
427 + return np.mean(v)
428 +
429 +
430 +def change_brightness(img, value=1.0):
431 + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
432 + h, s, v = cv2.split(img_hsv)
433 + v = value * v
434 + v[v > 255] = 255
435 + v = np.asarray(v, dtype=np.uint8)
436 + final_hsv = cv2.merge((h, s, v))
437 + img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
438 + return img
439 +
440 +
441 +def change_saturation(img, value=1.0):
442 + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
443 + h, s, v = cv2.split(img_hsv)
444 + s = value * s
445 + s[s > 255] = 255
446 + s = np.asarray(s, dtype=np.uint8)
447 + final_hsv = cv2.merge((h, s, v))
448 + img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
449 + return img
450 +
451 +
452 +def check_path(path):
453 + is_directory = False
454 + is_file = False
455 + is_other = False
456 + if os.path.isdir(path):
457 + is_directory = True
458 + elif os.path.isfile(path):
459 + is_file = True
460 + else:
461 + is_other = True
462 +
463 + return is_directory, is_file, is_other
464 +
465 +
466 +def shape_to_landmarks(shape):
467 + face_landmarks = {}
468 + face_landmarks["left_eyebrow"] = [
469 + tuple(shape[17]),
470 + tuple(shape[18]),
471 + tuple(shape[19]),
472 + tuple(shape[20]),
473 + tuple(shape[21]),
474 + ]
475 + face_landmarks["right_eyebrow"] = [
476 + tuple(shape[22]),
477 + tuple(shape[23]),
478 + tuple(shape[24]),
479 + tuple(shape[25]),
480 + tuple(shape[26]),
481 + ]
482 + face_landmarks["nose_bridge"] = [
483 + tuple(shape[27]),
484 + tuple(shape[28]),
485 + tuple(shape[29]),
486 + tuple(shape[30]),
487 + ]
488 + face_landmarks["nose_tip"] = [
489 + tuple(shape[31]),
490 + tuple(shape[32]),
491 + tuple(shape[33]),
492 + tuple(shape[34]),
493 + tuple(shape[35]),
494 + ]
495 + face_landmarks["left_eye"] = [
496 + tuple(shape[36]),
497 + tuple(shape[37]),
498 + tuple(shape[38]),
499 + tuple(shape[39]),
500 + tuple(shape[40]),
501 + tuple(shape[41]),
502 + ]
503 + face_landmarks["right_eye"] = [
504 + tuple(shape[42]),
505 + tuple(shape[43]),
506 + tuple(shape[44]),
507 + tuple(shape[45]),
508 + tuple(shape[46]),
509 + tuple(shape[47]),
510 + ]
511 + face_landmarks["top_lip"] = [
512 + tuple(shape[48]),
513 + tuple(shape[49]),
514 + tuple(shape[50]),
515 + tuple(shape[51]),
516 + tuple(shape[52]),
517 + tuple(shape[53]),
518 + tuple(shape[54]),
519 + tuple(shape[60]),
520 + tuple(shape[61]),
521 + tuple(shape[62]),
522 + tuple(shape[63]),
523 + tuple(shape[64]),
524 + ]
525 +
526 + face_landmarks["bottom_lip"] = [
527 + tuple(shape[54]),
528 + tuple(shape[55]),
529 + tuple(shape[56]),
530 + tuple(shape[57]),
531 + tuple(shape[58]),
532 + tuple(shape[59]),
533 + tuple(shape[48]),
534 + tuple(shape[64]),
535 + tuple(shape[65]),
536 + tuple(shape[66]),
537 + tuple(shape[67]),
538 + tuple(shape[60]),
539 + ]
540 +
541 + face_landmarks["chin"] = [
542 + tuple(shape[0]),
543 + tuple(shape[1]),
544 + tuple(shape[2]),
545 + tuple(shape[3]),
546 + tuple(shape[4]),
547 + tuple(shape[5]),
548 + tuple(shape[6]),
549 + tuple(shape[7]),
550 + tuple(shape[8]),
551 + tuple(shape[9]),
552 + tuple(shape[10]),
553 + tuple(shape[11]),
554 + tuple(shape[12]),
555 + tuple(shape[13]),
556 + tuple(shape[14]),
557 + tuple(shape[15]),
558 + tuple(shape[16]),
559 + ]
560 + return face_landmarks
561 +
562 +
563 +def rect_to_bb(rect):
564 + x1 = rect.left()
565 + x2 = rect.right()
566 + y1 = rect.top()
567 + y2 = rect.bottom()
568 + return (x1, x2, y2, x1)
569 +
570 +
571 +def mask_image(image_path, args):
572 + # Read the image
573 + image = cv2.imread(image_path)
574 + original_image = image.copy()
575 + # gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
576 + gray = image
577 + face_locations = args.detector(gray, 1)
578 + mask_type = args.mask_type
579 + verbose = args.verbose
580 + if args.code:
581 + ind = random.randint(0, len(args.code_count) - 1)
582 + mask_dict = args.mask_dict_of_dict[ind]
583 + mask_type = mask_dict["type"]
584 + args.color = mask_dict["color"]
585 + args.pattern = mask_dict["texture"]
586 + args.code_count[ind] += 1
587 +
588 + elif mask_type == "random":
589 + available_mask_types = get_available_mask_types()
590 + mask_type = random.choice(available_mask_types)
591 +
592 + if verbose:
593 + tqdm.write("Faces found: {:2d}".format(len(face_locations)))
594 + # Process each face in the image
595 + masked_images = []
596 + mask_binary_array = []
597 + mask = []
598 + for (i, face_location) in enumerate(face_locations):
599 + shape = args.predictor(gray, face_location)
600 + shape = face_utils.shape_to_np(shape)
601 + face_landmarks = shape_to_landmarks(shape)
602 + face_location = rect_to_bb(face_location)
603 + # draw_landmarks(face_landmarks, image)
604 + six_points_on_face, angle = get_six_points(face_landmarks, image)
605 + mask = []
606 + if mask_type != "all":
607 + if len(masked_images) > 0:
608 + image = masked_images.pop(0)
609 + image, mask_binary = mask_face(
610 + image, face_location, six_points_on_face, angle, args, type=mask_type
611 + )
612 +
613 + # compress to face tight
614 + face_height = face_location[2] - face_location[0]
615 + face_width = face_location[1] - face_location[3]
616 + masked_images.append(image)
617 + mask_binary_array.append(mask_binary)
618 + mask.append(mask_type)
619 + else:
620 + available_mask_types = get_available_mask_types()
621 + for m in range(len(available_mask_types)):
622 + if len(masked_images) == len(available_mask_types):
623 + image = masked_images.pop(m)
624 + img, mask_binary = mask_face(
625 + image,
626 + face_location,
627 + six_points_on_face,
628 + angle,
629 + args,
630 + type=available_mask_types[m],
631 + )
632 + masked_images.insert(m, img)
633 + mask_binary_array.insert(m, mask_binary)
634 + mask = available_mask_types
635 + cc = 1
636 +
637 + return masked_images, mask, mask_binary_array, original_image
638 +
639 +
640 +def is_image(path):
641 + try:
642 + extensions = path[-4:]
643 + image_extensions = ["png", "PNG", "jpg", "JPG"]
644 +
645 + if extensions[1:] in image_extensions:
646 + return True
647 + else:
648 + print("Please input image file. png / jpg")
649 + return False
650 + except:
651 + return False
652 +
653 +
654 +def get_available_mask_types(config_filename="masks/masks.cfg"):
655 + parser = ConfigParser()
656 + parser.optionxform = str
657 + parser.read(config_filename)
658 + available_mask_types = parser.sections()
659 + available_mask_types = [
660 + string for string in available_mask_types if "left" not in string
661 + ]
662 + available_mask_types = [
663 + string for string in available_mask_types if "right" not in string
664 + ]
665 +
666 + return available_mask_types
667 +
668 +
669 +def print_orderly(str, n):
670 + # print("")
671 + hyphens = "-" * int((n - len(str)) / 2)
672 + str_p = hyphens + " " + str + " " + hyphens
673 + hyphens_bar = "-" * len(str_p)
674 + print(hyphens_bar)
675 + print(str_p)
676 + print(hyphens_bar)
677 +
678 +
679 +def display_MaskTheFace():
680 + with open("utils/display.txt", "r") as file:
681 + for line in file:
682 + cc = 1
683 + print(line, end="")
1 +# Author: aqeelanwar
2 +# Created: 6 July,2020, 12:14 AM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +from PIL import ImageColor
6 +import cv2
7 +import numpy as np
8 +
9 +COLOR = [
10 + "#fc1c1a",
11 + "#177ABC",
12 + "#94B6D2",
13 + "#A5AB81",
14 + "#DD8047",
15 + "#6b425e",
16 + "#e26d5a",
17 + "#c92c48",
18 + "#6a506d",
19 + "#ffc900",
20 + "#ffffff",
21 + "#000000",
22 + "#49ff00",
23 +]
24 +
25 +
26 +def color_the_mask(mask_image, color, intensity):
27 + assert 0 <= intensity <= 1, "intensity should be between 0 and 1"
28 + RGB_color = ImageColor.getcolor(color, "RGB")
29 + RGB_color = (RGB_color[2], RGB_color[1], RGB_color[0])
30 + orig_shape = mask_image.shape
31 + bit_mask = mask_image[:, :, 3]
32 + mask_image = mask_image[:, :, 0:3]
33 +
34 + color_image = np.full(mask_image.shape, RGB_color, np.uint8)
35 + mask_color = cv2.addWeighted(mask_image, 1 - intensity, color_image, intensity, 0)
36 + mask_color = cv2.bitwise_and(mask_color, mask_color, mask=bit_mask)
37 + colored_mask = np.zeros(orig_shape, dtype=np.uint8)
38 + colored_mask[:, :, 0:3] = mask_color
39 + colored_mask[:, :, 3] = bit_mask
40 + return colored_mask
41 +
42 +
43 +def texture_the_mask(mask_image, texture_path, intensity):
44 + assert 0 <= intensity <= 1, "intensity should be between 0 and 1"
45 + orig_shape = mask_image.shape
46 + bit_mask = mask_image[:, :, 3]
47 + mask_image = mask_image[:, :, 0:3]
48 + texture_image = cv2.imread(texture_path)
49 + texture_image = cv2.resize(texture_image, (orig_shape[1], orig_shape[0]))
50 +
51 + mask_texture = cv2.addWeighted(
52 + mask_image, 1 - intensity, texture_image, intensity, 0
53 + )
54 + mask_texture = cv2.bitwise_and(mask_texture, mask_texture, mask=bit_mask)
55 + textured_mask = np.zeros(orig_shape, dtype=np.uint8)
56 + textured_mask[:, :, 0:3] = mask_texture
57 + textured_mask[:, :, 3] = bit_mask
58 +
59 + return textured_mask
60 +
61 +
62 +
63 +# cloth_mask = cv2.imread("masks/templates/cloth.png", cv2.IMREAD_UNCHANGED)
64 +# # cloth_mask = color_the_mask(cloth_mask, color=COLOR[0], intensity=0.5)
65 +# path = "masks/textures"
66 +# path, dir, files = os.walk(path).__next__()
67 +# first_frame = True
68 +# col_limit = 6
69 +# i = 0
70 +# # img_concat_row=[]
71 +# img_concat = []
72 +# # for f in files:
73 +# # if "._" not in f:
74 +# # print(f)
75 +# # i += 1
76 +# # texture_image = cv2.imread(os.path.join(path, f))
77 +# # m = texture_the_mask(cloth_mask, texture_image, intensity=0.5)
78 +# # if first_frame:
79 +# # img_concat_row = m
80 +# # first_frame = False
81 +# # else:
82 +# # img_concat_row = cv2.hconcat((img_concat_row, m))
83 +# #
84 +# # if i % col_limit == 0:
85 +# # if len(img_concat) > 0:
86 +# # img_concat = cv2.vconcat((img_concat, img_concat_row))
87 +# # else:
88 +# # img_concat = img_concat_row
89 +# # first_frame = True
90 +#
91 +# ## COlor the mask
92 +# thresholds = np.arange(0.1,0.9,0.05)
93 +# for intensity in thresholds:
94 +# c=COLOR[2]
95 +# # intensity = 0.5
96 +# if "._" not in c:
97 +# print(intensity)
98 +# i += 1
99 +# # texture_image = cv2.imread(os.path.join(path, f))
100 +# m = color_the_mask(cloth_mask, c, intensity=intensity)
101 +# if first_frame:
102 +# img_concat_row = m
103 +# first_frame = False
104 +# else:
105 +# img_concat_row = cv2.hconcat((img_concat_row, m))
106 +#
107 +# if i % col_limit == 0:
108 +# if len(img_concat) > 0:
109 +# img_concat = cv2.vconcat((img_concat, img_concat_row))
110 +# else:
111 +# img_concat = img_concat_row
112 +# first_frame = True
113 +#
114 +#
115 +# cv2.imshow("k", img_concat)
116 +# cv2.imwrite("combine_N95_left.png", img_concat)
117 +# cv2.waitKey(0)
118 +# cc = 1
1 + __ __ _ _______ _ ______
2 +| \/ | | |__ __| | | ____|
3 +| \ / | __ _ ___| | _| | | |__ ___| |__ __ _ ___ ___
4 +| |\/| |/ _` / __| |/ / | | '_ \ / _ \ __/ _` |/ __/ _ \
5 +| | | | (_| \__ \ <| | | | | | __/ | | (_| | (_| __/
6 +|_| |_|\__,_|___/_|\_\_| |_| |_|\___|_| \__,_|\___\___|
1 +# Author: Aqeel Anwar(ICSRL)
2 +# Created: 7/30/2020, 1:44 PM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +# Code resued from https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
6 +# Make sure you run this from parent folder and not from utils folder i.e.
7 +# python utils/fetch_dataset.py
8 +
9 +import requests, os
10 +from zipfile import ZipFile
11 +import argparse
12 +import urllib
13 +
14 +parser = argparse.ArgumentParser(
15 + description="Download dataset - Python code to download associated datasets"
16 +)
17 +parser.add_argument(
18 + "--dataset",
19 + type=str,
20 + default="mfr2",
21 + help="Name of the dataset - Details on available datasets can be found at GitHub Page",
22 +)
23 +args = parser.parse_args()
24 +
25 +
26 +def download_file_from_google_drive(id, destination):
27 + URL = "https://docs.google.com/uc?export=download"
28 +
29 + session = requests.Session()
30 +
31 + response = session.get(URL, params={"id": id}, stream=True)
32 + token = get_confirm_token(response)
33 +
34 + if token:
35 + params = {"id": id, "confirm": token}
36 + response = session.get(URL, params=params, stream=True)
37 +
38 + save_response_content(response, destination)
39 +
40 +
41 +def get_confirm_token(response):
42 + for key, value in response.cookies.items():
43 + if key.startswith("download_warning"):
44 + return value
45 +
46 + return None
47 +
48 +
49 +def save_response_content(response, destination):
50 + CHUNK_SIZE = 32768
51 + print(destination)
52 + with open(destination, "wb") as f:
53 + for chunk in response.iter_content(CHUNK_SIZE):
54 + if chunk: # filter out keep-alive new chunks
55 + f.write(chunk)
56 +
57 +
58 +def download(t_url):
59 + response = urllib.request.urlopen(t_url)
60 + data = response.read()
61 + txt_str = str(data)
62 + lines = txt_str.split("\\n")
63 + return lines
64 +
65 +
66 +def Convert(lst):
67 + it = iter(lst)
68 + res_dct = dict(zip(it, it))
69 + return res_dct
70 +
71 +
72 +if __name__ == "__main__":
73 + # Fetch the latest download_links.txt file from GitHub
74 + link = "https://raw.githubusercontent.com/aqeelanwar/MaskTheFace/master/datasets/download_links.txt"
75 + links_dict = Convert(
76 + download(link)[0]
77 + .replace(":", "\n")
78 + .replace("b'", "")
79 + .replace("'", "")
80 + .replace(" ", "")
81 + .split("\n")
82 + )
83 + file_id = links_dict[args.dataset]
84 + destination = "datasets\_.zip"
85 + print("Downloading: ", args.dataset)
86 + download_file_from_google_drive(file_id, destination)
87 + print("Extracting: ", args.dataset)
88 + with ZipFile(destination, "r") as zipObj:
89 + # Extract all the contents of zip file in current directory
90 + zipObj.extractall(destination.rsplit(os.path.sep, 1)[0])
91 +
92 + os.remove(destination)
1 +# Author: aqeelanwar
2 +# Created: 4 May,2020, 1:30 AM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +import numpy as np
6 +from numpy.linalg import eig, inv
7 +
8 +def fitEllipse(x,y):
9 + x = x[:,np.newaxis]
10 + y = y[:,np.newaxis]
11 + D = np.hstack((x*x, x*y, y*y, x, y, np.ones_like(x)))
12 + S = np.dot(D.T,D)
13 + C = np.zeros([6,6])
14 + C[0,2] = C[2,0] = 2; C[1,1] = -1
15 + E, V = eig(np.dot(inv(S), C))
16 + n = np.argmax(np.abs(E))
17 + a = V[:,n]
18 + return a
19 +
20 +def ellipse_center(a):
21 + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
22 + num = b*b-a*c
23 + x0=(c*d-b*f)/num
24 + y0=(a*f-b*d)/num
25 + return np.array([x0,y0])
26 +
27 +
28 +def ellipse_angle_of_rotation( a ):
29 + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
30 + return 0.5*np.arctan(2*b/(a-c))
31 +
32 +
33 +def ellipse_axis_length( a ):
34 + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
35 + up = 2*(a*f*f+c*d*d+g*b*b-2*b*d*f-a*c*g)
36 + down1=(b*b-a*c)*( (c-a)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a))
37 + down2=(b*b-a*c)*( (a-c)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a))
38 + res1=np.sqrt(up/down1)
39 + res2=np.sqrt(up/down2)
40 + return np.array([res1, res2])
41 +
42 +def ellipse_angle_of_rotation2( a ):
43 + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
44 + if b == 0:
45 + if a > c:
46 + return 0
47 + else:
48 + return np.pi/2
49 + else:
50 + if a > c:
51 + return np.arctan(2*b/(a-c))/2
52 + else:
53 + return np.pi/2 + np.arctan(2*b/(a-c))/2
54 +
55 +# a = fitEllipse(x,y)
56 +# center = ellipse_center(a)
57 +# #phi = ellipse_angle_of_rotation(a)
58 +# phi = ellipse_angle_of_rotation2(a)
59 +# axes = ellipse_axis_length(a)
60 +#
61 +# print("center = ", center)
62 +# print("angle of rotation = ", phi)
63 +# print("axes = ", axes)
64 +
1 +# Author: aqeelanwar
2 +# Created: 2 May,2020, 2:49 AM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +from tkinter import filedialog
6 +from tkinter import *
7 +import cv2, os
8 +
9 +mouse_pts = []
10 +
11 +
12 +def get_mouse_points(event, x, y, flags, param):
13 + global mouseX, mouseY, mouse_pts
14 + if event == cv2.EVENT_LBUTTONDOWN:
15 + mouseX, mouseY = x, y
16 + cv2.circle(mask_im, (x, y), 10, (0, 255, 255), 10)
17 + if "mouse_pts" not in globals():
18 + mouse_pts = []
19 + mouse_pts.append((x, y))
20 + # print("Point detected")
21 + # print((x,y))
22 +
23 +
24 +root = Tk()
25 +filename = filedialog.askopenfilename(
26 + initialdir="/",
27 + title="Select file",
28 + filetypes=(("PNG files", "*.PNG"), ("png files", "*.png"), ("All files", "*.*")),
29 +)
30 +root.destroy()
31 +filename_split = os.path.split(filename)
32 +folder = filename_split[0]
33 +file = filename_split[1]
34 +file_split = file.split(".")
35 +new_filename = folder + "/" + file_split[0] + "_marked." + file_split[-1]
36 +mask_im = cv2.imread(filename)
37 +cv2.namedWindow("Mask")
38 +cv2.setMouseCallback("Mask", get_mouse_points)
39 +
40 +while True:
41 + cv2.imshow("Mask", mask_im)
42 + cv2.waitKey(1)
43 + if len(mouse_pts) == 6:
44 + cv2.destroyWindow("Mask")
45 + break
46 + first_frame_display = False
47 +points = mouse_pts
48 +print(points)
49 +print("----------------------------------------------------------------")
50 +print("Copy the following code and paste it in masks.cfg")
51 +print("----------------------------------------------------------------")
52 +name_points = ["a", "b", "c", "d", "e", "f"]
53 +
54 +mask_title = "[" + file_split[0] + "]"
55 +print(mask_title)
56 +print("template: ", filename)
57 +for i in range(len(mouse_pts)):
58 + name = (
59 + "mask_"
60 + + name_points[i]
61 + + ": "
62 + + str(mouse_pts[i][0])
63 + + ","
64 + + str(mouse_pts[i][1])
65 + )
66 + print(name)
67 +
68 +cv2.imwrite(new_filename, mask_im)
1 +# Author: Aqeel Anwar(ICSRL)
2 +# Created: 9/20/2019, 12:43 PM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +from configparser import ConfigParser
6 +from dotmap import DotMap
7 +
8 +
9 +def ConvertIfStringIsInt(input_string):
10 + try:
11 + float(input_string)
12 +
13 + try:
14 + if int(input_string) == float(input_string):
15 + return int(input_string)
16 + else:
17 + return float(input_string)
18 + except ValueError:
19 + return float(input_string)
20 +
21 + except ValueError:
22 + return input_string
23 +
24 +
25 +def read_cfg(config_filename="masks/masks.cfg", mask_type="surgical", verbose=False):
26 + parser = ConfigParser()
27 + parser.optionxform = str
28 + parser.read(config_filename)
29 + cfg = DotMap()
30 + section_name = mask_type
31 +
32 + if verbose:
33 + hyphens = "-" * int((80 - len(config_filename)) / 2)
34 + print(hyphens + " " + config_filename + " " + hyphens)
35 +
36 + # for section_name in parser.sections():
37 +
38 + if verbose:
39 + print("[" + section_name + "]")
40 + for name, value in parser.items(section_name):
41 + value = ConvertIfStringIsInt(value)
42 + if name != "template":
43 + cfg[name] = tuple(int(s) for s in value.split(","))
44 + else:
45 + cfg[name] = value
46 + spaces = " " * (30 - len(name))
47 + if verbose:
48 + print(name + ":" + spaces + str(cfg[name]))
49 +
50 + return cfg
1 +{
2 + "cells": [
3 + {
4 + "cell_type": "code",
5 + "execution_count": 18,
6 + "metadata": {},
7 + "outputs": [],
8 + "source": [
9 + "import torch\n",
10 + "import torch.nn as nn\n",
11 + "from torch.nn import Parameter\n",
12 + "import torch.nn.functional as F\n",
13 + "from torchvision import transforms as tf\n",
14 + "import torch.utils.data as data\n",
15 + "\n",
16 + "import os\n",
17 + "import cv2\n",
18 + "import functools\n",
19 + "import numpy as np\n",
20 + "from PIL import Image\n",
21 + "import matplotlib.pyplot as plt"
22 + ]
23 + },
24 + {
25 + "cell_type": "code",
26 + "execution_count": 1,
27 + "metadata": {},
28 + "outputs": [],
29 + "source": [
30 + "from models import vgg19"
31 + ]
32 + },
33 + {
34 + "cell_type": "code",
35 + "execution_count": 14,
36 + "metadata": {},
37 + "outputs": [],
38 + "source": [
39 + "model = vgg19(pretrained=True).features[:-2]\n",
40 + "\n",
41 + "model = model.eval()"
42 + ]
43 + },
44 + {
45 + "cell_type": "code",
46 + "execution_count": 15,
47 + "metadata": {},
48 + "outputs": [
49 + {
50 + "data": {
51 + "text/plain": [
52 + "Sequential(\n",
53 + " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
54 + " (1): ReLU(inplace=True)\n",
55 + " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
56 + " (3): ReLU(inplace=True)\n",
57 + " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
58 + " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
59 + " (6): ReLU(inplace=True)\n",
60 + " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
61 + " (8): ReLU(inplace=True)\n",
62 + " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
63 + " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
64 + " (11): ReLU(inplace=True)\n",
65 + " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
66 + " (13): ReLU(inplace=True)\n",
67 + " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
68 + " (15): ReLU(inplace=True)\n",
69 + " (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
70 + " (17): ReLU(inplace=True)\n",
71 + " (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
72 + " (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
73 + " (20): ReLU(inplace=True)\n",
74 + " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
75 + " (22): ReLU(inplace=True)\n",
76 + " (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
77 + " (24): ReLU(inplace=True)\n",
78 + " (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
79 + " (26): ReLU(inplace=True)\n",
80 + " (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
81 + " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
82 + " (29): ReLU(inplace=True)\n",
83 + " (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
84 + " (31): ReLU(inplace=True)\n",
85 + " (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
86 + " (33): ReLU(inplace=True)\n",
87 + " (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
88 + ")"
89 + ]
90 + },
91 + "execution_count": 15,
92 + "metadata": {},
93 + "output_type": "execute_result"
94 + }
95 + ],
96 + "source": [
97 + "model"
98 + ]
99 + },
100 + {
101 + "cell_type": "code",
102 + "execution_count": 9,
103 + "metadata": {},
104 + "outputs": [],
105 + "source": [
106 + "img = torch.rand(4,3,256,256)"
107 + ]
108 + },
109 + {
110 + "cell_type": "code",
111 + "execution_count": 10,
112 + "metadata": {},
113 + "outputs": [
114 + {
115 + "data": {
116 + "text/plain": [
117 + "torch.Size([4, 512, 8, 8])"
118 + ]
119 + },
120 + "execution_count": 10,
121 + "metadata": {},
122 + "output_type": "execute_result"
123 + }
124 + ],
125 + "source": [
126 + "out = model(img)\n",
127 + "out.shape"
128 + ]
129 + },
130 + {
131 + "cell_type": "code",
132 + "execution_count": 19,
133 + "metadata": {},
134 + "outputs": [],
135 + "source": [
136 + "class GatedConv2d(nn.Module):\n",
137 + " def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, activation = 'lrelu', norm = 'in'):\n",
138 + " super(GatedConv2d, self).__init__()\n",
139 + " self.pad = nn.ZeroPad2d(padding)\n",
140 + " if norm is not None:\n",
141 + " self.norm = nn.InstanceNorm2d(out_channels)\n",
142 + " else:\n",
143 + " self.norm = None\n",
144 + " \n",
145 + " if activation == 'tanh':\n",
146 + " self.activation = nn.Tanh()\n",
147 + " else:\n",
148 + " self.activation = nn.LeakyReLU(0.2, inplace = True)\n",
149 + " \n",
150 + " \n",
151 + " self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)\n",
152 + " self.mask_conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)\n",
153 + " self.sigmoid = torch.nn.Sigmoid()\n",
154 + " \n",
155 + " def forward(self, x):\n",
156 + " x = self.pad(x)\n",
157 + " conv = self.conv2d(x)\n",
158 + " mask = self.mask_conv2d(x)\n",
159 + " gated_mask = self.sigmoid(mask)\n",
160 + " x = conv * gated_mask\n",
161 + " if self.norm:\n",
162 + " x = self.norm(x)\n",
163 + " if self.activation:\n",
164 + " x = self.activation(x)\n",
165 + " return x\n",
166 + "\n",
167 + "class TransposeGatedConv2d(nn.Module):\n",
168 + " def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, norm=None, scale_factor = 2):\n",
169 + " super(TransposeGatedConv2d, self).__init__()\n",
170 + " # Initialize the conv scheme\n",
171 + " self.scale_factor = scale_factor\n",
172 + " self.gated_conv2d = GatedConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, norm=norm)\n",
173 + " \n",
174 + " def forward(self, x):\n",
175 + " x = F.interpolate(x, scale_factor = self.scale_factor, mode = 'nearest')\n",
176 + " x = self.gated_conv2d(x)\n",
177 + " return x"
178 + ]
179 + },
180 + {
181 + "cell_type": "code",
182 + "execution_count": 20,
183 + "metadata": {},
184 + "outputs": [],
185 + "source": [
186 + "class GatedGenerator(nn.Module):\n",
187 + " def __init__(self, in_channels=4, latent_channels=64, out_channels=3):\n",
188 + " super(GatedGenerator, self).__init__()\n",
189 + " self.coarse = nn.Sequential(\n",
190 + " # encoder\n",
191 + " GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None),\n",
192 + " GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1),\n",
193 + " GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1),\n",
194 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1),\n",
195 + " # Bottleneck\n",
196 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n",
197 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n",
198 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2),\n",
199 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4),\n",
200 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8),\n",
201 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16),\n",
202 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n",
203 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n",
204 + " # decoder\n",
205 + " TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1),\n",
206 + " GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1),\n",
207 + " TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1),\n",
208 + " GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None)\n",
209 + " )\n",
210 + " self.refinement = nn.Sequential(\n",
211 + " # encoder\n",
212 + " GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None),\n",
213 + " GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1),\n",
214 + " GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1),\n",
215 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1),\n",
216 + " # Bottleneck\n",
217 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n",
218 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n",
219 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2),\n",
220 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4),\n",
221 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8),\n",
222 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16),\n",
223 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n",
224 + " GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),\n",
225 + " # decoder\n",
226 + " TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1),\n",
227 + " GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1),\n",
228 + " TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1),\n",
229 + " GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None)\n",
230 + " )\n",
231 + " \n",
232 + " def forward(self, img, mask):\n",
233 + " # img: entire img\n",
234 + " # mask: 1 for mask region; 0 for unmask region\n",
235 + " # 1 - mask: unmask\n",
236 + " # img * (1 - mask): ground truth unmask region\n",
237 + " # Coarse\n",
238 + " \n",
239 + " first_masked_img = img * (1 - mask) + mask\n",
240 + " first_in = torch.cat((first_masked_img, mask), 1) # in: [B, 4, H, W]\n",
241 + " first_out = self.coarse(first_in) # out: [B, 3, H, W]\n",
242 + " # Refinement\n",
243 + " second_masked_img = img * (1 - mask) + first_out * mask\n",
244 + " second_in = torch.cat((second_masked_img, mask), 1) # in: [B, 4, H, W]\n",
245 + " second_out = self.refinement(second_in) # out: [B, 3, H, W]\n",
246 + " return first_out, second_out"
247 + ]
248 + },
249 + {
250 + "cell_type": "code",
251 + "execution_count": 21,
252 + "metadata": {},
253 + "outputs": [],
254 + "source": [
255 + "class NLayerDiscriminator(nn.Module):\n",
256 + " def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):\n",
257 + " super(NLayerDiscriminator, self).__init__()\n",
258 + " if type(norm_layer) == functools.partial:\n",
259 + " use_bias = norm_layer.func == nn.InstanceNorm2d\n",
260 + " else:\n",
261 + " use_bias = norm_layer == nn.InstanceNorm2d\n",
262 + "\n",
263 + " kw = 4\n",
264 + " padw = 1\n",
265 + " sequence = [\n",
266 + " nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),\n",
267 + " nn.LeakyReLU(0.2, True)\n",
268 + " ]\n",
269 + "\n",
270 + " nf_mult = 1\n",
271 + " nf_mult_prev = 1\n",
272 + " for n in range(1, n_layers):\n",
273 + " nf_mult_prev = nf_mult\n",
274 + " nf_mult = min(2**n, 8)\n",
275 + " sequence += [\n",
276 + " nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n",
277 + " kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n",
278 + " norm_layer(ndf * nf_mult),\n",
279 + " nn.LeakyReLU(0.2, True)\n",
280 + " ]\n",
281 + "\n",
282 + " nf_mult_prev = nf_mult\n",
283 + " nf_mult = min(2**n_layers, 8)\n",
284 + " sequence += [\n",
285 + " nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n",
286 + " kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n",
287 + " norm_layer(ndf * nf_mult),\n",
288 + " nn.LeakyReLU(0.2, True)\n",
289 + " ]\n",
290 + "\n",
291 + " sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]\n",
292 + "\n",
293 + " if use_sigmoid:\n",
294 + " sequence += [nn.Sigmoid()]\n",
295 + "\n",
296 + " self.model = nn.Sequential(*sequence)\n",
297 + "\n",
298 + " def forward(self, input):\n",
299 + " return self.model(input)"
300 + ]
301 + },
302 + {
303 + "cell_type": "code",
304 + "execution_count": 22,
305 + "metadata": {},
306 + "outputs": [],
307 + "source": [
308 + "class PerceptualNet(nn.Module):\n",
309 + " def __init__(self):\n",
310 + " super(PerceptualNet, self).__init__()\n",
311 + " self.features = nn.Sequential(\n",
312 + " nn.Conv2d(3, 64, 3, 1, 1),\n",
313 + " nn.ReLU(inplace = True),\n",
314 + " nn.Conv2d(64, 64, 3, 1, 1),\n",
315 + " nn.ReLU(inplace = True),\n",
316 + " nn.MaxPool2d(2, 2),\n",
317 + " nn.Conv2d(64, 128, 3, 1, 1),\n",
318 + " nn.ReLU(inplace = True),\n",
319 + " nn.Conv2d(128, 128, 3, 1, 1),\n",
320 + " nn.ReLU(inplace = True),\n",
321 + " nn.MaxPool2d(2, 2),\n",
322 + " nn.Conv2d(128, 256, 3, 1, 1),\n",
323 + " nn.ReLU(inplace = True),\n",
324 + " nn.Conv2d(256, 256, 3, 1, 1),\n",
325 + " nn.ReLU(inplace = True),\n",
326 + " nn.Conv2d(256, 256, 3, 1, 1),\n",
327 + " nn.MaxPool2d(2, 2),\n",
328 + " nn.Conv2d(256, 512, 3, 1, 1),\n",
329 + " nn.ReLU(inplace = True),\n",
330 + " nn.Conv2d(512, 512, 3, 1, 1),\n",
331 + " nn.ReLU(inplace = True),\n",
332 + " nn.Conv2d(512, 512, 3, 1, 1)\n",
333 + " )\n",
334 + "\n",
335 + " def forward(self, x):\n",
336 + " x = self.features(x)\n",
337 + " return x"
338 + ]
339 + },
340 + {
341 + "cell_type": "code",
342 + "execution_count": 6,
343 + "metadata": {},
344 + "outputs": [],
345 + "source": [
346 + "class GANLoss(nn.Module):\n",
347 + " def __init__(self, target_real_label=1.0, target_fake_label=0.0):\n",
348 + " super(GANLoss, self).__init__()\n",
349 + " self.register_buffer('real_label', torch.tensor(target_real_label))\n",
350 + " self.register_buffer('fake_label', torch.tensor(target_fake_label))\n",
351 + " self.loss = nn.BCELoss()\n",
352 + "\n",
353 + " def get_target_tensor(self, input, target_is_real):\n",
354 + " if target_is_real:\n",
355 + " target_tensor = self.real_label\n",
356 + " else:\n",
357 + " target_tensor = self.fake_label\n",
358 + " return target_tensor.expand_as(input)\n",
359 + "\n",
360 + " def __call__(self, input, target_is_real):\n",
361 + " target_tensor = self.get_target_tensor(input, target_is_real)\n",
362 + " return self.loss(input, target_tensor)"
363 + ]
364 + },
365 + {
366 + "cell_type": "code",
367 + "execution_count": 7,
368 + "metadata": {},
369 + "outputs": [],
370 + "source": [
371 + "class InpaintDataset(data.Dataset):\n",
372 + " def __init__(self, img_dir):\n",
373 + " self.img_dir = img_dir\n",
374 + " self.load_images()\n",
375 + " \n",
376 + " def load_images(self):\n",
377 + " self.fns =[]\n",
378 + " img_paths = sorted(os.listdir(self.img_dir))\n",
379 + " for path in img_paths:\n",
380 + " self.fns.append(os.path.join(self.img_dir, path))\n",
381 + " \n",
382 + " def __getitem__(self, index):\n",
383 + " img_path = self.fns[index]\n",
384 + " img = cv2.imread(img_path)\n",
385 + " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
386 + " img = cv2.resize(img, (256,256))\n",
387 + " \n",
388 + " mask = self.random_ff_mask()\n",
389 + " img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()\n",
390 + " mask = torch.from_numpy(mask.astype(np.float32)).contiguous()\n",
391 + " return img, mask\n",
392 + " \n",
393 + " def collate_fn(self, batch):\n",
394 + " imgs = torch.stack([i[0] for i in batch])\n",
395 + " masks = torch.stack([i[1] for i in batch])\n",
396 + " return {\n",
397 + " 'imgs': imgs,\n",
398 + " 'masks': masks\n",
399 + " }\n",
400 + " \n",
401 + " def __len__(self):\n",
402 + " return len(self.fns)\n",
403 + " \n",
404 + " def random_ff_mask(self, shape =256 , max_angle = 4, max_len = 40, max_width = 10, times = 15):\n",
405 + " \"\"\"Generate a random free form mask with configuration.\n",
406 + " Args:\n",
407 + " config: Config should have configuration including IMG_SHAPES,\n",
408 + " VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.\n",
409 + " Returns:\n",
410 + " tuple: (top, left, height, width)\n",
411 + " \"\"\"\n",
412 + " height = shape\n",
413 + " width = shape\n",
414 + " mask = np.zeros((height, width), np.float32)\n",
415 + " times = np.random.randint(times)\n",
416 + " for i in range(times):\n",
417 + " start_x = np.random.randint(width)\n",
418 + " start_y = np.random.randint(height)\n",
419 + " for j in range(1 + np.random.randint(5)):\n",
420 + " angle = 0.01 + np.random.randint(max_angle)\n",
421 + " if i % 2 == 0:\n",
422 + " angle = 2 * 3.1415926 - angle\n",
423 + " length = 10 + np.random.randint(max_len)\n",
424 + " brush_w = 5 + np.random.randint(max_width)\n",
425 + " end_x = (start_x + length * np.sin(angle)).astype(np.int32)\n",
426 + " end_y = (start_y + length * np.cos(angle)).astype(np.int32)\n",
427 + " cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)\n",
428 + " start_x, start_y = end_x, end_y\n",
429 + " return mask.reshape((1, ) + mask.shape).astype(np.float32)"
430 + ]
431 + },
432 + {
433 + "cell_type": "code",
434 + "execution_count": null,
435 + "metadata": {},
436 + "outputs": [],
437 + "source": [
438 + "dataset = InpaintDataset(img_dir='datasets/places365standard_easyformat/places365_standard/train/waterfall')\n",
439 + "dataloader = data.DataLoader(dataset, batch_size=4, collate_fn = dataset.collate_fn)"
440 + ]
441 + },
442 + {
443 + "cell_type": "code",
444 + "execution_count": null,
445 + "metadata": {},
446 + "outputs": [],
447 + "source": [
448 + "for batch in dataloader:\n",
449 + " imgs = batch['imgs']\n",
450 + " masks = batch['masks']\n",
451 + " \n",
452 + " break"
453 + ]
454 + },
455 + {
456 + "cell_type": "code",
457 + "execution_count": 8,
458 + "metadata": {},
459 + "outputs": [],
460 + "source": [
461 + "device = torch.device('cuda')"
462 + ]
463 + },
464 + {
465 + "cell_type": "code",
466 + "execution_count": 23,
467 + "metadata": {},
468 + "outputs": [
469 + {
470 + "ename": "NameError",
471 + "evalue": "name 'GANLoss' is not defined",
472 + "output_type": "error",
473 + "traceback": [
474 + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
475 + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
476 + "\u001b[1;32m<ipython-input-23-bdcc75eef256>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mmodel_D\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mNLayerDiscriminator\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0muse_sigmoid\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mmodel_P\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPerceptualNet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mcriterion_adv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mGANLoss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 5\u001b[0m \u001b[0mcriterion_rec\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMSELoss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mcriterion_per\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mL1Loss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
477 + "\u001b[1;31mNameError\u001b[0m: name 'GANLoss' is not defined"
478 + ]
479 + }
480 + ],
481 + "source": [
482 + "model_G = GatedGenerator()\n",
483 + "model_D = NLayerDiscriminator(3, use_sigmoid=True)\n",
484 + "model_P = PerceptualNet()\n",
485 + "criterion_adv = GANLoss()\n",
486 + "criterion_rec = nn.MSELoss()\n",
487 + "criterion_per = nn.L1Loss()\n",
488 + "optimizer_D = torch.optim.Adam(model_D.parameters(), lr=1e-4)\n",
489 + "optimizer_G = torch.optim.Adam(model_G.parameters(), lr=1e-4)"
490 + ]
491 + },
492 + {
493 + "cell_type": "code",
494 + "execution_count": 24,
495 + "metadata": {},
496 + "outputs": [],
497 + "source": [
498 + "torch.save({\n",
499 + " 'D': model_D.state_dict(),\n",
500 + " 'G': model_G.state_dict()\n",
501 + "}, 's.pth')"
502 + ]
503 + },
504 + {
505 + "cell_type": "code",
506 + "execution_count": null,
507 + "metadata": {},
508 + "outputs": [],
509 + "source": [
510 + "def count_params(model):\n",
511 + " return sum(p.numel() for p in model.parameters() if p.requires_grad)"
512 + ]
513 + },
514 + {
515 + "cell_type": "code",
516 + "execution_count": null,
517 + "metadata": {},
518 + "outputs": [],
519 + "source": [
520 + "print(count_params(model_G))\n",
521 + "print(count_params(model_D))\n",
522 + "print(count_params(model_P))"
523 + ]
524 + },
525 + {
526 + "cell_type": "code",
527 + "execution_count": 10,
528 + "metadata": {},
529 + "outputs": [],
530 + "source": [
531 + "def random_ff_mask(shape =256 , max_angle = 4, max_len = 40, max_width = 10, times = 15):\n",
532 + " \"\"\"Generate a random free form mask with configuration.\n",
533 + " Args:\n",
534 + " config: Config should have configuration including IMG_SHAPES,\n",
535 + " VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.\n",
536 + " Returns:\n",
537 + " tuple: (top, left, height, width)\n",
538 + " \"\"\"\n",
539 + " height = shape\n",
540 + " width = shape\n",
541 + " mask = np.zeros((height, width), np.float32)\n",
542 + " times = np.random.randint(times)\n",
543 + " for i in range(times):\n",
544 + " start_x = np.random.randint(width)\n",
545 + " start_y = np.random.randint(height)\n",
546 + " for j in range(1 + np.random.randint(5)):\n",
547 + " angle = 0.01 + np.random.randint(max_angle)\n",
548 + " if i % 2 == 0:\n",
549 + " angle = 2 * 3.1415926 - angle\n",
550 + " length = 10 + np.random.randint(max_len)\n",
551 + " brush_w = 5 + np.random.randint(max_width)\n",
552 + " end_x = (start_x + length * np.sin(angle)).astype(np.int32)\n",
553 + " end_y = (start_y + length * np.cos(angle)).astype(np.int32)\n",
554 + " cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)\n",
555 + " start_x, start_y = end_x, end_y\n",
556 + " return mask.reshape((1, ) + mask.shape).astype(np.float32)"
557 + ]
558 + },
559 + {
560 + "cell_type": "code",
561 + "execution_count": 11,
562 + "metadata": {},
563 + "outputs": [],
564 + "source": [
565 + "img = cv2.imread('datasets/places365standard_easyformat/places365_standard/train/waterfall/00000003.jpg')\n",
566 + "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
567 + "img = cv2.resize(img, (256, 256))\n",
568 + "img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()\n",
569 + "img_tensor = img.unsqueeze(0)\n",
570 + "mask = random_ff_mask()\n",
571 + "mask = torch.from_numpy(mask).contiguous().unsqueeze(0)"
572 + ]
573 + },
574 + {
575 + "cell_type": "code",
576 + "execution_count": null,
577 + "metadata": {},
578 + "outputs": [],
579 + "source": [
580 + "def visualize(img):\n",
581 + " np_img = img.squeeze(0).detach().cpu().numpy()\n",
582 + " return np_img.transpose(1, 2, 0)"
583 + ]
584 + },
585 + {
586 + "cell_type": "code",
587 + "execution_count": null,
588 + "metadata": {},
589 + "outputs": [],
590 + "source": [
591 + "plt.imshow(visualize(first_out_wholeimg))"
592 + ]
593 + },
594 + {
595 + "cell_type": "code",
596 + "execution_count": 12,
597 + "metadata": {},
598 + "outputs": [],
599 + "source": [
600 + "first_out, second_out = model_G(img_tensor, mask)\n",
601 + "\n",
602 + "first_out_wholeimg = img_tensor * (1 - mask) + first_out * mask \n",
603 + "second_out_wholeimg = img_tensor * (1 - mask) + second_out * mask"
604 + ]
605 + },
606 + {
607 + "cell_type": "code",
608 + "execution_count": 13,
609 + "metadata": {},
610 + "outputs": [],
611 + "source": [
612 + "# Train discriminator\n",
613 + "optimizer_D.zero_grad()\n",
614 + "\n",
615 + "fake_D = model_D(second_out_wholeimg.detach())\n",
616 + "real_D = model_D(img_tensor)\n",
617 + "\n",
618 + "loss_fake_D = criterion_adv(fake_D, target_is_real=False)\n",
619 + "loss_real_D = criterion_adv(real_D, target_is_real=True)\n",
620 + "\n",
621 + "loss_D = (loss_fake_D + loss_real_D) *0.5\n",
622 + "\n",
623 + "loss_D.backward()\n",
624 + "optimizer_D.step()"
625 + ]
626 + },
627 + {
628 + "cell_type": "code",
629 + "execution_count": 15,
630 + "metadata": {},
631 + "outputs": [],
632 + "source": [
633 + "# Train Generator\n",
634 + "\n",
635 + "optimizer_G.zero_grad()\n",
636 + "\n",
637 + "fake_D = model_D(second_out_wholeimg)\n",
638 + "G_loss = criterion_adv(fake_D, target_is_real=True)"
639 + ]
640 + },
641 + {
642 + "cell_type": "code",
643 + "execution_count": 16,
644 + "metadata": {},
645 + "outputs": [],
646 + "source": [
647 + "# Reconstruction loss\n",
648 + "\n",
649 + "loss_rec_1 = criterion_rec(first_out_wholeimg, img_tensor)\n",
650 + "loss_rec_2 = criterion_rec(second_out_wholeimg, img_tensor)"
651 + ]
652 + },
653 + {
654 + "cell_type": "code",
655 + "execution_count": 17,
656 + "metadata": {},
657 + "outputs": [],
658 + "source": [
659 + "# Perceptual loss\n",
660 + "\n",
661 + "img_featuremaps = model_P(img_tensor) # feature maps\n",
662 + "second_out_wholeimg_featuremaps = model_P(second_out_wholeimg)\n",
663 + "\n",
664 + "loss_P = criterion_per(second_out_wholeimg_featuremaps, img_featuremaps)"
665 + ]
666 + },
667 + {
668 + "cell_type": "code",
669 + "execution_count": 18,
670 + "metadata": {},
671 + "outputs": [
672 + {
673 + "ename": "NameError",
674 + "evalue": "name 'lambda_G' is not defined",
675 + "output_type": "error",
676 + "traceback": [
677 + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
678 + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
679 + "\u001b[1;32m<ipython-input-18-dbfe5f51e2fc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlambda_G\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mG_loss\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mlambda_rec_1\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mloss_rec_1\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mlambda_rec_2\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mloss_rec_2\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mlambda_per\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mloss_P\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0moptimizer_G\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
680 + "\u001b[1;31mNameError\u001b[0m: name 'lambda_G' is not defined"
681 + ]
682 + }
683 + ],
684 + "source": [
685 + "loss = lambda_G * G_loss + lambda_rec_1 * loss_rec_1 + lambda_rec_2 * loss_rec_2 + lambda_per * loss_P\n",
686 + "loss.backward()\n",
687 + "optimizer_G.step()"
688 + ]
689 + }
690 + ],
691 + "metadata": {
692 + "kernelspec": {
693 + "display_name": "Python 3",
694 + "language": "python",
695 + "name": "python3"
696 + },
697 + "language_info": {
698 + "codemirror_mode": {
699 + "name": "ipython",
700 + "version": 3
701 + },
702 + "file_extension": ".py",
703 + "mimetype": "text/x-python",
704 + "name": "python",
705 + "nbconvert_exporter": "python",
706 + "pygments_lexer": "ipython3",
707 + "version": "3.7.6"
708 + }
709 + },
710 + "nbformat": 4,
711 + "nbformat_minor": 4
712 +}
1 +from .config import Config
...\ No newline at end of file ...\ No newline at end of file
1 +import yaml
2 +
3 +class Config():
4 + def __init__(self, yaml_path):
5 + yaml_file = open(yaml_path)
6 + self._attr = yaml.load(yaml_file, Loader=yaml.FullLoader)['settings']
7 +
8 + def __getattr__(self, attr):
9 + try:
10 + return self._attr[attr]
11 + except KeyError:
12 + return None
1 +settings:
2 + root_dir: "./datasets/celeba/images/"
3 + checkpoint_path: "weights"
4 + sample_folder: "sample"
5 +
6 + cuda: True
7 + lr: 0.001
8 + batch_size: 2
9 + num_workers: 4
10 +
11 + step_iters: [10000, 15000, 20000]
12 + gamma: 0.1
13 +
14 + d_num_layers: 3
15 +
16 + visualize_per_iter: 500
17 + save_per_iter: 500
18 + print_per_iter: 10
19 + num_epochs: 100
20 +
21 + lambda_G: 1.0
22 + lambda_rec_1: 100.0
23 + lambda_rec_2: 100.0
24 + lambda_per: 10.0
25 +
26 + img_size: 512
1 +settings:
2 + root_dir: "./datasets/places365_10classes"
3 + checkpoint_path: "/content/drive/MyDrive/weights/Places365 Inpainting/phase 3"
4 + sample_folder: "/content/drive/MyDrive/results/Places365 Inpainting/phase 3"
5 +
6 + cuda: True
7 + lr: 0.0001
8 + batch_size: 8
9 + num_workers: 4
10 +
11 + step_iters: [50000, 75000, 100000]
12 + gamma: 0.1
13 +
14 + d_num_layers: 3
15 +
16 + visualize_per_iter: 500
17 + save_per_iter: 500
18 + print_per_iter: 10
19 + num_epochs: 100
20 +
21 + lambda_G: 0.3
22 + lambda_rec_1: 10.0
23 + lambda_rec_2: 10.0
24 + lambda_per: 1.0
25 +
26 + img_size: 256
27 + max_angle: 4
28 + max_len: 50
29 + max_width: 30
30 + times: 15
1 +settings:
2 + root_dir: "./datasets/celeba/images/"
3 + train_anns: "./datasets/celeba/annotations/train.csv"
4 + val_anns: "./datasets/celeba/annotations/val.csv"
5 +
6 + checkpoint_path: "weights" #"/content/drive/MyDrive/weights/Places365 Inpainting/unet/phase 1"
7 + sample_folder: "sample" #"/content/drive/MyDrive/results/Places365 Inpainting/unet/phase 1"
8 +
9 + cuda: True
10 + lr: 0.001
11 + batch_size: 4
12 + num_workers: 4
13 +
14 + step_iters: [50000, 75000, 100000]
15 + gamma: 0.1
16 +
17 + visualize_per_iter: 1000
18 + save_per_iter: 1000
19 + print_per_iter: 10
20 + num_epochs: 100
21 +
22 + img_size: 512
1 +from .dataset import Places365Dataset, FacemaskDataset
2 +from .dataset_seg import FacemaskSegDataset
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
1 +import os
2 +import csv
3 +
4 +f = open("./datasets/celeba/annotations/train.csv", "w", newline="")
5 +wr = csv.writer(f)
6 +wr.writerow(["_", "img_name", "mask_name"])
7 +
8 +for i in range(23304):
9 + wr.writerow(
10 + [
11 + i,
12 + "celeba512_30k_masked/"
13 + + os.listdir("./datasets/celeba/images/celeba512_30k_masked")[i],
14 + "celeba512_30k_binary/"
15 + + os.listdir("./datasets/celeba/images/celeba512_30k_binary")[i],
16 + ]
17 + )
18 +
19 +f.close()
20 +
21 +f = open("./datasets/celeba/annotations/val.csv", "w", newline="")
22 +wr = csv.writer(f)
23 +wr.writerow(["_", "img_name", "mask_name"])
24 +
25 +for i in range(23304, 29131):
26 + wr.writerow(
27 + [
28 + i,
29 + "celeba512_30k_masked/"
30 + + os.listdir("./datasets/celeba/images/celeba512_30k_masked")[i],
31 + "celeba512_30k_binary/"
32 + + os.listdir("./datasets/celeba/images/celeba512_30k_binary")[i],
33 + ]
34 + )
35 +
36 +f.close()
1 +import os
2 +import torch
3 +import torch.nn as nn
4 +import torch.utils.data as data
5 +import cv2
6 +import numpy as np
7 +from tqdm import tqdm
8 +
9 +class Places365Dataset(data.Dataset):
10 + def __init__(self, cfg):
11 + self.root_dir = cfg.root_dir
12 + self.cfg = cfg
13 + self.load_images()
14 +
15 + def load_images(self):
16 + self.fns =[]
17 + idx = 0
18 + img_paths = os.listdir(self.root_dir)
19 + for cls_id in img_paths:
20 + paths = os.path.join(self.root_dir, cls_id)
21 + file_paths = os.listdir(paths)
22 + for img_name in file_paths:
23 + filename = os.path.join(paths, img_name)
24 + self.fns.append(filename)
25 +
26 + def __getitem__(self, index):
27 + img_path = self.fns[index]
28 + img = cv2.imread(img_path)
29 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
30 + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
31 +
32 + mask = self.random_ff_mask(
33 + shape = self.cfg.img_size,
34 + max_angle = self.cfg.max_angle,
35 + max_len = self.cfg.max_len,
36 + max_width = self.cfg.max_width,
37 + times = self.cfg.times)
38 +
39 + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
40 + mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
41 +
42 + return img, mask
43 +
44 + def collate_fn(self, batch):
45 + imgs = torch.stack([i[0] for i in batch])
46 + masks = torch.stack([i[1] for i in batch])
47 + return {
48 + 'imgs': imgs,
49 + 'masks': masks
50 + }
51 +
52 + def __len__(self):
53 + return len(self.fns)
54 +
55 + def random_ff_mask(self, shape = 256 , max_angle = 4, max_len = 50, max_width = 20, times = 15):
56 + """Generate a random free form mask with configuration.
57 + Args:
58 + config: Config should have configuration including IMG_SHAPES,
59 + VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
60 + Returns:
61 + tuple: (top, left, height, width)
62 + """
63 + height = shape
64 + width = shape
65 + mask = np.zeros((height, width), np.float32)
66 + times = np.random.randint(10, times)
67 + for i in range(times):
68 + start_x = np.random.randint(width)
69 + start_y = np.random.randint(height)
70 + for j in range(1 + np.random.randint(5)):
71 + angle = 0.01 + np.random.randint(max_angle)
72 + if i % 2 == 0:
73 + angle = 2 * 3.1415926 - angle
74 + length = 10 + np.random.randint(max_len)
75 + brush_w = 5 + np.random.randint(max_width)
76 + end_x = (start_x + length * np.sin(angle)).astype(np.int32)
77 + end_y = (start_y + length * np.cos(angle)).astype(np.int32)
78 + cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
79 + start_x, start_y = end_x, end_y
80 + return mask.reshape((1, ) + mask.shape).astype(np.float32)
81 +
82 +
83 +class FacemaskDataset(data.Dataset):
84 + def __init__(self, cfg):
85 + self.root_dir = cfg.root_dir
86 + self.cfg = cfg
87 +
88 + self.mask_folder = os.path.join(self.root_dir, 'celeba512_30k_binary')
89 + self.img_folder = os.path.join(self.root_dir, 'celeba512_30k')
90 + self.load_images()
91 +
92 + def load_images(self):
93 + self.fns = []
94 + idx = 0
95 + img_paths = sorted(os.listdir(self.img_folder))
96 + for img_name in img_paths:
97 + mask_name = img_name.split('.')[0]+'_binary.jpg'
98 + img_path = os.path.join(self.img_folder, img_name)
99 + mask_path = os.path.join(self.mask_folder, mask_name)
100 + if os.path.isfile(mask_path):
101 + self.fns.append([img_path, mask_path])
102 +
103 + def __getitem__(self, index):
104 + img_path, mask_path = self.fns[index]
105 + img = cv2.imread(img_path)
106 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
107 + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
108 +
109 +
110 + mask = cv2.imread(mask_path, 0)
111 +
112 + mask[mask>0]=1.0
113 + mask = np.expand_dims(mask, axis=0)
114 +
115 + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
116 + mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
117 + return img, mask
118 +
119 + def collate_fn(self, batch):
120 + imgs = torch.stack([i[0] for i in batch])
121 + masks = torch.stack([i[1] for i in batch])
122 + return {
123 + 'imgs': imgs,
124 + 'masks': masks
125 + }
126 +
127 + def __len__(self):
128 + return len(self.fns)
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import torch
3 +import torch.nn as nn
4 +import torch.utils.data as data
5 +import cv2
6 +import numpy as np
7 +from tqdm import tqdm
8 +import pandas as pd
9 +from PIL import Image
10 +
11 +
12 +class FacemaskSegDataset(data.Dataset):
13 + def __init__(self, cfg, train=True):
14 + self.root_dir = cfg.root_dir
15 + self.cfg = cfg
16 + self.train = train
17 +
18 + if self.train:
19 + self.df = pd.read_csv(cfg.train_anns)
20 + else:
21 + self.df = pd.read_csv(cfg.val_anns)
22 +
23 + self.load_images()
24 +
25 + def load_images(self):
26 + self.fns = []
27 + for idx, rows in self.df.iterrows():
28 + _, img_name, mask_name = rows
29 + img_path = os.path.join(self.root_dir, img_name)
30 + mask_path = os.path.join(self.root_dir, mask_name)
31 + img_path = img_path.replace("\\", "/")
32 + mask_path = mask_path.replace("\\", "/")
33 + if os.path.isfile(mask_path):
34 + self.fns.append([img_path, mask_path])
35 +
36 + def __getitem__(self, index):
37 + img_path, mask_path = self.fns[index]
38 + img = cv2.imread(img_path)
39 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
40 + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
41 + mask = cv2.imread(mask_path, 0)
42 + mask[mask > 0] = 1.0
43 + mask = np.expand_dims(mask, axis=0)
44 +
45 + img = (
46 + torch.from_numpy(img.astype(np.float32) / 255.0)
47 + .permute(2, 0, 1)
48 + .contiguous()
49 + )
50 + mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
51 + return img, mask
52 +
53 + def collate_fn(self, batch):
54 + imgs = torch.stack([i[0] for i in batch])
55 + masks = torch.stack([i[1] for i in batch])
56 + return {"imgs": imgs, "masks": masks}
57 +
58 + def __len__(self):
59 + return len(self.fns)
1 +import torch
2 +import torch.nn as nn
3 +from torchvision.utils import save_image
4 +
5 +import numpy as np
6 +from PIL import Image
7 +import cv2
8 +from models import UNetSemantic, GatedGenerator
9 +import argparse
10 +from configs import Config
11 +
12 +class Predictor():
13 + def __init__(self, cfg):
14 + self.cfg = cfg
15 + self.device = torch.device('cuda:0' if cfg.cuda else 'cpu')
16 + self.masking = UNetSemantic().to(self.device)
17 + self.masking.load_state_dict(torch.load('weights\model_segm_19_135000.pth', map_location='cpu'))
18 + #self.masking.eval()
19 +
20 + self.inpaint = GatedGenerator().to(self.device)
21 + self.inpaint.load_state_dict(torch.load('weights/model_6_100000.pth', map_location='cpu')['G'])
22 + self.inpaint.eval()
23 +
24 + def save_image(self, img_list, save_img_path, nrow):
25 + img_list = [i.clone().cpu() for i in img_list]
26 + imgs = torch.stack(img_list, dim=1)
27 + imgs = imgs.view(-1, *list(imgs.size())[2:])
28 + save_image(imgs, save_img_path, nrow = nrow)
29 + print(f"Save image to {save_img_path}")
30 +
31 + def predict(self, image, outpath='sample/results.png'):
32 + outpath=f'sample/results_{image}.png'
33 + image = 'sample/'+image
34 + img = cv2.imread(image+'_masked.jpg')
35 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36 + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
37 + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
38 + img = img.unsqueeze(0).to(self.device)
39 +
40 + img_ori = cv2.imread(image+'.jpg')
41 + img_ori = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB)
42 + img_ori = cv2.resize(img_ori, (self.cfg.img_size, self.cfg.img_size))
43 + img_ori = torch.from_numpy(img_ori.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
44 + img_ori = img_ori.unsqueeze(0)
45 + with torch.no_grad():
46 + outputs = self.masking(img)
47 + _, out = self.inpaint(img, outputs)
48 + inpaint = img * (1 - outputs) + out * outputs
49 + masks = img * (1 - outputs) + outputs #torch.cat([outputs, outputs, outputs], dim=1)
50 +
51 +
52 +
53 + self.save_image([img, masks, inpaint, img_ori], outpath, nrow=4)
54 +
55 +
56 +
57 +
58 +if __name__ == '__main__':
59 + parser = argparse.ArgumentParser(description='Training custom model')
60 + parser.add_argument('--image', default=None, type=str, help='resume training')
61 + parser.add_argument('config', default='config', type=str, help='config training')
62 + args = parser.parse_args()
63 +
64 + config = Config(f'./configs/{args.config}.yaml')
65 +
66 +
67 + model = Predictor(config)
68 + model.predict(args.image)
...\ No newline at end of file ...\ No newline at end of file
1 +from .loggers import *
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import numpy as np
3 +from torch.utils.tensorboard import SummaryWriter
4 +from datetime import datetime
5 +
6 +class Logger():
7 + """
8 + Logger for Tensorboard visualization
9 + :param log_dir: Path to save checkpoint
10 + """
11 + def __init__(self, log_dir=None):
12 + self.log_dir = log_dir
13 + if self.log_dir is None:
14 + self.log_dir = os.path.join('loggers/runs',datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
15 + if not os.path.exists(self.log_dir):
16 + os.mkdir(self.log_dir)
17 + self.writer = SummaryWriter(log_dir=self.log_dir)
18 + self.iters = {}
19 +
20 + def write(self, tags, values):
21 + """
22 + Write a log to specified directory
23 + :param tags: (str) tag for log
24 + :param values: (number) value for corresponding tag
25 + """
26 + if not isinstance(tags, list):
27 + tags = list(tags)
28 + if not isinstance(values, list):
29 + values = list(values)
30 +
31 + for i, (tag, value) in enumerate(zip(tags,values)):
32 + if tag not in self.iters.keys():
33 + self.iters[tag] = 0
34 + self.writer.add_scalar(tag, value, self.iters[tag])
35 + self.iters[tag] += 1
36 +
37 +
1 +from .adversarial import GANLoss
2 +from .ssim import SSIM
3 +from .dice import DiceLoss
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +
4 +class GANLoss(nn.Module):
5 + def __init__(self, target_real_label=1.0, target_fake_label=0.0):
6 + super(GANLoss, self).__init__()
7 + self.register_buffer('real_label', torch.tensor(target_real_label))
8 + self.register_buffer('fake_label', torch.tensor(target_fake_label))
9 + self.loss = nn.MSELoss()
10 +
11 + def get_target_tensor(self, input, target_is_real):
12 + if target_is_real:
13 + target_tensor = self.real_label
14 + else:
15 + target_tensor = self.fake_label
16 + return target_tensor.expand_as(input)
17 +
18 + def __call__(self, input, target_is_real):
19 + target_tensor = self.get_target_tensor(input, target_is_real).to(input.device)
20 + return self.loss(input, target_tensor)
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +
4 +
5 +class DiceLoss(nn.Module):
6 + """
7 + Dice loss of binary class
8 + Args:
9 + smooth: A float number to smooth loss, and avoid NaN error, default: 1
10 + p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
11 + predict: A tensor of shape [N, *]
12 + target: A tensor of shape same with predict
13 + reduction: Reduction method to apply, return mean over batch if 'mean',
14 + return sum if 'sum', return a tensor of shape [N,] if 'none'
15 + Returns:
16 + Loss tensor according to arg reduction
17 + Raise:
18 + Exception if unexpected reduction
19 + """
20 + def __init__(self, smooth=1, p=2, reduction='mean'):
21 + super(DiceLoss, self).__init__()
22 + self.smooth = smooth
23 + self.p = p
24 + self.reduction = reduction
25 +
26 + def forward(self, predict, target):
27 + assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
28 + predict = predict.contiguous().view(predict.shape[0], -1)
29 + target = target.contiguous().view(target.shape[0], -1)
30 +
31 + num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
32 + den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
33 +
34 + loss = 1 - num / den
35 +
36 + if self.reduction == 'mean':
37 + return loss.mean()
38 + elif self.reduction == 'sum':
39 + return loss.sum()
40 + elif self.reduction == 'none':
41 + return loss
42 + else:
43 + raise Exception('Unexpected reduction {}'.format(self.reduction))
1 +#Source: https://github.com/Po-Hsun-Su/pytorch-ssim.git
2 +
3 +import torch
4 +import torch.nn.functional as F
5 +from torch.autograd import Variable
6 +import numpy as np
7 +from math import exp
8 +
9 +def gaussian(window_size, sigma):
10 + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
11 + return gauss/gauss.sum()
12 +
13 +def create_window(window_size, channel):
14 + _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15 + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
16 + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
17 + return window
18 +
19 +def _ssim(img1, img2, window, window_size, channel, size_average = True):
20 + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
21 + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
22 +
23 + mu1_sq = mu1.pow(2)
24 + mu2_sq = mu2.pow(2)
25 + mu1_mu2 = mu1*mu2
26 +
27 + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
28 + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
29 + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
30 +
31 + C1 = 0.01**2
32 + C2 = 0.03**2
33 +
34 + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
35 +
36 + if size_average:
37 + return ssim_map.mean()
38 + else:
39 + return ssim_map.mean(1).mean(1).mean(1)
40 +
41 +class SSIM(torch.nn.Module):
42 + def __init__(self, window_size = 11, size_average = True):
43 + super(SSIM, self).__init__()
44 + self.window_size = window_size
45 + self.size_average = size_average
46 + self.channel = 1
47 + self.window = create_window(window_size, self.channel)
48 +
49 + def forward(self, img1, img2):
50 + (_, channel, _, _) = img1.size()
51 +
52 + if channel == self.channel and self.window.data.type() == img1.data.type():
53 + window = self.window
54 + else:
55 + window = create_window(self.window_size, channel)
56 +
57 + if img1.is_cuda:
58 + window = window.cuda(img1.get_device())
59 + window = window.type_as(img1)
60 +
61 + self.window = window
62 + self.channel = channel
63 +
64 +
65 + return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
66 +
67 +def ssim(img1, img2, window_size = 11, size_average = True):
68 + (_, channel, _, _) = img1.size()
69 + window = create_window(window_size, channel)
70 +
71 + if img1.is_cuda:
72 + window = window.cuda(img1.get_device())
73 + window = window.type_as(img1)
74 +
75 + return _ssim(img1, img2, window, window_size, channel, size_average)
...\ No newline at end of file ...\ No newline at end of file
1 +from .dicecoeff import DiceScore
2 +from .pixelacc import PixelAccuracy
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +import numpy as np
4 +
5 +class DiceScore():
6 + def __init__(self, num_classes, ignore_index = None, eps=1e-6, thresh = 0.5):
7 + self.thresh = thresh
8 + self.num_classes = num_classes
9 + self.pred_type = "multi" if num_classes > 1 else "binary"
10 +
11 + if num_classes == 1:
12 + self.num_classes+=1
13 +
14 + self.ignore_index = ignore_index
15 + self.eps = eps
16 +
17 + self.scores_list = np.zeros(self.num_classes)
18 + self.reset()
19 +
20 + def compute(self, outputs, targets):
21 + # outputs: (batch, num_classes, W, H)
22 + # targets: (batch, num_classes, W, H)
23 +
24 + batch_size, _ , w, h = outputs.shape
25 + if len(targets.shape) == 3:
26 + targets = targets.unsqueeze(1)
27 +
28 + one_hot_targets = torch.zeros(batch_size, self.num_classes, h, w)
29 + one_hot_predicts = torch.zeros(batch_size, self.num_classes, h, w)
30 +
31 + if self.pred_type == 'binary':
32 + predicts = (outputs > self.thresh).float()
33 + elif self.pred_type =='multi':
34 + predicts = torch.argmax(outputs, dim=1).unsqueeze(1)
35 +
36 + one_hot_targets.scatter_(1, targets.long(), 1)
37 + one_hot_predicts.scatter_(1, predicts.long(), 1)
38 +
39 + for cl in range(self.num_classes):
40 + cl_pred = one_hot_predicts[:,cl,:,:]
41 + cl_target = one_hot_targets[:,cl,:,:]
42 + score = self.binary_compute(cl_pred, cl_target)
43 + self.scores_list[cl] += sum(score)
44 +
45 +
46 + def binary_compute(self, predict, target):
47 + # outputs: (batch, 1, W, H)
48 + # targets: (batch, 1, W, H)
49 +
50 + intersect = (predict * target).sum((-2,-1))
51 + union = (predict + target).sum((-2,-1))
52 + return (2. * intersect + self.eps) / (union +self.eps)
53 +
54 + def reset(self):
55 + self.scores_list = np.zeros(self.num_classes)
56 + self.sample_size = 0
57 +
58 + def update(self, outputs, targets):
59 + self.sample_size += outputs.shape[0]
60 + self.compute(outputs, targets)
61 +
62 + def value(self):
63 + scores_each_class = self.scores_list / self.sample_size #mean over number of samples
64 + if self.pred_type == 'binary':
65 + scores = scores_each_class[1] # ignore background which is label 0
66 + else:
67 + scores = sum(scores_each_class) / self.num_classes
68 + return np.round(scores, decimals=4)
69 +
70 + def summary(self):
71 + class_iou = self.scores_list / self.sample_size #mean
72 +
73 + print(f'{self.value()}')
74 + for i, x in enumerate(class_iou):
75 + print(f'\tClass {i}: {x:.4f}')
76 +
77 + def __str__(self):
78 + return f'Dice Score: {self.value()}'
79 +
80 + def __len__(self):
81 + return len(self.sample_size)
82 +
83 +
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +import torch.nn as nn
3 +import numpy as np
4 +
5 +class PixelAccuracy():
6 + def __init__(self, num_classes, ignore_index=None, eps=1e-6, thresh = 0.5):
7 + self.thresh = thresh
8 + self.num_classes = num_classes
9 + self.pred_type = "multi" if num_classes > 1 else "binary"
10 +
11 + if num_classes == 1:
12 + self.num_classes+=1
13 +
14 + self.ignore_index = ignore_index
15 + self.eps = eps
16 +
17 + self.scores_list = np.zeros(self.num_classes)
18 + self.reset()
19 +
20 + def compute(self, outputs, targets):
21 + # outputs: (batch, num_classes, W, H)
22 + # targets: (batch, num_classes, W, H)
23 +
24 + batch_size, _ , w, h = outputs.shape
25 + if len(targets.shape) == 3:
26 + targets = targets.unsqueeze(1)
27 +
28 + one_hot_targets = torch.zeros(batch_size, self.num_classes, h, w)
29 + one_hot_predicts = torch.zeros(batch_size, self.num_classes, h, w)
30 +
31 + if self.pred_type == 'binary':
32 + predicts = (outputs > self.thresh).float()
33 + elif self.pred_type =='multi':
34 + predicts = torch.argmax(outputs, dim=1).unsqueeze(1)
35 +
36 + one_hot_targets.scatter_(1, targets.long(), 1)
37 + one_hot_predicts.scatter_(1, predicts.long(), 1)
38 +
39 + for cl in range(self.num_classes):
40 + cl_pred = one_hot_predicts[:,cl,:,:]
41 + cl_target = one_hot_targets[:,cl,:,:]
42 + score = self.binary_compute(cl_pred, cl_target)
43 + self.scores_list[cl] += sum(score)
44 +
45 + def binary_compute(self, predict, target):
46 + # predict: (batch, 1, W, H)
47 + # targets: (batch, 1, W, H)
48 +
49 + correct = (predict == target).sum((-2,-1))
50 + total = target.shape[-1] * target.shape[-2]
51 + return (correct + self.eps) *1.0 / (total +self.eps)
52 +
53 + def reset(self):
54 + self.scores_list = np.zeros(self.num_classes)
55 + self.sample_size = 0
56 +
57 + def update(self, outputs, targets):
58 + self.sample_size += outputs.shape[0]
59 + self.compute(outputs, targets)
60 +
61 + def value(self):
62 + scores_each_class = self.scores_list / self.sample_size #mean over number of samples
63 + if self.pred_type == 'binary':
64 + scores = scores_each_class[1] # ignore background which is label 0
65 + else:
66 + scores = sum(scores_each_class) / self.num_classes
67 + return np.round(scores, decimals=4)
68 +
69 + def summary(self):
70 + class_iou = self.scores_list / self.sample_size #mean
71 +
72 + print(f'{self.value()}')
73 + for i, x in enumerate(class_iou):
74 + print(f'\tClass {i}: {x:.4f}')
75 +
76 + def __str__(self):
77 + return f'Pixel Accuracy: {self.value()}'
78 +
79 + def __len__(self):
80 + return len(self.sample_size)
...\ No newline at end of file ...\ No newline at end of file
1 +from .networks import GatedGenerator, NLayerDiscriminator, PerceptualNet
2 +from .unet import UNetSemantic
1 +import torch
2 +import torch.nn as nn
3 +from torch.nn import Parameter
4 +import torch.nn.functional as F
5 +import torch.utils.data as data
6 +import functools
7 +from torchvision.models import vgg19, vgg16
8 +
9 +class GatedConv2d(nn.Module):
10 + def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, activation = 'lrelu', norm = 'in'):
11 + super(GatedConv2d, self).__init__()
12 + self.pad = nn.ZeroPad2d(padding)
13 + if norm is not None:
14 + self.norm = nn.InstanceNorm2d(out_channels)
15 + else:
16 + self.norm = None
17 +
18 + if activation == 'tanh':
19 + self.activation = nn.Tanh()
20 + else:
21 + self.activation = nn.LeakyReLU(0.2, inplace = True)
22 +
23 +
24 + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
25 + self.mask_conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
26 + self.sigmoid = torch.nn.Sigmoid()
27 +
28 + def forward(self, x):
29 + x = self.pad(x)
30 + conv = self.conv2d(x)
31 + mask = self.mask_conv2d(x)
32 + gated_mask = self.sigmoid(mask)
33 + x = conv * gated_mask
34 + if self.norm:
35 + x = self.norm(x)
36 + if self.activation:
37 + x = self.activation(x)
38 + return x
39 +
40 +class TransposeGatedConv2d(nn.Module):
41 + def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, dilation = 1, norm=None, scale_factor = 2):
42 + super(TransposeGatedConv2d, self).__init__()
43 + # Initialize the conv scheme
44 + self.scale_factor = scale_factor
45 + self.gated_conv2d = GatedConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, norm=norm)
46 +
47 + def forward(self, x):
48 + x = F.interpolate(x, scale_factor = self.scale_factor, mode = 'nearest')
49 + x = self.gated_conv2d(x)
50 + return x
51 +
52 +
53 +class GatedGenerator(nn.Module):
54 + def __init__(self, in_channels=4, latent_channels=64, out_channels=3):
55 + super(GatedGenerator, self).__init__()
56 + self.coarse = nn.Sequential(
57 + # encoder
58 + GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None),
59 + GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1),
60 + GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1),
61 + GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1),
62 + # Bottleneck
63 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
64 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
65 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2),
66 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4),
67 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8),
68 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16),
69 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
70 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
71 + # decoder
72 + TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1),
73 + GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1),
74 + TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1),
75 + GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None)
76 + )
77 + self.refinement = nn.Sequential(
78 + # encoder
79 + GatedConv2d(in_channels, latent_channels, 7, 1, 3, norm = None),
80 + GatedConv2d(latent_channels, latent_channels * 2, 4, 2, 1),
81 + GatedConv2d(latent_channels * 2, latent_channels * 4, 3, 1, 1),
82 + GatedConv2d(latent_channels * 4, latent_channels * 4, 4, 2, 1),
83 + # Bottleneck
84 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
85 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
86 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 2, dilation = 2),
87 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 4, dilation = 4),
88 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 8, dilation = 8),
89 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 16, dilation = 16),
90 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
91 + GatedConv2d(latent_channels * 4, latent_channels * 4, 3, 1, 1),
92 + # decoder
93 + TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1),
94 + GatedConv2d(latent_channels * 2, latent_channels * 2, 3, 1, 1),
95 + TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1),
96 + GatedConv2d(latent_channels, out_channels, 7, 1, 3, activation = 'tanh', norm = None)
97 + )
98 +
99 + def forward(self, img, mask):
100 + # img: entire img
101 + # mask: 1 for mask region; 0 for unmask region
102 + # 1 - mask: unmask
103 + # img * (1 - mask): ground truth unmask region
104 + # Coarse
105 +
106 + first_masked_img = img * (1 - mask) + mask
107 + first_in = torch.cat((first_masked_img, mask), 1) # in: [B, 4, H, W]
108 + first_out = self.coarse(first_in) # out: [B, 3, H, W]
109 + # Refinement
110 + second_masked_img = img * (1 - mask) + first_out * mask
111 + second_in = torch.cat((second_masked_img, mask), 1) # in: [B, 4, H, W]
112 + second_out = self.refinement(second_in) # out: [B, 3, H, W]
113 + return first_out, second_out
114 +
115 +
116 +class NLayerDiscriminator(nn.Module):
117 + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
118 + super(NLayerDiscriminator, self).__init__()
119 + if type(norm_layer) == functools.partial:
120 + use_bias = norm_layer.func == nn.InstanceNorm2d
121 + else:
122 + use_bias = norm_layer == nn.InstanceNorm2d
123 +
124 + kw = 4
125 + padw = 1
126 + sequence = [
127 + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
128 + nn.LeakyReLU(0.2, True)
129 + ]
130 +
131 + nf_mult = 1
132 + nf_mult_prev = 1
133 + for n in range(1, n_layers):
134 + nf_mult_prev = nf_mult
135 + nf_mult = min(2**n, 8)
136 + sequence += [
137 + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
138 + kernel_size=kw, stride=2, padding=padw, bias=use_bias),
139 + norm_layer(ndf * nf_mult),
140 + nn.LeakyReLU(0.2, True)
141 + ]
142 +
143 + nf_mult_prev = nf_mult
144 + nf_mult = min(2**n_layers, 8)
145 + sequence += [
146 + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
147 + kernel_size=kw, stride=1, padding=padw, bias=use_bias),
148 + norm_layer(ndf * nf_mult),
149 + nn.LeakyReLU(0.2, True)
150 + ]
151 +
152 + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
153 +
154 + if use_sigmoid:
155 + sequence += [nn.Sigmoid()]
156 +
157 + self.model = nn.Sequential(*sequence)
158 +
159 + def forward(self, input):
160 + return self.model(input)
161 +
162 +class PerceptualNet(nn.Module):
163 + # https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49
164 + def __init__(self, name = "vgg19", resize=True):
165 + super(PerceptualNet, self).__init__()
166 + blocks = []
167 + if name == "vgg19":
168 + blocks.append(vgg19(pretrained=True).features[:4].eval())
169 + blocks.append(vgg19(pretrained=True).features[4:9].eval())
170 + blocks.append(vgg19(pretrained=True).features[9:16].eval())
171 + blocks.append(vgg19(pretrained=True).features[16:23].eval())
172 + elif name == "vgg16":
173 + blocks.append(vgg16(pretrained=True).features[:4].eval())
174 + blocks.append(vgg16(pretrained=True).features[4:9].eval())
175 + blocks.append(vgg16(pretrained=True).features[9:16].eval())
176 + blocks.append(vgg16(pretrained=True).features[16:23].eval())
177 + else:
178 + assert "wrong model name"
179 +
180 + for bl in blocks:
181 + for p in bl:
182 + p.requires_grad = False
183 + self.blocks = torch.nn.ModuleList(blocks)
184 + self.transform = torch.nn.functional.interpolate
185 + self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
186 + self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
187 + self.resize = resize
188 +
189 + def forward(self, inputs, targets):
190 + if inputs.shape[1] != 3:
191 + inputs = inputs.repeat(1, 3, 1, 1)
192 + targets = targets.repeat(1, 3, 1, 1)
193 + inputs = (inputs-self.mean) / self.std
194 + targets = (targets-self.mean) / self.std
195 + if self.resize:
196 + inputs = self.transform(inputs, mode='bilinear', size=(512, 512), align_corners=False)
197 + targets = self.transform(targets, mode='bilinear', size=(512, 512), align_corners=False)
198 + loss = 0.0
199 + x = inputs
200 + y = targets
201 + for block in self.blocks:
202 + x = block(x)
203 + y = block(y)
204 + loss += torch.nn.functional.l1_loss(x, y)
205 + return loss
206 +
207 +
208 +
1 +import torch
2 +import torch.nn as nn
3 +from torch.nn import Parameter
4 +import torch.nn.functional as F
5 +import torch.utils.data as data
6 +import functools
7 +
8 +
9 +class conv_block(nn.Module):
10 + """
11 + Convolution Block
12 + """
13 + def __init__(self, in_ch, out_ch):
14 + super(conv_block, self).__init__()
15 +
16 + self.conv = nn.Sequential(
17 + nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
18 + nn.BatchNorm2d(out_ch),
19 + nn.ReLU(inplace=True),
20 + nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
21 + nn.BatchNorm2d(out_ch),
22 + nn.ReLU(inplace=True))
23 +
24 + def forward(self, x):
25 +
26 + x = self.conv(x)
27 + return x
28 +
29 +
30 +class up_conv(nn.Module):
31 + """
32 + Up Convolution Block
33 + """
34 + def __init__(self, in_ch, out_ch):
35 + super(up_conv, self).__init__()
36 + self.up = nn.Sequential(
37 + nn.Upsample(scale_factor=2),
38 + nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
39 + nn.BatchNorm2d(out_ch),
40 + nn.ReLU(inplace=True)
41 + )
42 +
43 + def forward(self, x):
44 + x = self.up(x)
45 + return x
46 +
47 +
48 +
49 +class Recurrent_block(nn.Module):
50 + """
51 + Recurrent Block for R2Unet_CNN
52 + """
53 + def __init__(self, out_ch, t=2):
54 + super(Recurrent_block, self).__init__()
55 +
56 + self.t = t
57 + self.out_ch = out_ch
58 + self.conv = nn.Sequential(
59 + nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
60 + nn.BatchNorm2d(out_ch),
61 + nn.ReLU(inplace=True)
62 + )
63 +
64 + def forward(self, x):
65 + for i in range(self.t):
66 + if i == 0:
67 + x = self.conv(x)
68 + out = self.conv(x + x)
69 + return out
70 +
71 +
72 +class RRCNN_block(nn.Module):
73 + """
74 + Recurrent Residual Convolutional Neural Network Block
75 + """
76 + def __init__(self, in_ch, out_ch, t=2):
77 + super(RRCNN_block, self).__init__()
78 +
79 + self.RCNN = nn.Sequential(
80 + Recurrent_block(out_ch, t=t),
81 + Recurrent_block(out_ch, t=t)
82 + )
83 + self.Conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
84 +
85 + def forward(self, x):
86 + x1 = self.Conv(x)
87 + x2 = self.RCNN(x1)
88 + out = x1 + x2
89 + return out
90 +
91 +class Attention_block(nn.Module):
92 + """
93 + Attention Block
94 + """
95 +
96 + def __init__(self, F_g, F_l, F_int):
97 + super(Attention_block, self).__init__()
98 +
99 + self.W_g = nn.Sequential(
100 + nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
101 + nn.BatchNorm2d(F_int)
102 + )
103 +
104 + self.W_x = nn.Sequential(
105 + nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
106 + nn.BatchNorm2d(F_int)
107 + )
108 +
109 + self.psi = nn.Sequential(
110 + nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
111 + nn.BatchNorm2d(1),
112 + nn.Sigmoid()
113 + )
114 +
115 + self.relu = nn.ReLU(inplace=True)
116 +
117 + def forward(self, g, x):
118 + g1 = self.W_g(g)
119 + x1 = self.W_x(x)
120 + psi = self.relu(g1 + x1)
121 + psi = self.psi(psi)
122 + out = x * psi
123 + return out
124 +
125 +class SE_Block(nn.Module):
126 + "credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4"
127 + def __init__(self, c, r=16):
128 + super().__init__()
129 + self.squeeze = nn.AdaptiveAvgPool2d(1)
130 + self.excitation = nn.Sequential(
131 + nn.Linear(c, c // r, bias=False),
132 + nn.ReLU(inplace=True),
133 + nn.Linear(c // r, c, bias=False),
134 + nn.Sigmoid()
135 + )
136 +
137 + def forward(self, x):
138 + bs, c, _, _ = x.shape
139 + y = self.squeeze(x).view(bs, c)
140 + y = self.excitation(y).view(bs, c, 1, 1)
141 + return x * y.expand_as(x)
142 +
143 +class AtrousConv(nn.Module):
144 + def __init__(self, in_ch):
145 + super().__init__()
146 + self.atrous_conv = nn.Sequential(
147 + nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, dilation=2, padding=2),
148 + nn.BatchNorm2d(in_ch),
149 + nn.ReLU(),
150 +
151 + nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, dilation=4, padding=4),
152 + nn.BatchNorm2d(in_ch),
153 + nn.ReLU(),
154 +
155 + nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, dilation=8, padding=8),
156 + nn.BatchNorm2d(in_ch),
157 + nn.ReLU(),
158 +
159 + nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, dilation=16, padding=16),
160 + nn.BatchNorm2d(in_ch),
161 + nn.ReLU(),
162 + )
163 +
164 + def forward(self, x):
165 + return self.atrous_conv(x)
166 +
167 +
168 +class UNetSemantic(nn.Module):
169 + """
170 + UNet - Basic Implementation
171 + Paper : https://arxiv.org/abs/1505.04597
172 + """
173 + def __init__(self, in_ch=3, out_ch=1):
174 + super(UNetSemantic, self).__init__()
175 +
176 + n1 = 32
177 + filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
178 +
179 + self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
180 + self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
181 + self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
182 + self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
183 +
184 + self.Conv1 = conv_block(in_ch, filters[0])
185 + self.Conv2 = conv_block(filters[0], filters[1])
186 + self.Conv3 = conv_block(filters[1], filters[2])
187 + self.Conv4 = conv_block(filters[2], filters[3])
188 + self.Conv5 = conv_block(filters[3], filters[4])
189 +
190 + self.Up5 = up_conv(filters[4], filters[3])
191 + self.Up_conv5 = conv_block(filters[4], filters[3])
192 +
193 + self.Up4 = up_conv(filters[3], filters[2])
194 + self.Up_conv4 = conv_block(filters[3], filters[2])
195 +
196 + self.Up3 = up_conv(filters[2], filters[1])
197 + self.Up_conv3 = conv_block(filters[2], filters[1])
198 +
199 + self.Up2 = up_conv(filters[1], filters[0])
200 + self.Up_conv2 = conv_block(filters[1], filters[0])
201 +
202 + self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
203 + self.se1 = SE_Block(filters[0])
204 + self.se2 = SE_Block(filters[1])
205 + self.se3 = SE_Block(filters[2])
206 + self.active = torch.nn.Sigmoid()
207 +
208 + def forward(self, x):
209 +
210 + e1 = self.Conv1(x)
211 + e1 = self.se1(e1)
212 +
213 + e2 = self.Maxpool1(e1)
214 + e2 = self.Conv2(e2)
215 + e2 = self.se2(e2)
216 +
217 + e3 = self.Maxpool2(e2)
218 + e3 = self.Conv3(e3)
219 + e3 = self.se3(e3)
220 +
221 + e4 = self.Maxpool3(e3)
222 + e4 = self.Conv4(e4)
223 +
224 + e5 = self.Maxpool4(e4)
225 + e5 = self.Conv5(e5)
226 +
227 + d5 = self.Up5(e5)
228 + d5 = torch.cat((e4, d5), dim=1)
229 +
230 + d5 = self.Up_conv5(d5)
231 +
232 + d4 = self.Up4(d5)
233 + d4 = torch.cat((e3, d4), dim=1)
234 + d4 = self.Up_conv4(d4)
235 +
236 + d3 = self.Up3(d4)
237 + d3 = torch.cat((e2, d3), dim=1)
238 + d3 = self.Up_conv3(d3)
239 +
240 + d2 = self.Up2(d3)
241 + d2 = torch.cat((e1, d2), dim=1)
242 + d2 = self.Up_conv2(d2)
243 +
244 + out = self.Conv(d2)
245 +
246 + out = self.active(out)
247 +
248 + return out
...\ No newline at end of file ...\ No newline at end of file
1 +import argparse
2 +from configs import Config
3 +from trainer import Trainer
4 +from unet_trainer import UNetTrainer
5 +
6 +
7 +def main(args, cfg):
8 + if args.config == "segm":
9 + trainer = UNetTrainer(args, cfg)
10 + else:
11 + trainer = Trainer(args, cfg)
12 + trainer.fit()
13 +
14 +
15 +if __name__ == "__main__":
16 + parser = argparse.ArgumentParser(description="Training custom model")
17 + parser.add_argument("--resume", default=None, type=str, help="resume training")
18 + parser.add_argument("config", default="config", type=str, help="config training")
19 + args = parser.parse_args()
20 +
21 + config = Config(f"./configs/{args.config}.yaml")
22 + main(args, config)
1 +import os
2 +import cv2
3 +import time
4 +import numpy as np
5 +from PIL import Image
6 +import matplotlib.pyplot as plt
7 +
8 +import torch
9 +import torch.nn as nn
10 +import torch.utils.data as data
11 +from torch.optim.lr_scheduler import StepLR
12 +from torchvision.utils import save_image
13 +
14 +
15 +from models import *
16 +from losses import *
17 +from datasets import Places365Dataset, FacemaskDataset
18 +
19 +
20 +def adjust_learning_rate(optimizer, gamma, num_steps=1):
21 + for i in range(num_steps):
22 + for param_group in optimizer.param_groups:
23 + param_group['lr'] *= gamma
24 +
25 +def get_epoch_iters(path):
26 + path = os.path.basename(path)
27 + tokens = path[:-4].split('_')
28 + try:
29 + if tokens[-1] == 'interrupted':
30 + epoch_idx = int(tokens[-3])
31 + iter_idx = int(tokens[-2])
32 + else:
33 + epoch_idx = int(tokens[-2])
34 + iter_idx = int(tokens[-1])
35 + except:
36 + return 0, 0
37 +
38 + return epoch_idx, iter_idx
39 +
40 +def load_checkpoint(model_G, model_D, path):
41 + state = torch.load(path,map_location='cpu')
42 + model_G.load_state_dict(state['G'])
43 + model_D.load_state_dict(state['D'])
44 + print('Loaded checkpoint successfully')
45 +
46 +class Trainer():
47 + def __init__(self, args, cfg):
48 +
49 + if args.resume is not None:
50 + epoch, iters = get_epoch_iters(args.resume)
51 + else:
52 + epoch = 0
53 + iters = 0
54 +
55 + if not os.path.exists(cfg.checkpoint_path):
56 + os.makedirs(cfg.checkpoint_path)
57 + if not os.path.exists(cfg.sample_folder):
58 + os.makedirs(cfg.sample_folder)
59 +
60 + self.cfg = cfg
61 + self.step_iters = cfg.step_iters
62 + self.gamma = cfg.gamma
63 + self.visualize_per_iter = cfg.visualize_per_iter
64 + self.print_per_iter = cfg.print_per_iter
65 + self.save_per_iter = cfg.save_per_iter
66 +
67 + self.start_iter = iters
68 + self.iters = 0
69 + self.num_epochs = cfg.num_epochs
70 + self.device = torch.device('cuda' if cfg.cuda else 'cpu')
71 +
72 + trainset = FacemaskDataset(cfg) # Places365Dataset(cfg) #
73 +
74 + self.trainloader = data.DataLoader(
75 + trainset,
76 + batch_size=cfg.batch_size,
77 + num_workers = cfg.num_workers,
78 + pin_memory = True,
79 + shuffle=True,
80 + collate_fn = trainset.collate_fn)
81 +
82 + self.epoch = int(self.start_iter / len(self.trainloader))
83 + self.iters = self.start_iter
84 + self.num_iters = (self.num_epochs+1) * len(self.trainloader)
85 +
86 + self.model_G = GatedGenerator().to(self.device)
87 + self.model_D = NLayerDiscriminator(cfg.d_num_layers, use_sigmoid=False).to(self.device)
88 + self.model_P = PerceptualNet(name = "vgg16", resize=False).to(self.device)
89 +
90 + if args.resume is not None:
91 + load_checkpoint(self.model_G, self.model_D, args.resume)
92 +
93 + self.criterion_adv = GANLoss(target_real_label=0.9, target_fake_label=0.1)
94 + self.criterion_rec = nn.SmoothL1Loss()
95 + self.criterion_ssim = SSIM(window_size = 11)
96 + self.criterion_per = nn.SmoothL1Loss()
97 +
98 + self.optimizer_D = torch.optim.Adam(self.model_D.parameters(), lr=cfg.lr)
99 + self.optimizer_G = torch.optim.Adam(self.model_G.parameters(), lr=cfg.lr)
100 +
101 + def validate(self, sample_folder, sample_name, img_list):
102 + save_img_path = os.path.join(sample_folder, sample_name+'.png')
103 + img_list = [i.clone().cpu() for i in img_list]
104 + imgs = torch.stack(img_list, dim=1)
105 +
106 + # imgs shape: Bx5xCxWxH
107 +
108 + imgs = imgs.view(-1, *list(imgs.size())[2:])
109 + save_image(imgs, save_img_path, nrow= 5)
110 + print(f"Save image to {save_img_path}")
111 +
112 + def fit(self):
113 + self.model_G.train()
114 + self.model_D.train()
115 +
116 + running_loss = {
117 + 'D': 0,
118 + 'G': 0,
119 + 'P': 0,
120 + 'R_1': 0,
121 + 'R_2': 0,
122 + 'T': 0,
123 + }
124 +
125 + running_time = 0
126 + step = 0
127 + try:
128 + for epoch in range(self.epoch, self.num_epochs):
129 + self.epoch = epoch
130 + for i, batch in enumerate(self.trainloader):
131 + start_time = time.time()
132 + imgs = batch['imgs'].to(self.device)
133 + masks = batch['masks'].to(self.device)
134 +
135 + # Train discriminator
136 + self.optimizer_D.zero_grad()
137 + self.optimizer_G.zero_grad()
138 +
139 + first_out, second_out = self.model_G(imgs, masks)
140 +
141 + first_out_wholeimg = imgs * (1 - masks) + first_out * masks
142 + second_out_wholeimg = imgs * (1 - masks) + second_out * masks
143 +
144 + masks = masks.cpu()
145 +
146 + fake_D = self.model_D(second_out_wholeimg.detach())
147 + real_D = self.model_D(imgs)
148 +
149 + loss_fake_D = self.criterion_adv(fake_D, target_is_real=False)
150 + loss_real_D = self.criterion_adv(real_D, target_is_real=True)
151 +
152 + loss_D = (loss_fake_D + loss_real_D) * 0.5
153 +
154 + loss_D.backward()
155 + self.optimizer_D.step()
156 +
157 + real_D = None
158 +
159 + # Train Generator
160 + self.optimizer_D.zero_grad()
161 + self.optimizer_G.zero_grad()
162 +
163 + fake_D = self.model_D(second_out_wholeimg)
164 + loss_G = self.criterion_adv(fake_D, target_is_real=True)
165 +
166 + fake_D = None
167 +
168 + # Reconstruction loss
169 + loss_l1_1 = self.criterion_rec(first_out_wholeimg, imgs)
170 + loss_l1_2 = self.criterion_rec(second_out_wholeimg, imgs)
171 + loss_ssim_1 = self.criterion_ssim(first_out_wholeimg, imgs)
172 + loss_ssim_2 = self.criterion_ssim(second_out_wholeimg, imgs)
173 +
174 + loss_rec_1 = 0.5 * loss_l1_1 + 0.5 * (1 - loss_ssim_1)
175 + loss_rec_2 = 0.5 * loss_l1_2 + 0.5 * (1 - loss_ssim_2)
176 +
177 + # Perceptual loss
178 + loss_P = self.model_P(second_out_wholeimg, imgs)
179 +
180 + loss = self.cfg.lambda_G * loss_G + self.cfg.lambda_rec_1 * loss_rec_1 + self.cfg.lambda_rec_2 * loss_rec_2 + self.cfg.lambda_per * loss_P
181 + loss.backward()
182 + self.optimizer_G.step()
183 +
184 + end_time = time.time()
185 +
186 + imgs = imgs.cpu()
187 + # Visualize number
188 + running_time += (end_time - start_time)
189 + running_loss['D'] += loss_D.item()
190 + running_loss['G'] += (self.cfg.lambda_G * loss_G.item())
191 + running_loss['P'] += (self.cfg.lambda_per * loss_P.item())
192 + running_loss['R_1'] += (self.cfg.lambda_rec_1 * loss_rec_1.item())
193 + running_loss['R_2'] += (self.cfg.lambda_rec_2 * loss_rec_2.item())
194 + running_loss['T'] += loss.item()
195 +
196 +
197 + if self.iters % self.print_per_iter == 0:
198 + for key in running_loss.keys():
199 + running_loss[key] /= self.print_per_iter
200 + running_loss[key] = np.round(running_loss[key], 5)
201 + loss_string = '{}'.format(running_loss)[1:-1].replace("'",'').replace(",",' ||')
202 + print("[{}|{}] [{}|{}] || {} || Time: {:10.4f}s".format(self.epoch, self.num_epochs, self.iters, self.num_iters, loss_string, running_time))
203 +
204 + running_loss = {
205 + 'D': 0,
206 + 'G': 0,
207 + 'P': 0,
208 + 'R_1': 0,
209 + 'R_2': 0,
210 + 'T': 0,
211 + }
212 + running_time = 0
213 +
214 + if self.iters % self.save_per_iter == 0:
215 + torch.save({
216 + 'D': self.model_D.state_dict(),
217 + 'G': self.model_G.state_dict(),
218 + }, os.path.join(self.cfg.checkpoint_path, f"model_{self.epoch}_{self.iters}.pth"))
219 +
220 + # Step learning rate
221 + if self.iters == self.step_iters[step]:
222 + adjust_learning_rate(self.optimizer_D, self.gamma)
223 + adjust_learning_rate(self.optimizer_G, self.gamma)
224 + step+=1
225 +
226 + # Visualize sample
227 + if self.iters % self.visualize_per_iter == 0:
228 + masked_imgs = imgs * (1 - masks) + masks
229 +
230 + img_list = [imgs, masked_imgs, first_out, second_out, second_out_wholeimg]
231 + #name_list = ['gt', 'mask', 'masked_img', 'first_out', 'second_out']
232 + filename = f"{self.epoch}_{str(self.iters)}"
233 + self.validate(self.cfg.sample_folder, filename , img_list)
234 +
235 + self.iters += 1
236 +
237 + except KeyboardInterrupt:
238 + torch.save({
239 + 'D': self.model_D.state_dict(),
240 + 'G': self.model_G.state_dict(),
241 + }, os.path.join(self.cfg.checkpoint_path, f"model_{self.epoch}_{self.iters}.pth"))
242 +
243 +
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import cv2
3 +import time
4 +import numpy as np
5 +from PIL import Image
6 +import matplotlib.pyplot as plt
7 +from tqdm import tqdm
8 +
9 +import torch
10 +import torch.nn as nn
11 +import torch.utils.data as data
12 +from torch.optim.lr_scheduler import StepLR
13 +from torchvision.utils import save_image
14 +
15 +
16 +from models import UNetSemantic
17 +from losses import DiceLoss
18 +from datasets import FacemaskSegDataset
19 +from metrics import *
20 +
21 +
22 +def adjust_learning_rate(optimizer, gamma, num_steps=1):
23 + for i in range(num_steps):
24 + for param_group in optimizer.param_groups:
25 + param_group["lr"] *= gamma
26 +
27 +
28 +def get_epoch_iters(path):
29 + path = os.path.basename(path)
30 + tokens = path[:-4].split("_")
31 + try:
32 + if tokens[-1] == "interrupted":
33 + epoch_idx = int(tokens[-3])
34 + iter_idx = int(tokens[-2])
35 + else:
36 + epoch_idx = int(tokens[-2])
37 + iter_idx = int(tokens[-1])
38 + except:
39 + return 0, 0
40 +
41 + return epoch_idx, iter_idx
42 +
43 +
44 +def load_checkpoint(model, path):
45 + state = torch.load(path, map_location="cpu")
46 + model.load_state_dict(state)
47 + print("Loaded checkpoint successfully")
48 +
49 +
50 +class UNetTrainer:
51 + def __init__(self, args, cfg):
52 +
53 + if args.resume is not None:
54 + epoch, iters = get_epoch_iters(args.resume)
55 + else:
56 + epoch = 0
57 + iters = 0
58 +
59 + self.cfg = cfg
60 + self.step_iters = cfg.step_iters
61 + self.gamma = cfg.gamma
62 + self.visualize_per_iter = cfg.visualize_per_iter
63 + self.print_per_iter = cfg.print_per_iter
64 + self.save_per_iter = cfg.save_per_iter
65 +
66 + self.start_iter = iters
67 + self.iters = 0
68 + self.num_epochs = cfg.num_epochs
69 + self.device = torch.device("cuda:0" if cfg.cuda else "cpu")
70 +
71 + trainset = FacemaskSegDataset(cfg)
72 + valset = FacemaskSegDataset(cfg, train=False)
73 +
74 + self.trainloader = data.DataLoader(
75 + trainset,
76 + batch_size=cfg.batch_size,
77 + num_workers=cfg.num_workers,
78 + pin_memory=True,
79 + shuffle=True,
80 + collate_fn=trainset.collate_fn,
81 + )
82 +
83 + self.valloader = data.DataLoader(
84 + valset,
85 + batch_size=cfg.batch_size,
86 + num_workers=cfg.num_workers,
87 + pin_memory=True,
88 + shuffle=True,
89 + collate_fn=valset.collate_fn,
90 + )
91 +
92 + self.epoch = int(self.start_iter / len(self.trainloader))
93 + self.iters = self.start_iter
94 + self.num_iters = (self.num_epochs + 1) * len(self.trainloader)
95 +
96 + self.model = UNetSemantic().to(self.device)
97 + self.criterion_dice = DiceLoss()
98 + self.criterion_bce = nn.BCELoss()
99 +
100 + if args.resume is not None:
101 + load_checkpoint(self.model, args.resume)
102 +
103 + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr)
104 +
105 + def validate(self, sample_folder, sample_name, img_list):
106 + save_img_path = os.path.join(sample_folder, sample_name + ".png")
107 + img_list = [i.clone().cpu() for i in img_list]
108 + imgs = torch.stack(img_list, dim=1)
109 +
110 + # imgs shape: Bx5xCxWxH
111 +
112 + imgs = imgs.view(-1, *list(imgs.size())[2:])
113 + save_image(imgs, save_img_path, nrow=3)
114 + print(f"Save image to {save_img_path}")
115 +
116 + def train_epoch(self):
117 + self.model.train()
118 + running_loss = {
119 + "DICE": 0,
120 + "BCE": 0,
121 + "T": 0,
122 + }
123 + running_time = 0
124 +
125 + for idx, batch in enumerate(self.trainloader):
126 + self.optimizer.zero_grad()
127 + inputs = batch["imgs"].to(self.device)
128 + targets = batch["masks"].to(self.device)
129 +
130 + start_time = time.time()
131 +
132 + outputs = self.model(inputs)
133 +
134 + loss_bce = self.criterion_bce(outputs, targets)
135 + loss_dice = self.criterion_dice(outputs, targets)
136 + loss = loss_bce + loss_dice
137 + loss.backward()
138 + self.optimizer.step()
139 +
140 + end_time = time.time()
141 +
142 + running_loss["T"] += loss.item()
143 + running_loss["DICE"] += loss_dice.item()
144 + running_loss["BCE"] += loss_bce.item()
145 + running_time += end_time - start_time
146 +
147 + if self.iters % self.print_per_iter == 0:
148 + for key in running_loss.keys():
149 + running_loss[key] /= self.print_per_iter
150 + running_loss[key] = np.round(running_loss[key], 5)
151 + loss_string = (
152 + "{}".format(running_loss)[1:-1].replace("'", "").replace(",", " ||")
153 + )
154 + running_time = np.round(running_time, 5)
155 + print(
156 + "[{}/{}][{}/{}] || {} || Time: {}s".format(
157 + self.epoch,
158 + self.num_epochs,
159 + self.iters,
160 + self.num_iters,
161 + loss_string,
162 + running_time,
163 + )
164 + )
165 + running_time = 0
166 + running_loss = {
167 + "DICE": 0,
168 + "BCE": 0,
169 + "T": 0,
170 + }
171 +
172 + if self.iters % self.save_per_iter == 0:
173 + save_path = os.path.join(
174 + self.cfg.checkpoint_path,
175 + f"model_segm_{self.epoch}_{self.iters}.pth",
176 + )
177 + torch.save(self.model.state_dict(), save_path)
178 + print(f"Save model at {save_path}")
179 + self.iters += 1
180 +
181 + def validate_epoch(self):
182 + # Validate
183 +
184 + self.model.eval()
185 + metrics = [DiceScore(1), PixelAccuracy(1)]
186 + running_loss = {
187 + "DICE": 0,
188 + "BCE": 0,
189 + "T": 0,
190 + }
191 +
192 + running_time = 0
193 + print(
194 + "=============================EVALUATION==================================="
195 + )
196 + with torch.no_grad():
197 + start_time = time.time()
198 + for idx, batch in enumerate(tqdm(self.valloader)):
199 +
200 + inputs = batch["imgs"].to(self.device)
201 + targets = batch["masks"].to(self.device)
202 + outputs = self.model(inputs)
203 + loss_bce = self.criterion_bce(outputs, targets)
204 + loss_dice = self.criterion_dice(outputs, targets)
205 + loss = loss_bce + loss_dice
206 + running_loss["T"] += loss.item()
207 + running_loss["DICE"] += loss_dice.item()
208 + running_loss["BCE"] += loss_bce.item()
209 + for metric in metrics:
210 + metric.update(outputs.cpu(), targets.cpu())
211 +
212 + end_time = time.time()
213 + running_time += end_time - start_time
214 + running_time = np.round(running_time, 5)
215 + for key in running_loss.keys():
216 + running_loss[key] /= len(self.valloader)
217 + running_loss[key] = np.round(running_loss[key], 5)
218 +
219 + loss_string = (
220 + "{}".format(running_loss)[1:-1].replace("'", "").replace(",", " ||")
221 + )
222 +
223 + print(
224 + "[{}/{}] || Validation || {} || Time: {}s".format(
225 + self.epoch, self.num_epochs, loss_string, running_time
226 + )
227 + )
228 + for metric in metrics:
229 + print(metric)
230 + print(
231 + "=========================================================================="
232 + )
233 +
234 + def fit(self):
235 + try:
236 + for epoch in range(self.epoch, self.num_epochs + 1):
237 + self.epoch = epoch
238 + self.train_epoch()
239 + self.validate_epoch()
240 + except KeyboardInterrupt:
241 + torch.save(
242 + self.model.state_dict(),
243 + os.path.join(
244 + self.cfg.checkpoint_path,
245 + f"model_segm_{self.epoch}_{self.iters}.pth",
246 + ),
247 + )
248 + print("Model saved!")
249 +
1 +# Repo-specific
2 +data/masks/*
3 +.vscode*
4 +
5 +# Byte-compiled / optimized / DLL files
6 +__pycache__/
7 +*.py[cod]
8 +*$py.class
9 +
10 +# C extensions
11 +*.so
12 +
13 +# Distribution / packaging
14 +.Python
15 +build/
16 +develop-eggs/
17 +dist/
18 +downloads/
19 +eggs/
20 +.eggs/
21 +lib/
22 +lib64/
23 +parts/
24 +sdist/
25 +var/
26 +wheels/
27 +*.egg-info/
28 +.installed.cfg
29 +*.egg
30 +MANIFEST
31 +
32 +# PyInstaller
33 +# Usually these files are written by a python script from a template
34 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
35 +*.manifest
36 +*.spec
37 +
38 +# Installer logs
39 +pip-log.txt
40 +pip-delete-this-directory.txt
41 +
42 +# Unit test / coverage reports
43 +htmlcov/
44 +.tox/
45 +.coverage
46 +.coverage.*
47 +.cache
48 +nosetests.xml
49 +coverage.xml
50 +*.cover
51 +.hypothesis/
52 +.pytest_cache/
53 +
54 +# Translations
55 +*.mo
56 +*.pot
57 +
58 +# Django stuff:
59 +*.log
60 +local_settings.py
61 +db.sqlite3
62 +
63 +# Flask stuff:
64 +instance/
65 +.webassets-cache
66 +
67 +# Scrapy stuff:
68 +.scrapy
69 +
70 +# Sphinx documentation
71 +docs/_build/
72 +
73 +# PyBuilder
74 +target/
75 +
76 +# Jupyter Notebook
77 +.ipynb_checkpoints
78 +
79 +# pyenv
80 +.python-version
81 +
82 +# celery beat schedule file
83 +celerybeat-schedule
84 +
85 +# SageMath parsed files
86 +*.sage.py
87 +
88 +# Environments
89 +.env
90 +.venv
91 +env/
92 +venv/
93 +ENV/
94 +env.bak/
95 +venv.bak/
96 +
97 +# Spyder project settings
98 +.spyderproject
99 +.spyproject
100 +
101 +# Rope project settings
102 +.ropeproject
103 +
104 +# mkdocs documentation
105 +/site
106 +
107 +# mypy
108 +.mypy_cache/
109 +
110 +backup*
111 +pexels_royalty_free_photos*
...\ No newline at end of file ...\ No newline at end of file
1 +
2 +from keras.utils import conv_utils
3 +from keras import backend as K
4 +from keras.engine import InputSpec
5 +from keras.layers import Conv2D
6 +
7 +
8 +class PConv2D(Conv2D):
9 + def __init__(self, *args, n_channels=3, mono=False, **kwargs):
10 + super().__init__(*args, **kwargs)
11 + self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
12 +
13 + def build(self, input_shape):
14 + """Adapted from original _Conv() layer of Keras
15 + param input_shape: list of dimensions for [img, mask]
16 + """
17 +
18 + if self.data_format == 'channels_first':
19 + channel_axis = 1
20 + else:
21 + channel_axis = -1
22 +
23 + if input_shape[0][channel_axis] is None:
24 + raise ValueError('The channel dimension of the inputs should be defined. Found `None`.')
25 +
26 + self.input_dim = input_shape[0][channel_axis]
27 +
28 + # Image kernel
29 + kernel_shape = self.kernel_size + (self.input_dim, self.filters)
30 + self.kernel = self.add_weight(shape=kernel_shape,
31 + initializer=self.kernel_initializer,
32 + name='img_kernel',
33 + regularizer=self.kernel_regularizer,
34 + constraint=self.kernel_constraint)
35 + # Mask kernel
36 + self.kernel_mask = K.ones(shape=self.kernel_size + (self.input_dim, self.filters))
37 +
38 + # Calculate padding size to achieve zero-padding
39 + self.pconv_padding = (
40 + (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)),
41 + (int((self.kernel_size[0]-1)/2), int((self.kernel_size[0]-1)/2)),
42 + )
43 +
44 + # Window size - used for normalization
45 + self.window_size = self.kernel_size[0] * self.kernel_size[1]
46 +
47 + if self.use_bias:
48 + self.bias = self.add_weight(shape=(self.filters,),
49 + initializer=self.bias_initializer,
50 + name='bias',
51 + regularizer=self.bias_regularizer,
52 + constraint=self.bias_constraint)
53 + else:
54 + self.bias = None
55 + self.built = True
56 +
57 + def call(self, inputs, mask=None):
58 + '''
59 + We will be using the Keras conv2d method, and essentially we have
60 + to do here is multiply the mask with the input X, before we apply the
61 + convolutions. For the mask itself, we apply convolutions with all weights
62 + set to 1.
63 + Subsequently, we clip mask values to between 0 and 1
64 + '''
65 +
66 + # Both image and mask must be supplied
67 + if type(inputs) is not list or len(inputs) != 2:
68 + raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs))
69 +
70 + # Padding done explicitly so that padding becomes part of the masked partial convolution
71 + images = K.spatial_2d_padding(inputs[0], self.pconv_padding, self.data_format)
72 + masks = K.spatial_2d_padding(inputs[1], self.pconv_padding, self.data_format)
73 +
74 + # Apply convolutions to mask
75 + mask_output = K.conv2d(
76 + masks, self.kernel_mask,
77 + strides=self.strides,
78 + padding='valid',
79 + data_format=self.data_format,
80 + dilation_rate=self.dilation_rate
81 + )
82 +
83 + # Apply convolutions to image
84 + img_output = K.conv2d(
85 + (images*masks), self.kernel,
86 + strides=self.strides,
87 + padding='valid',
88 + data_format=self.data_format,
89 + dilation_rate=self.dilation_rate
90 + )
91 +
92 + # Calculate the mask ratio on each pixel in the output mask
93 + mask_ratio = self.window_size / (mask_output + 1e-8)
94 +
95 + # Clip output to be between 0 and 1
96 + mask_output = K.clip(mask_output, 0, 1)
97 +
98 + # Remove ratio values where there are holes
99 + mask_ratio = mask_ratio * mask_output
100 +
101 + # Normalize iamge output
102 + img_output = img_output * mask_ratio
103 +
104 + # Apply bias only to the image (if chosen to do so)
105 + if self.use_bias:
106 + img_output = K.bias_add(
107 + img_output,
108 + self.bias,
109 + data_format=self.data_format)
110 +
111 + # Apply activations on the image
112 + if self.activation is not None:
113 + img_output = self.activation(img_output)
114 +
115 + return [img_output, mask_output]
116 +
117 + def compute_output_shape(self, input_shape):
118 + if self.data_format == 'channels_last':
119 + space = input_shape[0][1:-1]
120 + new_space = []
121 + for i in range(len(space)):
122 + new_dim = conv_utils.conv_output_length(
123 + space[i],
124 + self.kernel_size[i],
125 + padding='same',
126 + stride=self.strides[i],
127 + dilation=self.dilation_rate[i])
128 + new_space.append(new_dim)
129 + new_shape = (input_shape[0][0],) + tuple(new_space) + (self.filters,)
130 + return [new_shape, new_shape]
131 + if self.data_format == 'channels_first':
132 + space = input_shape[2:]
133 + new_space = []
134 + for i in range(len(space)):
135 + new_dim = conv_utils.conv_output_length(
136 + space[i],
137 + self.kernel_size[i],
138 + padding='same',
139 + stride=self.strides[i],
140 + dilation=self.dilation_rate[i])
141 + new_space.append(new_dim)
142 + new_shape = (input_shape[0], self.filters) + tuple(new_space)
143 + return [new_shape, new_shape]
1 +import os
2 +import sys
3 +import numpy as np
4 +from datetime import datetime
5 +
6 +import tensorflow as tf
7 +from keras.models import Model
8 +from keras.models import load_model
9 +from keras.optimizers import Adam
10 +from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, BatchNormalization, Activation, Lambda
11 +from keras.layers.merge import Concatenate
12 +from keras.applications import VGG16
13 +from keras import backend as K
14 +from keras.utils.multi_gpu_utils import multi_gpu_model
15 +
16 +from libs.pconv_layer import PConv2D
17 +
18 +
19 +class PConvUnet(object):
20 +
21 + def __init__(self, img_rows=512, img_cols=512, vgg_weights="imagenet", inference_only=False, net_name='default', gpus=1, vgg_device=None):
22 + """Create the PConvUnet. If variable image size, set img_rows and img_cols to None
23 +
24 + Args:
25 + img_rows (int): image height.
26 + img_cols (int): image width.
27 + vgg_weights (str): which weights to pass to the vgg network.
28 + inference_only (bool): initialize BN layers for inference.
29 + net_name (str): Name of this network (used in logging).
30 + gpus (int): How many GPUs to use for training.
31 + vgg_device (str): In case of training with multiple GPUs, specify which device to run VGG inference on.
32 + e.g. if training on 8 GPUs, vgg inference could be off-loaded exclusively to one GPU, instead of
33 + running on one of the GPUs which is also training the UNet.
34 + """
35 +
36 + # Settings
37 + self.img_rows = img_rows
38 + self.img_cols = img_cols
39 + self.img_overlap = 30
40 + self.inference_only = inference_only
41 + self.net_name = net_name
42 + self.gpus = gpus
43 + self.vgg_device = vgg_device
44 +
45 + # Scaling for VGG input
46 + self.mean = [0.485, 0.456, 0.406]
47 + self.std = [0.229, 0.224, 0.225]
48 +
49 + # Assertions
50 + assert self.img_rows >= 256, 'Height must be >256 pixels'
51 + assert self.img_cols >= 256, 'Width must be >256 pixels'
52 +
53 + # Set current epoch
54 + self.current_epoch = 0
55 +
56 + # VGG layers to extract features from (first maxpooling layers, see pp. 7 of paper)
57 + self.vgg_layers = [3, 6, 10]
58 +
59 + # Instantiate the vgg network
60 + if self.vgg_device:
61 + with tf.device(self.vgg_device):
62 + self.vgg = self.build_vgg(vgg_weights)
63 + else:
64 + self.vgg = self.build_vgg(vgg_weights)
65 +
66 + # Create UNet-like model
67 + if self.gpus <= 1:
68 + self.model, inputs_mask = self.build_pconv_unet()
69 + self.compile_pconv_unet(self.model, inputs_mask)
70 + else:
71 + with tf.device("/cpu:0"):
72 + self.model, inputs_mask = self.build_pconv_unet()
73 + self.model = multi_gpu_model(self.model, gpus=self.gpus)
74 + self.compile_pconv_unet(self.model, inputs_mask)
75 +
76 + def build_vgg(self, weights="imagenet"):
77 + """
78 + Load pre-trained VGG16 from keras applications
79 + Extract features to be used in loss function from last conv layer, see architecture at:
80 + https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py
81 + """
82 +
83 + # Input image to extract features from
84 + img = Input(shape=(self.img_rows, self.img_cols, 3))
85 +
86 + # Mean center and rescale by variance as in PyTorch
87 + processed = Lambda(lambda x: (x-self.mean) / self.std)(img)
88 +
89 + # If inference only, just return empty model
90 + if self.inference_only:
91 + model = Model(inputs=img, outputs=[img for _ in range(len(self.vgg_layers))])
92 + model.trainable = False
93 + model.compile(loss='mse', optimizer='adam')
94 + return model
95 +
96 + # Get the vgg network from Keras applications
97 + if weights in ['imagenet', None]:
98 + vgg = VGG16(weights=weights, include_top=False)
99 + else:
100 + vgg = VGG16(weights=None, include_top=False)
101 + vgg.load_weights(weights, by_name=True)
102 +
103 + # Output the first three pooling layers
104 + vgg.outputs = [vgg.layers[i].output for i in self.vgg_layers]
105 +
106 + # Create model and compile
107 + model = Model(inputs=img, outputs=vgg(processed))
108 + model.trainable = False
109 + model.compile(loss='mse', optimizer='adam')
110 +
111 + return model
112 +
113 + def build_pconv_unet(self, train_bn=True):
114 +
115 + # INPUTS
116 + inputs_img = Input((self.img_rows, self.img_cols, 3), name='inputs_img')
117 + inputs_mask = Input((self.img_rows, self.img_cols, 3), name='inputs_mask')
118 +
119 + # ENCODER
120 + def encoder_layer(img_in, mask_in, filters, kernel_size, bn=True):
121 + conv, mask = PConv2D(filters, kernel_size, strides=2, padding='same')([img_in, mask_in])
122 + if bn:
123 + conv = BatchNormalization(name='EncBN'+str(encoder_layer.counter))(conv, training=train_bn)
124 + conv = Activation('relu')(conv)
125 + encoder_layer.counter += 1
126 + return conv, mask
127 + encoder_layer.counter = 0
128 +
129 + e_conv1, e_mask1 = encoder_layer(inputs_img, inputs_mask, 64, 7, bn=False)
130 + e_conv2, e_mask2 = encoder_layer(e_conv1, e_mask1, 128, 5)
131 + e_conv3, e_mask3 = encoder_layer(e_conv2, e_mask2, 256, 5)
132 + e_conv4, e_mask4 = encoder_layer(e_conv3, e_mask3, 512, 3)
133 + e_conv5, e_mask5 = encoder_layer(e_conv4, e_mask4, 512, 3)
134 + e_conv6, e_mask6 = encoder_layer(e_conv5, e_mask5, 512, 3)
135 + e_conv7, e_mask7 = encoder_layer(e_conv6, e_mask6, 512, 3)
136 + e_conv8, e_mask8 = encoder_layer(e_conv7, e_mask7, 512, 3)
137 +
138 + # DECODER
139 + def decoder_layer(img_in, mask_in, e_conv, e_mask, filters, kernel_size, bn=True):
140 + up_img = UpSampling2D(size=(2,2))(img_in)
141 + up_mask = UpSampling2D(size=(2,2))(mask_in)
142 + concat_img = Concatenate(axis=3)([e_conv,up_img])
143 + concat_mask = Concatenate(axis=3)([e_mask,up_mask])
144 + conv, mask = PConv2D(filters, kernel_size, padding='same')([concat_img, concat_mask])
145 + if bn:
146 + conv = BatchNormalization()(conv)
147 + conv = LeakyReLU(alpha=0.2)(conv)
148 + return conv, mask
149 +
150 + d_conv9, d_mask9 = decoder_layer(e_conv8, e_mask8, e_conv7, e_mask7, 512, 3)
151 + d_conv10, d_mask10 = decoder_layer(d_conv9, d_mask9, e_conv6, e_mask6, 512, 3)
152 + d_conv11, d_mask11 = decoder_layer(d_conv10, d_mask10, e_conv5, e_mask5, 512, 3)
153 + d_conv12, d_mask12 = decoder_layer(d_conv11, d_mask11, e_conv4, e_mask4, 512, 3)
154 + d_conv13, d_mask13 = decoder_layer(d_conv12, d_mask12, e_conv3, e_mask3, 256, 3)
155 + d_conv14, d_mask14 = decoder_layer(d_conv13, d_mask13, e_conv2, e_mask2, 128, 3)
156 + d_conv15, d_mask15 = decoder_layer(d_conv14, d_mask14, e_conv1, e_mask1, 64, 3)
157 + d_conv16, d_mask16 = decoder_layer(d_conv15, d_mask15, inputs_img, inputs_mask, 3, 3, bn=False)
158 + outputs = Conv2D(3, 1, activation = 'sigmoid', name='outputs_img')(d_conv16)
159 +
160 + # Setup the model inputs / outputs
161 + model = Model(inputs=[inputs_img, inputs_mask], outputs=outputs)
162 +
163 + return model, inputs_mask
164 +
165 + def compile_pconv_unet(self, model, inputs_mask, lr=0.0002):
166 + model.compile(
167 + optimizer = Adam(lr=lr),
168 + loss=self.loss_total(inputs_mask),
169 + metrics=[self.PSNR]
170 + )
171 +
172 + def loss_total(self, mask):
173 + """
174 + Creates a loss function which sums all the loss components
175 + and multiplies by their weights. See paper eq. 7.
176 + """
177 + def loss(y_true, y_pred):
178 +
179 + # Compute predicted image with non-hole pixels set to ground truth
180 + y_comp = mask * y_true + (1-mask) * y_pred
181 +
182 + # Compute the vgg features.
183 + if self.vgg_device:
184 + with tf.device(self.vgg_device):
185 + vgg_out = self.vgg(y_pred)
186 + vgg_gt = self.vgg(y_true)
187 + vgg_comp = self.vgg(y_comp)
188 + else:
189 + vgg_out = self.vgg(y_pred)
190 + vgg_gt = self.vgg(y_true)
191 + vgg_comp = self.vgg(y_comp)
192 +
193 + # Compute loss components
194 + l1 = self.loss_valid(mask, y_true, y_pred)
195 + l2 = self.loss_hole(mask, y_true, y_pred)
196 + l3 = self.loss_perceptual(vgg_out, vgg_gt, vgg_comp)
197 + l4 = self.loss_style(vgg_out, vgg_gt)
198 + l5 = self.loss_style(vgg_comp, vgg_gt)
199 + l6 = self.loss_tv(mask, y_comp)
200 +
201 + # Return loss function
202 + return l1 + 6*l2 + 0.05*l3 + 120*(l4+l5) + 0.1*l6
203 +
204 + return loss
205 +
206 + def loss_hole(self, mask, y_true, y_pred):
207 + """Pixel L1 loss within the hole / mask"""
208 + return self.l1((1-mask) * y_true, (1-mask) * y_pred)
209 +
210 + def loss_valid(self, mask, y_true, y_pred):
211 + """Pixel L1 loss outside the hole / mask"""
212 + return self.l1(mask * y_true, mask * y_pred)
213 +
214 + def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp):
215 + """Perceptual loss based on VGG16, see. eq. 3 in paper"""
216 + loss = 0
217 + for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):
218 + loss += self.l1(o, g) + self.l1(c, g)
219 + return loss
220 +
221 + def loss_style(self, output, vgg_gt):
222 + """Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
223 + loss = 0
224 + for o, g in zip(output, vgg_gt):
225 + loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
226 + return loss
227 +
228 + def loss_tv(self, mask, y_comp):
229 + """Total variation loss, used for smoothing the hole region, see. eq. 6"""
230 +
231 + # Create dilated hole region using a 3x3 kernel of all 1s.
232 + kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))
233 + dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')
234 +
235 + # Cast values to be [0., 1.], and compute dilated hole region of y_comp
236 + dilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')
237 + P = dilated_mask * y_comp
238 +
239 + # Calculate total variation loss
240 + a = self.l1(P[:,1:,:,:], P[:,:-1,:,:])
241 + b = self.l1(P[:,:,1:,:], P[:,:,:-1,:])
242 + return a+b
243 +
244 + def fit_generator(self, generator, *args, **kwargs):
245 + """Fit the U-Net to a (images, targets) generator
246 +
247 + Args:
248 + generator (generator): generator supplying input image & mask, as well as targets.
249 + *args: arguments to be passed to fit_generator
250 + **kwargs: keyword arguments to be passed to fit_generator
251 + """
252 + self.model.fit_generator(
253 + generator,
254 + *args, **kwargs
255 + )
256 +
257 + def summary(self):
258 + """Get summary of the UNet model"""
259 + print(self.model.summary())
260 +
261 + def load(self, filepath, train_bn=True, lr=0.0002):
262 +
263 + # Create UNet-like model
264 + self.model, inputs_mask = self.build_pconv_unet(train_bn)
265 + self.compile_pconv_unet(self.model, inputs_mask, lr)
266 +
267 + # Load weights into model
268 + epoch = int(os.path.basename(filepath).split('.')[1].split('-')[0])
269 + assert epoch > 0, "Could not parse weight file. Should include the epoch"
270 + self.current_epoch = epoch
271 + self.model.load_weights(filepath)
272 +
273 + @staticmethod
274 + def PSNR(y_true, y_pred):
275 + """
276 + PSNR is Peek Signal to Noise Ratio, see https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
277 + The equation is:
278 + PSNR = 20 * log10(MAX_I) - 10 * log10(MSE)
279 +
280 + Our input is scaled with be within the range -2.11 to 2.64 (imagenet value scaling). We use the difference between these
281 + two values (4.75) as MAX_I
282 + """
283 + #return 20 * K.log(4.75) / K.log(10.0) - 10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0)
284 + return - 10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0)
285 +
286 + @staticmethod
287 + def current_timestamp():
288 + return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
289 +
290 + @staticmethod
291 + def l1(y_true, y_pred):
292 + """Calculate the L1 loss used in all loss calculations"""
293 + if K.ndim(y_true) == 4:
294 + return K.mean(K.abs(y_pred - y_true), axis=[1,2,3])
295 + elif K.ndim(y_true) == 3:
296 + return K.mean(K.abs(y_pred - y_true), axis=[1,2])
297 + else:
298 + raise NotImplementedError("Calculating L1 loss on 1D tensors? should not occur for this network")
299 +
300 + @staticmethod
301 + def gram_matrix(x, norm_by_channels=False):
302 + """Calculate gram matrix used in style loss"""
303 +
304 + # Assertions on input
305 + assert K.ndim(x) == 4, 'Input tensor should be a 4d (B, H, W, C) tensor'
306 + assert K.image_data_format() == 'channels_last', "Please use channels-last format"
307 +
308 + # Permute channels and get resulting shape
309 + x = K.permute_dimensions(x, (0, 3, 1, 2))
310 + shape = K.shape(x)
311 + B, C, H, W = shape[0], shape[1], shape[2], shape[3]
312 +
313 + # Reshape x and do batch dot product
314 + features = K.reshape(x, K.stack([B, C, H*W]))
315 + gram = K.batch_dot(features, features, axes=2)
316 +
317 + # Normalize with channels, height and width
318 + gram = gram / K.cast(C * H * W, x.dtype)
319 +
320 + return gram
321 +
322 + # Prediction functions
323 + ######################
324 + def predict(self, sample, **kwargs):
325 + """Run prediction using this model"""
326 + return self.model.predict(sample, **kwargs)
1 +import os
2 +from random import randint, seed
3 +import itertools
4 +import numpy as np
5 +import cv2
6 +
7 +
8 +class MaskGenerator():
9 +
10 + def __init__(self, height, width, channels=3, rand_seed=None, filepath=None):
11 + """Convenience functions for generating masks to be used for inpainting training
12 +
13 + Arguments:
14 + height {int} -- Mask height
15 + width {width} -- Mask width
16 +
17 + Keyword Arguments:
18 + channels {int} -- Channels to output (default: {3})
19 + rand_seed {[type]} -- Random seed (default: {None})
20 + filepath {[type]} -- Load masks from filepath. If None, generate masks with OpenCV (default: {None})
21 + """
22 +
23 + self.height = height
24 + self.width = width
25 + self.channels = channels
26 + self.filepath = filepath
27 +
28 + # If filepath supplied, load the list of masks within the directory
29 + self.mask_files = []
30 + if self.filepath:
31 + filenames = [f for f in os.listdir(self.filepath)]
32 + self.mask_files = [f for f in filenames if any(filetype in f.lower() for filetype in ['.jpeg', '.png', '.jpg'])]
33 + print(">> Found {} masks in {}".format(len(self.mask_files), self.filepath))
34 +
35 + # Seed for reproducibility
36 + if rand_seed:
37 + seed(rand_seed)
38 +
39 + def _generate_mask(self):
40 + """Generates a random irregular mask with lines, circles and elipses"""
41 +
42 + img = np.zeros((self.height, self.width, self.channels), np.uint8)
43 +
44 + # Set size scale
45 + size = int((self.width + self.height) * 0.03)
46 + if self.width < 64 or self.height < 64:
47 + raise Exception("Width and Height of mask must be at least 64!")
48 +
49 + # Draw random lines
50 + for _ in range(randint(1, 20)):
51 + x1, x2 = randint(1, self.width), randint(1, self.width)
52 + y1, y2 = randint(1, self.height), randint(1, self.height)
53 + thickness = randint(3, size)
54 + cv2.line(img,(x1,y1),(x2,y2),(1,1,1),thickness)
55 +
56 + # Draw random circles
57 + for _ in range(randint(1, 20)):
58 + x1, y1 = randint(1, self.width), randint(1, self.height)
59 + radius = randint(3, size)
60 + cv2.circle(img,(x1,y1),radius,(1,1,1), -1)
61 +
62 + # Draw random ellipses
63 + for _ in range(randint(1, 20)):
64 + x1, y1 = randint(1, self.width), randint(1, self.height)
65 + s1, s2 = randint(1, self.width), randint(1, self.height)
66 + a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180)
67 + thickness = randint(3, size)
68 + cv2.ellipse(img, (x1,y1), (s1,s2), a1, a2, a3,(1,1,1), thickness)
69 +
70 + return 1-img
71 +
72 + def _load_mask(self, rotation=True, dilation=True, cropping=True):
73 + """Loads a mask from disk, and optionally augments it"""
74 +
75 + # Read image
76 + mask = cv2.imread(os.path.join(self.filepath, np.random.choice(self.mask_files, 1, replace=False)[0]))
77 +
78 + # Random rotation
79 + if rotation:
80 + rand = np.random.randint(-180, 180)
81 + M = cv2.getRotationMatrix2D((mask.shape[1]/2, mask.shape[0]/2), rand, 1.5)
82 + mask = cv2.warpAffine(mask, M, (mask.shape[1], mask.shape[0]))
83 +
84 + # Random dilation
85 + if dilation:
86 + rand = np.random.randint(5, 47)
87 + kernel = np.ones((rand, rand), np.uint8)
88 + mask = cv2.erode(mask, kernel, iterations=1)
89 +
90 + # Random cropping
91 + if cropping:
92 + x = np.random.randint(0, mask.shape[1] - self.width)
93 + y = np.random.randint(0, mask.shape[0] - self.height)
94 + mask = mask[y:y+self.height, x:x+self.width]
95 +
96 + return (mask > 1).astype(np.uint8)
97 +
98 + def sample(self, random_seed=None):
99 + """Retrieve a random mask"""
100 + if random_seed:
101 + seed(random_seed)
102 + if self.filepath and len(self.mask_files) > 0:
103 + return self._load_mask()
104 + else:
105 + return self._generate_mask()
106 +
107 +
108 +class ImageChunker(object):
109 +
110 + def __init__(self, rows, cols, overlap):
111 + self.rows = rows
112 + self.cols = cols
113 + self.overlap = overlap
114 +
115 + def perform_chunking(self, img_size, chunk_size):
116 + """
117 + Given an image dimension img_size, return list of (start, stop)
118 + tuples to perform chunking of chunk_size
119 + """
120 + chunks, i = [], 0
121 + while True:
122 + chunks.append((i*(chunk_size - self.overlap/2), i*(chunk_size - self.overlap/2)+chunk_size))
123 + i+=1
124 + if chunks[-1][1] > img_size:
125 + break
126 + n_count = len(chunks)
127 + chunks[-1] = tuple(x - (n_count*chunk_size - img_size - (n_count-1)*self.overlap/2) for x in chunks[-1])
128 + chunks = [(int(x), int(y)) for x, y in chunks]
129 + return chunks
130 +
131 + def get_chunks(self, img, scale=1):
132 + """
133 + Get width and height lists of (start, stop) tuples for chunking of img.
134 + """
135 + x_chunks, y_chunks = [(0, self.rows)], [(0, self.cols)]
136 + if img.shape[0] > self.rows:
137 + x_chunks = self.perform_chunking(img.shape[0], self.rows)
138 + else:
139 + x_chunks = [(0, img.shape[0])]
140 + if img.shape[1] > self.cols:
141 + y_chunks = self.perform_chunking(img.shape[1], self.cols)
142 + else:
143 + y_chunks = [(0, img.shape[1])]
144 + return x_chunks, y_chunks
145 +
146 + def dimension_preprocess(self, img, padding=True):
147 + """
148 + In case of prediction on image of different size than 512x512,
149 + this function is used to add padding and chunk up the image into pieces
150 + of 512x512, which can then later be reconstructed into the original image
151 + using the dimension_postprocess() function.
152 + """
153 +
154 + # Assert single image input
155 + assert len(img.shape) == 3, "Image dimension expected to be (H, W, C)"
156 +
157 + # Check if we are adding padding for too small images
158 + if padding:
159 +
160 + # Check if height is too small
161 + if img.shape[0] < self.rows:
162 + padding = np.ones((self.rows - img.shape[0], img.shape[1], img.shape[2]))
163 + img = np.concatenate((img, padding), axis=0)
164 +
165 + # Check if width is too small
166 + if img.shape[1] < self.cols:
167 + padding = np.ones((img.shape[0], self.cols - img.shape[1], img.shape[2]))
168 + img = np.concatenate((img, padding), axis=1)
169 +
170 + # Get chunking of the image
171 + x_chunks, y_chunks = self.get_chunks(img)
172 +
173 + # Chunk up the image
174 + images = []
175 + for x in x_chunks:
176 + for y in y_chunks:
177 + images.append(
178 + img[x[0]:x[1], y[0]:y[1], :]
179 + )
180 + images = np.array(images)
181 + return images
182 +
183 + def dimension_postprocess(self, chunked_images, original_image, scale=1, padding=True):
184 + """
185 + In case of prediction on image of different size than 512x512,
186 + the dimension_preprocess function is used to add padding and chunk
187 + up the image into pieces of 512x512, and this function is used to
188 + reconstruct these pieces into the original image.
189 + """
190 +
191 + # Assert input dimensions
192 + assert len(original_image.shape) == 3, "Image dimension expected to be (H, W, C)"
193 + assert len(chunked_images.shape) == 4, "Chunked images dimension expected to be (B, H, W, C)"
194 +
195 + # Check if we are adding padding for too small images
196 + if padding:
197 +
198 + # Check if height is too small
199 + if original_image.shape[0] < self.rows:
200 + new_images = []
201 + for img in chunked_images:
202 + new_images.append(img[0:scale*original_image.shape[0], :, :])
203 + chunked_images = np.array(new_images)
204 +
205 + # Check if width is too small
206 + if original_image.shape[1] < self.cols:
207 + new_images = []
208 + for img in chunked_images:
209 + new_images.append(img[:, 0:scale*original_image.shape[1], :])
210 + chunked_images = np.array(new_images)
211 +
212 + # Put reconstruction into this array
213 + new_shape = (
214 + original_image.shape[0]*scale,
215 + original_image.shape[1]*scale,
216 + original_image.shape[2]
217 + )
218 + reconstruction = np.zeros(new_shape)
219 +
220 + # Get the chunks for this image
221 + x_chunks, y_chunks = self.get_chunks(original_image)
222 +
223 + i = 0
224 + s = scale
225 + for x in x_chunks:
226 + for y in y_chunks:
227 +
228 + prior_fill = reconstruction != 0
229 + chunk = np.zeros(new_shape)
230 + chunk[x[0]*s:x[1]*s, y[0]*s:y[1]*s, :] += chunked_images[i]
231 + chunk_fill = chunk != 0
232 +
233 + reconstruction += chunk
234 + reconstruction[prior_fill & chunk_fill] = reconstruction[prior_fill & chunk_fill] / 2
235 +
236 + i += 1
237 +
238 + return reconstruction
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import gc
3 +import datetime
4 +import numpy as np
5 +import pandas as pd
6 +import cv2
7 +
8 +from argparse import ArgumentParser
9 +from copy import deepcopy
10 +from tqdm import tqdm
11 +
12 +from keras.preprocessing.image import ImageDataGenerator
13 +from keras.callbacks import TensorBoard, ModelCheckpoint, LambdaCallback
14 +from keras import backend as K
15 +from keras.utils import Sequence
16 +from keras_tqdm import TQDMCallback
17 +
18 +import matplotlib.pyplot as plt
19 +from matplotlib.ticker import NullFormatter
20 +
21 +from libs.pconv_model import PConvUnet
22 +from libs.util import MaskGenerator
23 +
24 +
25 +# Sample call
26 +r"""
27 +# Train on CelebaHQ
28 +python main.py --name CelebHQ --train C:\Documents\Kaggle\celebaHQ-512\train\ --validation C:\Documents\Kaggle\celebaHQ-512\val\ --test C:\Documents\Kaggle\celebaHQ-512\test\ --checkpoint "C:\Users\Mathias Felix Gruber\Documents\GitHub\PConv-Keras\data\logs\imagenet_phase1_paperMasks\weights.35-0.70.h5"
29 +"""
30 +
31 +
32 +def parse_args():
33 + parser = ArgumentParser(description="Training script for PConv inpainting")
34 +
35 + parser.add_argument(
36 + "-stage",
37 + "--stage",
38 + type=str,
39 + default="train",
40 + help="Which stage of training to run",
41 + choices=["train", "finetune"],
42 + )
43 +
44 + parser.add_argument(
45 + "-train", "--train", type=str, help="Folder with training images"
46 + )
47 +
48 + parser.add_argument(
49 + "-validation", "--validation", type=str, help="Folder with validation images"
50 + )
51 +
52 + parser.add_argument("-test", "--test", type=str, help="Folder with testing images")
53 +
54 + parser.add_argument(
55 + "-name",
56 + "--name",
57 + type=str,
58 + default="myDataset",
59 + help="Dataset name, e.g. 'imagenet'",
60 + )
61 +
62 + parser.add_argument(
63 + "-batch_size",
64 + "--batch_size",
65 + type=int,
66 + default=4,
67 + help="What batch-size should we use",
68 + )
69 +
70 + parser.add_argument(
71 + "-test_path",
72 + "--test_path",
73 + type=str,
74 + default="./data/test_samples/",
75 + help="Where to output test images during training",
76 + )
77 +
78 + parser.add_argument(
79 + "-weight_path",
80 + "--weight_path",
81 + type=str,
82 + default="./data/logs/",
83 + help="Where to output weights during training",
84 + )
85 +
86 + parser.add_argument(
87 + "-log_path",
88 + "--log_path",
89 + type=str,
90 + default="./data/logs/",
91 + help="Where to output tensorboard logs during training",
92 + )
93 +
94 + parser.add_argument(
95 + "-vgg_path",
96 + "--vgg_path",
97 + type=str,
98 + default="./data/logs/pytorch_to_keras_vgg16.h5",
99 + help="VGG16 weights trained on PyTorch with pixel scaling 1/255.",
100 + )
101 +
102 + parser.add_argument(
103 + "-checkpoint",
104 + "--checkpoint",
105 + type=str,
106 + help="Previous weights to be loaded onto model",
107 + )
108 +
109 + return parser.parse_args()
110 +
111 +
112 +class AugmentingDataGenerator(ImageDataGenerator):
113 + """Wrapper for ImageDataGenerator to return mask & image"""
114 +
115 + def flow_from_directory(self, directory, mask_generator, *args, **kwargs):
116 + generator = super().flow_from_directory(
117 + directory, class_mode=None, *args, **kwargs
118 + )
119 + seed = None if "seed" not in kwargs else kwargs["seed"]
120 + while True:
121 +
122 + # Get augmentend image samples
123 + ori = next(generator)
124 +
125 + # Get masks for each image sample
126 + mask = np.stack(
127 + [mask_generator.sample(seed) for _ in range(ori.shape[0])], axis=0
128 + )
129 +
130 + # Apply masks to all image sample
131 + masked = deepcopy(ori)
132 + masked[mask == 0] = 1
133 +
134 + # Yield ([ori, masl], ori) training batches
135 + # print(masked.shape, ori.shape)
136 + gc.collect()
137 + yield [masked, mask], ori
138 +
139 +
140 +# Run script
141 +if __name__ == "__main__":
142 +
143 + # Parse command-line arguments
144 + args = parse_args()
145 +
146 + if args.stage == "finetune" and not args.checkpoint:
147 + raise AttributeError(
148 + "If you are finetuning your model, you must supply a checkpoint file"
149 + )
150 +
151 + # Create training generator
152 + train_datagen = AugmentingDataGenerator(
153 + rotation_range=10,
154 + width_shift_range=0.1,
155 + height_shift_range=0.1,
156 + rescale=1.0 / 255,
157 + horizontal_flip=True,
158 + )
159 + train_generator = train_datagen.flow_from_directory(
160 + args.train,
161 + MaskGenerator(512, 512, 3),
162 + target_size=(512, 512),
163 + batch_size=args.batch_size,
164 + )
165 +
166 + # Create validation generator
167 + val_datagen = AugmentingDataGenerator(rescale=1.0 / 255)
168 + val_generator = val_datagen.flow_from_directory(
169 + args.validation,
170 + MaskGenerator(512, 512, 3),
171 + target_size=(512, 512),
172 + batch_size=args.batch_size,
173 + classes=["val"],
174 + seed=42,
175 + )
176 +
177 + # Create testing generator
178 + test_datagen = AugmentingDataGenerator(rescale=1.0 / 255)
179 + test_generator = test_datagen.flow_from_directory(
180 + args.test,
181 + MaskGenerator(512, 512, 3),
182 + target_size=(512, 512),
183 + batch_size=args.batch_size,
184 + seed=42,
185 + )
186 +
187 + # Pick out an example to be send to test samples folder
188 + test_data = next(test_generator)
189 + (masked, mask), ori = test_data
190 +
191 + def plot_callback(model, path):
192 + """Called at the end of each epoch, displaying our previous test images,
193 + as well as their masked predictions and saving them to disk"""
194 +
195 + # Get samples & Display them
196 + pred_img = model.predict([masked, mask])
197 + pred_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
198 +
199 + # Clear current output and display test images
200 + for i in range(len(ori)):
201 + _, axes = plt.subplots(1, 3, figsize=(20, 5))
202 + axes[0].imshow(masked[i, :, :, :])
203 + axes[1].imshow(pred_img[i, :, :, :] * 1.0)
204 + axes[2].imshow(ori[i, :, :, :])
205 + axes[0].set_title("Masked Image")
206 + axes[1].set_title("Predicted Image")
207 + axes[2].set_title("Original Image")
208 +
209 + plt.savefig(os.path.join(path, "/img_{}_{}.png".format(i, pred_time)))
210 + plt.close()
211 +
212 + # Load the model
213 + if args.vgg_path:
214 + model = PConvUnet(vgg_weights=args.vgg_path)
215 + else:
216 + model = PConvUnet()
217 +
218 + # Loading of checkpoint
219 + if args.checkpoint:
220 + if args.stage == "train":
221 + model.load(args.checkpoint)
222 + elif args.stage == "finetune":
223 + model.load(args.checkpoint, train_bn=False, lr=0.00005)
224 +
225 + # Fit model
226 + model.fit_generator(
227 + train_generator,
228 + steps_per_epoch=10000,
229 + validation_data=val_generator,
230 + validation_steps=1000,
231 + epochs=100,
232 + verbose=0,
233 + callbacks=[
234 + TensorBoard(
235 + log_dir=os.path.join(args.log_path, args.name + "_phase1"),
236 + write_graph=False,
237 + ),
238 + ModelCheckpoint(
239 + os.path.join(
240 + args.log_path,
241 + args.name + "_phase1",
242 + "weights.{epoch:02d}-{loss:.2f}.h5",
243 + ),
244 + monitor="val_loss",
245 + save_best_only=True,
246 + save_weights_only=True,
247 + ),
248 + LambdaCallback(
249 + on_epoch_end=lambda epoch, logs: plot_callback(model, args.test_path)
250 + ),
251 + TQDMCallback(),
252 + ],
253 + )
254 +
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
1 +h5py==2.8.0
2 +Keras==2.2.4
3 +Keras-Applications==1.0.6
4 +Keras-Preprocessing==1.0.5
5 +keras-tqdm==2.0.1
6 +matplotlib==3.0.2
7 +numpy==1.15.4
8 +pandas==0.23.4
9 +scipy==1.1.0
10 +seaborn==0.9.0
11 +tables==3.4.4
12 +tensorboard==1.12.2
13 +tensorflow==1.12.0
14 +tqdm==4.28.1
...\ No newline at end of file ...\ No newline at end of file