Showing
1 changed file
with
61 additions
and
0 deletions
KoBERT/dataset_.py
0 → 100644
1 | +import torch | ||
2 | +from torch.utils.data import Dataset | ||
3 | +import gluonnlp as nlp | ||
4 | +import numpy as np | ||
5 | +from kobert.utils import get_tokenizer | ||
6 | +from KoBERT.Sentiment_Analysis_BERT_main import bertmodel, vocab | ||
7 | + | ||
8 | +tokenizer = get_tokenizer() | ||
9 | +tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False) | ||
10 | + | ||
11 | +class BERTDataset(Dataset): | ||
12 | + def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len, | ||
13 | + pad, pair): | ||
14 | + transform = nlp.data.BERTSentenceTransform( | ||
15 | + bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) | ||
16 | + | ||
17 | + self.sentences = [transform([i[sent_idx]]) for i in dataset] | ||
18 | + self.labels = [np.int32(i[label_idx]) for i in dataset] | ||
19 | + | ||
20 | + def __getitem__(self, i): | ||
21 | + return (self.sentences[i] + (self.labels[i], )) | ||
22 | + | ||
23 | + def __len__(self): | ||
24 | + return (len(self.labels)) | ||
25 | + | ||
26 | +class infer_BERTDataset(Dataset): | ||
27 | + def __init__(self, dataset, sent_idx, bert_tokenizer, max_len, | ||
28 | + pad, pair): | ||
29 | + transform = nlp.data.BERTSentenceTransform( | ||
30 | + bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) | ||
31 | + | ||
32 | + self.sentences = [transform([i[sent_idx]]) for i in dataset] | ||
33 | + | ||
34 | + def __getitem__(self, i): | ||
35 | + return (self.sentences[i]) | ||
36 | + | ||
37 | +def get_loader(args): | ||
38 | + dataset_train = nlp.data.TSVDataset("ratings_train.txt", field_indices=[1, 2], num_discard_samples=1) | ||
39 | + dataset_test = nlp.data.TSVDataset("ratings_test.txt", field_indices=[1, 2], num_discard_samples=1) | ||
40 | + #chatbot_0325_label_0.txt | ||
41 | + data_train = BERTDataset(dataset_train, 0, 1, tok, args.max_len, True, False) | ||
42 | + data_test = BERTDataset(dataset_test, 0, 1, tok, args.max_len, True, False) | ||
43 | + | ||
44 | + train_dataloader = torch.utils.data.DataLoader( | ||
45 | + data_train, batch_size=args.batch_size, drop_last=True, shuffle=True) | ||
46 | + test_dataloader = torch.utils.data.DataLoader( | ||
47 | + data_test, batch_size=args.batch_size, drop_last=False, shuffle=False) | ||
48 | + | ||
49 | + return train_dataloader, test_dataloader | ||
50 | + | ||
51 | +def infer(args, src): | ||
52 | + SRC_data = infer_BERTDataset(src, 0, tok, args.max_len, True, False) | ||
53 | + return SRC_data | ||
54 | + | ||
55 | +# import csv | ||
56 | +# num=0 | ||
57 | +# f = open('chatbot_0325_label_0.txt', 'r', encoding='utf-8') | ||
58 | +# rdr = csv.reader(f, delimiter='\t') | ||
59 | +# for idx, lin in enumerate(rdr): | ||
60 | +# num+=1 | ||
61 | +# print(num) |
-
Please register or login to post a comment