Showing
3 changed files
with
105 additions
and
52 deletions
... | @@ -5,19 +5,19 @@ from pprint import pprint | ... | @@ -5,19 +5,19 @@ from pprint import pprint |
5 | 5 | ||
6 | import torch | 6 | import torch |
7 | import torch.nn as nn | 7 | import torch.nn as nn |
8 | -from torch.utils.tensorboard import SummaryWriter | 8 | +import torchvision.transforms as transforms |
9 | +#from torch.utils.tensorboard import SummaryWriter | ||
9 | 10 | ||
10 | from utils import * | 11 | from utils import * |
11 | 12 | ||
12 | # command | 13 | # command |
13 | # python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/' | 14 | # python eval.py --model_path='logs/April_16_00:26:10__resnet50__None/' |
14 | 15 | ||
15 | -def eval(model_path, num_data): | 16 | +def eval(model_path): |
16 | print('\n[+] Parse arguments') | 17 | print('\n[+] Parse arguments') |
17 | kwargs_path = os.path.join(model_path, 'kwargs.json') | 18 | kwargs_path = os.path.join(model_path, 'kwargs.json') |
18 | kwargs = json.loads(open(kwargs_path).read()) | 19 | kwargs = json.loads(open(kwargs_path).read()) |
19 | args, kwargs = parse_args(kwargs) | 20 | args, kwargs = parse_args(kwargs) |
20 | - args.batch_size = num_data | ||
21 | pprint(args) | 21 | pprint(args) |
22 | device = torch.device('cuda' if args.use_cuda else 'cpu') | 22 | device = torch.device('cuda' if args.use_cuda else 'cpu') |
23 | 23 | ||
... | @@ -35,23 +35,25 @@ def eval(model_path, num_data): | ... | @@ -35,23 +35,25 @@ def eval(model_path, num_data): |
35 | model.load_state_dict(torch.load(weight_path)) | 35 | model.load_state_dict(torch.load(weight_path)) |
36 | 36 | ||
37 | print('\n[+] Load dataset') | 37 | print('\n[+] Load dataset') |
38 | - test_dataset = get_dataset(args, 'test') | 38 | + transform = transforms.Compose([ |
39 | + transforms.Resize([240, 240]), | ||
40 | + transforms.ToTensor() | ||
41 | + ]) | ||
42 | + test_dataset = get_dataset(args, transform, 'test') | ||
39 | 43 | ||
40 | 44 | ||
41 | test_loader = iter(get_dataloader(args, test_dataset)) ### | 45 | test_loader = iter(get_dataloader(args, test_dataset)) ### |
42 | 46 | ||
43 | - print('\n[+] Start testing') | 47 | + # print('\n[+] Start testing') |
44 | - writer = SummaryWriter(log_dir=model_path) | 48 | + # writer = SummaryWriter(log_dir=model_path) |
45 | - _test_res = validate(args, model, criterion, test_loader, step=0, writer=writer) | 49 | + _test_res = validate(args, model, criterion, test_loader, step=0) |
46 | 50 | ||
47 | print('\n[+] Valid results') | 51 | print('\n[+] Valid results') |
48 | print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) | 52 | print(' Acc@1 : {:.3f}%'.format(_test_res[0].data.cpu().numpy()[0]*100)) |
49 | - print(' Acc@5 : {:.3f}%'.format(_test_res[1].data.cpu().numpy()[0]*100)) | 53 | + print(' Loss : {:.3f}'.format(_test_res[1].data)) |
50 | - print(' Acc_all : {:.3f}%'.format(_test_res[2].data.cpu().numpy()[0]*100)) | 54 | + print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[2]*1000 / len(test_dataset))) |
51 | - print(' Loss : {:.3f}'.format(_test_res[3].data)) | ||
52 | - print(' Infer Time(per image) : {:.3f}ms'.format(_test_res[4]*1000 / len(test_dataset))) | ||
53 | 55 | ||
54 | - writer.close() | 56 | + #writer.close() |
55 | 57 | ||
56 | if __name__ == '__main__': | 58 | if __name__ == '__main__': |
57 | fire.Fire(eval) | 59 | fire.Fire(eval) | ... | ... |
... | @@ -7,7 +7,7 @@ from pprint import pprint | ... | @@ -7,7 +7,7 @@ from pprint import pprint |
7 | 7 | ||
8 | import torch.nn as nn | 8 | import torch.nn as nn |
9 | import torch.backends.cudnn as cudnn | 9 | import torch.backends.cudnn as cudnn |
10 | -from torch.utils.tensorboard import SummaryWriter | 10 | +#from torch.utils.tensorboard import SummaryWriter |
11 | 11 | ||
12 | from networks import * | 12 | from networks import * |
13 | from utils import * | 13 | from utils import * |
... | @@ -27,7 +27,7 @@ def train(**kwargs): | ... | @@ -27,7 +27,7 @@ def train(**kwargs): |
27 | log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name) | 27 | log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name) |
28 | os.makedirs(os.path.join(log_dir, 'model')) | 28 | os.makedirs(os.path.join(log_dir, 'model')) |
29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) | 29 | json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w')) |
30 | - writer = SummaryWriter(log_dir=log_dir) | 30 | + #writer = SummaryWriter(log_dir=log_dir) |
31 | 31 | ||
32 | if args.seed is not None: | 32 | if args.seed is not None: |
33 | random.seed(args.seed) | 33 | random.seed(args.seed) |
... | @@ -45,8 +45,10 @@ def train(**kwargs): | ... | @@ -45,8 +45,10 @@ def train(**kwargs): |
45 | #writer.add_graph(model) | 45 | #writer.add_graph(model) |
46 | 46 | ||
47 | print('\n[+] Load dataset') | 47 | print('\n[+] Load dataset') |
48 | - train_dataset = get_dataset(args, 'train') | 48 | + transform = get_train_transform(args, model, log_dir) |
49 | - valid_dataset = get_dataset(args, 'val') | 49 | + val_transform = get_valid_transform(args, model) |
50 | + train_dataset = get_dataset(args, transform, 'train') | ||
51 | + valid_dataset = get_dataset(args, val_transform, 'val') | ||
50 | train_loader = iter(get_inf_dataloader(args, train_dataset)) | 52 | train_loader = iter(get_inf_dataloader(args, train_dataset)) |
51 | max_epoch = len(train_dataset) // args.batch_size | 53 | max_epoch = len(train_dataset) // args.batch_size |
52 | best_acc = -1 | 54 | best_acc = -1 |
... | @@ -62,16 +64,16 @@ def train(**kwargs): | ... | @@ -62,16 +64,16 @@ def train(**kwargs): |
62 | start_t = time.time() | 64 | start_t = time.time() |
63 | for step in range(args.start_step, args.max_step): | 65 | for step in range(args.start_step, args.max_step): |
64 | batch = next(train_loader) | 66 | batch = next(train_loader) |
65 | - _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step, writer) | 67 | + _train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step) |
66 | 68 | ||
67 | if step % args.print_step == 0: | 69 | if step % args.print_step == 0: |
68 | print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( | 70 | print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format( |
69 | step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) | 71 | step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr'])) |
70 | - writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) | 72 | + # writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step) |
71 | - writer.add_scalar('train/acc1', _train_res[0], global_step=step) | 73 | + # writer.add_scalar('train/acc1', _train_res[0], global_step=step) |
72 | - writer.add_scalar('train/loss', _train_res[1], global_step=step) | 74 | + # writer.add_scalar('train/loss', _train_res[1], global_step=step) |
73 | - writer.add_scalar('train/forward_time', _train_res[2], global_step=step) | 75 | + # writer.add_scalar('train/forward_time', _train_res[2], global_step=step) |
74 | - writer.add_scalar('train/backward_time', _train_res[3], global_step=step) | 76 | + # writer.add_scalar('train/backward_time', _train_res[3], global_step=step) |
75 | print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) | 77 | print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100)) |
76 | print(' Loss : {}'.format(_train_res[1].data)) | 78 | print(' Loss : {}'.format(_train_res[1].data)) |
77 | print(' FW Time : {:.3f}ms'.format(_train_res[2]*1000)) | 79 | print(' FW Time : {:.3f}ms'.format(_train_res[2]*1000)) |
... | @@ -80,10 +82,10 @@ def train(**kwargs): | ... | @@ -80,10 +82,10 @@ def train(**kwargs): |
80 | if step % args.val_step == args.val_step-1: | 82 | if step % args.val_step == args.val_step-1: |
81 | # print("\nstep, args.val_step: ", step, args.val_step) | 83 | # print("\nstep, args.val_step: ", step, args.val_step) |
82 | valid_loader = iter(get_dataloader(args, valid_dataset)) | 84 | valid_loader = iter(get_dataloader(args, valid_dataset)) |
83 | - _valid_res = validate(args, model, criterion, valid_loader, step, writer) | 85 | + _valid_res = validate(args, model, criterion, valid_loader, step) |
84 | - print('\n[+] Valid results') | 86 | + print('\n[+] (Valid results) Valid step: {}/{}'.format(step, args.max_step)) |
85 | - writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) | 87 | + # writer.add_scalar('valid/acc1', _valid_res[0], global_step=step) |
86 | - writer.add_scalar('valid/loss', _valid_res[1], global_step=step) | 88 | + # writer.add_scalar('valid/loss', _valid_res[1], global_step=step) |
87 | print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) | 89 | print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100)) |
88 | print(' Loss : {}'.format(_valid_res[1].data)) | 90 | print(' Loss : {}'.format(_valid_res[1].data)) |
89 | 91 | ||
... | @@ -92,7 +94,7 @@ def train(**kwargs): | ... | @@ -92,7 +94,7 @@ def train(**kwargs): |
92 | torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) | 94 | torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt")) |
93 | print('\n[+] Model saved') | 95 | print('\n[+] Model saved') |
94 | 96 | ||
95 | - writer.close() | 97 | + # writer.close() |
96 | 98 | ||
97 | 99 | ||
98 | if __name__ == '__main__': | 100 | if __name__ == '__main__': | ... | ... |
... | @@ -23,6 +23,7 @@ from sklearn.model_selection import KFold | ... | @@ -23,6 +23,7 @@ from sklearn.model_selection import KFold |
23 | 23 | ||
24 | from networks import basenet, grayResNet2 | 24 | from networks import basenet, grayResNet2 |
25 | 25 | ||
26 | +DATASET_PATH = '/content/drive/My Drive/CD2 Project/' | ||
26 | 27 | ||
27 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | 28 | TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' |
28 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | 29 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' |
... | @@ -131,17 +132,17 @@ def parse_args(kwargs): | ... | @@ -131,17 +132,17 @@ def parse_args(kwargs): |
131 | kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | 132 | kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' |
132 | kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | 133 | kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' |
133 | kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | 134 | kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' |
134 | - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.001 | 135 | + kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.01 |
135 | kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | 136 | kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None |
136 | kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | 137 | kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True |
137 | kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | 138 | kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() |
138 | kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | 139 | kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 |
139 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 100 | 140 | + kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 50 |
140 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 100 | 141 | + kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 50 |
141 | kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | 142 | kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' |
142 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 32 | 143 | + kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 16 |
143 | kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | 144 | kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 |
144 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 2500 | 145 | + kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 500 |
145 | kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | 146 | kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None |
146 | 147 | ||
147 | # to named tuple | 148 | # to named tuple |
... | @@ -155,11 +156,10 @@ def select_model(args): | ... | @@ -155,11 +156,10 @@ def select_model(args): |
155 | 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | 156 | 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} |
156 | 157 | ||
157 | if args.network in resnet_dict: | 158 | if args.network in resnet_dict: |
158 | - backbone = resnet_dict[args.network] | 159 | + model = resnet_dict[args.network] |
159 | - model = basenet.BaseNet(backbone, args) | 160 | + # else: # 3 channels |
160 | - else: | 161 | + # backbone = models.__dict__[args.network]() |
161 | - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | 162 | + # model = basenet.BaseNet(backbone, args) |
162 | - model = Net(args) | ||
163 | 163 | ||
164 | #print(model) # print model architecture | 164 | #print(model) # print model architecture |
165 | return model | 165 | return model |
... | @@ -187,16 +187,44 @@ def select_scheduler(args, optimizer): | ... | @@ -187,16 +187,44 @@ def select_scheduler(args, optimizer): |
187 | else: | 187 | else: |
188 | raise Exception('Unknown Scheduler') | 188 | raise Exception('Unknown Scheduler') |
189 | 189 | ||
190 | +def get_train_transform(args, model, transform, log_dir=None): | ||
191 | + if args.dataset == 'cifar10': | ||
192 | + transform = transforms.Compose([ | ||
193 | + transforms.Pad(4), | ||
194 | + transforms.RandomCrop(32), | ||
195 | + transforms.RandomHorizontalFlip(), | ||
196 | + transforms.ToTensor() | ||
197 | + ]) | ||
198 | + | ||
199 | + else: | ||
200 | + transform = transforms.Compose([ | ||
201 | + transforms.Resize([240, 240]), | ||
202 | + transforms.ToTensor() | ||
203 | + ]) | ||
204 | + | ||
205 | + return transform | ||
206 | + | ||
207 | +def get_valid_transform(args, model): | ||
208 | + if args.dataset == 'cifar10': | ||
209 | + val_transform = transforms.Compose([ | ||
210 | + transforms.Resize(32), | ||
211 | + transforms.ToTensor() | ||
212 | + ]) | ||
213 | + | ||
214 | + else: | ||
215 | + val_transform = transforms.Compose([ | ||
216 | + transforms.Resize([240, 240]), | ||
217 | + transforms.ToTensor() | ||
218 | + ]) | ||
219 | + | ||
220 | + return val_transform | ||
190 | 221 | ||
191 | class CustomDataset(Dataset): | 222 | class CustomDataset(Dataset): |
192 | - def __init__(self, data_path, csv_path): | 223 | + def __init__(self, data_path, csv_path, transform): |
193 | self.path = data_path | 224 | self.path = data_path |
194 | self.imgs = natsorted(os.listdir(data_path)) | 225 | self.imgs = natsorted(os.listdir(data_path)) |
195 | self.len = len(self.imgs) | 226 | self.len = len(self.imgs) |
196 | - self.transform = transforms.Compose([ | 227 | + self.transform = transform |
197 | - transforms.Resize([240, 240]), | ||
198 | - transforms.ToTensor() | ||
199 | - ]) | ||
200 | 228 | ||
201 | df = pd.read_csv(csv_path) | 229 | df = pd.read_csv(csv_path) |
202 | targets_list = [] | 230 | targets_list = [] |
... | @@ -215,6 +243,7 @@ class CustomDataset(Dataset): | ... | @@ -215,6 +243,7 @@ class CustomDataset(Dataset): |
215 | targets = self.targets[idx] | 243 | targets = self.targets[idx] |
216 | image = Image.open(img_loc) | 244 | image = Image.open(img_loc) |
217 | image = self.transform(image) | 245 | image = self.transform(image) |
246 | + #print("\n idx, img, targets: ", idx, img_loc, targets) | ||
218 | return image, targets | 247 | return image, targets |
219 | 248 | ||
220 | 249 | ||
... | @@ -222,12 +251,32 @@ class CustomDataset(Dataset): | ... | @@ -222,12 +251,32 @@ class CustomDataset(Dataset): |
222 | def get_dataset(args, transform, split='train'): | 251 | def get_dataset(args, transform, split='train'): |
223 | assert split in ['train', 'val', 'test'] | 252 | assert split in ['train', 'val', 'test'] |
224 | 253 | ||
225 | - if split in ['train']: | 254 | + if args.dataset == 'cifar10': |
226 | - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH) | 255 | + train = split in ['train', 'val', 'trainval'] |
227 | - elif split in ['val']: | 256 | + dataset = torchvision.datasets.CIFAR10(DATASET_PATH, |
228 | - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH) | 257 | + train=train, |
229 | - else : # test | 258 | + transform=transform, |
230 | - dataset = CustomDataset(TEST_DATASET_PATH, TEST_TARGET_PATH) | 259 | + download=True) |
260 | + | ||
261 | + if split in ['train', 'val']: | ||
262 | + split_path = os.path.join(DATASET_PATH, | ||
263 | + 'cifar-10-batches-py', 'train_val_index.cp') | ||
264 | + | ||
265 | + if not os.path.exists(split_path): | ||
266 | + [train_index], [val_index] = split_dataset(args, dataset, k=1) | ||
267 | + split_index = {'train':train_index, 'val':val_index} | ||
268 | + cp.dump(split_index, open(split_path, 'wb')) | ||
269 | + | ||
270 | + split_index = cp.load(open(split_path, 'rb')) | ||
271 | + dataset = Subset(dataset, split_index[split]) | ||
272 | + | ||
273 | + else: | ||
274 | + if split in ['train']: | ||
275 | + dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform) | ||
276 | + elif split in ['val']: | ||
277 | + dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform) | ||
278 | + else : # test | ||
279 | + dataset = CustomDataset(TEST_DATASET_PATH, TEST_TARGET_PATH, transform) | ||
231 | 280 | ||
232 | 281 | ||
233 | return dataset | 282 | return dataset |
... | @@ -261,7 +310,7 @@ def get_inf_dataloader(args, dataset): | ... | @@ -261,7 +310,7 @@ def get_inf_dataloader(args, dataset): |
261 | 310 | ||
262 | 311 | ||
263 | 312 | ||
264 | -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | 313 | +def train_step(args, model, optimizer, scheduler, criterion, batch, step, device=None): |
265 | model.train() | 314 | model.train() |
266 | images, target = batch | 315 | images, target = batch |
267 | 316 | ||
... | @@ -275,7 +324,7 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer | ... | @@ -275,7 +324,7 @@ def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer |
275 | 324 | ||
276 | # compute output | 325 | # compute output |
277 | start_t = time.time() | 326 | start_t = time.time() |
278 | - output, first = model(images) | 327 | + output= model(images) |
279 | forward_t = time.time() - start_t | 328 | forward_t = time.time() - start_t |
280 | loss = criterion(output, target) | 329 | loss = criterion(output, target) |
281 | 330 | ||
... | @@ -323,7 +372,7 @@ def accuracy(output, target, topk=(1,)): | ... | @@ -323,7 +372,7 @@ def accuracy(output, target, topk=(1,)): |
323 | res.append(correct_k) | 372 | res.append(correct_k) |
324 | return res | 373 | return res |
325 | 374 | ||
326 | -def validate(args, model, criterion, valid_loader, step, writer, device=None): | 375 | +def validate(args, model, criterion, valid_loader, step, device=None): |
327 | # switch to evaluate mode | 376 | # switch to evaluate mode |
328 | model.eval() | 377 | model.eval() |
329 | 378 | ||
... | @@ -344,7 +393,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): | ... | @@ -344,7 +393,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None): |
344 | target = target.cuda(non_blocking=True) | 393 | target = target.cuda(non_blocking=True) |
345 | 394 | ||
346 | # compute output | 395 | # compute output |
347 | - output, first = model(images) | 396 | + output = model(images) |
348 | loss = criterion(output, target) | 397 | loss = criterion(output, target) |
349 | infer_t += time.time() - start_t | 398 | infer_t += time.time() - start_t |
350 | 399 | ... | ... |
-
Please register or login to post a comment