이동주

test after removing n% high variance augmented data

1 +# run train.py --dataset cifar10 --model resnet18 --data_augmentation --cutout --length 16
2 +# run train.py --dataset cifar100 --model resnet18 --data_augmentation --cutout --length 8
3 +# run train.py --dataset svhn --model wideresnet --learning_rate 0.01 --epochs 160 --cutout --length 20
4 +
5 +import pdb
6 +import argparse
7 +import numpy as np
8 +from tqdm import tqdm
9 +import os
10 +
11 +import torch
12 +import torch.nn as nn
13 +from torch.autograd import Variable
14 +import torch.backends.cudnn as cudnn
15 +from torch.optim.lr_scheduler import MultiStepLR
16 +
17 +from torchvision.utils import make_grid, save_image
18 +from torchvision import datasets, transforms
19 +from torchvision.transforms.transforms import ToTensor
20 +from torchvision.datasets import ImageFolder
21 +from torch.utils.data import Dataset, DataLoader
22 +
23 +from torch.utils.data.dataloader import RandomSampler
24 +from util.misc import CSVLogger
25 +from util.cutout import Cutout
26 +
27 +from model.resnet import ResNet18
28 +from model.wide_resnet import WideResNet
29 +
30 +from PIL import Image
31 +from matplotlib.pyplot import imshow
32 +import time
33 +
34 +def csv2list(filename):
35 + lists = []
36 + file = open(filename, 'r', encoding='utf-8-sig')
37 + while True:
38 + line = file.readline().strip("\n")
39 + # int_list = [int(i) for i in line]
40 + if line:
41 + line = line.split(",")
42 + lists.append(line)
43 + else:
44 + break
45 + return lists
46 +
47 +# variance순으로 정렬된 logs파일에서 읽어오기
48 +filelist = csv2list("C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/logs/image_save/1_5000_deleted.csv")
49 +for i in range(len(filelist)):
50 + for j in range(len(filelist[0])):
51 + filelist[i][j] = float(filelist[i][j])
52 +transposelist = np.transpose(filelist)
53 +
54 +# print(list)
55 +list_tensor = torch.tensor(transposelist, dtype=torch.long)
56 +target = list(list_tensor[2])
57 +train_img_list = list()
58 +
59 +for img_idx in transposelist[1]:
60 + img_path = "C:/Users/82109/Desktop/model1/img" + str(int(img_idx)) + ".png"
61 + train_img_list.append(img_path)
62 +
63 +
64 +class Img_Dataset(Dataset):
65 +
66 + def __init__(self,file_list,transform):
67 + self.file_list = file_list
68 + self.transform = transform
69 +
70 + def __len__(self):
71 + return len(self.file_list)
72 +
73 + def __getitem__(self, index):
74 + img_path = self.file_list[index]
75 + images = np.array(Image.open(img_path))
76 + # img_transformed = self.transform(images)
77 +
78 + labels = target[index]
79 +
80 + return images, labels
81 +# print(list_tensor)
82 +# topk 갯수 설정
83 +# k = 5000
84 +# values, indices = torch.topk(list_tensor[0], k)
85 +# # print(values)
86 +# image_list = []
87 +# for i in range(k):
88 +# image_list.append(int(list_tensor[1][49999-i].item()))
89 +
90 +# for i in image_list:
91 +# file = "C:/Users/82109/Desktop/1/1/img{0}.png".format(i)
92 +# if os.path.isfile(file):
93 +# os.remove(file)
94 +
95 +# transform_train = transforms.Compose([ transforms.ToTensor(), ])
96 +
97 +
98 +
99 +
100 +# normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
101 +# std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
102 +
103 +
104 +
105 +
106 +model_options = ['resnet18', 'wideresnet']
107 +dataset_options = ['cifar10', 'cifar100', 'svhn']
108 +
109 +parser = argparse.ArgumentParser(description='CNN')
110 +parser.add_argument('--dataset', '-d', default='cifar10',
111 + choices=dataset_options)
112 +parser.add_argument('--model', '-a', default='resnet18',
113 + choices=model_options)
114 +parser.add_argument('--batch_size', type=int, default=100,
115 + help='input batch size for training (default: 128)')
116 +parser.add_argument('--epochs', type=int, default=200,
117 + help='number of epochs to train (default: 20)')
118 +parser.add_argument('--learning_rate', type=float, default=0.1,
119 + help='learning rate')
120 +parser.add_argument('--data_augmentation', action='store_true', default=False,
121 + help='augment data by flipping and cropping')
122 +parser.add_argument('--cutout', action='store_true', default=False,
123 + help='apply cutout')
124 +parser.add_argument('--n_holes', type=int, default=0,
125 + help='number of holes to cut out from image')
126 +parser.add_argument('--length', type=int, default=0,
127 + help='length of the holes')
128 +parser.add_argument('--no-cuda', action='store_true', default=False,
129 + help='enables CUDA training')
130 +parser.add_argument('--seed', type=int, default=0,
131 + help='random seed (default: 1)')
132 +
133 +args = parser.parse_args()
134 +args.cuda = not args.no_cuda and torch.cuda.is_available()
135 +cudnn.benchmark = True # Should make training should go faster for large models
136 +
137 +torch.manual_seed(args.seed)
138 +if args.cuda:
139 + torch.cuda.manual_seed(args.seed)
140 +
141 +test_id = args.dataset + '_' + args.model
142 +
143 +print(args)
144 +
145 +# Image Preprocessing
146 +if args.dataset == 'svhn':
147 + normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]],
148 + std=[x / 255.0 for x in [50.1, 50.6, 50.8]])
149 +else:
150 + normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
151 + std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
152 +
153 +train_transform = transforms.Compose([])
154 +if args.data_augmentation:
155 + train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
156 + train_transform.transforms.append(transforms.RandomHorizontalFlip())
157 +train_transform.transforms.append(transforms.ToTensor())
158 +train_transform.transforms.append(normalize)
159 +if args.cutout:
160 + train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
161 +
162 +test_transform = transforms.Compose([
163 + transforms.ToTensor(),
164 + normalize])
165 +
166 +
167 +if args.dataset == 'cifar10':
168 + num_classes = 10
169 + train_dataset = Img_Dataset(file_list = train_img_list,
170 + transform=train_transform)
171 + # custom_dataset = ImageFolder(root='C:/Users/82109/Desktop/1/', transform = transform_train)
172 +
173 + # train_dataset = datasets.CIFAR10(root='data/',
174 + # train=True,
175 + # transform=train_transform,
176 + # download=True)
177 +
178 + test_dataset = datasets.CIFAR10(root='data/',
179 + train=False,
180 + transform=test_transform,
181 + download=True)
182 +# elif args.dataset == 'cifar100':
183 +# num_classes = 100
184 +# train_dataset = datasets.CIFAR100(root='data/',
185 +# train=True,
186 +# transform=train_transform,
187 +# download=True)
188 +
189 +# test_dataset = datasets.CIFAR100(root='data/',
190 +# train=False,
191 +# transform=test_transform,
192 +# download=True)
193 +# elif args.dataset == 'svhn':
194 +# num_classes = 10
195 +# train_dataset = datasets.SVHN(root='data/',
196 +# split='train',
197 +# transform=train_transform,
198 +# download=True)
199 +
200 +# extra_dataset = datasets.SVHN(root='data/',
201 +# split='extra',
202 +# transform=train_transform,
203 +# download=True)
204 +
205 +# # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf)
206 +# data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0)
207 +# labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0)
208 +# train_dataset.data = data
209 +# train_dataset.labels = labels
210 +
211 +# test_dataset = datasets.SVHN(root='data/',
212 +# split='test',
213 +# transform=test_transform,
214 +# download=True)
215 +
216 +# # Data Loader (Input Pipeline)
217 +# train_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True,num_workers=0)
218 +train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
219 + batch_size=args.batch_size,
220 + shuffle=False,
221 + # sampler=RandomSampler(train_dataset, True, 40000),
222 + pin_memory=True,
223 + num_workers=0)
224 +
225 +test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
226 + batch_size=args.batch_size,
227 + shuffle=False,
228 + pin_memory=True,
229 + num_workers=0)
230 +
231 +if args.model == 'resnet18':
232 + cnn = ResNet18(num_classes=num_classes)
233 +elif args.model == 'wideresnet':
234 + if args.dataset == 'svhn':
235 + cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8,
236 + dropRate=0.4)
237 + else:
238 + cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10,
239 + dropRate=0.3)
240 +
241 +
242 +cnn = cnn.cuda()
243 +
244 +criterion = nn.CrossEntropyLoss().cuda()
245 +cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate,
246 + momentum=0.9, nesterov=True, weight_decay=5e-4)
247 +
248 +# scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2)
249 +if args.dataset == 'svhn':
250 + scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1)
251 +else:
252 + scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2)
253 +
254 +test_id = 'custom_dataset_resnet18'
255 +
256 +filename = 'logs/' + test_id + '.csv'
257 +csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc'], filename=filename)
258 +
259 +
260 +def test(loader):
261 + cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var).
262 + correct = 0.
263 + total = 0.
264 + count = 0
265 + for images, labels in loader:
266 + images = images.cuda()
267 + labels = labels.cuda()
268 +
269 + with torch.no_grad():
270 + pred = cnn(images)
271 +
272 + pred = torch.max(pred.data, 1)[1]
273 +
274 + total += labels.size(0)
275 + # if (pred == labels).sum().item():
276 + # print('match')
277 + # count +=1
278 + correct += (pred == labels).sum().item()
279 + val_acc = correct / total
280 + cnn.train()
281 + return val_acc
282 +
283 +# kl_sum = 0
284 +# y_bar = torch.zeros(8, 10).detach().cuda()
285 +
286 +# y_bar 구하는 epoch
287 +for epoch in range(1):
288 + xentropy_loss_avg = 0.
289 + correct = 0.
290 + total = 0.
291 + norm_const = 0
292 +
293 + kldiv = 0
294 + # pred_sum = torch.Tensor([0] * 10).detach().cuda()
295 + count = 0
296 + progress_bar = tqdm(train_loader)
297 + for i, (images, labels) in enumerate(progress_bar):
298 + progress_bar.set_description('Epoch ' + str(epoch))
299 +
300 + images = Variable(images.view([args.batch_size,3,32,32]).float().cuda())
301 + labels = Variable(labels.float().cuda())
302 + labels = torch.tensor(labels, dtype=torch.long, device=torch.device('cuda:0'))
303 + cnn.zero_grad()
304 + pred = cnn(images)
305 + xentropy_loss = criterion(pred, labels)
306 + xentropy_loss.backward()
307 + cnn_optimizer.step()
308 +
309 + xentropy_loss_avg += xentropy_loss.item()
310 +
311 + pred_softmax = nn.functional.softmax(pred).cuda()
312 + # Calculate running average of accuracy
313 + pred = torch.max(pred.data, 1)[1]
314 +
315 + total += labels.size(0)
316 + correct += (pred == labels.data).sum().item()
317 + accuracy = correct / total
318 + # for a in range(pred_softmax.data.size()[0]):
319 + # for b in range(y_bar.size()[1]):
320 + # y_bar[epoch][b] += torch.log(pred_softmax.data[a][b])
321 +
322 +
323 + progress_bar.set_postfix(
324 + xentropy='%.3f' % (xentropy_loss_avg / (i + 1)),
325 + acc='%.3f' % accuracy)
326 + # count += 1
327 + # xentropy = xentropy_loss_avg / count
328 + # y_bar[epoch] = torch.Tensor([x / 50000 for x in y_bar[epoch]]).cuda()
329 + # y_bar[epoch] = torch.exp(y_bar[epoch])
330 + # for index in range(y_bar.size()[1]):
331 + # norm_const += y_bar[epoch][index]
332 + # for index in range(y_bar.size()[1]):
333 + # y_bar[epoch][index] = y_bar[epoch][index] / norm_const
334 + # print("y_bar[{0}] : ".format(epoch), y_bar[epoch])
335 + test_acc = test(test_loader)
336 + # print(pred, labels.data)
337 + tqdm.write('test_acc: %.3f' % (test_acc))
338 +
339 + scheduler.step() # Use this line for PyTorch <1.4
340 + # scheduler.step() # Use this line for PyTorch >=1.4
341 +
342 + row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc)}
343 + csv_logger.writerow(row)
344 + # del pred
345 + # torch.cuda.empty_cache()
346 +
347 +
348 +# var_tensor = torch.zeros(8, 50000).detach().cuda()
349 +# var_addeachcol = torch.zeros(1, 50000).detach().cuda()
350 +
351 +# # kl_div 구하는 epoch
352 +# for epoch in range(1):
353 +# checkpoint = torch.load('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/Cutout/checkpoints/sampling/sampling_{0}.pt'.format(8), map_location = torch.device('cuda:0'))
354 +# cnn.load_state_dict(checkpoint)
355 +# cnn.eval()
356 +# kldiv = 0
357 +# for i, (images, labels) in enumerate(progress_bar):
358 +# progress_bar.set_description('Epoch ' + str(epoch) + ': Calculate kl_div')
359 +
360 +# images = images.cuda()
361 +# labels = labels.cuda()
362 +
363 +# cnn.zero_grad()
364 +# pred = cnn(images)
365 +
366 +# pred_softmax = nn.functional.softmax(pred).cuda()
367 +
368 +# # 입력 두 개의 shape이 다르면 batchsize로 평균을 내서 반환.
369 +# kldiv = torch.nn.functional.kl_div(y_bar[epoch], pred_softmax, reduction='sum')
370 +# # 1 * 50000에 한 모델의 데이터별 variance 저장
371 +# var_tensor[epoch][i] += abs(kldiv).detach()
372 +# var_addeachcol[0][i] += var_tensor[epoch][i]
373 +# kl_sum += kldiv.detach()
374 +# # print(y_bar_copy.size(), pred_softmax.size())
375 +# # print(kl_sum)
376 +# var = abs(kl_sum.item() / 50000)
377 +# print("Variance : ", var)
378 +# csv_logger.writerow({'var' : float(var)})
379 +# # print(var_tensor)
380 +# for i in range(var_addeachcol.size()[1]):
381 +# var_addeachcol[0][i] = var_addeachcol[0][i] / 8
382 +
383 +# print(var_addeachcol)
384 +# # var_addeachcol[0] = torch.Tensor([x / 8 for x in var_addeachcol]).cuda()
385 +# var_sorted = torch.argsort(var_addeachcol)
386 +# print(var_sorted)
387 +# for i in range(var_addeachcol.size()[1]):
388 +# csv_logger.writerow({'avg_var' : float(var_addeachcol[0][i]), 'arg_var' : float(var_sorted[0][i]), 'index' : float(i + 1)})
389 +torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt')
390 +csv_logger.close()