metric.py 1.15 KB
import torch

# acc 출력
def acc(yhat, y):
    with torch.no_grad():
        yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value
        acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거
    return acc

# 학습시 모델에 넣는 입력과 모델의 예측 출력.
def train_test(step, y_pred, dec_output, real_value_index, enc_input, args, TEXT, LABEL):

    if 0 <= step < 3:
        _, ix = y_pred[real_value_index].data.topk(1)
        train_Q = enc_input[0]
        print("<<Q>> :", end=" ")
        for i in train_Q:
            if TEXT.vocab.itos[i] == "<pad>":
                break
            print(TEXT.vocab.itos[i], end=" ")

        print("\n<<trg A>> :", end=" ")
        for jj, jx in enumerate(dec_output[real_value_index]):
            if LABEL.vocab.itos[jx] == "<eos>":
                break
            print(LABEL.vocab.itos[jx], end=" ")

        print("\n<<pred A>> :", end=" ")
        for jj, ix in enumerate(ix):
            if jj == args.max_len:
                break
            if LABEL.vocab.itos[ix] == '<eos>':
                break
            print(LABEL.vocab.itos[ix], end=" ")
        print("\n")