Showing
2 changed files
with
17 additions
and
13 deletions
코드/연합학습/README.md
0 → 100644
| 1 | +# khu_capstone_1 | ||
| 2 | + | ||
| 3 | +## 연합학습 기반 유해트래픽 탐지 | ||
| 4 | +- Pytorch | ||
| 5 | +- CAN protocol 유해 트래픽 데이터 셋 | ||
| 6 | +- FedAvg, FedProx, Fed using timestamp, Fed dynamic weight 논문 구현 및 성능 비교 | ||
| 7 | + | ||
| 8 | +## Model train | ||
| 9 | +- Install [PyTorch](http://pytorch.org) | ||
| 10 | +- Train model | ||
| 11 | +```bash | ||
| 12 | +python3 fed_train.py --packet_num 3 --fold_num 0 --batch_size 128 --lr 0.001 --n_nets 100 --comm_type fedprox --comm_round 50 | ||
| 13 | +``` |
| ... | @@ -49,8 +49,6 @@ alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0 | ... | @@ -49,8 +49,6 @@ alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0 |
| 49 | def add_args(parser): | 49 | def add_args(parser): |
| 50 | parser.add_argument('--packet_num', type=int, default=1, | 50 | parser.add_argument('--packet_num', type=int, default=1, |
| 51 | help='packet number used in training, 1 ~ 3') | 51 | help='packet number used in training, 1 ~ 3') |
| 52 | - parser.add_argument('--dataset', type=str, default='can', | ||
| 53 | - help='dataset used for training, can or syncan') | ||
| 54 | parser.add_argument('--fold_num', type=int, default=0, | 52 | parser.add_argument('--fold_num', type=int, default=0, |
| 55 | help='5-fold, 0 ~ 4') | 53 | help='5-fold, 0 ~ 4') |
| 56 | parser.add_argument('--batch_size', type=int, default=128, | 54 | parser.add_argument('--batch_size', type=int, default=128, |
| ... | @@ -549,21 +547,14 @@ def start_train(): | ... | @@ -549,21 +547,14 @@ def start_train(): |
| 549 | torch.manual_seed(seed) | 547 | torch.manual_seed(seed) |
| 550 | 548 | ||
| 551 | print("Loading data...") | 549 | print("Loading data...") |
| 552 | - if args.dataset == 'can': | 550 | + train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt") |
| 553 | - train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt") | 551 | + |
| 554 | - elif args.dataset == 'syncan': | ||
| 555 | - train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/test_mixed.csv", "./dataset/Mixed_dataset_1.txt") | ||
| 556 | - | ||
| 557 | sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size) | 552 | sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size) |
| 558 | testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler, | 553 | testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler, |
| 559 | shuffle=False, num_workers=2, drop_last=True) | 554 | shuffle=False, num_workers=2, drop_last=True) |
| 560 | 555 | ||
| 561 | - if args.dataset == 'can': | 556 | + fed_model = model.OneNet(args.packet_num) |
| 562 | - fed_model = model.OneNet(args.packet_num) | 557 | + edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)] |
| 563 | - edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)] | ||
| 564 | - elif args.dataset == 'syncan': | ||
| 565 | - fed_model = model.OneNet(args.packet_num) | ||
| 566 | - edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)] | ||
| 567 | 558 | ||
| 568 | if args.comm_type == "fedavg": | 559 | if args.comm_type == "fedavg": |
| 569 | start_fedavg(fed_model, args, | 560 | start_fedavg(fed_model, args, | ... | ... |
-
Please register or login to post a comment