Showing
1 changed file
with
185 additions
and
0 deletions
KoBERT/Sentiment_Analysis_BERT_main.py
0 → 100644
1 | +import argparse | ||
2 | +import torch | ||
3 | +from torch import nn | ||
4 | +from tqdm import tqdm | ||
5 | +from transformers import AdamW | ||
6 | +#from transformers.optimization import WarmupLinearSchedule | ||
7 | +import time | ||
8 | +import random | ||
9 | +import numpy as np | ||
10 | +from kobert.pytorch_kobert import get_pytorch_kobert_model | ||
11 | +bertmodel, vocab = get_pytorch_kobert_model() | ||
12 | +import KoBERT.dataset_ as dataset | ||
13 | +# print(vocab.to_tokens(517)) | ||
14 | +# print(vocab.to_tokens(5515)) | ||
15 | +# print(vocab.to_tokens(517)) | ||
16 | +# print(vocab.to_tokens(492)) | ||
17 | +# print("----------------------------------------------") | ||
18 | +# print(vocab.to_tokens(3610)) | ||
19 | +# print(vocab.to_tokens(7096)) | ||
20 | +# print(vocab.to_tokens(4214)) | ||
21 | +# print(vocab.to_tokens(1770)) | ||
22 | +# print(vocab.to_tokens(517)) | ||
23 | +# print(vocab.to_tokens(46)) | ||
24 | +# print(vocab.to_tokens(4525)) | ||
25 | +# print(vocab.to_tokens(3610)) | ||
26 | +# print(vocab.to_tokens(6954)) | ||
27 | +# | ||
28 | +# exit() | ||
29 | + | ||
30 | +device = torch.device("cuda:0") | ||
31 | +SEED = 1234 | ||
32 | +random.seed(SEED) | ||
33 | +np.random.seed(SEED) | ||
34 | +torch.manual_seed(SEED) | ||
35 | +torch.backends.cudnn.deterministic = True | ||
36 | + | ||
37 | +def train(model, iter_loader, optimizer, loss_fn): | ||
38 | + train_acc = 0.0 | ||
39 | + model.train() | ||
40 | + for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm(iter_loader)): | ||
41 | + optimizer.zero_grad() | ||
42 | + token_ids = token_ids.long().to(device) | ||
43 | + segment_ids = segment_ids.long().to(device) | ||
44 | + valid_length = valid_length | ||
45 | + label = label.long().to(device) | ||
46 | + out = model(token_ids, valid_length, segment_ids) | ||
47 | + loss = loss_fn(out, label) | ||
48 | + loss.backward() | ||
49 | + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) | ||
50 | + optimizer.step() | ||
51 | + #scheduler.step() # Update learning rate schedule | ||
52 | + train_acc += calc_accuracy(out, label) | ||
53 | + return loss.data.cpu().numpy(), train_acc/(batch_id + 1) | ||
54 | + | ||
55 | +def test(model, iter_loader, loss_fn): | ||
56 | + model.eval() | ||
57 | + test_acc = 0.0 | ||
58 | + with torch.no_grad(): | ||
59 | + for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(iter_loader): | ||
60 | + token_ids = token_ids.long().to(device) | ||
61 | + segment_ids = segment_ids.long().to(device) | ||
62 | + valid_length = valid_length | ||
63 | + label = label.long().to(device) | ||
64 | + | ||
65 | + out = model(token_ids, valid_length, segment_ids) | ||
66 | + loss = loss_fn(out, label) | ||
67 | + test_acc += calc_accuracy(out, label) | ||
68 | + return loss.data.cpu().numpy(), test_acc/(batch_id + 1) | ||
69 | + | ||
70 | +def bert_inference(model, src): | ||
71 | + model.eval() | ||
72 | + with torch.no_grad(): | ||
73 | + src_data = dataset.infer(args, src) | ||
74 | + for batch_id, (token_ids, valid_length, segment_ids) in enumerate(src_data): | ||
75 | + token_ids = torch.tensor([token_ids]).long().to(device) | ||
76 | + segment_ids = torch.tensor([segment_ids]).long().to(device) | ||
77 | + valid_length = valid_length.tolist() | ||
78 | + valid_length = torch.tensor([valid_length]).long() | ||
79 | + | ||
80 | + out = model(token_ids, valid_length, segment_ids) | ||
81 | + | ||
82 | + max_vals, max_indices = torch.max(out, 1) | ||
83 | + | ||
84 | + label = max_indices.data.cpu().numpy() | ||
85 | + if label == 0: | ||
86 | + return 0 | ||
87 | + else: | ||
88 | + return 1 | ||
89 | + return -1 | ||
90 | + | ||
91 | +import csv | ||
92 | +def calc_accuracy(X,Y): | ||
93 | + max_vals, max_indices = torch.max(X, 1) | ||
94 | + if args.do_test: | ||
95 | + max_list = max_indices.data.cpu().numpy().tolist() | ||
96 | + f = open('chat_Q_label_0325.txt', 'a', encoding='utf-8') | ||
97 | + wr = csv.writer(f, delimiter='\t') | ||
98 | + for i in range(len(max_list)): | ||
99 | + wr.writerow(str(max_list[i])) | ||
100 | + train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0] | ||
101 | + return train_acc | ||
102 | + | ||
103 | +def epoch_time(start_time, end_time): | ||
104 | + elapsed_time = end_time - start_time | ||
105 | + elapsed_mins = int(elapsed_time / 60) | ||
106 | + elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) | ||
107 | + return elapsed_mins, elapsed_secs | ||
108 | + | ||
109 | +# Argparse init | ||
110 | +parser = argparse.ArgumentParser() | ||
111 | +parser.add_argument('--max_len', type=int, default=64) | ||
112 | +parser.add_argument('--batch_size', type=int, default=64) | ||
113 | +parser.add_argument('--warmup_ratio', type=int, default=0.1) | ||
114 | +parser.add_argument('--num_epochs', type=int, default=5) | ||
115 | +parser.add_argument('--max_grad_norm', type=int, default=1) | ||
116 | +parser.add_argument('--learning_rate', type=float, default=5e-5) | ||
117 | +parser.add_argument('--num_workers', type=int, default=1) | ||
118 | +parser.add_argument('--do_train', type=bool, default=False) | ||
119 | +parser.add_argument('--do_test', type=bool, default=False) | ||
120 | +parser.add_argument('--train', type=bool, default=True) | ||
121 | +args = parser.parse_args() | ||
122 | + | ||
123 | +def main(): | ||
124 | + from Bert_model import BERTClassifier | ||
125 | + model = BERTClassifier(bertmodel, dr_rate=0.5).to(device) | ||
126 | + train_dataloader, test_dataloader = dataset.get_loader(args) | ||
127 | + # Prepare optimizer and schedule (linear warmup and decay) | ||
128 | + no_decay = ['bias', 'LayerNorm.weight'] | ||
129 | + optimizer_grouped_parameters = [ | ||
130 | + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | ||
131 | + 'weight_decay': 0.01}, | ||
132 | + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} | ||
133 | + ] | ||
134 | + | ||
135 | + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) | ||
136 | + loss_fn = nn.CrossEntropyLoss() | ||
137 | + | ||
138 | + t_total = len(train_dataloader) * args.num_epochs | ||
139 | + warmup_step = int(t_total * args.warmup_ratio) | ||
140 | + | ||
141 | + # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_step, t_total=t_total) | ||
142 | + best_valid_loss = float('inf') | ||
143 | + | ||
144 | + # for idx, (key, value) in enumerate(args.__dict__.items()): | ||
145 | + # if idx == 0: | ||
146 | + # print("\nargparse{\n", "\t", key, ":", value) | ||
147 | + # elif idx == len(args.__dict__) - 1: | ||
148 | + # print("\t", key, ":", value, "\n}") | ||
149 | + # else: | ||
150 | + # print("\t", key, ":", value) | ||
151 | + | ||
152 | + if args.do_train: | ||
153 | + | ||
154 | + for epoch in range(args.num_epochs): | ||
155 | + start_time = time.time() | ||
156 | + | ||
157 | + print("\n\t-----Train-----") | ||
158 | + train_loss, train_acc = train(model, train_dataloader, optimizer, loss_fn) | ||
159 | + valid_loss, valid_acc = test(model, test_dataloader, loss_fn) | ||
160 | + | ||
161 | + end_time = time.time() | ||
162 | + | ||
163 | + epoch_mins, epoch_secs = epoch_time(start_time, end_time) | ||
164 | + | ||
165 | + if valid_loss < best_valid_loss: | ||
166 | + best_valid_loss = valid_loss | ||
167 | + torch.save(model.state_dict(), 'bert_SA-model.pt') | ||
168 | + | ||
169 | + print(f'Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s') | ||
170 | + print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc * 100:.2f}%') | ||
171 | + print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc * 100:.2f}%') | ||
172 | + | ||
173 | + model.load_state_dict(torch.load('bert_SA-model.pt')) | ||
174 | + | ||
175 | + if args.do_test: | ||
176 | + test_loss, test_acc = test(model, test_dataloader, loss_fn) | ||
177 | + print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc * 100:.2f}%') | ||
178 | + | ||
179 | + # while(1): | ||
180 | + # se = input("input : ") | ||
181 | + # se_list = [se, '-1'] | ||
182 | + # bert_inference(model, [se_list]) | ||
183 | + | ||
184 | +if __name__ == "__main__": | ||
185 | + main() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment