김지훈

docs: federated train readme

# khu_capstone_1
## 연합학습 기반 유해트래픽 탐지
- Pytorch
- CAN protocol 유해 트래픽 데이터 셋
- FedAvg, FedProx, Fed using timestamp, Fed dynamic weight 논문 구현 및 성능 비교
## Model train
- Install [PyTorch](http://pytorch.org)
- Train model
```bash
python3 fed_train.py --packet_num 3 --fold_num 0 --batch_size 128 --lr 0.001 --n_nets 100 --comm_type fedprox --comm_round 50
```
......@@ -49,8 +49,6 @@ alpha, beta, gamma = 40.0/100.0, 40.0/100.0, 20.0/100.0
def add_args(parser):
parser.add_argument('--packet_num', type=int, default=1,
help='packet number used in training, 1 ~ 3')
parser.add_argument('--dataset', type=str, default='can',
help='dataset used for training, can or syncan')
parser.add_argument('--fold_num', type=int, default=0,
help='5-fold, 0 ~ 4')
parser.add_argument('--batch_size', type=int, default=128,
......@@ -549,21 +547,14 @@ def start_train():
torch.manual_seed(seed)
print("Loading data...")
if args.dataset == 'can':
train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt")
elif args.dataset == 'syncan':
train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/test_mixed.csv", "./dataset/Mixed_dataset_1.txt")
train_data_set, data_idx_map, net_data_count, test_data_set = dataset.GetCanDataset(args.n_nets, args.fold_num, args.packet_num, "./dataset/Mixed_dataset.csv", "./dataset/Mixed_dataset_1.txt")
sampler = dataset.BatchIntervalSampler(len(test_data_set), args.batch_size)
testloader = torch.utils.data.DataLoader(test_data_set, batch_size=args.batch_size, sampler=sampler,
shuffle=False, num_workers=2, drop_last=True)
if args.dataset == 'can':
fed_model = model.OneNet(args.packet_num)
edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)]
elif args.dataset == 'syncan':
fed_model = model.OneNet(args.packet_num)
edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)]
fed_model = model.OneNet(args.packet_num)
edges = [model.OneNet(args.packet_num) for _ in range(args.n_nets)]
if args.comm_type == "fedavg":
start_fedavg(fed_model, args,
......