bongminkim

update_pyfile

......@@ -15,13 +15,13 @@ parser.add_argument('--max_len', type=int, default=40) # max_len í¬ê²Œ 해야 ì
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_epochs', type=int, default=22)
parser.add_argument('--warming_up_epochs', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--lr', type=float, default=0.0002)#0.0002
parser.add_argument('--embedding_dim', type=int, default=160)
parser.add_argument('--nlayers', type=int, default=2)
parser.add_argument('--nhead', type=int, default=2)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--train', type=bool, default=True)
parser.add_argument('--per_soft', type=bool, default=False)
parser.add_argument('--per_soft', type=bool, default=True)
parser.add_argument('--per_rough', type=bool, default=False)
args = parser.parse_args()
......@@ -30,6 +30,7 @@ def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
# 학습
......
import torch
import torch.nn.functional as F
from math import log
from numpy import array
from get_data import tokenizer1
from torch.autograd import Variable
from chatspace import ChatSpace
spacer = ChatSpace()
from konlpy.tag import Mecab
import re
def tokenizer1(text):
result_text = re.sub('[-=+.,#/\:$@*\"※&%ㆍ!?』\\‘|\(\)\[\]\<\>`\'…》;]', '', text)
a = Mecab().morphs(result_text)
return ([a[i] for i in range(len(a))])
def _get_length_penalty(text, alpha=1.2, min_length=5):
p_list = []
for i in range(len(text)):
temp_text = tokenizer1(text[i][0])
length = len(temp_text)
p_list.append(((5 + length) ** alpha) / (5 + 1) ** alpha)
lp_list = [ text[j][1]/p_list[j] for j in range(len(text)) ]
return lp_list
def compair_beam_and_greedy(beam_pair, greedy_pair):
lp_list = _get_length_penalty(beam_pair)
gr_lp_list = _get_length_penalty(greedy_pair)
low_val = float('inf')
checked_sen = ""
for idx in range(len(beam_pair)):
if lp_list[idx] < low_val:
low_val = lp_list[idx]
checked_sen = beam_pair[idx][0]
print(" beam output > ", checked_sen, " |", low_val)
print("greedy output > ", greedy_pair[0][0]," |", gr_lp_list[0])
if low_val < gr_lp_list[0]:
print("use beam")
else:
print("use greedy")
def cal_score(pred, score):
pred = F.softmax(pred, dim=-1)
pred_ids = pred.max(dim=-1)[0]
pred_ids = pred_ids.to('cpu').tolist()
score = score * -log(pred_ids)
return score
def Beam_Search(data, k, first, sequences):
#sequences = [[list(), 1.0]]
if first:
data = data.squeeze(0)
else:
data = data.unsqueeze(0)
data = F.softmax(data, dim=-1)
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
no_tensor_row = row[j].to('cpu').tolist()
candidate = [seq + [j], score * -log(no_tensor_row)]
all_candidates.append(candidate)
ordered = sorted(all_candidates, key=lambda tup:tup[1])
sequences = ordered[:k]
return(sequences)
def beam(args, dec_input, enc_input_index, model, first, device, k_, LABEL):
temp_dec_input = torch.zeros([k_,1], dtype=torch.long)
temp_dec_input = temp_dec_input + dec_input
deliver_high_beam_value = torch.zeros([k_,1], dtype=torch.long)
return_sentence_beamVal_pair = []
check_k = [float('inf')] * k_
sequences = [[list(), 1.0]]
end_sentence = []
end_idx = []
if first:
y_pred = model(enc_input_index.to(device), dec_input.to(device))
first_beam_sequence = Beam_Search(y_pred, k_, True, sequences)
for i in range(len(deliver_high_beam_value)):
deliver_high_beam_value[i] = first_beam_sequence[i][0][0]
temp_dec_input = torch.cat(
[temp_dec_input.to(torch.device('cpu')),
deliver_high_beam_value.to(torch.device('cpu'))], dim=-1)
check_num = 0
beam_input_sequence = first_beam_sequence
for i in range(args.max_len):
which_value = [float('inf')] * k_
which_node = [0] * k_
for j in range(len(temp_dec_input)):
if temp_dec_input[j][-1] == torch.LongTensor([3]):
continue
y_pred = model(enc_input_index.to(device), temp_dec_input[j].unsqueeze(0).to(device))
beam_seq = Beam_Search(y_pred.squeeze(0)[-1], k_, False, [beam_input_sequence[j]])
beam_input_sequence[j] = [[beam_seq[0][0][-1]], beam_seq[0][1]]
which_node[j] = beam_seq[0][0][-1] # k개의 output중 누적확률 높은 거 get
for l in range(len(deliver_high_beam_value)):
if temp_dec_input[j][-1] == torch.LongTensor([3]):
continue
deliver_high_beam_value[l] = which_node[l]
temp_dec_input = torch.cat(
[temp_dec_input.to(torch.device('cpu')),
deliver_high_beam_value.to(torch.device('cpu'))], dim=-1)
for x in range(len(temp_dec_input)):
for y in range(len(temp_dec_input[x])):
if temp_dec_input[x][y] == torch.LongTensor([3]) and check_k[x] == float('inf'):
check_k[x] = beam_input_sequence[x][1]
if i+1 == args.max_len:
for k in range(k_):
for kk in range(len(temp_dec_input[k])):
if temp_dec_input[k][kk] == torch.LongTensor([3]):
check_num += 1
end_sentence.append(temp_dec_input[k])
end_idx.append(k)
break
for l in range(len(end_sentence)):
pred = []
for idx in range(len(end_sentence[l])):
if end_sentence[l][idx] == torch.LongTensor([3]):
pred_sentence = "".join(pred)
pred_str = spacer.space(pred_sentence)
#print(pred_str, " |", check_k[end_idx[l]])
return_sentence_beamVal_pair.append([pred_str, check_k[end_idx[l]]])
break
else:
if idx == 0:
continue
pred.append(LABEL.vocab.itos[end_sentence[l][idx]])
return return_sentence_beamVal_pair
def inference(device, args, TEXT, LABEL, model, sa_model):
from KoBERT.Sentiment_Analysis_BERT_main import bert_inference
......@@ -40,23 +185,37 @@ def inference(device, args, TEXT, LABEL, model, sa_model):
model.eval()
pred = []
beam_k = 10
beam_sen_val_pair = beam(args, dec_input, enc_input_index, model, True, device, beam_k, LABEL)
greedy_pair = []
for i in range(args.max_len):
y_pred = model(enc_input_index.to(device), dec_input.to(device))
if i == 0:
score = cal_score(y_pred.squeeze(0)[-1], 1.0)
else:
score = cal_score(y_pred.squeeze(0)[-1], score )
y_pred_ids = y_pred.max(dim=-1)[1]
if (y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']):
y_pred_ids = y_pred_ids.squeeze(0)
print(">", end=" ")
#print(">", end=" ")
for idx in range(len(y_pred_ids)):
if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>':
pred_sentence = "".join(pred)
pred_str = spacer.space(pred_sentence)
print(pred_str)
#print(pred_str, " |", score)
greedy_pair.append([pred_str, score])
break
else:
pred.append(LABEL.vocab.itos[y_pred_ids[idx]])
compair_beam_and_greedy(beam_sen_val_pair, greedy_pair)
return 0
dec_input = torch.cat(
[dec_input.to(torch.device('cpu')),
y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1)
return 0
\ No newline at end of file
......
......@@ -51,8 +51,8 @@ def data_preprocessing(args, device):
# TEXT, LABEL 에 필요한 special token 만듦.
text_specials, label_specials = make_special_token(args)
TEXT.build_vocab(train_data, vectors=vectors, max_size=15000, specials=text_specials)
LABEL.build_vocab(train_data, vectors=vectors, max_size=15000, specials=label_specials)
TEXT.build_vocab(train_data,vectors=vectors, max_size=15000, specials=text_specials)
LABEL.build_vocab(train_data,vectors=vectors, max_size=15000, specials=label_specials)
train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True)
test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True)
......