Showing
1 changed file
with
202 additions
and
0 deletions
Chatbot/Chatbot_main.py
0 → 100644
| 1 | +import time | ||
| 2 | +import torch | ||
| 3 | +import argparse | ||
| 4 | +from torch import nn | ||
| 5 | +from metric import acc, train_test | ||
| 6 | +from Styling import styling, make_special_token | ||
| 7 | +from get_data import data_preprocessing, tokenizer1 | ||
| 8 | +from generation import inference | ||
| 9 | + | ||
| 10 | +SEED = 1234 | ||
| 11 | + | ||
| 12 | +# argparse 정의 | ||
| 13 | +parser = argparse.ArgumentParser() | ||
| 14 | +parser.add_argument('--max_len', type=int, default=40) # max_len 크게 해야 오류 안 생김. | ||
| 15 | +parser.add_argument('--batch_size', type=int, default=256) | ||
| 16 | +parser.add_argument('--num_epochs', type=int, default=22) | ||
| 17 | +parser.add_argument('--warming_up_epochs', type=int, default=5) | ||
| 18 | +parser.add_argument('--lr', type=float, default=0.0002) | ||
| 19 | +parser.add_argument('--embedding_dim', type=int, default=160) | ||
| 20 | +parser.add_argument('--nlayers', type=int, default=2) | ||
| 21 | +parser.add_argument('--nhead', type=int, default=2) | ||
| 22 | +parser.add_argument('--dropout', type=float, default=0.1) | ||
| 23 | +parser.add_argument('--train', type=bool, default=True) | ||
| 24 | +parser.add_argument('--per_soft', type=bool, default=False) | ||
| 25 | +parser.add_argument('--per_rough', type=bool, default=False) | ||
| 26 | +args = parser.parse_args() | ||
| 27 | + | ||
| 28 | +# 시간 계산 함수 | ||
| 29 | +def epoch_time(start_time, end_time): | ||
| 30 | + elapsed_time = end_time - start_time | ||
| 31 | + elapsed_mins = int(elapsed_time / 60) | ||
| 32 | + elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) | ||
| 33 | + return elapsed_mins, elapsed_secs | ||
| 34 | + | ||
| 35 | +# 학습 | ||
| 36 | +def train(model, iterator, optimizer, criterion): | ||
| 37 | + total_loss = 0 | ||
| 38 | + iter_num = 0 | ||
| 39 | + tr_acc = 0 | ||
| 40 | + model.train() | ||
| 41 | + | ||
| 42 | + for step, batch in enumerate(iterator): | ||
| 43 | + optimizer.zero_grad() | ||
| 44 | + | ||
| 45 | + enc_input, dec_input , enc_label = batch.text, batch.target_text, batch.SA | ||
| 46 | + | ||
| 47 | + dec_output = dec_input[:, 1:] | ||
| 48 | + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data) | ||
| 49 | + | ||
| 50 | + # emotion 과 체를 반영 | ||
| 51 | + enc_input, dec_input, dec_outputs = \ | ||
| 52 | + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args, TEXT, LABEL) | ||
| 53 | + | ||
| 54 | + y_pred = model(enc_input, dec_input) | ||
| 55 | + | ||
| 56 | + y_pred = y_pred.reshape(-1, y_pred.size(-1)) | ||
| 57 | + dec_output = dec_outputs.view(-1).long() | ||
| 58 | + | ||
| 59 | + # padding 제외한 value index 추출 | ||
| 60 | + real_value_index = [dec_output != 1] # <pad> == 1 | ||
| 61 | + | ||
| 62 | + # padding 은 loss 계산시 제외 | ||
| 63 | + loss = criterion(y_pred[real_value_index], dec_output[real_value_index]) | ||
| 64 | + loss.backward() | ||
| 65 | + optimizer.step() | ||
| 66 | + | ||
| 67 | + with torch.no_grad(): | ||
| 68 | + train_acc = acc(y_pred, dec_output) | ||
| 69 | + | ||
| 70 | + total_loss += loss | ||
| 71 | + iter_num += 1 | ||
| 72 | + tr_acc += train_acc | ||
| 73 | + | ||
| 74 | + train_test(step, y_pred, dec_output, real_value_index, enc_input, | ||
| 75 | + args, TEXT, LABEL) | ||
| 76 | + | ||
| 77 | + return total_loss.data.cpu().numpy() / iter_num, tr_acc.data.cpu().numpy() / iter_num | ||
| 78 | + | ||
| 79 | +# 테스트 | ||
| 80 | +def test(model, iterator, criterion): | ||
| 81 | + total_loss = 0 | ||
| 82 | + iter_num = 0 | ||
| 83 | + te_acc = 0 | ||
| 84 | + model.eval() | ||
| 85 | + | ||
| 86 | + with torch.no_grad(): | ||
| 87 | + for batch in iterator: | ||
| 88 | + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA | ||
| 89 | + dec_output = dec_input[:, 1:] | ||
| 90 | + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data) | ||
| 91 | + | ||
| 92 | + # emotion 과 체를 반영 | ||
| 93 | + enc_input, dec_input, dec_outputs = \ | ||
| 94 | + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args, TEXT, LABEL) | ||
| 95 | + | ||
| 96 | + y_pred = model(enc_input, dec_input) | ||
| 97 | + | ||
| 98 | + y_pred = y_pred.reshape(-1, y_pred.size(-1)) | ||
| 99 | + dec_output = dec_outputs.view(-1).long() | ||
| 100 | + | ||
| 101 | + real_value_index = [dec_output != 1] # <pad> == 1 | ||
| 102 | + | ||
| 103 | + loss = criterion(y_pred[real_value_index], dec_output[real_value_index]) | ||
| 104 | + | ||
| 105 | + with torch.no_grad(): | ||
| 106 | + test_acc = acc(y_pred, dec_output) | ||
| 107 | + total_loss += loss | ||
| 108 | + iter_num += 1 | ||
| 109 | + te_acc += test_acc | ||
| 110 | + | ||
| 111 | + return total_loss.data.cpu().numpy() / iter_num, te_acc.data.cpu().numpy() / iter_num | ||
| 112 | + | ||
| 113 | +def main(TEXT, LABEL, train_loader, test_loader): | ||
| 114 | + | ||
| 115 | + # for sentiment analysis. load .pt file | ||
| 116 | + from KoBERT.Bert_model import BERTClassifier | ||
| 117 | + from kobert.pytorch_kobert import get_pytorch_kobert_model | ||
| 118 | + bertmodel, vocab = get_pytorch_kobert_model() | ||
| 119 | + sa_model = BERTClassifier(bertmodel, dr_rate=0.5).to(device) | ||
| 120 | + sa_model.load_state_dict(torch.load('bert_SA-model.pt')) | ||
| 121 | + | ||
| 122 | + # print argparse | ||
| 123 | + for idx, (key, value) in enumerate(args.__dict__.items()): | ||
| 124 | + if idx == 0: | ||
| 125 | + print("\nargparse{\n", "\t", key, ":", value) | ||
| 126 | + elif idx == len(args.__dict__)-1: | ||
| 127 | + print("\t", key, ":", value, "\n}") | ||
| 128 | + else: | ||
| 129 | + print("\t", key, ":", value) | ||
| 130 | + | ||
| 131 | + from model import Transformer, GradualWarmupScheduler | ||
| 132 | + | ||
| 133 | + # Transformer model init | ||
| 134 | + model = Transformer(args, TEXT, LABEL) | ||
| 135 | + if args.per_soft: | ||
| 136 | + sorted_path = 'sorted_model-soft.pth' | ||
| 137 | + else: | ||
| 138 | + sorted_path = 'sorted_model-rough.pth' | ||
| 139 | + | ||
| 140 | + # loss 계산시 pad 제외. | ||
| 141 | + criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>']) | ||
| 142 | + | ||
| 143 | + optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) | ||
| 144 | + scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=args.num_epochs) | ||
| 145 | + | ||
| 146 | + # pre-trained 된 vectors load | ||
| 147 | + model.src_embedding.weight.data.copy_(TEXT.vocab.vectors) | ||
| 148 | + model.trg_embedding.weight.data.copy_(LABEL.vocab.vectors) | ||
| 149 | + model.to(device) | ||
| 150 | + criterion.to(device) | ||
| 151 | + | ||
| 152 | + # overfitting 막기 | ||
| 153 | + best_valid_loss = float('inf') | ||
| 154 | + | ||
| 155 | + # train | ||
| 156 | + if args.train: | ||
| 157 | + for epoch in range(args.num_epochs): | ||
| 158 | + torch.manual_seed(SEED) | ||
| 159 | + scheduler.step(epoch) | ||
| 160 | + start_time = time.time() | ||
| 161 | + | ||
| 162 | + # train, validation | ||
| 163 | + train_loss, train_acc = train(model, train_loader, optimizer, criterion) | ||
| 164 | + valid_loss, valid_acc = test(model, test_loader, criterion) | ||
| 165 | + | ||
| 166 | + # time cal | ||
| 167 | + end_time = time.time() | ||
| 168 | + epoch_mins, epoch_secs = epoch_time(start_time, end_time) | ||
| 169 | + | ||
| 170 | + #torch.save(model.state_dict(), sorted_path) # for some overfitting | ||
| 171 | + #전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장. | ||
| 172 | + if valid_loss < best_valid_loss: | ||
| 173 | + best_valid_loss = valid_loss | ||
| 174 | + torch.save({ | ||
| 175 | + 'epoch': epoch, | ||
| 176 | + 'model_state_dict': model.state_dict(), | ||
| 177 | + 'optimizer_state_dict': optimizer.state_dict(), | ||
| 178 | + 'loss': valid_loss}, | ||
| 179 | + sorted_path) | ||
| 180 | + print(f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##') | ||
| 181 | + | ||
| 182 | + # print loss and acc | ||
| 183 | + print(f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s==') | ||
| 184 | + print(f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}==') | ||
| 185 | + print(f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n') | ||
| 186 | + | ||
| 187 | + # inference | ||
| 188 | + print("\t----------성능평가----------") | ||
| 189 | + checkpoint = torch.load(sorted_path) | ||
| 190 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
| 191 | + test_loss, test_acc = test(model, test_loader, criterion) # 아 | ||
| 192 | + print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==') | ||
| 193 | + print("\t-----------------------------") | ||
| 194 | + while (True): | ||
| 195 | + inference(device, args, TEXT, LABEL, model, sa_model) | ||
| 196 | + print("\n") | ||
| 197 | + | ||
| 198 | +if __name__ == '__main__': | ||
| 199 | + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
| 200 | + # TEXT 는 사람의 말, LABEL 은 챗봇 답변을 의미하는 Field. | ||
| 201 | + TEXT, LABEL, train_loader, test_loader = data_preprocessing(args, device) | ||
| 202 | + main(TEXT, LABEL, train_loader, test_loader) |
-
Please register or login to post a comment