generation.py
2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from get_data import tokenizer1
from torch.autograd import Variable
from chatspace import ChatSpace
spacer = ChatSpace()
def inference(device, args, TEXT, LABEL, model, sa_model):
from KoBERT.Sentiment_Analysis_BERT_main import bert_inference
sentence = input("문장을 입력하세요 : ")
se_list = [sentence]
# https://github.com/SKTBrain/KoBERT
# SKT 에서 공개한 KoBert Sentiment Analysis 를 통해 입력문장의 긍정 부정 판단.
sa_label = int(bert_inference(sa_model, se_list))
sa_token = ''
# SA Label 에 따른 encoder input 변화.
if sa_label == 0:
sa_token = TEXT.vocab.stoi['<nega>']
else:
sa_token = TEXT.vocab.stoi['<posi>']
enc_input = tokenizer1(sentence)
enc_input_index = []
for tok in enc_input:
enc_input_index.append(TEXT.vocab.stoi[tok])
# encoder input string to index tensor and plus <pad>
if args.per_soft:
enc_input_index.append(sa_token)
for j in range(args.max_len - len(enc_input_index)):
enc_input_index.append(TEXT.vocab.stoi['<pad>'])
enc_input_index = Variable(torch.LongTensor([enc_input_index]))
dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]])
#print("긍정" if sa_label == 1 else "부정")
model.eval()
pred = []
for i in range(args.max_len):
y_pred = model(enc_input_index.to(device), dec_input.to(device))
y_pred_ids = y_pred.max(dim=-1)[1]
if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']):
y_pred_ids = y_pred_ids.squeeze(0)
print(">", end=" ")
for idx in range(len(y_pred_ids)):
if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>':
pred_sentence = "".join(pred)
pred_str = spacer.space(pred_sentence)
print(pred_str)
break
else:
pred.append(LABEL.vocab.itos[y_pred_ids[idx]])
return 0
dec_input = torch.cat(
[dec_input.to(torch.device('cpu')),
y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1)
return 0