김지훈

엣지모델 수정 전, 학습 속도 개선 필요

1 +# Byte-compiled / optimized / DLL files
2 +__pycache__/
3 +*.py[cod]
4 +*$py.class
5 +
6 +# C extensions
7 +*.so
8 +
9 +# Distribution / packaging
10 +.Python
11 +build/
12 +develop-eggs/
13 +dist/
14 +downloads/
15 +eggs/
16 +.eggs/
17 +lib/
18 +lib64/
19 +parts/
20 +sdist/
21 +var/
22 +wheels/
23 +share/python-wheels/
24 +*.egg-info/
25 +.installed.cfg
26 +*.egg
27 +MANIFEST
28 +
29 +# PyInstaller
30 +# Usually these files are written by a python script from a template
31 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
32 +*.manifest
33 +*.spec
34 +
35 +# Installer logs
36 +pip-log.txt
37 +pip-delete-this-directory.txt
38 +
39 +# Unit test / coverage reports
40 +htmlcov/
41 +.tox/
42 +.nox/
43 +.coverage
44 +.coverage.*
45 +.cache
46 +nosetests.xml
47 +coverage.xml
48 +*.cover
49 +*.py,cover
50 +.hypothesis/
51 +.pytest_cache/
52 +cover/
53 +
54 +# Translations
55 +*.mo
56 +*.pot
57 +
58 +# Django stuff:
59 +*.log
60 +local_settings.py
61 +db.sqlite3
62 +db.sqlite3-journal
63 +
64 +# Flask stuff:
65 +instance/
66 +.webassets-cache
67 +
68 +# Scrapy stuff:
69 +.scrapy
70 +
71 +# Sphinx documentation
72 +docs/_build/
73 +
74 +# PyBuilder
75 +.pybuilder/
76 +target/
77 +
78 +# Jupyter Notebook
79 +.ipynb_checkpoints
80 +
81 +# IPython
82 +profile_default/
83 +ipython_config.py
84 +
85 +# pyenv
86 +# For a library or package, you might want to ignore these files since the code is
87 +# intended to run in multiple environments; otherwise, check them in:
88 +# .python-version
89 +
90 +# pipenv
91 +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 +# However, in case of collaboration, if having platform-specific dependencies or dependencies
93 +# having no cross-platform support, pipenv may install dependencies that don't work, or not
94 +# install all needed dependencies.
95 +#Pipfile.lock
96 +
97 +# PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 +__pypackages__/
99 +
100 +# Celery stuff
101 +celerybeat-schedule
102 +celerybeat.pid
103 +
104 +# SageMath parsed files
105 +*.sage.py
106 +
107 +# Environments
108 +.env
109 +.venv
110 +env/
111 +venv/
112 +ENV/
113 +env.bak/
114 +venv.bak/
115 +
116 +# Spyder project settings
117 +.spyderproject
118 +.spyproject
119 +
120 +# Rope project settings
121 +.ropeproject
122 +
123 +# mkdocs documentation
124 +/site
125 +
126 +# mypy
127 +.mypy_cache/
128 +.dmypy.json
129 +dmypy.json
130 +
131 +# Pyre type checker
132 +.pyre/
133 +
134 +# pytype static type analyzer
135 +.pytype/
136 +
137 +# Cython debug symbols
138 +cython_debug/
...\ No newline at end of file ...\ No newline at end of file
1 +CAN_ID_BIT = 29
...\ No newline at end of file ...\ No newline at end of file
1 +import os
2 +import torch
3 +import pandas as pd
4 +import numpy as np
5 +from torch.utils.data import Dataset, DataLoader
6 +import const
7 +
8 +'''
9 +def int_to_binary(x, bits):
10 + mask = 2 ** torch.arange(bits).to(x.device, x.dtype)
11 + return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
12 +'''
13 +
14 +def unpack_bits(x, num_bits):
15 + """
16 + Args:
17 + x (int): bit로 변환할 정수
18 + num_bits (int): 표현할 비트수
19 + """
20 + xshape = list(x.shape)
21 + x = x.reshape([-1, 1])
22 + mask = 2**np.arange(num_bits).reshape([1, num_bits])
23 + return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits])
24 +
25 +
26 +# def CsvToNumpy(csv_file):
27 +# target_csv = pd.read_csv(csv_file)
28 +# inputs_save_numpy = 'inputs_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy'
29 +# labels_save_numpy = 'labels_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy'
30 +# print(inputs_save_numpy, labels_save_numpy)
31 +
32 +# i = 0
33 +# inputs_array = []
34 +# labels_array = []
35 +# print(len(target_csv))
36 +
37 +# while i + const.CAN_ID_BIT - 1 < len(target_csv):
38 +
39 +# is_regular = True
40 +# for j in range(const.CAN_ID_BIT):
41 +# l = target_csv.iloc[i + j]
42 +# b = l[2]
43 +# r = (l[b+2+1] == 'R')
44 +
45 +# if not r:
46 +# is_regular = False
47 +# break
48 +
49 +# inputs = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
50 +# for idx in range(const.CAN_ID_BIT):
51 +# can_id = int(target_csv.iloc[i + idx, 1], 16)
52 +# inputs[idx] = unpack_bits(np.array(can_id), const.CAN_ID_BIT)
53 +# inputs = np.reshape(inputs, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
54 +
55 +# if is_regular:
56 +# labels = 1
57 +# else:
58 +# labels = 0
59 +
60 +# inputs_array.append(inputs)
61 +# labels_array.append(labels)
62 +
63 +# i+=1
64 +# if (i % 5000 == 0):
65 +# print(i)
66 +# # break
67 +
68 +# inputs_array = np.array(inputs_array)
69 +# labels_array = np.array(labels_array)
70 +# np.save(inputs_save_numpy, arr=inputs_array)
71 +# np.save(labels_save_numpy, arr=labels_array)
72 +# print('done')
73 +
74 +
75 +def CsvToText(csv_file):
76 + target_csv = pd.read_csv(csv_file)
77 + text_file_name = csv_file.split('/')[-1].split('.')[0] + '.txt'
78 + print(text_file_name)
79 + target_text = open(text_file_name, mode='wt', encoding='utf-8')
80 +
81 + i = 0
82 + datum = [ [], [] ]
83 + print(len(target_csv))
84 +
85 + while i + const.CAN_ID_BIT - 1 < len(target_csv):
86 +
87 + is_regular = True
88 + for j in range(const.CAN_ID_BIT):
89 + l = target_csv.iloc[i + j]
90 + b = l[2]
91 + r = (l[b+2+1] == 'R')
92 +
93 + if not r:
94 + is_regular = False
95 + break
96 +
97 + if is_regular:
98 + target_text.write("%d R\n" % i)
99 + else:
100 + target_text.write("%d T\n" % i)
101 +
102 + i+=1
103 + if (i % 5000 == 0):
104 + print(i)
105 +
106 + target_text.close()
107 + print('done')
108 +
109 +
110 +def record_net_data_stats(label_temp, data_idx_map):
111 + net_class_count = {}
112 + net_data_count= {}
113 +
114 + for net_i, dataidx in data_idx_map.items():
115 + unq, unq_cnt = np.unique(label_temp[dataidx], return_counts=True)
116 + tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
117 + net_class_count[net_i] = tmp
118 + net_data_count[net_i] = len(dataidx)
119 + print('Data statistics: %s' % str(net_class_count))
120 + return net_class_count, net_data_count
121 +
122 +
123 +def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs):
124 + csv_list = []
125 + total_datum = []
126 + total_label_temp = []
127 + csv_idx = 0
128 + for csv_file, txt_file in kwargs.items():
129 + csv = pd.read_csv(csv_file)
130 + csv_list.append(csv)
131 +
132 + txt = open(txt_file, "r")
133 + lines = txt.read().splitlines()
134 +
135 + idx = 0
136 + local_datum = []
137 + while idx + const.CAN_ID_BIT - 1 < len(csv):
138 + line = lines[idx]
139 + if not line:
140 + break
141 +
142 + if line.split(' ')[1] == 'R':
143 + local_datum.append((csv_idx, idx, 1))
144 + total_label_temp.append(1)
145 + else:
146 + local_datum.append((csv_idx, idx, 0))
147 + total_label_temp.append(0)
148 +
149 + idx += 1
150 + if (idx % 1000000 == 0):
151 + print(idx)
152 +
153 + csv_idx += 1
154 + total_datum += local_datum
155 +
156 + fold_length = int(len(total_label_temp) / 5)
157 + datum = []
158 + label_temp = []
159 + for i in range(5):
160 + if i != fold_num:
161 + datum += total_datum[i*fold_length:(i+1)*fold_length]
162 + label_temp += total_label_temp[i*fold_length:(i+1)*fold_length]
163 + else:
164 + test_datum = total_datum[i*fold_length:(i+1)*fold_length]
165 +
166 + min_size = 0
167 + output_class_num = 2
168 + N = len(label_temp)
169 + label_temp = np.array(label_temp)
170 + data_idx_map = {}
171 +
172 + while min_size < 512:
173 + idx_batch = [[] for _ in range(total_edge)]
174 + # for each class in the dataset
175 + for k in range(output_class_num):
176 + idx_k = np.where(label_temp == k)[0]
177 + np.random.shuffle(idx_k)
178 + proportions = np.random.dirichlet(np.repeat(1, total_edge))
179 + ## Balance
180 + proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)])
181 + proportions = proportions/proportions.sum()
182 + proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
183 + idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
184 + min_size = min([len(idx_j) for idx_j in idx_batch])
185 +
186 + for j in range(total_edge):
187 + np.random.shuffle(idx_batch[j])
188 + data_idx_map[j] = idx_batch[j]
189 +
190 + net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map)
191 +
192 + return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False)
193 +
194 +
195 +class CanDatasetKwarg(Dataset):
196 +
197 + def __init__(self, csv_list, datum, is_train=True):
198 + self.csv_list = csv_list
199 + self.datum = datum
200 + if is_train:
201 + self.idx_map = []
202 + else:
203 + self.idx_map = [idx for idx in range(len(self.datum))]
204 +
205 + def __len__(self):
206 + return len(self.idx_map)
207 +
208 + def set_idx_map(self, data_idx_map):
209 + self.idx_map = data_idx_map
210 +
211 + def __getitem__(self, idx):
212 + csv_idx = self.datum[self.idx_map[idx]][0]
213 + start_i = self.datum[self.idx_map[idx]][1]
214 + is_regular = self.datum[self.idx_map[idx]][2]
215 +
216 + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
217 + for i in range(const.CAN_ID_BIT):
218 + id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16)
219 + bits = unpack_bits(np.array(id_), const.CAN_ID_BIT)
220 + l[i] = bits
221 + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
222 +
223 + return (l, is_regular)
224 +
225 +
226 +def GetCanDatasetUsingTxt(csv_file, txt_path, length):
227 + csv = pd.read_csv(csv_file)
228 + txt = open(txt_path, "r")
229 + lines = txt.read().splitlines()
230 +
231 + idx = 0
232 + datum = [ [], [] ]
233 + while idx + const.CAN_ID_BIT - 1 < len(csv):
234 + if len(datum[0]) >= length//2 and len(datum[1]) >= length//2:
235 + break
236 +
237 + line = lines[idx]
238 + if not line:
239 + break
240 +
241 + if line.split(' ')[1] == 'R':
242 + if len(datum[0]) < length//2:
243 + datum[0].append((idx, 1))
244 + else:
245 + if len(datum[1]) < length//2:
246 + datum[1].append((idx, 0))
247 +
248 + idx += 1
249 + if (idx % 5000 == 0):
250 + print(idx, len(datum[0]), len(datum[1]))
251 +
252 + l = int((length // 2) * 0.9)
253 + return CanDataset(csv, datum[0][:l] + datum[1][:l]), \
254 + CanDataset(csv, datum[0][l:] + datum[1][l:])
255 +
256 +
257 +def GetCanDataset(csv_file, length):
258 + csv = pd.read_csv(csv_file)
259 +
260 + i = 0
261 + datum = [ [], [] ]
262 +
263 + while i + const.CAN_ID_BIT - 1 < len(csv):
264 + if len(datum[0]) >= length//2 and len(datum[1]) >= length//2:
265 + break
266 +
267 + is_regular = True
268 + for j in range(const.CAN_ID_BIT):
269 + l = csv.iloc[i + j]
270 + b = l[2]
271 + r = (l[b+2+1] == 'R')
272 +
273 + if not r:
274 + is_regular = False
275 + break
276 +
277 + if is_regular:
278 + if len(datum[0]) < length//2:
279 + datum[0].append((i, 1))
280 + else:
281 + if len(datum[1]) < length//2:
282 + datum[1].append((i, 0))
283 + i+=1
284 + if (i % 5000 == 0):
285 + print(i, len(datum[0]), len(datum[1]))
286 +
287 + l = int((length // 2) * 0.9)
288 + return CanDataset(csv, datum[0][:l] + datum[1][:l]), \
289 + CanDataset(csv, datum[0][l:] + datum[1][l:])
290 +
291 +
292 +class CanDataset(Dataset):
293 +
294 + def __init__(self, csv, datum):
295 + self.csv = csv
296 + self.datum = datum
297 +
298 + def __len__(self):
299 + return len(self.datum)
300 +
301 + def __getitem__(self, idx):
302 + start_i = self.datum[idx][0]
303 + is_regular = self.datum[idx][1]
304 +
305 + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT))
306 + for i in range(const.CAN_ID_BIT):
307 + id = int(self.csv.iloc[start_i + i, 1], 16)
308 + bits = unpack_bits(np.array(id), const.CAN_ID_BIT)
309 + l[i] = bits
310 + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT))
311 +
312 + return (l, is_regular)
313 +
314 +
315 +if __name__ == "__main__":
316 + kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'}
317 + test_data_set = dataset.GetCanDatasetUsingTxtKwarg(-1, -1, False, **kwargs)
318 + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size,
319 + shuffle=False, num_workers=2)
320 +
321 + for x, y in testloader:
322 + print(x)
323 + print(y)
324 + break
This diff is collapsed. Click to expand it.
1 +import torch.nn as nn
2 +import torch.nn.functional as F
3 +import torch
4 +import const
5 +
6 +class Net(nn.Module):
7 + def __init__(self):
8 + super(Net, self).__init__()
9 +
10 + self.f1 = nn.Sequential(
11 + nn.Conv2d(1, 2, 3),
12 + nn.ReLU(True),
13 + )
14 + self.f2 = nn.Sequential(
15 + nn.Conv2d(2, 4, 3),
16 + nn.ReLU(True),
17 + )
18 + self.f3 = nn.Sequential(
19 + nn.Conv2d(4, 8, 3),
20 + nn.ReLU(True),
21 + )
22 + self.f4 = nn.Sequential(
23 + nn.Linear(8 * 23 * 23, 2),
24 + )
25 +
26 + def forward(self, x):
27 + x = self.f1(x)
28 + x = self.f2(x)
29 + x = self.f3(x)
30 + x = torch.flatten(x, 1)
31 + x = self.f4(x)
32 + return x
...\ No newline at end of file ...\ No newline at end of file