김민수

Upload Wekly report & codes

1 +# -*- coding: utf-8 -*-
2 +import argparse
3 +from argparse import ArgumentError
4 +import os
5 +import glob
6 +import time
7 +import subprocess
8 +
9 +import gluonnlp as nlp
10 +from numpy.lib.function_base import delete
11 +import torch
12 +from torch.utils.data import DataLoader, Dataset
13 +from gluonnlp.data import SentencepieceTokenizer
14 +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
15 +from kogpt2.utils import get_tokenizer
16 +from tqdm import tqdm
17 +from util.data_loader import ArticleDataset, ToTensor
18 +
19 +def get_gpu_memory_map():
20 + """Get the current gpu usage.
21 + Returns
22 + -------
23 + usage: dict
24 + Keys are device ids as integers.
25 + Values are memory usage as integers in MB.
26 + """
27 + result = subprocess.check_output(
28 + [
29 + 'nvidia-smi', '--query-gpu=memory.used',
30 + '--format=csv,nounits,noheader'
31 + ], encoding='utf-8')
32 + # Convert lines into a dictionary
33 + gpu_memory = [int(x) for x in result.strip().split('\n')]
34 + gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
35 + return gpu_memory_map
36 +
37 +
38 +if __name__ == "__main__":
39 + parser=argparse.ArgumentParser(description='Train KoGPT2 with ArticleDataset.')
40 + parser.add_argument('--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
41 + parser.add_argument('--default', action='store_true', help="Use un-tuned KoGPT2")
42 + parser.add_argument('--model_topic', choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'] )
43 + parser.add_argument('--epoch', type=int)
44 + parser.add_argument('--topic', nargs='+',choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], default=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'])
45 + parser.add_argument('device', choices=['cpu', 'cuda', 'cuda:0', 'cuda:1'])
46 + args = parser.parse_args()
47 + print(args)
48 +
49 + model_cache_path='/code/model' if args.docker else 'model'
50 + dataset_path='/dataset' if args.docker else '../dataset'
51 + save_path='/code/save' if args.docker else 'save'
52 +
53 + ctx=args.device if torch.cuda.is_available() else 'cpu'
54 + print(ctx)
55 + device=torch.device(ctx)
56 + tokenizer_path = get_tokenizer(cachedir=model_cache_path)
57 + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
58 + tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
59 + num_workers=32
60 + batch_size=64
61 + padding_id=vocab[vocab.padding_token]
62 + topics=set(sorted(args.topic))
63 + transform=ToTensor(tokenizer,vocab,128)
64 + print("Preparing dataloader...")
65 + dataloaders={}
66 + dataloaders["all"]=DataLoader(ArticleDataset(dataset_path,label='test', transform=transform),batch_size=batch_size, num_workers=0)
67 + for topic in tqdm(topics):
68 + dataloaders[topic]=DataLoader(ArticleDataset(dataset_path, topics={topic},label='test', transform=transform),batch_size=batch_size, num_workers=0)
69 + print("Prepared dataloader.")
70 + epoches=30
71 + checkpoint_epoch=0
72 + learning_rate = 3e-5
73 + criterion = torch.nn.CrossEntropyLoss()
74 + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
75 +
76 + topic_all=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']
77 + model_topic=topic_all if args.model_topic is None else sorted(list({args.model_topic}))
78 + model_epoch='*' if args.epoch is None else args.epoch
79 + dev=ctx if ctx in {'cpu', 'cuda'} else 'cuda:*'
80 + braced=str("{'생활', '경제', 'IT_과학', '미용_건강', '스포츠', '사회', '연예', '문화', '정치'}") if args.model_topic is None else '{'+str(model_topic)[1:-1]+'}'
81 + saves=glob.glob(f'{save_path}/KoGPT2_checkpoint_{dev}_{braced}_{transform.max_len}_{model_epoch}.state')
82 + if not args.default:
83 + if len(saves)>0:
84 + last_save=max(saves,key=os.path.getmtime)
85 + checkpoint = torch.load(last_save, map_location=device)
86 + print(f"Loading save from {last_save}")
87 + model.load_state_dict(checkpoint['model_state_dict'])
88 + optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
89 + checkpoint_epoch = checkpoint['epoch']
90 + last_test_loss = checkpoint['loss']
91 + else:
92 + print("No save exists.")
93 + raise FileNotFoundError(f'{save_path}/KoGPT2_checkpoint_{ctx}_{model_topic}_{transform.max_len}_{model_epoch}.state')
94 + model.to(device)
95 + model.eval()
96 +
97 + cached_testset_path=f"{save_path}/test_{topic_all}_{transform.max_len}"
98 + if os.path.isfile(cached_testset_path+'.npy'):
99 + dataloaders["all"].dataset.load_from_file(cached_testset_path+'.npy')
100 + else:
101 + print("Caching testset... topic: all")
102 + for temp in tqdm(dataloaders["all"]):
103 + pass
104 + dataloaders["all"].dataset.set_use_cache(True, cached_testset_path)
105 + print("Cached. topic: all")
106 + dataloaders["all"].dataset.num_workers=num_workers
107 + for topic in tqdm(topics):
108 + cached_testset_path=f"{save_path}/test_{{topic}}_{transform.max_len}"
109 + if os.path.isfile(cached_testset_path+'.npy'):
110 + dataloaders[topic].dataset.load_from_file(cached_testset_path+'.npy')
111 + else:
112 + print(f"Caching testset... topic: {topic}")
113 + for temp in tqdm(dataloaders[topic]):
114 + pass
115 + dataloaders[topic].dataset.set_use_cache(True, cached_testset_path)
116 + print(f"Cached. topic: {topic}")
117 + dataloaders[topic].dataset.num_workers=num_workers
118 +
119 + last_test_loss=float('infinity')
120 + overfit=-1
121 + states=[]
122 +
123 + for topic in tqdm(dataloaders):
124 + try:
125 + test_loss_list=[]
126 + for data in tqdm(dataloaders[topic]):
127 + data = data.to(ctx)
128 + label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
129 + mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
130 + output=model(data,labels=label, attention_mask=mask)
131 + loss=output[0]
132 + test_loss_list.append(loss.item())
133 + del label
134 + del mask
135 + del loss
136 + del output
137 + del data
138 + test_loss=sum(test_loss_list)/len(test_loss_list)
139 + print(f"data_topic: {topic}, model_topic: {model_topic} test loss: {test_loss}")
140 + states.append((topic, model_topic,test_loss))
141 + except KeyboardInterrupt:
142 + break
143 + log_path=f"{save_path}/test_{'DEFAULT' if args.default else model_topic}_{topics}_{transform.max_len}_{int(time.time())}.log"
144 + with open(log_path, 'w') as log:
145 + log.write("data_topic, model_topic, test loss,\n")
146 + for state in states:
147 + log.write(f"{state[0]}, {state[1]},{state[2]},\n")
148 + print(f"Log written at: {log_path}")
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from random import choice, choices, randint
3 +import argparse
4 +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
5 +from gluonnlp.data import SentencepieceTokenizer
6 +from kogpt2.utils import get_tokenizer
7 +
8 +def top_k(predict, vocab, k):
9 + # topk 중 랜덤으로 선택된 값을 반환.
10 + probs, indices = torch.topk(predict, k=k,dim=-1)
11 + return vocab.to_tokens(choice(indices.tolist()))
12 +
13 +def top_p(logits, vocab, threshold = 0.9):
14 + sorted_logits, sorted_indices = torch.sort(logits, descending=True)
15 + indexs = sorted_indices.tolist()
16 + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
17 + cum_prob = 0
18 + top_p_index = 0
19 + # Top-p에 해당하는 index를 획득
20 + for i, prob in enumerate(sorted_softmax_logits):
21 + if cum_prob>threshold:
22 + top_p_index = 0 if i==0 else i-1
23 + break
24 + cum_prob+=prob
25 + rand_num = randint(0, top_p_index) # top-p 분포에서 랜덤 샘플링
26 + return vocab.to_tokens(indexs[rand_num])
27 +
28 +def weighted_random(logits, vocab):
29 + sorted_logits, sorted_indices = torch.sort(logits, descending=True)
30 + indexs = sorted_indices.tolist()
31 + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
32 + return vocab.to_tokens(choices(indexs,weights=sorted_softmax_logits)[0])
33 +
34 +if __name__ == "__main__":
35 + parser = argparse.ArgumentParser(description='KoGPT2 generation example')
36 + parser.add_argument('sentence', metavar='S', type=str, nargs='?',default= '2019년 한해를 보내며,',
37 + help='korean sentence to use as input.')
38 +
39 + ctx='cuda' if torch.cuda.is_available() else 'cpu'
40 + tok_path = get_tokenizer(cachedir='/code/model')
41 + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model')
42 + tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0)
43 + sent = parser.parse_args().sentence
44 + toked = tok(sent)
45 + token_count=0
46 + while token_count<100:
47 + try:
48 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0)
49 + pred = model(input_ids)[0]
50 + gen = vocab.to_tokens(torch.argmax(pred, axis=-1).squeeze().tolist())[-1]
51 + if gen == '</s>':
52 + break
53 + sent += gen.replace('▁', ' ')
54 + toked = tok(sent)
55 + token_count+=1
56 + except KeyboardInterrupt:
57 + break
58 + print('Greedy:',sent)
59 +
60 + sent = parser.parse_args().sentence
61 + toked = tok(sent)
62 + token_count=0
63 + while token_count<100:
64 + try:
65 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0)
66 + pred = model(input_ids)[0]
67 + gen = top_k(pred.squeeze()[-1], vocab, 3)
68 + if gen == '</s>':
69 + break
70 + sent += gen.replace('▁', ' ')
71 + toked = tok(sent)
72 + token_count+=1
73 + except KeyboardInterrupt:
74 + break
75 + print('Top 3:', sent)
76 +
77 + sent = parser.parse_args().sentence
78 + toked = tok(sent)
79 + token_count=0
80 + while token_count<100:
81 + try:
82 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0)
83 + pred = model(input_ids)[0]
84 + gen = top_k(pred.squeeze()[-1], vocab, 5)
85 + if gen == '</s>':
86 + break
87 + sent += gen.replace('▁', ' ')
88 + toked = tok(sent)
89 + token_count+=1
90 + except KeyboardInterrupt:
91 + break
92 + print('Top 5:', sent)
93 +
94 + sent = parser.parse_args().sentence
95 + toked = tok(sent)
96 + token_count=0
97 + while token_count<100:
98 + try:
99 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0)
100 + pred = model(input_ids)[0]
101 + gen = top_p(pred.squeeze()[-1], vocab,0.5)
102 + if gen == '</s>':
103 + break
104 + sent += gen.replace('▁', ' ')
105 + toked = tok(sent)
106 + token_count+=1
107 + except KeyboardInterrupt:
108 + break
109 + print('Top p=0.5:', sent)
110 +
111 + sent = parser.parse_args().sentence
112 + toked = tok(sent)
113 + token_count=0
114 + while token_count<100:
115 + try:
116 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0)
117 + pred = model(input_ids)[0]
118 + gen = top_p(pred.squeeze()[-1], vocab,0.7)
119 + if gen == '</s>':
120 + break
121 + sent += gen.replace('▁', ' ')
122 + toked = tok(sent)
123 + token_count+=1
124 + except KeyboardInterrupt:
125 + break
126 + print('Top p=0.7:', sent)
127 +
128 + sent = parser.parse_args().sentence
129 + toked = tok(sent)
130 + token_count=0
131 + while token_count<100:
132 + try:
133 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0)
134 + pred = model(input_ids)[0]
135 + gen = top_p(pred.squeeze()[-1], vocab)
136 + if gen == '</s>':
137 + break
138 + sent += gen.replace('▁', ' ')
139 + toked = tok(sent)
140 + token_count+=1
141 + except KeyboardInterrupt:
142 + break
143 + print('Top p=0.9:', sent)
144 +
145 + sent = parser.parse_args().sentence
146 + toked = tok(sent)
147 + token_count=0
148 + while token_count<100:
149 + try:
150 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0)
151 + pred = model(input_ids)[0]
152 + gen = weighted_random(pred.squeeze()[-1], vocab)
153 + if gen == '</s>':
154 + break
155 + sent += gen.replace('▁', ' ')
156 + toked = tok(sent)
157 + token_count+=1
158 + except KeyboardInterrupt:
159 + break
160 + print('Weighted random:', sent)
...\ No newline at end of file ...\ No newline at end of file
1 +import argparse
2 +import tqdm
3 +import gluonnlp as nlp
4 +import torch
5 +from gluonnlp.data import SentencepieceTokenizer
6 +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
7 +from kogpt2.utils import get_tokenizer
8 +from pytorch_lightning import Trainer
9 +from pytorch_lightning.callbacks import ModelCheckpoint
10 +from pytorch_lightning.core.lightning import LightningModule
11 +from torch.utils.data import DataLoader, Dataset
12 +from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
13 +from .utils.data_loader import ArticleDataset, ToTensor
14 +
15 +class KoGPT2Chat(LightningModule):
16 + def __init__(self, hparams, **kwargs):
17 + super(KoGPT2Chat, self).__init__()
18 + self.hparams = hparams
19 + self.tok_path = get_tokenizer()
20 + self.neg = -1e18
21 + self.model, self.vocab = get_pytorch_kogpt2_model()
22 + self.loss_function = torch.nn.CrossEntropyLoss(reduction='none')
23 +
24 + @staticmethod
25 + def add_model_specific_args(parent_parser):
26 + # add model specific args
27 + parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
28 + parser.add_argument('--max-len',
29 + type=int,
30 + default=32,
31 + help='max sentence length on input (default: 32)')
32 +
33 + parser.add_argument('--batch-size',
34 + type=int,
35 + default=96,
36 + help='batch size for training (default: 96)')
37 + parser.add_argument('--lr',
38 + type=float,
39 + default=5e-5,
40 + help='The initial learning rate')
41 + parser.add_argument('--warmup_ratio',
42 + type=float,
43 + default=0.1,
44 + help='warmup ratio')
45 + return parser
46 +
47 + def forward(self, inputs):
48 + # (batch, seq_len, hiddens)
49 + output, _ = self.kogpt2(inputs)
50 + return output
51 +
52 + def training_step(self, batch, batch_idx):
53 + token_ids, mask, label = batch
54 + out = self(token_ids)
55 + mask_3d = mask.unsqueeze(dim=2).repeat_interleave(repeats=out.shape[2], dim=2)
56 + mask_out = torch.where(mask_3d == 1, out, self.neg * torch.ones_like(out))
57 + loss = self.loss_function(mask_out.transpose(2, 1), label)
58 + loss_avg = loss.sum() / mask.sum()
59 + tensorboard_logs = {'train_loss': loss_avg}
60 + return {'loss': loss_avg, 'log': tensorboard_logs}
61 +
62 + def configure_optimizers(self):
63 + # Prepare optimizer
64 + param_optimizer = list(self.named_parameters())
65 + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
66 + optimizer_grouped_parameters = [
67 + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
68 + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
69 + ]
70 + optimizer = AdamW(optimizer_grouped_parameters,
71 + lr=self.hparams.lr, correct_bias=False)
72 + # warm up lr
73 + num_train_steps = len(self.train_dataloader()) * self.hparams.max_epochs
74 + num_warmup_steps = int(num_train_steps * self.hparams.warmup_ratio)
75 + scheduler = get_cosine_schedule_with_warmup(
76 + optimizer,
77 + num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)
78 + lr_scheduler = {'scheduler': scheduler, 'name': 'cosine_schedule_with_warmup',
79 + 'monitor': 'loss', 'interval': 'step',
80 + 'frequency': 1}
81 + return [optimizer], [lr_scheduler]
82 +
83 + def _collate_fn(self, batch):
84 + data = [item[0] for item in batch]
85 + mask = [item[1] for item in batch]
86 + label = [item[2] for item in batch]
87 + return torch.LongTensor(data), torch.LongTensor(mask), torch.LongTensor(label)
88 +
89 + def train_dataloader(self):
90 + data = pd.read_csv('Chatbot_data/ChatbotData.csv')
91 + jhkhk self.train_set = ArticleDataset(data, self.tok_path, self.vocab, max_len=self.hparams.max_len)
92 + train_dataloader = DataLoader(
93 + self.train_set, batch_size=self.hparams.batch_size, num_workers=2,
94 + shuffle=True, collate_fn=self._collate_fn)
95 + return train_dataloader
96 +
97 +parser = KoGPT2Chat.add_model_specific_args(parser)
98 +parser = Trainer.add_argparse_args(parser)
99 +args = parser.parse_args()
100 +
101 +if __name__ == "__main__":
102 + if args.train:
103 + checkpoint_callback = ModelCheckpoint(
104 + filepath='model_chp/{epoch:02d}-{loss:.2f}',
105 + verbose=True,
106 + save_last=True,
107 + monitor='loss',
108 + mode='min',
109 + prefix='model_'
110 + )
111 + # python train_torch.py --train --gpus 1 --max_epochs 3
112 + model = KoGPT2Chat(args)
113 + model.train()
114 + trainer = Trainer.from_argparse_args(
115 + args,
116 + checkpoint_callback=checkpoint_callback, gradient_clip_val=1.0)
117 + trainer.fit(model, )
118 + if args.chat:
119 + model = KoGPT2Chat.load_from_checkpoint(args.model_params)
120 + model.eval()
...\ No newline at end of file ...\ No newline at end of file
1 +from random import choice, choices, randint
2 +import argparse
3 +import re
4 +import time
5 +import torch
6 +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
7 +from gluonnlp.data import SentencepieceTokenizer
8 +from kogpt2.utils import get_tokenizer
9 +
10 +def greedy(predict):
11 + return (torch.argmax(predict, axis=-1).tolist())
12 +
13 +def top_k(predict, k):
14 + # topk 중 랜덤으로 선택된 값을 반환.
15 + probs, indices = torch.topk(predict, k=k,dim=-1)
16 + return choice(indices.tolist())
17 +
18 +def top_p(logits, threshold = 0.9):
19 + sorted_logits, sorted_indices = torch.sort(logits, descending=True)
20 + indices = sorted_indices.tolist()
21 + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
22 + cum_prob = 0
23 + top_p_index = 0
24 + # Top-p에 해당하는 index를 획득
25 + for i, prob in enumerate(sorted_softmax_logits):
26 + if cum_prob>threshold:
27 + top_p_index = 0 if i==0 else i-1
28 + break
29 + cum_prob+=prob
30 + rand_num = randint(0, top_p_index) # top-p 분포에서 랜덤 샘플링
31 + return indices[rand_num]
32 +
33 +def weighted_random(logits):
34 + indices=torch.where(logits>=0)[0] #음수 고려 안 함
35 + selected_logits=torch.index_select(logits,-1,indices)
36 + softmax_logits = torch.nn.functional.softmax(selected_logits, dim=-1)
37 + return choices(indices.tolist(),weights=softmax_logits)[0]
38 +
39 +def weighted_top_k(predict, k):
40 + probs, indices = torch.topk(predict, k=k,dim=-1)
41 + softmax_probs = torch.nn.functional.softmax(probs, dim=-1)
42 + return choices(indices.tolist(),weights=softmax_probs)[0]
43 +
44 +def weighted_top_p(logits, threshold = 0.9):
45 + sorted_logits, sorted_indices = torch.sort(logits, descending=True)
46 + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
47 + cum_prob = 0
48 + last_cum_prob=0
49 + top_p_bound = 0
50 + # Top-p에 해당하는 index를 획득
51 + for i, prob in enumerate(sorted_softmax_logits):
52 + if cum_prob>threshold:
53 + top_p_bound = i
54 + break
55 + last_cum_prob=cum_prob
56 + cum_prob+=prob
57 + return choices(sorted_indices[:top_p_bound].tolist(),weights=sorted_softmax_logits[:top_p_bound]/last_cum_prob)[0]
58 +
59 +
60 +if __name__ == "__main__":
61 + parser = argparse.ArgumentParser(description='KoGPT2 generation example')
62 + group=parser.add_mutually_exclusive_group()
63 + group.add_argument('-g','--greedy',action='store_const',const='greedy',help='Greedy sampling')
64 + group.add_argument('-k','--topk',type=int, choices=range(1,51), help='Top k sampling. 1<=K<=50', metavar='K')
65 + group.add_argument('-p','--topp',type=float, help='Top p sampling. 0<P<=1.0', metavar='P')
66 + parser.add_argument('-w','--weighted',action='store_true', help='Use weighted version of sampling.')
67 + parser.add_argument('-d','--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
68 + parser.add_argument('-c','--checkpoint',type=str , help='Model chekpoint path',metavar='PATH')
69 + parser.add_argument('-f','--full_sentence', action='store_true' , help='Treat last S as a full_sentence. (Do not append it.)')
70 + parser.add_argument('-l','--length', type=int, choices=range(1,21) , help='Set length of paragraph.', metavar='LENGTH', default=15)
71 + parser.add_argument('sentence', metavar='S', type=str, nargs='*',
72 + help='korean sentence to use as input.')
73 + args = parser.parse_args()
74 + print(args)
75 + model_cache_path='/code/model' if args.docker else 'model'
76 + save_path='/code/save' if args.docker else 'save'
77 +
78 + if args.greedy:
79 + sampling_name = "Weighted" if args.weighted else "Greedy"
80 + sampling=weighted_random if args.weighted else greedy
81 + elif args.topk is not None:
82 + sampling_name=f"Weighted Top k={args.topk}" if args.weighted else f"Top k={args.topk}"
83 + sampling= (lambda pred: weighted_top_k(pred,args.topk)) if args.weighted else (lambda pred: top_k(pred,args.topk))
84 + elif args.topp is not None:
85 + sampling_name=f"Weighted Top p={args.topp}" if args.weighted else f"Top p={args.topp}"
86 + sampling= (lambda pred: weighted_top_p(pred,args.topp)) if args.weighted else (lambda pred: top_p(pred,args.topp))
87 + else: #if args.weighted:
88 + sampling_name="Weighted"
89 + sampling=weighted_random
90 +
91 + ctx='cuda:0' if torch.cuda.is_available() else 'cpu'
92 + device=torch.device(ctx)
93 + tok_path = get_tokenizer(cachedir=model_cache_path)
94 + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
95 + tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0)
96 + if args.checkpoint:
97 + checkpoint = torch.load(args.checkpoint, map_location=device)
98 + model.load_state_dict(checkpoint['model_state_dict'])
99 + epoch = checkpoint['epoch']
100 + model.eval()
101 +
102 + toked=[]
103 + for sent in args.sentence:
104 + toked += (tok(sent)+[vocab.eos_token,vocab.bos_token])
105 + else:
106 + if not args.full_sentence:
107 + toked=toked[:-2]
108 + token_count=0
109 + sent_count=0
110 + started=time.time()
111 + while token_count<1000:
112 + try:
113 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0).to(device=device)
114 + pred = model(input_ids)[0]
115 + gen_id = sampling(pred.squeeze()[-1])
116 + gen_token=vocab.to_tokens(gen_id)
117 + if gen_token == vocab.eos_token:
118 + sent_count+=1
119 + print(sent_count, token_count)
120 + if sent_count>=args.length:
121 + break
122 + else:
123 + toked+=[vocab.eos_token,vocab.bos_token]
124 + token_count+=2
125 + else:
126 + toked.append(gen_token)
127 + token_count+=1
128 + except KeyboardInterrupt:
129 + break
130 + print(f'{sampling_name}:',re.sub('</s>', '\r\n',re.sub('(▁|<s>)',' ',''.join(toked))))
131 + print("Time elapsed:", time.time()-started)
132 +
1 +from random import choice, choices, randint
2 +import argparse
3 +import torch
4 +from gluonnlp.data import SentencepieceTokenizer
5 +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
6 +from kogpt2.utils import get_tokenizer
7 +from tqdm import trange
8 +
9 +def greedy(predict):
10 + return (torch.argmax(predict, axis=-1).tolist())
11 +
12 +def top_k(predict, k):
13 + # topk 중 랜덤으로 선택된 값을 반환.
14 + probs, indices = torch.topk(predict, k=k,dim=-1)
15 + return choice(indices.tolist())
16 +
17 +def top_p(logits, threshold = 0.9):
18 + sorted_logits, sorted_indices = torch.sort(logits, descending=True)
19 + indices = sorted_indices.tolist()
20 + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
21 + cum_prob = 0
22 + top_p_index = 0
23 + # Top-p에 해당하는 index를 획득
24 + for i, prob in enumerate(sorted_softmax_logits):
25 + if cum_prob>threshold:
26 + top_p_index = 0 if i==0 else i-1
27 + break
28 + cum_prob+=prob
29 + rand_num = randint(0, top_p_index) # top-p 분포에서 랜덤 샘플링
30 + return indices[rand_num]
31 +
32 +def weighted_random(logits):
33 + indices=torch.where(logits>=0)[0] #음수 고려 안 함
34 + selected_logits=torch.index_select(logits,-1,indices)
35 + softmax_logits = torch.nn.functional.softmax(selected_logits, dim=-1)
36 + return choices(indices.tolist(),weights=softmax_logits)[0]
37 +
38 +def weighted_top_k(predict, k):
39 + probs, indices = torch.topk(predict, k=k,dim=-1)
40 + softmax_probs = torch.nn.functional.softmax(probs, dim=-1)
41 + return choices(indices.tolist(),weights=softmax_probs)[0]
42 +
43 +def weighted_top_p(logits, threshold = 0.9):
44 + sorted_logits, sorted_indices = torch.sort(logits, descending=True)
45 + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1)
46 + cum_prob = 0
47 + last_cum_prob=0
48 + top_p_bound = 0
49 + # Top-p에 해당하는 index를 획득
50 + for i, prob in enumerate(sorted_softmax_logits):
51 + if cum_prob>threshold:
52 + top_p_bound = i
53 + break
54 + last_cum_prob=cum_prob
55 + cum_prob+=prob
56 + return choices(sorted_indices[:top_p_bound].tolist(),weights=sorted_softmax_logits[:top_p_bound]/last_cum_prob)[0]
57 +
58 +
59 +if __name__ == "__main__":
60 + parser = argparse.ArgumentParser(description='KoGPT2 generation example')
61 + group=parser.add_mutually_exclusive_group()
62 + group.add_argument('-g','--greedy',action='store_const',const='greedy',help='Greedy sampling')
63 + group.add_argument('-k','--topk',type=int, choices=range(1,51), help='Top k sampling. 1<=K<=50', metavar='K')
64 + group.add_argument('-p','--topp',type=float, help='Top p sampling. 0<P<=1.0', metavar='P')
65 + parser.add_argument('-w','--weighted',action='store_true', help='Use weighted version of sampling.')
66 + parser.add_argument('-d','--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
67 + parser.add_argument('-c','--checkpoint',type=str , help='Model chekpoint path',metavar='PATH')
68 + parser.add_argument('-f','--full_sentence', action='store_true' , help='Treat last S as a full_sentence. (Do not append it.)')
69 + parser.add_argument('-l','--length', type=int, choices=range(1,21) , help='Set length of paragraph.', metavar='LENGTH', default=15)
70 + parser.add_argument('sentence', metavar='S', type=str, nargs='*',
71 + help='korean sentence to use as input.')
72 + args = parser.parse_args()
73 + print(args)
74 + model_cache_path='/code/model' if args.docker else 'model'
75 + save_path='/code/save' if args.docker else 'save'
76 +
77 + if args.greedy:
78 + sampling_name = "Weighted" if args.weighted else "Greedy"
79 + sampling=weighted_random if args.weighted else greedy
80 + elif args.topk is not None:
81 + sampling_name=f"Weighted Top k={args.topk}" if args.weighted else f"Top k={args.topk}"
82 + sampling= (lambda pred: weighted_top_k(pred,args.topk)) if args.weighted else (lambda pred: top_k(pred,args.topk))
83 + elif args.topp is not None:
84 + sampling_name=f"Weighted Top p={args.topp}" if args.weighted else f"Top p={args.topp}"
85 + sampling= (lambda pred: weighted_top_p(pred,args.topp)) if args.weighted else (lambda pred: top_p(pred,args.topp))
86 + else: #if args.weighted:
87 + sampling_name="Weighted"
88 + sampling=weighted_random
89 +
90 + ctx='cuda:0' if torch.cuda.is_available() else 'cpu'
91 + device=torch.device(ctx)
92 + tok_path = get_tokenizer(cachedir=model_cache_path)
93 + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
94 + tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0)
95 + if args.checkpoint:
96 + checkpoint = torch.load(args.checkpoint, map_location=device)
97 + model.load_state_dict(checkpoint['model_state_dict'])
98 + epoch = checkpoint['epoch']
99 + model.eval()
100 + lenght_list=[]
101 + for i in trange(20):
102 + toked=[]
103 + for sent in args.sentence:
104 + toked += (tok(sent)+[vocab.eos_token,vocab.bos_token])
105 + else:
106 + if not args.full_sentence:
107 + toked=toked[:-2]
108 + token_count=0
109 + sent_count=0
110 + while token_count<1000:
111 + try:
112 + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0).to(device=device)
113 + pred = model(input_ids)[0]
114 + gen_id = sampling(pred.squeeze()[-1])
115 + gen_token=vocab.to_tokens(gen_id)
116 + if gen_token == vocab.eos_token:
117 + sent_count+=1
118 + lenght_list.append(f"{i},{sent_count}, {token_count}\n")
119 + if sent_count>=args.length:
120 + break
121 + else:
122 + toked+=[vocab.eos_token,vocab.bos_token]
123 + token_count+=2
124 + else:
125 + toked.append(gen_token)
126 + token_count+=1
127 + except KeyboardInterrupt:
128 + break
129 + with open('length.log','a') as log:
130 + log.write(f'#-*- {args.checkpoint} -*-\n')
131 + log.writelines(lenght_list)
132 + log.write('#-*- -*-\n')
...\ No newline at end of file ...\ No newline at end of file
1 +import torch
2 +from tqdm import tqdm
3 +from util.data_loader import ArticleDataset, ToTensor
4 +from torch.utils.data import DataLoader
5 +from gluonnlp.data import SentencepieceTokenizer
6 +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
7 +from kogpt2.utils import get_tokenizer
8 +
9 +max_len=1024
10 +
11 +ctx='cuda' if torch.cuda.is_available() else 'cpu'
12 +device=torch.device(ctx)
13 +tokenizer_path = get_tokenizer(cachedir='/code/model')
14 +model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model')
15 +tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
16 +transform=ToTensor(tokenizer,vocab,max_len=max_len)
17 +batch_size=64
18 +trainset=DataLoader(ArticleDataset('/dataset', label='train', transform=transform, use_cache=False),batch_size=batch_size, num_workers=32,shuffle=True)
19 +count_dict=dict((idx,0)for idx in range(256,max_len,256))
20 +for (data, original_len) in tqdm(trainset):
21 + original_len.to(device)
22 + for bound in count_dict:
23 + count_dict[bound]+= torch.sum(torch.where( original_len<=bound , torch.ones_like(original_len), torch.zeros_like(original_len))).item()
24 +for bound in count_dict:
25 + print(f"count[{bound}]: {count_dict[bound]}/{len(trainset.dataset)} ({100*count_dict[bound]/len(trainset.dataset):.1f}%)")
...\ No newline at end of file ...\ No newline at end of file
1 +# -*- coding: utf-8 -*-
2 +import argparse
3 +import os
4 +import glob
5 +import time
6 +
7 +import gluonnlp as nlp
8 +import torch
9 +from torch.nn import DataParallel
10 +from torch.utils.data import DataLoader, Dataset
11 +from gluonnlp.data import SentencepieceTokenizer
12 +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
13 +from kogpt2.utils import get_tokenizer
14 +from tqdm import tqdm
15 +from util.data_loader import ArticleDataset, ToTensor
16 +
17 +if __name__ == "__main__":
18 + parser=argparse.ArgumentParser(description='Train KoGPT2 with ArticleDataset.')
19 + parser.add_argument('--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
20 + parser.add_argument('--resume', choices=['default', 'cpu', 'cuda', 'cuda:0', 'cuda:1'], nargs='?', const='default', help="Load state file to device; then resume train.")
21 + parser.add_argument('--topic', nargs='+',choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], default=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'])
22 + parser.add_argument('--length', type=int, default=128, choices=[2**i for i in range(11)], help="token length for transform")
23 + parser.add_argument('--epoch', type=int, default=30, help="Train epoch")
24 + parser.add_argument('device', choices=['cpu', 'cuda', 'cuda:0', 'cuda:1'])
25 + args = parser.parse_args()
26 + print(args)
27 + model_cache_path='/code/model' if args.docker else 'model'
28 + dataset_path='/dataset' if args.docker else '../dataset'
29 + save_path='/code/save' if args.docker else 'save'
30 +
31 + ctx=args.device if torch.cuda.is_available() else 'cpu'
32 + print(ctx)
33 + device=torch.device(ctx)
34 + tokenizer_path = get_tokenizer(cachedir=model_cache_path)
35 + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
36 + tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
37 + num_workers=int(32*(128/args.length)) if args.length<1024 else 4
38 + batch_size=int(64*(128/args.length)) if args.length<1024 else 4
39 + padding_id=vocab[vocab.padding_token]
40 +
41 + topics=set(set(sorted(args.topic)))
42 + transform=ToTensor(tokenizer,vocab,args.length)
43 + print("Preparing dataloader...")
44 + trainset=DataLoader(ArticleDataset(dataset_path, topics=topics,label='train', transform=transform),batch_size=batch_size, num_workers=0,shuffle=True)
45 + validset=DataLoader(ArticleDataset(dataset_path,topics=topics,label='valid', transform=transform),batch_size=batch_size, num_workers=0)
46 + #testset=DataLoader(ArticleDataset(dataset_path,label='test', transform=transform),batch_size=128, num_workers=4)
47 + print("Prepared dataloader.")
48 + epoches=args.epoch
49 + checkpoint_epoch=0
50 + learning_rate = 3e-5
51 + criterion = torch.nn.CrossEntropyLoss()
52 + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
53 +
54 +
55 +
56 + if args.resume:
57 + save_ctx=ctx if args.resume=="default" else args.resume
58 + saves=glob.glob(f'{save_path}/KoGPT2_checkpoint_{save_ctx}_{topics}_{transform.max_len}_*.state')
59 + if len(saves)>0:
60 + last_save=max(saves,key=os.path.getmtime)
61 + checkpoint = torch.load(last_save, map_location=device)
62 + print(f"Loading save from {last_save}")
63 + model.load_state_dict(checkpoint['model_state_dict'])
64 + optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
65 + checkpoint_epoch = checkpoint['epoch']
66 + last_valid_loss = checkpoint['loss']
67 + print("Loaded.")
68 + else:
69 + print("No save exists.")
70 +
71 + model.to(device)
72 + model.train()
73 +
74 + cached_trainset_path=f"{save_path}/train_{topics}_{transform.max_len}"
75 + cached_vaildset_path=f"{save_path}/valid_{topics}_{transform.max_len}"
76 + if os.path.isfile(cached_trainset_path+'.npy'):
77 + trainset.dataset.load_from_file(cached_trainset_path+'.npy')
78 + else:
79 + print("Caching trainset...")
80 + for temp in tqdm(trainset):
81 + pass
82 + trainset.dataset.set_use_cache(True, cached_trainset_path)
83 + if os.path.isfile(cached_vaildset_path+'.npy'):
84 + validset.dataset.load_from_file(cached_vaildset_path+'.npy')
85 + else:
86 + print("Caching validset...")
87 + for temp in tqdm(validset):
88 + pass
89 + validset.dataset.set_use_cache(True, cached_vaildset_path)
90 + print("Cached.")
91 +
92 + trainset.num_workers=num_workers
93 + validset.num_workers=num_workers
94 +
95 + last_valid_loss=float('infinity')
96 + overfit=-1
97 + states=[]
98 +
99 + for epoch in tqdm(range(checkpoint_epoch+1,epoches)):
100 + try:
101 + train_loss_list=[]
102 + valid_loss_list=[]
103 + for data in tqdm(trainset):
104 + optimizer.zero_grad()
105 + data = data.to(ctx)
106 + label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
107 + mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
108 + output=model(data, labels=label, attention_mask=mask)
109 + loss=output[0]
110 + loss.backward()
111 + optimizer.step()
112 + train_loss_list.append(loss.item())
113 + del loss
114 + del output
115 + del label
116 + del mask
117 + del data
118 + with torch.no_grad():
119 + for v_data in tqdm(validset):
120 + v_data = v_data.to(ctx)
121 + v_label = torch.where(v_data!=padding_id, v_data, torch.ones_like(v_data)*-100)
122 + v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data))
123 + v_output=model(v_data,labels=v_label, attention_mask=v_mask)
124 + v_loss=v_output[0]
125 + valid_loss_list.append(v_loss.item())
126 + del v_loss
127 + del v_output
128 + del v_mask
129 + del v_label
130 + del v_data
131 + valid_loss=sum(valid_loss_list)/len(valid_loss_list)
132 + train_loss=sum(train_loss_list)/len(train_loss_list)
133 + print(f"epoch: {epoch} train loss: {train_loss} valid loss: {valid_loss}")
134 + states.append((epoch,train_loss,valid_loss))
135 + if valid_loss>last_valid_loss:
136 + overfit=epoch
137 + try:
138 + torch.save({
139 + 'epoch': epoch,
140 + 'model_state_dict': model.state_dict(),
141 + 'optimizer_state_dict': optimizer.state_dict(),
142 + 'loss': train_loss
143 + }, f"{save_path}/KoGPT2_checkpoint_{ctx}_{topics}_{transform.max_len}_{epoch}.state")
144 + except Exception as e:
145 + print(e)
146 + last_valid_loss=valid_loss
147 + except KeyboardInterrupt:
148 + break
149 + log_path=f"{save_path}/{topics}_{transform.max_len}_{int(time.time())}.log"
150 + with open(log_path, 'w') as log:
151 + log.write(f"Overfit at: {overfit}\n")
152 + for state in states:
153 + log.write(f"epoch: {state[0]} train loss: {state[1]} valid loss: {state[2]}\n")
154 + print(f"Log written at: {log_path}")
155 +
156 +
157 +
158 +
1 import os 1 import os
2 +import numpy as np
2 import pyarrow as pa 3 import pyarrow as pa
3 import pyarrow.parquet as pq 4 import pyarrow.parquet as pq
4 import torch 5 import torch
...@@ -10,12 +11,15 @@ class ArticleDataset(Dataset): ...@@ -10,12 +11,15 @@ class ArticleDataset(Dataset):
10 기사 학습을 위한 데이터셋 11 기사 학습을 위한 데이터셋
11 dataset for learn articles 12 dataset for learn articles
12 """ 13 """
13 - def __init__(self, dataset_path:str, topics:list=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], label:str='train'): 14 + def __init__(self, dataset_path:str, topics:set=set(['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']), label:str='train',
15 + transform=None, use_cache=False):
14 """ 16 """
15 Initializer 17 Initializer
16 :param dataset_path: path of parquet dataset 18 :param dataset_path: path of parquet dataset
17 - :param topic: if not None, only use specified topics; must be sublist of [경제, 문화, 미용_건강, 사회, 생활, 스포츠, 연예, 정치, IT_과학] 19 + :param topic: if not None, only use specified topics; must be subset of {경제, 문화, 미용_건강, 사회, 생활, 스포츠, 연예, 정치, IT_과학}
18 :param label: specify type of dataset; must be one of [train, test, valid] (default is train) 20 :param label: specify type of dataset; must be one of [train, test, valid] (default is train)
21 + :param transform: if not None, transforms data. (paragraph:stringScalar, topic:stringScalar)=>Tensor
22 + :param use_cache: if True, __getitem__ uses cache. Must be used after first epoch.
19 """ 23 """
20 expanded_dataset_path = os.path.expanduser(dataset_path) 24 expanded_dataset_path = os.path.expanduser(dataset_path)
21 tables=[] 25 tables=[]
...@@ -23,26 +27,68 @@ class ArticleDataset(Dataset): ...@@ -23,26 +27,68 @@ class ArticleDataset(Dataset):
23 table=pq.read_table(f'{expanded_dataset_path}/topic={topic}/label={label}',columns=['paragraph']) 27 table=pq.read_table(f'{expanded_dataset_path}/topic={topic}/label={label}',columns=['paragraph'])
24 tables.append(table.append_column('topic',pa.array([topic]*len(table)))) 28 tables.append(table.append_column('topic',pa.array([topic]*len(table))))
25 self.data=pa.concat_tables(tables) 29 self.data=pa.concat_tables(tables)
30 + self.transform=transform
31 + self.use_cache=use_cache
32 + self.cache=[None]*len(self.data)
33 + #if self.transform is not None: too slow
34 + # self.data=[ self.transform((p,t)) for p, t in zip(self.data['paragraph'],self.data['topic'])]
26 35
27 def __len__(self): 36 def __len__(self):
28 return len(self.data) 37 return len(self.data)
29 38
30 def __getitem__(self,index): 39 def __getitem__(self,index):
31 - return self.data['paragraph'][index], self.data['topic'][index] 40 + item=(self.data['paragraph'][index], self.data['topic'][index]) if self.transform is None \
41 + else self.transform((self.data['paragraph'][index], self.data['topic'][index]))
42 + if self.use_cache and self.cache[index] is not None:
43 + return self.cache[index]
44 + else:
45 + self.cache[index]=item
46 + return item
47 +
48 + def load_from_file(self, cache_file_path:str):
49 + self.use_cache=True
50 + self.cache=torch.from_numpy(np.load(cache_file_path))
51 +
52 + def set_use_cache(self, use_cache:bool, cache_file_path:str=None):
53 + self.use_cache=use_cache
54 + if use_cache:
55 + if isinstance(self.cache,torch.Tensor):
56 + if cache_file_path is not None:
57 + np.save(cache_file_path,self.cache.numpy())
58 + else:
59 + print("Already fully cached.")
60 + return
61 + try:
62 + self.cache=torch.stack(self.cache)
63 + if cache_file_path is not None:
64 + np.save(cache_file_path,self.cache.numpy())
65 + except RuntimeError:
66 + print("Not fully cached yet. Please run epoch with num_worker=0.")
67 + return
68 + else:
69 + self.cache=[]
32 70
33 class ToTensor(object): 71 class ToTensor(object):
34 """ 72 """
35 Convert Article dataset paragraph to Tensor using tokenizer 73 Convert Article dataset paragraph to Tensor using tokenizer
36 """ 74 """
37 - def __init__(self, tokenizer, vocab): 75 + def __init__(self, tokenizer, vocab, max_len=512):
38 self.tokenizer=tokenizer 76 self.tokenizer=tokenizer
39 self.vocab=vocab 77 self.vocab=vocab
78 + self.max_len=max_len
40 79
41 def __call__(self, sample): 80 def __call__(self, sample):
42 tokens=[] 81 tokens=[]
43 - for i, sentence in enumerate(sample[0]): 82 + paragraph=sample[0]
83 + topic=sample[1]
84 + for i, sentence in enumerate(paragraph):
44 if i==0: 85 if i==0:
45 - tokens+=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sample[1].as_py())+self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]] 86 + line=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(topic.as_py())+self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
87 + else:
88 + line=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
89 + if len(tokens)+len(line)<=self.max_len: # prevent sentence fragment
90 + tokens+=line
46 else: 91 else:
47 - tokens+=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]]
48 - return torch.Tensor(tokens)
...\ No newline at end of file ...\ No newline at end of file
92 + break
93 + tokens+=([self.vocab[self.vocab.padding_token]]*(self.max_len-len(tokens))) # indicate padding with -100 ref: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel
94 + return torch.tensor(tokens,dtype=torch.long)
...\ No newline at end of file ...\ No newline at end of file
......