choiseungmi

code update

...@@ -151,7 +151,7 @@ def compute_psnr(a, b): ...@@ -151,7 +151,7 @@ def compute_psnr(a, b):
151 mse = torch.mean((a - b)**2).item() 151 mse = torch.mean((a - b)**2).item()
152 return -10 * math.log10(mse) 152 return -10 * math.log10(mse)
153 153
154 -def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path): 154 +def _encode(seq, path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path):
155 compressai.set_entropy_coder(coder) 155 compressai.set_entropy_coder(coder)
156 enc_start = time.time() 156 enc_start = time.time()
157 157
...@@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou ...@@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou
182 strings.append([s[0]]) 182 strings.append([s[0]])
183 183
184 with torch.no_grad(): 184 with torch.no_grad():
185 - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) 185 + recon_out = net.decompress(strings, shape)
186 x_recon = crop(recon_out["x_hat"], (h, w)) 186 x_recon = crop(recon_out["x_hat"], (h, w))
187 187
188 psnr=compute_psnr(x, x_recon) 188 psnr=compute_psnr(x, x_recon)
189 189
190 - if i==False: 190 + #if i==False:
191 - diff=x-ref 191 + # diff=x-ref
192 - diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 192 + # diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
193 - diff_img = torch2img(diff1) 193 + # diff_img = torch2img(diff1)
194 - diff_img.save(path+"recon/diff_v1_"+str(ff)+"_q"+str(quality)+".png") 194 + # diff_img.save("../Data/train/"+seq+str(ff)+"_train_v1_q"+str(quality)+".png")
195 195
196 enc_time = time.time() - enc_start 196 enc_time = time.time() - enc_start
197 size = filesize(output) 197 size = filesize(output)
...@@ -336,15 +336,15 @@ def encode(argv): ...@@ -336,15 +336,15 @@ def encode(argv):
336 total_psnr=0.0 336 total_psnr=0.0
337 total_bpp=0.0 337 total_bpp=0.0
338 total_time=0.0 338 total_time=0.0
339 - args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" 339 + img_path =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
340 - img=args.image+"_frame"+str(0)+".png" 340 + img=img_path+"_frame"+str(0)+".png"
341 - total_psnr, total_bpp, ref,total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) 341 + total_psnr, total_bpp, ref,total_time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
342 for ff in range(1, args.frame): 342 for ff in range(1, args.frame):
343 with Path(log_path).open("a") as f: 343 with Path(log_path).open("a") as f:
344 f.write(f" {ff:3d} | ") 344 f.write(f" {ff:3d} | ")
345 - img=args.image+"_frame"+str(ff)+".png" 345 + img=img_path+"_frame"+str(ff)+".png"
346 346
347 - psnr, total_bpp, ref,time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) 347 + psnr, total_bpp, ref,time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
348 total_psnr+=psnr 348 total_psnr+=psnr
349 total_time+=time 349 total_time+=time
350 350
......
...@@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
213 strings.append([s[0]]) 213 strings.append([s[0]])
214 214
215 with torch.no_grad(): 215 with torch.no_grad():
216 - recon_out1 = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) 216 + recon_out1 = net.decompress(strings,shape)
217 x_hat1 = crop(recon_out1["x_hat"], (h, w)) 217 x_hat1 = crop(recon_out1["x_hat"], (h, w))
218 218
219 with torch.no_grad(): 219 with torch.no_grad():
...@@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
231 strings.append([s[0]]) 231 strings.append([s[0]])
232 232
233 with torch.no_grad(): 233 with torch.no_grad():
234 - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) 234 + recon_out = net.decompress(strings, shape)
235 x_hat2 = crop(recon_out["x_hat"], (h, w)) 235 x_hat2 = crop(recon_out["x_hat"], (h, w))
236 x_recon=ref+x_hat1-x_hat2 236 x_recon=ref+x_hat1-x_hat2
237 237
......
...@@ -17,6 +17,7 @@ import struct ...@@ -17,6 +17,7 @@ import struct
17 import sys 17 import sys
18 import time 18 import time
19 import math 19 import math
20 +from pytorch_msssim import ms_ssim
20 21
21 from pathlib import Path 22 from pathlib import Path
22 23
...@@ -27,7 +28,12 @@ from PIL import Image ...@@ -27,7 +28,12 @@ from PIL import Image
27 from torchvision.transforms import ToPILImage, ToTensor 28 from torchvision.transforms import ToPILImage, ToTensor
28 29
29 import compressai 30 import compressai
30 - 31 +from compressai.transforms.functional import (
32 + rgb2ycbcr,
33 + ycbcr2rgb,
34 + yuv_420_to_444,
35 + yuv_444_to_420,
36 +)
31 from compressai.zoo import models 37 from compressai.zoo import models
32 38
33 model_ids = {k: i for i, k in enumerate(models.keys())} 39 model_ids = {k: i for i, k in enumerate(models.keys())}
...@@ -151,13 +157,28 @@ def compute_psnr(a, b): ...@@ -151,13 +157,28 @@ def compute_psnr(a, b):
151 mse = torch.mean((a - b)**2).item() 157 mse = torch.mean((a - b)**2).item()
152 return -10 * math.log10(mse) 158 return -10 * math.log10(mse)
153 159
154 -def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path): 160 +def compute_msssim(a, b):
161 + return ms_ssim(a, b, data_range=1.).item()
162 +
163 +def ycbcr_psnr(a, b):
164 + yuv_a=rgb2ycbcr(a)
165 + yuv_b=rgb2ycbcr(b)
166 + a_y, a_cb, a_cr = yuv_a.chunk(3, -3)
167 + b_y, b_cb, b_cr = yuv_b.chunk(3, -3)
168 + y=compute_psnr(a_y, b_y)
169 + cb=compute_psnr(a_cb, b_cb)
170 + cr=compute_psnr(a_cr, b_cr)
171 + return (4*y+cb+cr)/6
172 +
173 +def _encode(checkpoint, path, seq, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path):
155 compressai.set_entropy_coder(coder) 174 compressai.set_entropy_coder(coder)
156 enc_start = time.time() 175 enc_start = time.time()
157 176
158 - img = load_image(image) 177 + img = load_image(image+"_frame"+str(ff)+".png")
159 start = time.time() 178 start = time.time()
160 - net = models[model](quality=quality, metric=metric, pretrained=True).eval() 179 + net = models[model](quality=quality, metric=metric, pretrained=True)
180 +
181 + net.eval()
161 load_time = time.time() - start 182 load_time = time.time() - start
162 183
163 x = img2torch(img) 184 x = img2torch(img)
...@@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
182 strings.append([s[0]]) 203 strings.append([s[0]])
183 204
184 with torch.no_grad(): 205 with torch.no_grad():
185 - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) 206 + recon_out = net.decompress(strings,shape)
186 x_recon = crop(recon_out["x_hat"], (h, w)) 207 x_recon = crop(recon_out["x_hat"], (h, w))
187 208
188 psnr=compute_psnr(x, x_recon) 209 psnr=compute_psnr(x, x_recon)
210 + ssim=compute_msssim(x, x_recon)
211 + ycbcr=ycbcr_psnr(x, x_recon)
189 else: 212 else:
213 + if checkpoint: # load from previous checkpoint
214 + checkpoint = torch.load(checkpoint)
215 + #state_dict = load_state_dict(checkpoint["state_dict"])
216 + net=models[model](quality=quality, metric=metric)
217 + net.load_state_dict(checkpoint["state_dict"])
218 + net.update(force=True)
219 + else:
220 + net = models[model](quality=quality, metric=metric, pretrained=True)
221 +
190 diff=x-ref 222 diff=x-ref
191 - #1
192 diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 223 diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
224 +
193 225
194 - #2
195 - '''
196 - diff1=torch.clamp(diff, min=0.0, max=1.0)
197 - diff2=-torch.clamp(diff, min=-1.0, max=0.0)
198 -
199 - diff1=pad(diff1, p)
200 - diff2=pad(diff2, p)
201 - '''
202 - #1
203 -
204 - with torch.no_grad():
205 - out1 = net.compress(diff1)
206 - shape1 = out1["shape"]
207 - strings = []
208 -
209 - with Path(output).open("ab") as f:
210 - # write shape and number of encoded latents
211 - write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
212 -
213 - for s in out1["strings"]:
214 - write_uints(f, (len(s[0]),))
215 - write_bytes(f, s[0])
216 - strings.append([s[0]])
217 -
218 - with torch.no_grad():
219 - recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
220 - x_hat1 = crop(recon_out["x_hat"], (h, w))
221 -
222 - #2
223 - '''
224 with torch.no_grad(): 226 with torch.no_grad():
225 out1 = net.compress(diff1) 227 out1 = net.compress(diff1)
226 shape1 = out1["shape"] 228 shape1 = out1["shape"]
...@@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
236 strings.append([s[0]]) 238 strings.append([s[0]])
237 239
238 with torch.no_grad(): 240 with torch.no_grad():
239 - recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) 241 + recon_out = net.decompress(strings, shape1)
240 x_hat1 = crop(recon_out["x_hat"], (h, w)) 242 x_hat1 = crop(recon_out["x_hat"], (h, w))
241 - with torch.no_grad(): 243 +
242 - out = net.compress(diff2)
243 - shape = out["shape"]
244 - strings = []
245 -
246 - with Path(output).open("ab") as f:
247 - # write shape and number of encoded latents
248 - write_uints(f, (shape[0], shape[1], len(out["strings"])))
249 -
250 - for s in out["strings"]:
251 - write_uints(f, (len(s[0]),))
252 - write_bytes(f, s[0])
253 - strings.append([s[0]])
254 -
255 - with torch.no_grad():
256 - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
257 - x_hat2 = crop(recon_out["x_hat"], (h, w))
258 - x_recon=ref+x_hat1-x_hat2
259 - '''
260 244
261 x_recon=ref+x_hat1-0.5 245 x_recon=ref+x_hat1-0.5
262 psnr=compute_psnr(x, x_recon) 246 psnr=compute_psnr(x, x_recon)
247 + ssim=compute_msssim(x, x_recon)
248 + ycbcr=ycbcr_psnr(x, x_recon)
263 diff_img = torch2img(diff1) 249 diff_img = torch2img(diff1)
264 - diff_img.save(path+"recon/diff"+str(ff)+"_q"+str(quality)+".png") 250 +# diff_img.save(path+"recon/"+seq+str(ff)+"_q"+str(quality)+".png")
251 +# diff_img.save("../Data/train/"+seq+str(ff)+"_train8_q"+str(quality)+".png")
265 252
266 enc_time = time.time() - enc_start 253 enc_time = time.time() - enc_start
267 size = filesize(output) 254 size = filesize(output)
...@@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
269 with Path(log_path).open("a") as f: 256 with Path(log_path).open("a") as f:
270 f.write( f" {bpp-total_bpp:.4f} | " 257 f.write( f" {bpp-total_bpp:.4f} | "
271 f" {psnr:.4f} |" 258 f" {psnr:.4f} |"
259 + f" {ssim:.4f} |"
260 + f" {ycbcr:.4f} |"
272 f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") 261 f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
273 recon_img = torch2img(x_recon) 262 recon_img = torch2img(x_recon)
274 recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png") 263 recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png")
275 264
276 - return psnr, bpp, x_recon, enc_time 265 + return psnr, bpp, x_recon, enc_time, ssim, ycbcr
277 266
278 267
279 def _decode(inputpath, coder, show, frame, output=None): 268 def _decode(inputpath, coder, show, frame, output=None):
...@@ -381,13 +370,19 @@ def encode(argv): ...@@ -381,13 +370,19 @@ def encode(argv):
381 default=768, 370 default=768,
382 help="hight setting (default: %(default))", 371 help="hight setting (default: %(default))",
383 ) 372 )
373 + parser.add_argument(
374 + "-check",
375 + "--checkpoint",
376 + type=str,
377 + help="Path to a checkpoint",
378 + )
384 parser.add_argument("-o", "--output", help="Output path") 379 parser.add_argument("-o", "--output", help="Output path")
385 args = parser.parse_args(argv) 380 args = parser.parse_args(argv)
386 path="examples/"+args.image+"/" 381 path="examples/"+args.image+"/"
387 if not args.output: 382 if not args.output:
388 #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") 383 #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
389 - args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.bin" 384 + args.output = path+args.image+"_q"+str(args.quality)+"_train_ssim.bin"
390 - log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.txt" 385 + log_path=path+args.image+"_q"+str(args.quality)+"_train_ssim.txt"
391 386
392 header = get_header(args.model, args.metric, args.quality) 387 header = get_header(args.model, args.metric, args.quality)
393 with Path(args.output).open("wb") as f: 388 with Path(args.output).open("wb") as f:
...@@ -400,32 +395,43 @@ def encode(argv): ...@@ -400,32 +395,43 @@ def encode(argv):
400 f"frames : {args.frame}\n") 395 f"frames : {args.frame}\n")
401 f.write( f"frame | bpp | " 396 f.write( f"frame | bpp | "
402 f" psnr |" 397 f" psnr |"
398 + f" ssim |"
403 f" Encoded time (model loading)\n" 399 f" Encoded time (model loading)\n"
404 f" {0:3d} | ") 400 f" {0:3d} | ")
405 401
406 total_psnr=0.0 402 total_psnr=0.0
403 + total_ssim=0.0
404 + total_ycbcr=0.0
407 total_bpp=0.0 405 total_bpp=0.0
408 total_time=0.0 406 total_time=0.0
409 - args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" 407 + img =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
410 - img=args.image+"_frame"+str(0)+".png" 408 + total_psnr, total_bpp, ref, total_time, total_ssim, total_ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
411 - total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
412 for ff in range(1, args.frame): 409 for ff in range(1, args.frame):
413 with Path(log_path).open("a") as f: 410 with Path(log_path).open("a") as f:
414 f.write(f" {ff:3d} | ") 411 f.write(f" {ff:3d} | ")
415 - img=args.image+"_frame"+str(ff)+".png" 412 + if ff%25==0:
416 - 413 + psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, ref, total_bpp, ff, args.output, log_path)
417 - psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) 414 + else:
415 + psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
418 total_psnr+=psnr 416 total_psnr+=psnr
417 + total_ssim+=ssim
418 + total_ycbcr+=ycbcr
419 total_time+=time 419 total_time+=time
420 420
421 total_psnr/=args.frame 421 total_psnr/=args.frame
422 + total_ssim/=args.frame
423 + total_ycbcr/=args.frame
422 total_bpp/=args.frame 424 total_bpp/=args.frame
423 425
424 with Path(log_path).open("a") as f: 426 with Path(log_path).open("a") as f:
425 f.write( f"\n Total Encoded time: {total_time:.2f}s\n" 427 f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
426 f"\n Total PSNR: {total_psnr:.6f}\n" 428 f"\n Total PSNR: {total_psnr:.6f}\n"
429 + f"\n Total SSIM: {total_ssim:.6f}\n"
430 + f"\n Total ycbcr: {total_ycbcr:.6f}\n"
427 f" Total BPP: {total_bpp:.6f}\n") 431 f" Total BPP: {total_bpp:.6f}\n")
428 print(total_psnr) 432 print(total_psnr)
433 + print(total_ssim)
434 + print(total_ycbcr)
429 print(total_bpp) 435 print(total_bpp)
430 436
431 437
......
1 +# Copyright 2020 InterDigital Communications, Inc.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import argparse
16 +import math
17 +import random
18 +import shutil
19 +import sys
20 +import time
21 +
22 +import torch
23 +import torch.nn as nn
24 +import torch.optim as optim
25 +
26 +from torch.utils.data import DataLoader
27 +from torchvision import transforms
28 +
29 +from compressai.datasets import ImageFolder
30 +from compressai.zoo import models
31 +import csv
32 +import cv2
33 +import numpy as np
34 +
35 +class RateDistortionLoss(nn.Module):
36 + """Custom rate distortion loss with a Lagrangian parameter."""
37 +
38 + def __init__(self, lmbda=1e-2):
39 + super().__init__()
40 + self.mse = nn.MSELoss()
41 + self.lmbda = lmbda
42 +
43 + def forward(self, output, target):
44 + N, _, H, W = target.size()
45 + out = {}
46 + num_pixels = N * H * W
47 +
48 + out["bpp_loss"] = sum(
49 + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
50 + for likelihoods in output["likelihoods"].values()
51 + )
52 + out["mse_loss"] = self.mse(output["x_hat"], target)
53 + out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]
54 +
55 + return out
56 +
57 +
58 +class AverageMeter:
59 + """Compute running average."""
60 +
61 + def __init__(self):
62 + self.val = 0
63 + self.avg = 0
64 + self.sum = 0
65 + self.count = 0
66 +
67 + def update(self, val, n=1):
68 + self.val = val
69 + self.sum += val * n
70 + self.count += n
71 + self.avg = self.sum / self.count
72 +
73 +
74 +class CustomDataParallel(nn.DataParallel):
75 + """Custom DataParallel to access the module methods."""
76 +
77 + def __getattr__(self, key):
78 + try:
79 + return super().__getattr__(key)
80 + except AttributeError:
81 + return getattr(self.module, key)
82 +
83 +
84 +def configure_optimizers(net, args):
85 + """Separate parameters for the main optimizer and the auxiliary optimizer.
86 + Return two optimizers"""
87 +
88 + parameters = set(
89 + n
90 + for n, p in net.named_parameters()
91 + if not n.endswith(".quantiles") and p.requires_grad
92 + )
93 + aux_parameters = set(
94 + n
95 + for n, p in net.named_parameters()
96 + if n.endswith(".quantiles") and p.requires_grad
97 + )
98 +
99 + # Make sure we don't have an intersection of parameters
100 + params_dict = dict(net.named_parameters())
101 + inter_params = parameters & aux_parameters
102 + union_params = parameters | aux_parameters
103 +
104 + assert len(inter_params) == 0
105 + assert len(union_params) - len(params_dict.keys()) == 0
106 +
107 + optimizer = optim.Adam(
108 + (params_dict[n] for n in sorted(list(parameters))),
109 + lr=args.learning_rate,
110 + )
111 + aux_optimizer = optim.Adam(
112 + (params_dict[n] for n in sorted(list(aux_parameters))),
113 + lr=args.aux_learning_rate,
114 + )
115 + return optimizer, aux_optimizer
116 +
117 +
118 +def train_one_epoch(
119 + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
120 +):
121 + model.train()
122 + device = next(model.parameters()).device
123 +
124 + loss = AverageMeter()
125 + bpp_loss = AverageMeter()
126 + mse_loss = AverageMeter()
127 + a_aux_loss = AverageMeter()
128 +
129 + for i, d in enumerate(train_dataloader):
130 + d = d.to(device)
131 +
132 + optimizer.zero_grad()
133 + aux_optimizer.zero_grad()
134 +
135 + out_net = model(d)
136 +
137 + out_criterion = criterion(out_net, d)
138 +
139 + bpp_loss.update(out_criterion["bpp_loss"])
140 + loss.update(out_criterion["loss"])
141 + mse_loss.update(out_criterion["mse_loss"])
142 +
143 + out_criterion["loss"].backward()
144 +
145 + if clip_max_norm > 0:
146 + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
147 + optimizer.step()
148 +
149 + aux_loss = model.aux_loss()
150 + a_aux_loss.update(aux_loss)
151 + aux_loss.backward()
152 + aux_optimizer.step()
153 +
154 + if i % 10 == 0:
155 + print(
156 + f"Train epoch {epoch}: ["
157 + f"{i*len(d)}/{len(train_dataloader.dataset)}"
158 + f" ({100. * i / len(train_dataloader):.0f}%)]"
159 + f'\tLoss: {out_criterion["loss"].item():.3f} |'
160 + f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |'
161 + f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
162 + f"\tAux loss: {aux_loss.item():.2f}"
163 + )
164 + return loss.avg, bpp_loss.avg, a_aux_loss.avg
165 +
166 +
167 +def test_epoch(epoch, test_dataloader, model, criterion):
168 + model.eval()
169 + device = next(model.parameters()).device
170 +
171 + loss = AverageMeter()
172 + bpp_loss = AverageMeter()
173 + mse_loss = AverageMeter()
174 + aux_loss = AverageMeter()
175 +
176 + with torch.no_grad():
177 + for d in test_dataloader:
178 + d = d.to(device)
179 + out_net = model(d)
180 + out_criterion = criterion(out_net, d)
181 +
182 + aux_loss.update(model.aux_loss())
183 + bpp_loss.update(out_criterion["bpp_loss"])
184 + loss.update(out_criterion["loss"])
185 + mse_loss.update(out_criterion["mse_loss"])
186 +
187 + print(
188 + f"Test epoch {epoch}: Average losses:"
189 + f"\tLoss: {loss.avg:.3f} |"
190 + f"\tMSE loss: {mse_loss.avg:.3f} |"
191 + f"\tBpp loss: {bpp_loss.avg:.2f} |"
192 + f"\tAux loss: {aux_loss.avg:.2f}\n"
193 + )
194 +
195 + return loss.avg, bpp_loss.avg, aux_loss.avg
196 +
197 +def save_checkpoint(state, is_best, q, filename="checkpoint"):
198 + torch.save(state, filename+q+".pth.tar")
199 + if is_best:
200 + shutil.copyfile( filename+q+".pth.tar", "checkpoint_best_loss"+q+".pth.tar")
201 +
202 +
203 +def parse_args(argv):
204 + parser = argparse.ArgumentParser(description="Example training script.")
205 + parser.add_argument(
206 + "-m",
207 + "--model",
208 + default="bmshj2018-hyperprior",
209 + choices=models.keys(),
210 + help="Model architecture (default: %(default)s)",
211 + )
212 + parser.add_argument(
213 + "-d", "--dataset", type=str, required=True, help="Training dataset"
214 + )
215 + parser.add_argument(
216 + "-e",
217 + "--epochs",
218 + default=100,
219 + type=int,
220 + help="Number of epochs (default: %(default)s)",
221 + )
222 + parser.add_argument(
223 + "-lr",
224 + "--learning-rate",
225 + default=1e-4,
226 + type=float,
227 + help="Learning rate (default: %(default)s)",
228 + )
229 + parser.add_argument(
230 + "-n",
231 + "--num-workers",
232 + type=int,
233 + default=0,
234 + help="Dataloaders threads (default: %(default)s)",
235 + )
236 + parser.add_argument(
237 + "--lambda",
238 + dest="lmbda",
239 + type=float,
240 + default=1e-2,
241 + help="Bit-rate distortion parameter (default: %(default)s)",
242 + )
243 + parser.add_argument(
244 + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)"
245 + )
246 + parser.add_argument(
247 + "--test-batch-size",
248 + type=int,
249 + default=64,
250 + help="Test batch size (default: %(default)s)",
251 + )
252 + parser.add_argument(
253 + "--aux-learning-rate",
254 + default=1e-3,
255 + help="Auxiliary loss learning rate (default: %(default)s)",
256 + )
257 + parser.add_argument(
258 + "--patch-size",
259 + type=int,
260 + nargs=2,
261 + default=(256, 256),
262 + help="Size of the patches to be cropped (default: %(default)s)",
263 + )
264 + parser.add_argument(
265 + "-q",
266 + "--quality",
267 + type=int,
268 + default=3,
269 + help="Quality (default: %(default)s)",
270 + )
271 + parser.add_argument("--cuda", action="store_true", help="Use cuda")
272 + parser.add_argument("--save", action="store_true", help="Save model to disk")
273 + parser.add_argument(
274 + "--seed", type=float, help="Set random seed for reproducibility"
275 + )
276 + parser.add_argument(
277 + "--clip_max_norm",
278 + default=1.0,
279 + type=float,
280 + help="gradient clipping max norm (default: %(default)s",
281 + )
282 + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint")
283 + args = parser.parse_args(argv)
284 + return args
285 +
286 +class CSVLogger():
287 + def __init__(self, fieldnames, filename='log.csv'):
288 +
289 + self.filename = filename
290 + self.csv_file = open(filename, 'a')
291 +
292 + # Write model configuration at top of csv
293 + writer = csv.writer(self.csv_file)
294 +
295 + self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
296 + # self.writer.writeheader()
297 +
298 + # self.csv_file.flush()
299 +
300 + def writerow(self, row):
301 + self.writer.writerow(row)
302 + self.csv_file.flush()
303 +
304 + def close(self):
305 + self.csv_file.close()
306 +
307 +class Blur(object):
308 + def __init__(self, k, sig):
309 + self.k = k
310 + self.sig = sig
311 +
312 + def __call__(self, img):
313 + r=np.random.rand(1)
314 + if r<0.5:
315 + img=cv2.GaussianBlur(img.numpy(), (self.k,self.k), self.sig)
316 + img=torch.from_numpy(img)
317 + return img
318 +
319 +def main(argv):
320 + args = parse_args(argv)
321 +
322 + if args.seed is not None:
323 + torch.manual_seed(args.seed)
324 + random.seed(args.seed)
325 +
326 + train_transforms = transforms.Compose(
327 + [transforms.RandomCrop(args.patch_size),
328 + transforms.RandomRotation(30),
329 + transforms.RandomHorizontalFlip(),
330 + transforms.ToTensor()]
331 + )
332 + #train_transforms.transforms.append(Blur(k=3, sig=5))
333 +
334 + test_transforms = transforms.Compose(
335 + [transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
336 + )
337 +
338 + train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
339 + test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
340 +
341 + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
342 + print(torch.cuda.is_available())
343 + print(device)
344 + train_dataloader = DataLoader(
345 + train_dataset,
346 + batch_size=args.batch_size,
347 + num_workers=args.num_workers,
348 + shuffle=True,
349 + pin_memory=(device == "cuda"),
350 + )
351 +
352 + test_dataloader = DataLoader(
353 + test_dataset,
354 + batch_size=args.test_batch_size,
355 + num_workers=args.num_workers,
356 + shuffle=False,
357 + pin_memory=(device == "cuda"),
358 + )
359 +
360 + net = models[args.model](quality=args.quality, pretrained=False)
361 + net = net.to(device)
362 +
363 + #if args.cuda and torch.cuda.device_count() > 1:
364 + # net = CustomDataParallel(net)
365 +
366 + optimizer, aux_optimizer = configure_optimizers(net, args)
367 +# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=20)
368 + lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400, 700], gamma=0.2)
369 + criterion = RateDistortionLoss(lmbda=args.lmbda)
370 +
371 + filename = "train"+str(args.quality)+".csv"
372 + csv_logger = CSVLogger(fieldnames=['epoch', 'train_loss', 'train_bpp_loss','train_aux', 'test_loss', 'test_bpp_loss', 'test_aux'], filename=filename)
373 +
374 + last_epoch = 0
375 + if args.checkpoint: # load from previous checkpoint
376 + print("Loading", args.checkpoint)
377 + checkpoint = torch.load(args.checkpoint, map_location=device)
378 + last_epoch = checkpoint["epoch"] + 1
379 + net.load_state_dict(checkpoint["state_dict"])
380 + optimizer.load_state_dict(checkpoint["optimizer"])
381 + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
382 + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
383 +
384 + for g in optimizer.param_groups:
385 + g['lr'] = 0.00001
386 + for g in aux_optimizer.param_groups:
387 + g['lr'] = 0.00001
388 +
389 + best_loss = float("inf")
390 + for epoch in range(last_epoch, args.epochs):
391 + start = time.time()
392 + print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
393 + train_loss, train_bpp_loss, train_aux = train_one_epoch(
394 + net,
395 + criterion,
396 + train_dataloader,
397 + optimizer,
398 + aux_optimizer,
399 + epoch,
400 + args.clip_max_norm,
401 + )
402 + loss, bpp_loss, aux = test_epoch(epoch, test_dataloader, net, criterion)
403 + lr_scheduler.step(loss)
404 +
405 + row = {'epoch': str(epoch), 'train_loss': str(train_loss.item()),'train_bpp_loss': str(train_bpp_loss.item()),'train_aux': str(train_aux.item()), 'test_loss': str(loss.item()), 'test_bpp_loss': str(bpp_loss.item()), 'test_aux': str(aux.item())}
406 + csv_logger.writerow(row)###
407 +
408 + is_best = loss < best_loss
409 + best_loss = min(loss, best_loss)
410 +
411 +
412 + if args.save:
413 + save_checkpoint(
414 + {
415 + "epoch": epoch,
416 + "state_dict": net.state_dict(),
417 + "loss": loss,
418 + "optimizer": optimizer.state_dict(),
419 + "aux_optimizer": aux_optimizer.state_dict(),
420 + "lr_scheduler": lr_scheduler.state_dict(),
421 + },
422 + is_best,
423 + str(args.quality)
424 + )
425 + print(f"Total TIme: {time.time() - start}")
426 + csv_logger.close()###
427 +
428 +
429 +if __name__ == "__main__":
430 + main(sys.argv[1:])
1 +# Copyright 2020 InterDigital Communications, Inc.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import argparse
16 +import math
17 +import random
18 +import shutil
19 +import sys
20 +import time
21 +
22 +import torch
23 +import torch.nn as nn
24 +import torch.optim as optim
25 +
26 +from torch.utils.data import DataLoader
27 +from torchvision import transforms
28 +
29 +from compressai.datasets import ImageFolder
30 +from compressai.zoo import models
31 +import csv
32 +import cv2
33 +import numpy as np
34 +from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
35 +
36 +class RateDistortionLoss(nn.Module):
37 + """Custom rate distortion loss with a Lagrangian parameter."""
38 +
39 + def __init__(self, lmbda=1e-2):
40 + super().__init__()
41 + self.mse = ms_ssim
42 + self.lmbda = lmbda
43 +
44 + def forward(self, output, target):
45 + N, _, H, W = target.size()
46 + out = {}
47 + num_pixels = N * H * W
48 +
49 + out["bpp_loss"] = sum(
50 + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
51 + for likelihoods in output["likelihoods"].values()
52 + )
53 + out["mse_loss"] = 1 - self.mse(output["x_hat"], target, data_range=1.)
54 + out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]
55 +
56 + return out
57 +
58 +
59 +class AverageMeter:
60 + """Compute running average."""
61 +
62 + def __init__(self):
63 + self.val = 0
64 + self.avg = 0
65 + self.sum = 0
66 + self.count = 0
67 +
68 + def update(self, val, n=1):
69 + self.val = val
70 + self.sum += val * n
71 + self.count += n
72 + self.avg = self.sum / self.count
73 +
74 +
75 +class CustomDataParallel(nn.DataParallel):
76 + """Custom DataParallel to access the module methods."""
77 +
78 + def __getattr__(self, key):
79 + try:
80 + return super().__getattr__(key)
81 + except AttributeError:
82 + return getattr(self.module, key)
83 +
84 +
85 +def configure_optimizers(net, args):
86 + """Separate parameters for the main optimizer and the auxiliary optimizer.
87 + Return two optimizers"""
88 +
89 + parameters = set(
90 + n
91 + for n, p in net.named_parameters()
92 + if not n.endswith(".quantiles") and p.requires_grad
93 + )
94 + aux_parameters = set(
95 + n
96 + for n, p in net.named_parameters()
97 + if n.endswith(".quantiles") and p.requires_grad
98 + )
99 +
100 + # Make sure we don't have an intersection of parameters
101 + params_dict = dict(net.named_parameters())
102 + inter_params = parameters & aux_parameters
103 + union_params = parameters | aux_parameters
104 +
105 + assert len(inter_params) == 0
106 + assert len(union_params) - len(params_dict.keys()) == 0
107 +
108 + optimizer = optim.Adam(
109 + (params_dict[n] for n in sorted(list(parameters))),
110 + lr=args.learning_rate,
111 + )
112 + aux_optimizer = optim.Adam(
113 + (params_dict[n] for n in sorted(list(aux_parameters))),
114 + lr=args.aux_learning_rate,
115 + )
116 + return optimizer, aux_optimizer
117 +
118 +
119 +def train_one_epoch(
120 + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
121 +):
122 + model.train()
123 + device = next(model.parameters()).device
124 +
125 + loss = AverageMeter()
126 + bpp_loss = AverageMeter()
127 + mse_loss = AverageMeter()
128 + a_aux_loss = AverageMeter()
129 +
130 + for i, d in enumerate(train_dataloader):
131 + d = d.to(device)
132 +
133 + optimizer.zero_grad()
134 + aux_optimizer.zero_grad()
135 +
136 + out_net = model(d)
137 +
138 + out_criterion = criterion(out_net, d)
139 +
140 + bpp_loss.update(out_criterion["bpp_loss"])
141 + loss.update(out_criterion["loss"])
142 + mse_loss.update(out_criterion["mse_loss"])
143 +
144 + out_criterion["loss"].backward()
145 +
146 + if clip_max_norm > 0:
147 + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
148 + optimizer.step()
149 +
150 + aux_loss = model.aux_loss()
151 + a_aux_loss.update(aux_loss)
152 + aux_loss.backward()
153 + aux_optimizer.step()
154 +
155 + if i % 10 == 0:
156 + print(
157 + f"Train epoch {epoch}: ["
158 + f"{i*len(d)}/{len(train_dataloader.dataset)}"
159 + f" ({100. * i / len(train_dataloader):.0f}%)]"
160 + f'\tLoss: {out_criterion["loss"].item():.3f} |'
161 + f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |'
162 + f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
163 + f"\tAux loss: {aux_loss.item():.2f}"
164 + )
165 + return loss.avg, bpp_loss.avg, a_aux_loss.avg
166 +
167 +
168 +def test_epoch(epoch, test_dataloader, model, criterion):
169 + model.eval()
170 + device = next(model.parameters()).device
171 +
172 + loss = AverageMeter()
173 + bpp_loss = AverageMeter()
174 + mse_loss = AverageMeter()
175 + aux_loss = AverageMeter()
176 +
177 + with torch.no_grad():
178 + for d in test_dataloader:
179 + d = d.to(device)
180 + out_net = model(d)
181 + out_criterion = criterion(out_net, d)
182 +
183 + aux_loss.update(model.aux_loss())
184 + bpp_loss.update(out_criterion["bpp_loss"])
185 + loss.update(out_criterion["loss"])
186 + mse_loss.update(out_criterion["mse_loss"])
187 +
188 + print(
189 + f"Test epoch {epoch}: Average losses:"
190 + f"\tLoss: {loss.avg:.3f} |"
191 + f"\tMSE loss: {mse_loss.avg:.3f} |"
192 + f"\tBpp loss: {bpp_loss.avg:.2f} |"
193 + f"\tAux loss: {aux_loss.avg:.2f}\n"
194 + )
195 +
196 + return loss.avg, bpp_loss.avg, aux_loss.avg
197 +
198 +def save_checkpoint(state, is_best, q, filename="checkpoint_msssim"):
199 + torch.save(state, filename+q+".pth.tar")
200 + if is_best:
201 + shutil.copyfile( filename+q+".pth.tar", "checkpoint_best_loss_msssim"+q+".pth.tar")
202 +
203 +
204 +def parse_args(argv):
205 + parser = argparse.ArgumentParser(description="Example training script.")
206 + parser.add_argument(
207 + "-m",
208 + "--model",
209 + default="bmshj2018-hyperprior",
210 + choices=models.keys(),
211 + help="Model architecture (default: %(default)s)",
212 + )
213 + parser.add_argument(
214 + "-d", "--dataset", type=str, required=True, help="Training dataset"
215 + )
216 + parser.add_argument(
217 + "-e",
218 + "--epochs",
219 + default=100,
220 + type=int,
221 + help="Number of epochs (default: %(default)s)",
222 + )
223 + parser.add_argument(
224 + "-lr",
225 + "--learning-rate",
226 + default=1e-4,
227 + type=float,
228 + help="Learning rate (default: %(default)s)",
229 + )
230 + parser.add_argument(
231 + "-n",
232 + "--num-workers",
233 + type=int,
234 + default=0,
235 + help="Dataloaders threads (default: %(default)s)",
236 + )
237 + parser.add_argument(
238 + "--lambda",
239 + dest="lmbda",
240 + type=float,
241 + default=1e-2,
242 + help="Bit-rate distortion parameter (default: %(default)s)",
243 + )
244 + parser.add_argument(
245 + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)"
246 + )
247 + parser.add_argument(
248 + "--test-batch-size",
249 + type=int,
250 + default=64,
251 + help="Test batch size (default: %(default)s)",
252 + )
253 + parser.add_argument(
254 + "--aux-learning-rate",
255 + default=1e-3,
256 + help="Auxiliary loss learning rate (default: %(default)s)",
257 + )
258 + parser.add_argument(
259 + "--patch-size",
260 + type=int,
261 + nargs=2,
262 + default=(256, 256),
263 + help="Size of the patches to be cropped (default: %(default)s)",
264 + )
265 + parser.add_argument(
266 + "-q",
267 + "--quality",
268 + type=int,
269 + default=3,
270 + help="Quality (default: %(default)s)",
271 + )
272 + parser.add_argument("--cuda", action="store_true", help="Use cuda")
273 + parser.add_argument("--save", action="store_true", help="Save model to disk")
274 + parser.add_argument(
275 + "--seed", type=float, help="Set random seed for reproducibility"
276 + )
277 + parser.add_argument(
278 + "--clip_max_norm",
279 + default=1.0,
280 + type=float,
281 + help="gradient clipping max norm (default: %(default)s",
282 + )
283 + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint")
284 + args = parser.parse_args(argv)
285 + return args
286 +
287 +class CSVLogger():
288 + def __init__(self, fieldnames, filename='log.csv'):
289 +
290 + self.filename = filename
291 + self.csv_file = open(filename, 'a')
292 +
293 + # Write model configuration at top of csv
294 + writer = csv.writer(self.csv_file)
295 +
296 + self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
297 + # self.writer.writeheader()
298 +
299 + # self.csv_file.flush()
300 +
301 + def writerow(self, row):
302 + self.writer.writerow(row)
303 + self.csv_file.flush()
304 +
305 + def close(self):
306 + self.csv_file.close()
307 +
308 +class Blur(object):
309 + def __init__(self, k, sig):
310 + self.k = k
311 + self.sig = sig
312 +
313 + def __call__(self, img):
314 + r=np.random.rand(1)
315 + if r<0.5:
316 + img=cv2.GaussianBlur(img.numpy(), (self.k,self.k), self.sig)
317 + img=torch.from_numpy(img)
318 + return img
319 +
320 +def main(argv):
321 + args = parse_args(argv)
322 +
323 + if args.seed is not None:
324 + torch.manual_seed(args.seed)
325 + random.seed(args.seed)
326 +
327 + train_transforms = transforms.Compose(
328 + [transforms.RandomCrop(args.patch_size),
329 + transforms.RandomRotation(30),
330 + transforms.RandomHorizontalFlip(),
331 + transforms.ToTensor()]
332 + )
333 + #train_transforms.transforms.append(Blur(k=3, sig=5))
334 +
335 + test_transforms = transforms.Compose(
336 + [transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
337 + )
338 +
339 + train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
340 + test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
341 +
342 + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
343 + print(torch.cuda.is_available())
344 + print(device)
345 + train_dataloader = DataLoader(
346 + train_dataset,
347 + batch_size=args.batch_size,
348 + num_workers=args.num_workers,
349 + shuffle=True,
350 + pin_memory=(device == "cuda"),
351 + )
352 +
353 + test_dataloader = DataLoader(
354 + test_dataset,
355 + batch_size=args.test_batch_size,
356 + num_workers=args.num_workers,
357 + shuffle=False,
358 + pin_memory=(device == "cuda"),
359 + )
360 +
361 + net = models[args.model](quality=args.quality, pretrained=False)
362 + net = net.to(device)
363 +
364 + #if args.cuda and torch.cuda.device_count() > 1:
365 + # net = CustomDataParallel(net)
366 +
367 + optimizer, aux_optimizer = configure_optimizers(net, args)
368 + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=20)
369 + criterion = RateDistortionLoss(lmbda=args.lmbda)
370 +
371 + filename = "train_msssim"+str(args.quality)+".csv"
372 + csv_logger = CSVLogger(fieldnames=['epoch', 'train_loss', 'train_bpp_loss','train_aux', 'test_loss', 'test_bpp_loss', 'test_aux'], filename=filename)
373 +
374 + last_epoch = 0
375 + if args.checkpoint: # load from previous checkpoint
376 + print("Loading", args.checkpoint)
377 + checkpoint = torch.load(args.checkpoint, map_location=device)
378 + last_epoch = checkpoint["epoch"] + 1
379 + net.load_state_dict(checkpoint["state_dict"])
380 + optimizer.load_state_dict(checkpoint["optimizer"])
381 + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
382 +# for g in optimizer.param_groups:
383 +# g['lr'] = 0.0001
384 +# for g in aux_optimizer.param_groups:
385 +# g['lr'] = 0.0001
386 + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
387 +
388 + best_loss = float("inf")
389 + for epoch in range(last_epoch, args.epochs):
390 + start = time.time()
391 + print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
392 + train_loss, train_bpp_loss, train_aux = train_one_epoch(
393 + net,
394 + criterion,
395 + train_dataloader,
396 + optimizer,
397 + aux_optimizer,
398 + epoch,
399 + args.clip_max_norm,
400 + )
401 + loss, bpp_loss, aux = test_epoch(epoch, test_dataloader, net, criterion)
402 + lr_scheduler.step(loss)
403 +
404 + row = {'epoch': str(epoch), 'train_loss': str(train_loss.item()),'train_bpp_loss': str(train_bpp_loss.item()),'train_aux': str(train_aux.item()), 'test_loss': str(loss.item()), 'test_bpp_loss': str(bpp_loss.item()), 'test_aux': str(aux.item())}
405 + csv_logger.writerow(row)###
406 +
407 + is_best = loss < best_loss
408 + best_loss = min(loss, best_loss)
409 +
410 +
411 + if args.save:
412 + save_checkpoint(
413 + {
414 + "epoch": epoch,
415 + "state_dict": net.state_dict(),
416 + "loss": loss,
417 + "optimizer": optimizer.state_dict(),
418 + "aux_optimizer": aux_optimizer.state_dict(),
419 + "lr_scheduler": lr_scheduler.state_dict(),
420 + },
421 + is_best,
422 + str(args.quality)
423 + )
424 + print(f"Total TIme: {time.time() - start}")
425 + csv_logger.close()###
426 +
427 +
428 +if __name__ == "__main__":
429 + main(sys.argv[1:])
1 +# Copyright 2020 InterDigital Communications, Inc.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import argparse
16 +import math
17 +import random
18 +import shutil
19 +import sys
20 +import time
21 +
22 +import torch
23 +import torch.nn as nn
24 +import torch.optim as optim
25 +
26 +from torch.utils.data import DataLoader
27 +from torchvision import transforms
28 +
29 +from compressai.datasets import ImageFolder
30 +from compressai.zoo import models
31 +import csv
32 +import cv2
33 +import numpy as np
34 +from compressai.transforms.functional import (
35 + rgb2ycbcr,
36 + ycbcr2rgb,
37 + yuv_420_to_444,
38 + yuv_444_to_420,
39 +)
40 +
41 +class RateDistortionLoss(nn.Module):
42 + """Custom rate distortion loss with a Lagrangian parameter."""
43 +# mse 함수를 4:1:1로 바꾸기
44 + def __init__(self, lmbda=1e-2):
45 + super().__init__()
46 + self.mse = nn.MSELoss()
47 + self.lmbda = lmbda
48 +
49 + def forward(self, output, target):
50 + N, _, H, W = target.size()
51 + out = {}
52 + num_pixels = N * H * W
53 +
54 + out["bpp_loss"] = sum(
55 + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
56 + for likelihoods in output["likelihoods"].values()
57 + )
58 + o_y, o_cb, o_cr = output["x_hat"].chunk(3, -3)
59 + t_y, t_cb, t_cr = target.chunk(3, -3)
60 + mse_y = self.mse(o_y, t_y)
61 + mse_cb = self.mse(o_cb, t_cb)
62 + mse_cr = self.mse(o_cr, t_cr)
63 +
64 + out["mse_loss"] = (4*mse_y+mse_cb+mse_cr)/6
65 + out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]
66 +
67 + return out
68 +
69 +
70 +
71 +class AverageMeter:
72 + """Compute running average."""
73 +
74 + def __init__(self):
75 + self.val = 0
76 + self.avg = 0
77 + self.sum = 0
78 + self.count = 0
79 +
80 + def update(self, val, n=1):
81 + self.val = val
82 + self.sum += val * n
83 + self.count += n
84 + self.avg = self.sum / self.count
85 +
86 +
87 +class CustomDataParallel(nn.DataParallel):
88 + """Custom DataParallel to access the module methods."""
89 +
90 + def __getattr__(self, key):
91 + try:
92 + return super().__getattr__(key)
93 + except AttributeError:
94 + return getattr(self.module, key)
95 +
96 +
97 +def configure_optimizers(net, args):
98 + """Separate parameters for the main optimizer and the auxiliary optimizer.
99 + Return two optimizers"""
100 +
101 + parameters = set(
102 + n
103 + for n, p in net.named_parameters()
104 + if not n.endswith(".quantiles") and p.requires_grad
105 + )
106 + aux_parameters = set(
107 + n
108 + for n, p in net.named_parameters()
109 + if n.endswith(".quantiles") and p.requires_grad
110 + )
111 +
112 + # Make sure we don't have an intersection of parameters
113 + params_dict = dict(net.named_parameters())
114 + inter_params = parameters & aux_parameters
115 + union_params = parameters | aux_parameters
116 +
117 + assert len(inter_params) == 0
118 + assert len(union_params) - len(params_dict.keys()) == 0
119 +
120 + optimizer = optim.Adam(
121 + (params_dict[n] for n in sorted(list(parameters))),
122 + lr=args.learning_rate,
123 + )
124 + aux_optimizer = optim.Adam(
125 + (params_dict[n] for n in sorted(list(aux_parameters))),
126 + lr=args.aux_learning_rate,
127 + )
128 + return optimizer, aux_optimizer
129 +
130 +
131 +def train_one_epoch(
132 + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
133 +):
134 + model.train()
135 + device = next(model.parameters()).device
136 +
137 + loss = AverageMeter()
138 + bpp_loss = AverageMeter()
139 + mse_loss = AverageMeter()
140 + a_aux_loss = AverageMeter()
141 +
142 + for i, d in enumerate(train_dataloader):
143 + d = d.to(device)
144 +
145 + optimizer.zero_grad()
146 + aux_optimizer.zero_grad()
147 +
148 + out_net = model(d)
149 +
150 + out_criterion = criterion(out_net, d)
151 +
152 + bpp_loss.update(out_criterion["bpp_loss"])
153 + loss.update(out_criterion["loss"])
154 + mse_loss.update(out_criterion["mse_loss"])
155 +
156 + out_criterion["loss"].backward()
157 +
158 + if clip_max_norm > 0:
159 + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
160 + optimizer.step()
161 +
162 + aux_loss = model.aux_loss()
163 + a_aux_loss.update(aux_loss)
164 + aux_loss.backward()
165 + aux_optimizer.step()
166 +
167 + if i % 10 == 0:
168 + print(
169 + f"Train epoch {epoch}: ["
170 + f"{i*len(d)}/{len(train_dataloader.dataset)}"
171 + f" ({100. * i / len(train_dataloader):.0f}%)]"
172 + f'\tLoss: {out_criterion["loss"].item():.3f} |'
173 + f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |'
174 + f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
175 + f"\tAux loss: {aux_loss.item():.2f}"
176 + )
177 + return loss.avg, bpp_loss.avg, a_aux_loss.avg
178 +
179 +
180 +def test_epoch(epoch, test_dataloader, model, criterion):
181 + model.eval()
182 + device = next(model.parameters()).device
183 +
184 + loss = AverageMeter()
185 + bpp_loss = AverageMeter()
186 + mse_loss = AverageMeter()
187 + aux_loss = AverageMeter()
188 +
189 + with torch.no_grad():
190 + for d in test_dataloader:
191 + d = d.to(device)
192 + out_net = model(d)
193 + out_criterion = criterion(out_net, d)
194 +
195 + aux_loss.update(model.aux_loss())
196 + bpp_loss.update(out_criterion["bpp_loss"])
197 + loss.update(out_criterion["loss"])
198 + mse_loss.update(out_criterion["mse_loss"])
199 +
200 + print(
201 + f"Test epoch {epoch}: Average losses:"
202 + f"\tLoss: {loss.avg:.3f} |"
203 + f"\tMSE loss: {mse_loss.avg:.3f} |"
204 + f"\tBpp loss: {bpp_loss.avg:.2f} |"
205 + f"\tAux loss: {aux_loss.avg:.2f}\n"
206 + )
207 +
208 + return loss.avg, bpp_loss.avg, aux_loss.avg
209 +
210 +def save_checkpoint(state, is_best, q, filename="checkpoint"):
211 + torch.save(state, filename+q+".pth.tar")
212 + if is_best:
213 + shutil.copyfile( filename+q+".pth.tar", "checkpoint_best_loss"+q+".pth.tar")
214 +
215 +
216 +def parse_args(argv):
217 + parser = argparse.ArgumentParser(description="Example training script.")
218 + parser.add_argument(
219 + "-m",
220 + "--model",
221 + default="bmshj2018-hyperprior",
222 + choices=models.keys(),
223 + help="Model architecture (default: %(default)s)",
224 + )
225 + parser.add_argument(
226 + "-d", "--dataset", type=str, required=True, help="Training dataset"
227 + )
228 + parser.add_argument(
229 + "-e",
230 + "--epochs",
231 + default=100,
232 + type=int,
233 + help="Number of epochs (default: %(default)s)",
234 + )
235 + parser.add_argument(
236 + "-lr",
237 + "--learning-rate",
238 + default=1e-4,
239 + type=float,
240 + help="Learning rate (default: %(default)s)",
241 + )
242 + parser.add_argument(
243 + "-n",
244 + "--num-workers",
245 + type=int,
246 + default=0,
247 + help="Dataloaders threads (default: %(default)s)",
248 + )
249 + parser.add_argument(
250 + "--lambda",
251 + dest="lmbda",
252 + type=float,
253 + default=1e-2,
254 + help="Bit-rate distortion parameter (default: %(default)s)",
255 + )
256 + parser.add_argument(
257 + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)"
258 + )
259 + parser.add_argument(
260 + "--test-batch-size",
261 + type=int,
262 + default=64,
263 + help="Test batch size (default: %(default)s)",
264 + )
265 + parser.add_argument(
266 + "--aux-learning-rate",
267 + default=1e-3,
268 + help="Auxiliary loss learning rate (default: %(default)s)",
269 + )
270 + parser.add_argument(
271 + "--patch-size",
272 + type=int,
273 + nargs=2,
274 + default=(256, 256),
275 + help="Size of the patches to be cropped (default: %(default)s)",
276 + )
277 + parser.add_argument(
278 + "-q",
279 + "--quality",
280 + type=int,
281 + default=3,
282 + help="Quality (default: %(default)s)",
283 + )
284 + parser.add_argument("--cuda", action="store_true", help="Use cuda")
285 + parser.add_argument("--save", action="store_true", help="Save model to disk")
286 + parser.add_argument(
287 + "--seed", type=float, help="Set random seed for reproducibility"
288 + )
289 + parser.add_argument(
290 + "--clip_max_norm",
291 + default=1.0,
292 + type=float,
293 + help="gradient clipping max norm (default: %(default)s",
294 + )
295 + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint")
296 + args = parser.parse_args(argv)
297 + return args
298 +
299 +class CSVLogger():
300 + def __init__(self, fieldnames, filename='log.csv'):
301 +
302 + self.filename = filename
303 + self.csv_file = open(filename, 'a')
304 +
305 + # Write model configuration at top of csv
306 + writer = csv.writer(self.csv_file)
307 +
308 + self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
309 + # self.writer.writeheader()
310 +
311 + # self.csv_file.flush()
312 +
313 + def writerow(self, row):
314 + self.writer.writerow(row)
315 + self.csv_file.flush()
316 +
317 + def close(self):
318 + self.csv_file.close()
319 +
320 +class Blur(object):
321 + def __init__(self, k, sig):
322 + self.k = k
323 + self.sig = sig
324 +
325 + def __call__(self, img):
326 + r=np.random.rand(1)
327 + if r<0.5:
328 + img=cv2.GaussianBlur(img.numpy(), (self.k,self.k), self.sig)
329 + img=torch.from_numpy(img)
330 + return img
331 +
332 +class RGB2YCbCr(object):
333 +
334 + def __call__(self, img):
335 + """
336 + Args:
337 + img (Tensor): Tensor image of size (C, H, W).
338 + Returns:
339 + Tensor: Image with n_holes of dimension length x length cut out of it.
340 + """
341 + img=rgb2ycbcr(img)
342 +
343 + return img
344 +
345 +def main(argv):
346 + args = parse_args(argv)
347 +
348 + if args.seed is not None:
349 + torch.manual_seed(args.seed)
350 + random.seed(args.seed)
351 +
352 + train_transforms = transforms.Compose(
353 + [transforms.RandomCrop(args.patch_size), #이미지 크기 조절
354 + transforms.RandomRotation(30),
355 + transforms.RandomHorizontalFlip(),
356 + transforms.ToTensor()] # numpy이미지에서 torch이미지로 변경
357 + )
358 + train_transforms.transforms.append(RGB2YCbCr())
359 +
360 +# print(train_transforms.shape)
361 +# train_transforms=rgb2ycbcr(train_transforms)
362 + #train_transforms.transforms.append(Blur(k=3, sig=5))
363 +
364 + test_transforms = transforms.Compose(
365 + [transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
366 + )
367 +
368 + train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
369 + test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
370 +
371 + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
372 + print(torch.cuda.is_available())
373 + print(device)
374 + train_dataloader = DataLoader(
375 + train_dataset,
376 + batch_size=args.batch_size,
377 + num_workers=args.num_workers,
378 + shuffle=True,
379 + pin_memory=(device == "cuda"),
380 + )
381 +
382 + test_dataloader = DataLoader(
383 + test_dataset,
384 + batch_size=args.test_batch_size,
385 + num_workers=args.num_workers,
386 + shuffle=False,
387 + pin_memory=(device == "cuda"),
388 + )
389 +
390 + net = models[args.model](quality=args.quality, pretrained=False)
391 + net = net.to(device)
392 +
393 + #if args.cuda and torch.cuda.device_count() > 1:
394 + # net = CustomDataParallel(net)
395 +
396 + optimizer, aux_optimizer = configure_optimizers(net, args)
397 +# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=20)
398 + lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400, 700], gamma=0.2)
399 + criterion = RateDistortionLoss(lmbda=args.lmbda)
400 +
401 + filename = "train"+str(args.quality)+".csv"
402 + csv_logger = CSVLogger(fieldnames=['epoch', 'train_loss', 'train_bpp_loss','train_aux', 'test_loss', 'test_bpp_loss', 'test_aux'], filename=filename)
403 +
404 + last_epoch = 0
405 + if args.checkpoint: # load from previous checkpoint
406 + print("Loading", args.checkpoint)
407 + checkpoint = torch.load(args.checkpoint, map_location=device)
408 + last_epoch = checkpoint["epoch"] + 1
409 + net.load_state_dict(checkpoint["state_dict"])
410 + optimizer.load_state_dict(checkpoint["optimizer"])
411 + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
412 +# for g in optimizer.param_groups:
413 +# g['lr'] = 0.0001
414 +# for g in aux_optimizer.param_groups:
415 +# g['lr'] = 0.0001
416 + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
417 +
418 + best_loss = float("inf")
419 + for epoch in range(last_epoch, args.epochs):
420 + start = time.time()
421 + print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
422 + train_loss, train_bpp_loss, train_aux = train_one_epoch(
423 + net,
424 + criterion,
425 + train_dataloader,
426 + optimizer,
427 + aux_optimizer,
428 + epoch,
429 + args.clip_max_norm,
430 + )
431 + loss, bpp_loss, aux = test_epoch(epoch, test_dataloader, net, criterion)
432 + lr_scheduler.step()
433 +
434 + row = {'epoch': str(epoch), 'train_loss': str(train_loss.item()),'train_bpp_loss': str(train_bpp_loss.item()),'train_aux': str(train_aux.item()), 'test_loss': str(loss.item()), 'test_bpp_loss': str(bpp_loss.item()), 'test_aux': str(aux.item())}
435 + csv_logger.writerow(row)###
436 +
437 + is_best = loss < best_loss
438 + best_loss = min(loss, best_loss)
439 +
440 +
441 + if args.save:
442 + save_checkpoint(
443 + {
444 + "epoch": epoch,
445 + "state_dict": net.state_dict(),
446 + "loss": loss,
447 + "optimizer": optimizer.state_dict(),
448 + "aux_optimizer": aux_optimizer.state_dict(),
449 + "lr_scheduler": lr_scheduler.state_dict(),
450 + },
451 + is_best,
452 + str(args.quality)
453 + )
454 + print(f"Total TIme: {time.time() - start}")
455 + csv_logger.close()###
456 +
457 +
458 +if __name__ == "__main__":
459 + main(sys.argv[1:])