bongminkim

Chatbot_main

1 +import time
2 +import torch
3 +import argparse
4 +from torch import nn
5 +from metric import acc, train_test
6 +from Styling import styling, make_special_token
7 +from get_data import data_preprocessing, tokenizer1
8 +from generation import inference
9 +
10 +SEED = 1234
11 +
12 +# argparse 정의
13 +parser = argparse.ArgumentParser()
14 +parser.add_argument('--max_len', type=int, default=40) # max_len 크게 해야 오류 안 생김.
15 +parser.add_argument('--batch_size', type=int, default=256)
16 +parser.add_argument('--num_epochs', type=int, default=22)
17 +parser.add_argument('--warming_up_epochs', type=int, default=5)
18 +parser.add_argument('--lr', type=float, default=0.0002)
19 +parser.add_argument('--embedding_dim', type=int, default=160)
20 +parser.add_argument('--nlayers', type=int, default=2)
21 +parser.add_argument('--nhead', type=int, default=2)
22 +parser.add_argument('--dropout', type=float, default=0.1)
23 +parser.add_argument('--train', type=bool, default=True)
24 +parser.add_argument('--per_soft', type=bool, default=False)
25 +parser.add_argument('--per_rough', type=bool, default=False)
26 +args = parser.parse_args()
27 +
28 +# 시간 계산 함수
29 +def epoch_time(start_time, end_time):
30 + elapsed_time = end_time - start_time
31 + elapsed_mins = int(elapsed_time / 60)
32 + elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
33 + return elapsed_mins, elapsed_secs
34 +
35 +# 학습
36 +def train(model, iterator, optimizer, criterion):
37 + total_loss = 0
38 + iter_num = 0
39 + tr_acc = 0
40 + model.train()
41 +
42 + for step, batch in enumerate(iterator):
43 + optimizer.zero_grad()
44 +
45 + enc_input, dec_input , enc_label = batch.text, batch.target_text, batch.SA
46 +
47 + dec_output = dec_input[:, 1:]
48 + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data)
49 +
50 + # emotion 과 체를 반영
51 + enc_input, dec_input, dec_outputs = \
52 + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args, TEXT, LABEL)
53 +
54 + y_pred = model(enc_input, dec_input)
55 +
56 + y_pred = y_pred.reshape(-1, y_pred.size(-1))
57 + dec_output = dec_outputs.view(-1).long()
58 +
59 + # padding 제외한 value index 추출
60 + real_value_index = [dec_output != 1] # <pad> == 1
61 +
62 + # padding 은 loss 계산시 제외
63 + loss = criterion(y_pred[real_value_index], dec_output[real_value_index])
64 + loss.backward()
65 + optimizer.step()
66 +
67 + with torch.no_grad():
68 + train_acc = acc(y_pred, dec_output)
69 +
70 + total_loss += loss
71 + iter_num += 1
72 + tr_acc += train_acc
73 +
74 + train_test(step, y_pred, dec_output, real_value_index, enc_input,
75 + args, TEXT, LABEL)
76 +
77 + return total_loss.data.cpu().numpy() / iter_num, tr_acc.data.cpu().numpy() / iter_num
78 +
79 +# 테스트
80 +def test(model, iterator, criterion):
81 + total_loss = 0
82 + iter_num = 0
83 + te_acc = 0
84 + model.eval()
85 +
86 + with torch.no_grad():
87 + for batch in iterator:
88 + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA
89 + dec_output = dec_input[:, 1:]
90 + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data)
91 +
92 + # emotion 과 체를 반영
93 + enc_input, dec_input, dec_outputs = \
94 + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args, TEXT, LABEL)
95 +
96 + y_pred = model(enc_input, dec_input)
97 +
98 + y_pred = y_pred.reshape(-1, y_pred.size(-1))
99 + dec_output = dec_outputs.view(-1).long()
100 +
101 + real_value_index = [dec_output != 1] # <pad> == 1
102 +
103 + loss = criterion(y_pred[real_value_index], dec_output[real_value_index])
104 +
105 + with torch.no_grad():
106 + test_acc = acc(y_pred, dec_output)
107 + total_loss += loss
108 + iter_num += 1
109 + te_acc += test_acc
110 +
111 + return total_loss.data.cpu().numpy() / iter_num, te_acc.data.cpu().numpy() / iter_num
112 +
113 +def main(TEXT, LABEL, train_loader, test_loader):
114 +
115 + # for sentiment analysis. load .pt file
116 + from KoBERT.Bert_model import BERTClassifier
117 + from kobert.pytorch_kobert import get_pytorch_kobert_model
118 + bertmodel, vocab = get_pytorch_kobert_model()
119 + sa_model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)
120 + sa_model.load_state_dict(torch.load('bert_SA-model.pt'))
121 +
122 + # print argparse
123 + for idx, (key, value) in enumerate(args.__dict__.items()):
124 + if idx == 0:
125 + print("\nargparse{\n", "\t", key, ":", value)
126 + elif idx == len(args.__dict__)-1:
127 + print("\t", key, ":", value, "\n}")
128 + else:
129 + print("\t", key, ":", value)
130 +
131 + from model import Transformer, GradualWarmupScheduler
132 +
133 + # Transformer model init
134 + model = Transformer(args, TEXT, LABEL)
135 + if args.per_soft:
136 + sorted_path = 'sorted_model-soft.pth'
137 + else:
138 + sorted_path = 'sorted_model-rough.pth'
139 +
140 + # loss 계산시 pad 제외.
141 + criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>'])
142 +
143 + optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
144 + scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=args.num_epochs)
145 +
146 + # pre-trained 된 vectors load
147 + model.src_embedding.weight.data.copy_(TEXT.vocab.vectors)
148 + model.trg_embedding.weight.data.copy_(LABEL.vocab.vectors)
149 + model.to(device)
150 + criterion.to(device)
151 +
152 + # overfitting 막기
153 + best_valid_loss = float('inf')
154 +
155 + # train
156 + if args.train:
157 + for epoch in range(args.num_epochs):
158 + torch.manual_seed(SEED)
159 + scheduler.step(epoch)
160 + start_time = time.time()
161 +
162 + # train, validation
163 + train_loss, train_acc = train(model, train_loader, optimizer, criterion)
164 + valid_loss, valid_acc = test(model, test_loader, criterion)
165 +
166 + # time cal
167 + end_time = time.time()
168 + epoch_mins, epoch_secs = epoch_time(start_time, end_time)
169 +
170 + #torch.save(model.state_dict(), sorted_path) # for some overfitting
171 + #전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장.
172 + if valid_loss < best_valid_loss:
173 + best_valid_loss = valid_loss
174 + torch.save({
175 + 'epoch': epoch,
176 + 'model_state_dict': model.state_dict(),
177 + 'optimizer_state_dict': optimizer.state_dict(),
178 + 'loss': valid_loss},
179 + sorted_path)
180 + print(f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##')
181 +
182 + # print loss and acc
183 + print(f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s==')
184 + print(f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}==')
185 + print(f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n')
186 +
187 + # inference
188 + print("\t----------성능평가----------")
189 + checkpoint = torch.load(sorted_path)
190 + model.load_state_dict(checkpoint['model_state_dict'])
191 + test_loss, test_acc = test(model, test_loader, criterion) # 아
192 + print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==')
193 + print("\t-----------------------------")
194 + while (True):
195 + inference(device, args, TEXT, LABEL, model, sa_model)
196 + print("\n")
197 +
198 +if __name__ == '__main__':
199 + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
200 + # TEXT 는 사람의 말, LABEL 은 챗봇 답변을 의미하는 Field.
201 + TEXT, LABEL, train_loader, test_loader = data_preprocessing(args, device)
202 + main(TEXT, LABEL, train_loader, test_loader)