김재형

CARN 학습 시 PSNR, SSIM eval 코드 추가

...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
2 import random 2 import random
3 import numpy as np 3 import numpy as np
4 import scipy.misc as misc 4 import scipy.misc as misc
5 -import skimage.measure as measure 5 +import skimage.metrics as metrics
6 from tensorboardX import SummaryWriter 6 from tensorboardX import SummaryWriter
7 import torch 7 import torch
8 import torch.nn as nn 8 import torch.nn as nn
...@@ -13,39 +13,39 @@ from dataset import TrainDataset, TestDataset ...@@ -13,39 +13,39 @@ from dataset import TrainDataset, TestDataset
13 class Solver(): 13 class Solver():
14 def __init__(self, model, cfg): 14 def __init__(self, model, cfg):
15 if cfg.scale > 0: 15 if cfg.scale > 0:
16 - self.refiner = model(scale=cfg.scale, 16 + self.refiner = model(scale=cfg.scale,
17 group=cfg.group) 17 group=cfg.group)
18 else: 18 else:
19 - self.refiner = model(multi_scale=True, 19 + self.refiner = model(multi_scale=True,
20 group=cfg.group) 20 group=cfg.group)
21 - 21 +
22 - if cfg.loss_fn in ["MSE"]: 22 + if cfg.loss_fn in ["MSE"]:
23 self.loss_fn = nn.MSELoss() 23 self.loss_fn = nn.MSELoss()
24 - elif cfg.loss_fn in ["L1"]: 24 + elif cfg.loss_fn in ["L1"]:
25 self.loss_fn = nn.L1Loss() 25 self.loss_fn = nn.L1Loss()
26 elif cfg.loss_fn in ["SmoothL1"]: 26 elif cfg.loss_fn in ["SmoothL1"]:
27 self.loss_fn = nn.SmoothL1Loss() 27 self.loss_fn = nn.SmoothL1Loss()
28 28
29 self.optim = optim.Adam( 29 self.optim = optim.Adam(
30 - filter(lambda p: p.requires_grad, self.refiner.parameters()), 30 + filter(lambda p: p.requires_grad, self.refiner.parameters()),
31 cfg.lr) 31 cfg.lr)
32 - 32 +
33 - self.train_data = TrainDataset(cfg.train_data_path, 33 + self.train_data = TrainDataset(cfg.train_data_path,
34 - scale=cfg.scale, 34 + scale=cfg.scale,
35 size=cfg.patch_size) 35 size=cfg.patch_size)
36 self.train_loader = DataLoader(self.train_data, 36 self.train_loader = DataLoader(self.train_data,
37 batch_size=cfg.batch_size, 37 batch_size=cfg.batch_size,
38 num_workers=1, 38 num_workers=1,
39 shuffle=True, drop_last=True) 39 shuffle=True, drop_last=True)
40 - 40 +
41 - 41 +
42 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43 self.refiner = self.refiner.to(self.device) 43 self.refiner = self.refiner.to(self.device)
44 self.loss_fn = self.loss_fn 44 self.loss_fn = self.loss_fn
45 45
46 self.cfg = cfg 46 self.cfg = cfg
47 self.step = 0 47 self.step = 0
48 - 48 +
49 self.writer = SummaryWriter(log_dir=os.path.join("runs", cfg.ckpt_name)) 49 self.writer = SummaryWriter(log_dir=os.path.join("runs", cfg.ckpt_name))
50 if cfg.verbose: 50 if cfg.verbose:
51 num_params = 0 51 num_params = 0
...@@ -57,9 +57,9 @@ class Solver(): ...@@ -57,9 +57,9 @@ class Solver():
57 57
58 def fit(self): 58 def fit(self):
59 cfg = self.cfg 59 cfg = self.cfg
60 - refiner = nn.DataParallel(self.refiner, 60 + refiner = nn.DataParallel(self.refiner,
61 device_ids=range(cfg.num_gpu)) 61 device_ids=range(cfg.num_gpu))
62 - 62 +
63 learning_rate = cfg.lr 63 learning_rate = cfg.lr
64 while True: 64 while True:
65 for inputs in self.train_loader: 65 for inputs in self.train_loader:
...@@ -73,13 +73,13 @@ class Solver(): ...@@ -73,13 +73,13 @@ class Solver():
73 # i know this is stupid but just temporary 73 # i know this is stupid but just temporary
74 scale = random.randint(2, 4) 74 scale = random.randint(2, 4)
75 hr, lr = inputs[scale-2][0], inputs[scale-2][1] 75 hr, lr = inputs[scale-2][0], inputs[scale-2][1]
76 - 76 +
77 hr = hr.to(self.device) 77 hr = hr.to(self.device)
78 lr = lr.to(self.device) 78 lr = lr.to(self.device)
79 - 79 +
80 sr = refiner(lr, scale) 80 sr = refiner(lr, scale)
81 loss = self.loss_fn(sr, hr) 81 loss = self.loss_fn(sr, hr)
82 - 82 +
83 self.optim.zero_grad() 83 self.optim.zero_grad()
84 loss.backward() 84 loss.backward()
85 nn.utils.clip_grad_norm(self.refiner.parameters(), cfg.clip) 85 nn.utils.clip_grad_norm(self.refiner.parameters(), cfg.clip)
...@@ -88,18 +88,19 @@ class Solver(): ...@@ -88,18 +88,19 @@ class Solver():
88 learning_rate = self.decay_learning_rate() 88 learning_rate = self.decay_learning_rate()
89 for param_group in self.optim.param_groups: 89 for param_group in self.optim.param_groups:
90 param_group["lr"] = learning_rate 90 param_group["lr"] = learning_rate
91 - 91 +
92 self.step += 1 92 self.step += 1
93 if cfg.verbose and self.step % cfg.print_interval == 0: 93 if cfg.verbose and self.step % cfg.print_interval == 0:
94 if cfg.scale > 0: 94 if cfg.scale > 0:
95 - psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) 95 + psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
96 - self.writer.add_scalar("Urban100", psnr, self.step) 96 + self.writer.add_scalar("PSNR", psnr, self.step)
97 - else: 97 + self.writer.add_scalar("SSIM", ssim, self.step)
98 + else:
98 psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] 99 psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
99 self.writer.add_scalar("Urban100_2x", psnr[0], self.step) 100 self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
100 self.writer.add_scalar("Urban100_3x", psnr[1], self.step) 101 self.writer.add_scalar("Urban100_3x", psnr[1], self.step)
101 self.writer.add_scalar("Urban100_4x", psnr[2], self.step) 102 self.writer.add_scalar("Urban100_4x", psnr[2], self.step)
102 - 103 +
103 self.save(cfg.ckpt_dir, cfg.ckpt_name) 104 self.save(cfg.ckpt_dir, cfg.ckpt_name)
104 105
105 if self.step > cfg.max_steps: break 106 if self.step > cfg.max_steps: break
...@@ -107,8 +108,9 @@ class Solver(): ...@@ -107,8 +108,9 @@ class Solver():
107 def evaluate(self, test_data_dir, scale=2, num_step=0): 108 def evaluate(self, test_data_dir, scale=2, num_step=0):
108 cfg = self.cfg 109 cfg = self.cfg
109 mean_psnr = 0 110 mean_psnr = 0
111 + mean_ssim = 0
110 self.refiner.eval() 112 self.refiner.eval()
111 - 113 +
112 test_data = TestDataset(test_data_dir, scale=scale) 114 test_data = TestDataset(test_data_dir, scale=scale)
113 test_loader = DataLoader(test_data, 115 test_loader = DataLoader(test_data,
114 batch_size=1, 116 batch_size=1,
...@@ -131,13 +133,13 @@ class Solver(): ...@@ -131,13 +133,13 @@ class Solver():
131 lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop]) 133 lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
132 lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w]) 134 lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
133 lr_patch = lr_patch.to(self.device) 135 lr_patch = lr_patch.to(self.device)
134 - 136 +
135 # run refine process in here! 137 # run refine process in here!
136 sr = self.refiner(lr_patch, scale).data 138 sr = self.refiner(lr_patch, scale).data
137 - 139 +
138 h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale 140 h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
139 w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale 141 w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale
140 - 142 +
141 # merge splited patch images 143 # merge splited patch images
142 result = torch.FloatTensor(3, h, w).to(self.device) 144 result = torch.FloatTensor(3, h, w).to(self.device)
143 result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half]) 145 result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
...@@ -148,16 +150,17 @@ class Solver(): ...@@ -148,16 +150,17 @@ class Solver():
148 150
149 hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 151 hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
150 sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 152 sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
151 - 153 +
152 - # evaluate PSNR 154 + # evaluate PSNR and SSIM
153 # this evaluation is different to MATLAB version 155 # this evaluation is different to MATLAB version
154 - # we evaluate PSNR in RGB channel not Y in YCbCR 156 + # we evaluate PSNR in RGB channel not Y in YCbCR
155 bnd = scale 157 bnd = scale
156 - im1 = hr[bnd:-bnd, bnd:-bnd] 158 + im1 = im2double(hr[bnd:-bnd, bnd:-bnd])
157 - im2 = sr[bnd:-bnd, bnd:-bnd] 159 + im2 = im2double(sr[bnd:-bnd, bnd:-bnd])
158 mean_psnr += psnr(im1, im2) / len(test_data) 160 mean_psnr += psnr(im1, im2) / len(test_data)
161 + mean_ssim += ssim(im1, im2) / len(test_data)
159 162
160 - return mean_psnr 163 + return mean_psnr, mean_ssim
161 164
162 def load(self, path): 165 def load(self, path):
163 self.refiner.load_state_dict(torch.load(path)) 166 self.refiner.load_state_dict(torch.load(path))
...@@ -177,14 +180,15 @@ class Solver(): ...@@ -177,14 +180,15 @@ class Solver():
177 lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) 180 lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
178 return lr 181 return lr
179 182
183 +def im2double(im):
184 + min_val, max_val = 0, 255
185 + out = (im.astype(np.float64)-min_val) / (max_val-min_val)
186 + return out
180 187
181 def psnr(im1, im2): 188 def psnr(im1, im2):
182 - def im2double(im): 189 + psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1)
183 - min_val, max_val = 0, 255
184 - out = (im.astype(np.float64)-min_val) / (max_val-min_val)
185 - return out
186 -
187 - im1 = im2double(im1)
188 - im2 = im2double(im2)
189 - psnr = measure.compare_psnr(im1, im2, data_range=1)
190 return psnr 190 return psnr
191 +
192 +def ssim(im1, im2):
193 + ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True)
194 + return ssim
......