Showing
4 changed files
with
352 additions
and
0 deletions
코드/연합학습/quantization/README.md
0 → 100644
1 | +# 추론시간 개선 - 양자화 시도 | ||
2 | + | ||
3 | +## Pytorch quantization - 학습해도 cpu 에서만 실행 가능, 모델의 채널을 신중하게 고르지 않으면 속도 개선 미미함. 또한 양자화 과정으로 학습된 모델은 pytorch model -> onnx -> tensorRT 변환이 불가능하여 gpu 에서 실행 불가능 학습해도 cpu 에서만 실행 가능, 모델의 채널을 신중하게 고르지 않으면 속도 개선 미미함 | ||
4 | + | ||
5 | +TensorRT - 양자화 학습을 사용하지 않고 바로 정밀도 감소 및 양자화 시도. float16 은 10% 정도 속도가 개선되었으나, int8 은 실패함 (사용법 미숙, 입력 값이 0.0 ~ 1.0 등) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/연합학습/quantization/fed_train.py
0 → 100644
1 | +import utils | ||
2 | +import copy | ||
3 | +from collections import OrderedDict | ||
4 | + | ||
5 | +import model | ||
6 | +import dataset | ||
7 | + | ||
8 | +import importlib | ||
9 | +importlib.reload(utils) | ||
10 | +importlib.reload(model) | ||
11 | +importlib.reload(dataset) | ||
12 | + | ||
13 | +from utils import * | ||
14 | +import torch.quantization | ||
15 | + | ||
16 | + | ||
17 | +def add_args(parser): | ||
18 | + # parser.add_argument('--model', type=str, default='moderate-cnn', | ||
19 | + # help='neural network used in training') | ||
20 | + parser.add_argument('--dataset', type=str, default='cifar10', metavar='N', | ||
21 | + help='dataset used for training') | ||
22 | + parser.add_argument('--fold_num', type=int, default=0, | ||
23 | + help='5-fold, 0 ~ 4') | ||
24 | + parser.add_argument('--batch_size', type=int, default=256, metavar='N', | ||
25 | + help='input batch size for training') | ||
26 | + parser.add_argument('--lr', type=float, default=0.002, metavar='LR', | ||
27 | + help='learning rate') | ||
28 | + parser.add_argument('--n_nets', type=int, default=100, metavar='NN', | ||
29 | + help='number of workers in a distributed cluster') | ||
30 | + parser.add_argument('--comm_type', type=str, default='fedtwa', | ||
31 | + help='which type of communication strategy is going to be used: layerwise/blockwise') | ||
32 | + parser.add_argument('--comm_round', type=int, default=10, | ||
33 | + help='how many round of communications we shoud use') | ||
34 | + args = parser.parse_args(args=[]) | ||
35 | + return args | ||
36 | + | ||
37 | + | ||
38 | +def start_fedavg(fed_model, args, | ||
39 | + train_data_set, | ||
40 | + data_idx_map, | ||
41 | + net_data_count, | ||
42 | + testloader, | ||
43 | + device): | ||
44 | + print("start fed avg") | ||
45 | + criterion = nn.CrossEntropyLoss() | ||
46 | + C = 0.1 | ||
47 | + num_edge = int(max(C * args.n_nets, 1)) | ||
48 | + total_data_count = 0 | ||
49 | + for _, data_count in net_data_count.items(): | ||
50 | + total_data_count += data_count | ||
51 | + print("total data: %d" % total_data_count) | ||
52 | + | ||
53 | + # quantize | ||
54 | + # fed_model.eval() | ||
55 | + # torch.jit.save(torch.jit.script(fed_model), './float.pth') | ||
56 | + # return | ||
57 | + fed_model.fuse_model() | ||
58 | + # modules_to_fuse = [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']] | ||
59 | + # torch.quantization.fuse_modules(fed_model, modules_to_fuse, inplace=True) | ||
60 | + | ||
61 | + fed_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') | ||
62 | + torch.quantization.prepare_qat(fed_model, inplace=True) | ||
63 | + | ||
64 | + # for making shape of weight_fake_quant.scale | ||
65 | + train_data_set.set_idx_map([0]) | ||
66 | + fed_model(torch.from_numpy(np.expand_dims(train_data_set[0][0], axis=0)).float()) | ||
67 | + | ||
68 | + edges, _, _ = init_models(args.n_nets, args) | ||
69 | + # edges = [copy.deepcopy(fed_model) for net_cnt in range(args.n_nets)] | ||
70 | + for edge_now in edges: | ||
71 | + edge_now.fuse_model() | ||
72 | + edge_now.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') | ||
73 | + torch.quantization.prepare_qat(edge_now, inplace=True) | ||
74 | + edge_now(torch.from_numpy(np.expand_dims(train_data_set[0][0], axis=0)).float()) | ||
75 | + | ||
76 | + # print('quantized \n', edges[edge_index].conv1) | ||
77 | + # end | ||
78 | + | ||
79 | + for cr in range(1, args.comm_round + 1): | ||
80 | + print("Communication round : %d" % (cr)) | ||
81 | + | ||
82 | + np.random.seed(cr) # make sure for each comparison, select the same clients each round | ||
83 | + selected_edge = np.random.choice(args.n_nets, num_edge, replace=False) | ||
84 | + print("selected edge", selected_edge) | ||
85 | + | ||
86 | + for edge_progress, edge_index in enumerate(selected_edge): | ||
87 | + train_data_set.set_idx_map(data_idx_map[edge_index]) | ||
88 | + train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=args.batch_size, | ||
89 | + shuffle=True, num_workers=2) | ||
90 | + print("[%2d/%2d] edge: %d, data len: %d" % (edge_progress, len(selected_edge), edge_index, len(train_data_set))) | ||
91 | + | ||
92 | + edges[edge_index] = copy.deepcopy(fed_model) | ||
93 | + edges[edge_index].to(device) | ||
94 | + edges[edge_index].train() | ||
95 | + edge_opt = optim.Adam(params=edges[edge_index].parameters(), lr=args.lr) | ||
96 | + # train | ||
97 | + for data_idx, (inputs, labels) in enumerate(train_loader): | ||
98 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
99 | + | ||
100 | + edge_opt.zero_grad() | ||
101 | + # edge_opt[edge_index].zero_grad() | ||
102 | + edge_pred = edges[edge_index](inputs) | ||
103 | + | ||
104 | + edge_loss = criterion(edge_pred, labels) | ||
105 | + edge_loss.backward() | ||
106 | + | ||
107 | + edge_opt.step() | ||
108 | + # edge_opt[edge_index].step() | ||
109 | + edge_loss = edge_loss.item() | ||
110 | + if data_idx % 100 == 0: | ||
111 | + print('[%4d] loss: %.3f' % (data_idx, edge_loss)) | ||
112 | + break | ||
113 | + edges[edge_index].to('cpu') | ||
114 | + # print(edge_index) | ||
115 | + # local_state = edges[edge_index].state_dict() | ||
116 | + # for key in edges[edge_index].state_dict().keys(): | ||
117 | + # if 'activation_post_process' in key or 'fake_quant' in key: | ||
118 | + # print(key, local_state[key]) | ||
119 | + # print() | ||
120 | + # return | ||
121 | + # cal weight using fed avg | ||
122 | + update_state = OrderedDict() | ||
123 | + for k, edge in enumerate(edges): | ||
124 | + local_state = edge.state_dict() | ||
125 | + for key in fed_model.state_dict().keys(): | ||
126 | + # if 'zero_point' in key: | ||
127 | + # print(local_state[key]) | ||
128 | + if 'activation_post_process' in key or 'fake_quant' in key: | ||
129 | + if k == 0: | ||
130 | + update_state[key] = local_state[key] | ||
131 | + else: | ||
132 | + update_state[key] += local_state[key] | ||
133 | + elif 'enable' in key: | ||
134 | + update_state[key] = local_state[key] | ||
135 | + else: | ||
136 | + if k == 0: | ||
137 | + update_state[key] = local_state[key] * (net_data_count[k] / total_data_count) | ||
138 | + else: | ||
139 | + update_state[key] += local_state[key] * (net_data_count[k] / total_data_count) | ||
140 | + # break | ||
141 | + for key in update_state.keys(): | ||
142 | + if 'enable' in key: | ||
143 | + continue | ||
144 | + if 'activation_post_process' in key or 'fake_quant' in key: | ||
145 | + # print(key, update_state[key], update_state[key].type()) | ||
146 | + # print(key, update_state[key]) | ||
147 | + if torch.is_floating_point(update_state[key]): | ||
148 | + update_state[key] = update_state[key] / args.n_nets | ||
149 | + else: | ||
150 | + update_state[key] = torch.floor_divide(update_state[key], args.n_nets) | ||
151 | + # print(update_state[key]) | ||
152 | + | ||
153 | + fed_model.load_state_dict(update_state) | ||
154 | + if cr % 1 == 0: | ||
155 | + fed_model.to(device) | ||
156 | + fed_model.eval() | ||
157 | + | ||
158 | + total_loss = 0.0 | ||
159 | + cnt = 0 | ||
160 | + step_acc = 0.0 | ||
161 | + with torch.no_grad(): | ||
162 | + for i, data in enumerate(testloader): | ||
163 | + inputs, labels = data | ||
164 | + inputs, labels = inputs.float().to(device), labels.long().to(device) | ||
165 | + | ||
166 | + outputs = fed_model(inputs) | ||
167 | + _, preds = torch.max(outputs, 1) | ||
168 | + | ||
169 | + loss = criterion(outputs, labels) | ||
170 | + cnt += inputs.shape[0] | ||
171 | + | ||
172 | + corr_sum = torch.sum(preds == labels.data) | ||
173 | + step_acc += corr_sum.double() | ||
174 | + running_loss = loss.item() * inputs.shape[0] | ||
175 | + total_loss += running_loss | ||
176 | + if i % 200 == 0: | ||
177 | + print('test [%4d] loss: %.3f' % (i, loss.item())) | ||
178 | + break | ||
179 | + print((step_acc / cnt).item()) | ||
180 | + print(total_loss / cnt) | ||
181 | + fed_model.to('cpu') | ||
182 | + quantized_fed_model = torch.quantization.convert(fed_model.eval(), inplace=False) | ||
183 | + torch.jit.save(torch.jit.script(quantized_fed_model), './quan.pth') | ||
184 | + | ||
185 | + | ||
186 | + | ||
187 | +def start_train(): | ||
188 | + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
189 | + print(device) | ||
190 | + args = add_args(argparse.ArgumentParser()) | ||
191 | + | ||
192 | + seed = 0 | ||
193 | + np.random.seed(seed) | ||
194 | + torch.manual_seed(seed) | ||
195 | + | ||
196 | + print("Loading data...") | ||
197 | + # kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt', | ||
198 | + # "./dataset/Fuzzy_dataset.csv" : './Fuzzy_dataset.txt', | ||
199 | + # "./dataset/RPM_dataset.csv" : './RPM_dataset.txt', | ||
200 | + # "./dataset/gear_dataset.csv" : './gear_dataset.txt' | ||
201 | + # } | ||
202 | + kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} | ||
203 | + train_data_set, data_idx_map, net_class_count, net_data_count, test_data_set = dataset.GetCanDatasetUsingTxtKwarg(args.n_nets, args.fold_num, **kwargs) | ||
204 | + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, | ||
205 | + shuffle=False, num_workers=2) | ||
206 | + | ||
207 | + run_benchmark('./quan.pth', testloader) | ||
208 | + run_benchmark('./float.pth', testloader) | ||
209 | + # run_benchmark('./quan.pth', testloader) | ||
210 | + | ||
211 | + fed_model = model.Net() | ||
212 | + args.comm_type = 'fedavg' | ||
213 | + if args.comm_type == "fedavg": | ||
214 | + start_fedavg(fed_model, args, | ||
215 | + train_data_set, | ||
216 | + data_idx_map, | ||
217 | + net_data_count, | ||
218 | + testloader, | ||
219 | + device) | ||
220 | + | ||
221 | +if __name__ == "__main__": | ||
222 | + start_train() |
코드/연합학습/quantization/model.py
0 → 100644
1 | +import torch.nn as nn | ||
2 | +import torch.nn.functional as F | ||
3 | +import torch | ||
4 | +from torch.quantization import QuantStub, DeQuantStub | ||
5 | + | ||
6 | +class Net(nn.Module): | ||
7 | + def __init__(self): | ||
8 | + super(Net, self).__init__() | ||
9 | + | ||
10 | + self.quant = QuantStub() | ||
11 | + self.dequant = DeQuantStub() | ||
12 | + | ||
13 | + self.conv1 = nn.Sequential( | ||
14 | + nn.Conv2d(1, 8, 3), | ||
15 | + nn.ReLU(True), | ||
16 | + ) | ||
17 | + self.conv2 = nn.Sequential( | ||
18 | + nn.Conv2d(8, 8, 3), | ||
19 | + nn.ReLU(True), | ||
20 | + ) | ||
21 | + self.conv3 = nn.Sequential( | ||
22 | + nn.Conv2d(8, 8, 3), | ||
23 | + nn.ReLU(True), | ||
24 | + ) | ||
25 | + self.fc4 = nn.Linear(8 * 23 * 23, 2) | ||
26 | + | ||
27 | + def forward(self, x): | ||
28 | + x = self.quant(x) | ||
29 | + x = self.conv1(x) | ||
30 | + x = self.conv2(x) | ||
31 | + x = self.conv3(x) | ||
32 | + x = torch.flatten(x, 1) | ||
33 | + x = self.fc4(x) | ||
34 | + x = self.dequant(x) | ||
35 | + return x | ||
36 | + | ||
37 | + def fuse_model(self): | ||
38 | + for m in self.modules(): | ||
39 | + if type(m) == nn.Sequential: | ||
40 | + torch.quantization.fuse_modules(m, ['0', '1'], inplace=True) |
코드/연합학습/quantization/utils.py
0 → 100644
1 | +import os | ||
2 | +import argparse | ||
3 | +import json | ||
4 | +import numpy as np | ||
5 | +import torch | ||
6 | +import torch.optim as optim | ||
7 | +import torch.nn as nn | ||
8 | +import torchvision | ||
9 | +import torchvision.transforms as transforms | ||
10 | +import torch.utils.data as data | ||
11 | +import math | ||
12 | +import copy | ||
13 | +import time | ||
14 | + | ||
15 | +import model | ||
16 | +import torch.quantization | ||
17 | +from torch.quantization import QuantStub, DeQuantStub | ||
18 | + | ||
19 | + | ||
20 | +def run_benchmark(model_file, img_loader): | ||
21 | + elapsed = 0 | ||
22 | + # myModel = torch.jit.load(model_file) | ||
23 | + # torch.backends.quantized.engine='fbgemm' | ||
24 | + # myModel.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') | ||
25 | + # myModel.eval() | ||
26 | + myModel = model.Net() | ||
27 | + # myModel = torch.quantization.quantize_dynamic(myModel, {torch.nn.Linear, torch.nn.Sequential}, dtype=torch.qint8) | ||
28 | + # print(myModel) | ||
29 | + # set quantization config for server (x86) | ||
30 | + myModel.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') | ||
31 | + num_batches = 10 | ||
32 | + # # insert observers | ||
33 | + torch.quantization.prepare(myModel, inplace=True) | ||
34 | + # # Calibrate the model and collect statistics | ||
35 | + with torch.no_grad(): | ||
36 | + for i, (images, target) in enumerate(img_loader): | ||
37 | + images = images.float() | ||
38 | + target = target.long() | ||
39 | + if i < num_batches: | ||
40 | + start = time.time() | ||
41 | + output = myModel(images) | ||
42 | + end = time.time() | ||
43 | + # elapsed = elapsed + (end-start) | ||
44 | + else: | ||
45 | + break | ||
46 | + | ||
47 | + # # convert to quantized version | ||
48 | + torch.quantization.convert(myModel, inplace=True) | ||
49 | + | ||
50 | + # quant = QuantStub() | ||
51 | + with torch.no_grad(): | ||
52 | + for i, (images, target) in enumerate(img_loader): | ||
53 | + images = images.float() | ||
54 | + target = target.long() | ||
55 | + if i < num_batches: | ||
56 | + start = time.time() | ||
57 | + output = myModel(images) | ||
58 | + end = time.time() | ||
59 | + elapsed = elapsed + (end-start) | ||
60 | + else: | ||
61 | + break | ||
62 | + num_images = images.size()[0] * num_batches | ||
63 | + print(elapsed) | ||
64 | + print('Elapsed time: %3.0f ms' % (elapsed/num_images*1000)) | ||
65 | + return elapsed | ||
66 | + | ||
67 | + | ||
68 | +def init_models(n_nets, args): | ||
69 | + models = [] | ||
70 | + layer_shape = [] | ||
71 | + layer_type = [] | ||
72 | + | ||
73 | + for idx in range(n_nets): | ||
74 | + # if args.model == "lenet": | ||
75 | + # cnn = LeNet() | ||
76 | + # elif args.model == "vgg": | ||
77 | + # cnn = vgg11() | ||
78 | + models.append(model.Net()) | ||
79 | + | ||
80 | + for (k, v) in models[0].state_dict().items(): | ||
81 | + layer_shape.append(v.shape) | ||
82 | + layer_type.append(k) | ||
83 | + | ||
84 | + return models, layer_shape, layer_type | ||
85 | + | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment