이동주

save augmented images, labels, variance

1 +# run train.py --dataset cifar10 --model resnet18 --data_augmentation --cutout --length 16
2 +# run train.py --dataset cifar100 --model resnet18 --data_augmentation --cutout --length 8
3 +# run train.py --dataset svhn --model wideresnet --learning_rate 0.01 --epochs 160 --cutout --length 20
4 +
5 +import pdb
6 +import argparse
7 +import numpy as np
8 +from tqdm import tqdm
9 +import os
10 +
11 +import torch
12 +import torch.nn as nn
13 +from torch.autograd import Variable
14 +import torch.backends.cudnn as cudnn
15 +from torch.optim.lr_scheduler import MultiStepLR
16 +
17 +from torchvision.utils import make_grid, save_image
18 +from torchvision import datasets, transforms
19 +
20 +from torch.utils.data.dataloader import RandomSampler
21 +from util.misc import CSVLogger
22 +from util.cutout import Cutout
23 +
24 +from model.resnet import ResNet18
25 +from model.wide_resnet import WideResNet
26 +
27 +model_options = ['resnet18', 'wideresnet']
28 +dataset_options = ['cifar10', 'cifar100', 'svhn']
29 +
30 +parser = argparse.ArgumentParser(description='CNN')
31 +parser.add_argument('--dataset', '-d', default='cifar10',
32 + choices=dataset_options)
33 +parser.add_argument('--model', '-a', default='resnet18',
34 + choices=model_options)
35 +parser.add_argument('--batch_size', type=int, default=1,
36 + help='input batch size for training (default: 128)')
37 +parser.add_argument('--epochs', type=int, default=200,
38 + help='number of epochs to train (default: 20)')
39 +parser.add_argument('--learning_rate', type=float, default=0.1,
40 + help='learning rate')
41 +parser.add_argument('--data_augmentation', action='store_true', default=False,
42 + help='augment data by flipping and cropping')
43 +parser.add_argument('--cutout', action='store_true', default=False,
44 + help='apply cutout')
45 +parser.add_argument('--n_holes', type=int, default=1,
46 + help='number of holes to cut out from image')
47 +parser.add_argument('--length', type=int, default=16,
48 + help='length of the holes')
49 +parser.add_argument('--no-cuda', action='store_true', default=False,
50 + help='enables CUDA training')
51 +parser.add_argument('--seed', type=int, default=0,
52 + help='random seed (default: 1)')
53 +
54 +args = parser.parse_args()
55 +args.cuda = not args.no_cuda and torch.cuda.is_available()
56 +cudnn.benchmark = True # Should make training should go faster for large models
57 +
58 +torch.manual_seed(args.seed)
59 +if args.cuda:
60 + torch.cuda.manual_seed(args.seed)
61 +
62 +test_id = args.dataset + '_' + args.model
63 +
64 +print(args)
65 +
66 +# Image Preprocessing
67 +if args.dataset == 'svhn':
68 + normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]],
69 + std=[x / 255.0 for x in [50.1, 50.6, 50.8]])
70 +else:
71 + normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
72 + std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
73 +
74 +train_transform = transforms.Compose([])
75 +if args.data_augmentation:
76 + train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
77 + train_transform.transforms.append(transforms.RandomHorizontalFlip())
78 +train_transform.transforms.append(transforms.ToTensor())
79 +train_transform.transforms.append(normalize)
80 +if args.cutout:
81 + train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
82 +
83 +
84 +test_transform = transforms.Compose([
85 + transforms.ToTensor(),
86 + normalize])
87 +
88 +if args.dataset == 'cifar10':
89 + num_classes = 10
90 + train_dataset = datasets.CIFAR10(root='data/',
91 + train=True,
92 + transform=train_transform,
93 + download=True)
94 +
95 + test_dataset = datasets.CIFAR10(root='data/',
96 + train=False,
97 + transform=test_transform,
98 + download=True)
99 +elif args.dataset == 'cifar100':
100 + num_classes = 100
101 + train_dataset = datasets.CIFAR100(root='data/',
102 + train=True,
103 + transform=train_transform,
104 + download=True)
105 +
106 + test_dataset = datasets.CIFAR100(root='data/',
107 + train=False,
108 + transform=test_transform,
109 + download=True)
110 +elif args.dataset == 'svhn':
111 + num_classes = 10
112 + train_dataset = datasets.SVHN(root='data/',
113 + split='train',
114 + transform=train_transform,
115 + download=True)
116 +
117 + extra_dataset = datasets.SVHN(root='data/',
118 + split='extra',
119 + transform=train_transform,
120 + download=True)
121 +
122 + # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf)
123 + data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0)
124 + labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0)
125 + train_dataset.data = data
126 + train_dataset.labels = labels
127 +
128 + test_dataset = datasets.SVHN(root='data/',
129 + split='test',
130 + transform=test_transform,
131 + download=True)
132 +
133 +# Data Loader (Input Pipeline)
134 +train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
135 + batch_size=args.batch_size,
136 + shuffle=False,
137 + # sampler=RandomSampler(train_dataset, True, 40000),
138 + pin_memory=True,
139 + num_workers=0)
140 +
141 +test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
142 + batch_size=args.batch_size,
143 + shuffle=False,
144 + pin_memory=True,
145 + num_workers=0)
146 +
147 +if args.model == 'resnet18':
148 + cnn = ResNet18(num_classes=num_classes)
149 +elif args.model == 'wideresnet':
150 + if args.dataset == 'svhn':
151 + cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8,
152 + dropRate=0.4)
153 + else:
154 + cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10,
155 + dropRate=0.3)
156 +
157 +cnn = cnn.cuda()
158 +
159 +criterion = nn.CrossEntropyLoss().cuda()
160 +cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate,
161 + momentum=0.9, nesterov=True, weight_decay=5e-4)
162 +
163 +if args.dataset == 'svhn':
164 + scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1)
165 +else:
166 + scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2)
167 +
168 +filename = 'logs/' + test_id + '.csv'
169 +csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc', 'xentropy', 'var', 'avg_var', 'arg_var', 'index', 'labels'], filename=filename)
170 +
171 +
172 +def test(loader):
173 + cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var).
174 + correct = 0.
175 + total = 0.
176 + for images, labels in loader:
177 + images = images.cuda()
178 + labels = labels.cuda()
179 +
180 + with torch.no_grad():
181 + pred = cnn(images)
182 +
183 + pred = torch.max(pred.data, 1)[1]
184 + total += labels.size(0)
185 + correct += (pred == labels).sum().item()
186 +
187 + val_acc = correct / total
188 + cnn.train()
189 + return val_acc
190 +
191 +kl_sum = 0
192 +y_bar = torch.zeros(8, 10).detach().cuda()
193 +
194 +# y_bar 구하는 epoch
195 +for epoch in range(1):
196 + checkpoint = torch.load('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/checkpoints/sampling/sampling_{0}.pt'.format(8), map_location = torch.device('cuda:0'))
197 + cnn.load_state_dict(checkpoint)
198 + cnn.eval()
199 + xentropy_loss_avg = 0.
200 + correct = 0.
201 + total = 0.
202 + norm_const = 0
203 +
204 + kldiv = 0
205 + # pred_sum = torch.Tensor([0] * 10).detach().cuda()
206 + count = 0
207 + label_list = []
208 + progress_bar = tqdm(train_loader)
209 + for i, (images, labels) in enumerate(progress_bar):
210 + progress_bar.set_description('Epoch ' + str(epoch))
211 +
212 + images = images.cuda()
213 + labels = labels.cuda()
214 + label_list.append(labels.item())
215 + save_image(images[0], os.path.join('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/augmented_images/{0}/'.format(8), 'img{0}.png'.format(i)))
216 +
217 + cnn.zero_grad()
218 + pred = cnn(images)
219 + xentropy_loss = criterion(pred, labels)
220 + # xentropy_loss.backward()
221 + # cnn_optimizer.step()
222 +
223 + xentropy_loss_avg += xentropy_loss.item()
224 +
225 + pred_softmax = nn.functional.softmax(pred).cuda()
226 + # Calculate running average of accuracy
227 + pred = torch.max(pred.data, 1)[1]
228 + total += labels.size(0)
229 + correct += (pred == labels.data).sum().item()
230 + accuracy = correct / total
231 + for a in range(pred_softmax.data.size()[0]):
232 + for b in range(y_bar.size()[1]):
233 + y_bar[epoch][b] += torch.log(pred_softmax.data[a][b])
234 +
235 +
236 + progress_bar.set_postfix(
237 + xentropy='%.3f' % (xentropy_loss_avg / (i + 1)),
238 + acc='%.3f' % accuracy)
239 + count += 1
240 + xentropy = xentropy_loss_avg / count
241 + y_bar[epoch] = torch.Tensor([x / 50000 for x in y_bar[epoch]]).cuda()
242 + y_bar[epoch] = torch.exp(y_bar[epoch])
243 + for index in range(y_bar.size()[1]):
244 + norm_const += y_bar[epoch][index]
245 + for index in range(y_bar.size()[1]):
246 + y_bar[epoch][index] = y_bar[epoch][index] / norm_const
247 + print("y_bar[{0}] : ".format(epoch), y_bar[epoch])
248 + test_acc = test(test_loader)
249 + # print(pred, labels.data)
250 + tqdm.write('test_acc: %.3f' % (test_acc))
251 +
252 + scheduler.step(epoch) # Use this line for PyTorch <1.4
253 + # scheduler.step() # Use this line for PyTorch >=1.4
254 +
255 + row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc), 'xentropy' : float(xentropy)
256 + }
257 + csv_logger.writerow(row)
258 + # del pred
259 + # torch.cuda.empty_cache()
260 +
261 +
262 +var_tensor = torch.zeros(8, 50000).detach().cuda()
263 +var_addeachcol = torch.zeros(1, 50000).detach().cuda()
264 +
265 +# kl_div 구하는 epoch
266 +for epoch in range(1):
267 + checkpoint = torch.load('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/checkpoints/sampling/sampling_{0}.pt'.format(8), map_location = torch.device('cuda:0'))
268 + cnn.load_state_dict(checkpoint)
269 + cnn.eval()
270 + kldiv = 0
271 + for i, (images, labels) in enumerate(progress_bar):
272 + progress_bar.set_description('Epoch ' + str(epoch) + ': Calculate kl_div')
273 +
274 + images = images.cuda()
275 + labels = labels.cuda()
276 +
277 + cnn.zero_grad()
278 + pred = cnn(images)
279 +
280 + pred_softmax = nn.functional.softmax(pred).cuda()
281 +
282 + # 입력 두 개의 shape이 다르면 batchsize로 평균을 내서 반환.
283 + kldiv = torch.nn.functional.kl_div(y_bar[epoch], pred_softmax, reduction='sum')
284 + # 1 * 50000에 한 모델의 데이터별 variance 저장
285 + var_tensor[epoch][i] += abs(kldiv).detach()
286 + var_addeachcol[0][i] += var_tensor[epoch][i]
287 + kl_sum += kldiv.detach()
288 + # print(y_bar_copy.size(), pred_softmax.size())
289 + # print(kl_sum)
290 + var = abs(kl_sum.item() / 50000)
291 + print("Variance : ", var)
292 + csv_logger.writerow({'var' : float(var)})
293 + # print(var_tensor)
294 +for i in range(var_addeachcol.size()[1]):
295 + var_addeachcol[0][i] = var_addeachcol[0][i] / 8
296 +
297 +print(var_addeachcol)
298 +# var_addeachcol[0] = torch.Tensor([x / 8 for x in var_addeachcol]).cuda()
299 +var_sorted = torch.argsort(var_addeachcol)
300 +print(var_sorted)
301 +for i in range(var_addeachcol.size()[1]):
302 + csv_logger.writerow({'avg_var' : float(var_addeachcol[0][i]), 'arg_var' : float(var_sorted[0][i]), 'index' : float(i + 1), 'labels' : float(label_list[i])})
303 +torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt')
304 +csv_logger.close()