hanbin9775
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
2 import random 2 import random
3 import numpy as np 3 import numpy as np
4 import scipy.misc as misc 4 import scipy.misc as misc
5 -import skimage.measure as measure 5 +import skimage.metrics as metrics
6 from tensorboardX import SummaryWriter 6 from tensorboardX import SummaryWriter
7 import torch 7 import torch
8 import torch.nn as nn 8 import torch.nn as nn
...@@ -13,39 +13,39 @@ from dataset import TrainDataset, TestDataset ...@@ -13,39 +13,39 @@ from dataset import TrainDataset, TestDataset
13 class Solver(): 13 class Solver():
14 def __init__(self, model, cfg): 14 def __init__(self, model, cfg):
15 if cfg.scale > 0: 15 if cfg.scale > 0:
16 - self.refiner = model(scale=cfg.scale, 16 + self.refiner = model(scale=cfg.scale,
17 group=cfg.group) 17 group=cfg.group)
18 else: 18 else:
19 - self.refiner = model(multi_scale=True, 19 + self.refiner = model(multi_scale=True,
20 group=cfg.group) 20 group=cfg.group)
21 - 21 +
22 - if cfg.loss_fn in ["MSE"]: 22 + if cfg.loss_fn in ["MSE"]:
23 self.loss_fn = nn.MSELoss() 23 self.loss_fn = nn.MSELoss()
24 - elif cfg.loss_fn in ["L1"]: 24 + elif cfg.loss_fn in ["L1"]:
25 self.loss_fn = nn.L1Loss() 25 self.loss_fn = nn.L1Loss()
26 elif cfg.loss_fn in ["SmoothL1"]: 26 elif cfg.loss_fn in ["SmoothL1"]:
27 self.loss_fn = nn.SmoothL1Loss() 27 self.loss_fn = nn.SmoothL1Loss()
28 28
29 self.optim = optim.Adam( 29 self.optim = optim.Adam(
30 - filter(lambda p: p.requires_grad, self.refiner.parameters()), 30 + filter(lambda p: p.requires_grad, self.refiner.parameters()),
31 cfg.lr) 31 cfg.lr)
32 - 32 +
33 - self.train_data = TrainDataset(cfg.train_data_path, 33 + self.train_data = TrainDataset(cfg.train_data_path,
34 - scale=cfg.scale, 34 + scale=cfg.scale,
35 size=cfg.patch_size) 35 size=cfg.patch_size)
36 self.train_loader = DataLoader(self.train_data, 36 self.train_loader = DataLoader(self.train_data,
37 batch_size=cfg.batch_size, 37 batch_size=cfg.batch_size,
38 num_workers=1, 38 num_workers=1,
39 shuffle=True, drop_last=True) 39 shuffle=True, drop_last=True)
40 - 40 +
41 - 41 +
42 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43 self.refiner = self.refiner.to(self.device) 43 self.refiner = self.refiner.to(self.device)
44 self.loss_fn = self.loss_fn 44 self.loss_fn = self.loss_fn
45 45
46 self.cfg = cfg 46 self.cfg = cfg
47 self.step = 0 47 self.step = 0
48 - 48 +
49 self.writer = SummaryWriter(log_dir=os.path.join("runs", cfg.ckpt_name)) 49 self.writer = SummaryWriter(log_dir=os.path.join("runs", cfg.ckpt_name))
50 if cfg.verbose: 50 if cfg.verbose:
51 num_params = 0 51 num_params = 0
...@@ -57,9 +57,9 @@ class Solver(): ...@@ -57,9 +57,9 @@ class Solver():
57 57
58 def fit(self): 58 def fit(self):
59 cfg = self.cfg 59 cfg = self.cfg
60 - refiner = nn.DataParallel(self.refiner, 60 + refiner = nn.DataParallel(self.refiner,
61 device_ids=range(cfg.num_gpu)) 61 device_ids=range(cfg.num_gpu))
62 - 62 +
63 learning_rate = cfg.lr 63 learning_rate = cfg.lr
64 while True: 64 while True:
65 for inputs in self.train_loader: 65 for inputs in self.train_loader:
...@@ -73,13 +73,13 @@ class Solver(): ...@@ -73,13 +73,13 @@ class Solver():
73 # i know this is stupid but just temporary 73 # i know this is stupid but just temporary
74 scale = random.randint(2, 4) 74 scale = random.randint(2, 4)
75 hr, lr = inputs[scale-2][0], inputs[scale-2][1] 75 hr, lr = inputs[scale-2][0], inputs[scale-2][1]
76 - 76 +
77 hr = hr.to(self.device) 77 hr = hr.to(self.device)
78 lr = lr.to(self.device) 78 lr = lr.to(self.device)
79 - 79 +
80 sr = refiner(lr, scale) 80 sr = refiner(lr, scale)
81 loss = self.loss_fn(sr, hr) 81 loss = self.loss_fn(sr, hr)
82 - 82 +
83 self.optim.zero_grad() 83 self.optim.zero_grad()
84 loss.backward() 84 loss.backward()
85 nn.utils.clip_grad_norm(self.refiner.parameters(), cfg.clip) 85 nn.utils.clip_grad_norm(self.refiner.parameters(), cfg.clip)
...@@ -88,18 +88,19 @@ class Solver(): ...@@ -88,18 +88,19 @@ class Solver():
88 learning_rate = self.decay_learning_rate() 88 learning_rate = self.decay_learning_rate()
89 for param_group in self.optim.param_groups: 89 for param_group in self.optim.param_groups:
90 param_group["lr"] = learning_rate 90 param_group["lr"] = learning_rate
91 - 91 +
92 self.step += 1 92 self.step += 1
93 if cfg.verbose and self.step % cfg.print_interval == 0: 93 if cfg.verbose and self.step % cfg.print_interval == 0:
94 if cfg.scale > 0: 94 if cfg.scale > 0:
95 - psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step) 95 + psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
96 - self.writer.add_scalar("Urban100", psnr, self.step) 96 + self.writer.add_scalar("PSNR", psnr, self.step)
97 - else: 97 + self.writer.add_scalar("SSIM", ssim, self.step)
98 + else:
98 psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)] 99 psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
99 self.writer.add_scalar("Urban100_2x", psnr[0], self.step) 100 self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
100 self.writer.add_scalar("Urban100_3x", psnr[1], self.step) 101 self.writer.add_scalar("Urban100_3x", psnr[1], self.step)
101 self.writer.add_scalar("Urban100_4x", psnr[2], self.step) 102 self.writer.add_scalar("Urban100_4x", psnr[2], self.step)
102 - 103 +
103 self.save(cfg.ckpt_dir, cfg.ckpt_name) 104 self.save(cfg.ckpt_dir, cfg.ckpt_name)
104 105
105 if self.step > cfg.max_steps: break 106 if self.step > cfg.max_steps: break
...@@ -107,8 +108,9 @@ class Solver(): ...@@ -107,8 +108,9 @@ class Solver():
107 def evaluate(self, test_data_dir, scale=2, num_step=0): 108 def evaluate(self, test_data_dir, scale=2, num_step=0):
108 cfg = self.cfg 109 cfg = self.cfg
109 mean_psnr = 0 110 mean_psnr = 0
111 + mean_ssim = 0
110 self.refiner.eval() 112 self.refiner.eval()
111 - 113 +
112 test_data = TestDataset(test_data_dir, scale=scale) 114 test_data = TestDataset(test_data_dir, scale=scale)
113 test_loader = DataLoader(test_data, 115 test_loader = DataLoader(test_data,
114 batch_size=1, 116 batch_size=1,
...@@ -131,13 +133,13 @@ class Solver(): ...@@ -131,13 +133,13 @@ class Solver():
131 lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop]) 133 lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
132 lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w]) 134 lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
133 lr_patch = lr_patch.to(self.device) 135 lr_patch = lr_patch.to(self.device)
134 - 136 +
135 # run refine process in here! 137 # run refine process in here!
136 sr = self.refiner(lr_patch, scale).data 138 sr = self.refiner(lr_patch, scale).data
137 - 139 +
138 h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale 140 h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
139 w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale 141 w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale
140 - 142 +
141 # merge splited patch images 143 # merge splited patch images
142 result = torch.FloatTensor(3, h, w).to(self.device) 144 result = torch.FloatTensor(3, h, w).to(self.device)
143 result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half]) 145 result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
...@@ -148,16 +150,17 @@ class Solver(): ...@@ -148,16 +150,17 @@ class Solver():
148 150
149 hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 151 hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
150 sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 152 sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
151 - 153 +
152 - # evaluate PSNR 154 + # evaluate PSNR and SSIM
153 # this evaluation is different to MATLAB version 155 # this evaluation is different to MATLAB version
154 - # we evaluate PSNR in RGB channel not Y in YCbCR 156 + # we evaluate PSNR in RGB channel not Y in YCbCR
155 bnd = scale 157 bnd = scale
156 - im1 = hr[bnd:-bnd, bnd:-bnd] 158 + im1 = im2double(hr[bnd:-bnd, bnd:-bnd])
157 - im2 = sr[bnd:-bnd, bnd:-bnd] 159 + im2 = im2double(sr[bnd:-bnd, bnd:-bnd])
158 mean_psnr += psnr(im1, im2) / len(test_data) 160 mean_psnr += psnr(im1, im2) / len(test_data)
161 + mean_ssim += ssim(im1, im2) / len(test_data)
159 162
160 - return mean_psnr 163 + return mean_psnr, mean_ssim
161 164
162 def load(self, path): 165 def load(self, path):
163 self.refiner.load_state_dict(torch.load(path)) 166 self.refiner.load_state_dict(torch.load(path))
...@@ -177,14 +180,15 @@ class Solver(): ...@@ -177,14 +180,15 @@ class Solver():
177 lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay)) 180 lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
178 return lr 181 return lr
179 182
183 +def im2double(im):
184 + min_val, max_val = 0, 255
185 + out = (im.astype(np.float64)-min_val) / (max_val-min_val)
186 + return out
180 187
181 def psnr(im1, im2): 188 def psnr(im1, im2):
182 - def im2double(im): 189 + psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1)
183 - min_val, max_val = 0, 255
184 - out = (im.astype(np.float64)-min_val) / (max_val-min_val)
185 - return out
186 -
187 - im1 = im2double(im1)
188 - im2 = im2double(im2)
189 - psnr = measure.compare_psnr(im1, im2, data_range=1)
190 return psnr 190 return psnr
191 +
192 +def ssim(im1, im2):
193 + ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True)
194 + return ssim
......
...@@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg): ...@@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg):
76 76
77 def main(cfg): 77 def main(cfg):
78 module = importlib.import_module("model.{}".format(cfg.model)) 78 module = importlib.import_module("model.{}".format(cfg.model))
79 - net = module.Net(multi_scale=True, 79 + net = module.Net(multi_scale=False,
80 + scale=cfg.scale,
80 group=cfg.group) 81 group=cfg.group)
81 print(json.dumps(vars(cfg), indent=4, sort_keys=True)) 82 print(json.dumps(vars(cfg), indent=4, sort_keys=True))
82 83
83 state_dict = torch.load(cfg.ckpt_path) 84 state_dict = torch.load(cfg.ckpt_path)
85 + # print(state_dict.keys())
84 new_state_dict = OrderedDict() 86 new_state_dict = OrderedDict()
85 for k, v in state_dict.items(): 87 for k, v in state_dict.items():
86 name = k 88 name = k
...@@ -88,12 +90,14 @@ def main(cfg): ...@@ -88,12 +90,14 @@ def main(cfg):
88 new_state_dict[name] = v 90 new_state_dict[name] = v
89 91
90 net.load_state_dict(new_state_dict) 92 net.load_state_dict(new_state_dict)
93 + net.eval()
91 94
92 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 95 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93 net = net.to(device) 96 net = net.to(device)
94 97
95 dataset = TestDataset(cfg.test_data_dir, cfg.scale) 98 dataset = TestDataset(cfg.test_data_dir, cfg.scale)
96 - sample(net, device, dataset, cfg) 99 + with torch.no_grad():
100 + sample(net, device, dataset, cfg)
97 101
98 102
99 if __name__ == "__main__": 103 if __name__ == "__main__":
......
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -2,10 +2,22 @@ ...@@ -2,10 +2,22 @@
2 "cells": [ 2 "cells": [
3 { 3 {
4 "cell_type": "code", 4 "cell_type": "code",
5 - "execution_count": 15, 5 + "execution_count": 1,
6 "id": "automotive-circus", 6 "id": "automotive-circus",
7 "metadata": {}, 7 "metadata": {},
8 - "outputs": [], 8 + "outputs": [
9 + {
10 + "output_type": "error",
11 + "ename": "ModuleNotFoundError",
12 + "evalue": "No module named 'cv2'",
13 + "traceback": [
14 + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
15 + "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
16 + "\u001b[1;32m<ipython-input-1-03d1a01a87c6>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mglob\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mglob\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mgt_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mglob\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../bbb_sunflower_1080p/*.png\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
17 + "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cv2'"
18 + ]
19 + }
20 + ],
9 "source": [ 21 "source": [
10 "from glob import glob\n", 22 "from glob import glob\n",
11 "import cv2\n", 23 "import cv2\n",
...@@ -69,4 +81,4 @@ ...@@ -69,4 +81,4 @@
69 }, 81 },
70 "nbformat": 4, 82 "nbformat": 4,
71 "nbformat_minor": 5 83 "nbformat_minor": 5
72 -} 84 +}
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -5,7 +5,19 @@ ...@@ -5,7 +5,19 @@
5 "execution_count": 1, 5 "execution_count": 1,
6 "id": "ahead-paste", 6 "id": "ahead-paste",
7 "metadata": {}, 7 "metadata": {},
8 - "outputs": [], 8 + "outputs": [
9 + {
10 + "output_type": "error",
11 + "ename": "ModuleNotFoundError",
12 + "evalue": "No module named 'cv2'",
13 + "traceback": [
14 + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
15 + "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
16 + "\u001b[1;32m<ipython-input-1-ff55b1ddb4f1>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mglob\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mglob\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mimages\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mglob\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../bbb_sunflower_540p/*.png\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
17 + "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cv2'"
18 + ]
19 + }
20 + ],
9 "source": [ 21 "source": [
10 "from glob import glob\n", 22 "from glob import glob\n",
11 "import cv2\n", 23 "import cv2\n",
...@@ -79,4 +91,4 @@ ...@@ -79,4 +91,4 @@
79 }, 91 },
80 "nbformat": 4, 92 "nbformat": 4,
81 "nbformat_minor": 5 93 "nbformat_minor": 5
82 -} 94 +}
...\ No newline at end of file ...\ No newline at end of file
......
1 +{
2 + "cells": [
3 + {
4 + "cell_type": "code",
5 + "execution_count": 1,
6 + "id": "ahead-paste",
7 + "metadata": {},
8 + "outputs": [],
9 + "source": [
10 + "from glob import glob\n",
11 + "import cv2\n",
12 + "\n",
13 + "images = sorted(glob(\"./tennis_test_1080p/*.png\"))"
14 + ]
15 + },
16 + {
17 + "cell_type": "code",
18 + "execution_count": 2,
19 + "id": "rapid-tension",
20 + "metadata": {},
21 + "outputs": [],
22 + "source": [
23 + "from pathlib import Path\n",
24 + "Path(\"./dataset/Urban100/x2\").mkdir(parents=True, exist_ok=True)"
25 + ]
26 + },
27 + {
28 + "cell_type": "code",
29 + "execution_count": 3,
30 + "id": "visible-texas",
31 + "metadata": {},
32 + "outputs": [
33 + {
34 + "name": "stderr",
35 + "output_type": "stream",
36 + "text": [
37 + "100%|██████████| 125/125 [00:18<00:00, 6.61it/s]\n"
38 + ]
39 + }
40 + ],
41 + "source": [
42 + "from tqdm import tqdm\n",
43 + "for image in tqdm(images):\n",
44 + " hr = cv2.imread(image, cv2.IMREAD_COLOR)\n",
45 + " lr = cv2.resize(hr, dsize=(960, 540), interpolation=cv2.INTER_CUBIC)\n",
46 + "\n",
47 + " cv2.imwrite(\"./dataset/Urban100/x2/\" + Path(image).stem + \"_HR.png\", hr)\n",
48 + " cv2.imwrite(\"./dataset/Urban100/x2/\" + Path(image).stem + \"_LR.png\", lr)"
49 + ]
50 + },
51 + {
52 + "cell_type": "code",
53 + "execution_count": null,
54 + "id": "fallen-religion",
55 + "metadata": {},
56 + "outputs": [],
57 + "source": []
58 + }
59 + ],
60 + "metadata": {
61 + "kernelspec": {
62 + "display_name": "Python 3",
63 + "language": "python",
64 + "name": "python3"
65 + },
66 + "language_info": {
67 + "codemirror_mode": {
68 + "name": "ipython",
69 + "version": 3
70 + },
71 + "file_extension": ".py",
72 + "mimetype": "text/x-python",
73 + "name": "python",
74 + "nbconvert_exporter": "python",
75 + "pygments_lexer": "ipython3",
76 + "version": "3.7.7"
77 + }
78 + },
79 + "nbformat": 4,
80 + "nbformat_minor": 5
81 +}
No preview for this file type