hanbin9775
......@@ -2,7 +2,7 @@ import os
import random
import numpy as np
import scipy.misc as misc
import skimage.measure as measure
import skimage.metrics as metrics
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
......@@ -92,8 +92,9 @@ class Solver():
self.step += 1
if cfg.verbose and self.step % cfg.print_interval == 0:
if cfg.scale > 0:
psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
self.writer.add_scalar("Urban100", psnr, self.step)
psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
self.writer.add_scalar("PSNR", psnr, self.step)
self.writer.add_scalar("SSIM", ssim, self.step)
else:
psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
......@@ -107,6 +108,7 @@ class Solver():
def evaluate(self, test_data_dir, scale=2, num_step=0):
cfg = self.cfg
mean_psnr = 0
mean_ssim = 0
self.refiner.eval()
test_data = TestDataset(test_data_dir, scale=scale)
......@@ -149,15 +151,16 @@ class Solver():
hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
# evaluate PSNR
# evaluate PSNR and SSIM
# this evaluation is different to MATLAB version
# we evaluate PSNR in RGB channel not Y in YCbCR
bnd = scale
im1 = hr[bnd:-bnd, bnd:-bnd]
im2 = sr[bnd:-bnd, bnd:-bnd]
im1 = im2double(hr[bnd:-bnd, bnd:-bnd])
im2 = im2double(sr[bnd:-bnd, bnd:-bnd])
mean_psnr += psnr(im1, im2) / len(test_data)
mean_ssim += ssim(im1, im2) / len(test_data)
return mean_psnr
return mean_psnr, mean_ssim
def load(self, path):
self.refiner.load_state_dict(torch.load(path))
......@@ -177,14 +180,15 @@ class Solver():
lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
return lr
def psnr(im1, im2):
def im2double(im):
def im2double(im):
min_val, max_val = 0, 255
out = (im.astype(np.float64)-min_val) / (max_val-min_val)
return out
im1 = im2double(im1)
im2 = im2double(im2)
psnr = measure.compare_psnr(im1, im2, data_range=1)
def psnr(im1, im2):
psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1)
return psnr
def ssim(im1, im2):
ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True)
return ssim
......
......@@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg):
def main(cfg):
module = importlib.import_module("model.{}".format(cfg.model))
net = module.Net(multi_scale=True,
net = module.Net(multi_scale=False,
scale=cfg.scale,
group=cfg.group)
print(json.dumps(vars(cfg), indent=4, sort_keys=True))
state_dict = torch.load(cfg.ckpt_path)
# print(state_dict.keys())
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k
......@@ -88,11 +90,13 @@ def main(cfg):
new_state_dict[name] = v
net.load_state_dict(new_state_dict)
net.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
dataset = TestDataset(cfg.test_data_dir, cfg.scale)
with torch.no_grad():
sample(net, device, dataset, cfg)
......
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -2,10 +2,22 @@
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 1,
"id": "automotive-circus",
"metadata": {},
"outputs": [],
"outputs": [
{
"output_type": "error",
"ename": "ModuleNotFoundError",
"evalue": "No module named 'cv2'",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-1-03d1a01a87c6>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mglob\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mglob\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mgt_list\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mglob\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../bbb_sunflower_1080p/*.png\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cv2'"
]
}
],
"source": [
"from glob import glob\n",
"import cv2\n",
......
......@@ -5,7 +5,19 @@
"execution_count": 1,
"id": "ahead-paste",
"metadata": {},
"outputs": [],
"outputs": [
{
"output_type": "error",
"ename": "ModuleNotFoundError",
"evalue": "No module named 'cv2'",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-1-ff55b1ddb4f1>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mglob\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mglob\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mimages\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msorted\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mglob\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../bbb_sunflower_540p/*.png\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cv2'"
]
}
],
"source": [
"from glob import glob\n",
"import cv2\n",
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ahead-paste",
"metadata": {},
"outputs": [],
"source": [
"from glob import glob\n",
"import cv2\n",
"\n",
"images = sorted(glob(\"./tennis_test_1080p/*.png\"))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "rapid-tension",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"Path(\"./dataset/Urban100/x2\").mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "visible-texas",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 125/125 [00:18<00:00, 6.61it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"for image in tqdm(images):\n",
" hr = cv2.imread(image, cv2.IMREAD_COLOR)\n",
" lr = cv2.resize(hr, dsize=(960, 540), interpolation=cv2.INTER_CUBIC)\n",
"\n",
" cv2.imwrite(\"./dataset/Urban100/x2/\" + Path(image).stem + \"_HR.png\", hr)\n",
" cv2.imwrite(\"./dataset/Urban100/x2/\" + Path(image).stem + \"_LR.png\", lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fallen-religion",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
No preview for this file type