Toggle navigation
Toggle navigation
This project
Loading...
Sign in
2021-1-capstone-design1
/
BSH_Project3
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
김재형
2021-06-05 23:28:20 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
8d6e4a041ac5bb684d8477302d9983a5f8d0c42d
8d6e4a04
1 parent
affc44f1
CARN 학습 시 PSNR, SSIM eval 코드 추가
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
41 deletions
carn/carn/solver.py
carn/carn/solver.py
View file @
8d6e4a0
...
...
@@ -2,7 +2,7 @@ import os
import
random
import
numpy
as
np
import
scipy.misc
as
misc
import
skimage.me
asure
as
measure
import
skimage.me
trics
as
metrics
from
tensorboardX
import
SummaryWriter
import
torch
import
torch.nn
as
nn
...
...
@@ -13,39 +13,39 @@ from dataset import TrainDataset, TestDataset
class
Solver
():
def
__init__
(
self
,
model
,
cfg
):
if
cfg
.
scale
>
0
:
self
.
refiner
=
model
(
scale
=
cfg
.
scale
,
self
.
refiner
=
model
(
scale
=
cfg
.
scale
,
group
=
cfg
.
group
)
else
:
self
.
refiner
=
model
(
multi_scale
=
True
,
self
.
refiner
=
model
(
multi_scale
=
True
,
group
=
cfg
.
group
)
if
cfg
.
loss_fn
in
[
"MSE"
]:
if
cfg
.
loss_fn
in
[
"MSE"
]:
self
.
loss_fn
=
nn
.
MSELoss
()
elif
cfg
.
loss_fn
in
[
"L1"
]:
elif
cfg
.
loss_fn
in
[
"L1"
]:
self
.
loss_fn
=
nn
.
L1Loss
()
elif
cfg
.
loss_fn
in
[
"SmoothL1"
]:
self
.
loss_fn
=
nn
.
SmoothL1Loss
()
self
.
optim
=
optim
.
Adam
(
filter
(
lambda
p
:
p
.
requires_grad
,
self
.
refiner
.
parameters
()),
filter
(
lambda
p
:
p
.
requires_grad
,
self
.
refiner
.
parameters
()),
cfg
.
lr
)
self
.
train_data
=
TrainDataset
(
cfg
.
train_data_path
,
scale
=
cfg
.
scale
,
self
.
train_data
=
TrainDataset
(
cfg
.
train_data_path
,
scale
=
cfg
.
scale
,
size
=
cfg
.
patch_size
)
self
.
train_loader
=
DataLoader
(
self
.
train_data
,
batch_size
=
cfg
.
batch_size
,
num_workers
=
1
,
shuffle
=
True
,
drop_last
=
True
)
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
self
.
refiner
=
self
.
refiner
.
to
(
self
.
device
)
self
.
loss_fn
=
self
.
loss_fn
self
.
cfg
=
cfg
self
.
step
=
0
self
.
writer
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
"runs"
,
cfg
.
ckpt_name
))
if
cfg
.
verbose
:
num_params
=
0
...
...
@@ -57,9 +57,9 @@ class Solver():
def
fit
(
self
):
cfg
=
self
.
cfg
refiner
=
nn
.
DataParallel
(
self
.
refiner
,
refiner
=
nn
.
DataParallel
(
self
.
refiner
,
device_ids
=
range
(
cfg
.
num_gpu
))
learning_rate
=
cfg
.
lr
while
True
:
for
inputs
in
self
.
train_loader
:
...
...
@@ -73,13 +73,13 @@ class Solver():
# i know this is stupid but just temporary
scale
=
random
.
randint
(
2
,
4
)
hr
,
lr
=
inputs
[
scale
-
2
][
0
],
inputs
[
scale
-
2
][
1
]
hr
=
hr
.
to
(
self
.
device
)
lr
=
lr
.
to
(
self
.
device
)
sr
=
refiner
(
lr
,
scale
)
loss
=
self
.
loss_fn
(
sr
,
hr
)
self
.
optim
.
zero_grad
()
loss
.
backward
()
nn
.
utils
.
clip_grad_norm
(
self
.
refiner
.
parameters
(),
cfg
.
clip
)
...
...
@@ -88,18 +88,19 @@ class Solver():
learning_rate
=
self
.
decay_learning_rate
()
for
param_group
in
self
.
optim
.
param_groups
:
param_group
[
"lr"
]
=
learning_rate
self
.
step
+=
1
if
cfg
.
verbose
and
self
.
step
%
cfg
.
print_interval
==
0
:
if
cfg
.
scale
>
0
:
psnr
=
self
.
evaluate
(
"dataset/Urban100"
,
scale
=
cfg
.
scale
,
num_step
=
self
.
step
)
self
.
writer
.
add_scalar
(
"Urban100"
,
psnr
,
self
.
step
)
else
:
psnr
,
ssim
=
self
.
evaluate
(
"dataset/Urban100"
,
scale
=
cfg
.
scale
,
num_step
=
self
.
step
)
self
.
writer
.
add_scalar
(
"PSNR"
,
psnr
,
self
.
step
)
self
.
writer
.
add_scalar
(
"SSIM"
,
ssim
,
self
.
step
)
else
:
psnr
=
[
self
.
evaluate
(
"dataset/Urban100"
,
scale
=
i
,
num_step
=
self
.
step
)
for
i
in
range
(
2
,
5
)]
self
.
writer
.
add_scalar
(
"Urban100_2x"
,
psnr
[
0
],
self
.
step
)
self
.
writer
.
add_scalar
(
"Urban100_3x"
,
psnr
[
1
],
self
.
step
)
self
.
writer
.
add_scalar
(
"Urban100_4x"
,
psnr
[
2
],
self
.
step
)
self
.
save
(
cfg
.
ckpt_dir
,
cfg
.
ckpt_name
)
if
self
.
step
>
cfg
.
max_steps
:
break
...
...
@@ -107,8 +108,9 @@ class Solver():
def
evaluate
(
self
,
test_data_dir
,
scale
=
2
,
num_step
=
0
):
cfg
=
self
.
cfg
mean_psnr
=
0
mean_ssim
=
0
self
.
refiner
.
eval
()
test_data
=
TestDataset
(
test_data_dir
,
scale
=
scale
)
test_loader
=
DataLoader
(
test_data
,
batch_size
=
1
,
...
...
@@ -131,13 +133,13 @@ class Solver():
lr_patch
[
2
]
.
copy_
(
lr
[:,
h
-
h_chop
:
h
,
0
:
w_chop
])
lr_patch
[
3
]
.
copy_
(
lr
[:,
h
-
h_chop
:
h
,
w
-
w_chop
:
w
])
lr_patch
=
lr_patch
.
to
(
self
.
device
)
# run refine process in here!
sr
=
self
.
refiner
(
lr_patch
,
scale
)
.
data
h
,
h_half
,
h_chop
=
h
*
scale
,
h_half
*
scale
,
h_chop
*
scale
w
,
w_half
,
w_chop
=
w
*
scale
,
w_half
*
scale
,
w_chop
*
scale
# merge splited patch images
result
=
torch
.
FloatTensor
(
3
,
h
,
w
)
.
to
(
self
.
device
)
result
[:,
0
:
h_half
,
0
:
w_half
]
.
copy_
(
sr
[
0
,
:,
0
:
h_half
,
0
:
w_half
])
...
...
@@ -148,16 +150,17 @@ class Solver():
hr
=
hr
.
cpu
()
.
mul
(
255
)
.
clamp
(
0
,
255
)
.
byte
()
.
permute
(
1
,
2
,
0
)
.
numpy
()
sr
=
sr
.
cpu
()
.
mul
(
255
)
.
clamp
(
0
,
255
)
.
byte
()
.
permute
(
1
,
2
,
0
)
.
numpy
()
# evaluate PSNR
# evaluate PSNR
and SSIM
# this evaluation is different to MATLAB version
# we evaluate PSNR in RGB channel not Y in YCbCR
# we evaluate PSNR in RGB channel not Y in YCbCR
bnd
=
scale
im1
=
hr
[
bnd
:
-
bnd
,
bnd
:
-
bnd
]
im2
=
sr
[
bnd
:
-
bnd
,
bnd
:
-
bnd
]
im1
=
im2double
(
hr
[
bnd
:
-
bnd
,
bnd
:
-
bnd
])
im2
=
im2double
(
sr
[
bnd
:
-
bnd
,
bnd
:
-
bnd
])
mean_psnr
+=
psnr
(
im1
,
im2
)
/
len
(
test_data
)
mean_ssim
+=
ssim
(
im1
,
im2
)
/
len
(
test_data
)
return
mean_psnr
return
mean_psnr
,
mean_ssim
def
load
(
self
,
path
):
self
.
refiner
.
load_state_dict
(
torch
.
load
(
path
))
...
...
@@ -177,14 +180,15 @@ class Solver():
lr
=
self
.
cfg
.
lr
*
(
0.5
**
(
self
.
step
//
self
.
cfg
.
decay
))
return
lr
def
im2double
(
im
):
min_val
,
max_val
=
0
,
255
out
=
(
im
.
astype
(
np
.
float64
)
-
min_val
)
/
(
max_val
-
min_val
)
return
out
def
psnr
(
im1
,
im2
):
def
im2double
(
im
):
min_val
,
max_val
=
0
,
255
out
=
(
im
.
astype
(
np
.
float64
)
-
min_val
)
/
(
max_val
-
min_val
)
return
out
im1
=
im2double
(
im1
)
im2
=
im2double
(
im2
)
psnr
=
measure
.
compare_psnr
(
im1
,
im2
,
data_range
=
1
)
psnr
=
metrics
.
peak_signal_noise_ratio
(
im1
,
im2
,
data_range
=
1
)
return
psnr
def
ssim
(
im1
,
im2
):
ssim
=
metrics
.
structural_similarity
(
im1
,
im2
,
data_range
=
1
,
multichannel
=
True
)
return
ssim
...
...
Please
register
or
login
to post a comment