이민호

upload 보고서, codes

Showing 156 changed files with 1547 additions and 0 deletions
1 +*.pyc
2 +__pycache__/
3 +/data/
...\ No newline at end of file ...\ No newline at end of file
1 +from __future__ import absolute_import
2 +from __future__ import division
3 +from __future__ import print_function
4 +
5 +import argparse
6 +from functools import partial
7 +from multiprocessing import Pool
8 +import os
9 +import re
10 +
11 +import cropper
12 +import numpy as np
13 +import tqdm
14 +
15 +
16 +# ==============================================================================
17 +# = param =
18 +# ==============================================================================
19 +
20 +parser = argparse.ArgumentParser()
21 +# main
22 +parser.add_argument('--img_dir', dest='img_dir', default='./data/img_celeba')
23 +parser.add_argument('--save_dir', dest='save_dir', default='./data/aligned')
24 +parser.add_argument('--landmark_file', dest='landmark_file', default='./data/landmark.txt')
25 +parser.add_argument('--standard_landmark_file', dest='standard_landmark_file', default='./data/standard_landmark_68pts.txt')
26 +parser.add_argument('--crop_size_h', dest='crop_size_h', type=int, default=572)
27 +parser.add_argument('--crop_size_w', dest='crop_size_w', type=int, default=572)
28 +parser.add_argument('--move_h', dest='move_h', type=float, default=0.25)
29 +parser.add_argument('--move_w', dest='move_w', type=float, default=0.)
30 +parser.add_argument('--save_format', dest='save_format', choices=['jpg', 'png'], default='jpg')
31 +parser.add_argument('--n_worker', dest='n_worker', type=int, default=8)
32 +# others
33 +parser.add_argument('--face_factor', dest='face_factor', type=float, help='The factor of face area relative to the output image.', default=0.45)
34 +parser.add_argument('--align_type', dest='align_type', choices=['affine', 'similarity'], default='similarity')
35 +parser.add_argument('--order', dest='order', type=int, choices=[0, 1, 2, 3, 4, 5], help='The order of interpolation.', default=3)
36 +parser.add_argument('--mode', dest='mode', choices=['constant', 'edge', 'symmetric', 'reflect', 'wrap'], default='edge')
37 +args = parser.parse_args()
38 +
39 +
40 +# ==============================================================================
41 +# = opencv first =
42 +# ==============================================================================
43 +
44 +_DEAFAULT_JPG_QUALITY = 95
45 +try:
46 + import cv2
47 + imread = cv2.imread
48 + imwrite = partial(cv2.imwrite, params=[int(cv2.IMWRITE_JPEG_QUALITY), _DEAFAULT_JPG_QUALITY])
49 + align_crop = cropper.align_crop_opencv
50 + print('Use OpenCV')
51 +except:
52 + import skimage.io as io
53 + imread = io.imread
54 + imwrite = partial(io.imsave, quality=_DEAFAULT_JPG_QUALITY)
55 + align_crop = cropper.align_crop_skimage
56 + print('Importing OpenCv fails. Use scikit-image')
57 +
58 +
59 +# ==============================================================================
60 +# = run =
61 +# ==============================================================================
62 +
63 +# count landmarks
64 +with open(args.landmark_file) as f:
65 + line = f.readline()
66 +n_landmark = len(re.split('[ ]+', line)[1:]) // 2
67 +
68 +# load standard landmark
69 +standard_landmark = np.genfromtxt(args.standard_landmark_file, dtype=np.float).reshape(n_landmark, 2)
70 +standard_landmark[:, 0] += args.move_w
71 +standard_landmark[:, 1] += args.move_h
72 +
73 +# data dir
74 +save_dir = os.path.join(args.save_dir, 'align_size(%d,%d)_move(%.3f,%.3f)_face_factor(%.3f)_%s' % (args.crop_size_h, args.crop_size_w, args.move_h, args.move_w, args.face_factor, args.save_format))
75 +data_dir = os.path.join(save_dir, 'data')
76 +if not os.path.isdir(data_dir):
77 + os.makedirs(data_dir)
78 +
79 +
80 +def work(name, landmark) -> str: # a single work
81 + for _ in range(3): # try three times
82 + try:
83 + img = imread(os.path.join(args.img_dir, name))
84 + img_crop, tformed_landmarks = align_crop(img,
85 + landmark,
86 + standard_landmark,
87 + crop_size=(args.crop_size_h, args.crop_size_w),
88 + face_factor=args.face_factor,
89 + align_type=args.align_type,
90 + order=args.order,
91 + mode=args.mode)
92 +
93 + name = os.path.splitext(name)[0] + '.' + args.save_format
94 + path = os.path.join(data_dir, name)
95 + if not os.path.isdir(os.path.split(path)[0]):
96 + os.makedirs(os.path.split(path)[0])
97 + imwrite(path, img_crop)
98 +
99 + tformed_landmarks.shape = -1
100 + name_landmark_str = ('%s' + ' %.1f' * n_landmark * 2) % ((name, ) + tuple(tformed_landmarks))
101 + return name_landmark_str
102 + except:
103 + print('%s fails!' % name)
104 +
105 +
106 +if __name__ == "__main__":
107 + img_names = np.genfromtxt(args.landmark_file, dtype=np.str, usecols=0)
108 + landmarks = np.genfromtxt(args.landmark_file, dtype=np.float,
109 + usecols=range(1, n_landmark * 2 + 1)).reshape(-1, n_landmark, 2)
110 +
111 + n_pics = len(img_names)
112 +
113 + landmarks_path = os.path.join(save_dir, 'landmark.txt')
114 + f = open(landmarks_path, 'w')
115 + pool = Pool(args.n_worker)
116 + bar = tqdm.tqdm(total=n_pics)
117 +
118 + tasks = []
119 + for i in range(n_pics):
120 + tasks.append(pool.apply_async(work, (img_names[i], landmarks[i]), callback=lambda _: bar.update()))
121 +
122 + try:
123 + result = tasks.pop(0).get()
124 + if result is not None and result != "":
125 + f.write(result + '\n')
126 + except:
127 + pass
128 +
129 + pool.close()
130 + pool.join()
131 + bar.close()
132 + f.close()
...\ No newline at end of file ...\ No newline at end of file
1 +import numpy as np
2 +
3 +
4 +def align_crop_opencv(img,
5 + src_landmarks,
6 + standard_landmarks,
7 + crop_size=512,
8 + face_factor=0.7,
9 + align_type='similarity',
10 + order=3,
11 + mode='edge'):
12 + """Align and crop a face image by landmarks.
13 +
14 + Arguments:
15 + img : Face image to be aligned and cropped.
16 + src_landmarks : [[x_1, y_1], ..., [x_n, y_n]].
17 + standard_landmarks : Standard shape, should be normalized.
18 + crop_size : Output image size, should be 1. int for (crop_size, crop_size)
19 + or 2. (int, int) for (crop_size_h, crop_size_w).
20 + face_factor : The factor of face area relative to the output image.
21 + align_type : 'similarity' or 'affine'.
22 + order : The order of interpolation. The order has to be in the range 0-5:
23 + - 0: INTER_NEAREST
24 + - 1: INTER_LINEAR
25 + - 2: INTER_AREA
26 + - 3: INTER_CUBIC
27 + - 4: INTER_LANCZOS4
28 + - 5: INTER_LANCZOS4
29 + mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap'].
30 + Points outside the boundaries of the input are filled according
31 + to the given mode.
32 + """
33 + # set OpenCV
34 + import cv2
35 + inter = {0: cv2.INTER_NEAREST, 1: cv2.INTER_LINEAR, 2: cv2.INTER_AREA,
36 + 3: cv2.INTER_CUBIC, 4: cv2.INTER_LANCZOS4, 5: cv2.INTER_LANCZOS4}
37 + border = {'constant': cv2.BORDER_CONSTANT, 'edge': cv2.BORDER_REPLICATE,
38 + 'symmetric': cv2.BORDER_REFLECT, 'reflect': cv2.BORDER_REFLECT101,
39 + 'wrap': cv2.BORDER_WRAP}
40 +
41 + # check
42 + assert align_type in ['affine', 'similarity'], 'Invalid `align_type`! Allowed: %s!' % ['affine', 'similarity']
43 + assert order in [0, 1, 2, 3, 4, 5], 'Invalid `order`! Allowed: %s!' % [0, 1, 2, 3, 4, 5]
44 + assert mode in ['constant', 'edge', 'symmetric', 'reflect', 'wrap'], 'Invalid `mode`! Allowed: %s!' % ['constant', 'edge', 'symmetric', 'reflect', 'wrap']
45 +
46 + # crop size
47 + if isinstance(crop_size, (list, tuple)) and len(crop_size) == 2:
48 + crop_size_h = crop_size[0]
49 + crop_size_w = crop_size[1]
50 + elif isinstance(crop_size, int):
51 + crop_size_h = crop_size_w = crop_size
52 + else:
53 + raise Exception('Invalid `crop_size`! `crop_size` should be 1. int for (crop_size, crop_size) or 2. (int, int) for (crop_size_h, crop_size_w)!')
54 +
55 + # estimate transform matrix
56 + trg_landmarks = standard_landmarks * max(crop_size_h, crop_size_w) * face_factor + np.array([crop_size_w // 2, crop_size_h // 2])
57 + if align_type == 'affine':
58 + tform = cv2.estimateAffine2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0]
59 + else:
60 + tform = cv2.estimateAffinePartial2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0]
61 +
62 + # warp image by given transform
63 + output_shape = (crop_size_h, crop_size_w)
64 + img_crop = cv2.warpAffine(img, tform, output_shape[::-1], flags=cv2.WARP_INVERSE_MAP + inter[order], borderMode=border[mode])
65 +
66 + # get transformed landmarks
67 + tformed_landmarks = cv2.transform(np.expand_dims(src_landmarks, axis=0), cv2.invertAffineTransform(tform))[0]
68 +
69 + return img_crop, tformed_landmarks
70 +
71 +
72 +def align_crop_skimage(img,
73 + src_landmarks,
74 + standard_landmarks,
75 + crop_size=512,
76 + face_factor=0.7,
77 + align_type='similarity',
78 + order=3,
79 + mode='edge'):
80 + """Align and crop a face image by landmarks.
81 +
82 + Arguments:
83 + img : Face image to be aligned and cropped.
84 + src_landmarks : [[x_1, y_1], ..., [x_n, y_n]].
85 + standard_landmarks : Standard shape, should be normalized.
86 + crop_size : Output image size, should be 1. int for (crop_size, crop_size)
87 + or 2. (int, int) for (crop_size_h, crop_size_w).
88 + face_factor : The factor of face area relative to the output image.
89 + align_type : 'similarity' or 'affine'.
90 + order : The order of interpolation. The order has to be in the range 0-5:
91 + - 0: INTER_NEAREST
92 + - 1: INTER_LINEAR
93 + - 2: INTER_AREA
94 + - 3: INTER_CUBIC
95 + - 4: INTER_LANCZOS4
96 + - 5: INTER_LANCZOS4
97 + mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap'].
98 + Points outside the boundaries of the input are filled according
99 + to the given mode.
100 + """
101 + raise NotImplementedError("'align_crop_skimage' is not implemented!")
This diff could not be displayed because it is too large.
1 +# Auto detect text files and perform LF normalization
2 +* text=auto
1 +*.pyc
2 +docs
3 +data
4 +lfw
5 +lfw_40
6 +.idea
7 +loss
8 +vgg_face_dataset
9 +saved_network
10 +loss
11 +z_detect_face.py
12 +z_main.py
13 +*.npy
14 +*.Lnk
15 +data1
16 +data1_masked
17 +scratch.py
18 +subset
19 +subset_masked
20 +vgg_face_dataset
21 +*.mp4
22 +ML_examples
23 +*.pptx
24 +datasets
25 +*.dat
26 +*.docx
27 +
1 +theme: jekyll-theme-cayman
...\ No newline at end of file ...\ No newline at end of file
1 +# Author: aqeelanwar
2 +# Created: 27 April,2020, 10:22 PM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +import argparse
6 +import dlib
7 +from utils.aux_functions import *
8 +
9 +
10 +# Command-line input setup
11 +parser = argparse.ArgumentParser(
12 + description="MaskTheFace - Python code to mask faces dataset"
13 +)
14 +parser.add_argument(
15 + "--path",
16 + type=str,
17 + default="",
18 + help="Path to either the folder containing images or the image itself",
19 +)
20 +parser.add_argument(
21 + "--mask_type",
22 + type=str,
23 + default="surgical",
24 + choices=["surgical", "N95", "KN95", "cloth", "gas", "inpaint", "random", "all"],
25 + help="Type of the mask to be applied. Available options: all, surgical_blue, surgical_green, N95, cloth",
26 +)
27 +
28 +parser.add_argument(
29 + "--pattern",
30 + type=str,
31 + default="",
32 + help="Type of the pattern. Available options in masks/textures",
33 +)
34 +
35 +parser.add_argument(
36 + "--pattern_weight",
37 + type=float,
38 + default=0.5,
39 + help="Weight of the pattern. Must be between 0 and 1",
40 +)
41 +
42 +parser.add_argument(
43 + "--color",
44 + type=str,
45 + default="#0473e2",
46 + help="Hex color value that need to be overlayed to the mask",
47 +)
48 +
49 +parser.add_argument(
50 + "--color_weight",
51 + type=float,
52 + default=0.5,
53 + help="Weight of the color intensity. Must be between 0 and 1",
54 +)
55 +
56 +parser.add_argument(
57 + "--code",
58 + type=str,
59 + # default="cloth-masks/textures/check/check_4.jpg, cloth-#e54294, cloth-#ff0000, cloth, cloth-masks/textures/others/heart_1.png, cloth-masks/textures/fruits/pineapple.png, N95, surgical_blue, surgical_green",
60 + default="",
61 + help="Generate specific formats",
62 +)
63 +
64 +
65 +parser.add_argument(
66 + "--verbose", dest="verbose", action="store_true", help="Turn verbosity on"
67 +)
68 +parser.add_argument(
69 + "--write_original_image",
70 + dest="write_original_image",
71 + action="store_true",
72 + help="If true, original image is also stored in the masked folder",
73 +)
74 +parser.set_defaults(feature=False)
75 +
76 +args = parser.parse_args()
77 +args.write_path = args.path + "_masked"
78 +
79 +# Set up dlib face detector and predictor
80 +args.detector = dlib.get_frontal_face_detector()
81 +path_to_dlib_model = "dlib_models/shape_predictor_68_face_landmarks.dat"
82 +if not os.path.exists(path_to_dlib_model):
83 + download_dlib_model()
84 +
85 +args.predictor = dlib.shape_predictor(path_to_dlib_model)
86 +
87 +# Extract data from code
88 +mask_code = "".join(args.code.split()).split(",")
89 +args.code_count = np.zeros(len(mask_code))
90 +args.mask_dict_of_dict = {}
91 +
92 +
93 +for i, entry in enumerate(mask_code):
94 + mask_dict = {}
95 + mask_color = ""
96 + mask_texture = ""
97 + mask_type = entry.split("-")[0]
98 + if len(entry.split("-")) == 2:
99 + mask_variation = entry.split("-")[1]
100 + if "#" in mask_variation:
101 + mask_color = mask_variation
102 + else:
103 + mask_texture = mask_variation
104 + mask_dict["type"] = mask_type
105 + mask_dict["color"] = mask_color
106 + mask_dict["texture"] = mask_texture
107 + args.mask_dict_of_dict[i] = mask_dict
108 +
109 +# Check if path is file or directory or none
110 +is_directory, is_file, is_other = check_path(args.path)
111 +display_MaskTheFace()
112 +
113 +if is_directory:
114 + path, dirs, files = os.walk(args.path).__next__()
115 + file_count = len(files)
116 + dirs_count = len(dirs)
117 + if len(files) > 0:
118 + print_orderly("Masking image files", 60)
119 +
120 + # Process files in the directory if any
121 + for f in tqdm(files):
122 + image_path = path + "/" + f
123 +
124 + write_path = path + "_masked"
125 + if not os.path.isdir(write_path):
126 + os.makedirs(write_path)
127 +
128 + if is_image(image_path):
129 + # Proceed if file is image
130 + if args.verbose:
131 + str_p = "Processing: " + image_path
132 + tqdm.write(str_p)
133 +
134 + split_path = f.rsplit(".")
135 + masked_image, mask, mask_binary_array, original_image = mask_image(
136 + image_path, args
137 + )
138 + for i in range(len(mask)):
139 + w_path = (
140 + write_path
141 + + "/"
142 + + split_path[0]
143 + + "_"
144 + + "masked"
145 + + "."
146 + + split_path[1]
147 + )
148 + img = masked_image[i]
149 + binary_img = mask_binary_array[i]
150 + cv2.imwrite(w_path, img)
151 + cv2.imwrite(
152 + path + "_binary/" + split_path[0] + "_binary" + "." + split_path[1],
153 + binary_img,
154 + )
155 + cv2.imwrite(
156 + path + "_original/" + split_path[0] + "." + split_path[1],
157 + original_image,
158 + )
159 +
160 + print_orderly("Masking image directories", 60)
161 +
162 + # Process directories withing the path provided
163 + for d in tqdm(dirs):
164 + dir_path = args.path + "/" + d
165 + dir_write_path = args.write_path + "/" + d
166 + if not os.path.isdir(dir_write_path):
167 + os.makedirs(dir_write_path)
168 + _, _, files = os.walk(dir_path).__next__()
169 +
170 + # Process each files within subdirectory
171 + for f in files:
172 + image_path = dir_path + "/" + f
173 + if args.verbose:
174 + str_p = "Processing: " + image_path
175 + tqdm.write(str_p)
176 + write_path = dir_write_path
177 + if is_image(image_path):
178 + # Proceed if file is image
179 + split_path = f.rsplit(".")
180 + masked_image, mask, mask_binary, original_image = mask_image(
181 + image_path, args
182 + )
183 + for i in range(len(mask)):
184 + w_path = (
185 + write_path
186 + + "/"
187 + + split_path[0]
188 + + "_"
189 + + "masked"
190 + + "."
191 + + split_path[1]
192 + )
193 + w_path_original = write_path + "/" + f
194 + img = masked_image[i]
195 + binary_img = mask_binary[i]
196 + cv2.imwrite(
197 + path
198 + + "_binary/"
199 + + split_path[0]
200 + + "_binary"
201 + + "."
202 + + split_path[1],
203 + binary_img,
204 + )
205 + # Write the masked image
206 + cv2.imwrite(w_path, img)
207 + if args.write_original_image:
208 + # Write the original image
209 + cv2.imwrite(w_path_original, original_image)
210 +
211 + if args.verbose:
212 + print(args.code_count)
213 +
214 +# Process if the path was a file
215 +elif is_file:
216 + print("Masking image file")
217 + image_path = args.path
218 + write_path = args.path.rsplit(".")[0]
219 + if is_image(image_path):
220 + # Proceed if file is image
221 + # masked_images, mask, mask_binary_array, original_image
222 + masked_image, mask, mask_binary_array, original_image = mask_image(
223 + image_path, args
224 + )
225 + for i in range(len(mask)):
226 + w_path = write_path + "_" + "masked" + "." + args.path.rsplit(".")[1]
227 + img = masked_image[i]
228 + binary_img = mask_binary_array[i]
229 + cv2.imwrite(w_path, img)
230 + cv2.imwrite(write_path + "_binary." + args.path.rsplit(".")[1], binary_img)
231 +else:
232 + print("Path is neither a valid file or a valid directory")
233 +print("Processing Done")
1 +[surgical]
2 +template: masks/templates/surgical.png
3 +mask_a: 21, 97
4 +mask_b: 307, 22
5 +mask_c: 600, 99
6 +mask_d: 25, 322
7 +mask_e: 295, 470
8 +mask_f: 600, 323
9 +
10 +[surgical_left]
11 +template: masks/templates/surgical_left.png
12 +mask_a: 39, 27
13 +mask_b: 130, 9
14 +mask_c: 567, 20
15 +mask_d: 87, 207
16 +mask_e: 168, 302
17 +mask_f: 568, 202
18 +
19 +[surgical_right]
20 +template: masks/templates/surgical_right.png
21 +mask_a: 3, 20
22 +mask_b: 440, 9
23 +mask_c: 531, 27
24 +mask_d: 2, 202
25 +mask_e: 402, 302
26 +mask_f: 483, 207
27 +
28 +[surgical_green]
29 +template: masks/templates/surgical_green.png
30 +mask_a: 21, 97
31 +mask_b: 307, 22
32 +mask_c: 600, 99
33 +mask_d: 25, 322
34 +mask_e: 295, 470
35 +mask_f: 600, 323
36 +
37 +[surgical_green_left]
38 +template: masks/templates/surgical_green_left.png
39 +mask_a: 39, 27
40 +mask_b: 130, 9
41 +mask_c: 567, 20
42 +mask_d: 87, 207
43 +mask_e: 168, 302
44 +mask_f: 568, 202
45 +
46 +[surgical_green_right]
47 +template: masks/templates/surgical_green_right.png
48 +mask_a: 3, 20
49 +mask_b: 440, 9
50 +mask_c: 531, 27
51 +mask_d: 2, 202
52 +mask_e: 402, 302
53 +mask_f: 483, 207
54 +
55 +[surgical_blue]
56 +template: masks/templates/surgical_blue.png
57 +mask_a: 21, 97
58 +mask_b: 307, 22
59 +mask_c: 600, 99
60 +mask_d: 25, 322
61 +mask_e: 295, 470
62 +mask_f: 600, 323
63 +
64 +[surgical_blue_left]
65 +template: masks/templates/surgical_blue_left.png
66 +mask_a: 39, 27
67 +mask_b: 130, 9
68 +mask_c: 567, 20
69 +mask_d: 87, 207
70 +mask_e: 168, 302
71 +mask_f: 568, 202
72 +
73 +[surgical_blue_right]
74 +template: masks/templates/surgical_blue_right.png
75 +mask_a: 3, 20
76 +mask_b: 440, 9
77 +mask_c: 531, 27
78 +mask_d: 2, 202
79 +mask_e: 402, 302
80 +mask_f: 483, 207
81 +
82 +
83 +[N95]
84 +template: masks/templates/N95.png
85 +mask_a: 15, 119
86 +mask_b: 327, 5
87 +mask_c: 640, 93
88 +mask_d: 13, 285
89 +mask_e: 351, 518
90 +mask_f: 645, 285
91 +
92 +;[N95_left]
93 +;template: masks/N95_left.png
94 +;mask_a: 176, 121
95 +;mask_b: 313, 46
96 +;mask_c: 799, 135
97 +;mask_d: 97, 438
98 +;mask_e: 329, 627
99 +;mask_f: 791, 401
100 +
101 +[N95_right]
102 +template: masks/templates/N95_right.png
103 +mask_c: 979, 331
104 +mask_b: 806, 172
105 +mask_a: 12, 222
106 +mask_f: 907, 762
107 +mask_e: 577, 875
108 +mask_d: -4, 632
109 +
110 +[N95_left]
111 +template: masks/templates/N95_left.png
112 +mask_a: 193, 331
113 +mask_b: 366, 172
114 +mask_c: 1160, 222
115 +mask_d: 265, 762
116 +mask_e: 595, 875
117 +mask_f: 1176, 632
118 +
119 +
120 +[cloth_left]
121 +template: masks/templates/cloth_left.png
122 +mask_a: 65, 93
123 +mask_b: 162, 15
124 +mask_c: 672, 75
125 +mask_d: 114, 296
126 +mask_e: 207, 443
127 +mask_f: 671, 341
128 +
129 +[cloth_right]
130 +template: masks/templates/cloth_right.png
131 +mask_a: 98, 93
132 +mask_b: 608, 15
133 +mask_c: 705, 75
134 +mask_d: 99, 296
135 +mask_e: 563, 443
136 +mask_f: 656, 341
137 +
138 +[cloth]
139 +template: masks/templates/cloth.png
140 +mask_a: 122, 90
141 +mask_b: 405, 7
142 +mask_c: 686, 79
143 +mask_d: 165, 323
144 +mask_e: 406, 509
145 +mask_f: 653, 311
146 +
147 +[gas]
148 +template: masks/templates/gas.png
149 +mask_a: 330, 431
150 +mask_b: 873, 117
151 +mask_c: 1494, 434
152 +mask_d: 430, 754
153 +mask_e: 869, 1100
154 +mask_f: 1400, 710
155 +
156 +[gas_left]
157 +template: masks/templates/gas_left.png
158 +mask_a: 239, 238
159 +mask_b: 317, 42
160 +mask_c: 965, 239
161 +mask_d: 224, 404
162 +mask_e: 337, 502
163 +mask_f: 963, 406
164 +
165 +[gas_right]
166 +template: masks/templates/gas_right.png
167 +mask_c: 621, 238
168 +mask_b: 543, 60
169 +mask_a: -105, 239
170 +mask_f: 636, 404
171 +mask_e: 523, 502
172 +mask_d: -103, 406
173 +
174 +[KN95]
175 +template: masks/templates/KN95.png
176 +mask_a: 20, 47
177 +mask_b: 410, 5
178 +mask_c: 760, 55
179 +mask_d: 75, 340
180 +mask_e: 398, 600
181 +mask_f: 671, 320
182 +
183 +[KN95_left]
184 +template: masks/templates/KN95_left.png
185 +mask_a: 52, 258
186 +mask_b: 207, 100
187 +mask_c: 730, 80
188 +mask_d: 210, 408
189 +mask_e: 335, 604
190 +mask_f: 770, 270
191 +
192 +[KN95_right]
193 +template: masks/templates/KN95_right.png
194 +mask_c: 664, 258
195 +mask_b: 509, 100
196 +mask_a: -14, 80
197 +mask_f: 506, 408
198 +mask_e: 381, 604
199 +mask_d: -54, 270
200 +
201 +
202 +[empty]
203 +[empty_left]
204 +[empty_right]
205 +
206 +[inpaint]
207 +[inpaint_left]
208 +[inpaint_right]
209 +
1 +certifi==2020.4.5.1
2 +click==7.1.2
3 +dlib==19.19.0
4 +dotmap==1.3.14
5 +face-recognition==1.3.0
6 +face-recognition-models==0.3.0
7 +numpy==1.18.4
8 +opencv-python==4.2.0.34
9 +Pillow==7.1.2
10 +tqdm==4.46.0
11 +wincertstore==0.2
12 +imutils==0.5.3
13 +requests==2.24.0
1 +# Author: Aqeel Anwar(ICSRL)
2 +# Created: 7/30/2020, 7:43 AM
3 +# Email: aqeel.anwar@gatech.edu
...\ No newline at end of file ...\ No newline at end of file
This diff is collapsed. Click to expand it.
1 +# Author: aqeelanwar
2 +# Created: 6 July,2020, 12:14 AM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +from PIL import ImageColor
6 +import cv2
7 +import numpy as np
8 +
9 +COLOR = [
10 + "#fc1c1a",
11 + "#177ABC",
12 + "#94B6D2",
13 + "#A5AB81",
14 + "#DD8047",
15 + "#6b425e",
16 + "#e26d5a",
17 + "#c92c48",
18 + "#6a506d",
19 + "#ffc900",
20 + "#ffffff",
21 + "#000000",
22 + "#49ff00",
23 +]
24 +
25 +
26 +def color_the_mask(mask_image, color, intensity):
27 + assert 0 <= intensity <= 1, "intensity should be between 0 and 1"
28 + RGB_color = ImageColor.getcolor(color, "RGB")
29 + RGB_color = (RGB_color[2], RGB_color[1], RGB_color[0])
30 + orig_shape = mask_image.shape
31 + bit_mask = mask_image[:, :, 3]
32 + mask_image = mask_image[:, :, 0:3]
33 +
34 + color_image = np.full(mask_image.shape, RGB_color, np.uint8)
35 + mask_color = cv2.addWeighted(mask_image, 1 - intensity, color_image, intensity, 0)
36 + mask_color = cv2.bitwise_and(mask_color, mask_color, mask=bit_mask)
37 + colored_mask = np.zeros(orig_shape, dtype=np.uint8)
38 + colored_mask[:, :, 0:3] = mask_color
39 + colored_mask[:, :, 3] = bit_mask
40 + return colored_mask
41 +
42 +
43 +def texture_the_mask(mask_image, texture_path, intensity):
44 + assert 0 <= intensity <= 1, "intensity should be between 0 and 1"
45 + orig_shape = mask_image.shape
46 + bit_mask = mask_image[:, :, 3]
47 + mask_image = mask_image[:, :, 0:3]
48 + texture_image = cv2.imread(texture_path)
49 + texture_image = cv2.resize(texture_image, (orig_shape[1], orig_shape[0]))
50 +
51 + mask_texture = cv2.addWeighted(
52 + mask_image, 1 - intensity, texture_image, intensity, 0
53 + )
54 + mask_texture = cv2.bitwise_and(mask_texture, mask_texture, mask=bit_mask)
55 + textured_mask = np.zeros(orig_shape, dtype=np.uint8)
56 + textured_mask[:, :, 0:3] = mask_texture
57 + textured_mask[:, :, 3] = bit_mask
58 +
59 + return textured_mask
60 +
61 +
62 +
63 +# cloth_mask = cv2.imread("masks/templates/cloth.png", cv2.IMREAD_UNCHANGED)
64 +# # cloth_mask = color_the_mask(cloth_mask, color=COLOR[0], intensity=0.5)
65 +# path = "masks/textures"
66 +# path, dir, files = os.walk(path).__next__()
67 +# first_frame = True
68 +# col_limit = 6
69 +# i = 0
70 +# # img_concat_row=[]
71 +# img_concat = []
72 +# # for f in files:
73 +# # if "._" not in f:
74 +# # print(f)
75 +# # i += 1
76 +# # texture_image = cv2.imread(os.path.join(path, f))
77 +# # m = texture_the_mask(cloth_mask, texture_image, intensity=0.5)
78 +# # if first_frame:
79 +# # img_concat_row = m
80 +# # first_frame = False
81 +# # else:
82 +# # img_concat_row = cv2.hconcat((img_concat_row, m))
83 +# #
84 +# # if i % col_limit == 0:
85 +# # if len(img_concat) > 0:
86 +# # img_concat = cv2.vconcat((img_concat, img_concat_row))
87 +# # else:
88 +# # img_concat = img_concat_row
89 +# # first_frame = True
90 +#
91 +# ## COlor the mask
92 +# thresholds = np.arange(0.1,0.9,0.05)
93 +# for intensity in thresholds:
94 +# c=COLOR[2]
95 +# # intensity = 0.5
96 +# if "._" not in c:
97 +# print(intensity)
98 +# i += 1
99 +# # texture_image = cv2.imread(os.path.join(path, f))
100 +# m = color_the_mask(cloth_mask, c, intensity=intensity)
101 +# if first_frame:
102 +# img_concat_row = m
103 +# first_frame = False
104 +# else:
105 +# img_concat_row = cv2.hconcat((img_concat_row, m))
106 +#
107 +# if i % col_limit == 0:
108 +# if len(img_concat) > 0:
109 +# img_concat = cv2.vconcat((img_concat, img_concat_row))
110 +# else:
111 +# img_concat = img_concat_row
112 +# first_frame = True
113 +#
114 +#
115 +# cv2.imshow("k", img_concat)
116 +# cv2.imwrite("combine_N95_left.png", img_concat)
117 +# cv2.waitKey(0)
118 +# cc = 1
1 + __ __ _ _______ _ ______
2 +| \/ | | |__ __| | | ____|
3 +| \ / | __ _ ___| | _| | | |__ ___| |__ __ _ ___ ___
4 +| |\/| |/ _` / __| |/ / | | '_ \ / _ \ __/ _` |/ __/ _ \
5 +| | | | (_| \__ \ <| | | | | | __/ | | (_| | (_| __/
6 +|_| |_|\__,_|___/_|\_\_| |_| |_|\___|_| \__,_|\___\___|
1 +# Author: Aqeel Anwar(ICSRL)
2 +# Created: 7/30/2020, 1:44 PM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +# Code resued from https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
6 +# Make sure you run this from parent folder and not from utils folder i.e.
7 +# python utils/fetch_dataset.py
8 +
9 +import requests, os
10 +from zipfile import ZipFile
11 +import argparse
12 +import urllib
13 +
14 +parser = argparse.ArgumentParser(
15 + description="Download dataset - Python code to download associated datasets"
16 +)
17 +parser.add_argument(
18 + "--dataset",
19 + type=str,
20 + default="mfr2",
21 + help="Name of the dataset - Details on available datasets can be found at GitHub Page",
22 +)
23 +args = parser.parse_args()
24 +
25 +
26 +def download_file_from_google_drive(id, destination):
27 + URL = "https://docs.google.com/uc?export=download"
28 +
29 + session = requests.Session()
30 +
31 + response = session.get(URL, params={"id": id}, stream=True)
32 + token = get_confirm_token(response)
33 +
34 + if token:
35 + params = {"id": id, "confirm": token}
36 + response = session.get(URL, params=params, stream=True)
37 +
38 + save_response_content(response, destination)
39 +
40 +
41 +def get_confirm_token(response):
42 + for key, value in response.cookies.items():
43 + if key.startswith("download_warning"):
44 + return value
45 +
46 + return None
47 +
48 +
49 +def save_response_content(response, destination):
50 + CHUNK_SIZE = 32768
51 + print(destination)
52 + with open(destination, "wb") as f:
53 + for chunk in response.iter_content(CHUNK_SIZE):
54 + if chunk: # filter out keep-alive new chunks
55 + f.write(chunk)
56 +
57 +
58 +def download(t_url):
59 + response = urllib.request.urlopen(t_url)
60 + data = response.read()
61 + txt_str = str(data)
62 + lines = txt_str.split("\\n")
63 + return lines
64 +
65 +
66 +def Convert(lst):
67 + it = iter(lst)
68 + res_dct = dict(zip(it, it))
69 + return res_dct
70 +
71 +
72 +if __name__ == "__main__":
73 + # Fetch the latest download_links.txt file from GitHub
74 + link = "https://raw.githubusercontent.com/aqeelanwar/MaskTheFace/master/datasets/download_links.txt"
75 + links_dict = Convert(
76 + download(link)[0]
77 + .replace(":", "\n")
78 + .replace("b'", "")
79 + .replace("'", "")
80 + .replace(" ", "")
81 + .split("\n")
82 + )
83 + file_id = links_dict[args.dataset]
84 + destination = "datasets\_.zip"
85 + print("Downloading: ", args.dataset)
86 + download_file_from_google_drive(file_id, destination)
87 + print("Extracting: ", args.dataset)
88 + with ZipFile(destination, "r") as zipObj:
89 + # Extract all the contents of zip file in current directory
90 + zipObj.extractall(destination.rsplit(os.path.sep, 1)[0])
91 +
92 + os.remove(destination)
1 +# Author: aqeelanwar
2 +# Created: 4 May,2020, 1:30 AM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +import numpy as np
6 +from numpy.linalg import eig, inv
7 +
8 +def fitEllipse(x,y):
9 + x = x[:,np.newaxis]
10 + y = y[:,np.newaxis]
11 + D = np.hstack((x*x, x*y, y*y, x, y, np.ones_like(x)))
12 + S = np.dot(D.T,D)
13 + C = np.zeros([6,6])
14 + C[0,2] = C[2,0] = 2; C[1,1] = -1
15 + E, V = eig(np.dot(inv(S), C))
16 + n = np.argmax(np.abs(E))
17 + a = V[:,n]
18 + return a
19 +
20 +def ellipse_center(a):
21 + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
22 + num = b*b-a*c
23 + x0=(c*d-b*f)/num
24 + y0=(a*f-b*d)/num
25 + return np.array([x0,y0])
26 +
27 +
28 +def ellipse_angle_of_rotation( a ):
29 + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
30 + return 0.5*np.arctan(2*b/(a-c))
31 +
32 +
33 +def ellipse_axis_length( a ):
34 + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
35 + up = 2*(a*f*f+c*d*d+g*b*b-2*b*d*f-a*c*g)
36 + down1=(b*b-a*c)*( (c-a)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a))
37 + down2=(b*b-a*c)*( (a-c)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a))
38 + res1=np.sqrt(up/down1)
39 + res2=np.sqrt(up/down2)
40 + return np.array([res1, res2])
41 +
42 +def ellipse_angle_of_rotation2( a ):
43 + b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0]
44 + if b == 0:
45 + if a > c:
46 + return 0
47 + else:
48 + return np.pi/2
49 + else:
50 + if a > c:
51 + return np.arctan(2*b/(a-c))/2
52 + else:
53 + return np.pi/2 + np.arctan(2*b/(a-c))/2
54 +
55 +# a = fitEllipse(x,y)
56 +# center = ellipse_center(a)
57 +# #phi = ellipse_angle_of_rotation(a)
58 +# phi = ellipse_angle_of_rotation2(a)
59 +# axes = ellipse_axis_length(a)
60 +#
61 +# print("center = ", center)
62 +# print("angle of rotation = ", phi)
63 +# print("axes = ", axes)
64 +
1 +# Author: aqeelanwar
2 +# Created: 2 May,2020, 2:49 AM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +from tkinter import filedialog
6 +from tkinter import *
7 +import cv2, os
8 +
9 +mouse_pts = []
10 +
11 +
12 +def get_mouse_points(event, x, y, flags, param):
13 + global mouseX, mouseY, mouse_pts
14 + if event == cv2.EVENT_LBUTTONDOWN:
15 + mouseX, mouseY = x, y
16 + cv2.circle(mask_im, (x, y), 10, (0, 255, 255), 10)
17 + if "mouse_pts" not in globals():
18 + mouse_pts = []
19 + mouse_pts.append((x, y))
20 + # print("Point detected")
21 + # print((x,y))
22 +
23 +
24 +root = Tk()
25 +filename = filedialog.askopenfilename(
26 + initialdir="/",
27 + title="Select file",
28 + filetypes=(("PNG files", "*.PNG"), ("png files", "*.png"), ("All files", "*.*")),
29 +)
30 +root.destroy()
31 +filename_split = os.path.split(filename)
32 +folder = filename_split[0]
33 +file = filename_split[1]
34 +file_split = file.split(".")
35 +new_filename = folder + "/" + file_split[0] + "_marked." + file_split[-1]
36 +mask_im = cv2.imread(filename)
37 +cv2.namedWindow("Mask")
38 +cv2.setMouseCallback("Mask", get_mouse_points)
39 +
40 +while True:
41 + cv2.imshow("Mask", mask_im)
42 + cv2.waitKey(1)
43 + if len(mouse_pts) == 6:
44 + cv2.destroyWindow("Mask")
45 + break
46 + first_frame_display = False
47 +points = mouse_pts
48 +print(points)
49 +print("----------------------------------------------------------------")
50 +print("Copy the following code and paste it in masks.cfg")
51 +print("----------------------------------------------------------------")
52 +name_points = ["a", "b", "c", "d", "e", "f"]
53 +
54 +mask_title = "[" + file_split[0] + "]"
55 +print(mask_title)
56 +print("template: ", filename)
57 +for i in range(len(mouse_pts)):
58 + name = (
59 + "mask_"
60 + + name_points[i]
61 + + ": "
62 + + str(mouse_pts[i][0])
63 + + ","
64 + + str(mouse_pts[i][1])
65 + )
66 + print(name)
67 +
68 +cv2.imwrite(new_filename, mask_im)
1 +# Author: Aqeel Anwar(ICSRL)
2 +# Created: 9/20/2019, 12:43 PM
3 +# Email: aqeel.anwar@gatech.edu
4 +
5 +from configparser import ConfigParser
6 +from dotmap import DotMap
7 +
8 +
9 +def ConvertIfStringIsInt(input_string):
10 + try:
11 + float(input_string)
12 +
13 + try:
14 + if int(input_string) == float(input_string):
15 + return int(input_string)
16 + else:
17 + return float(input_string)
18 + except ValueError:
19 + return float(input_string)
20 +
21 + except ValueError:
22 + return input_string
23 +
24 +
25 +def read_cfg(config_filename="masks/masks.cfg", mask_type="surgical", verbose=False):
26 + parser = ConfigParser()
27 + parser.optionxform = str
28 + parser.read(config_filename)
29 + cfg = DotMap()
30 + section_name = mask_type
31 +
32 + if verbose:
33 + hyphens = "-" * int((80 - len(config_filename)) / 2)
34 + print(hyphens + " " + config_filename + " " + hyphens)
35 +
36 + # for section_name in parser.sections():
37 +
38 + if verbose:
39 + print("[" + section_name + "]")
40 + for name, value in parser.items(section_name):
41 + value = ConvertIfStringIsInt(value)
42 + if name != "template":
43 + cfg[name] = tuple(int(s) for s in value.split(","))
44 + else:
45 + cfg[name] = value
46 + spaces = " " * (30 - len(name))
47 + if verbose:
48 + print(name + ":" + spaces + str(cfg[name]))
49 +
50 + return cfg
This diff is collapsed. Click to expand it.
1 +from .config import Config
...\ No newline at end of file ...\ No newline at end of file
1 +import yaml
2 +
3 +class Config():
4 + def __init__(self, yaml_path):
5 + yaml_file = open(yaml_path)
6 + self._attr = yaml.load(yaml_file, Loader=yaml.FullLoader)['settings']
7 +
8 + def __getattr__(self, attr):
9 + try:
10 + return self._attr[attr]
11 + except KeyError:
12 + return None
1 +settings:
2 + root_dir: "./datasets/celeba/images/"
3 + checkpoint_path: "weights"
4 + sample_folder: "sample"
5 +
6 + cuda: True
7 + lr: 0.001
8 + batch_size: 2
9 + num_workers: 4
10 +
11 + step_iters: [10000, 15000, 20000]
12 + gamma: 0.1
13 +
14 + d_num_layers: 3
15 +
16 + visualize_per_iter: 500
17 + save_per_iter: 500
18 + print_per_iter: 10
19 + num_epochs: 100
20 +
21 + lambda_G: 1.0
22 + lambda_rec_1: 100.0
23 + lambda_rec_2: 100.0
24 + lambda_per: 10.0
25 +
26 + img_size: 512
1 +settings:
2 + root_dir: "./datasets/places365_10classes"
3 + checkpoint_path: "/content/drive/MyDrive/weights/Places365 Inpainting/phase 3"
4 + sample_folder: "/content/drive/MyDrive/results/Places365 Inpainting/phase 3"
5 +
6 + cuda: True
7 + lr: 0.0001
8 + batch_size: 8
9 + num_workers: 4
10 +
11 + step_iters: [50000, 75000, 100000]
12 + gamma: 0.1
13 +
14 + d_num_layers: 3
15 +
16 + visualize_per_iter: 500
17 + save_per_iter: 500
18 + print_per_iter: 10
19 + num_epochs: 100
20 +
21 + lambda_G: 0.3
22 + lambda_rec_1: 10.0
23 + lambda_rec_2: 10.0
24 + lambda_per: 1.0
25 +
26 + img_size: 256
27 + max_angle: 4
28 + max_len: 50
29 + max_width: 30
30 + times: 15
1 +settings:
2 + root_dir: "./datasets/celeba/images/"
3 + train_anns: "./datasets/celeba/annotations/train.csv"
4 + val_anns: "./datasets/celeba/annotations/val.csv"
5 +
6 + checkpoint_path: "weights" #"/content/drive/MyDrive/weights/Places365 Inpainting/unet/phase 1"
7 + sample_folder: "sample" #"/content/drive/MyDrive/results/Places365 Inpainting/unet/phase 1"
8 +
9 + cuda: True
10 + lr: 0.001
11 + batch_size: 4
12 + num_workers: 4
13 +
14 + step_iters: [50000, 75000, 100000]
15 + gamma: 0.1
16 +
17 + visualize_per_iter: 1000
18 + save_per_iter: 1000
19 + print_per_iter: 10
20 + num_epochs: 100
21 +
22 + img_size: 512
1 +from .dataset import Places365Dataset, FacemaskDataset
2 +from .dataset_seg import FacemaskSegDataset
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
1 +import os
2 +import csv
3 +
4 +f = open("./datasets/celeba/annotations/train.csv", "w", newline="")
5 +wr = csv.writer(f)
6 +wr.writerow(["_", "img_name", "mask_name"])
7 +
8 +for i in range(23304):
9 + wr.writerow(
10 + [
11 + i,
12 + "celeba512_30k_masked/"
13 + + os.listdir("./datasets/celeba/images/celeba512_30k_masked")[i],
14 + "celeba512_30k_binary/"
15 + + os.listdir("./datasets/celeba/images/celeba512_30k_binary")[i],
16 + ]
17 + )
18 +
19 +f.close()
20 +
21 +f = open("./datasets/celeba/annotations/val.csv", "w", newline="")
22 +wr = csv.writer(f)
23 +wr.writerow(["_", "img_name", "mask_name"])
24 +
25 +for i in range(23304, 29131):
26 + wr.writerow(
27 + [
28 + i,
29 + "celeba512_30k_masked/"
30 + + os.listdir("./datasets/celeba/images/celeba512_30k_masked")[i],
31 + "celeba512_30k_binary/"
32 + + os.listdir("./datasets/celeba/images/celeba512_30k_binary")[i],
33 + ]
34 + )
35 +
36 +f.close()
1 +import os
2 +import torch
3 +import torch.nn as nn
4 +import torch.utils.data as data
5 +import cv2
6 +import numpy as np
7 +from tqdm import tqdm
8 +
9 +class Places365Dataset(data.Dataset):
10 + def __init__(self, cfg):
11 + self.root_dir = cfg.root_dir
12 + self.cfg = cfg
13 + self.load_images()
14 +
15 + def load_images(self):
16 + self.fns =[]
17 + idx = 0
18 + img_paths = os.listdir(self.root_dir)
19 + for cls_id in img_paths:
20 + paths = os.path.join(self.root_dir, cls_id)
21 + file_paths = os.listdir(paths)
22 + for img_name in file_paths:
23 + filename = os.path.join(paths, img_name)
24 + self.fns.append(filename)
25 +
26 + def __getitem__(self, index):
27 + img_path = self.fns[index]
28 + img = cv2.imread(img_path)
29 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
30 + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
31 +
32 + mask = self.random_ff_mask(
33 + shape = self.cfg.img_size,
34 + max_angle = self.cfg.max_angle,
35 + max_len = self.cfg.max_len,
36 + max_width = self.cfg.max_width,
37 + times = self.cfg.times)
38 +
39 + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
40 + mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
41 +
42 + return img, mask
43 +
44 + def collate_fn(self, batch):
45 + imgs = torch.stack([i[0] for i in batch])
46 + masks = torch.stack([i[1] for i in batch])
47 + return {
48 + 'imgs': imgs,
49 + 'masks': masks
50 + }
51 +
52 + def __len__(self):
53 + return len(self.fns)
54 +
55 + def random_ff_mask(self, shape = 256 , max_angle = 4, max_len = 50, max_width = 20, times = 15):
56 + """Generate a random free form mask with configuration.
57 + Args:
58 + config: Config should have configuration including IMG_SHAPES,
59 + VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
60 + Returns:
61 + tuple: (top, left, height, width)
62 + """
63 + height = shape
64 + width = shape
65 + mask = np.zeros((height, width), np.float32)
66 + times = np.random.randint(10, times)
67 + for i in range(times):
68 + start_x = np.random.randint(width)
69 + start_y = np.random.randint(height)
70 + for j in range(1 + np.random.randint(5)):
71 + angle = 0.01 + np.random.randint(max_angle)
72 + if i % 2 == 0:
73 + angle = 2 * 3.1415926 - angle
74 + length = 10 + np.random.randint(max_len)
75 + brush_w = 5 + np.random.randint(max_width)
76 + end_x = (start_x + length * np.sin(angle)).astype(np.int32)
77 + end_y = (start_y + length * np.cos(angle)).astype(np.int32)
78 + cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
79 + start_x, start_y = end_x, end_y
80 + return mask.reshape((1, ) + mask.shape).astype(np.float32)
81 +
82 +
83 +class FacemaskDataset(data.Dataset):
84 + def __init__(self, cfg):
85 + self.root_dir = cfg.root_dir
86 + self.cfg = cfg
87 +
88 + self.mask_folder = os.path.join(self.root_dir, 'celeba512_30k_binary')
89 + self.img_folder = os.path.join(self.root_dir, 'celeba512_30k')
90 + self.load_images()
91 +
92 + def load_images(self):
93 + self.fns = []
94 + idx = 0
95 + img_paths = sorted(os.listdir(self.img_folder))
96 + for img_name in img_paths:
97 + mask_name = img_name.split('.')[0]+'_binary.jpg'
98 + img_path = os.path.join(self.img_folder, img_name)
99 + mask_path = os.path.join(self.mask_folder, mask_name)
100 + if os.path.isfile(mask_path):
101 + self.fns.append([img_path, mask_path])
102 +
103 + def __getitem__(self, index):
104 + img_path, mask_path = self.fns[index]
105 + img = cv2.imread(img_path)
106 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
107 + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
108 +
109 +
110 + mask = cv2.imread(mask_path, 0)
111 +
112 + mask[mask>0]=1.0
113 + mask = np.expand_dims(mask, axis=0)
114 +
115 + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
116 + mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
117 + return img, mask
118 +
119 + def collate_fn(self, batch):
120 + imgs = torch.stack([i[0] for i in batch])
121 + masks = torch.stack([i[1] for i in batch])
122 + return {
123 + 'imgs': imgs,
124 + 'masks': masks
125 + }
126 +
127 + def __len__(self):
128 + return len(self.fns)
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import torch
3 +import torch.nn as nn
4 +import torch.utils.data as data
5 +import cv2
6 +import numpy as np
7 +from tqdm import tqdm
8 +import pandas as pd
9 +from PIL import Image
10 +
11 +
12 +class FacemaskSegDataset(data.Dataset):
13 + def __init__(self, cfg, train=True):
14 + self.root_dir = cfg.root_dir
15 + self.cfg = cfg
16 + self.train = train
17 +
18 + if self.train:
19 + self.df = pd.read_csv(cfg.train_anns)
20 + else:
21 + self.df = pd.read_csv(cfg.val_anns)
22 +
23 + self.load_images()
24 +
25 + def load_images(self):
26 + self.fns = []
27 + for idx, rows in self.df.iterrows():
28 + _, img_name, mask_name = rows
29 + img_path = os.path.join(self.root_dir, img_name)
30 + mask_path = os.path.join(self.root_dir, mask_name)
31 + img_path = img_path.replace("\\", "/")
32 + mask_path = mask_path.replace("\\", "/")
33 + if os.path.isfile(mask_path):
34 + self.fns.append([img_path, mask_path])
35 +
36 + def __getitem__(self, index):
37 + img_path, mask_path = self.fns[index]
38 + img = cv2.imread(img_path)
39 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
40 + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
41 + mask = cv2.imread(mask_path, 0)
42 + mask[mask > 0] = 1.0
43 + mask = np.expand_dims(mask, axis=0)
44 +
45 + img = (
46 + torch.from_numpy(img.astype(np.float32) / 255.0)
47 + .permute(2, 0, 1)
48 + .contiguous()
49 + )
50 + mask = torch.from_numpy(mask.astype(np.float32)).contiguous()
51 + return img, mask
52 +
53 + def collate_fn(self, batch):
54 + imgs = torch.stack([i[0] for i in batch])
55 + masks = torch.stack([i[1] for i in batch])
56 + return {"imgs": imgs, "masks": masks}
57 +
58 + def __len__(self):
59 + return len(self.fns)
1 +import torch
2 +import torch.nn as nn
3 +from torchvision.utils import save_image
4 +
5 +import numpy as np
6 +from PIL import Image
7 +import cv2
8 +from models import UNetSemantic, GatedGenerator
9 +import argparse
10 +from configs import Config
11 +
12 +class Predictor():
13 + def __init__(self, cfg):
14 + self.cfg = cfg
15 + self.device = torch.device('cuda:0' if cfg.cuda else 'cpu')
16 + self.masking = UNetSemantic().to(self.device)
17 + self.masking.load_state_dict(torch.load('weights\model_segm_19_135000.pth', map_location='cpu'))
18 + #self.masking.eval()
19 +
20 + self.inpaint = GatedGenerator().to(self.device)
21 + self.inpaint.load_state_dict(torch.load('weights/model_6_100000.pth', map_location='cpu')['G'])
22 + self.inpaint.eval()
23 +
24 + def save_image(self, img_list, save_img_path, nrow):
25 + img_list = [i.clone().cpu() for i in img_list]
26 + imgs = torch.stack(img_list, dim=1)
27 + imgs = imgs.view(-1, *list(imgs.size())[2:])
28 + save_image(imgs, save_img_path, nrow = nrow)
29 + print(f"Save image to {save_img_path}")
30 +
31 + def predict(self, image, outpath='sample/results.png'):
32 + outpath=f'sample/results_{image}.png'
33 + image = 'sample/'+image
34 + img = cv2.imread(image+'_masked.jpg')
35 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
36 + img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
37 + img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
38 + img = img.unsqueeze(0).to(self.device)
39 +
40 + img_ori = cv2.imread(image+'.jpg')
41 + img_ori = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB)
42 + img_ori = cv2.resize(img_ori, (self.cfg.img_size, self.cfg.img_size))
43 + img_ori = torch.from_numpy(img_ori.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
44 + img_ori = img_ori.unsqueeze(0)
45 + with torch.no_grad():
46 + outputs = self.masking(img)
47 + _, out = self.inpaint(img, outputs)
48 + inpaint = img * (1 - outputs) + out * outputs
49 + masks = img * (1 - outputs) + outputs #torch.cat([outputs, outputs, outputs], dim=1)
50 +
51 +
52 +
53 + self.save_image([img, masks, inpaint, img_ori], outpath, nrow=4)
54 +
55 +
56 +
57 +
58 +if __name__ == '__main__':
59 + parser = argparse.ArgumentParser(description='Training custom model')
60 + parser.add_argument('--image', default=None, type=str, help='resume training')
61 + parser.add_argument('config', default='config', type=str, help='config training')
62 + args = parser.parse_args()
63 +
64 + config = Config(f'./configs/{args.config}.yaml')
65 +
66 +
67 + model = Predictor(config)
68 + model.predict(args.image)
...\ No newline at end of file ...\ No newline at end of file
1 +from .loggers import *
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import numpy as np
3 +from torch.utils.tensorboard import SummaryWriter
4 +from datetime import datetime
5 +
6 +class Logger():
7 + """
8 + Logger for Tensorboard visualization
9 + :param log_dir: Path to save checkpoint
10 + """
11 + def __init__(self, log_dir=None):
12 + self.log_dir = log_dir
13 + if self.log_dir is None:
14 + self.log_dir = os.path.join('loggers/runs',datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
15 + if not os.path.exists(self.log_dir):
16 + os.mkdir(self.log_dir)
17 + self.writer = SummaryWriter(log_dir=self.log_dir)
18 + self.iters = {}
19 +
20 + def write(self, tags, values):
21 + """
22 + Write a log to specified directory
23 + :param tags: (str) tag for log
24 + :param values: (number) value for corresponding tag
25 + """
26 + if not isinstance(tags, list):
27 + tags = list(tags)
28 + if not isinstance(values, list):
29 + values = list(values)
30 +
31 + for i, (tag, value) in enumerate(zip(tags,values)):
32 + if tag not in self.iters.keys():
33 + self.iters[tag] = 0
34 + self.writer.add_scalar(tag, value, self.iters[tag])
35 + self.iters[tag] += 1
36 +
37 +
1 +from .adversarial import GANLoss
2 +from .ssim import SSIM
3 +from .dice import DiceLoss
...\ No newline at end of file ...\ No newline at end of file
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff could not be displayed because it is too large.
This diff is collapsed. Click to expand it.