Showing
3 changed files
with
0 additions
and
1036 deletions
code/avg_var.py
deleted
100644 → 0
1 | -# run train.py --dataset cifar10 --model resnet18 --data_augmentation --cutout --length 16 | ||
2 | -# run train.py --dataset cifar100 --model resnet18 --data_augmentation --cutout --length 8 | ||
3 | -# run train.py --dataset svhn --model wideresnet --learning_rate 0.01 --epochs 160 --cutout --length 20 | ||
4 | - | ||
5 | -import pdb | ||
6 | -import argparse | ||
7 | -import numpy as np | ||
8 | -from tqdm import tqdm | ||
9 | -import os | ||
10 | - | ||
11 | -import torch | ||
12 | -import torch.nn as nn | ||
13 | -from torch.autograd import Variable | ||
14 | -import torch.backends.cudnn as cudnn | ||
15 | -from torch.optim.lr_scheduler import MultiStepLR | ||
16 | - | ||
17 | -from torchvision.utils import make_grid, save_image | ||
18 | -from torchvision import datasets, transforms | ||
19 | - | ||
20 | -from torch.utils.data.dataloader import RandomSampler | ||
21 | -from util.misc import CSVLogger | ||
22 | -from util.cutout import Cutout | ||
23 | - | ||
24 | -from model.resnet import ResNet18 | ||
25 | -from model.wide_resnet import WideResNet | ||
26 | - | ||
27 | -model_options = ['resnet18', 'wideresnet'] | ||
28 | -dataset_options = ['cifar10', 'cifar100', 'svhn'] | ||
29 | - | ||
30 | -parser = argparse.ArgumentParser(description='CNN') | ||
31 | -parser.add_argument('--dataset', '-d', default='cifar10', | ||
32 | - choices=dataset_options) | ||
33 | -parser.add_argument('--model', '-a', default='resnet18', | ||
34 | - choices=model_options) | ||
35 | -parser.add_argument('--batch_size', type=int, default=1, | ||
36 | - help='input batch size for training (default: 128)') | ||
37 | -parser.add_argument('--epochs', type=int, default=200, | ||
38 | - help='number of epochs to train (default: 20)') | ||
39 | -parser.add_argument('--learning_rate', type=float, default=0.1, | ||
40 | - help='learning rate') | ||
41 | -parser.add_argument('--data_augmentation', action='store_true', default=False, | ||
42 | - help='augment data by flipping and cropping') | ||
43 | -parser.add_argument('--cutout', action='store_true', default=False, | ||
44 | - help='apply cutout') | ||
45 | -parser.add_argument('--n_holes', type=int, default=1, | ||
46 | - help='number of holes to cut out from image') | ||
47 | -parser.add_argument('--length', type=int, default=16, | ||
48 | - help='length of the holes') | ||
49 | -parser.add_argument('--no-cuda', action='store_true', default=False, | ||
50 | - help='enables CUDA training') | ||
51 | -parser.add_argument('--seed', type=int, default=0, | ||
52 | - help='random seed (default: 1)') | ||
53 | - | ||
54 | -args = parser.parse_args() | ||
55 | -args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
56 | -cudnn.benchmark = True # Should make training should go faster for large models | ||
57 | - | ||
58 | -torch.manual_seed(args.seed) | ||
59 | -if args.cuda: | ||
60 | - torch.cuda.manual_seed(args.seed) | ||
61 | - | ||
62 | -test_id = args.dataset + '_' + args.model | ||
63 | - | ||
64 | -print(args) | ||
65 | - | ||
66 | -# Image Preprocessing | ||
67 | -if args.dataset == 'svhn': | ||
68 | - normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]], | ||
69 | - std=[x / 255.0 for x in [50.1, 50.6, 50.8]]) | ||
70 | -else: | ||
71 | - normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], | ||
72 | - std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) | ||
73 | - | ||
74 | -train_transform = transforms.Compose([]) | ||
75 | -if args.data_augmentation: | ||
76 | - train_transform.transforms.append(transforms.RandomCrop(32, padding=4)) | ||
77 | - train_transform.transforms.append(transforms.RandomHorizontalFlip()) | ||
78 | -train_transform.transforms.append(transforms.ToTensor()) | ||
79 | -train_transform.transforms.append(normalize) | ||
80 | -if args.cutout: | ||
81 | - train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length)) | ||
82 | - | ||
83 | - | ||
84 | -test_transform = transforms.Compose([ | ||
85 | - transforms.ToTensor(), | ||
86 | - normalize]) | ||
87 | - | ||
88 | -if args.dataset == 'cifar10': | ||
89 | - num_classes = 10 | ||
90 | - train_dataset = datasets.CIFAR10(root='data/', | ||
91 | - train=True, | ||
92 | - transform=train_transform, | ||
93 | - download=True) | ||
94 | - | ||
95 | - test_dataset = datasets.CIFAR10(root='data/', | ||
96 | - train=False, | ||
97 | - transform=test_transform, | ||
98 | - download=True) | ||
99 | -elif args.dataset == 'cifar100': | ||
100 | - num_classes = 100 | ||
101 | - train_dataset = datasets.CIFAR100(root='data/', | ||
102 | - train=True, | ||
103 | - transform=train_transform, | ||
104 | - download=True) | ||
105 | - | ||
106 | - test_dataset = datasets.CIFAR100(root='data/', | ||
107 | - train=False, | ||
108 | - transform=test_transform, | ||
109 | - download=True) | ||
110 | -elif args.dataset == 'svhn': | ||
111 | - num_classes = 10 | ||
112 | - train_dataset = datasets.SVHN(root='data/', | ||
113 | - split='train', | ||
114 | - transform=train_transform, | ||
115 | - download=True) | ||
116 | - | ||
117 | - extra_dataset = datasets.SVHN(root='data/', | ||
118 | - split='extra', | ||
119 | - transform=train_transform, | ||
120 | - download=True) | ||
121 | - | ||
122 | - # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf) | ||
123 | - data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0) | ||
124 | - labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0) | ||
125 | - train_dataset.data = data | ||
126 | - train_dataset.labels = labels | ||
127 | - | ||
128 | - test_dataset = datasets.SVHN(root='data/', | ||
129 | - split='test', | ||
130 | - transform=test_transform, | ||
131 | - download=True) | ||
132 | - | ||
133 | -# Data Loader (Input Pipeline) | ||
134 | -train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
135 | - batch_size=args.batch_size, | ||
136 | - shuffle=False, | ||
137 | - # sampler=RandomSampler(train_dataset, True, 40000), | ||
138 | - pin_memory=True, | ||
139 | - num_workers=0) | ||
140 | - | ||
141 | -test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | ||
142 | - batch_size=args.batch_size, | ||
143 | - shuffle=False, | ||
144 | - pin_memory=True, | ||
145 | - num_workers=0) | ||
146 | - | ||
147 | -if args.model == 'resnet18': | ||
148 | - cnn = ResNet18(num_classes=num_classes) | ||
149 | -elif args.model == 'wideresnet': | ||
150 | - if args.dataset == 'svhn': | ||
151 | - cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8, | ||
152 | - dropRate=0.4) | ||
153 | - else: | ||
154 | - cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, | ||
155 | - dropRate=0.3) | ||
156 | - | ||
157 | -cnn = cnn.cuda() | ||
158 | - | ||
159 | -criterion = nn.CrossEntropyLoss().cuda() | ||
160 | -cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate, | ||
161 | - momentum=0.9, nesterov=True, weight_decay=5e-4) | ||
162 | - | ||
163 | -if args.dataset == 'svhn': | ||
164 | - scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1) | ||
165 | -else: | ||
166 | - scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2) | ||
167 | - | ||
168 | -filename = 'logs/' + test_id + '.csv' | ||
169 | -csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc', 'xentropy', 'var', 'avg_var', 'arg_var', 'index', 'labels'], filename=filename) | ||
170 | - | ||
171 | - | ||
172 | -def test(loader): | ||
173 | - cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). | ||
174 | - correct = 0. | ||
175 | - total = 0. | ||
176 | - for images, labels in loader: | ||
177 | - images = images.cuda() | ||
178 | - labels = labels.cuda() | ||
179 | - | ||
180 | - with torch.no_grad(): | ||
181 | - pred = cnn(images) | ||
182 | - | ||
183 | - pred = torch.max(pred.data, 1)[1] | ||
184 | - total += labels.size(0) | ||
185 | - correct += (pred == labels).sum().item() | ||
186 | - | ||
187 | - val_acc = correct / total | ||
188 | - cnn.train() | ||
189 | - return val_acc | ||
190 | - | ||
191 | -kl_sum = 0 | ||
192 | -y_bar = torch.zeros(8, 10).detach().cuda() | ||
193 | - | ||
194 | -# y_bar 구하는 epoch | ||
195 | -for epoch in range(1): | ||
196 | - checkpoint = torch.load('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/checkpoints/sampling/sampling_{0}.pt'.format(8), map_location = torch.device('cuda:0')) | ||
197 | - cnn.load_state_dict(checkpoint) | ||
198 | - cnn.eval() | ||
199 | - xentropy_loss_avg = 0. | ||
200 | - correct = 0. | ||
201 | - total = 0. | ||
202 | - norm_const = 0 | ||
203 | - | ||
204 | - kldiv = 0 | ||
205 | - # pred_sum = torch.Tensor([0] * 10).detach().cuda() | ||
206 | - count = 0 | ||
207 | - label_list = [] | ||
208 | - progress_bar = tqdm(train_loader) | ||
209 | - for i, (images, labels) in enumerate(progress_bar): | ||
210 | - progress_bar.set_description('Epoch ' + str(epoch)) | ||
211 | - | ||
212 | - images = images.cuda() | ||
213 | - labels = labels.cuda() | ||
214 | - label_list.append(labels.item()) | ||
215 | - save_image(images[0], os.path.join('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/augmented_images/{0}/'.format(8), 'img{0}.png'.format(i))) | ||
216 | - | ||
217 | - cnn.zero_grad() | ||
218 | - pred = cnn(images) | ||
219 | - xentropy_loss = criterion(pred, labels) | ||
220 | - # xentropy_loss.backward() | ||
221 | - # cnn_optimizer.step() | ||
222 | - | ||
223 | - xentropy_loss_avg += xentropy_loss.item() | ||
224 | - | ||
225 | - pred_softmax = nn.functional.softmax(pred).cuda() | ||
226 | - # Calculate running average of accuracy | ||
227 | - pred = torch.max(pred.data, 1)[1] | ||
228 | - total += labels.size(0) | ||
229 | - correct += (pred == labels.data).sum().item() | ||
230 | - accuracy = correct / total | ||
231 | - for a in range(pred_softmax.data.size()[0]): | ||
232 | - for b in range(y_bar.size()[1]): | ||
233 | - y_bar[epoch][b] += torch.log(pred_softmax.data[a][b]) | ||
234 | - | ||
235 | - | ||
236 | - progress_bar.set_postfix( | ||
237 | - xentropy='%.3f' % (xentropy_loss_avg / (i + 1)), | ||
238 | - acc='%.3f' % accuracy) | ||
239 | - count += 1 | ||
240 | - xentropy = xentropy_loss_avg / count | ||
241 | - y_bar[epoch] = torch.Tensor([x / 50000 for x in y_bar[epoch]]).cuda() | ||
242 | - y_bar[epoch] = torch.exp(y_bar[epoch]) | ||
243 | - for index in range(y_bar.size()[1]): | ||
244 | - norm_const += y_bar[epoch][index] | ||
245 | - for index in range(y_bar.size()[1]): | ||
246 | - y_bar[epoch][index] = y_bar[epoch][index] / norm_const | ||
247 | - print("y_bar[{0}] : ".format(epoch), y_bar[epoch]) | ||
248 | - test_acc = test(test_loader) | ||
249 | - # print(pred, labels.data) | ||
250 | - tqdm.write('test_acc: %.3f' % (test_acc)) | ||
251 | - | ||
252 | - scheduler.step(epoch) # Use this line for PyTorch <1.4 | ||
253 | - # scheduler.step() # Use this line for PyTorch >=1.4 | ||
254 | - | ||
255 | - row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc), 'xentropy' : float(xentropy) | ||
256 | - } | ||
257 | - csv_logger.writerow(row) | ||
258 | - # del pred | ||
259 | - # torch.cuda.empty_cache() | ||
260 | - | ||
261 | - | ||
262 | -var_tensor = torch.zeros(8, 50000).detach().cuda() | ||
263 | -var_addeachcol = torch.zeros(1, 50000).detach().cuda() | ||
264 | - | ||
265 | -# kl_div 구하는 epoch | ||
266 | -for epoch in range(1): | ||
267 | - checkpoint = torch.load('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/checkpoints/sampling/sampling_{0}.pt'.format(8), map_location = torch.device('cuda:0')) | ||
268 | - cnn.load_state_dict(checkpoint) | ||
269 | - cnn.eval() | ||
270 | - kldiv = 0 | ||
271 | - for i, (images, labels) in enumerate(progress_bar): | ||
272 | - progress_bar.set_description('Epoch ' + str(epoch) + ': Calculate kl_div') | ||
273 | - | ||
274 | - images = images.cuda() | ||
275 | - labels = labels.cuda() | ||
276 | - | ||
277 | - cnn.zero_grad() | ||
278 | - pred = cnn(images) | ||
279 | - | ||
280 | - pred_softmax = nn.functional.softmax(pred).cuda() | ||
281 | - | ||
282 | - # 입력 두 개의 shape이 다르면 batchsize로 평균을 내서 반환. | ||
283 | - kldiv = torch.nn.functional.kl_div(y_bar[epoch], pred_softmax, reduction='sum') | ||
284 | - # 1 * 50000에 한 모델의 데이터별 variance 저장 | ||
285 | - var_tensor[epoch][i] += abs(kldiv).detach() | ||
286 | - var_addeachcol[0][i] += var_tensor[epoch][i] | ||
287 | - kl_sum += kldiv.detach() | ||
288 | - # print(y_bar_copy.size(), pred_softmax.size()) | ||
289 | - # print(kl_sum) | ||
290 | - var = abs(kl_sum.item() / 50000) | ||
291 | - print("Variance : ", var) | ||
292 | - csv_logger.writerow({'var' : float(var)}) | ||
293 | - # print(var_tensor) | ||
294 | -for i in range(var_addeachcol.size()[1]): | ||
295 | - var_addeachcol[0][i] = var_addeachcol[0][i] / 8 | ||
296 | - | ||
297 | -print(var_addeachcol) | ||
298 | -# var_addeachcol[0] = torch.Tensor([x / 8 for x in var_addeachcol]).cuda() | ||
299 | -var_sorted = torch.argsort(var_addeachcol) | ||
300 | -print(var_sorted) | ||
301 | -for i in range(var_addeachcol.size()[1]): | ||
302 | - csv_logger.writerow({'avg_var' : float(var_addeachcol[0][i]), 'arg_var' : float(var_sorted[0][i]), 'index' : float(i + 1), 'labels' : float(label_list[i])}) | ||
303 | -torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt') | ||
304 | -csv_logger.close() |
code/cal_variance.py
deleted
100644 → 0
1 | -# run train.py --dataset cifar10 --model resnet18 --data_augmentation --cutout --length 16 | ||
2 | -# run train.py --dataset cifar100 --model resnet18 --data_augmentation --cutout --length 8 | ||
3 | -# run train.py --dataset svhn --model wideresnet --learning_rate 0.01 --epochs 160 --cutout --length 20 | ||
4 | - | ||
5 | -import pdb | ||
6 | -import argparse | ||
7 | -import numpy as np | ||
8 | -from tqdm import tqdm | ||
9 | - | ||
10 | -import torch | ||
11 | -import torch.nn as nn | ||
12 | -from torch.autograd import Variable | ||
13 | -import torch.backends.cudnn as cudnn | ||
14 | -from torch.optim.lr_scheduler import MultiStepLR | ||
15 | - | ||
16 | -from torchvision.utils import make_grid | ||
17 | -from torchvision import datasets, transforms | ||
18 | - | ||
19 | -from torch.utils.data.dataloader import RandomSampler | ||
20 | -from util.misc import CSVLogger | ||
21 | -from util.cutout import Cutout | ||
22 | - | ||
23 | -from model.resnet import ResNet18 | ||
24 | -from model.wide_resnet import WideResNet | ||
25 | - | ||
26 | -model_options = ['resnet18', 'wideresnet'] | ||
27 | -dataset_options = ['cifar10', 'cifar100', 'svhn'] | ||
28 | - | ||
29 | -parser = argparse.ArgumentParser(description='CNN') | ||
30 | -parser.add_argument('--dataset', '-d', default='cifar10', | ||
31 | - choices=dataset_options) | ||
32 | -parser.add_argument('--model', '-a', default='resnet18', | ||
33 | - choices=model_options) | ||
34 | -parser.add_argument('--batch_size', type=int, default=128, | ||
35 | - help='input batch size for training (default: 128)') | ||
36 | -parser.add_argument('--epochs', type=int, default=200, | ||
37 | - help='number of epochs to train (default: 20)') | ||
38 | -parser.add_argument('--learning_rate', type=float, default=0.1, | ||
39 | - help='learning rate') | ||
40 | -parser.add_argument('--data_augmentation', action='store_true', default=False, | ||
41 | - help='augment data by flipping and cropping') | ||
42 | -parser.add_argument('--cutout', action='store_true', default=False, | ||
43 | - help='apply cutout') | ||
44 | -parser.add_argument('--n_holes', type=int, default=1, | ||
45 | - help='number of holes to cut out from image') | ||
46 | -parser.add_argument('--length', type=int, default=16, | ||
47 | - help='length of the holes') | ||
48 | -parser.add_argument('--no-cuda', action='store_true', default=False, | ||
49 | - help='enables CUDA training') | ||
50 | -parser.add_argument('--seed', type=int, default=0, | ||
51 | - help='random seed (default: 1)') | ||
52 | - | ||
53 | -args = parser.parse_args() | ||
54 | -args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
55 | -cudnn.benchmark = True # Should make training should go faster for large models | ||
56 | - | ||
57 | -torch.manual_seed(args.seed) | ||
58 | -if args.cuda: | ||
59 | - torch.cuda.manual_seed(args.seed) | ||
60 | - | ||
61 | -test_id = args.dataset + '_' + args.model | ||
62 | - | ||
63 | -print(args) | ||
64 | - | ||
65 | -# Image Preprocessing | ||
66 | -if args.dataset == 'svhn': | ||
67 | - normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]], | ||
68 | - std=[x / 255.0 for x in [50.1, 50.6, 50.8]]) | ||
69 | -else: | ||
70 | - normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], | ||
71 | - std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) | ||
72 | - | ||
73 | -train_transform = transforms.Compose([]) | ||
74 | -if args.data_augmentation: | ||
75 | - train_transform.transforms.append(transforms.RandomCrop(32, padding=4)) | ||
76 | - train_transform.transforms.append(transforms.RandomHorizontalFlip()) | ||
77 | -train_transform.transforms.append(transforms.ToTensor()) | ||
78 | -train_transform.transforms.append(normalize) | ||
79 | -if args.cutout: | ||
80 | - train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length)) | ||
81 | - | ||
82 | - | ||
83 | -test_transform = transforms.Compose([ | ||
84 | - transforms.ToTensor(), | ||
85 | - normalize]) | ||
86 | - | ||
87 | -if args.dataset == 'cifar10': | ||
88 | - num_classes = 10 | ||
89 | - train_dataset = datasets.CIFAR10(root='data/', | ||
90 | - train=True, | ||
91 | - transform=train_transform, | ||
92 | - download=True) | ||
93 | - | ||
94 | - test_dataset = datasets.CIFAR10(root='data/', | ||
95 | - train=False, | ||
96 | - transform=test_transform, | ||
97 | - download=True) | ||
98 | -elif args.dataset == 'cifar100': | ||
99 | - num_classes = 100 | ||
100 | - train_dataset = datasets.CIFAR100(root='data/', | ||
101 | - train=True, | ||
102 | - transform=train_transform, | ||
103 | - download=True) | ||
104 | - | ||
105 | - test_dataset = datasets.CIFAR100(root='data/', | ||
106 | - train=False, | ||
107 | - transform=test_transform, | ||
108 | - download=True) | ||
109 | -elif args.dataset == 'svhn': | ||
110 | - num_classes = 10 | ||
111 | - train_dataset = datasets.SVHN(root='data/', | ||
112 | - split='train', | ||
113 | - transform=train_transform, | ||
114 | - download=True) | ||
115 | - | ||
116 | - extra_dataset = datasets.SVHN(root='data/', | ||
117 | - split='extra', | ||
118 | - transform=train_transform, | ||
119 | - download=True) | ||
120 | - | ||
121 | - # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf) | ||
122 | - data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0) | ||
123 | - labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0) | ||
124 | - train_dataset.data = data | ||
125 | - train_dataset.labels = labels | ||
126 | - | ||
127 | - test_dataset = datasets.SVHN(root='data/', | ||
128 | - split='test', | ||
129 | - transform=test_transform, | ||
130 | - download=True) | ||
131 | - | ||
132 | -# Data Loader (Input Pipeline) | ||
133 | -train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
134 | - batch_size=args.batch_size, | ||
135 | - shuffle=False, | ||
136 | - # sampler=RandomSampler(train_dataset, True, 40000), | ||
137 | - pin_memory=True, | ||
138 | - num_workers=0) | ||
139 | - | ||
140 | -test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | ||
141 | - batch_size=args.batch_size, | ||
142 | - shuffle=False, | ||
143 | - pin_memory=True, | ||
144 | - num_workers=0) | ||
145 | - | ||
146 | -if args.model == 'resnet18': | ||
147 | - cnn = ResNet18(num_classes=num_classes) | ||
148 | -elif args.model == 'wideresnet': | ||
149 | - if args.dataset == 'svhn': | ||
150 | - cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8, | ||
151 | - dropRate=0.4) | ||
152 | - else: | ||
153 | - cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, | ||
154 | - dropRate=0.3) | ||
155 | -checkpoint = torch.load('/content/drive/MyDrive/capstone/Cutout/checkpoints/baseline_cifar10_resnet18.pt', map_location = torch.device('cuda:0')) | ||
156 | -cnn = cnn.cuda() | ||
157 | -cnn.load_state_dict(checkpoint) | ||
158 | -criterion = nn.CrossEntropyLoss().cuda() | ||
159 | -cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate, | ||
160 | - momentum=0.9, nesterov=True, weight_decay=5e-4) | ||
161 | - | ||
162 | -if args.dataset == 'svhn': | ||
163 | - scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1) | ||
164 | -else: | ||
165 | - scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2) | ||
166 | - | ||
167 | -filename = 'logs/' + test_id + '.csv' | ||
168 | -csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc'], filename=filename) | ||
169 | - | ||
170 | - | ||
171 | -def test(loader): | ||
172 | - cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). | ||
173 | - correct = 0. | ||
174 | - total = 0. | ||
175 | - for images, labels in loader: | ||
176 | - images = images.cuda() | ||
177 | - labels = labels.cuda() | ||
178 | - | ||
179 | - with torch.no_grad(): | ||
180 | - pred = cnn(images) | ||
181 | - | ||
182 | - pred = torch.max(pred.data, 1)[1] | ||
183 | - | ||
184 | - total += labels.size(0) | ||
185 | - correct += (pred == labels).sum().item() | ||
186 | - | ||
187 | - val_acc = correct / total | ||
188 | - cnn.train() | ||
189 | - return val_acc | ||
190 | - | ||
191 | -kl_sum = 0 | ||
192 | -y_bar = torch.Tensor([0] * 10).detach().cuda() | ||
193 | - | ||
194 | -# y_bar 구하는 epoch | ||
195 | -for epoch in range(args.epochs): | ||
196 | - | ||
197 | - cnn.eval() | ||
198 | - xentropy_loss_avg = 0. | ||
199 | - correct = 0. | ||
200 | - total = 0. | ||
201 | - norm_const = 0 | ||
202 | - | ||
203 | - kldiv = 0 | ||
204 | - # pred_sum = torch.Tensor([0] * 10).detach().cuda() | ||
205 | - | ||
206 | - progress_bar = tqdm(train_loader) | ||
207 | - for i, (images, labels) in enumerate(progress_bar): | ||
208 | - progress_bar.set_description('Epoch ' + str(epoch)) | ||
209 | - | ||
210 | - images = images.cuda() | ||
211 | - labels = labels.cuda() | ||
212 | - | ||
213 | - cnn.zero_grad() | ||
214 | - pred = cnn(images) | ||
215 | - xentropy_loss = criterion(pred, labels) | ||
216 | - # xentropy_loss.backward() | ||
217 | - # cnn_optimizer.step() | ||
218 | - | ||
219 | - xentropy_loss_avg += xentropy_loss.item() | ||
220 | - | ||
221 | - pred_softmax = nn.functional.softmax(pred).cuda() | ||
222 | - # Calculate running average of accuracy | ||
223 | - pred = torch.max(pred.data, 1)[1] | ||
224 | - total += labels.size(0) | ||
225 | - correct += (pred == labels.data).sum().item() | ||
226 | - accuracy = correct / total | ||
227 | - for a in range(pred_softmax.data.size()[0]): | ||
228 | - for b in range(y_bar.size()[0]): | ||
229 | - y_bar[b] += torch.log(pred_softmax.data[a][b]) | ||
230 | - | ||
231 | - | ||
232 | - # expectation(log y_hat) | ||
233 | - # y_bar = [x / pred.data.size()[0] for x in y_bar] | ||
234 | - | ||
235 | - # print(pred.data.size()[0], y_bar.size()[0]) # 128, 10 | ||
236 | - | ||
237 | - | ||
238 | - # print(pred) | ||
239 | - # y_hat : 모델별 예측값 --> pred_softmax | ||
240 | - # y_bar : 예측값들 평균값 -- > pred / total : pred_sum | ||
241 | - # labes.data : ground_truth | ||
242 | - | ||
243 | - # y_bar = pred_sum / (i+1) | ||
244 | - # kl = torch.nn.functional.kl_div(pred, y_bar) | ||
245 | - # kl_sum += kl | ||
246 | - | ||
247 | - | ||
248 | - | ||
249 | - # for문 추가안하면 epoch별 iter마다 xentropy_loss_avg값의 1/iter이 xentropy값으로 출력 | ||
250 | - # for문 추가하면 epoch 별 iter 마다 xentropy_loss_avg 값은 동일하나 xentropy값 출력이 x_l_avg 값의 1/10으로 출력 | ||
251 | - # for문 상관 없이 pred, labels 값은 동일하게 확인됨. | ||
252 | - | ||
253 | - # for a in range(list(pred_sum.size())[0]): | ||
254 | - # for b in range(list(pred.size())[0]): | ||
255 | - # if pred[b] == a: | ||
256 | - # pred_sum[a] += 1 | ||
257 | - | ||
258 | - # variance calculate : E[KL_div(y_bar, y_hat)] -> expectation of KLDivLoss(pred_sum, pred) | ||
259 | - # 한 epoch마다 계산해서 출력해야 할듯 | ||
260 | - # nn.functional.kl_div(pred_sum, pred) | ||
261 | - | ||
262 | - | ||
263 | - # print('\n',i, ' ', xentropy_loss_avg) | ||
264 | - progress_bar.set_postfix( | ||
265 | - # y_hat = '%.5f' % pred, | ||
266 | - # y_bar = '%.5f' % y_bar, | ||
267 | - # groun_truth = '%.5f' % labels.data, | ||
268 | - # kl = '%.3f' % kl.item(), | ||
269 | - # kl_sum = '%.3f' % (kl_sum.item()), | ||
270 | - # kl_div = '%.3f' % (kl_sum.item() / (i + 1)), # kl_div 호출 | ||
271 | - xentropy='%.3f' % (xentropy_loss_avg / (i + 1)), | ||
272 | - acc='%.3f' % accuracy) | ||
273 | - # pred_sum = [x / 40000 for x in pred_sum] | ||
274 | - y_bar = torch.Tensor([x / 50000 for x in y_bar]).cuda() | ||
275 | - y_bar = torch.exp(y_bar) | ||
276 | - # print(y_bar) | ||
277 | - for index in range(y_bar.size()[0]): | ||
278 | - norm_const += y_bar[index] | ||
279 | - print(y_bar) | ||
280 | - print(norm_const) | ||
281 | - # print(norm_const) | ||
282 | - for index in range(y_bar.size()[0]): | ||
283 | - y_bar[index] = y_bar[index] / norm_const | ||
284 | - print(y_bar) | ||
285 | - # print(y_bar) | ||
286 | - # print(pred_softmax) | ||
287 | - # print(y_bar) | ||
288 | - # kldiv = torch.nn.functional.kl_div(y_bar, pred_softmax, reduction='batchmean') | ||
289 | - # kl_sum += kldiv | ||
290 | - # print(kldiv, kl_sum) | ||
291 | - y_bar_copy = y_bar.clone().detach() | ||
292 | - test_acc = test(test_loader) | ||
293 | - # print(pred, labels.data) | ||
294 | - tqdm.write('test_acc: %.3f' % (test_acc)) | ||
295 | - | ||
296 | - scheduler.step(epoch) # Use this line for PyTorch <1.4 | ||
297 | - # scheduler.step() # Use this line for PyTorch >=1.4 | ||
298 | - | ||
299 | - row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc) | ||
300 | - } | ||
301 | - csv_logger.writerow(row) | ||
302 | - del pred | ||
303 | - torch.cuda.empty_cache() | ||
304 | - | ||
305 | -# kl_div 구하는 epoch | ||
306 | -for epoch in range(args.epochs): | ||
307 | - cnn.eval() | ||
308 | - kldiv = 0 | ||
309 | - for i, (images, labels) in enumerate(progress_bar): | ||
310 | - progress_bar.set_description('Epoch ' + str(epoch) + ': Calculate kl_div') | ||
311 | - | ||
312 | - images = images.cuda() | ||
313 | - labels = labels.cuda() | ||
314 | - | ||
315 | - cnn.zero_grad() | ||
316 | - pred = cnn(images) | ||
317 | - | ||
318 | - pred_softmax = nn.functional.softmax(pred).cuda() | ||
319 | - | ||
320 | - # 입력 두 개의 shape이 다르면 batchsize로 평균을 내서 반환. | ||
321 | - kldiv = torch.nn.functional.kl_div(y_bar_copy, pred_softmax, reduction='sum') | ||
322 | - kl_sum += kldiv.detach() | ||
323 | - # print(y_bar_copy.size(), pred_softmax.size()) | ||
324 | - # print(kl_sum) | ||
325 | - print("Average KL_div : ", abs(kl_sum / 50000)) | ||
326 | - # y_bar = torch.Tensor([x / 40000 for x in y_bar]).cuda() | ||
327 | - # y_bar = torch.exp(y_bar) | ||
328 | - # # print(y_bar) | ||
329 | - # for index in range(y_bar.size()[0]): | ||
330 | - # norm_const += y_bar[index] | ||
331 | - # # print(norm_const) | ||
332 | - # for index in range(y_bar.size()[0]): | ||
333 | - # y_bar[index] = y_bar[index] / norm_const | ||
334 | - # # print(y_bar) | ||
335 | - # # print(pred_softmax) | ||
336 | - # # print(y_bar) | ||
337 | - # kldiv = torch.nn.functional.kl_div(y_bar, pred_softmax, reduction='batchmean') | ||
338 | - # kl_sum += kldiv | ||
339 | - # print(kldiv, kl_sum) | ||
340 | - | ||
341 | -torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt') | ||
342 | -csv_logger.close() |
code/test_n%.py
deleted
100644 → 0
1 | -# run train.py --dataset cifar10 --model resnet18 --data_augmentation --cutout --length 16 | ||
2 | -# run train.py --dataset cifar100 --model resnet18 --data_augmentation --cutout --length 8 | ||
3 | -# run train.py --dataset svhn --model wideresnet --learning_rate 0.01 --epochs 160 --cutout --length 20 | ||
4 | - | ||
5 | -import pdb | ||
6 | -import argparse | ||
7 | -import numpy as np | ||
8 | -from tqdm import tqdm | ||
9 | -import os | ||
10 | - | ||
11 | -import torch | ||
12 | -import torch.nn as nn | ||
13 | -from torch.autograd import Variable | ||
14 | -import torch.backends.cudnn as cudnn | ||
15 | -from torch.optim.lr_scheduler import MultiStepLR | ||
16 | - | ||
17 | -from torchvision.utils import make_grid, save_image | ||
18 | -from torchvision import datasets, transforms | ||
19 | -from torchvision.transforms.transforms import ToTensor | ||
20 | -from torchvision.datasets import ImageFolder | ||
21 | -from torch.utils.data import Dataset, DataLoader | ||
22 | - | ||
23 | -from torch.utils.data.dataloader import RandomSampler | ||
24 | -from util.misc import CSVLogger | ||
25 | -from util.cutout import Cutout | ||
26 | - | ||
27 | -from model.resnet import ResNet18 | ||
28 | -from model.wide_resnet import WideResNet | ||
29 | - | ||
30 | -from PIL import Image | ||
31 | -from matplotlib.pyplot import imshow | ||
32 | -import time | ||
33 | - | ||
34 | -def csv2list(filename): | ||
35 | - lists = [] | ||
36 | - file = open(filename, 'r', encoding='utf-8-sig') | ||
37 | - while True: | ||
38 | - line = file.readline().strip("\n") | ||
39 | - # int_list = [int(i) for i in line] | ||
40 | - if line: | ||
41 | - line = line.split(",") | ||
42 | - lists.append(line) | ||
43 | - else: | ||
44 | - break | ||
45 | - return lists | ||
46 | - | ||
47 | -# variance순으로 정렬된 logs파일에서 읽어오기 | ||
48 | -filelist = csv2list("C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/cutout/logs/image_save/1_5000_deleted.csv") | ||
49 | -for i in range(len(filelist)): | ||
50 | - for j in range(len(filelist[0])): | ||
51 | - filelist[i][j] = float(filelist[i][j]) | ||
52 | -transposelist = np.transpose(filelist) | ||
53 | - | ||
54 | -# print(list) | ||
55 | -list_tensor = torch.tensor(transposelist, dtype=torch.long) | ||
56 | -target = list(list_tensor[2]) | ||
57 | -train_img_list = list() | ||
58 | - | ||
59 | -for img_idx in transposelist[1]: | ||
60 | - img_path = "C:/Users/82109/Desktop/model1/img" + str(int(img_idx)) + ".png" | ||
61 | - train_img_list.append(img_path) | ||
62 | - | ||
63 | - | ||
64 | -class Img_Dataset(Dataset): | ||
65 | - | ||
66 | - def __init__(self,file_list,transform): | ||
67 | - self.file_list = file_list | ||
68 | - self.transform = transform | ||
69 | - | ||
70 | - def __len__(self): | ||
71 | - return len(self.file_list) | ||
72 | - | ||
73 | - def __getitem__(self, index): | ||
74 | - img_path = self.file_list[index] | ||
75 | - images = np.array(Image.open(img_path)) | ||
76 | - # img_transformed = self.transform(images) | ||
77 | - | ||
78 | - labels = target[index] | ||
79 | - | ||
80 | - return images, labels | ||
81 | -# print(list_tensor) | ||
82 | -# topk 갯수 설정 | ||
83 | -# k = 5000 | ||
84 | -# values, indices = torch.topk(list_tensor[0], k) | ||
85 | -# # print(values) | ||
86 | -# image_list = [] | ||
87 | -# for i in range(k): | ||
88 | -# image_list.append(int(list_tensor[1][49999-i].item())) | ||
89 | - | ||
90 | -# for i in image_list: | ||
91 | -# file = "C:/Users/82109/Desktop/1/1/img{0}.png".format(i) | ||
92 | -# if os.path.isfile(file): | ||
93 | -# os.remove(file) | ||
94 | - | ||
95 | -# transform_train = transforms.Compose([ transforms.ToTensor(), ]) | ||
96 | - | ||
97 | - | ||
98 | - | ||
99 | - | ||
100 | -# normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], | ||
101 | -# std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) | ||
102 | - | ||
103 | - | ||
104 | - | ||
105 | - | ||
106 | -model_options = ['resnet18', 'wideresnet'] | ||
107 | -dataset_options = ['cifar10', 'cifar100', 'svhn'] | ||
108 | - | ||
109 | -parser = argparse.ArgumentParser(description='CNN') | ||
110 | -parser.add_argument('--dataset', '-d', default='cifar10', | ||
111 | - choices=dataset_options) | ||
112 | -parser.add_argument('--model', '-a', default='resnet18', | ||
113 | - choices=model_options) | ||
114 | -parser.add_argument('--batch_size', type=int, default=100, | ||
115 | - help='input batch size for training (default: 128)') | ||
116 | -parser.add_argument('--epochs', type=int, default=200, | ||
117 | - help='number of epochs to train (default: 20)') | ||
118 | -parser.add_argument('--learning_rate', type=float, default=0.1, | ||
119 | - help='learning rate') | ||
120 | -parser.add_argument('--data_augmentation', action='store_true', default=False, | ||
121 | - help='augment data by flipping and cropping') | ||
122 | -parser.add_argument('--cutout', action='store_true', default=False, | ||
123 | - help='apply cutout') | ||
124 | -parser.add_argument('--n_holes', type=int, default=0, | ||
125 | - help='number of holes to cut out from image') | ||
126 | -parser.add_argument('--length', type=int, default=0, | ||
127 | - help='length of the holes') | ||
128 | -parser.add_argument('--no-cuda', action='store_true', default=False, | ||
129 | - help='enables CUDA training') | ||
130 | -parser.add_argument('--seed', type=int, default=0, | ||
131 | - help='random seed (default: 1)') | ||
132 | - | ||
133 | -args = parser.parse_args() | ||
134 | -args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
135 | -cudnn.benchmark = True # Should make training should go faster for large models | ||
136 | - | ||
137 | -torch.manual_seed(args.seed) | ||
138 | -if args.cuda: | ||
139 | - torch.cuda.manual_seed(args.seed) | ||
140 | - | ||
141 | -test_id = args.dataset + '_' + args.model | ||
142 | - | ||
143 | -print(args) | ||
144 | - | ||
145 | -# Image Preprocessing | ||
146 | -if args.dataset == 'svhn': | ||
147 | - normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]], | ||
148 | - std=[x / 255.0 for x in [50.1, 50.6, 50.8]]) | ||
149 | -else: | ||
150 | - normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], | ||
151 | - std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) | ||
152 | - | ||
153 | -train_transform = transforms.Compose([]) | ||
154 | -if args.data_augmentation: | ||
155 | - train_transform.transforms.append(transforms.RandomCrop(32, padding=4)) | ||
156 | - train_transform.transforms.append(transforms.RandomHorizontalFlip()) | ||
157 | -train_transform.transforms.append(transforms.ToTensor()) | ||
158 | -train_transform.transforms.append(normalize) | ||
159 | -if args.cutout: | ||
160 | - train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length)) | ||
161 | - | ||
162 | -test_transform = transforms.Compose([ | ||
163 | - transforms.ToTensor(), | ||
164 | - normalize]) | ||
165 | - | ||
166 | - | ||
167 | -if args.dataset == 'cifar10': | ||
168 | - num_classes = 10 | ||
169 | - train_dataset = Img_Dataset(file_list = train_img_list, | ||
170 | - transform=train_transform) | ||
171 | - # custom_dataset = ImageFolder(root='C:/Users/82109/Desktop/1/', transform = transform_train) | ||
172 | - | ||
173 | - # train_dataset = datasets.CIFAR10(root='data/', | ||
174 | - # train=True, | ||
175 | - # transform=train_transform, | ||
176 | - # download=True) | ||
177 | - | ||
178 | - test_dataset = datasets.CIFAR10(root='data/', | ||
179 | - train=False, | ||
180 | - transform=test_transform, | ||
181 | - download=True) | ||
182 | -# elif args.dataset == 'cifar100': | ||
183 | -# num_classes = 100 | ||
184 | -# train_dataset = datasets.CIFAR100(root='data/', | ||
185 | -# train=True, | ||
186 | -# transform=train_transform, | ||
187 | -# download=True) | ||
188 | - | ||
189 | -# test_dataset = datasets.CIFAR100(root='data/', | ||
190 | -# train=False, | ||
191 | -# transform=test_transform, | ||
192 | -# download=True) | ||
193 | -# elif args.dataset == 'svhn': | ||
194 | -# num_classes = 10 | ||
195 | -# train_dataset = datasets.SVHN(root='data/', | ||
196 | -# split='train', | ||
197 | -# transform=train_transform, | ||
198 | -# download=True) | ||
199 | - | ||
200 | -# extra_dataset = datasets.SVHN(root='data/', | ||
201 | -# split='extra', | ||
202 | -# transform=train_transform, | ||
203 | -# download=True) | ||
204 | - | ||
205 | -# # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf) | ||
206 | -# data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0) | ||
207 | -# labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0) | ||
208 | -# train_dataset.data = data | ||
209 | -# train_dataset.labels = labels | ||
210 | - | ||
211 | -# test_dataset = datasets.SVHN(root='data/', | ||
212 | -# split='test', | ||
213 | -# transform=test_transform, | ||
214 | -# download=True) | ||
215 | - | ||
216 | -# # Data Loader (Input Pipeline) | ||
217 | -# train_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True,num_workers=0) | ||
218 | -train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
219 | - batch_size=args.batch_size, | ||
220 | - shuffle=False, | ||
221 | - # sampler=RandomSampler(train_dataset, True, 40000), | ||
222 | - pin_memory=True, | ||
223 | - num_workers=0) | ||
224 | - | ||
225 | -test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | ||
226 | - batch_size=args.batch_size, | ||
227 | - shuffle=False, | ||
228 | - pin_memory=True, | ||
229 | - num_workers=0) | ||
230 | - | ||
231 | -if args.model == 'resnet18': | ||
232 | - cnn = ResNet18(num_classes=num_classes) | ||
233 | -elif args.model == 'wideresnet': | ||
234 | - if args.dataset == 'svhn': | ||
235 | - cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8, | ||
236 | - dropRate=0.4) | ||
237 | - else: | ||
238 | - cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, | ||
239 | - dropRate=0.3) | ||
240 | - | ||
241 | - | ||
242 | -cnn = cnn.cuda() | ||
243 | - | ||
244 | -criterion = nn.CrossEntropyLoss().cuda() | ||
245 | -cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate, | ||
246 | - momentum=0.9, nesterov=True, weight_decay=5e-4) | ||
247 | - | ||
248 | -# scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2) | ||
249 | -if args.dataset == 'svhn': | ||
250 | - scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1) | ||
251 | -else: | ||
252 | - scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2) | ||
253 | - | ||
254 | -test_id = 'custom_dataset_resnet18' | ||
255 | - | ||
256 | -filename = 'logs/' + test_id + '.csv' | ||
257 | -csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc'], filename=filename) | ||
258 | - | ||
259 | - | ||
260 | -def test(loader): | ||
261 | - cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). | ||
262 | - correct = 0. | ||
263 | - total = 0. | ||
264 | - count = 0 | ||
265 | - for images, labels in loader: | ||
266 | - images = images.cuda() | ||
267 | - labels = labels.cuda() | ||
268 | - | ||
269 | - with torch.no_grad(): | ||
270 | - pred = cnn(images) | ||
271 | - | ||
272 | - pred = torch.max(pred.data, 1)[1] | ||
273 | - | ||
274 | - total += labels.size(0) | ||
275 | - # if (pred == labels).sum().item(): | ||
276 | - # print('match') | ||
277 | - # count +=1 | ||
278 | - correct += (pred == labels).sum().item() | ||
279 | - val_acc = correct / total | ||
280 | - cnn.train() | ||
281 | - return val_acc | ||
282 | - | ||
283 | -# kl_sum = 0 | ||
284 | -# y_bar = torch.zeros(8, 10).detach().cuda() | ||
285 | - | ||
286 | -# y_bar 구하는 epoch | ||
287 | -for epoch in range(1): | ||
288 | - xentropy_loss_avg = 0. | ||
289 | - correct = 0. | ||
290 | - total = 0. | ||
291 | - norm_const = 0 | ||
292 | - | ||
293 | - kldiv = 0 | ||
294 | - # pred_sum = torch.Tensor([0] * 10).detach().cuda() | ||
295 | - count = 0 | ||
296 | - progress_bar = tqdm(train_loader) | ||
297 | - for i, (images, labels) in enumerate(progress_bar): | ||
298 | - progress_bar.set_description('Epoch ' + str(epoch)) | ||
299 | - | ||
300 | - images = Variable(images.view([args.batch_size,3,32,32]).float().cuda()) | ||
301 | - labels = Variable(labels.float().cuda()) | ||
302 | - labels = torch.tensor(labels, dtype=torch.long, device=torch.device('cuda:0')) | ||
303 | - cnn.zero_grad() | ||
304 | - pred = cnn(images) | ||
305 | - xentropy_loss = criterion(pred, labels) | ||
306 | - xentropy_loss.backward() | ||
307 | - cnn_optimizer.step() | ||
308 | - | ||
309 | - xentropy_loss_avg += xentropy_loss.item() | ||
310 | - | ||
311 | - pred_softmax = nn.functional.softmax(pred).cuda() | ||
312 | - # Calculate running average of accuracy | ||
313 | - pred = torch.max(pred.data, 1)[1] | ||
314 | - | ||
315 | - total += labels.size(0) | ||
316 | - correct += (pred == labels.data).sum().item() | ||
317 | - accuracy = correct / total | ||
318 | - # for a in range(pred_softmax.data.size()[0]): | ||
319 | - # for b in range(y_bar.size()[1]): | ||
320 | - # y_bar[epoch][b] += torch.log(pred_softmax.data[a][b]) | ||
321 | - | ||
322 | - | ||
323 | - progress_bar.set_postfix( | ||
324 | - xentropy='%.3f' % (xentropy_loss_avg / (i + 1)), | ||
325 | - acc='%.3f' % accuracy) | ||
326 | - # count += 1 | ||
327 | - # xentropy = xentropy_loss_avg / count | ||
328 | - # y_bar[epoch] = torch.Tensor([x / 50000 for x in y_bar[epoch]]).cuda() | ||
329 | - # y_bar[epoch] = torch.exp(y_bar[epoch]) | ||
330 | - # for index in range(y_bar.size()[1]): | ||
331 | - # norm_const += y_bar[epoch][index] | ||
332 | - # for index in range(y_bar.size()[1]): | ||
333 | - # y_bar[epoch][index] = y_bar[epoch][index] / norm_const | ||
334 | - # print("y_bar[{0}] : ".format(epoch), y_bar[epoch]) | ||
335 | - test_acc = test(test_loader) | ||
336 | - # print(pred, labels.data) | ||
337 | - tqdm.write('test_acc: %.3f' % (test_acc)) | ||
338 | - | ||
339 | - scheduler.step() # Use this line for PyTorch <1.4 | ||
340 | - # scheduler.step() # Use this line for PyTorch >=1.4 | ||
341 | - | ||
342 | - row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc)} | ||
343 | - csv_logger.writerow(row) | ||
344 | - # del pred | ||
345 | - # torch.cuda.empty_cache() | ||
346 | - | ||
347 | - | ||
348 | -# var_tensor = torch.zeros(8, 50000).detach().cuda() | ||
349 | -# var_addeachcol = torch.zeros(1, 50000).detach().cuda() | ||
350 | - | ||
351 | -# # kl_div 구하는 epoch | ||
352 | -# for epoch in range(1): | ||
353 | -# checkpoint = torch.load('C:/Users/82109/Desktop/캡디/캡디자료들/논문모델/Cutout/checkpoints/sampling/sampling_{0}.pt'.format(8), map_location = torch.device('cuda:0')) | ||
354 | -# cnn.load_state_dict(checkpoint) | ||
355 | -# cnn.eval() | ||
356 | -# kldiv = 0 | ||
357 | -# for i, (images, labels) in enumerate(progress_bar): | ||
358 | -# progress_bar.set_description('Epoch ' + str(epoch) + ': Calculate kl_div') | ||
359 | - | ||
360 | -# images = images.cuda() | ||
361 | -# labels = labels.cuda() | ||
362 | - | ||
363 | -# cnn.zero_grad() | ||
364 | -# pred = cnn(images) | ||
365 | - | ||
366 | -# pred_softmax = nn.functional.softmax(pred).cuda() | ||
367 | - | ||
368 | -# # 입력 두 개의 shape이 다르면 batchsize로 평균을 내서 반환. | ||
369 | -# kldiv = torch.nn.functional.kl_div(y_bar[epoch], pred_softmax, reduction='sum') | ||
370 | -# # 1 * 50000에 한 모델의 데이터별 variance 저장 | ||
371 | -# var_tensor[epoch][i] += abs(kldiv).detach() | ||
372 | -# var_addeachcol[0][i] += var_tensor[epoch][i] | ||
373 | -# kl_sum += kldiv.detach() | ||
374 | -# # print(y_bar_copy.size(), pred_softmax.size()) | ||
375 | -# # print(kl_sum) | ||
376 | -# var = abs(kl_sum.item() / 50000) | ||
377 | -# print("Variance : ", var) | ||
378 | -# csv_logger.writerow({'var' : float(var)}) | ||
379 | -# # print(var_tensor) | ||
380 | -# for i in range(var_addeachcol.size()[1]): | ||
381 | -# var_addeachcol[0][i] = var_addeachcol[0][i] / 8 | ||
382 | - | ||
383 | -# print(var_addeachcol) | ||
384 | -# # var_addeachcol[0] = torch.Tensor([x / 8 for x in var_addeachcol]).cuda() | ||
385 | -# var_sorted = torch.argsort(var_addeachcol) | ||
386 | -# print(var_sorted) | ||
387 | -# for i in range(var_addeachcol.size()[1]): | ||
388 | -# csv_logger.writerow({'avg_var' : float(var_addeachcol[0][i]), 'arg_var' : float(var_sorted[0][i]), 'index' : float(i + 1)}) | ||
389 | -torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt') | ||
390 | -csv_logger.close() |
-
Please register or login to post a comment