Showing
1 changed file
with
6 additions
and
2 deletions
... | @@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg): | ... | @@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg): |
76 | 76 | ||
77 | def main(cfg): | 77 | def main(cfg): |
78 | module = importlib.import_module("model.{}".format(cfg.model)) | 78 | module = importlib.import_module("model.{}".format(cfg.model)) |
79 | - net = module.Net(multi_scale=True, | 79 | + net = module.Net(multi_scale=False, |
80 | + scale=cfg.scale, | ||
80 | group=cfg.group) | 81 | group=cfg.group) |
81 | print(json.dumps(vars(cfg), indent=4, sort_keys=True)) | 82 | print(json.dumps(vars(cfg), indent=4, sort_keys=True)) |
82 | 83 | ||
83 | state_dict = torch.load(cfg.ckpt_path) | 84 | state_dict = torch.load(cfg.ckpt_path) |
85 | + # print(state_dict.keys()) | ||
84 | new_state_dict = OrderedDict() | 86 | new_state_dict = OrderedDict() |
85 | for k, v in state_dict.items(): | 87 | for k, v in state_dict.items(): |
86 | name = k | 88 | name = k |
... | @@ -88,12 +90,14 @@ def main(cfg): | ... | @@ -88,12 +90,14 @@ def main(cfg): |
88 | new_state_dict[name] = v | 90 | new_state_dict[name] = v |
89 | 91 | ||
90 | net.load_state_dict(new_state_dict) | 92 | net.load_state_dict(new_state_dict) |
93 | + net.eval() | ||
91 | 94 | ||
92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 95 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
93 | net = net.to(device) | 96 | net = net.to(device) |
94 | 97 | ||
95 | dataset = TestDataset(cfg.test_data_dir, cfg.scale) | 98 | dataset = TestDataset(cfg.test_data_dir, cfg.scale) |
96 | - sample(net, device, dataset, cfg) | 99 | + with torch.no_grad(): |
100 | + sample(net, device, dataset, cfg) | ||
97 | 101 | ||
98 | 102 | ||
99 | if __name__ == "__main__": | 103 | if __name__ == "__main__": | ... | ... |
-
Please register or login to post a comment