bongminkim

bert_SA_datasetpy

import torch
from torch.utils.data import Dataset
import gluonnlp as nlp
import numpy as np
from kobert.utils import get_tokenizer
from KoBERT.Sentiment_Analysis_BERT_main import bertmodel, vocab
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)
class BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
pad, pair):
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
self.labels = [np.int32(i[label_idx]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i] + (self.labels[i], ))
def __len__(self):
return (len(self.labels))
class infer_BERTDataset(Dataset):
def __init__(self, dataset, sent_idx, bert_tokenizer, max_len,
pad, pair):
transform = nlp.data.BERTSentenceTransform(
bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
self.sentences = [transform([i[sent_idx]]) for i in dataset]
def __getitem__(self, i):
return (self.sentences[i])
def get_loader(args):
dataset_train = nlp.data.TSVDataset("ratings_train.txt", field_indices=[1, 2], num_discard_samples=1)
dataset_test = nlp.data.TSVDataset("ratings_test.txt", field_indices=[1, 2], num_discard_samples=1)
#chatbot_0325_label_0.txt
data_train = BERTDataset(dataset_train, 0, 1, tok, args.max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, args.max_len, True, False)
train_dataloader = torch.utils.data.DataLoader(
data_train, batch_size=args.batch_size, drop_last=True, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(
data_test, batch_size=args.batch_size, drop_last=False, shuffle=False)
return train_dataloader, test_dataloader
def infer(args, src):
SRC_data = infer_BERTDataset(src, 0, tok, args.max_len, True, False)
return SRC_data
# import csv
# num=0
# f = open('chatbot_0325_label_0.txt', 'r', encoding='utf-8')
# rdr = csv.reader(f, delimiter='\t')
# for idx, lin in enumerate(rdr):
# num+=1
# print(num)