김재형

Test 코드 수정

...@@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg): ...@@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg):
76 76
77 def main(cfg): 77 def main(cfg):
78 module = importlib.import_module("model.{}".format(cfg.model)) 78 module = importlib.import_module("model.{}".format(cfg.model))
79 - net = module.Net(multi_scale=True, 79 + net = module.Net(multi_scale=False,
80 + scale=cfg.scale,
80 group=cfg.group) 81 group=cfg.group)
81 print(json.dumps(vars(cfg), indent=4, sort_keys=True)) 82 print(json.dumps(vars(cfg), indent=4, sort_keys=True))
82 83
83 state_dict = torch.load(cfg.ckpt_path) 84 state_dict = torch.load(cfg.ckpt_path)
85 + # print(state_dict.keys())
84 new_state_dict = OrderedDict() 86 new_state_dict = OrderedDict()
85 for k, v in state_dict.items(): 87 for k, v in state_dict.items():
86 name = k 88 name = k
...@@ -88,12 +90,14 @@ def main(cfg): ...@@ -88,12 +90,14 @@ def main(cfg):
88 new_state_dict[name] = v 90 new_state_dict[name] = v
89 91
90 net.load_state_dict(new_state_dict) 92 net.load_state_dict(new_state_dict)
93 + net.eval()
91 94
92 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 95 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93 net = net.to(device) 96 net = net.to(device)
94 97
95 dataset = TestDataset(cfg.test_data_dir, cfg.scale) 98 dataset = TestDataset(cfg.test_data_dir, cfg.scale)
96 - sample(net, device, dataset, cfg) 99 + with torch.no_grad():
100 + sample(net, device, dataset, cfg)
97 101
98 102
99 if __name__ == "__main__": 103 if __name__ == "__main__":
......