Showing
4 changed files
with
7 additions
and
9 deletions
| ... | @@ -33,12 +33,9 @@ def eval(model_path): | ... | @@ -33,12 +33,9 @@ def eval(model_path): |
| 33 | model.load_state_dict(torch.load(weight_path)) | 33 | model.load_state_dict(torch.load(weight_path)) |
| 34 | 34 | ||
| 35 | print('\n[+] Load dataset') | 35 | print('\n[+] Load dataset') |
| 36 | - test_transform = get_valid_transform(args, model) | ||
| 37 | - #print('\nTEST Transform\n', test_transform) | ||
| 38 | test_dataset = get_dataset(args, 'test') | 36 | test_dataset = get_dataset(args, 'test') |
| 39 | 37 | ||
| 40 | 38 | ||
| 41 | - | ||
| 42 | test_loader = iter(get_dataloader(args, test_dataset)) ### | 39 | test_loader = iter(get_dataloader(args, test_dataset)) ### |
| 43 | 40 | ||
| 44 | print('\n[+] Start testing') | 41 | print('\n[+] Start testing') | ... | ... |
| ... | @@ -16,6 +16,8 @@ class BaseNet(nn.Module): | ... | @@ -16,6 +16,8 @@ class BaseNet(nn.Module): |
| 16 | x = self.after(f) | 16 | x = self.after(f) |
| 17 | x = x.reshape(x.size(0), -1) | 17 | x = x.reshape(x.size(0), -1) |
| 18 | x = self.fc(x) | 18 | x = self.fc(x) |
| 19 | + | ||
| 20 | + # output, first | ||
| 19 | return x, f | 21 | return x, f |
| 20 | 22 | ||
| 21 | """ | 23 | """ | ... | ... |
| ... | @@ -24,7 +24,7 @@ def train(**kwargs): | ... | @@ -24,7 +24,7 @@ def train(**kwargs): |
| 24 | 24 | ||
| 25 | print('\n[+] Create log dir') | 25 | print('\n[+] Create log dir') |
| 26 | model_name = get_model_name(args) | 26 | model_name = get_model_name(args) |
| 27 | - log_dir = os.path.join('/content/drive/My Drive/CD2 Project/classify/', model_name) | 27 | + log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name) |
| 28 | os.makedirs(os.path.join(log_dir, 'model')) | 28 | os.makedirs(os.path.join(log_dir, 'model')) |
| 29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | 29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) |
| 30 | writer = SummaryWriter(log_dir=log_dir) | 30 | writer = SummaryWriter(log_dir=log_dir) |
| ... | @@ -42,13 +42,11 @@ def train(**kwargs): | ... | @@ -42,13 +42,11 @@ def train(**kwargs): |
| 42 | if args.use_cuda: | 42 | if args.use_cuda: |
| 43 | model = model.cuda() | 43 | model = model.cuda() |
| 44 | criterion = criterion.cuda() | 44 | criterion = criterion.cuda() |
| 45 | - writer.add_graph(model) | 45 | + #writer.add_graph(model) |
| 46 | 46 | ||
| 47 | print('\n[+] Load dataset') | 47 | print('\n[+] Load dataset') |
| 48 | - transform = get_train_transform(args, model, log_dir) | 48 | + train_dataset = get_dataset(args, 'train') |
| 49 | - val_transform = get_valid_transform(args, model) | 49 | + valid_dataset = get_dataset(args, 'val') |
| 50 | - train_dataset = get_dataset(args, transform, 'train') | ||
| 51 | - valid_dataset = get_dataset(args, val_transform, 'val') | ||
| 52 | train_loader = iter(get_inf_dataloader(args, train_dataset)) | 50 | train_loader = iter(get_inf_dataloader(args, train_dataset)) |
| 53 | max_epoch = len(train_dataset) // args.batch_size | 51 | max_epoch = len(train_dataset) // args.batch_size |
| 54 | best_acc = -1 | 52 | best_acc = -1 |
| ... | @@ -82,6 +80,7 @@ def train(**kwargs): | ... | @@ -82,6 +80,7 @@ def train(**kwargs): |
| 82 | print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) | 80 | print(' BW Time : {:.3f}ms'.format(_train_res[4]*1000)) |
| 83 | 81 | ||
| 84 | if step % args.val_step == args.val_step-1: | 82 | if step % args.val_step == args.val_step-1: |
| 83 | + # print("\nstep, args.val_step: ", step, args.val_step) | ||
| 85 | valid_loader = iter(get_dataloader(args, valid_dataset)) | 84 | valid_loader = iter(get_dataloader(args, valid_dataset)) |
| 86 | _valid_res = validate(args, model, criterion, valid_loader, step, writer) | 85 | _valid_res = validate(args, model, criterion, valid_loader, step, writer) |
| 87 | print('\n[+] Valid results') | 86 | print('\n[+] Valid results') | ... | ... |
code/classifier/utils.py
0 → 100644
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment