bongminkim

bert_SA_datasetpy

1 +import torch
2 +from torch.utils.data import Dataset
3 +import gluonnlp as nlp
4 +import numpy as np
5 +from kobert.utils import get_tokenizer
6 +from KoBERT.Sentiment_Analysis_BERT_main import bertmodel, vocab
7 +
8 +tokenizer = get_tokenizer()
9 +tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)
10 +
11 +class BERTDataset(Dataset):
12 + def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
13 + pad, pair):
14 + transform = nlp.data.BERTSentenceTransform(
15 + bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
16 +
17 + self.sentences = [transform([i[sent_idx]]) for i in dataset]
18 + self.labels = [np.int32(i[label_idx]) for i in dataset]
19 +
20 + def __getitem__(self, i):
21 + return (self.sentences[i] + (self.labels[i], ))
22 +
23 + def __len__(self):
24 + return (len(self.labels))
25 +
26 +class infer_BERTDataset(Dataset):
27 + def __init__(self, dataset, sent_idx, bert_tokenizer, max_len,
28 + pad, pair):
29 + transform = nlp.data.BERTSentenceTransform(
30 + bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
31 +
32 + self.sentences = [transform([i[sent_idx]]) for i in dataset]
33 +
34 + def __getitem__(self, i):
35 + return (self.sentences[i])
36 +
37 +def get_loader(args):
38 + dataset_train = nlp.data.TSVDataset("ratings_train.txt", field_indices=[1, 2], num_discard_samples=1)
39 + dataset_test = nlp.data.TSVDataset("ratings_test.txt", field_indices=[1, 2], num_discard_samples=1)
40 + #chatbot_0325_label_0.txt
41 + data_train = BERTDataset(dataset_train, 0, 1, tok, args.max_len, True, False)
42 + data_test = BERTDataset(dataset_test, 0, 1, tok, args.max_len, True, False)
43 +
44 + train_dataloader = torch.utils.data.DataLoader(
45 + data_train, batch_size=args.batch_size, drop_last=True, shuffle=True)
46 + test_dataloader = torch.utils.data.DataLoader(
47 + data_test, batch_size=args.batch_size, drop_last=False, shuffle=False)
48 +
49 + return train_dataloader, test_dataloader
50 +
51 +def infer(args, src):
52 + SRC_data = infer_BERTDataset(src, 0, tok, args.max_len, True, False)
53 + return SRC_data
54 +
55 +# import csv
56 +# num=0
57 +# f = open('chatbot_0325_label_0.txt', 'r', encoding='utf-8')
58 +# rdr = csv.reader(f, delimiter='\t')
59 +# for idx, lin in enumerate(rdr):
60 +# num+=1
61 +# print(num)