Showing
1 changed file
with
61 additions
and
0 deletions
KoBERT/dataset_.py
0 → 100644
| 1 | +import torch | ||
| 2 | +from torch.utils.data import Dataset | ||
| 3 | +import gluonnlp as nlp | ||
| 4 | +import numpy as np | ||
| 5 | +from kobert.utils import get_tokenizer | ||
| 6 | +from KoBERT.Sentiment_Analysis_BERT_main import bertmodel, vocab | ||
| 7 | + | ||
| 8 | +tokenizer = get_tokenizer() | ||
| 9 | +tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False) | ||
| 10 | + | ||
| 11 | +class BERTDataset(Dataset): | ||
| 12 | + def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len, | ||
| 13 | + pad, pair): | ||
| 14 | + transform = nlp.data.BERTSentenceTransform( | ||
| 15 | + bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) | ||
| 16 | + | ||
| 17 | + self.sentences = [transform([i[sent_idx]]) for i in dataset] | ||
| 18 | + self.labels = [np.int32(i[label_idx]) for i in dataset] | ||
| 19 | + | ||
| 20 | + def __getitem__(self, i): | ||
| 21 | + return (self.sentences[i] + (self.labels[i], )) | ||
| 22 | + | ||
| 23 | + def __len__(self): | ||
| 24 | + return (len(self.labels)) | ||
| 25 | + | ||
| 26 | +class infer_BERTDataset(Dataset): | ||
| 27 | + def __init__(self, dataset, sent_idx, bert_tokenizer, max_len, | ||
| 28 | + pad, pair): | ||
| 29 | + transform = nlp.data.BERTSentenceTransform( | ||
| 30 | + bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) | ||
| 31 | + | ||
| 32 | + self.sentences = [transform([i[sent_idx]]) for i in dataset] | ||
| 33 | + | ||
| 34 | + def __getitem__(self, i): | ||
| 35 | + return (self.sentences[i]) | ||
| 36 | + | ||
| 37 | +def get_loader(args): | ||
| 38 | + dataset_train = nlp.data.TSVDataset("ratings_train.txt", field_indices=[1, 2], num_discard_samples=1) | ||
| 39 | + dataset_test = nlp.data.TSVDataset("ratings_test.txt", field_indices=[1, 2], num_discard_samples=1) | ||
| 40 | + #chatbot_0325_label_0.txt | ||
| 41 | + data_train = BERTDataset(dataset_train, 0, 1, tok, args.max_len, True, False) | ||
| 42 | + data_test = BERTDataset(dataset_test, 0, 1, tok, args.max_len, True, False) | ||
| 43 | + | ||
| 44 | + train_dataloader = torch.utils.data.DataLoader( | ||
| 45 | + data_train, batch_size=args.batch_size, drop_last=True, shuffle=True) | ||
| 46 | + test_dataloader = torch.utils.data.DataLoader( | ||
| 47 | + data_test, batch_size=args.batch_size, drop_last=False, shuffle=False) | ||
| 48 | + | ||
| 49 | + return train_dataloader, test_dataloader | ||
| 50 | + | ||
| 51 | +def infer(args, src): | ||
| 52 | + SRC_data = infer_BERTDataset(src, 0, tok, args.max_len, True, False) | ||
| 53 | + return SRC_data | ||
| 54 | + | ||
| 55 | +# import csv | ||
| 56 | +# num=0 | ||
| 57 | +# f = open('chatbot_0325_label_0.txt', 'r', encoding='utf-8') | ||
| 58 | +# rdr = csv.reader(f, delimiter='\t') | ||
| 59 | +# for idx, lin in enumerate(rdr): | ||
| 60 | +# num+=1 | ||
| 61 | +# print(num) |
-
Please register or login to post a comment