train.py
4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import fire
import time
import json
import random
from pprint import pprint
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from networks import *
from utils import *
# python train.py --use_cuda=True --network=resnet50 --dataset=BraTS --optimizer=adam
# nohup python train.py --use_cuda=True --network=resnet50 --dataset=BraTS --optimizer=adam &
def train(**kwargs):
print('\n[+] Parse arguments')
args, kwargs = parse_args(kwargs)
pprint(args)
device = torch.device('cuda' if args.use_cuda else 'cpu')
print('\n[+] Create log dir')
model_name = get_model_name(args)
log_dir = os.path.join('/content/drive/My Drive/CD2 Project/runs/classify/', model_name)
os.makedirs(os.path.join(log_dir, 'model'))
json.dump(kwargs, open(os.path.join(log_dir, 'kwargs.json'), 'w'))
os.makedirs(os.path.join(log_dir, 'train'))
writer = SummaryWriter(log_dir=os.path.join(log_dir, 'train'))
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
print('\n[+] Create network')
model = select_model(args)
optimizer = select_optimizer(args, model)
scheduler = select_scheduler(args, optimizer)
criterion = nn.CrossEntropyLoss()
if args.use_cuda:
model = model.cuda()
criterion = criterion.cuda()
#writer.add_graph(model)
print('\n[+] Load dataset')
transform = get_train_transform(args, model, log_dir)
val_transform = get_valid_transform(args, model)
train_dataset = get_dataset(args, transform, 'train')
valid_dataset = get_dataset(args, val_transform, 'val')
train_loader = iter(get_inf_dataloader(args, train_dataset))
max_epoch = len(train_dataset) // args.batch_size
best_acc = -1
print('\n[+] Start training')
if torch.cuda.device_count() > 1:
print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
model = nn.DataParallel(model)
print('\n[+] Use {} GPUs'.format(torch.cuda.device_count()))
print('\n[+] Using GPU: {} '.format(torch.cuda.get_device_name(0)))
start_t = time.time()
for step in range(args.start_step, args.max_step):
batch = next(train_loader)
_train_res = train_step(args, model, optimizer, scheduler, criterion, batch, step)
if step % args.print_step == 0:
print('\n[+] Training step: {}/{}\tTraining epoch: {}/{}\tElapsed time: {:.2f}min\tLearning rate: {}'.format(
step, args.max_step, current_epoch, max_epoch, (time.time()-start_t)/60, optimizer.param_groups[0]['lr']))
writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step=step)
writer.add_scalar('train/acc1', _train_res[0], global_step=step)
writer.add_scalar('train/loss', _train_res[1], global_step=step)
writer.add_scalar('train/forward_time', _train_res[2], global_step=step)
writer.add_scalar('train/backward_time', _train_res[3], global_step=step)
print(' Acc@1 : {:.3f}%'.format(_train_res[0].data.cpu().numpy()[0]*100))
print(' Loss : {}'.format(_train_res[1].data))
print(' FW Time : {:.3f}ms'.format(_train_res[2]*1000))
print(' BW Time : {:.3f}ms'.format(_train_res[3]*1000))
if step % args.val_step == args.val_step-1:
# print("\nstep, args.val_step: ", step, args.val_step)
valid_loader = iter(get_dataloader(args, valid_dataset))
_valid_res = validate(args, model, criterion, valid_loader, step)
print('\n[+] (Valid results) Valid step: {}/{}'.format(step, args.max_step))
writer.add_scalar('valid/acc1', _valid_res[0], global_step=step)
writer.add_scalar('valid/loss', _valid_res[1], global_step=step)
print(' Acc@1 : {:.3f}%'.format(_valid_res[0].data.cpu().numpy()[0]*100))
print(' Loss : {}'.format(_valid_res[1].data))
if _valid_res[0] >= best_acc:
best_acc = _valid_res[0]
torch.save(model.state_dict(), os.path.join(log_dir, "model","model.pt"))
print('\n[+] Model saved')
# writer.close()
if __name__ == '__main__':
fire.Fire(train)