김재형
*.swp
DS_Store
__pycache__
MIT License
Copyright (c) 2018 Namhyuk Ahn
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network
Namhyuk Ahn, Byungkon Kang, Kyung-Ah Sohn.<br>
European Conference on Computer Vision (ECCV), 2018.
[[arXiv](https://arxiv.org/abs/1803.08664)]
<img src="assets/benchmark.png">
### Abstract
In recent years, deep learning methods have been successfully applied to single-image super-resolution tasks. Despite their great performances, deep learning methods cannot be easily applied to real-world applications due to the requirement of heavy computation. In this paper, we address this issue by proposing an accurate and lightweight deep learning model for image super-resolution. In detail, we design an architecture that implements a cascading mechanism upon a residual network. We also present a variant model of the proposed cascading residual network to further improve efficiency. Our extensive experiments show that even with much fewer parameters and operations, our models achieve performance comparable to that of state-of-the-art methods.
### FAQs
1. Can't reproduce PSNR/SSIM as recorded in the paper: See [issue#6](https://github.com/nmhkahn/CARN-pytorch/issues/6)
### Requirements
- Python 3
- [PyTorch](https://github.com/pytorch/pytorch) (0.4.0), [torchvision](https://github.com/pytorch/vision)
- Numpy, Scipy
- Pillow, Scikit-image
- h5py
- importlib
### Dataset
We use DIV2K dataset for training and Set5, Set14, B100 and Urban100 dataset for the benchmark test. Here are the following steps to prepare datasets.
1. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K) and unzip on `dataset` directory as below:
```
dataset
└── DIV2K
├── DIV2K_train_HR
├── DIV2K_train_LR_bicubic
├── DIV2K_valid_HR
└── DIV2K_valid_LR_bicubic
```
2. To accelerate training, we first convert training images to h5 format as follow (h5py module has to be installed).
```shell
$ cd datasets && python div2h5.py
```
3. Other benchmark datasets can be downloaded in [Google Drive](https://drive.google.com/drive/folders/1t2le0-Wz7GZQ4M2mJqmRamw5o4ce2AVw?usp=sharing). Same as DIV2K, please put all the datasets in `dataset` directory.
### Test Pretrained Models
We provide the pretrained models in `checkpoint` directory. To test CARN on benchmark dataset:
```shell
$ python carn/sample.py --model carn \
--test_data_dir dataset/<dataset> \
--scale [2|3|4] \
--ckpt_path ./checkpoint/<path>.pth \
--sample_dir <sample_dir>
```
and for CARN-M,
```shell
$ python carn/sample.py --model carn_m \
--test_data_dir dataset/<dataset> \
--scale [2|3|4] \
--ckpt_path ./checkpoint/<path>.pth \
--sample_dir <sample_dir> \
--group 4
```
We provide our results on four benchmark dataset (Set5, Set14, B100 and Urban100). [Google Drive](https://drive.google.com/drive/folders/1R4vZMs3Adf8UlYbIzStY98qlsl5y1wxH?usp=sharing)
### Training Models
Here are our settings to train CARN and CARN-M. Note: We use two GPU to utilize large batch size, but if OOM error arise, please reduce batch size.
```shell
# For CARN
$ python carn/train.py --patch_size 64 \
--batch_size 64 \
--max_steps 600000 \
--decay 400000 \
--model carn \
--ckpt_name carn \
--ckpt_dir checkpoint/carn \
--scale 0 \
--num_gpu 2
# For CARN-M
$ python carn/train.py --patch_size 64 \
--batch_size 64 \
--max_steps 600000 \
--decay 400000 \
--model carn_m \
--ckpt_name carn_m \
--ckpt_dir checkpoint/carn_m \
--scale 0 \
--group 4 \
--num_gpu 2
```
In the `--scale` argument, [2, 3, 4] is for single-scale training and 0 for multi-scale learning. `--group` represents group size of group convolution. The differences from previous version are: 1) we increase batch size and patch size to 64 and 64. 2) Instead of using `reduce_upsample` argument which replace 3x3 conv of the upsample block to 1x1, we use group convolution as same way to the efficient residual block.
### Results
**Note:** As pointed out in [#2](https://github.com/nmhkahn/CARN-pytorch/issues/2), previous Urban100 benchmark dataset was incorrect. The issue is related to the mismatch of the HR image resolution from the original dataset in x2 and x3 scale. We correct this problem, and provided dataset and results are fixed ones.
<img src="assets/table.png">
<img src="assets/visual.png">
### Citation
```
@article{ahn2018fast,
title={Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network},
author={Ahn, Namhyuk and Kang, Byungkon and Sohn, Kyung-Ah},
journal={arXiv preprint arXiv:1803.08664},
year={2018}
}
```
File mode changed
import os
import glob
import h5py
import random
import numpy as np
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms
def random_crop(hr, lr, size, scale):
h, w = lr.shape[:-1]
x = random.randint(0, w-size)
y = random.randint(0, h-size)
hsize = size*scale
hx, hy = x*scale, y*scale
crop_lr = lr[y:y+size, x:x+size].copy()
crop_hr = hr[hy:hy+hsize, hx:hx+hsize].copy()
return crop_hr, crop_lr
def random_flip_and_rotate(im1, im2):
if random.random() < 0.5:
im1 = np.flipud(im1)
im2 = np.flipud(im2)
if random.random() < 0.5:
im1 = np.fliplr(im1)
im2 = np.fliplr(im2)
angle = random.choice([0, 1, 2, 3])
im1 = np.rot90(im1, angle)
im2 = np.rot90(im2, angle)
# have to copy before be called by transform function
return im1.copy(), im2.copy()
class TrainDataset(data.Dataset):
def __init__(self, path, size, scale):
super(TrainDataset, self).__init__()
self.size = size
h5f = h5py.File(path, "r")
self.hr = [v[:] for v in h5f["HR"].values()]
# perform multi-scale training
if scale == 0:
self.scale = [2, 3, 4]
self.lr = [[v[:] for v in h5f["X{}".format(i)].values()] for i in self.scale]
else:
self.scale = [scale]
self.lr = [[v[:] for v in h5f["X{}".format(scale)].values()]]
h5f.close()
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
size = self.size
item = [(self.hr[index], self.lr[i][index]) for i, _ in enumerate(self.lr)]
item = [random_crop(hr, lr, size, self.scale[i]) for i, (hr, lr) in enumerate(item)]
item = [random_flip_and_rotate(hr, lr) for hr, lr in item]
return [(self.transform(hr), self.transform(lr)) for hr, lr in item]
def __len__(self):
return len(self.hr)
class TestDataset(data.Dataset):
def __init__(self, dirname, scale):
super(TestDataset, self).__init__()
self.name = dirname.split("/")[-1]
self.scale = scale
if "DIV" in self.name:
self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png"))
self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname),
"X{}/*.png".format(scale)))
else:
all_files = glob.glob(os.path.join(dirname, "x{}/*.png".format(scale)))
self.hr = [name for name in all_files if "HR" in name]
self.lr = [name for name in all_files if "LR" in name]
self.hr.sort()
self.lr.sort()
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
hr = Image.open(self.hr[index])
lr = Image.open(self.lr[index])
hr = hr.convert("RGB")
lr = lr.convert("RGB")
filename = self.hr[index].split("/")[-1]
return self.transform(hr), self.transform(lr), filename
def __len__(self):
return len(self.hr)
import torch
import torch.nn as nn
import model.ops as ops
class Block(nn.Module):
def __init__(self,
in_channels, out_channels,
group=1):
super(Block, self).__init__()
self.b1 = ops.ResidualBlock(64, 64)
self.b2 = ops.ResidualBlock(64, 64)
self.b3 = ops.ResidualBlock(64, 64)
self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
def forward(self, x):
c0 = o0 = x
b1 = self.b1(o0)
c1 = torch.cat([c0, b1], dim=1)
o1 = self.c1(c1)
b2 = self.b2(o1)
c2 = torch.cat([c1, b2], dim=1)
o2 = self.c2(c2)
b3 = self.b3(o2)
c3 = torch.cat([c2, b3], dim=1)
o3 = self.c3(c3)
return o3
class Net(nn.Module):
def __init__(self, **kwargs):
super(Net, self).__init__()
scale = kwargs.get("scale")
multi_scale = kwargs.get("multi_scale")
group = kwargs.get("group", 1)
self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
self.entry = nn.Conv2d(3, 64, 3, 1, 1)
self.b1 = Block(64, 64)
self.b2 = Block(64, 64)
self.b3 = Block(64, 64)
self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
self.upsample = ops.UpsampleBlock(64, scale=scale,
multi_scale=multi_scale,
group=group)
self.exit = nn.Conv2d(64, 3, 3, 1, 1)
def forward(self, x, scale):
x = self.sub_mean(x)
x = self.entry(x)
c0 = o0 = x
b1 = self.b1(o0)
c1 = torch.cat([c0, b1], dim=1)
o1 = self.c1(c1)
b2 = self.b2(o1)
c2 = torch.cat([c1, b2], dim=1)
o2 = self.c2(c2)
b3 = self.b3(o2)
c3 = torch.cat([c2, b3], dim=1)
o3 = self.c3(c3)
out = self.upsample(o3, scale=scale)
out = self.exit(out)
out = self.add_mean(out)
return out
import torch
import torch.nn as nn
import model.ops as ops
class Block(nn.Module):
def __init__(self,
in_channels, out_channels,
group=1):
super(Block, self).__init__()
self.b1 = ops.EResidualBlock(64, 64, group=group)
self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
def forward(self, x):
c0 = o0 = x
b1 = self.b1(o0)
c1 = torch.cat([c0, b1], dim=1)
o1 = self.c1(c1)
b2 = self.b1(o1)
c2 = torch.cat([c1, b2], dim=1)
o2 = self.c2(c2)
b3 = self.b1(o2)
c3 = torch.cat([c2, b3], dim=1)
o3 = self.c3(c3)
return o3
class Net(nn.Module):
def __init__(self, **kwargs):
super(Net, self).__init__()
scale = kwargs.get("scale")
multi_scale = kwargs.get("multi_scale")
group = kwargs.get("group", 1)
self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
self.entry = nn.Conv2d(3, 64, 3, 1, 1)
self.b1 = Block(64, 64, group=group)
self.b2 = Block(64, 64, group=group)
self.b3 = Block(64, 64, group=group)
self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0)
self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0)
self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)
self.upsample = ops.UpsampleBlock(64, scale=scale,
multi_scale=multi_scale,
group=group)
self.exit = nn.Conv2d(64, 3, 3, 1, 1)
def forward(self, x, scale):
x = self.sub_mean(x)
x = self.entry(x)
c0 = o0 = x
b1 = self.b1(o0)
c1 = torch.cat([c0, b1], dim=1)
o1 = self.c1(c1)
b2 = self.b2(o1)
c2 = torch.cat([c1, b2], dim=1)
o2 = self.c2(c2)
b3 = self.b3(o2)
c3 = torch.cat([c2, b3], dim=1)
o3 = self.c3(c3)
out = self.upsample(o3, scale=scale)
out = self.exit(out)
out = self.add_mean(out)
return out
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
def init_weights(modules):
pass
class MeanShift(nn.Module):
def __init__(self, mean_rgb, sub):
super(MeanShift, self).__init__()
sign = -1 if sub else 1
r = mean_rgb[0] * sign
g = mean_rgb[1] * sign
b = mean_rgb[2] * sign
self.shifter = nn.Conv2d(3, 3, 1, 1, 0)
self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.shifter.bias.data = torch.Tensor([r, g, b])
# Freeze the mean shift layer
for params in self.shifter.parameters():
params.requires_grad = False
def forward(self, x):
x = self.shifter(x)
return x
class BasicBlock(nn.Module):
def __init__(self,
in_channels, out_channels,
ksize=3, stride=1, pad=1):
super(BasicBlock, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
nn.ReLU(inplace=True)
)
init_weights(self.modules)
def forward(self, x):
out = self.body(x)
return out
class ResidualBlock(nn.Module):
def __init__(self,
in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
)
init_weights(self.modules)
def forward(self, x):
out = self.body(x)
out = F.relu(out + x)
return out
class EResidualBlock(nn.Module):
def __init__(self,
in_channels, out_channels,
group=1):
super(EResidualBlock, self).__init__()
self.body = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 1, 1, 0),
)
init_weights(self.modules)
def forward(self, x):
out = self.body(x)
out = F.relu(out + x)
return out
class UpsampleBlock(nn.Module):
def __init__(self,
n_channels, scale, multi_scale,
group=1):
super(UpsampleBlock, self).__init__()
if multi_scale:
self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
else:
self.up = _UpsampleBlock(n_channels, scale=scale, group=group)
self.multi_scale = multi_scale
def forward(self, x, scale):
if self.multi_scale:
if scale == 2:
return self.up2(x)
elif scale == 3:
return self.up3(x)
elif scale == 4:
return self.up4(x)
else:
return self.up(x)
class _UpsampleBlock(nn.Module):
def __init__(self,
n_channels, scale,
group=1):
super(_UpsampleBlock, self).__init__()
modules = []
if scale == 2 or scale == 4 or scale == 8:
for _ in range(int(math.log(scale, 2))):
modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
modules += [nn.PixelShuffle(2)]
elif scale == 3:
modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
modules += [nn.PixelShuffle(3)]
self.body = nn.Sequential(*modules)
init_weights(self.modules)
def forward(self, x):
out = self.body(x)
return out
import os
import json
import time
import importlib
import argparse
import numpy as np
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.autograd import Variable
from dataset import TestDataset
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--ckpt_path", type=str)
parser.add_argument("--group", type=int, default=1)
parser.add_argument("--sample_dir", type=str)
parser.add_argument("--test_data_dir", type=str, default="dataset/Urban100")
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--scale", type=int, default=4)
parser.add_argument("--shave", type=int, default=20)
return parser.parse_args()
def save_image(tensor, filename):
tensor = tensor.cpu()
ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
im = Image.fromarray(ndarr)
im.save(filename)
def sample(net, device, dataset, cfg):
scale = cfg.scale
for step, (hr, lr, name) in enumerate(dataset):
if "DIV2K" in dataset.name:
t1 = time.time()
h, w = lr.size()[1:]
h_half, w_half = int(h/2), int(w/2)
h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave
lr_patch = torch.tensor((4, 3, h_chop, w_chop), dtype=torch.float)
lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop])
lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w])
lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
lr_patch = lr_patch.to(device)
sr = net(lr_patch, cfg.scale).detach()
h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale
result = torch.tensor((3, h, w), dtype=torch.float).to(device)
result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop])
result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half])
result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop])
sr = result
t2 = time.time()
else:
t1 = time.time()
lr = lr.unsqueeze(0).to(device)
sr = net(lr, cfg.scale).detach().squeeze(0)
lr = lr.squeeze(0)
t2 = time.time()
model_name = cfg.ckpt_path.split(".")[0].split("/")[-1]
sr_dir = os.path.join(cfg.sample_dir,
model_name,
cfg.test_data_dir.split("/")[-1],
"x{}".format(cfg.scale),
"SR")
hr_dir = os.path.join(cfg.sample_dir,
model_name,
cfg.test_data_dir.split("/")[-1],
"x{}".format(cfg.scale),
"HR")
os.makedirs(sr_dir, exist_ok=True)
os.makedirs(hr_dir, exist_ok=True)
sr_im_path = os.path.join(sr_dir, "{}".format(name.replace("HR", "SR")))
hr_im_path = os.path.join(hr_dir, "{}".format(name))
save_image(sr, sr_im_path)
save_image(hr, hr_im_path)
print("Saved {} ({}x{} -> {}x{}, {:.3f}s)"
.format(sr_im_path, lr.shape[1], lr.shape[2], sr.shape[1], sr.shape[2], t2-t1))
def main(cfg):
module = importlib.import_module("model.{}".format(cfg.model))
net = module.Net(multi_scale=True,
group=cfg.group)
print(json.dumps(vars(cfg), indent=4, sort_keys=True))
state_dict = torch.load(cfg.ckpt_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k
# name = k[7:] # remove "module."
new_state_dict[name] = v
net.load_state_dict(new_state_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
dataset = TestDataset(cfg.test_data_dir, cfg.scale)
sample(net, device, dataset, cfg)
if __name__ == "__main__":
cfg = parse_args()
main(cfg)
import os
import random
import numpy as np
import scipy.misc as misc
import skimage.measure as measure
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import TrainDataset, TestDataset
class Solver():
def __init__(self, model, cfg):
if cfg.scale > 0:
self.refiner = model(scale=cfg.scale,
group=cfg.group)
else:
self.refiner = model(multi_scale=True,
group=cfg.group)
if cfg.loss_fn in ["MSE"]:
self.loss_fn = nn.MSELoss()
elif cfg.loss_fn in ["L1"]:
self.loss_fn = nn.L1Loss()
elif cfg.loss_fn in ["SmoothL1"]:
self.loss_fn = nn.SmoothL1Loss()
self.optim = optim.Adam(
filter(lambda p: p.requires_grad, self.refiner.parameters()),
cfg.lr)
self.train_data = TrainDataset(cfg.train_data_path,
scale=cfg.scale,
size=cfg.patch_size)
self.train_loader = DataLoader(self.train_data,
batch_size=cfg.batch_size,
num_workers=1,
shuffle=True, drop_last=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.refiner = self.refiner.to(self.device)
self.loss_fn = self.loss_fn
self.cfg = cfg
self.step = 0
self.writer = SummaryWriter(log_dir=os.path.join("runs", cfg.ckpt_name))
if cfg.verbose:
num_params = 0
for param in self.refiner.parameters():
num_params += param.nelement()
print("# of params:", num_params)
os.makedirs(cfg.ckpt_dir, exist_ok=True)
def fit(self):
cfg = self.cfg
refiner = nn.DataParallel(self.refiner,
device_ids=range(cfg.num_gpu))
learning_rate = cfg.lr
while True:
for inputs in self.train_loader:
self.refiner.train()
if cfg.scale > 0:
scale = cfg.scale
hr, lr = inputs[-1][0], inputs[-1][1]
else:
# only use one of multi-scale data
# i know this is stupid but just temporary
scale = random.randint(2, 4)
hr, lr = inputs[scale-2][0], inputs[scale-2][1]
hr = hr.to(self.device)
lr = lr.to(self.device)
sr = refiner(lr, scale)
loss = self.loss_fn(sr, hr)
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(self.refiner.parameters(), cfg.clip)
self.optim.step()
learning_rate = self.decay_learning_rate()
for param_group in self.optim.param_groups:
param_group["lr"] = learning_rate
self.step += 1
if cfg.verbose and self.step % cfg.print_interval == 0:
if cfg.scale > 0:
psnr = self.evaluate("dataset/Urban100", scale=cfg.scale, num_step=self.step)
self.writer.add_scalar("Urban100", psnr, self.step)
else:
psnr = [self.evaluate("dataset/Urban100", scale=i, num_step=self.step) for i in range(2, 5)]
self.writer.add_scalar("Urban100_2x", psnr[0], self.step)
self.writer.add_scalar("Urban100_3x", psnr[1], self.step)
self.writer.add_scalar("Urban100_4x", psnr[2], self.step)
self.save(cfg.ckpt_dir, cfg.ckpt_name)
if self.step > cfg.max_steps: break
def evaluate(self, test_data_dir, scale=2, num_step=0):
cfg = self.cfg
mean_psnr = 0
self.refiner.eval()
test_data = TestDataset(test_data_dir, scale=scale)
test_loader = DataLoader(test_data,
batch_size=1,
num_workers=1,
shuffle=False)
for step, inputs in enumerate(test_loader):
hr = inputs[0].squeeze(0)
lr = inputs[1].squeeze(0)
name = inputs[2][0]
h, w = lr.size()[1:]
h_half, w_half = int(h/2), int(w/2)
h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave
# split large image to 4 patch to avoid OOM error
lr_patch = torch.FloatTensor(4, 3, h_chop, w_chop)
lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop])
lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w])
lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
lr_patch = lr_patch.to(self.device)
# run refine process in here!
sr = self.refiner(lr_patch, scale).data
h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale
# merge splited patch images
result = torch.FloatTensor(3, h, w).to(self.device)
result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop])
result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half])
result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop])
sr = result
hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
# evaluate PSNR
# this evaluation is different to MATLAB version
# we evaluate PSNR in RGB channel not Y in YCbCR
bnd = scale
im1 = hr[bnd:-bnd, bnd:-bnd]
im2 = sr[bnd:-bnd, bnd:-bnd]
mean_psnr += psnr(im1, im2) / len(test_data)
return mean_psnr
def load(self, path):
self.refiner.load_state_dict(torch.load(path))
splited = path.split(".")[0].split("_")[-1]
try:
self.step = int(path.split(".")[0].split("_")[-1])
except ValueError:
self.step = 0
print("Load pretrained {} model".format(path))
def save(self, ckpt_dir, ckpt_name):
save_path = os.path.join(
ckpt_dir, "{}_{}.pth".format(ckpt_name, self.step))
torch.save(self.refiner.state_dict(), save_path)
def decay_learning_rate(self):
lr = self.cfg.lr * (0.5 ** (self.step // self.cfg.decay))
return lr
def psnr(im1, im2):
def im2double(im):
min_val, max_val = 0, 255
out = (im.astype(np.float64)-min_val) / (max_val-min_val)
return out
im1 = im2double(im1)
im2 = im2double(im2)
psnr = measure.compare_psnr(im1, im2, data_range=1)
return psnr
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import json
import argparse
import importlib
from solver import Solver
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--ckpt_name", type=str)
parser.add_argument("--print_interval", type=int, default=1000)
parser.add_argument("--train_data_path", type=str,
default="dataset/DIV2K_train.h5")
parser.add_argument("--ckpt_dir", type=str,
default="checkpoint")
parser.add_argument("--sample_dir", type=str,
default="sample/")
parser.add_argument("--num_gpu", type=int, default=1)
parser.add_argument("--shave", type=int, default=20)
parser.add_argument("--scale", type=int, default=2)
parser.add_argument("--verbose", action="store_true", default="store_true")
parser.add_argument("--group", type=int, default=1)
parser.add_argument("--patch_size", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--max_steps", type=int, default=200000)
parser.add_argument("--decay", type=int, default=150000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--clip", type=float, default=10.0)
parser.add_argument("--loss_fn", type=str,
choices=["MSE", "L1", "SmoothL1"], default="L1")
return parser.parse_args()
def main(cfg):
# dynamic import using --model argument
net = importlib.import_module("model.{}".format(cfg.model)).Net
print(json.dumps(vars(cfg), indent=4, sort_keys=True))
solver = Solver(net, cfg)
solver.fit()
if __name__ == "__main__":
cfg = parse_args()
main(cfg)
No preview for this file type
No preview for this file type
*
!.gitignore
!div2h5.py
import os
import glob
import h5py
import scipy.misc as misc
import numpy as np
dataset_dir = "DIV2K/"
dataset_type = "train"
f = h5py.File("DIV2K_{}.h5".format(dataset_type), "w")
dt = h5py.special_dtype(vlen=np.dtype('uint8'))
for subdir in ["HR", "X2", "X3", "X4"]:
if subdir in ["HR"]:
im_paths = glob.glob(os.path.join(dataset_dir,
"DIV2K_{}_HR".format(dataset_type),
"*.png"))
else:
im_paths = glob.glob(os.path.join(dataset_dir,
"DIV2K_{}_LR_bicubic".format(dataset_type),
subdir, "*.png"))
im_paths.sort()
grp = f.create_group(subdir)
for i, path in enumerate(im_paths):
im = misc.imread(path)
print(path)
grp.create_dataset(str(i), data=im)