Showing
5 changed files
with
495 additions
and
0 deletions
코드/.gitignore
0 → 100644
1 | +# Byte-compiled / optimized / DLL files | ||
2 | +__pycache__/ | ||
3 | +*.py[cod] | ||
4 | +*$py.class | ||
5 | + | ||
6 | +# C extensions | ||
7 | +*.so | ||
8 | + | ||
9 | +# Distribution / packaging | ||
10 | +.Python | ||
11 | +build/ | ||
12 | +develop-eggs/ | ||
13 | +dist/ | ||
14 | +downloads/ | ||
15 | +eggs/ | ||
16 | +.eggs/ | ||
17 | +lib/ | ||
18 | +lib64/ | ||
19 | +parts/ | ||
20 | +sdist/ | ||
21 | +var/ | ||
22 | +wheels/ | ||
23 | +share/python-wheels/ | ||
24 | +*.egg-info/ | ||
25 | +.installed.cfg | ||
26 | +*.egg | ||
27 | +MANIFEST | ||
28 | + | ||
29 | +# PyInstaller | ||
30 | +# Usually these files are written by a python script from a template | ||
31 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
32 | +*.manifest | ||
33 | +*.spec | ||
34 | + | ||
35 | +# Installer logs | ||
36 | +pip-log.txt | ||
37 | +pip-delete-this-directory.txt | ||
38 | + | ||
39 | +# Unit test / coverage reports | ||
40 | +htmlcov/ | ||
41 | +.tox/ | ||
42 | +.nox/ | ||
43 | +.coverage | ||
44 | +.coverage.* | ||
45 | +.cache | ||
46 | +nosetests.xml | ||
47 | +coverage.xml | ||
48 | +*.cover | ||
49 | +*.py,cover | ||
50 | +.hypothesis/ | ||
51 | +.pytest_cache/ | ||
52 | +cover/ | ||
53 | + | ||
54 | +# Translations | ||
55 | +*.mo | ||
56 | +*.pot | ||
57 | + | ||
58 | +# Django stuff: | ||
59 | +*.log | ||
60 | +local_settings.py | ||
61 | +db.sqlite3 | ||
62 | +db.sqlite3-journal | ||
63 | + | ||
64 | +# Flask stuff: | ||
65 | +instance/ | ||
66 | +.webassets-cache | ||
67 | + | ||
68 | +# Scrapy stuff: | ||
69 | +.scrapy | ||
70 | + | ||
71 | +# Sphinx documentation | ||
72 | +docs/_build/ | ||
73 | + | ||
74 | +# PyBuilder | ||
75 | +.pybuilder/ | ||
76 | +target/ | ||
77 | + | ||
78 | +# Jupyter Notebook | ||
79 | +.ipynb_checkpoints | ||
80 | + | ||
81 | +# IPython | ||
82 | +profile_default/ | ||
83 | +ipython_config.py | ||
84 | + | ||
85 | +# pyenv | ||
86 | +# For a library or package, you might want to ignore these files since the code is | ||
87 | +# intended to run in multiple environments; otherwise, check them in: | ||
88 | +# .python-version | ||
89 | + | ||
90 | +# pipenv | ||
91 | +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
92 | +# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
93 | +# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
94 | +# install all needed dependencies. | ||
95 | +#Pipfile.lock | ||
96 | + | ||
97 | +# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
98 | +__pypackages__/ | ||
99 | + | ||
100 | +# Celery stuff | ||
101 | +celerybeat-schedule | ||
102 | +celerybeat.pid | ||
103 | + | ||
104 | +# SageMath parsed files | ||
105 | +*.sage.py | ||
106 | + | ||
107 | +# Environments | ||
108 | +.env | ||
109 | +.venv | ||
110 | +env/ | ||
111 | +venv/ | ||
112 | +ENV/ | ||
113 | +env.bak/ | ||
114 | +venv.bak/ | ||
115 | + | ||
116 | +# Spyder project settings | ||
117 | +.spyderproject | ||
118 | +.spyproject | ||
119 | + | ||
120 | +# Rope project settings | ||
121 | +.ropeproject | ||
122 | + | ||
123 | +# mkdocs documentation | ||
124 | +/site | ||
125 | + | ||
126 | +# mypy | ||
127 | +.mypy_cache/ | ||
128 | +.dmypy.json | ||
129 | +dmypy.json | ||
130 | + | ||
131 | +# Pyre type checker | ||
132 | +.pyre/ | ||
133 | + | ||
134 | +# pytype static type analyzer | ||
135 | +.pytype/ | ||
136 | + | ||
137 | +# Cython debug symbols | ||
138 | +cython_debug/ | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/const.py
0 → 100644
1 | +CAN_ID_BIT = 29 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
코드/dataset.py
0 → 100644
1 | +import os | ||
2 | +import torch | ||
3 | +import pandas as pd | ||
4 | +import numpy as np | ||
5 | +from torch.utils.data import Dataset, DataLoader | ||
6 | +import const | ||
7 | + | ||
8 | +''' | ||
9 | +def int_to_binary(x, bits): | ||
10 | + mask = 2 ** torch.arange(bits).to(x.device, x.dtype) | ||
11 | + return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() | ||
12 | +''' | ||
13 | + | ||
14 | +def unpack_bits(x, num_bits): | ||
15 | + """ | ||
16 | + Args: | ||
17 | + x (int): bit로 변환할 정수 | ||
18 | + num_bits (int): 표현할 비트수 | ||
19 | + """ | ||
20 | + xshape = list(x.shape) | ||
21 | + x = x.reshape([-1, 1]) | ||
22 | + mask = 2**np.arange(num_bits).reshape([1, num_bits]) | ||
23 | + return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits]) | ||
24 | + | ||
25 | + | ||
26 | +# def CsvToNumpy(csv_file): | ||
27 | +# target_csv = pd.read_csv(csv_file) | ||
28 | +# inputs_save_numpy = 'inputs_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy' | ||
29 | +# labels_save_numpy = 'labels_' + csv_file.split('/')[-1].split('.')[0].split('_')[0] + '.npy' | ||
30 | +# print(inputs_save_numpy, labels_save_numpy) | ||
31 | + | ||
32 | +# i = 0 | ||
33 | +# inputs_array = [] | ||
34 | +# labels_array = [] | ||
35 | +# print(len(target_csv)) | ||
36 | + | ||
37 | +# while i + const.CAN_ID_BIT - 1 < len(target_csv): | ||
38 | + | ||
39 | +# is_regular = True | ||
40 | +# for j in range(const.CAN_ID_BIT): | ||
41 | +# l = target_csv.iloc[i + j] | ||
42 | +# b = l[2] | ||
43 | +# r = (l[b+2+1] == 'R') | ||
44 | + | ||
45 | +# if not r: | ||
46 | +# is_regular = False | ||
47 | +# break | ||
48 | + | ||
49 | +# inputs = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
50 | +# for idx in range(const.CAN_ID_BIT): | ||
51 | +# can_id = int(target_csv.iloc[i + idx, 1], 16) | ||
52 | +# inputs[idx] = unpack_bits(np.array(can_id), const.CAN_ID_BIT) | ||
53 | +# inputs = np.reshape(inputs, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
54 | + | ||
55 | +# if is_regular: | ||
56 | +# labels = 1 | ||
57 | +# else: | ||
58 | +# labels = 0 | ||
59 | + | ||
60 | +# inputs_array.append(inputs) | ||
61 | +# labels_array.append(labels) | ||
62 | + | ||
63 | +# i+=1 | ||
64 | +# if (i % 5000 == 0): | ||
65 | +# print(i) | ||
66 | +# # break | ||
67 | + | ||
68 | +# inputs_array = np.array(inputs_array) | ||
69 | +# labels_array = np.array(labels_array) | ||
70 | +# np.save(inputs_save_numpy, arr=inputs_array) | ||
71 | +# np.save(labels_save_numpy, arr=labels_array) | ||
72 | +# print('done') | ||
73 | + | ||
74 | + | ||
75 | +def CsvToText(csv_file): | ||
76 | + target_csv = pd.read_csv(csv_file) | ||
77 | + text_file_name = csv_file.split('/')[-1].split('.')[0] + '.txt' | ||
78 | + print(text_file_name) | ||
79 | + target_text = open(text_file_name, mode='wt', encoding='utf-8') | ||
80 | + | ||
81 | + i = 0 | ||
82 | + datum = [ [], [] ] | ||
83 | + print(len(target_csv)) | ||
84 | + | ||
85 | + while i + const.CAN_ID_BIT - 1 < len(target_csv): | ||
86 | + | ||
87 | + is_regular = True | ||
88 | + for j in range(const.CAN_ID_BIT): | ||
89 | + l = target_csv.iloc[i + j] | ||
90 | + b = l[2] | ||
91 | + r = (l[b+2+1] == 'R') | ||
92 | + | ||
93 | + if not r: | ||
94 | + is_regular = False | ||
95 | + break | ||
96 | + | ||
97 | + if is_regular: | ||
98 | + target_text.write("%d R\n" % i) | ||
99 | + else: | ||
100 | + target_text.write("%d T\n" % i) | ||
101 | + | ||
102 | + i+=1 | ||
103 | + if (i % 5000 == 0): | ||
104 | + print(i) | ||
105 | + | ||
106 | + target_text.close() | ||
107 | + print('done') | ||
108 | + | ||
109 | + | ||
110 | +def record_net_data_stats(label_temp, data_idx_map): | ||
111 | + net_class_count = {} | ||
112 | + net_data_count= {} | ||
113 | + | ||
114 | + for net_i, dataidx in data_idx_map.items(): | ||
115 | + unq, unq_cnt = np.unique(label_temp[dataidx], return_counts=True) | ||
116 | + tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))} | ||
117 | + net_class_count[net_i] = tmp | ||
118 | + net_data_count[net_i] = len(dataidx) | ||
119 | + print('Data statistics: %s' % str(net_class_count)) | ||
120 | + return net_class_count, net_data_count | ||
121 | + | ||
122 | + | ||
123 | +def GetCanDatasetUsingTxtKwarg(total_edge, fold_num, **kwargs): | ||
124 | + csv_list = [] | ||
125 | + total_datum = [] | ||
126 | + total_label_temp = [] | ||
127 | + csv_idx = 0 | ||
128 | + for csv_file, txt_file in kwargs.items(): | ||
129 | + csv = pd.read_csv(csv_file) | ||
130 | + csv_list.append(csv) | ||
131 | + | ||
132 | + txt = open(txt_file, "r") | ||
133 | + lines = txt.read().splitlines() | ||
134 | + | ||
135 | + idx = 0 | ||
136 | + local_datum = [] | ||
137 | + while idx + const.CAN_ID_BIT - 1 < len(csv): | ||
138 | + line = lines[idx] | ||
139 | + if not line: | ||
140 | + break | ||
141 | + | ||
142 | + if line.split(' ')[1] == 'R': | ||
143 | + local_datum.append((csv_idx, idx, 1)) | ||
144 | + total_label_temp.append(1) | ||
145 | + else: | ||
146 | + local_datum.append((csv_idx, idx, 0)) | ||
147 | + total_label_temp.append(0) | ||
148 | + | ||
149 | + idx += 1 | ||
150 | + if (idx % 1000000 == 0): | ||
151 | + print(idx) | ||
152 | + | ||
153 | + csv_idx += 1 | ||
154 | + total_datum += local_datum | ||
155 | + | ||
156 | + fold_length = int(len(total_label_temp) / 5) | ||
157 | + datum = [] | ||
158 | + label_temp = [] | ||
159 | + for i in range(5): | ||
160 | + if i != fold_num: | ||
161 | + datum += total_datum[i*fold_length:(i+1)*fold_length] | ||
162 | + label_temp += total_label_temp[i*fold_length:(i+1)*fold_length] | ||
163 | + else: | ||
164 | + test_datum = total_datum[i*fold_length:(i+1)*fold_length] | ||
165 | + | ||
166 | + min_size = 0 | ||
167 | + output_class_num = 2 | ||
168 | + N = len(label_temp) | ||
169 | + label_temp = np.array(label_temp) | ||
170 | + data_idx_map = {} | ||
171 | + | ||
172 | + while min_size < 512: | ||
173 | + idx_batch = [[] for _ in range(total_edge)] | ||
174 | + # for each class in the dataset | ||
175 | + for k in range(output_class_num): | ||
176 | + idx_k = np.where(label_temp == k)[0] | ||
177 | + np.random.shuffle(idx_k) | ||
178 | + proportions = np.random.dirichlet(np.repeat(1, total_edge)) | ||
179 | + ## Balance | ||
180 | + proportions = np.array([p*(len(idx_j)<N/total_edge) for p,idx_j in zip(proportions,idx_batch)]) | ||
181 | + proportions = proportions/proportions.sum() | ||
182 | + proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1] | ||
183 | + idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))] | ||
184 | + min_size = min([len(idx_j) for idx_j in idx_batch]) | ||
185 | + | ||
186 | + for j in range(total_edge): | ||
187 | + np.random.shuffle(idx_batch[j]) | ||
188 | + data_idx_map[j] = idx_batch[j] | ||
189 | + | ||
190 | + net_class_count, net_data_count = record_net_data_stats(label_temp, data_idx_map) | ||
191 | + | ||
192 | + return CanDatasetKwarg(csv_list, datum), data_idx_map, net_class_count, net_data_count, CanDatasetKwarg(csv_list, test_datum, False) | ||
193 | + | ||
194 | + | ||
195 | +class CanDatasetKwarg(Dataset): | ||
196 | + | ||
197 | + def __init__(self, csv_list, datum, is_train=True): | ||
198 | + self.csv_list = csv_list | ||
199 | + self.datum = datum | ||
200 | + if is_train: | ||
201 | + self.idx_map = [] | ||
202 | + else: | ||
203 | + self.idx_map = [idx for idx in range(len(self.datum))] | ||
204 | + | ||
205 | + def __len__(self): | ||
206 | + return len(self.idx_map) | ||
207 | + | ||
208 | + def set_idx_map(self, data_idx_map): | ||
209 | + self.idx_map = data_idx_map | ||
210 | + | ||
211 | + def __getitem__(self, idx): | ||
212 | + csv_idx = self.datum[self.idx_map[idx]][0] | ||
213 | + start_i = self.datum[self.idx_map[idx]][1] | ||
214 | + is_regular = self.datum[self.idx_map[idx]][2] | ||
215 | + | ||
216 | + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
217 | + for i in range(const.CAN_ID_BIT): | ||
218 | + id_ = int(self.csv_list[csv_idx].iloc[start_i + i, 1], 16) | ||
219 | + bits = unpack_bits(np.array(id_), const.CAN_ID_BIT) | ||
220 | + l[i] = bits | ||
221 | + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
222 | + | ||
223 | + return (l, is_regular) | ||
224 | + | ||
225 | + | ||
226 | +def GetCanDatasetUsingTxt(csv_file, txt_path, length): | ||
227 | + csv = pd.read_csv(csv_file) | ||
228 | + txt = open(txt_path, "r") | ||
229 | + lines = txt.read().splitlines() | ||
230 | + | ||
231 | + idx = 0 | ||
232 | + datum = [ [], [] ] | ||
233 | + while idx + const.CAN_ID_BIT - 1 < len(csv): | ||
234 | + if len(datum[0]) >= length//2 and len(datum[1]) >= length//2: | ||
235 | + break | ||
236 | + | ||
237 | + line = lines[idx] | ||
238 | + if not line: | ||
239 | + break | ||
240 | + | ||
241 | + if line.split(' ')[1] == 'R': | ||
242 | + if len(datum[0]) < length//2: | ||
243 | + datum[0].append((idx, 1)) | ||
244 | + else: | ||
245 | + if len(datum[1]) < length//2: | ||
246 | + datum[1].append((idx, 0)) | ||
247 | + | ||
248 | + idx += 1 | ||
249 | + if (idx % 5000 == 0): | ||
250 | + print(idx, len(datum[0]), len(datum[1])) | ||
251 | + | ||
252 | + l = int((length // 2) * 0.9) | ||
253 | + return CanDataset(csv, datum[0][:l] + datum[1][:l]), \ | ||
254 | + CanDataset(csv, datum[0][l:] + datum[1][l:]) | ||
255 | + | ||
256 | + | ||
257 | +def GetCanDataset(csv_file, length): | ||
258 | + csv = pd.read_csv(csv_file) | ||
259 | + | ||
260 | + i = 0 | ||
261 | + datum = [ [], [] ] | ||
262 | + | ||
263 | + while i + const.CAN_ID_BIT - 1 < len(csv): | ||
264 | + if len(datum[0]) >= length//2 and len(datum[1]) >= length//2: | ||
265 | + break | ||
266 | + | ||
267 | + is_regular = True | ||
268 | + for j in range(const.CAN_ID_BIT): | ||
269 | + l = csv.iloc[i + j] | ||
270 | + b = l[2] | ||
271 | + r = (l[b+2+1] == 'R') | ||
272 | + | ||
273 | + if not r: | ||
274 | + is_regular = False | ||
275 | + break | ||
276 | + | ||
277 | + if is_regular: | ||
278 | + if len(datum[0]) < length//2: | ||
279 | + datum[0].append((i, 1)) | ||
280 | + else: | ||
281 | + if len(datum[1]) < length//2: | ||
282 | + datum[1].append((i, 0)) | ||
283 | + i+=1 | ||
284 | + if (i % 5000 == 0): | ||
285 | + print(i, len(datum[0]), len(datum[1])) | ||
286 | + | ||
287 | + l = int((length // 2) * 0.9) | ||
288 | + return CanDataset(csv, datum[0][:l] + datum[1][:l]), \ | ||
289 | + CanDataset(csv, datum[0][l:] + datum[1][l:]) | ||
290 | + | ||
291 | + | ||
292 | +class CanDataset(Dataset): | ||
293 | + | ||
294 | + def __init__(self, csv, datum): | ||
295 | + self.csv = csv | ||
296 | + self.datum = datum | ||
297 | + | ||
298 | + def __len__(self): | ||
299 | + return len(self.datum) | ||
300 | + | ||
301 | + def __getitem__(self, idx): | ||
302 | + start_i = self.datum[idx][0] | ||
303 | + is_regular = self.datum[idx][1] | ||
304 | + | ||
305 | + l = np.zeros((const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
306 | + for i in range(const.CAN_ID_BIT): | ||
307 | + id = int(self.csv.iloc[start_i + i, 1], 16) | ||
308 | + bits = unpack_bits(np.array(id), const.CAN_ID_BIT) | ||
309 | + l[i] = bits | ||
310 | + l = np.reshape(l, (1, const.CAN_ID_BIT, const.CAN_ID_BIT)) | ||
311 | + | ||
312 | + return (l, is_regular) | ||
313 | + | ||
314 | + | ||
315 | +if __name__ == "__main__": | ||
316 | + kwargs = {"./dataset/DoS_dataset.csv" : './DoS_dataset.txt'} | ||
317 | + test_data_set = dataset.GetCanDatasetUsingTxtKwarg(-1, -1, False, **kwargs) | ||
318 | + testloader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, | ||
319 | + shuffle=False, num_workers=2) | ||
320 | + | ||
321 | + for x, y in testloader: | ||
322 | + print(x) | ||
323 | + print(y) | ||
324 | + break |
코드/fed_train.py
0 → 100644
This diff is collapsed. Click to expand it.
코드/model.py
0 → 100644
1 | +import torch.nn as nn | ||
2 | +import torch.nn.functional as F | ||
3 | +import torch | ||
4 | +import const | ||
5 | + | ||
6 | +class Net(nn.Module): | ||
7 | + def __init__(self): | ||
8 | + super(Net, self).__init__() | ||
9 | + | ||
10 | + self.f1 = nn.Sequential( | ||
11 | + nn.Conv2d(1, 2, 3), | ||
12 | + nn.ReLU(True), | ||
13 | + ) | ||
14 | + self.f2 = nn.Sequential( | ||
15 | + nn.Conv2d(2, 4, 3), | ||
16 | + nn.ReLU(True), | ||
17 | + ) | ||
18 | + self.f3 = nn.Sequential( | ||
19 | + nn.Conv2d(4, 8, 3), | ||
20 | + nn.ReLU(True), | ||
21 | + ) | ||
22 | + self.f4 = nn.Sequential( | ||
23 | + nn.Linear(8 * 23 * 23, 2), | ||
24 | + ) | ||
25 | + | ||
26 | + def forward(self, x): | ||
27 | + x = self.f1(x) | ||
28 | + x = self.f2(x) | ||
29 | + x = self.f3(x) | ||
30 | + x = torch.flatten(x, 1) | ||
31 | + x = self.f4(x) | ||
32 | + return x | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment