Showing
1 changed file
with
185 additions
and
0 deletions
KoBERT/Sentiment_Analysis_BERT_main.py
0 → 100644
| 1 | +import argparse | ||
| 2 | +import torch | ||
| 3 | +from torch import nn | ||
| 4 | +from tqdm import tqdm | ||
| 5 | +from transformers import AdamW | ||
| 6 | +#from transformers.optimization import WarmupLinearSchedule | ||
| 7 | +import time | ||
| 8 | +import random | ||
| 9 | +import numpy as np | ||
| 10 | +from kobert.pytorch_kobert import get_pytorch_kobert_model | ||
| 11 | +bertmodel, vocab = get_pytorch_kobert_model() | ||
| 12 | +import KoBERT.dataset_ as dataset | ||
| 13 | +# print(vocab.to_tokens(517)) | ||
| 14 | +# print(vocab.to_tokens(5515)) | ||
| 15 | +# print(vocab.to_tokens(517)) | ||
| 16 | +# print(vocab.to_tokens(492)) | ||
| 17 | +# print("----------------------------------------------") | ||
| 18 | +# print(vocab.to_tokens(3610)) | ||
| 19 | +# print(vocab.to_tokens(7096)) | ||
| 20 | +# print(vocab.to_tokens(4214)) | ||
| 21 | +# print(vocab.to_tokens(1770)) | ||
| 22 | +# print(vocab.to_tokens(517)) | ||
| 23 | +# print(vocab.to_tokens(46)) | ||
| 24 | +# print(vocab.to_tokens(4525)) | ||
| 25 | +# print(vocab.to_tokens(3610)) | ||
| 26 | +# print(vocab.to_tokens(6954)) | ||
| 27 | +# | ||
| 28 | +# exit() | ||
| 29 | + | ||
| 30 | +device = torch.device("cuda:0") | ||
| 31 | +SEED = 1234 | ||
| 32 | +random.seed(SEED) | ||
| 33 | +np.random.seed(SEED) | ||
| 34 | +torch.manual_seed(SEED) | ||
| 35 | +torch.backends.cudnn.deterministic = True | ||
| 36 | + | ||
| 37 | +def train(model, iter_loader, optimizer, loss_fn): | ||
| 38 | + train_acc = 0.0 | ||
| 39 | + model.train() | ||
| 40 | + for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm(iter_loader)): | ||
| 41 | + optimizer.zero_grad() | ||
| 42 | + token_ids = token_ids.long().to(device) | ||
| 43 | + segment_ids = segment_ids.long().to(device) | ||
| 44 | + valid_length = valid_length | ||
| 45 | + label = label.long().to(device) | ||
| 46 | + out = model(token_ids, valid_length, segment_ids) | ||
| 47 | + loss = loss_fn(out, label) | ||
| 48 | + loss.backward() | ||
| 49 | + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) | ||
| 50 | + optimizer.step() | ||
| 51 | + #scheduler.step() # Update learning rate schedule | ||
| 52 | + train_acc += calc_accuracy(out, label) | ||
| 53 | + return loss.data.cpu().numpy(), train_acc/(batch_id + 1) | ||
| 54 | + | ||
| 55 | +def test(model, iter_loader, loss_fn): | ||
| 56 | + model.eval() | ||
| 57 | + test_acc = 0.0 | ||
| 58 | + with torch.no_grad(): | ||
| 59 | + for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(iter_loader): | ||
| 60 | + token_ids = token_ids.long().to(device) | ||
| 61 | + segment_ids = segment_ids.long().to(device) | ||
| 62 | + valid_length = valid_length | ||
| 63 | + label = label.long().to(device) | ||
| 64 | + | ||
| 65 | + out = model(token_ids, valid_length, segment_ids) | ||
| 66 | + loss = loss_fn(out, label) | ||
| 67 | + test_acc += calc_accuracy(out, label) | ||
| 68 | + return loss.data.cpu().numpy(), test_acc/(batch_id + 1) | ||
| 69 | + | ||
| 70 | +def bert_inference(model, src): | ||
| 71 | + model.eval() | ||
| 72 | + with torch.no_grad(): | ||
| 73 | + src_data = dataset.infer(args, src) | ||
| 74 | + for batch_id, (token_ids, valid_length, segment_ids) in enumerate(src_data): | ||
| 75 | + token_ids = torch.tensor([token_ids]).long().to(device) | ||
| 76 | + segment_ids = torch.tensor([segment_ids]).long().to(device) | ||
| 77 | + valid_length = valid_length.tolist() | ||
| 78 | + valid_length = torch.tensor([valid_length]).long() | ||
| 79 | + | ||
| 80 | + out = model(token_ids, valid_length, segment_ids) | ||
| 81 | + | ||
| 82 | + max_vals, max_indices = torch.max(out, 1) | ||
| 83 | + | ||
| 84 | + label = max_indices.data.cpu().numpy() | ||
| 85 | + if label == 0: | ||
| 86 | + return 0 | ||
| 87 | + else: | ||
| 88 | + return 1 | ||
| 89 | + return -1 | ||
| 90 | + | ||
| 91 | +import csv | ||
| 92 | +def calc_accuracy(X,Y): | ||
| 93 | + max_vals, max_indices = torch.max(X, 1) | ||
| 94 | + if args.do_test: | ||
| 95 | + max_list = max_indices.data.cpu().numpy().tolist() | ||
| 96 | + f = open('chat_Q_label_0325.txt', 'a', encoding='utf-8') | ||
| 97 | + wr = csv.writer(f, delimiter='\t') | ||
| 98 | + for i in range(len(max_list)): | ||
| 99 | + wr.writerow(str(max_list[i])) | ||
| 100 | + train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0] | ||
| 101 | + return train_acc | ||
| 102 | + | ||
| 103 | +def epoch_time(start_time, end_time): | ||
| 104 | + elapsed_time = end_time - start_time | ||
| 105 | + elapsed_mins = int(elapsed_time / 60) | ||
| 106 | + elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) | ||
| 107 | + return elapsed_mins, elapsed_secs | ||
| 108 | + | ||
| 109 | +# Argparse init | ||
| 110 | +parser = argparse.ArgumentParser() | ||
| 111 | +parser.add_argument('--max_len', type=int, default=64) | ||
| 112 | +parser.add_argument('--batch_size', type=int, default=64) | ||
| 113 | +parser.add_argument('--warmup_ratio', type=int, default=0.1) | ||
| 114 | +parser.add_argument('--num_epochs', type=int, default=5) | ||
| 115 | +parser.add_argument('--max_grad_norm', type=int, default=1) | ||
| 116 | +parser.add_argument('--learning_rate', type=float, default=5e-5) | ||
| 117 | +parser.add_argument('--num_workers', type=int, default=1) | ||
| 118 | +parser.add_argument('--do_train', type=bool, default=False) | ||
| 119 | +parser.add_argument('--do_test', type=bool, default=False) | ||
| 120 | +parser.add_argument('--train', type=bool, default=True) | ||
| 121 | +args = parser.parse_args() | ||
| 122 | + | ||
| 123 | +def main(): | ||
| 124 | + from Bert_model import BERTClassifier | ||
| 125 | + model = BERTClassifier(bertmodel, dr_rate=0.5).to(device) | ||
| 126 | + train_dataloader, test_dataloader = dataset.get_loader(args) | ||
| 127 | + # Prepare optimizer and schedule (linear warmup and decay) | ||
| 128 | + no_decay = ['bias', 'LayerNorm.weight'] | ||
| 129 | + optimizer_grouped_parameters = [ | ||
| 130 | + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | ||
| 131 | + 'weight_decay': 0.01}, | ||
| 132 | + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} | ||
| 133 | + ] | ||
| 134 | + | ||
| 135 | + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) | ||
| 136 | + loss_fn = nn.CrossEntropyLoss() | ||
| 137 | + | ||
| 138 | + t_total = len(train_dataloader) * args.num_epochs | ||
| 139 | + warmup_step = int(t_total * args.warmup_ratio) | ||
| 140 | + | ||
| 141 | + # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_step, t_total=t_total) | ||
| 142 | + best_valid_loss = float('inf') | ||
| 143 | + | ||
| 144 | + # for idx, (key, value) in enumerate(args.__dict__.items()): | ||
| 145 | + # if idx == 0: | ||
| 146 | + # print("\nargparse{\n", "\t", key, ":", value) | ||
| 147 | + # elif idx == len(args.__dict__) - 1: | ||
| 148 | + # print("\t", key, ":", value, "\n}") | ||
| 149 | + # else: | ||
| 150 | + # print("\t", key, ":", value) | ||
| 151 | + | ||
| 152 | + if args.do_train: | ||
| 153 | + | ||
| 154 | + for epoch in range(args.num_epochs): | ||
| 155 | + start_time = time.time() | ||
| 156 | + | ||
| 157 | + print("\n\t-----Train-----") | ||
| 158 | + train_loss, train_acc = train(model, train_dataloader, optimizer, loss_fn) | ||
| 159 | + valid_loss, valid_acc = test(model, test_dataloader, loss_fn) | ||
| 160 | + | ||
| 161 | + end_time = time.time() | ||
| 162 | + | ||
| 163 | + epoch_mins, epoch_secs = epoch_time(start_time, end_time) | ||
| 164 | + | ||
| 165 | + if valid_loss < best_valid_loss: | ||
| 166 | + best_valid_loss = valid_loss | ||
| 167 | + torch.save(model.state_dict(), 'bert_SA-model.pt') | ||
| 168 | + | ||
| 169 | + print(f'Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s') | ||
| 170 | + print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc * 100:.2f}%') | ||
| 171 | + print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc * 100:.2f}%') | ||
| 172 | + | ||
| 173 | + model.load_state_dict(torch.load('bert_SA-model.pt')) | ||
| 174 | + | ||
| 175 | + if args.do_test: | ||
| 176 | + test_loss, test_acc = test(model, test_dataloader, loss_fn) | ||
| 177 | + print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc * 100:.2f}%') | ||
| 178 | + | ||
| 179 | + # while(1): | ||
| 180 | + # se = input("input : ") | ||
| 181 | + # se_list = [se, '-1'] | ||
| 182 | + # bert_inference(model, [se_list]) | ||
| 183 | + | ||
| 184 | +if __name__ == "__main__": | ||
| 185 | + main() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment