Showing
1 changed file
with
6 additions
and
2 deletions
| ... | @@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg): | ... | @@ -76,11 +76,13 @@ def sample(net, device, dataset, cfg): |
| 76 | 76 | ||
| 77 | def main(cfg): | 77 | def main(cfg): |
| 78 | module = importlib.import_module("model.{}".format(cfg.model)) | 78 | module = importlib.import_module("model.{}".format(cfg.model)) |
| 79 | - net = module.Net(multi_scale=True, | 79 | + net = module.Net(multi_scale=False, |
| 80 | + scale=cfg.scale, | ||
| 80 | group=cfg.group) | 81 | group=cfg.group) |
| 81 | print(json.dumps(vars(cfg), indent=4, sort_keys=True)) | 82 | print(json.dumps(vars(cfg), indent=4, sort_keys=True)) |
| 82 | 83 | ||
| 83 | state_dict = torch.load(cfg.ckpt_path) | 84 | state_dict = torch.load(cfg.ckpt_path) |
| 85 | + # print(state_dict.keys()) | ||
| 84 | new_state_dict = OrderedDict() | 86 | new_state_dict = OrderedDict() |
| 85 | for k, v in state_dict.items(): | 87 | for k, v in state_dict.items(): |
| 86 | name = k | 88 | name = k |
| ... | @@ -88,12 +90,14 @@ def main(cfg): | ... | @@ -88,12 +90,14 @@ def main(cfg): |
| 88 | new_state_dict[name] = v | 90 | new_state_dict[name] = v |
| 89 | 91 | ||
| 90 | net.load_state_dict(new_state_dict) | 92 | net.load_state_dict(new_state_dict) |
| 93 | + net.eval() | ||
| 91 | 94 | ||
| 92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 95 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 93 | net = net.to(device) | 96 | net = net.to(device) |
| 94 | 97 | ||
| 95 | dataset = TestDataset(cfg.test_data_dir, cfg.scale) | 98 | dataset = TestDataset(cfg.test_data_dir, cfg.scale) |
| 96 | - sample(net, device, dataset, cfg) | 99 | + with torch.no_grad(): |
| 100 | + sample(net, device, dataset, cfg) | ||
| 97 | 101 | ||
| 98 | 102 | ||
| 99 | if __name__ == "__main__": | 103 | if __name__ == "__main__": | ... | ... |
-
Please register or login to post a comment