김재형

CARN 학습 시 PSNR, SSIM eval 코드 추가

......@@ -2,7 +2,7 @@ import os
import random
import numpy as np
import scipy.misc as misc
import skimage.measure as measure
import skimage.metrics as metrics
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
......@@ -13,39 +13,39 @@ from dataset import TrainDataset, TestDataset
class Solver():
def __init__(self, model, cfg):
if cfg.scale > 0:
self.refiner = model(scale=cfg.scale,
self.refiner = model(scale=cfg.scale,
group=cfg.group)
else:
self.refiner = model(multi_scale=True,
self.refiner = model(multi_scale=True,
group=cfg.group)
if cfg.loss_fn in ["MSE"]:
if cfg.loss_fn in ["MSE"]:
self.loss_fn = nn.MSELoss()
elif cfg.loss_fn in ["L1"]:
elif cfg.loss_fn in ["L1"]:
self.loss_fn = nn.L1Loss()
elif cfg.loss_fn in ["SmoothL1"]:
self.loss_fn = nn.SmoothL1Loss()
self.optim = optim.Adam(
filter(lambda p: p.requires_grad, self.refiner.parameters()),
filter(lambda p: p.requires_grad, self.refiner.parameters()),
cfg.lr)
self.train_data = TrainDataset(cfg.train_data_path,
scale=cfg.scale,
self.train_data = TrainDataset(cfg.train_data_path,
scale=cfg.scale,
size=cfg.patch_size)
self.train_loader = DataLoader(self.train_data,
batch_size=cfg.batch_size,
num_workers=1,
shuffle=True, drop_last=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.refiner = self.refiner.to(self.device)
self.loss_fn = self.loss_fn
self.cfg = cfg
self.step = 0
self.writer = SummaryWriter(log_dir=os.path.join("runs", cfg.ckpt_name))
if cfg.verbose:
num_params = 0
......@@ -57,9 +57,9 @@ class Solver():
def fit(self):
cfg = self.cfg
refiner = nn.DataParallel(self.refiner,
refiner = nn.DataParallel(self.refiner,
device_ids=range(cfg.num_gpu))
learning_rate = cfg.lr
while True:
for inputs in self.train_loader:
......@@ -73,13 +73,13 @@ class Solver():
# i know this is stupid but just temporary
scale = random.randint(2, 4)
hr, lr = inputs[scale-2][0], inputs[scale-2][1]
hr = hr.to(self.device)
lr = lr.to(self.device)
sr = refiner(lr, scale)
loss = self.loss_fn(sr, hr)
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(self.refiner.parameters(), cfg.clip)
......@@ -88,18 +88,19 @@ class Solver():
learning_rate = self.decay_learning_rate()
for param_group in self.optim.param_groups:
param_group["lr"] = learning_rate
self.step += 1
if cfg.verbose and self.step % cfg.print_interval == 0:
if cfg.scale > 0:
psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
self.writer.add_scalar("Urban100", psnr, self.step)
else:
psnr, ssim = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
self.writer.add_scalar("PSNR", psnr, self.step)
self.writer.add_scalar("SSIM", ssim, self.step)
else:
psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
self.writer.add_scalar("Urban100_3x", psnr[1], self.step)
self.writer.add_scalar("Urban100_4x", psnr[2], self.step)
self.save(cfg.ckpt_dir, cfg.ckpt_name)
if self.step > cfg.max_steps: break
......@@ -107,8 +108,9 @@ class Solver():
def evaluate(self, test_data_dir, scale=2, num_step=0):
cfg = self.cfg
mean_psnr = 0
mean_ssim = 0
self.refiner.eval()
test_data = TestDataset(test_data_dir, scale=scale)
test_loader = DataLoader(test_data,
batch_size=1,
......@@ -131,13 +133,13 @@ class Solver():
lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
lr_patch = lr_patch.to(self.device)
# run refine process in here!
sr = self.refiner(lr_patch, scale).data
h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale
# merge splited patch images
result = torch.FloatTensor(3, h, w).to(self.device)
result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
......@@ -148,16 +150,17 @@ class Solver():
hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
# evaluate PSNR
# evaluate PSNR and SSIM
# this evaluation is different to MATLAB version
# we evaluate PSNR in RGB channel not Y in YCbCR
# we evaluate PSNR in RGB channel not Y in YCbCR
bnd = scale
im1 = hr[bnd:-bnd, bnd:-bnd]
im2 = sr[bnd:-bnd, bnd:-bnd]
im1 = im2double(hr[bnd:-bnd, bnd:-bnd])
im2 = im2double(sr[bnd:-bnd, bnd:-bnd])
mean_psnr += psnr(im1, im2) / len(test_data)
mean_ssim += ssim(im1, im2) / len(test_data)
return mean_psnr
return mean_psnr, mean_ssim
def load(self, path):
self.refiner.load_state_dict(torch.load(path))
......@@ -177,14 +180,15 @@ class Solver():
lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
return lr
def im2double(im):
min_val, max_val = 0, 255
out = (im.astype(np.float64)-min_val) / (max_val-min_val)
return out
def psnr(im1, im2):
def im2double(im):
min_val, max_val = 0, 255
out = (im.astype(np.float64)-min_val) / (max_val-min_val)
return out
im1 = im2double(im1)
im2 = im2double(im2)
psnr = measure.compare_psnr(im1, im2, data_range=1)
psnr = metrics.peak_signal_noise_ratio(im1, im2, data_range=1)
return psnr
def ssim(im1, im2):
ssim = metrics.structural_similarity(im1, im2, data_range=1, multichannel=True)
return ssim
......