infer.py
2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
from torchvision.utils import save_image
import numpy as np
from PIL import Image
import cv2
from models import UNetSemantic, GatedGenerator
import argparse
from configs import Config
class Predictor():
def __init__(self, cfg):
self.cfg = cfg
self.device = torch.device('cuda:0' if cfg.cuda else 'cpu')
self.masking = UNetSemantic().to(self.device)
self.masking.load_state_dict(torch.load('weights\model_segm_19_135000.pth', map_location='cpu'))
#self.masking.eval()
self.inpaint = GatedGenerator().to(self.device)
self.inpaint.load_state_dict(torch.load('weights/model_6_100000.pth', map_location='cpu')['G'])
self.inpaint.eval()
def save_image(self, img_list, save_img_path, nrow):
img_list = [i.clone().cpu() for i in img_list]
imgs = torch.stack(img_list, dim=1)
imgs = imgs.view(-1, *list(imgs.size())[2:])
save_image(imgs, save_img_path, nrow = nrow)
print(f"Save image to {save_img_path}")
def predict(self, image, outpath='sample/results.png'):
outpath=f'sample/results_{image}.png'
image = 'sample/'+image
img = cv2.imread(image+'_masked.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (self.cfg.img_size, self.cfg.img_size))
img = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
img = img.unsqueeze(0).to(self.device)
img_ori = cv2.imread(image+'.jpg')
img_ori = cv2.cvtColor(img_ori, cv2.COLOR_BGR2RGB)
img_ori = cv2.resize(img_ori, (self.cfg.img_size, self.cfg.img_size))
img_ori = torch.from_numpy(img_ori.astype(np.float32) / 255.0).permute(2, 0, 1).contiguous()
img_ori = img_ori.unsqueeze(0)
with torch.no_grad():
outputs = self.masking(img)
_, out = self.inpaint(img, outputs)
inpaint = img * (1 - outputs) + out * outputs
masks = img * (1 - outputs) + outputs #torch.cat([outputs, outputs, outputs], dim=1)
self.save_image([img, masks, inpaint, img_ori], outpath, nrow=4)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training custom model')
parser.add_argument('--image', default=None, type=str, help='resume training')
parser.add_argument('config', default='config', type=str, help='config training')
args = parser.parse_args()
config = Config(f'./configs/{args.config}.yaml')
model = Predictor(config)
model.predict(args.image)