이동주

1.train model with sampling data

1 +# run train.py --dataset cifar10 --model resnet18 --data_augmentation --cutout --length 16
2 +# run train.py --dataset cifar100 --model resnet18 --data_augmentation --cutout --length 8
3 +# run train.py --dataset svhn --model wideresnet --learning_rate 0.01 --epochs 160 --cutout --length 20
4 +
5 +import pdb
6 +import argparse
7 +import numpy as np
8 +from tqdm import tqdm
9 +
10 +import torch
11 +import torch.nn as nn
12 +from torch.autograd import Variable
13 +import torch.backends.cudnn as cudnn
14 +from torch.optim.lr_scheduler import MultiStepLR
15 +
16 +from torchvision.utils import make_grid, save_image
17 +from torchvision import datasets, transforms
18 +
19 +from torch.utils.data.dataloader import RandomSampler
20 +from util.misc import CSVLogger
21 +from util.cutout import Cutout
22 +
23 +from model.resnet import ResNet18
24 +from model.wide_resnet import WideResNet
25 +
26 +model_options = ['resnet18', 'wideresnet']
27 +dataset_options = ['cifar10', 'cifar100', 'svhn']
28 +
29 +parser = argparse.ArgumentParser(description='CNN')
30 +parser.add_argument('--dataset', '-d', default='cifar10',
31 + choices=dataset_options)
32 +parser.add_argument('--model', '-a', default='resnet18',
33 + choices=model_options)
34 +parser.add_argument('--batch_size', type=int, default=100,
35 + help='input batch size for training (default: 128)')
36 +parser.add_argument('--epochs', type=int, default=200,
37 + help='number of epochs to train (default: 20)')
38 +parser.add_argument('--learning_rate', type=float, default=0.1,
39 + help='learning rate')
40 +parser.add_argument('--data_augmentation', action='store_true', default=False,
41 + help='augment data by flipping and cropping')
42 +parser.add_argument('--cutout', action='store_true', default=False,
43 + help='apply cutout')
44 +parser.add_argument('--n_holes', type=int, default=0,
45 + help='number of holes to cut out from image')
46 +parser.add_argument('--length', type=int, default=0,
47 + help='length of the holes')
48 +parser.add_argument('--no-cuda', action='store_true', default=False,
49 + help='enables CUDA training')
50 +parser.add_argument('--seed', type=int, default=0,
51 + help='random seed (default: 1)')
52 +
53 +args = parser.parse_args()
54 +args.cuda = not args.no_cuda and torch.cuda.is_available()
55 +cudnn.benchmark = True # Should make training should go faster for large models
56 +
57 +torch.manual_seed(args.seed)
58 +if args.cuda:
59 + torch.cuda.manual_seed(args.seed)
60 +
61 +test_id = args.dataset + '_' + args.model
62 +
63 +print(args)
64 +
65 +# Image Preprocessing
66 +if args.dataset == 'svhn':
67 + normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]],
68 + std=[x / 255.0 for x in [50.1, 50.6, 50.8]])
69 +else:
70 + normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
71 + std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
72 +
73 +train_transform = transforms.Compose([])
74 +if args.data_augmentation:
75 + train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
76 + train_transform.transforms.append(transforms.RandomHorizontalFlip())
77 +train_transform.transforms.append(transforms.ToTensor())
78 +train_transform.transforms.append(normalize)
79 +if args.cutout:
80 + train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
81 +
82 +
83 +test_transform = transforms.Compose([
84 + transforms.ToTensor(),
85 + normalize])
86 +
87 +if args.dataset == 'cifar10':
88 + num_classes = 10
89 + train_dataset = datasets.CIFAR10(root='data/',
90 + train=True,
91 + transform=train_transform,
92 + download=True)
93 +
94 + test_dataset = datasets.CIFAR10(root='data/',
95 + train=False,
96 + transform=test_transform,
97 + download=True)
98 +elif args.dataset == 'cifar100':
99 + num_classes = 100
100 + train_dataset = datasets.CIFAR100(root='data/',
101 + train=True,
102 + transform=train_transform,
103 + download=True)
104 +
105 + test_dataset = datasets.CIFAR100(root='data/',
106 + train=False,
107 + transform=test_transform,
108 + download=True)
109 +elif args.dataset == 'svhn':
110 + num_classes = 10
111 + train_dataset = datasets.SVHN(root='data/',
112 + split='train',
113 + transform=train_transform,
114 + download=True)
115 +
116 + extra_dataset = datasets.SVHN(root='data/',
117 + split='extra',
118 + transform=train_transform,
119 + download=True)
120 +
121 + # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf)
122 + data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0)
123 + labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0)
124 + train_dataset.data = data
125 + train_dataset.labels = labels
126 +
127 + test_dataset = datasets.SVHN(root='data/',
128 + split='test',
129 + transform=test_transform,
130 + download=True)
131 +
132 +# Data Loader (Input Pipeline)
133 +train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
134 + batch_size=args.batch_size,
135 + shuffle=False,
136 + sampler=RandomSampler(train_dataset, True, 40000),
137 + pin_memory=True,
138 + num_workers=0)
139 +
140 +test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
141 + batch_size=args.batch_size,
142 + shuffle=False,
143 + pin_memory=True,
144 + num_workers=0)
145 +
146 +if args.model == 'resnet18':
147 + cnn = ResNet18(num_classes=num_classes)
148 +elif args.model == 'wideresnet':
149 + if args.dataset == 'svhn':
150 + cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8,
151 + dropRate=0.4)
152 + else:
153 + cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10,
154 + dropRate=0.3)
155 +
156 +cnn = cnn.cuda()
157 +criterion = nn.CrossEntropyLoss().cuda()
158 +cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate,
159 + momentum=0.9, nesterov=True, weight_decay=5e-4)
160 +
161 +if args.dataset == 'svhn':
162 + scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1)
163 +else:
164 + scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2)
165 +
166 +filename = 'logs/' + test_id + '.csv'
167 +csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc', 'labels'], filename=filename)
168 +
169 +
170 +def test(loader):
171 + cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var).
172 + correct = 0.
173 + total = 0.
174 + count = 0
175 + for images, labels in loader:
176 + images = images.cuda()
177 + labels = labels.cuda()
178 +
179 + with torch.no_grad():
180 + pred = cnn(images)
181 + # print(pred, labels)
182 +
183 + pred = torch.max(pred.data, 1)[1]
184 + print(pred, '\n', labels, (pred==labels).sum().item())
185 + # if (pred == labels).sum().item():
186 + # print('match')
187 + # count +=1
188 + total += labels.size(0)
189 + correct += (pred == labels).sum().item()
190 + print(correct)
191 + val_acc = correct / total
192 + cnn.train()
193 + return val_acc
194 +
195 +
196 +for epoch in range(args.epochs):
197 +
198 + xentropy_loss_avg = 0.
199 + correct = 0.
200 + total = 0.
201 + # kl_sum = 0
202 + # pred_sum = torch.Tensor([0] * 10).detach().cuda()
203 + label_list = []
204 + progress_bar = tqdm(train_loader)
205 + for i, (images, labels) in enumerate(progress_bar):
206 + progress_bar.set_description('Epoch ' + str(epoch))
207 +
208 + images = images.cuda()
209 + labels = labels.cuda()
210 + cnn.zero_grad()
211 + pred = cnn(images)
212 +
213 + xentropy_loss = criterion(pred, labels)
214 + xentropy_loss.backward()
215 + cnn_optimizer.step()
216 +
217 + xentropy_loss_avg += xentropy_loss.item()
218 +
219 +
220 + # Calculate running average of accuracy
221 + pred = torch.max(pred.data, 1)[1]
222 + total += labels.size(0)
223 + correct += (pred == labels.data).sum().item()
224 + accuracy = correct / total
225 +
226 + # print(pred)
227 + # y_hat : 모델별 예측값 --> pred
228 + # y_bar : 예측값들 평균값 -- > pred / total
229 + # labes.data : ground_truth
230 +
231 + # pred_sum = torch.add(pred_sum, pred)
232 + # y_bar = pred_sum / (i+1)
233 + # kl = torch.nn.functional.kl_div(pred, y_bar)
234 + # kl_sum += kl
235 +
236 +
237 +
238 + # for문 추가안하면 epoch별 iter마다 xentropy_loss_avg값의 1/iter이 xentropy값으로 출력
239 + # for문 추가하면 epoch 별 iter 마다 xentropy_loss_avg 값은 동일하나 xentropy값 출력이 x_l_avg 값의 1/10으로 출력
240 + # for문 상관 없이 pred, labels 값은 동일하게 확인됨.
241 +
242 + # for a in range(list(pred_sum.size())[0]):
243 + # for b in range(list(pred.size())[0]):
244 + # if pred[b] == a:
245 + # pred_sum[a] += 1
246 +
247 +
248 + # print('\n',i, ' ', xentropy_loss_avg)
249 + progress_bar.set_postfix(
250 + # y_hat = '%.5f' % pred,
251 + # y_bar = '%.5f' % y_bar,
252 + # groun_truth = '%.5f' % labels.data,
253 + # kl = '%.3f' % kl.item(),
254 + # kl_sum = '%.3f' % (kl_sum.item()),
255 + # kl_div = '%.3f' % (kl_sum.item() / (i + 1)), # kl_div 호출
256 + xentropy='%.3f' % (xentropy_loss_avg / (i + 1)),
257 + acc='%.3f' % accuracy)
258 + # pred_sum = [x / 40000 for x in pred_sum]
259 + test_acc = test(test_loader)
260 + # print(pred, labels.data)
261 + tqdm.write('test_acc: %.3f' % (test_acc))
262 +
263 + scheduler.step(epoch) # Use this line for PyTorch <1.4
264 + # scheduler.step() # Use this line for PyTorch >=1.4
265 +
266 + row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc)}
267 + csv_logger.writerow(row)
268 +for i in range(len(label_list)):
269 + csv_logger.writerow({'labels' : float(label_list[i])})
270 +torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt')
271 +csv_logger.close()