Showing
1 changed file
with
304 additions
and
0 deletions
code/avg_var.py
0 → 100644
1 | +# run train.py --dataset cifar10 --model resnet18 --data_augmentation --cutout --length 16 | ||
2 | +# run train.py --dataset cifar100 --model resnet18 --data_augmentation --cutout --length 8 | ||
3 | +# run train.py --dataset svhn --model wideresnet --learning_rate 0.01 --epochs 160 --cutout --length 20 | ||
4 | + | ||
5 | +import pdb | ||
6 | +import argparse | ||
7 | +import numpy as np | ||
8 | +from tqdm import tqdm | ||
9 | +import os | ||
10 | + | ||
11 | +import torch | ||
12 | +import torch.nn as nn | ||
13 | +from torch.autograd import Variable | ||
14 | +import torch.backends.cudnn as cudnn | ||
15 | +from torch.optim.lr_scheduler import MultiStepLR | ||
16 | + | ||
17 | +from torchvision.utils import make_grid, save_image | ||
18 | +from torchvision import datasets, transforms | ||
19 | + | ||
20 | +from torch.utils.data.dataloader import RandomSampler | ||
21 | +from util.misc import CSVLogger | ||
22 | +from util.cutout import Cutout | ||
23 | + | ||
24 | +from model.resnet import ResNet18 | ||
25 | +from model.wide_resnet import WideResNet | ||
26 | + | ||
27 | +model_options = ['resnet18', 'wideresnet'] | ||
28 | +dataset_options = ['cifar10', 'cifar100', 'svhn'] | ||
29 | + | ||
30 | +parser = argparse.ArgumentParser(description='CNN') | ||
31 | +parser.add_argument('--dataset', '-d', default='cifar10', | ||
32 | + choices=dataset_options) | ||
33 | +parser.add_argument('--model', '-a', default='resnet18', | ||
34 | + choices=model_options) | ||
35 | +parser.add_argument('--batch_size', type=int, default=1, | ||
36 | + help='input batch size for training (default: 128)') | ||
37 | +parser.add_argument('--epochs', type=int, default=200, | ||
38 | + help='number of epochs to train (default: 20)') | ||
39 | +parser.add_argument('--learning_rate', type=float, default=0.1, | ||
40 | + help='learning rate') | ||
41 | +parser.add_argument('--data_augmentation', action='store_true', default=False, | ||
42 | + help='augment data by flipping and cropping') | ||
43 | +parser.add_argument('--cutout', action='store_true', default=False, | ||
44 | + help='apply cutout') | ||
45 | +parser.add_argument('--n_holes', type=int, default=1, | ||
46 | + help='number of holes to cut out from image') | ||
47 | +parser.add_argument('--length', type=int, default=16, | ||
48 | + help='length of the holes') | ||
49 | +parser.add_argument('--no-cuda', action='store_true', default=False, | ||
50 | + help='enables CUDA training') | ||
51 | +parser.add_argument('--seed', type=int, default=0, | ||
52 | + help='random seed (default: 1)') | ||
53 | + | ||
54 | +args = parser.parse_args() | ||
55 | +args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
56 | +cudnn.benchmark = True # Should make training should go faster for large models | ||
57 | + | ||
58 | +torch.manual_seed(args.seed) | ||
59 | +if args.cuda: | ||
60 | + torch.cuda.manual_seed(args.seed) | ||
61 | + | ||
62 | +test_id = args.dataset + '_' + args.model | ||
63 | + | ||
64 | +print(args) | ||
65 | + | ||
66 | +# Image Preprocessing | ||
67 | +if args.dataset == 'svhn': | ||
68 | + normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]], | ||
69 | + std=[x / 255.0 for x in [50.1, 50.6, 50.8]]) | ||
70 | +else: | ||
71 | + normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], | ||
72 | + std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) | ||
73 | + | ||
74 | +train_transform = transforms.Compose([]) | ||
75 | +if args.data_augmentation: | ||
76 | + train_transform.transforms.append(transforms.RandomCrop(32, padding=4)) | ||
77 | + train_transform.transforms.append(transforms.RandomHorizontalFlip()) | ||
78 | +train_transform.transforms.append(transforms.ToTensor()) | ||
79 | +train_transform.transforms.append(normalize) | ||
80 | +if args.cutout: | ||
81 | + train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length)) | ||
82 | + | ||
83 | + | ||
84 | +test_transform = transforms.Compose([ | ||
85 | + transforms.ToTensor(), | ||
86 | + normalize]) | ||
87 | + | ||
88 | +if args.dataset == 'cifar10': | ||
89 | + num_classes = 10 | ||
90 | + train_dataset = datasets.CIFAR10(root='data/', | ||
91 | + train=True, | ||
92 | + transform=train_transform, | ||
93 | + download=True) | ||
94 | + | ||
95 | + test_dataset = datasets.CIFAR10(root='data/', | ||
96 | + train=False, | ||
97 | + transform=test_transform, | ||
98 | + download=True) | ||
99 | +elif args.dataset == 'cifar100': | ||
100 | + num_classes = 100 | ||
101 | + train_dataset = datasets.CIFAR100(root='data/', | ||
102 | + train=True, | ||
103 | + transform=train_transform, | ||
104 | + download=True) | ||
105 | + | ||
106 | + test_dataset = datasets.CIFAR100(root='data/', | ||
107 | + train=False, | ||
108 | + transform=test_transform, | ||
109 | + download=True) | ||
110 | +elif args.dataset == 'svhn': | ||
111 | + num_classes = 10 | ||
112 | + train_dataset = datasets.SVHN(root='data/', | ||
113 | + split='train', | ||
114 | + transform=train_transform, | ||
115 | + download=True) | ||
116 | + | ||
117 | + extra_dataset = datasets.SVHN(root='data/', | ||
118 | + split='extra', | ||
119 | + transform=train_transform, | ||
120 | + download=True) | ||
121 | + | ||
122 | + # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf) | ||
123 | + data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0) | ||
124 | + labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0) | ||
125 | + train_dataset.data = data | ||
126 | + train_dataset.labels = labels | ||
127 | + | ||
128 | + test_dataset = datasets.SVHN(root='data/', | ||
129 | + split='test', | ||
130 | + transform=test_transform, | ||
131 | + download=True) | ||
132 | + | ||
133 | +# Data Loader (Input Pipeline) | ||
134 | +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
135 | + batch_size=args.batch_size, | ||
136 | + shuffle=False, | ||
137 | + # sampler=RandomSampler(train_dataset, True, 40000), | ||
138 | + pin_memory=True, | ||
139 | + num_workers=0) | ||
140 | + | ||
141 | +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | ||
142 | + batch_size=args.batch_size, | ||
143 | + shuffle=False, | ||
144 | + pin_memory=True, | ||
145 | + num_workers=0) | ||
146 | + | ||
147 | +if args.model == 'resnet18': | ||
148 | + cnn = ResNet18(num_classes=num_classes) | ||
149 | +elif args.model == 'wideresnet': | ||
150 | + if args.dataset == 'svhn': | ||
151 | + cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8, | ||
152 | + dropRate=0.4) | ||
153 | + else: | ||
154 | + cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, | ||
155 | + dropRate=0.3) | ||
156 | + | ||
157 | +cnn = cnn.cuda() | ||
158 | + | ||
159 | +criterion = nn.CrossEntropyLoss().cuda() | ||
160 | +cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate, | ||
161 | + momentum=0.9, nesterov=True, weight_decay=5e-4) | ||
162 | + | ||
163 | +if args.dataset == 'svhn': | ||
164 | + scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1) | ||
165 | +else: | ||
166 | + scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2) | ||
167 | + | ||
168 | +filename = 'logs/' + test_id + '.csv' | ||
169 | +csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc', 'xentropy', 'var', 'avg_var', 'arg_var', 'index', 'labels'], filename=filename) | ||
170 | + | ||
171 | + | ||
172 | +def test(loader): | ||
173 | + cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). | ||
174 | + correct = 0. | ||
175 | + total = 0. | ||
176 | + for images, labels in loader: | ||
177 | + images = images.cuda() | ||
178 | + labels = labels.cuda() | ||
179 | + | ||
180 | + with torch.no_grad(): | ||
181 | + pred = cnn(images) | ||
182 | + | ||
183 | + pred = torch.max(pred.data, 1)[1] | ||
184 | + total += labels.size(0) | ||
185 | + correct += (pred == labels).sum().item() | ||
186 | + | ||
187 | + val_acc = correct / total | ||
188 | + cnn.train() | ||
189 | + return val_acc | ||
190 | + | ||
191 | +kl_sum = 0 | ||
192 | +y_bar = torch.zeros(8, 10).detach().cuda() | ||
193 | + | ||
194 | +# y_bar 구하는 epoch | ||
195 | +for epoch in range(1): | ||
196 | + checkpoint = torch.load('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/checkpoints/sampling/sampling_{0}.pt'.format(8), map_location = torch.device('cuda:0')) | ||
197 | + cnn.load_state_dict(checkpoint) | ||
198 | + cnn.eval() | ||
199 | + xentropy_loss_avg = 0. | ||
200 | + correct = 0. | ||
201 | + total = 0. | ||
202 | + norm_const = 0 | ||
203 | + | ||
204 | + kldiv = 0 | ||
205 | + # pred_sum = torch.Tensor([0] * 10).detach().cuda() | ||
206 | + count = 0 | ||
207 | + label_list = [] | ||
208 | + progress_bar = tqdm(train_loader) | ||
209 | + for i, (images, labels) in enumerate(progress_bar): | ||
210 | + progress_bar.set_description('Epoch ' + str(epoch)) | ||
211 | + | ||
212 | + images = images.cuda() | ||
213 | + labels = labels.cuda() | ||
214 | + label_list.append(labels.item()) | ||
215 | + save_image(images[0], os.path.join('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/augmented_images/{0}/'.format(8), 'img{0}.png'.format(i))) | ||
216 | + | ||
217 | + cnn.zero_grad() | ||
218 | + pred = cnn(images) | ||
219 | + xentropy_loss = criterion(pred, labels) | ||
220 | + # xentropy_loss.backward() | ||
221 | + # cnn_optimizer.step() | ||
222 | + | ||
223 | + xentropy_loss_avg += xentropy_loss.item() | ||
224 | + | ||
225 | + pred_softmax = nn.functional.softmax(pred).cuda() | ||
226 | + # Calculate running average of accuracy | ||
227 | + pred = torch.max(pred.data, 1)[1] | ||
228 | + total += labels.size(0) | ||
229 | + correct += (pred == labels.data).sum().item() | ||
230 | + accuracy = correct / total | ||
231 | + for a in range(pred_softmax.data.size()[0]): | ||
232 | + for b in range(y_bar.size()[1]): | ||
233 | + y_bar[epoch][b] += torch.log(pred_softmax.data[a][b]) | ||
234 | + | ||
235 | + | ||
236 | + progress_bar.set_postfix( | ||
237 | + xentropy='%.3f' % (xentropy_loss_avg / (i + 1)), | ||
238 | + acc='%.3f' % accuracy) | ||
239 | + count += 1 | ||
240 | + xentropy = xentropy_loss_avg / count | ||
241 | + y_bar[epoch] = torch.Tensor([x / 50000 for x in y_bar[epoch]]).cuda() | ||
242 | + y_bar[epoch] = torch.exp(y_bar[epoch]) | ||
243 | + for index in range(y_bar.size()[1]): | ||
244 | + norm_const += y_bar[epoch][index] | ||
245 | + for index in range(y_bar.size()[1]): | ||
246 | + y_bar[epoch][index] = y_bar[epoch][index] / norm_const | ||
247 | + print("y_bar[{0}] : ".format(epoch), y_bar[epoch]) | ||
248 | + test_acc = test(test_loader) | ||
249 | + # print(pred, labels.data) | ||
250 | + tqdm.write('test_acc: %.3f' % (test_acc)) | ||
251 | + | ||
252 | + scheduler.step(epoch) # Use this line for PyTorch <1.4 | ||
253 | + # scheduler.step() # Use this line for PyTorch >=1.4 | ||
254 | + | ||
255 | + row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc), 'xentropy' : float(xentropy) | ||
256 | + } | ||
257 | + csv_logger.writerow(row) | ||
258 | + # del pred | ||
259 | + # torch.cuda.empty_cache() | ||
260 | + | ||
261 | + | ||
262 | +var_tensor = torch.zeros(8, 50000).detach().cuda() | ||
263 | +var_addeachcol = torch.zeros(1, 50000).detach().cuda() | ||
264 | + | ||
265 | +# kl_div 구하는 epoch | ||
266 | +for epoch in range(1): | ||
267 | + checkpoint = torch.load('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/checkpoints/sampling/sampling_{0}.pt'.format(8), map_location = torch.device('cuda:0')) | ||
268 | + cnn.load_state_dict(checkpoint) | ||
269 | + cnn.eval() | ||
270 | + kldiv = 0 | ||
271 | + for i, (images, labels) in enumerate(progress_bar): | ||
272 | + progress_bar.set_description('Epoch ' + str(epoch) + ': Calculate kl_div') | ||
273 | + | ||
274 | + images = images.cuda() | ||
275 | + labels = labels.cuda() | ||
276 | + | ||
277 | + cnn.zero_grad() | ||
278 | + pred = cnn(images) | ||
279 | + | ||
280 | + pred_softmax = nn.functional.softmax(pred).cuda() | ||
281 | + | ||
282 | + # 입력 두 개의 shape이 다르면 batchsize로 평균을 내서 반환. | ||
283 | + kldiv = torch.nn.functional.kl_div(y_bar[epoch], pred_softmax, reduction='sum') | ||
284 | + # 1 * 50000에 한 모델의 데이터별 variance 저장 | ||
285 | + var_tensor[epoch][i] += abs(kldiv).detach() | ||
286 | + var_addeachcol[0][i] += var_tensor[epoch][i] | ||
287 | + kl_sum += kldiv.detach() | ||
288 | + # print(y_bar_copy.size(), pred_softmax.size()) | ||
289 | + # print(kl_sum) | ||
290 | + var = abs(kl_sum.item() / 50000) | ||
291 | + print("Variance : ", var) | ||
292 | + csv_logger.writerow({'var' : float(var)}) | ||
293 | + # print(var_tensor) | ||
294 | +for i in range(var_addeachcol.size()[1]): | ||
295 | + var_addeachcol[0][i] = var_addeachcol[0][i] / 8 | ||
296 | + | ||
297 | +print(var_addeachcol) | ||
298 | +# var_addeachcol[0] = torch.Tensor([x / 8 for x in var_addeachcol]).cuda() | ||
299 | +var_sorted = torch.argsort(var_addeachcol) | ||
300 | +print(var_sorted) | ||
301 | +for i in range(var_addeachcol.size()[1]): | ||
302 | + csv_logger.writerow({'avg_var' : float(var_addeachcol[0][i]), 'arg_var' : float(var_sorted[0][i]), 'index' : float(i + 1), 'labels' : float(label_list[i])}) | ||
303 | +torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt') | ||
304 | +csv_logger.close() |
-
Please register or login to post a comment