bongminkim

SA_main

1 +import argparse
2 +import torch
3 +from torch import nn
4 +from tqdm import tqdm
5 +from transformers import AdamW
6 +#from transformers.optimization import WarmupLinearSchedule
7 +import time
8 +import random
9 +import numpy as np
10 +from kobert.pytorch_kobert import get_pytorch_kobert_model
11 +bertmodel, vocab = get_pytorch_kobert_model()
12 +import KoBERT.dataset_ as dataset
13 +# print(vocab.to_tokens(517))
14 +# print(vocab.to_tokens(5515))
15 +# print(vocab.to_tokens(517))
16 +# print(vocab.to_tokens(492))
17 +# print("----------------------------------------------")
18 +# print(vocab.to_tokens(3610))
19 +# print(vocab.to_tokens(7096))
20 +# print(vocab.to_tokens(4214))
21 +# print(vocab.to_tokens(1770))
22 +# print(vocab.to_tokens(517))
23 +# print(vocab.to_tokens(46))
24 +# print(vocab.to_tokens(4525))
25 +# print(vocab.to_tokens(3610))
26 +# print(vocab.to_tokens(6954))
27 +#
28 +# exit()
29 +
30 +device = torch.device("cuda:0")
31 +SEED = 1234
32 +random.seed(SEED)
33 +np.random.seed(SEED)
34 +torch.manual_seed(SEED)
35 +torch.backends.cudnn.deterministic = True
36 +
37 +def train(model, iter_loader, optimizer, loss_fn):
38 + train_acc = 0.0
39 + model.train()
40 + for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm(iter_loader)):
41 + optimizer.zero_grad()
42 + token_ids = token_ids.long().to(device)
43 + segment_ids = segment_ids.long().to(device)
44 + valid_length = valid_length
45 + label = label.long().to(device)
46 + out = model(token_ids, valid_length, segment_ids)
47 + loss = loss_fn(out, label)
48 + loss.backward()
49 + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
50 + optimizer.step()
51 + #scheduler.step() # Update learning rate schedule
52 + train_acc += calc_accuracy(out, label)
53 + return loss.data.cpu().numpy(), train_acc/(batch_id + 1)
54 +
55 +def test(model, iter_loader, loss_fn):
56 + model.eval()
57 + test_acc = 0.0
58 + with torch.no_grad():
59 + for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(iter_loader):
60 + token_ids = token_ids.long().to(device)
61 + segment_ids = segment_ids.long().to(device)
62 + valid_length = valid_length
63 + label = label.long().to(device)
64 +
65 + out = model(token_ids, valid_length, segment_ids)
66 + loss = loss_fn(out, label)
67 + test_acc += calc_accuracy(out, label)
68 + return loss.data.cpu().numpy(), test_acc/(batch_id + 1)
69 +
70 +def bert_inference(model, src):
71 + model.eval()
72 + with torch.no_grad():
73 + src_data = dataset.infer(args, src)
74 + for batch_id, (token_ids, valid_length, segment_ids) in enumerate(src_data):
75 + token_ids = torch.tensor([token_ids]).long().to(device)
76 + segment_ids = torch.tensor([segment_ids]).long().to(device)
77 + valid_length = valid_length.tolist()
78 + valid_length = torch.tensor([valid_length]).long()
79 +
80 + out = model(token_ids, valid_length, segment_ids)
81 +
82 + max_vals, max_indices = torch.max(out, 1)
83 +
84 + label = max_indices.data.cpu().numpy()
85 + if label == 0:
86 + return 0
87 + else:
88 + return 1
89 + return -1
90 +
91 +import csv
92 +def calc_accuracy(X,Y):
93 + max_vals, max_indices = torch.max(X, 1)
94 + if args.do_test:
95 + max_list = max_indices.data.cpu().numpy().tolist()
96 + f = open('chat_Q_label_0325.txt', 'a', encoding='utf-8')
97 + wr = csv.writer(f, delimiter='\t')
98 + for i in range(len(max_list)):
99 + wr.writerow(str(max_list[i]))
100 + train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
101 + return train_acc
102 +
103 +def epoch_time(start_time, end_time):
104 + elapsed_time = end_time - start_time
105 + elapsed_mins = int(elapsed_time / 60)
106 + elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
107 + return elapsed_mins, elapsed_secs
108 +
109 +# Argparse init
110 +parser = argparse.ArgumentParser()
111 +parser.add_argument('--max_len', type=int, default=64)
112 +parser.add_argument('--batch_size', type=int, default=64)
113 +parser.add_argument('--warmup_ratio', type=int, default=0.1)
114 +parser.add_argument('--num_epochs', type=int, default=5)
115 +parser.add_argument('--max_grad_norm', type=int, default=1)
116 +parser.add_argument('--learning_rate', type=float, default=5e-5)
117 +parser.add_argument('--num_workers', type=int, default=1)
118 +parser.add_argument('--do_train', type=bool, default=False)
119 +parser.add_argument('--do_test', type=bool, default=False)
120 +parser.add_argument('--train', type=bool, default=True)
121 +args = parser.parse_args()
122 +
123 +def main():
124 + from Bert_model import BERTClassifier
125 + model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)
126 + train_dataloader, test_dataloader = dataset.get_loader(args)
127 + # Prepare optimizer and schedule (linear warmup and decay)
128 + no_decay = ['bias', 'LayerNorm.weight']
129 + optimizer_grouped_parameters = [
130 + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
131 + 'weight_decay': 0.01},
132 + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
133 + ]
134 +
135 + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
136 + loss_fn = nn.CrossEntropyLoss()
137 +
138 + t_total = len(train_dataloader) * args.num_epochs
139 + warmup_step = int(t_total * args.warmup_ratio)
140 +
141 + # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_step, t_total=t_total)
142 + best_valid_loss = float('inf')
143 +
144 + # for idx, (key, value) in enumerate(args.__dict__.items()):
145 + # if idx == 0:
146 + # print("\nargparse{\n", "\t", key, ":", value)
147 + # elif idx == len(args.__dict__) - 1:
148 + # print("\t", key, ":", value, "\n}")
149 + # else:
150 + # print("\t", key, ":", value)
151 +
152 + if args.do_train:
153 +
154 + for epoch in range(args.num_epochs):
155 + start_time = time.time()
156 +
157 + print("\n\t-----Train-----")
158 + train_loss, train_acc = train(model, train_dataloader, optimizer, loss_fn)
159 + valid_loss, valid_acc = test(model, test_dataloader, loss_fn)
160 +
161 + end_time = time.time()
162 +
163 + epoch_mins, epoch_secs = epoch_time(start_time, end_time)
164 +
165 + if valid_loss < best_valid_loss:
166 + best_valid_loss = valid_loss
167 + torch.save(model.state_dict(), 'bert_SA-model.pt')
168 +
169 + print(f'Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
170 + print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc * 100:.2f}%')
171 + print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc * 100:.2f}%')
172 +
173 + model.load_state_dict(torch.load('bert_SA-model.pt'))
174 +
175 + if args.do_test:
176 + test_loss, test_acc = test(model, test_dataloader, loss_fn)
177 + print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc * 100:.2f}%')
178 +
179 + # while(1):
180 + # se = input("input : ")
181 + # se_list = [se, '-1']
182 + # bert_inference(model, [se_list])
183 +
184 +if __name__ == "__main__":
185 + main()
...\ No newline at end of file ...\ No newline at end of file