Showing
6 changed files
with
1402 additions
and
78 deletions
... | @@ -151,7 +151,7 @@ def compute_psnr(a, b): | ... | @@ -151,7 +151,7 @@ def compute_psnr(a, b): |
151 | mse = torch.mean((a - b)**2).item() | 151 | mse = torch.mean((a - b)**2).item() |
152 | return -10 * math.log10(mse) | 152 | return -10 * math.log10(mse) |
153 | 153 | ||
154 | -def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path): | 154 | +def _encode(seq, path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path): |
155 | compressai.set_entropy_coder(coder) | 155 | compressai.set_entropy_coder(coder) |
156 | enc_start = time.time() | 156 | enc_start = time.time() |
157 | 157 | ||
... | @@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou | ... | @@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou |
182 | strings.append([s[0]]) | 182 | strings.append([s[0]]) |
183 | 183 | ||
184 | with torch.no_grad(): | 184 | with torch.no_grad(): |
185 | - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | 185 | + recon_out = net.decompress(strings, shape) |
186 | x_recon = crop(recon_out["x_hat"], (h, w)) | 186 | x_recon = crop(recon_out["x_hat"], (h, w)) |
187 | 187 | ||
188 | psnr=compute_psnr(x, x_recon) | 188 | psnr=compute_psnr(x, x_recon) |
189 | 189 | ||
190 | - if i==False: | 190 | + #if i==False: |
191 | - diff=x-ref | 191 | + # diff=x-ref |
192 | - diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 | 192 | + # diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 |
193 | - diff_img = torch2img(diff1) | 193 | + # diff_img = torch2img(diff1) |
194 | - diff_img.save(path+"recon/diff_v1_"+str(ff)+"_q"+str(quality)+".png") | 194 | + # diff_img.save("../Data/train/"+seq+str(ff)+"_train_v1_q"+str(quality)+".png") |
195 | 195 | ||
196 | enc_time = time.time() - enc_start | 196 | enc_time = time.time() - enc_start |
197 | size = filesize(output) | 197 | size = filesize(output) |
... | @@ -336,15 +336,15 @@ def encode(argv): | ... | @@ -336,15 +336,15 @@ def encode(argv): |
336 | total_psnr=0.0 | 336 | total_psnr=0.0 |
337 | total_bpp=0.0 | 337 | total_bpp=0.0 |
338 | total_time=0.0 | 338 | total_time=0.0 |
339 | - args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" | 339 | + img_path =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" |
340 | - img=args.image+"_frame"+str(0)+".png" | 340 | + img=img_path+"_frame"+str(0)+".png" |
341 | - total_psnr, total_bpp, ref,total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) | 341 | + total_psnr, total_bpp, ref,total_time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) |
342 | for ff in range(1, args.frame): | 342 | for ff in range(1, args.frame): |
343 | with Path(log_path).open("a") as f: | 343 | with Path(log_path).open("a") as f: |
344 | f.write(f" {ff:3d} | ") | 344 | f.write(f" {ff:3d} | ") |
345 | - img=args.image+"_frame"+str(ff)+".png" | 345 | + img=img_path+"_frame"+str(ff)+".png" |
346 | 346 | ||
347 | - psnr, total_bpp, ref,time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | 347 | + psnr, total_bpp, ref,time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) |
348 | total_psnr+=psnr | 348 | total_psnr+=psnr |
349 | total_time+=time | 349 | total_time+=time |
350 | 350 | ... | ... |
... | @@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
213 | strings.append([s[0]]) | 213 | strings.append([s[0]]) |
214 | 214 | ||
215 | with torch.no_grad(): | 215 | with torch.no_grad(): |
216 | - recon_out1 = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | 216 | + recon_out1 = net.decompress(strings,shape) |
217 | x_hat1 = crop(recon_out1["x_hat"], (h, w)) | 217 | x_hat1 = crop(recon_out1["x_hat"], (h, w)) |
218 | 218 | ||
219 | with torch.no_grad(): | 219 | with torch.no_grad(): |
... | @@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
231 | strings.append([s[0]]) | 231 | strings.append([s[0]]) |
232 | 232 | ||
233 | with torch.no_grad(): | 233 | with torch.no_grad(): |
234 | - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | 234 | + recon_out = net.decompress(strings, shape) |
235 | x_hat2 = crop(recon_out["x_hat"], (h, w)) | 235 | x_hat2 = crop(recon_out["x_hat"], (h, w)) |
236 | x_recon=ref+x_hat1-x_hat2 | 236 | x_recon=ref+x_hat1-x_hat2 |
237 | 237 | ... | ... |
... | @@ -17,6 +17,7 @@ import struct | ... | @@ -17,6 +17,7 @@ import struct |
17 | import sys | 17 | import sys |
18 | import time | 18 | import time |
19 | import math | 19 | import math |
20 | +from pytorch_msssim import ms_ssim | ||
20 | 21 | ||
21 | from pathlib import Path | 22 | from pathlib import Path |
22 | 23 | ||
... | @@ -27,7 +28,12 @@ from PIL import Image | ... | @@ -27,7 +28,12 @@ from PIL import Image |
27 | from torchvision.transforms import ToPILImage, ToTensor | 28 | from torchvision.transforms import ToPILImage, ToTensor |
28 | 29 | ||
29 | import compressai | 30 | import compressai |
30 | - | 31 | +from compressai.transforms.functional import ( |
32 | + rgb2ycbcr, | ||
33 | + ycbcr2rgb, | ||
34 | + yuv_420_to_444, | ||
35 | + yuv_444_to_420, | ||
36 | +) | ||
31 | from compressai.zoo import models | 37 | from compressai.zoo import models |
32 | 38 | ||
33 | model_ids = {k: i for i, k in enumerate(models.keys())} | 39 | model_ids = {k: i for i, k in enumerate(models.keys())} |
... | @@ -151,13 +157,28 @@ def compute_psnr(a, b): | ... | @@ -151,13 +157,28 @@ def compute_psnr(a, b): |
151 | mse = torch.mean((a - b)**2).item() | 157 | mse = torch.mean((a - b)**2).item() |
152 | return -10 * math.log10(mse) | 158 | return -10 * math.log10(mse) |
153 | 159 | ||
154 | -def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path): | 160 | +def compute_msssim(a, b): |
161 | + return ms_ssim(a, b, data_range=1.).item() | ||
162 | + | ||
163 | +def ycbcr_psnr(a, b): | ||
164 | + yuv_a=rgb2ycbcr(a) | ||
165 | + yuv_b=rgb2ycbcr(b) | ||
166 | + a_y, a_cb, a_cr = yuv_a.chunk(3, -3) | ||
167 | + b_y, b_cb, b_cr = yuv_b.chunk(3, -3) | ||
168 | + y=compute_psnr(a_y, b_y) | ||
169 | + cb=compute_psnr(a_cb, b_cb) | ||
170 | + cr=compute_psnr(a_cr, b_cr) | ||
171 | + return (4*y+cb+cr)/6 | ||
172 | + | ||
173 | +def _encode(checkpoint, path, seq, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path): | ||
155 | compressai.set_entropy_coder(coder) | 174 | compressai.set_entropy_coder(coder) |
156 | enc_start = time.time() | 175 | enc_start = time.time() |
157 | 176 | ||
158 | - img = load_image(image) | 177 | + img = load_image(image+"_frame"+str(ff)+".png") |
159 | start = time.time() | 178 | start = time.time() |
160 | - net = models[model](quality=quality, metric=metric, pretrained=True).eval() | 179 | + net = models[model](quality=quality, metric=metric, pretrained=True) |
180 | + | ||
181 | + net.eval() | ||
161 | load_time = time.time() - start | 182 | load_time = time.time() - start |
162 | 183 | ||
163 | x = img2torch(img) | 184 | x = img2torch(img) |
... | @@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
182 | strings.append([s[0]]) | 203 | strings.append([s[0]]) |
183 | 204 | ||
184 | with torch.no_grad(): | 205 | with torch.no_grad(): |
185 | - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | 206 | + recon_out = net.decompress(strings,shape) |
186 | x_recon = crop(recon_out["x_hat"], (h, w)) | 207 | x_recon = crop(recon_out["x_hat"], (h, w)) |
187 | 208 | ||
188 | psnr=compute_psnr(x, x_recon) | 209 | psnr=compute_psnr(x, x_recon) |
210 | + ssim=compute_msssim(x, x_recon) | ||
211 | + ycbcr=ycbcr_psnr(x, x_recon) | ||
212 | + else: | ||
213 | + if checkpoint: # load from previous checkpoint | ||
214 | + checkpoint = torch.load(checkpoint) | ||
215 | + #state_dict = load_state_dict(checkpoint["state_dict"]) | ||
216 | + net=models[model](quality=quality, metric=metric) | ||
217 | + net.load_state_dict(checkpoint["state_dict"]) | ||
218 | + net.update(force=True) | ||
189 | else: | 219 | else: |
220 | + net = models[model](quality=quality, metric=metric, pretrained=True) | ||
221 | + | ||
190 | diff=x-ref | 222 | diff=x-ref |
191 | - #1 | ||
192 | diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 | 223 | diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 |
193 | 224 | ||
194 | - #2 | ||
195 | - ''' | ||
196 | - diff1=torch.clamp(diff, min=0.0, max=1.0) | ||
197 | - diff2=-torch.clamp(diff, min=-1.0, max=0.0) | ||
198 | - | ||
199 | - diff1=pad(diff1, p) | ||
200 | - diff2=pad(diff2, p) | ||
201 | - ''' | ||
202 | - #1 | ||
203 | - | ||
204 | - with torch.no_grad(): | ||
205 | - out1 = net.compress(diff1) | ||
206 | - shape1 = out1["shape"] | ||
207 | - strings = [] | ||
208 | - | ||
209 | - with Path(output).open("ab") as f: | ||
210 | - # write shape and number of encoded latents | ||
211 | - write_uints(f, (shape1[0], shape1[1], len(out1["strings"]))) | ||
212 | - | ||
213 | - for s in out1["strings"]: | ||
214 | - write_uints(f, (len(s[0]),)) | ||
215 | - write_bytes(f, s[0]) | ||
216 | - strings.append([s[0]]) | ||
217 | - | ||
218 | - with torch.no_grad(): | ||
219 | - recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | ||
220 | - x_hat1 = crop(recon_out["x_hat"], (h, w)) | ||
221 | 225 | ||
222 | - #2 | ||
223 | - ''' | ||
224 | with torch.no_grad(): | 226 | with torch.no_grad(): |
225 | out1 = net.compress(diff1) | 227 | out1 = net.compress(diff1) |
226 | shape1 = out1["shape"] | 228 | shape1 = out1["shape"] |
... | @@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
236 | strings.append([s[0]]) | 238 | strings.append([s[0]]) |
237 | 239 | ||
238 | with torch.no_grad(): | 240 | with torch.no_grad(): |
239 | - recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | 241 | + recon_out = net.decompress(strings, shape1) |
240 | x_hat1 = crop(recon_out["x_hat"], (h, w)) | 242 | x_hat1 = crop(recon_out["x_hat"], (h, w)) |
241 | - with torch.no_grad(): | ||
242 | - out = net.compress(diff2) | ||
243 | - shape = out["shape"] | ||
244 | - strings = [] | ||
245 | 243 | ||
246 | - with Path(output).open("ab") as f: | ||
247 | - # write shape and number of encoded latents | ||
248 | - write_uints(f, (shape[0], shape[1], len(out["strings"]))) | ||
249 | - | ||
250 | - for s in out["strings"]: | ||
251 | - write_uints(f, (len(s[0]),)) | ||
252 | - write_bytes(f, s[0]) | ||
253 | - strings.append([s[0]]) | ||
254 | - | ||
255 | - with torch.no_grad(): | ||
256 | - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | ||
257 | - x_hat2 = crop(recon_out["x_hat"], (h, w)) | ||
258 | - x_recon=ref+x_hat1-x_hat2 | ||
259 | - ''' | ||
260 | 244 | ||
261 | x_recon=ref+x_hat1-0.5 | 245 | x_recon=ref+x_hat1-0.5 |
262 | psnr=compute_psnr(x, x_recon) | 246 | psnr=compute_psnr(x, x_recon) |
247 | + ssim=compute_msssim(x, x_recon) | ||
248 | + ycbcr=ycbcr_psnr(x, x_recon) | ||
263 | diff_img = torch2img(diff1) | 249 | diff_img = torch2img(diff1) |
264 | - diff_img.save(path+"recon/diff"+str(ff)+"_q"+str(quality)+".png") | 250 | +# diff_img.save(path+"recon/"+seq+str(ff)+"_q"+str(quality)+".png") |
251 | +# diff_img.save("../Data/train/"+seq+str(ff)+"_train8_q"+str(quality)+".png") | ||
265 | 252 | ||
266 | enc_time = time.time() - enc_start | 253 | enc_time = time.time() - enc_start |
267 | size = filesize(output) | 254 | size = filesize(output) |
... | @@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
269 | with Path(log_path).open("a") as f: | 256 | with Path(log_path).open("a") as f: |
270 | f.write( f" {bpp-total_bpp:.4f} | " | 257 | f.write( f" {bpp-total_bpp:.4f} | " |
271 | f" {psnr:.4f} |" | 258 | f" {psnr:.4f} |" |
259 | + f" {ssim:.4f} |" | ||
260 | + f" {ycbcr:.4f} |" | ||
272 | f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") | 261 | f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") |
273 | recon_img = torch2img(x_recon) | 262 | recon_img = torch2img(x_recon) |
274 | recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png") | 263 | recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png") |
275 | 264 | ||
276 | - return psnr, bpp, x_recon, enc_time | 265 | + return psnr, bpp, x_recon, enc_time, ssim, ycbcr |
277 | 266 | ||
278 | 267 | ||
279 | def _decode(inputpath, coder, show, frame, output=None): | 268 | def _decode(inputpath, coder, show, frame, output=None): |
... | @@ -381,13 +370,19 @@ def encode(argv): | ... | @@ -381,13 +370,19 @@ def encode(argv): |
381 | default=768, | 370 | default=768, |
382 | help="hight setting (default: %(default))", | 371 | help="hight setting (default: %(default))", |
383 | ) | 372 | ) |
373 | + parser.add_argument( | ||
374 | + "-check", | ||
375 | + "--checkpoint", | ||
376 | + type=str, | ||
377 | + help="Path to a checkpoint", | ||
378 | + ) | ||
384 | parser.add_argument("-o", "--output", help="Output path") | 379 | parser.add_argument("-o", "--output", help="Output path") |
385 | args = parser.parse_args(argv) | 380 | args = parser.parse_args(argv) |
386 | path="examples/"+args.image+"/" | 381 | path="examples/"+args.image+"/" |
387 | if not args.output: | 382 | if not args.output: |
388 | #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") | 383 | #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") |
389 | - args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.bin" | 384 | + args.output = path+args.image+"_q"+str(args.quality)+"_train_ssim.bin" |
390 | - log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.txt" | 385 | + log_path=path+args.image+"_q"+str(args.quality)+"_train_ssim.txt" |
391 | 386 | ||
392 | header = get_header(args.model, args.metric, args.quality) | 387 | header = get_header(args.model, args.metric, args.quality) |
393 | with Path(args.output).open("wb") as f: | 388 | with Path(args.output).open("wb") as f: |
... | @@ -400,32 +395,43 @@ def encode(argv): | ... | @@ -400,32 +395,43 @@ def encode(argv): |
400 | f"frames : {args.frame}\n") | 395 | f"frames : {args.frame}\n") |
401 | f.write( f"frame | bpp | " | 396 | f.write( f"frame | bpp | " |
402 | f" psnr |" | 397 | f" psnr |" |
398 | + f" ssim |" | ||
403 | f" Encoded time (model loading)\n" | 399 | f" Encoded time (model loading)\n" |
404 | f" {0:3d} | ") | 400 | f" {0:3d} | ") |
405 | 401 | ||
406 | total_psnr=0.0 | 402 | total_psnr=0.0 |
403 | + total_ssim=0.0 | ||
404 | + total_ycbcr=0.0 | ||
407 | total_bpp=0.0 | 405 | total_bpp=0.0 |
408 | total_time=0.0 | 406 | total_time=0.0 |
409 | - args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" | 407 | + img =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" |
410 | - img=args.image+"_frame"+str(0)+".png" | 408 | + total_psnr, total_bpp, ref, total_time, total_ssim, total_ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) |
411 | - total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) | ||
412 | for ff in range(1, args.frame): | 409 | for ff in range(1, args.frame): |
413 | with Path(log_path).open("a") as f: | 410 | with Path(log_path).open("a") as f: |
414 | f.write(f" {ff:3d} | ") | 411 | f.write(f" {ff:3d} | ") |
415 | - img=args.image+"_frame"+str(ff)+".png" | 412 | + if ff%25==0: |
416 | - | 413 | + psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, ref, total_bpp, ff, args.output, log_path) |
417 | - psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | 414 | + else: |
415 | + psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | ||
418 | total_psnr+=psnr | 416 | total_psnr+=psnr |
417 | + total_ssim+=ssim | ||
418 | + total_ycbcr+=ycbcr | ||
419 | total_time+=time | 419 | total_time+=time |
420 | 420 | ||
421 | total_psnr/=args.frame | 421 | total_psnr/=args.frame |
422 | + total_ssim/=args.frame | ||
423 | + total_ycbcr/=args.frame | ||
422 | total_bpp/=args.frame | 424 | total_bpp/=args.frame |
423 | 425 | ||
424 | with Path(log_path).open("a") as f: | 426 | with Path(log_path).open("a") as f: |
425 | f.write( f"\n Total Encoded time: {total_time:.2f}s\n" | 427 | f.write( f"\n Total Encoded time: {total_time:.2f}s\n" |
426 | f"\n Total PSNR: {total_psnr:.6f}\n" | 428 | f"\n Total PSNR: {total_psnr:.6f}\n" |
429 | + f"\n Total SSIM: {total_ssim:.6f}\n" | ||
430 | + f"\n Total ycbcr: {total_ycbcr:.6f}\n" | ||
427 | f" Total BPP: {total_bpp:.6f}\n") | 431 | f" Total BPP: {total_bpp:.6f}\n") |
428 | print(total_psnr) | 432 | print(total_psnr) |
433 | + print(total_ssim) | ||
434 | + print(total_ycbcr) | ||
429 | print(total_bpp) | 435 | print(total_bpp) |
430 | 436 | ||
431 | 437 | ... | ... |
Our Encoder/train_RGB.py
0 → 100644
1 | +# Copyright 2020 InterDigital Communications, Inc. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import argparse | ||
16 | +import math | ||
17 | +import random | ||
18 | +import shutil | ||
19 | +import sys | ||
20 | +import time | ||
21 | + | ||
22 | +import torch | ||
23 | +import torch.nn as nn | ||
24 | +import torch.optim as optim | ||
25 | + | ||
26 | +from torch.utils.data import DataLoader | ||
27 | +from torchvision import transforms | ||
28 | + | ||
29 | +from compressai.datasets import ImageFolder | ||
30 | +from compressai.zoo import models | ||
31 | +import csv | ||
32 | +import cv2 | ||
33 | +import numpy as np | ||
34 | + | ||
35 | +class RateDistortionLoss(nn.Module): | ||
36 | + """Custom rate distortion loss with a Lagrangian parameter.""" | ||
37 | + | ||
38 | + def __init__(self, lmbda=1e-2): | ||
39 | + super().__init__() | ||
40 | + self.mse = nn.MSELoss() | ||
41 | + self.lmbda = lmbda | ||
42 | + | ||
43 | + def forward(self, output, target): | ||
44 | + N, _, H, W = target.size() | ||
45 | + out = {} | ||
46 | + num_pixels = N * H * W | ||
47 | + | ||
48 | + out["bpp_loss"] = sum( | ||
49 | + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) | ||
50 | + for likelihoods in output["likelihoods"].values() | ||
51 | + ) | ||
52 | + out["mse_loss"] = self.mse(output["x_hat"], target) | ||
53 | + out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"] | ||
54 | + | ||
55 | + return out | ||
56 | + | ||
57 | + | ||
58 | +class AverageMeter: | ||
59 | + """Compute running average.""" | ||
60 | + | ||
61 | + def __init__(self): | ||
62 | + self.val = 0 | ||
63 | + self.avg = 0 | ||
64 | + self.sum = 0 | ||
65 | + self.count = 0 | ||
66 | + | ||
67 | + def update(self, val, n=1): | ||
68 | + self.val = val | ||
69 | + self.sum += val * n | ||
70 | + self.count += n | ||
71 | + self.avg = self.sum / self.count | ||
72 | + | ||
73 | + | ||
74 | +class CustomDataParallel(nn.DataParallel): | ||
75 | + """Custom DataParallel to access the module methods.""" | ||
76 | + | ||
77 | + def __getattr__(self, key): | ||
78 | + try: | ||
79 | + return super().__getattr__(key) | ||
80 | + except AttributeError: | ||
81 | + return getattr(self.module, key) | ||
82 | + | ||
83 | + | ||
84 | +def configure_optimizers(net, args): | ||
85 | + """Separate parameters for the main optimizer and the auxiliary optimizer. | ||
86 | + Return two optimizers""" | ||
87 | + | ||
88 | + parameters = set( | ||
89 | + n | ||
90 | + for n, p in net.named_parameters() | ||
91 | + if not n.endswith(".quantiles") and p.requires_grad | ||
92 | + ) | ||
93 | + aux_parameters = set( | ||
94 | + n | ||
95 | + for n, p in net.named_parameters() | ||
96 | + if n.endswith(".quantiles") and p.requires_grad | ||
97 | + ) | ||
98 | + | ||
99 | + # Make sure we don't have an intersection of parameters | ||
100 | + params_dict = dict(net.named_parameters()) | ||
101 | + inter_params = parameters & aux_parameters | ||
102 | + union_params = parameters | aux_parameters | ||
103 | + | ||
104 | + assert len(inter_params) == 0 | ||
105 | + assert len(union_params) - len(params_dict.keys()) == 0 | ||
106 | + | ||
107 | + optimizer = optim.Adam( | ||
108 | + (params_dict[n] for n in sorted(list(parameters))), | ||
109 | + lr=args.learning_rate, | ||
110 | + ) | ||
111 | + aux_optimizer = optim.Adam( | ||
112 | + (params_dict[n] for n in sorted(list(aux_parameters))), | ||
113 | + lr=args.aux_learning_rate, | ||
114 | + ) | ||
115 | + return optimizer, aux_optimizer | ||
116 | + | ||
117 | + | ||
118 | +def train_one_epoch( | ||
119 | + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm | ||
120 | +): | ||
121 | + model.train() | ||
122 | + device = next(model.parameters()).device | ||
123 | + | ||
124 | + loss = AverageMeter() | ||
125 | + bpp_loss = AverageMeter() | ||
126 | + mse_loss = AverageMeter() | ||
127 | + a_aux_loss = AverageMeter() | ||
128 | + | ||
129 | + for i, d in enumerate(train_dataloader): | ||
130 | + d = d.to(device) | ||
131 | + | ||
132 | + optimizer.zero_grad() | ||
133 | + aux_optimizer.zero_grad() | ||
134 | + | ||
135 | + out_net = model(d) | ||
136 | + | ||
137 | + out_criterion = criterion(out_net, d) | ||
138 | + | ||
139 | + bpp_loss.update(out_criterion["bpp_loss"]) | ||
140 | + loss.update(out_criterion["loss"]) | ||
141 | + mse_loss.update(out_criterion["mse_loss"]) | ||
142 | + | ||
143 | + out_criterion["loss"].backward() | ||
144 | + | ||
145 | + if clip_max_norm > 0: | ||
146 | + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) | ||
147 | + optimizer.step() | ||
148 | + | ||
149 | + aux_loss = model.aux_loss() | ||
150 | + a_aux_loss.update(aux_loss) | ||
151 | + aux_loss.backward() | ||
152 | + aux_optimizer.step() | ||
153 | + | ||
154 | + if i % 10 == 0: | ||
155 | + print( | ||
156 | + f"Train epoch {epoch}: [" | ||
157 | + f"{i*len(d)}/{len(train_dataloader.dataset)}" | ||
158 | + f" ({100. * i / len(train_dataloader):.0f}%)]" | ||
159 | + f'\tLoss: {out_criterion["loss"].item():.3f} |' | ||
160 | + f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |' | ||
161 | + f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' | ||
162 | + f"\tAux loss: {aux_loss.item():.2f}" | ||
163 | + ) | ||
164 | + return loss.avg, bpp_loss.avg, a_aux_loss.avg | ||
165 | + | ||
166 | + | ||
167 | +def test_epoch(epoch, test_dataloader, model, criterion): | ||
168 | + model.eval() | ||
169 | + device = next(model.parameters()).device | ||
170 | + | ||
171 | + loss = AverageMeter() | ||
172 | + bpp_loss = AverageMeter() | ||
173 | + mse_loss = AverageMeter() | ||
174 | + aux_loss = AverageMeter() | ||
175 | + | ||
176 | + with torch.no_grad(): | ||
177 | + for d in test_dataloader: | ||
178 | + d = d.to(device) | ||
179 | + out_net = model(d) | ||
180 | + out_criterion = criterion(out_net, d) | ||
181 | + | ||
182 | + aux_loss.update(model.aux_loss()) | ||
183 | + bpp_loss.update(out_criterion["bpp_loss"]) | ||
184 | + loss.update(out_criterion["loss"]) | ||
185 | + mse_loss.update(out_criterion["mse_loss"]) | ||
186 | + | ||
187 | + print( | ||
188 | + f"Test epoch {epoch}: Average losses:" | ||
189 | + f"\tLoss: {loss.avg:.3f} |" | ||
190 | + f"\tMSE loss: {mse_loss.avg:.3f} |" | ||
191 | + f"\tBpp loss: {bpp_loss.avg:.2f} |" | ||
192 | + f"\tAux loss: {aux_loss.avg:.2f}\n" | ||
193 | + ) | ||
194 | + | ||
195 | + return loss.avg, bpp_loss.avg, aux_loss.avg | ||
196 | + | ||
197 | +def save_checkpoint(state, is_best, q, filename="checkpoint"): | ||
198 | + torch.save(state, filename+q+".pth.tar") | ||
199 | + if is_best: | ||
200 | + shutil.copyfile( filename+q+".pth.tar", "checkpoint_best_loss"+q+".pth.tar") | ||
201 | + | ||
202 | + | ||
203 | +def parse_args(argv): | ||
204 | + parser = argparse.ArgumentParser(description="Example training script.") | ||
205 | + parser.add_argument( | ||
206 | + "-m", | ||
207 | + "--model", | ||
208 | + default="bmshj2018-hyperprior", | ||
209 | + choices=models.keys(), | ||
210 | + help="Model architecture (default: %(default)s)", | ||
211 | + ) | ||
212 | + parser.add_argument( | ||
213 | + "-d", "--dataset", type=str, required=True, help="Training dataset" | ||
214 | + ) | ||
215 | + parser.add_argument( | ||
216 | + "-e", | ||
217 | + "--epochs", | ||
218 | + default=100, | ||
219 | + type=int, | ||
220 | + help="Number of epochs (default: %(default)s)", | ||
221 | + ) | ||
222 | + parser.add_argument( | ||
223 | + "-lr", | ||
224 | + "--learning-rate", | ||
225 | + default=1e-4, | ||
226 | + type=float, | ||
227 | + help="Learning rate (default: %(default)s)", | ||
228 | + ) | ||
229 | + parser.add_argument( | ||
230 | + "-n", | ||
231 | + "--num-workers", | ||
232 | + type=int, | ||
233 | + default=0, | ||
234 | + help="Dataloaders threads (default: %(default)s)", | ||
235 | + ) | ||
236 | + parser.add_argument( | ||
237 | + "--lambda", | ||
238 | + dest="lmbda", | ||
239 | + type=float, | ||
240 | + default=1e-2, | ||
241 | + help="Bit-rate distortion parameter (default: %(default)s)", | ||
242 | + ) | ||
243 | + parser.add_argument( | ||
244 | + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" | ||
245 | + ) | ||
246 | + parser.add_argument( | ||
247 | + "--test-batch-size", | ||
248 | + type=int, | ||
249 | + default=64, | ||
250 | + help="Test batch size (default: %(default)s)", | ||
251 | + ) | ||
252 | + parser.add_argument( | ||
253 | + "--aux-learning-rate", | ||
254 | + default=1e-3, | ||
255 | + help="Auxiliary loss learning rate (default: %(default)s)", | ||
256 | + ) | ||
257 | + parser.add_argument( | ||
258 | + "--patch-size", | ||
259 | + type=int, | ||
260 | + nargs=2, | ||
261 | + default=(256, 256), | ||
262 | + help="Size of the patches to be cropped (default: %(default)s)", | ||
263 | + ) | ||
264 | + parser.add_argument( | ||
265 | + "-q", | ||
266 | + "--quality", | ||
267 | + type=int, | ||
268 | + default=3, | ||
269 | + help="Quality (default: %(default)s)", | ||
270 | + ) | ||
271 | + parser.add_argument("--cuda", action="store_true", help="Use cuda") | ||
272 | + parser.add_argument("--save", action="store_true", help="Save model to disk") | ||
273 | + parser.add_argument( | ||
274 | + "--seed", type=float, help="Set random seed for reproducibility" | ||
275 | + ) | ||
276 | + parser.add_argument( | ||
277 | + "--clip_max_norm", | ||
278 | + default=1.0, | ||
279 | + type=float, | ||
280 | + help="gradient clipping max norm (default: %(default)s", | ||
281 | + ) | ||
282 | + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint") | ||
283 | + args = parser.parse_args(argv) | ||
284 | + return args | ||
285 | + | ||
286 | +class CSVLogger(): | ||
287 | + def __init__(self, fieldnames, filename='log.csv'): | ||
288 | + | ||
289 | + self.filename = filename | ||
290 | + self.csv_file = open(filename, 'a') | ||
291 | + | ||
292 | + # Write model configuration at top of csv | ||
293 | + writer = csv.writer(self.csv_file) | ||
294 | + | ||
295 | + self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) | ||
296 | + # self.writer.writeheader() | ||
297 | + | ||
298 | + # self.csv_file.flush() | ||
299 | + | ||
300 | + def writerow(self, row): | ||
301 | + self.writer.writerow(row) | ||
302 | + self.csv_file.flush() | ||
303 | + | ||
304 | + def close(self): | ||
305 | + self.csv_file.close() | ||
306 | + | ||
307 | +class Blur(object): | ||
308 | + def __init__(self, k, sig): | ||
309 | + self.k = k | ||
310 | + self.sig = sig | ||
311 | + | ||
312 | + def __call__(self, img): | ||
313 | + r=np.random.rand(1) | ||
314 | + if r<0.5: | ||
315 | + img=cv2.GaussianBlur(img.numpy(), (self.k,self.k), self.sig) | ||
316 | + img=torch.from_numpy(img) | ||
317 | + return img | ||
318 | + | ||
319 | +def main(argv): | ||
320 | + args = parse_args(argv) | ||
321 | + | ||
322 | + if args.seed is not None: | ||
323 | + torch.manual_seed(args.seed) | ||
324 | + random.seed(args.seed) | ||
325 | + | ||
326 | + train_transforms = transforms.Compose( | ||
327 | + [transforms.RandomCrop(args.patch_size), | ||
328 | + transforms.RandomRotation(30), | ||
329 | + transforms.RandomHorizontalFlip(), | ||
330 | + transforms.ToTensor()] | ||
331 | + ) | ||
332 | + #train_transforms.transforms.append(Blur(k=3, sig=5)) | ||
333 | + | ||
334 | + test_transforms = transforms.Compose( | ||
335 | + [transforms.CenterCrop(args.patch_size), transforms.ToTensor()] | ||
336 | + ) | ||
337 | + | ||
338 | + train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) | ||
339 | + test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms) | ||
340 | + | ||
341 | + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" | ||
342 | + print(torch.cuda.is_available()) | ||
343 | + print(device) | ||
344 | + train_dataloader = DataLoader( | ||
345 | + train_dataset, | ||
346 | + batch_size=args.batch_size, | ||
347 | + num_workers=args.num_workers, | ||
348 | + shuffle=True, | ||
349 | + pin_memory=(device == "cuda"), | ||
350 | + ) | ||
351 | + | ||
352 | + test_dataloader = DataLoader( | ||
353 | + test_dataset, | ||
354 | + batch_size=args.test_batch_size, | ||
355 | + num_workers=args.num_workers, | ||
356 | + shuffle=False, | ||
357 | + pin_memory=(device == "cuda"), | ||
358 | + ) | ||
359 | + | ||
360 | + net = models[args.model](quality=args.quality, pretrained=False) | ||
361 | + net = net.to(device) | ||
362 | + | ||
363 | + #if args.cuda and torch.cuda.device_count() > 1: | ||
364 | + # net = CustomDataParallel(net) | ||
365 | + | ||
366 | + optimizer, aux_optimizer = configure_optimizers(net, args) | ||
367 | +# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=20) | ||
368 | + lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400, 700], gamma=0.2) | ||
369 | + criterion = RateDistortionLoss(lmbda=args.lmbda) | ||
370 | + | ||
371 | + filename = "train"+str(args.quality)+".csv" | ||
372 | + csv_logger = CSVLogger(fieldnames=['epoch', 'train_loss', 'train_bpp_loss','train_aux', 'test_loss', 'test_bpp_loss', 'test_aux'], filename=filename) | ||
373 | + | ||
374 | + last_epoch = 0 | ||
375 | + if args.checkpoint: # load from previous checkpoint | ||
376 | + print("Loading", args.checkpoint) | ||
377 | + checkpoint = torch.load(args.checkpoint, map_location=device) | ||
378 | + last_epoch = checkpoint["epoch"] + 1 | ||
379 | + net.load_state_dict(checkpoint["state_dict"]) | ||
380 | + optimizer.load_state_dict(checkpoint["optimizer"]) | ||
381 | + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) | ||
382 | + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||
383 | + | ||
384 | + for g in optimizer.param_groups: | ||
385 | + g['lr'] = 0.00001 | ||
386 | + for g in aux_optimizer.param_groups: | ||
387 | + g['lr'] = 0.00001 | ||
388 | + | ||
389 | + best_loss = float("inf") | ||
390 | + for epoch in range(last_epoch, args.epochs): | ||
391 | + start = time.time() | ||
392 | + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") | ||
393 | + train_loss, train_bpp_loss, train_aux = train_one_epoch( | ||
394 | + net, | ||
395 | + criterion, | ||
396 | + train_dataloader, | ||
397 | + optimizer, | ||
398 | + aux_optimizer, | ||
399 | + epoch, | ||
400 | + args.clip_max_norm, | ||
401 | + ) | ||
402 | + loss, bpp_loss, aux = test_epoch(epoch, test_dataloader, net, criterion) | ||
403 | + lr_scheduler.step(loss) | ||
404 | + | ||
405 | + row = {'epoch': str(epoch), 'train_loss': str(train_loss.item()),'train_bpp_loss': str(train_bpp_loss.item()),'train_aux': str(train_aux.item()), 'test_loss': str(loss.item()), 'test_bpp_loss': str(bpp_loss.item()), 'test_aux': str(aux.item())} | ||
406 | + csv_logger.writerow(row)### | ||
407 | + | ||
408 | + is_best = loss < best_loss | ||
409 | + best_loss = min(loss, best_loss) | ||
410 | + | ||
411 | + | ||
412 | + if args.save: | ||
413 | + save_checkpoint( | ||
414 | + { | ||
415 | + "epoch": epoch, | ||
416 | + "state_dict": net.state_dict(), | ||
417 | + "loss": loss, | ||
418 | + "optimizer": optimizer.state_dict(), | ||
419 | + "aux_optimizer": aux_optimizer.state_dict(), | ||
420 | + "lr_scheduler": lr_scheduler.state_dict(), | ||
421 | + }, | ||
422 | + is_best, | ||
423 | + str(args.quality) | ||
424 | + ) | ||
425 | + print(f"Total TIme: {time.time() - start}") | ||
426 | + csv_logger.close()### | ||
427 | + | ||
428 | + | ||
429 | +if __name__ == "__main__": | ||
430 | + main(sys.argv[1:]) |
Our Encoder/train_RGB_MS-SSIMloss.py
0 → 100644
1 | +# Copyright 2020 InterDigital Communications, Inc. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import argparse | ||
16 | +import math | ||
17 | +import random | ||
18 | +import shutil | ||
19 | +import sys | ||
20 | +import time | ||
21 | + | ||
22 | +import torch | ||
23 | +import torch.nn as nn | ||
24 | +import torch.optim as optim | ||
25 | + | ||
26 | +from torch.utils.data import DataLoader | ||
27 | +from torchvision import transforms | ||
28 | + | ||
29 | +from compressai.datasets import ImageFolder | ||
30 | +from compressai.zoo import models | ||
31 | +import csv | ||
32 | +import cv2 | ||
33 | +import numpy as np | ||
34 | +from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM | ||
35 | + | ||
36 | +class RateDistortionLoss(nn.Module): | ||
37 | + """Custom rate distortion loss with a Lagrangian parameter.""" | ||
38 | + | ||
39 | + def __init__(self, lmbda=1e-2): | ||
40 | + super().__init__() | ||
41 | + self.mse = ms_ssim | ||
42 | + self.lmbda = lmbda | ||
43 | + | ||
44 | + def forward(self, output, target): | ||
45 | + N, _, H, W = target.size() | ||
46 | + out = {} | ||
47 | + num_pixels = N * H * W | ||
48 | + | ||
49 | + out["bpp_loss"] = sum( | ||
50 | + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) | ||
51 | + for likelihoods in output["likelihoods"].values() | ||
52 | + ) | ||
53 | + out["mse_loss"] = 1 - self.mse(output["x_hat"], target, data_range=1.) | ||
54 | + out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"] | ||
55 | + | ||
56 | + return out | ||
57 | + | ||
58 | + | ||
59 | +class AverageMeter: | ||
60 | + """Compute running average.""" | ||
61 | + | ||
62 | + def __init__(self): | ||
63 | + self.val = 0 | ||
64 | + self.avg = 0 | ||
65 | + self.sum = 0 | ||
66 | + self.count = 0 | ||
67 | + | ||
68 | + def update(self, val, n=1): | ||
69 | + self.val = val | ||
70 | + self.sum += val * n | ||
71 | + self.count += n | ||
72 | + self.avg = self.sum / self.count | ||
73 | + | ||
74 | + | ||
75 | +class CustomDataParallel(nn.DataParallel): | ||
76 | + """Custom DataParallel to access the module methods.""" | ||
77 | + | ||
78 | + def __getattr__(self, key): | ||
79 | + try: | ||
80 | + return super().__getattr__(key) | ||
81 | + except AttributeError: | ||
82 | + return getattr(self.module, key) | ||
83 | + | ||
84 | + | ||
85 | +def configure_optimizers(net, args): | ||
86 | + """Separate parameters for the main optimizer and the auxiliary optimizer. | ||
87 | + Return two optimizers""" | ||
88 | + | ||
89 | + parameters = set( | ||
90 | + n | ||
91 | + for n, p in net.named_parameters() | ||
92 | + if not n.endswith(".quantiles") and p.requires_grad | ||
93 | + ) | ||
94 | + aux_parameters = set( | ||
95 | + n | ||
96 | + for n, p in net.named_parameters() | ||
97 | + if n.endswith(".quantiles") and p.requires_grad | ||
98 | + ) | ||
99 | + | ||
100 | + # Make sure we don't have an intersection of parameters | ||
101 | + params_dict = dict(net.named_parameters()) | ||
102 | + inter_params = parameters & aux_parameters | ||
103 | + union_params = parameters | aux_parameters | ||
104 | + | ||
105 | + assert len(inter_params) == 0 | ||
106 | + assert len(union_params) - len(params_dict.keys()) == 0 | ||
107 | + | ||
108 | + optimizer = optim.Adam( | ||
109 | + (params_dict[n] for n in sorted(list(parameters))), | ||
110 | + lr=args.learning_rate, | ||
111 | + ) | ||
112 | + aux_optimizer = optim.Adam( | ||
113 | + (params_dict[n] for n in sorted(list(aux_parameters))), | ||
114 | + lr=args.aux_learning_rate, | ||
115 | + ) | ||
116 | + return optimizer, aux_optimizer | ||
117 | + | ||
118 | + | ||
119 | +def train_one_epoch( | ||
120 | + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm | ||
121 | +): | ||
122 | + model.train() | ||
123 | + device = next(model.parameters()).device | ||
124 | + | ||
125 | + loss = AverageMeter() | ||
126 | + bpp_loss = AverageMeter() | ||
127 | + mse_loss = AverageMeter() | ||
128 | + a_aux_loss = AverageMeter() | ||
129 | + | ||
130 | + for i, d in enumerate(train_dataloader): | ||
131 | + d = d.to(device) | ||
132 | + | ||
133 | + optimizer.zero_grad() | ||
134 | + aux_optimizer.zero_grad() | ||
135 | + | ||
136 | + out_net = model(d) | ||
137 | + | ||
138 | + out_criterion = criterion(out_net, d) | ||
139 | + | ||
140 | + bpp_loss.update(out_criterion["bpp_loss"]) | ||
141 | + loss.update(out_criterion["loss"]) | ||
142 | + mse_loss.update(out_criterion["mse_loss"]) | ||
143 | + | ||
144 | + out_criterion["loss"].backward() | ||
145 | + | ||
146 | + if clip_max_norm > 0: | ||
147 | + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) | ||
148 | + optimizer.step() | ||
149 | + | ||
150 | + aux_loss = model.aux_loss() | ||
151 | + a_aux_loss.update(aux_loss) | ||
152 | + aux_loss.backward() | ||
153 | + aux_optimizer.step() | ||
154 | + | ||
155 | + if i % 10 == 0: | ||
156 | + print( | ||
157 | + f"Train epoch {epoch}: [" | ||
158 | + f"{i*len(d)}/{len(train_dataloader.dataset)}" | ||
159 | + f" ({100. * i / len(train_dataloader):.0f}%)]" | ||
160 | + f'\tLoss: {out_criterion["loss"].item():.3f} |' | ||
161 | + f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |' | ||
162 | + f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' | ||
163 | + f"\tAux loss: {aux_loss.item():.2f}" | ||
164 | + ) | ||
165 | + return loss.avg, bpp_loss.avg, a_aux_loss.avg | ||
166 | + | ||
167 | + | ||
168 | +def test_epoch(epoch, test_dataloader, model, criterion): | ||
169 | + model.eval() | ||
170 | + device = next(model.parameters()).device | ||
171 | + | ||
172 | + loss = AverageMeter() | ||
173 | + bpp_loss = AverageMeter() | ||
174 | + mse_loss = AverageMeter() | ||
175 | + aux_loss = AverageMeter() | ||
176 | + | ||
177 | + with torch.no_grad(): | ||
178 | + for d in test_dataloader: | ||
179 | + d = d.to(device) | ||
180 | + out_net = model(d) | ||
181 | + out_criterion = criterion(out_net, d) | ||
182 | + | ||
183 | + aux_loss.update(model.aux_loss()) | ||
184 | + bpp_loss.update(out_criterion["bpp_loss"]) | ||
185 | + loss.update(out_criterion["loss"]) | ||
186 | + mse_loss.update(out_criterion["mse_loss"]) | ||
187 | + | ||
188 | + print( | ||
189 | + f"Test epoch {epoch}: Average losses:" | ||
190 | + f"\tLoss: {loss.avg:.3f} |" | ||
191 | + f"\tMSE loss: {mse_loss.avg:.3f} |" | ||
192 | + f"\tBpp loss: {bpp_loss.avg:.2f} |" | ||
193 | + f"\tAux loss: {aux_loss.avg:.2f}\n" | ||
194 | + ) | ||
195 | + | ||
196 | + return loss.avg, bpp_loss.avg, aux_loss.avg | ||
197 | + | ||
198 | +def save_checkpoint(state, is_best, q, filename="checkpoint_msssim"): | ||
199 | + torch.save(state, filename+q+".pth.tar") | ||
200 | + if is_best: | ||
201 | + shutil.copyfile( filename+q+".pth.tar", "checkpoint_best_loss_msssim"+q+".pth.tar") | ||
202 | + | ||
203 | + | ||
204 | +def parse_args(argv): | ||
205 | + parser = argparse.ArgumentParser(description="Example training script.") | ||
206 | + parser.add_argument( | ||
207 | + "-m", | ||
208 | + "--model", | ||
209 | + default="bmshj2018-hyperprior", | ||
210 | + choices=models.keys(), | ||
211 | + help="Model architecture (default: %(default)s)", | ||
212 | + ) | ||
213 | + parser.add_argument( | ||
214 | + "-d", "--dataset", type=str, required=True, help="Training dataset" | ||
215 | + ) | ||
216 | + parser.add_argument( | ||
217 | + "-e", | ||
218 | + "--epochs", | ||
219 | + default=100, | ||
220 | + type=int, | ||
221 | + help="Number of epochs (default: %(default)s)", | ||
222 | + ) | ||
223 | + parser.add_argument( | ||
224 | + "-lr", | ||
225 | + "--learning-rate", | ||
226 | + default=1e-4, | ||
227 | + type=float, | ||
228 | + help="Learning rate (default: %(default)s)", | ||
229 | + ) | ||
230 | + parser.add_argument( | ||
231 | + "-n", | ||
232 | + "--num-workers", | ||
233 | + type=int, | ||
234 | + default=0, | ||
235 | + help="Dataloaders threads (default: %(default)s)", | ||
236 | + ) | ||
237 | + parser.add_argument( | ||
238 | + "--lambda", | ||
239 | + dest="lmbda", | ||
240 | + type=float, | ||
241 | + default=1e-2, | ||
242 | + help="Bit-rate distortion parameter (default: %(default)s)", | ||
243 | + ) | ||
244 | + parser.add_argument( | ||
245 | + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" | ||
246 | + ) | ||
247 | + parser.add_argument( | ||
248 | + "--test-batch-size", | ||
249 | + type=int, | ||
250 | + default=64, | ||
251 | + help="Test batch size (default: %(default)s)", | ||
252 | + ) | ||
253 | + parser.add_argument( | ||
254 | + "--aux-learning-rate", | ||
255 | + default=1e-3, | ||
256 | + help="Auxiliary loss learning rate (default: %(default)s)", | ||
257 | + ) | ||
258 | + parser.add_argument( | ||
259 | + "--patch-size", | ||
260 | + type=int, | ||
261 | + nargs=2, | ||
262 | + default=(256, 256), | ||
263 | + help="Size of the patches to be cropped (default: %(default)s)", | ||
264 | + ) | ||
265 | + parser.add_argument( | ||
266 | + "-q", | ||
267 | + "--quality", | ||
268 | + type=int, | ||
269 | + default=3, | ||
270 | + help="Quality (default: %(default)s)", | ||
271 | + ) | ||
272 | + parser.add_argument("--cuda", action="store_true", help="Use cuda") | ||
273 | + parser.add_argument("--save", action="store_true", help="Save model to disk") | ||
274 | + parser.add_argument( | ||
275 | + "--seed", type=float, help="Set random seed for reproducibility" | ||
276 | + ) | ||
277 | + parser.add_argument( | ||
278 | + "--clip_max_norm", | ||
279 | + default=1.0, | ||
280 | + type=float, | ||
281 | + help="gradient clipping max norm (default: %(default)s", | ||
282 | + ) | ||
283 | + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint") | ||
284 | + args = parser.parse_args(argv) | ||
285 | + return args | ||
286 | + | ||
287 | +class CSVLogger(): | ||
288 | + def __init__(self, fieldnames, filename='log.csv'): | ||
289 | + | ||
290 | + self.filename = filename | ||
291 | + self.csv_file = open(filename, 'a') | ||
292 | + | ||
293 | + # Write model configuration at top of csv | ||
294 | + writer = csv.writer(self.csv_file) | ||
295 | + | ||
296 | + self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) | ||
297 | + # self.writer.writeheader() | ||
298 | + | ||
299 | + # self.csv_file.flush() | ||
300 | + | ||
301 | + def writerow(self, row): | ||
302 | + self.writer.writerow(row) | ||
303 | + self.csv_file.flush() | ||
304 | + | ||
305 | + def close(self): | ||
306 | + self.csv_file.close() | ||
307 | + | ||
308 | +class Blur(object): | ||
309 | + def __init__(self, k, sig): | ||
310 | + self.k = k | ||
311 | + self.sig = sig | ||
312 | + | ||
313 | + def __call__(self, img): | ||
314 | + r=np.random.rand(1) | ||
315 | + if r<0.5: | ||
316 | + img=cv2.GaussianBlur(img.numpy(), (self.k,self.k), self.sig) | ||
317 | + img=torch.from_numpy(img) | ||
318 | + return img | ||
319 | + | ||
320 | +def main(argv): | ||
321 | + args = parse_args(argv) | ||
322 | + | ||
323 | + if args.seed is not None: | ||
324 | + torch.manual_seed(args.seed) | ||
325 | + random.seed(args.seed) | ||
326 | + | ||
327 | + train_transforms = transforms.Compose( | ||
328 | + [transforms.RandomCrop(args.patch_size), | ||
329 | + transforms.RandomRotation(30), | ||
330 | + transforms.RandomHorizontalFlip(), | ||
331 | + transforms.ToTensor()] | ||
332 | + ) | ||
333 | + #train_transforms.transforms.append(Blur(k=3, sig=5)) | ||
334 | + | ||
335 | + test_transforms = transforms.Compose( | ||
336 | + [transforms.CenterCrop(args.patch_size), transforms.ToTensor()] | ||
337 | + ) | ||
338 | + | ||
339 | + train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) | ||
340 | + test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms) | ||
341 | + | ||
342 | + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" | ||
343 | + print(torch.cuda.is_available()) | ||
344 | + print(device) | ||
345 | + train_dataloader = DataLoader( | ||
346 | + train_dataset, | ||
347 | + batch_size=args.batch_size, | ||
348 | + num_workers=args.num_workers, | ||
349 | + shuffle=True, | ||
350 | + pin_memory=(device == "cuda"), | ||
351 | + ) | ||
352 | + | ||
353 | + test_dataloader = DataLoader( | ||
354 | + test_dataset, | ||
355 | + batch_size=args.test_batch_size, | ||
356 | + num_workers=args.num_workers, | ||
357 | + shuffle=False, | ||
358 | + pin_memory=(device == "cuda"), | ||
359 | + ) | ||
360 | + | ||
361 | + net = models[args.model](quality=args.quality, pretrained=False) | ||
362 | + net = net.to(device) | ||
363 | + | ||
364 | + #if args.cuda and torch.cuda.device_count() > 1: | ||
365 | + # net = CustomDataParallel(net) | ||
366 | + | ||
367 | + optimizer, aux_optimizer = configure_optimizers(net, args) | ||
368 | + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=20) | ||
369 | + criterion = RateDistortionLoss(lmbda=args.lmbda) | ||
370 | + | ||
371 | + filename = "train_msssim"+str(args.quality)+".csv" | ||
372 | + csv_logger = CSVLogger(fieldnames=['epoch', 'train_loss', 'train_bpp_loss','train_aux', 'test_loss', 'test_bpp_loss', 'test_aux'], filename=filename) | ||
373 | + | ||
374 | + last_epoch = 0 | ||
375 | + if args.checkpoint: # load from previous checkpoint | ||
376 | + print("Loading", args.checkpoint) | ||
377 | + checkpoint = torch.load(args.checkpoint, map_location=device) | ||
378 | + last_epoch = checkpoint["epoch"] + 1 | ||
379 | + net.load_state_dict(checkpoint["state_dict"]) | ||
380 | + optimizer.load_state_dict(checkpoint["optimizer"]) | ||
381 | + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) | ||
382 | +# for g in optimizer.param_groups: | ||
383 | +# g['lr'] = 0.0001 | ||
384 | +# for g in aux_optimizer.param_groups: | ||
385 | +# g['lr'] = 0.0001 | ||
386 | + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||
387 | + | ||
388 | + best_loss = float("inf") | ||
389 | + for epoch in range(last_epoch, args.epochs): | ||
390 | + start = time.time() | ||
391 | + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") | ||
392 | + train_loss, train_bpp_loss, train_aux = train_one_epoch( | ||
393 | + net, | ||
394 | + criterion, | ||
395 | + train_dataloader, | ||
396 | + optimizer, | ||
397 | + aux_optimizer, | ||
398 | + epoch, | ||
399 | + args.clip_max_norm, | ||
400 | + ) | ||
401 | + loss, bpp_loss, aux = test_epoch(epoch, test_dataloader, net, criterion) | ||
402 | + lr_scheduler.step(loss) | ||
403 | + | ||
404 | + row = {'epoch': str(epoch), 'train_loss': str(train_loss.item()),'train_bpp_loss': str(train_bpp_loss.item()),'train_aux': str(train_aux.item()), 'test_loss': str(loss.item()), 'test_bpp_loss': str(bpp_loss.item()), 'test_aux': str(aux.item())} | ||
405 | + csv_logger.writerow(row)### | ||
406 | + | ||
407 | + is_best = loss < best_loss | ||
408 | + best_loss = min(loss, best_loss) | ||
409 | + | ||
410 | + | ||
411 | + if args.save: | ||
412 | + save_checkpoint( | ||
413 | + { | ||
414 | + "epoch": epoch, | ||
415 | + "state_dict": net.state_dict(), | ||
416 | + "loss": loss, | ||
417 | + "optimizer": optimizer.state_dict(), | ||
418 | + "aux_optimizer": aux_optimizer.state_dict(), | ||
419 | + "lr_scheduler": lr_scheduler.state_dict(), | ||
420 | + }, | ||
421 | + is_best, | ||
422 | + str(args.quality) | ||
423 | + ) | ||
424 | + print(f"Total TIme: {time.time() - start}") | ||
425 | + csv_logger.close()### | ||
426 | + | ||
427 | + | ||
428 | +if __name__ == "__main__": | ||
429 | + main(sys.argv[1:]) |
Our Encoder/train_YCbCr.py
0 → 100644
1 | +# Copyright 2020 InterDigital Communications, Inc. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import argparse | ||
16 | +import math | ||
17 | +import random | ||
18 | +import shutil | ||
19 | +import sys | ||
20 | +import time | ||
21 | + | ||
22 | +import torch | ||
23 | +import torch.nn as nn | ||
24 | +import torch.optim as optim | ||
25 | + | ||
26 | +from torch.utils.data import DataLoader | ||
27 | +from torchvision import transforms | ||
28 | + | ||
29 | +from compressai.datasets import ImageFolder | ||
30 | +from compressai.zoo import models | ||
31 | +import csv | ||
32 | +import cv2 | ||
33 | +import numpy as np | ||
34 | +from compressai.transforms.functional import ( | ||
35 | + rgb2ycbcr, | ||
36 | + ycbcr2rgb, | ||
37 | + yuv_420_to_444, | ||
38 | + yuv_444_to_420, | ||
39 | +) | ||
40 | + | ||
41 | +class RateDistortionLoss(nn.Module): | ||
42 | + """Custom rate distortion loss with a Lagrangian parameter.""" | ||
43 | +# mse 함수를 4:1:1로 바꾸기 | ||
44 | + def __init__(self, lmbda=1e-2): | ||
45 | + super().__init__() | ||
46 | + self.mse = nn.MSELoss() | ||
47 | + self.lmbda = lmbda | ||
48 | + | ||
49 | + def forward(self, output, target): | ||
50 | + N, _, H, W = target.size() | ||
51 | + out = {} | ||
52 | + num_pixels = N * H * W | ||
53 | + | ||
54 | + out["bpp_loss"] = sum( | ||
55 | + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) | ||
56 | + for likelihoods in output["likelihoods"].values() | ||
57 | + ) | ||
58 | + o_y, o_cb, o_cr = output["x_hat"].chunk(3, -3) | ||
59 | + t_y, t_cb, t_cr = target.chunk(3, -3) | ||
60 | + mse_y = self.mse(o_y, t_y) | ||
61 | + mse_cb = self.mse(o_cb, t_cb) | ||
62 | + mse_cr = self.mse(o_cr, t_cr) | ||
63 | + | ||
64 | + out["mse_loss"] = (4*mse_y+mse_cb+mse_cr)/6 | ||
65 | + out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"] | ||
66 | + | ||
67 | + return out | ||
68 | + | ||
69 | + | ||
70 | + | ||
71 | +class AverageMeter: | ||
72 | + """Compute running average.""" | ||
73 | + | ||
74 | + def __init__(self): | ||
75 | + self.val = 0 | ||
76 | + self.avg = 0 | ||
77 | + self.sum = 0 | ||
78 | + self.count = 0 | ||
79 | + | ||
80 | + def update(self, val, n=1): | ||
81 | + self.val = val | ||
82 | + self.sum += val * n | ||
83 | + self.count += n | ||
84 | + self.avg = self.sum / self.count | ||
85 | + | ||
86 | + | ||
87 | +class CustomDataParallel(nn.DataParallel): | ||
88 | + """Custom DataParallel to access the module methods.""" | ||
89 | + | ||
90 | + def __getattr__(self, key): | ||
91 | + try: | ||
92 | + return super().__getattr__(key) | ||
93 | + except AttributeError: | ||
94 | + return getattr(self.module, key) | ||
95 | + | ||
96 | + | ||
97 | +def configure_optimizers(net, args): | ||
98 | + """Separate parameters for the main optimizer and the auxiliary optimizer. | ||
99 | + Return two optimizers""" | ||
100 | + | ||
101 | + parameters = set( | ||
102 | + n | ||
103 | + for n, p in net.named_parameters() | ||
104 | + if not n.endswith(".quantiles") and p.requires_grad | ||
105 | + ) | ||
106 | + aux_parameters = set( | ||
107 | + n | ||
108 | + for n, p in net.named_parameters() | ||
109 | + if n.endswith(".quantiles") and p.requires_grad | ||
110 | + ) | ||
111 | + | ||
112 | + # Make sure we don't have an intersection of parameters | ||
113 | + params_dict = dict(net.named_parameters()) | ||
114 | + inter_params = parameters & aux_parameters | ||
115 | + union_params = parameters | aux_parameters | ||
116 | + | ||
117 | + assert len(inter_params) == 0 | ||
118 | + assert len(union_params) - len(params_dict.keys()) == 0 | ||
119 | + | ||
120 | + optimizer = optim.Adam( | ||
121 | + (params_dict[n] for n in sorted(list(parameters))), | ||
122 | + lr=args.learning_rate, | ||
123 | + ) | ||
124 | + aux_optimizer = optim.Adam( | ||
125 | + (params_dict[n] for n in sorted(list(aux_parameters))), | ||
126 | + lr=args.aux_learning_rate, | ||
127 | + ) | ||
128 | + return optimizer, aux_optimizer | ||
129 | + | ||
130 | + | ||
131 | +def train_one_epoch( | ||
132 | + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm | ||
133 | +): | ||
134 | + model.train() | ||
135 | + device = next(model.parameters()).device | ||
136 | + | ||
137 | + loss = AverageMeter() | ||
138 | + bpp_loss = AverageMeter() | ||
139 | + mse_loss = AverageMeter() | ||
140 | + a_aux_loss = AverageMeter() | ||
141 | + | ||
142 | + for i, d in enumerate(train_dataloader): | ||
143 | + d = d.to(device) | ||
144 | + | ||
145 | + optimizer.zero_grad() | ||
146 | + aux_optimizer.zero_grad() | ||
147 | + | ||
148 | + out_net = model(d) | ||
149 | + | ||
150 | + out_criterion = criterion(out_net, d) | ||
151 | + | ||
152 | + bpp_loss.update(out_criterion["bpp_loss"]) | ||
153 | + loss.update(out_criterion["loss"]) | ||
154 | + mse_loss.update(out_criterion["mse_loss"]) | ||
155 | + | ||
156 | + out_criterion["loss"].backward() | ||
157 | + | ||
158 | + if clip_max_norm > 0: | ||
159 | + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) | ||
160 | + optimizer.step() | ||
161 | + | ||
162 | + aux_loss = model.aux_loss() | ||
163 | + a_aux_loss.update(aux_loss) | ||
164 | + aux_loss.backward() | ||
165 | + aux_optimizer.step() | ||
166 | + | ||
167 | + if i % 10 == 0: | ||
168 | + print( | ||
169 | + f"Train epoch {epoch}: [" | ||
170 | + f"{i*len(d)}/{len(train_dataloader.dataset)}" | ||
171 | + f" ({100. * i / len(train_dataloader):.0f}%)]" | ||
172 | + f'\tLoss: {out_criterion["loss"].item():.3f} |' | ||
173 | + f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |' | ||
174 | + f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' | ||
175 | + f"\tAux loss: {aux_loss.item():.2f}" | ||
176 | + ) | ||
177 | + return loss.avg, bpp_loss.avg, a_aux_loss.avg | ||
178 | + | ||
179 | + | ||
180 | +def test_epoch(epoch, test_dataloader, model, criterion): | ||
181 | + model.eval() | ||
182 | + device = next(model.parameters()).device | ||
183 | + | ||
184 | + loss = AverageMeter() | ||
185 | + bpp_loss = AverageMeter() | ||
186 | + mse_loss = AverageMeter() | ||
187 | + aux_loss = AverageMeter() | ||
188 | + | ||
189 | + with torch.no_grad(): | ||
190 | + for d in test_dataloader: | ||
191 | + d = d.to(device) | ||
192 | + out_net = model(d) | ||
193 | + out_criterion = criterion(out_net, d) | ||
194 | + | ||
195 | + aux_loss.update(model.aux_loss()) | ||
196 | + bpp_loss.update(out_criterion["bpp_loss"]) | ||
197 | + loss.update(out_criterion["loss"]) | ||
198 | + mse_loss.update(out_criterion["mse_loss"]) | ||
199 | + | ||
200 | + print( | ||
201 | + f"Test epoch {epoch}: Average losses:" | ||
202 | + f"\tLoss: {loss.avg:.3f} |" | ||
203 | + f"\tMSE loss: {mse_loss.avg:.3f} |" | ||
204 | + f"\tBpp loss: {bpp_loss.avg:.2f} |" | ||
205 | + f"\tAux loss: {aux_loss.avg:.2f}\n" | ||
206 | + ) | ||
207 | + | ||
208 | + return loss.avg, bpp_loss.avg, aux_loss.avg | ||
209 | + | ||
210 | +def save_checkpoint(state, is_best, q, filename="checkpoint"): | ||
211 | + torch.save(state, filename+q+".pth.tar") | ||
212 | + if is_best: | ||
213 | + shutil.copyfile( filename+q+".pth.tar", "checkpoint_best_loss"+q+".pth.tar") | ||
214 | + | ||
215 | + | ||
216 | +def parse_args(argv): | ||
217 | + parser = argparse.ArgumentParser(description="Example training script.") | ||
218 | + parser.add_argument( | ||
219 | + "-m", | ||
220 | + "--model", | ||
221 | + default="bmshj2018-hyperprior", | ||
222 | + choices=models.keys(), | ||
223 | + help="Model architecture (default: %(default)s)", | ||
224 | + ) | ||
225 | + parser.add_argument( | ||
226 | + "-d", "--dataset", type=str, required=True, help="Training dataset" | ||
227 | + ) | ||
228 | + parser.add_argument( | ||
229 | + "-e", | ||
230 | + "--epochs", | ||
231 | + default=100, | ||
232 | + type=int, | ||
233 | + help="Number of epochs (default: %(default)s)", | ||
234 | + ) | ||
235 | + parser.add_argument( | ||
236 | + "-lr", | ||
237 | + "--learning-rate", | ||
238 | + default=1e-4, | ||
239 | + type=float, | ||
240 | + help="Learning rate (default: %(default)s)", | ||
241 | + ) | ||
242 | + parser.add_argument( | ||
243 | + "-n", | ||
244 | + "--num-workers", | ||
245 | + type=int, | ||
246 | + default=0, | ||
247 | + help="Dataloaders threads (default: %(default)s)", | ||
248 | + ) | ||
249 | + parser.add_argument( | ||
250 | + "--lambda", | ||
251 | + dest="lmbda", | ||
252 | + type=float, | ||
253 | + default=1e-2, | ||
254 | + help="Bit-rate distortion parameter (default: %(default)s)", | ||
255 | + ) | ||
256 | + parser.add_argument( | ||
257 | + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" | ||
258 | + ) | ||
259 | + parser.add_argument( | ||
260 | + "--test-batch-size", | ||
261 | + type=int, | ||
262 | + default=64, | ||
263 | + help="Test batch size (default: %(default)s)", | ||
264 | + ) | ||
265 | + parser.add_argument( | ||
266 | + "--aux-learning-rate", | ||
267 | + default=1e-3, | ||
268 | + help="Auxiliary loss learning rate (default: %(default)s)", | ||
269 | + ) | ||
270 | + parser.add_argument( | ||
271 | + "--patch-size", | ||
272 | + type=int, | ||
273 | + nargs=2, | ||
274 | + default=(256, 256), | ||
275 | + help="Size of the patches to be cropped (default: %(default)s)", | ||
276 | + ) | ||
277 | + parser.add_argument( | ||
278 | + "-q", | ||
279 | + "--quality", | ||
280 | + type=int, | ||
281 | + default=3, | ||
282 | + help="Quality (default: %(default)s)", | ||
283 | + ) | ||
284 | + parser.add_argument("--cuda", action="store_true", help="Use cuda") | ||
285 | + parser.add_argument("--save", action="store_true", help="Save model to disk") | ||
286 | + parser.add_argument( | ||
287 | + "--seed", type=float, help="Set random seed for reproducibility" | ||
288 | + ) | ||
289 | + parser.add_argument( | ||
290 | + "--clip_max_norm", | ||
291 | + default=1.0, | ||
292 | + type=float, | ||
293 | + help="gradient clipping max norm (default: %(default)s", | ||
294 | + ) | ||
295 | + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint") | ||
296 | + args = parser.parse_args(argv) | ||
297 | + return args | ||
298 | + | ||
299 | +class CSVLogger(): | ||
300 | + def __init__(self, fieldnames, filename='log.csv'): | ||
301 | + | ||
302 | + self.filename = filename | ||
303 | + self.csv_file = open(filename, 'a') | ||
304 | + | ||
305 | + # Write model configuration at top of csv | ||
306 | + writer = csv.writer(self.csv_file) | ||
307 | + | ||
308 | + self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) | ||
309 | + # self.writer.writeheader() | ||
310 | + | ||
311 | + # self.csv_file.flush() | ||
312 | + | ||
313 | + def writerow(self, row): | ||
314 | + self.writer.writerow(row) | ||
315 | + self.csv_file.flush() | ||
316 | + | ||
317 | + def close(self): | ||
318 | + self.csv_file.close() | ||
319 | + | ||
320 | +class Blur(object): | ||
321 | + def __init__(self, k, sig): | ||
322 | + self.k = k | ||
323 | + self.sig = sig | ||
324 | + | ||
325 | + def __call__(self, img): | ||
326 | + r=np.random.rand(1) | ||
327 | + if r<0.5: | ||
328 | + img=cv2.GaussianBlur(img.numpy(), (self.k,self.k), self.sig) | ||
329 | + img=torch.from_numpy(img) | ||
330 | + return img | ||
331 | + | ||
332 | +class RGB2YCbCr(object): | ||
333 | + | ||
334 | + def __call__(self, img): | ||
335 | + """ | ||
336 | + Args: | ||
337 | + img (Tensor): Tensor image of size (C, H, W). | ||
338 | + Returns: | ||
339 | + Tensor: Image with n_holes of dimension length x length cut out of it. | ||
340 | + """ | ||
341 | + img=rgb2ycbcr(img) | ||
342 | + | ||
343 | + return img | ||
344 | + | ||
345 | +def main(argv): | ||
346 | + args = parse_args(argv) | ||
347 | + | ||
348 | + if args.seed is not None: | ||
349 | + torch.manual_seed(args.seed) | ||
350 | + random.seed(args.seed) | ||
351 | + | ||
352 | + train_transforms = transforms.Compose( | ||
353 | + [transforms.RandomCrop(args.patch_size), #이미지 크기 조절 | ||
354 | + transforms.RandomRotation(30), | ||
355 | + transforms.RandomHorizontalFlip(), | ||
356 | + transforms.ToTensor()] # numpy이미지에서 torch이미지로 변경 | ||
357 | + ) | ||
358 | + train_transforms.transforms.append(RGB2YCbCr()) | ||
359 | + | ||
360 | +# print(train_transforms.shape) | ||
361 | +# train_transforms=rgb2ycbcr(train_transforms) | ||
362 | + #train_transforms.transforms.append(Blur(k=3, sig=5)) | ||
363 | + | ||
364 | + test_transforms = transforms.Compose( | ||
365 | + [transforms.CenterCrop(args.patch_size), transforms.ToTensor()] | ||
366 | + ) | ||
367 | + | ||
368 | + train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) | ||
369 | + test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms) | ||
370 | + | ||
371 | + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" | ||
372 | + print(torch.cuda.is_available()) | ||
373 | + print(device) | ||
374 | + train_dataloader = DataLoader( | ||
375 | + train_dataset, | ||
376 | + batch_size=args.batch_size, | ||
377 | + num_workers=args.num_workers, | ||
378 | + shuffle=True, | ||
379 | + pin_memory=(device == "cuda"), | ||
380 | + ) | ||
381 | + | ||
382 | + test_dataloader = DataLoader( | ||
383 | + test_dataset, | ||
384 | + batch_size=args.test_batch_size, | ||
385 | + num_workers=args.num_workers, | ||
386 | + shuffle=False, | ||
387 | + pin_memory=(device == "cuda"), | ||
388 | + ) | ||
389 | + | ||
390 | + net = models[args.model](quality=args.quality, pretrained=False) | ||
391 | + net = net.to(device) | ||
392 | + | ||
393 | + #if args.cuda and torch.cuda.device_count() > 1: | ||
394 | + # net = CustomDataParallel(net) | ||
395 | + | ||
396 | + optimizer, aux_optimizer = configure_optimizers(net, args) | ||
397 | +# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=20) | ||
398 | + lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400, 700], gamma=0.2) | ||
399 | + criterion = RateDistortionLoss(lmbda=args.lmbda) | ||
400 | + | ||
401 | + filename = "train"+str(args.quality)+".csv" | ||
402 | + csv_logger = CSVLogger(fieldnames=['epoch', 'train_loss', 'train_bpp_loss','train_aux', 'test_loss', 'test_bpp_loss', 'test_aux'], filename=filename) | ||
403 | + | ||
404 | + last_epoch = 0 | ||
405 | + if args.checkpoint: # load from previous checkpoint | ||
406 | + print("Loading", args.checkpoint) | ||
407 | + checkpoint = torch.load(args.checkpoint, map_location=device) | ||
408 | + last_epoch = checkpoint["epoch"] + 1 | ||
409 | + net.load_state_dict(checkpoint["state_dict"]) | ||
410 | + optimizer.load_state_dict(checkpoint["optimizer"]) | ||
411 | + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) | ||
412 | +# for g in optimizer.param_groups: | ||
413 | +# g['lr'] = 0.0001 | ||
414 | +# for g in aux_optimizer.param_groups: | ||
415 | +# g['lr'] = 0.0001 | ||
416 | + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||
417 | + | ||
418 | + best_loss = float("inf") | ||
419 | + for epoch in range(last_epoch, args.epochs): | ||
420 | + start = time.time() | ||
421 | + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") | ||
422 | + train_loss, train_bpp_loss, train_aux = train_one_epoch( | ||
423 | + net, | ||
424 | + criterion, | ||
425 | + train_dataloader, | ||
426 | + optimizer, | ||
427 | + aux_optimizer, | ||
428 | + epoch, | ||
429 | + args.clip_max_norm, | ||
430 | + ) | ||
431 | + loss, bpp_loss, aux = test_epoch(epoch, test_dataloader, net, criterion) | ||
432 | + lr_scheduler.step() | ||
433 | + | ||
434 | + row = {'epoch': str(epoch), 'train_loss': str(train_loss.item()),'train_bpp_loss': str(train_bpp_loss.item()),'train_aux': str(train_aux.item()), 'test_loss': str(loss.item()), 'test_bpp_loss': str(bpp_loss.item()), 'test_aux': str(aux.item())} | ||
435 | + csv_logger.writerow(row)### | ||
436 | + | ||
437 | + is_best = loss < best_loss | ||
438 | + best_loss = min(loss, best_loss) | ||
439 | + | ||
440 | + | ||
441 | + if args.save: | ||
442 | + save_checkpoint( | ||
443 | + { | ||
444 | + "epoch": epoch, | ||
445 | + "state_dict": net.state_dict(), | ||
446 | + "loss": loss, | ||
447 | + "optimizer": optimizer.state_dict(), | ||
448 | + "aux_optimizer": aux_optimizer.state_dict(), | ||
449 | + "lr_scheduler": lr_scheduler.state_dict(), | ||
450 | + }, | ||
451 | + is_best, | ||
452 | + str(args.quality) | ||
453 | + ) | ||
454 | + print(f"Total TIme: {time.time() - start}") | ||
455 | + csv_logger.close()### | ||
456 | + | ||
457 | + | ||
458 | +if __name__ == "__main__": | ||
459 | + main(sys.argv[1:]) |
-
Please register or login to post a comment