videotester.py
2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import math
import utility
from data import common
import torch
import cv2
from tqdm import tqdm
class VideoTester():
def __init__(self, args, my_model, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.model = my_model
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
def test(self):
torch.set_grad_enabled(False)
self.ckp.write_log('\nEvaluation on video:')
self.model.eval()
timer_test = utility.timer()
for idx_scale, scale in enumerate(self.scale):
vidcap = cv2.VideoCapture(self.args.dir_demo)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
vidwri = cv2.VideoWriter(
self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
cv2.VideoWriter_fourcc(*'XVID'),
vidcap.get(cv2.CAP_PROP_FPS),
(
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
)
)
tqdm_test = tqdm(range(total_frames), ncols=80)
for _ in tqdm_test:
success, lr = vidcap.read()
if not success: break
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
lr, = self.prepare(lr.unsqueeze(0))
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
normalized = sr * 255 / self.args.rgb_range
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
vidwri.write(ndarr)
vidcap.release()
vidwri.release()
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]