김지훈

quantization

1 +# 추론시간 개선 - 양자화 시도
2 +
3 +## Pytorch quantization - 학습해도 cpu 에서만 실행 가능, 모델의 채널을 신중하게 고르지 않으면 속도 개선 미미함. 또한 양자화 과정으로 학습된 모델은 pytorch model -> onnx -> tensorRT 변환이 불가능하여 gpu 에서 실행 불가능 학습해도 cpu 에서만 실행 가능, 모델의 채널을 신중하게 고르지 않으면 속도 개선 미미함
4 +
5 +TensorRT - 양자화 학습을 사용하지 않고 바로 정밀도 감소 및 양자화 시도. float16 은 10% 정도 속도가 개선되었으나, int8 은 실패함 (사용법 미숙, 입력 값이 0.0 ~ 1.0 등)
...\ No newline at end of file ...\ No newline at end of file
1 +import utils
2 +import copy
3 +from collections import OrderedDict
4 +
5 +import model
6 +import dataset
7 +
8 +import importlib
9 +importlib.reload(utils)
10 +importlib.reload(model)
11 +importlib.reload(dataset)
12 +
13 +from utils import *
14 +import torch.quantization
15 +
16 +
17 +def add_args(parser):
18 + # parser.add_argument('--model', type=str, default='moderate-cnn',
19 + # help='neural network used in training')
20 + parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
21 + help='dataset used for training')
22 + parser.add_argument('--fold_num', type=int, default=0,
23 + help='5-fold, 0 ~ 4')
24 + parser.add_argument('--batch_size', type=int, default=256, metavar='N',
25 + help='input batch size for training')
26 + parser.add_argument('--lr', type=float, default=0.002, metavar='LR',
27 + help='learning rate')
28 + parser.add_argument('--n_nets', type=int, default=100, metavar='NN',
29 + help='number of workers in a distributed cluster')
30 + parser.add_argument('--comm_type', type=str, default='fedtwa',
31 + help='which type of communication strategy is going to be used: layerwise/blockwise')
32 + parser.add_argument('--comm_round', type=int, default=10,
33 + help='how many round of communications we shoud use')
34 + args = parser.parse_args(args=[])
35 + return args
36 +
37 +
38 +def start_fedavg(fed_model, args,
39 + train_data_set,
40 + data_idx_map,
41 + net_data_count,
42 + testloader,
43 + device):
44 + print("start fed avg")
45 + criterion = nn.CrossEntropyLoss()
46 + C = 0.1
47 + num_edge = int(max(C * args.n_nets, 1))
48 + total_data_count = 0
49 + for _, data_count in net_data_count.items():
50 + total_data_count += data_count
51 + print("total data: %d" % total_data_count)
52 +
53 + # quantize
54 + # fed_model.eval()
55 + # torch.jit.save(torch.jit.script(fed_model), './float.pth')
56 + # return
57 + fed_model.fuse_model()
58 + # modules_to_fuse = [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']]
59 + # torch.quantization.fuse_modules(fed_model, modules_to_fuse, inplace=True)
60 +
61 + fed_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
62 + torch.quantization.prepare_qat(fed_model, inplace=True)
63 +
64 + # for making shape of weight_fake_quant.scale
65 + train_data_set.set_idx_map([0])
66 + fed_model(torch.from_numpy(np.expand_dims(train_data_set[0][0], axis=0)).float())
67 +
68 + edges, _, _ = init_models(args.n_nets, args)
69 + # edges = [copy.deepcopy(fed_model) for net_cnt in range(args.n_nets)]
70 + for edge_now in edges:
71 + edge_now.fuse_model()
72 + edge_now.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
73 + torch.quantization.prepare_qat(edge_now, inplace=True)
74 + edge_now(torch.from_numpy(np.expand_dims(train_data_set[0][0], axis=0)).float())
75 +
76 + # print('quantized \n', edges[edge_index].conv1)
77 + # end
78 +
79 + for cr in range(1, args.comm_round + 1):
80 + print("Communication round : %d" % (cr))
81 +
82 + np.random.seed(cr) # make sure for each comparison, select the same clients each round
83 + selected_edge = np.random.choice(args.n_nets, num_edge, replace=False)
84 + print("selected edge", selected_edge)
85 +
86 + for edge_progress, edge_index in enumerate(selected_edge):
87 + train_data_set.set_idx_map(data_idx_map[edge_index])
88 + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size,
89 + shuffle=True, num_workers=2)
90 + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set)))
91 +
92 + edges[edge_index] = copy.deepcopy(fed_model)
93 + edges[edge_index].to(device)
94 + edges[edge_index].train()
95 + edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr)
96 + # train
97 + for data_idx, (inputs, labels) in enumerate(train_loader):
98 + inputs, labels = inputs.float().to(device), labels.long().to(device)
99 +
100 + edge_opt.zero_grad()
101 + # edge_opt[edge_index].zero_grad()
102 + edge_pred = edges[edge_index](inputs)
103 +
104 + edge_loss = criterion(edge_pred, labels)
105 + edge_loss.backward()
106 +
107 + edge_opt.step()
108 + # edge_opt[edge_index].step()
109 + edge_loss = edge_loss.item()
110 + if data_idx % 100 == 0:
111 + print('[%4d] loss: %.3f' % (data_idx, edge_loss))
112 + break
113 + edges[edge_index].to('cpu')
114 + # print(edge_index)
115 + # local_state = edges[edge_index].state_dict()
116 + # for key in edges[edge_index].state_dict().keys():
117 + # if 'activation_post_process' in key or 'fake_quant' in key:
118 + # print(key, local_state[key])
119 + # print()
120 + # return
121 + # cal weight using fed avg
122 + update_state = OrderedDict()
123 + for k, edge in enumerate(edges):
124 + local_state = edge.state_dict()
125 + for key in fed_model.state_dict().keys():
126 + # if 'zero_point' in key:
127 + # print(local_state[key])
128 + if 'activation_post_process' in key or 'fake_quant' in key:
129 + if k == 0:
130 + update_state[key] = local_state[key]
131 + else:
132 + update_state[key] += local_state[key]
133 + elif 'enable' in key:
134 + update_state[key] = local_state[key]
135 + else:
136 + if k == 0:
137 + update_state[key] = local_state[key] * (net_data_count[k] / total_data_count)
138 + else:
139 + update_state[key] += local_state[key] * (net_data_count[k] / total_data_count)
140 + # break
141 + for key in update_state.keys():
142 + if 'enable' in key:
143 + continue
144 + if 'activation_post_process' in key or 'fake_quant' in key:
145 + # print(key, update_state[key], update_state[key].type())
146 + # print(key, update_state[key])
147 + if torch.is_floating_point(update_state[key]):
148 + update_state[key] = update_state[key] / args.n_nets
149 + else:
150 + update_state[key] = torch.floor_divide(update_state[key], args.n_nets)
151 + # print(update_state[key])
152 +
153 + fed_model.load_state_dict(update_state)
154 + if cr % 1 == 0:
155 + fed_model.to(device)
156 + fed_model.eval()
157 +
158 + total_loss = 0.0
159 + cnt = 0
160 + step_acc = 0.0
161 + with torch.no_grad():
162 + for i, data in enumerate(testloader):
163 + inputs, labels = data
164 + inputs, labels = inputs.float().to(device), labels.long().to(device)
165 +
166 + outputs = fed_model(inputs)
167 + _, preds = torch.max(outputs, 1)
168 +
169 + loss = criterion(outputs, labels)
170 + cnt += inputs.shape[0]
171 +
172 + corr_sum = torch.sum(preds == labels.data)
173 + step_acc += corr_sum.double()
174 + running_loss = loss.item() * inputs.shape[0]
175 + total_loss += running_loss
176 + if i % 200 == 0:
177 + print('test [%4d] loss: %.3f' % (i, loss.item()))
178 + break
179 + print((step_acc / cnt).item())
180 + print(total_loss / cnt)
181 + fed_model.to('cpu')
182 + quantized_fed_model = torch.quantization.convert(fed_model.eval(), inplace=False)
183 + torch.jit.save(torch.jit.script(quantized_fed_model), './quan.pth')
184 +
185 +
186 +
187 +def start_train():
188 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
189 + print(device)
190 + args = add_args(argparse.ArgumentParser())
191 +
192 + seed = 0
193 + np.random.seed(seed)
194 + torch.manual_seed(seed)
195 +
196 + print("Loading data...")
197 + # kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt',
198 + # "./dataset/Fuzzy_dataset.csv" : './Fuzzy_dataset.txt',
199 + # "./dataset/RPM_dataset.csv" : './RPM_dataset.txt',
200 + # "./dataset/gear_dataset.csv" : './gear_dataset.txt'
201 + # }
202 + kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'}
203 + train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(args.n_nets, args.fold_num, **kwargs)
204 + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size,
205 + shuffle=False, num_workers=2)
206 +
207 + run_benchmark('./quan.pth', testloader)
208 + run_benchmark('./float.pth', testloader)
209 + # run_benchmark('./quan.pth', testloader)
210 +
211 + fed_model = model.Net()
212 + args.comm_type = 'fedavg'
213 + if args.comm_type == "fedavg":
214 + start_fedavg(fed_model, args,
215 + train_data_set,
216 + data_idx_map,
217 + net_data_count,
218 + testloader,
219 + device)
220 +
221 +if __name__ == "__main__":
222 + start_train()
1 +import torch.nn as nn
2 +import torch.nn.functional as F
3 +import torch
4 +from torch.quantization import QuantStub, DeQuantStub
5 +
6 +class Net(nn.Module):
7 + def __init__(self):
8 + super(Net, self).__init__()
9 +
10 + self.quant = QuantStub()
11 + self.dequant = DeQuantStub()
12 +
13 + self.conv1 = nn.Sequential(
14 + nn.Conv2d(1, 8, 3),
15 + nn.ReLU(True),
16 + )
17 + self.conv2 = nn.Sequential(
18 + nn.Conv2d(8, 8, 3),
19 + nn.ReLU(True),
20 + )
21 + self.conv3 = nn.Sequential(
22 + nn.Conv2d(8, 8, 3),
23 + nn.ReLU(True),
24 + )
25 + self.fc4 = nn.Linear(8 * 23 * 23, 2)
26 +
27 + def forward(self, x):
28 + x = self.quant(x)
29 + x = self.conv1(x)
30 + x = self.conv2(x)
31 + x = self.conv3(x)
32 + x = torch.flatten(x, 1)
33 + x = self.fc4(x)
34 + x = self.dequant(x)
35 + return x
36 +
37 + def fuse_model(self):
38 + for m in self.modules():
39 + if type(m) == nn.Sequential:
40 + torch.quantization.fuse_modules(m, ['0', '1'], inplace=True)
1 +import os
2 +import argparse
3 +import json
4 +import numpy as np
5 +import torch
6 +import torch.optim as optim
7 +import torch.nn as nn
8 +import torchvision
9 +import torchvision.transforms as transforms
10 +import torch.utils.data as data
11 +import math
12 +import copy
13 +import time
14 +
15 +import model
16 +import torch.quantization
17 +from torch.quantization import QuantStub, DeQuantStub
18 +
19 +
20 +def run_benchmark(model_file, img_loader):
21 + elapsed = 0
22 + # myModel = torch.jit.load(model_file)
23 + # torch.backends.quantized.engine='fbgemm'
24 + # myModel.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
25 + # myModel.eval()
26 + myModel = model.Net()
27 + # myModel = torch.quantization.quantize_dynamic(myModel, {torch.nn.Linear, torch.nn.Sequential}, dtype=torch.qint8)
28 + # print(myModel)
29 + # set quantization config for server (x86)
30 + myModel.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
31 + num_batches = 10
32 + # # insert observers
33 + torch.quantization.prepare(myModel, inplace=True)
34 + # # Calibrate the model and collect statistics
35 + with torch.no_grad():
36 + for i, (images, target) in enumerate(img_loader):
37 + images = images.float()
38 + target = target.long()
39 + if i < num_batches:
40 + start = time.time()
41 + output = myModel(images)
42 + end = time.time()
43 + # elapsed = elapsed + (end-start)
44 + else:
45 + break
46 +
47 + # # convert to quantized version
48 + torch.quantization.convert(myModel, inplace=True)
49 +
50 + # quant = QuantStub()
51 + with torch.no_grad():
52 + for i, (images, target) in enumerate(img_loader):
53 + images = images.float()
54 + target = target.long()
55 + if i < num_batches:
56 + start = time.time()
57 + output = myModel(images)
58 + end = time.time()
59 + elapsed = elapsed + (end-start)
60 + else:
61 + break
62 + num_images = images.size()[0] * num_batches
63 + print(elapsed)
64 + print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000))
65 + return elapsed
66 +
67 +
68 +def init_models(n_nets, args):
69 + models = []
70 + layer_shape = []
71 + layer_type = []
72 +
73 + for idx in range(n_nets):
74 + # if args.model == "lenet":
75 + # cnn = LeNet()
76 + # elif args.model == "vgg":
77 + # cnn = vgg11()
78 + models.append(model.Net())
79 +
80 + for (k, v) in models[0].state_dict().items():
81 + layer_shape.append(v.shape)
82 + layer_type.append(k)
83 +
84 + return models, layer_shape, layer_type
85 +
...\ No newline at end of file ...\ No newline at end of file