Showing
9 changed files
with
929 additions
and
8 deletions
code/cross_test.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | +import argparse | ||
3 | +from argparse import ArgumentError | ||
4 | +import os | ||
5 | +import glob | ||
6 | +import time | ||
7 | +import subprocess | ||
8 | + | ||
9 | +import gluonnlp as nlp | ||
10 | +from numpy.lib.function_base import delete | ||
11 | +import torch | ||
12 | +from torch.utils.data import DataLoader, Dataset | ||
13 | +from gluonnlp.data import SentencepieceTokenizer | ||
14 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
15 | +from kogpt2.utils import get_tokenizer | ||
16 | +from tqdm import tqdm | ||
17 | +from util.data_loader import ArticleDataset, ToTensor | ||
18 | + | ||
19 | +def get_gpu_memory_map(): | ||
20 | + """Get the current gpu usage. | ||
21 | + Returns | ||
22 | + ------- | ||
23 | + usage: dict | ||
24 | + Keys are device ids as integers. | ||
25 | + Values are memory usage as integers in MB. | ||
26 | + """ | ||
27 | + result = subprocess.check_output( | ||
28 | + [ | ||
29 | + 'nvidia-smi', '--query-gpu=memory.used', | ||
30 | + '--format=csv,nounits,noheader' | ||
31 | + ], encoding='utf-8') | ||
32 | + # Convert lines into a dictionary | ||
33 | + gpu_memory = [int(x) for x in result.strip().split('\n')] | ||
34 | + gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) | ||
35 | + return gpu_memory_map | ||
36 | + | ||
37 | + | ||
38 | +if __name__ == "__main__": | ||
39 | + parser=argparse.ArgumentParser(description='Train KoGPT2 with ArticleDataset.') | ||
40 | + parser.add_argument('--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.") | ||
41 | + parser.add_argument('--default', action='store_true', help="Use un-tuned KoGPT2") | ||
42 | + parser.add_argument('--model_topic', choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'] ) | ||
43 | + parser.add_argument('--epoch', type=int) | ||
44 | + parser.add_argument('--topic', nargs='+',choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], default=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']) | ||
45 | + parser.add_argument('device', choices=['cpu', 'cuda', 'cuda:0', 'cuda:1']) | ||
46 | + args = parser.parse_args() | ||
47 | + print(args) | ||
48 | + | ||
49 | + model_cache_path='/code/model' if args.docker else 'model' | ||
50 | + dataset_path='/dataset' if args.docker else '../dataset' | ||
51 | + save_path='/code/save' if args.docker else 'save' | ||
52 | + | ||
53 | + ctx=args.device if torch.cuda.is_available() else 'cpu' | ||
54 | + print(ctx) | ||
55 | + device=torch.device(ctx) | ||
56 | + tokenizer_path = get_tokenizer(cachedir=model_cache_path) | ||
57 | + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path) | ||
58 | + tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0) | ||
59 | + num_workers=32 | ||
60 | + batch_size=64 | ||
61 | + padding_id=vocab[vocab.padding_token] | ||
62 | + topics=set(sorted(args.topic)) | ||
63 | + transform=ToTensor(tokenizer,vocab,128) | ||
64 | + print("Preparing dataloader...") | ||
65 | + dataloaders={} | ||
66 | + dataloaders["all"]=DataLoader(ArticleDataset(dataset_path,label='test', transform=transform),batch_size=batch_size, num_workers=0) | ||
67 | + for topic in tqdm(topics): | ||
68 | + dataloaders[topic]=DataLoader(ArticleDataset(dataset_path, topics={topic},label='test', transform=transform),batch_size=batch_size, num_workers=0) | ||
69 | + print("Prepared dataloader.") | ||
70 | + epoches=30 | ||
71 | + checkpoint_epoch=0 | ||
72 | + learning_rate = 3e-5 | ||
73 | + criterion = torch.nn.CrossEntropyLoss() | ||
74 | + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||
75 | + | ||
76 | + topic_all=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'] | ||
77 | + model_topic=topic_all if args.model_topic is None else sorted(list({args.model_topic})) | ||
78 | + model_epoch='*' if args.epoch is None else args.epoch | ||
79 | + dev=ctx if ctx in {'cpu', 'cuda'} else 'cuda:*' | ||
80 | + braced=str("{'생활', '경제', 'IT_과학', '미용_건강', '스포츠', '사회', '연예', '문화', '정치'}") if args.model_topic is None else '{'+str(model_topic)[1:-1]+'}' | ||
81 | + saves=glob.glob(f'{save_path}/KoGPT2_checkpoint_{dev}_{braced}_{transform.max_len}_{model_epoch}.state') | ||
82 | + if not args.default: | ||
83 | + if len(saves)>0: | ||
84 | + last_save=max(saves,key=os.path.getmtime) | ||
85 | + checkpoint = torch.load(last_save, map_location=device) | ||
86 | + print(f"Loading save from {last_save}") | ||
87 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
88 | + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
89 | + checkpoint_epoch = checkpoint['epoch'] | ||
90 | + last_test_loss = checkpoint['loss'] | ||
91 | + else: | ||
92 | + print("No save exists.") | ||
93 | + raise FileNotFoundError(f'{save_path}/KoGPT2_checkpoint_{ctx}_{model_topic}_{transform.max_len}_{model_epoch}.state') | ||
94 | + model.to(device) | ||
95 | + model.eval() | ||
96 | + | ||
97 | + cached_testset_path=f"{save_path}/test_{topic_all}_{transform.max_len}" | ||
98 | + if os.path.isfile(cached_testset_path+'.npy'): | ||
99 | + dataloaders["all"].dataset.load_from_file(cached_testset_path+'.npy') | ||
100 | + else: | ||
101 | + print("Caching testset... topic: all") | ||
102 | + for temp in tqdm(dataloaders["all"]): | ||
103 | + pass | ||
104 | + dataloaders["all"].dataset.set_use_cache(True, cached_testset_path) | ||
105 | + print("Cached. topic: all") | ||
106 | + dataloaders["all"].dataset.num_workers=num_workers | ||
107 | + for topic in tqdm(topics): | ||
108 | + cached_testset_path=f"{save_path}/test_{{topic}}_{transform.max_len}" | ||
109 | + if os.path.isfile(cached_testset_path+'.npy'): | ||
110 | + dataloaders[topic].dataset.load_from_file(cached_testset_path+'.npy') | ||
111 | + else: | ||
112 | + print(f"Caching testset... topic: {topic}") | ||
113 | + for temp in tqdm(dataloaders[topic]): | ||
114 | + pass | ||
115 | + dataloaders[topic].dataset.set_use_cache(True, cached_testset_path) | ||
116 | + print(f"Cached. topic: {topic}") | ||
117 | + dataloaders[topic].dataset.num_workers=num_workers | ||
118 | + | ||
119 | + last_test_loss=float('infinity') | ||
120 | + overfit=-1 | ||
121 | + states=[] | ||
122 | + | ||
123 | + for topic in tqdm(dataloaders): | ||
124 | + try: | ||
125 | + test_loss_list=[] | ||
126 | + for data in tqdm(dataloaders[topic]): | ||
127 | + data = data.to(ctx) | ||
128 | + label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100) | ||
129 | + mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data)) | ||
130 | + output=model(data,labels=label, attention_mask=mask) | ||
131 | + loss=output[0] | ||
132 | + test_loss_list.append(loss.item()) | ||
133 | + del label | ||
134 | + del mask | ||
135 | + del loss | ||
136 | + del output | ||
137 | + del data | ||
138 | + test_loss=sum(test_loss_list)/len(test_loss_list) | ||
139 | + print(f"data_topic: {topic}, model_topic: {model_topic} test loss: {test_loss}") | ||
140 | + states.append((topic, model_topic,test_loss)) | ||
141 | + except KeyboardInterrupt: | ||
142 | + break | ||
143 | + log_path=f"{save_path}/test_{'DEFAULT' if args.default else model_topic}_{topics}_{transform.max_len}_{int(time.time())}.log" | ||
144 | + with open(log_path, 'w') as log: | ||
145 | + log.write("data_topic, model_topic, test loss,\n") | ||
146 | + for state in states: | ||
147 | + log.write(f"{state[0]}, {state[1]},{state[2]},\n") | ||
148 | + print(f"Log written at: {log_path}") | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/example.py
0 → 100644
1 | +import torch | ||
2 | +from random import choice, choices, randint | ||
3 | +import argparse | ||
4 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
5 | +from gluonnlp.data import SentencepieceTokenizer | ||
6 | +from kogpt2.utils import get_tokenizer | ||
7 | + | ||
8 | +def top_k(predict, vocab, k): | ||
9 | + # topk 중 랜덤으로 선택된 값을 반환. | ||
10 | + probs, indices = torch.topk(predict, k=k,dim=-1) | ||
11 | + return vocab.to_tokens(choice(indices.tolist())) | ||
12 | + | ||
13 | +def top_p(logits, vocab, threshold = 0.9): | ||
14 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
15 | + indexs = sorted_indices.tolist() | ||
16 | + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1) | ||
17 | + cum_prob = 0 | ||
18 | + top_p_index = 0 | ||
19 | + # Top-p에 해당하는 index를 획득 | ||
20 | + for i, prob in enumerate(sorted_softmax_logits): | ||
21 | + if cum_prob>threshold: | ||
22 | + top_p_index = 0 if i==0 else i-1 | ||
23 | + break | ||
24 | + cum_prob+=prob | ||
25 | + rand_num = randint(0, top_p_index) # top-p 분포에서 랜덤 샘플링 | ||
26 | + return vocab.to_tokens(indexs[rand_num]) | ||
27 | + | ||
28 | +def weighted_random(logits, vocab): | ||
29 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
30 | + indexs = sorted_indices.tolist() | ||
31 | + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1) | ||
32 | + return vocab.to_tokens(choices(indexs,weights=sorted_softmax_logits)[0]) | ||
33 | + | ||
34 | +if __name__ == "__main__": | ||
35 | + parser = argparse.ArgumentParser(description='KoGPT2 generation example') | ||
36 | + parser.add_argument('sentence', metavar='S', type=str, nargs='?',default= '2019년 한해를 보내며,', | ||
37 | + help='korean sentence to use as input.') | ||
38 | + | ||
39 | + ctx='cuda' if torch.cuda.is_available() else 'cpu' | ||
40 | + tok_path = get_tokenizer(cachedir='/code/model') | ||
41 | + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model') | ||
42 | + tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0) | ||
43 | + sent = parser.parse_args().sentence | ||
44 | + toked = tok(sent) | ||
45 | + token_count=0 | ||
46 | + while token_count<100: | ||
47 | + try: | ||
48 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0) | ||
49 | + pred = model(input_ids)[0] | ||
50 | + gen = vocab.to_tokens(torch.argmax(pred, axis=-1).squeeze().tolist())[-1] | ||
51 | + if gen == '</s>': | ||
52 | + break | ||
53 | + sent += gen.replace('▁', ' ') | ||
54 | + toked = tok(sent) | ||
55 | + token_count+=1 | ||
56 | + except KeyboardInterrupt: | ||
57 | + break | ||
58 | + print('Greedy:',sent) | ||
59 | + | ||
60 | + sent = parser.parse_args().sentence | ||
61 | + toked = tok(sent) | ||
62 | + token_count=0 | ||
63 | + while token_count<100: | ||
64 | + try: | ||
65 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0) | ||
66 | + pred = model(input_ids)[0] | ||
67 | + gen = top_k(pred.squeeze()[-1], vocab, 3) | ||
68 | + if gen == '</s>': | ||
69 | + break | ||
70 | + sent += gen.replace('▁', ' ') | ||
71 | + toked = tok(sent) | ||
72 | + token_count+=1 | ||
73 | + except KeyboardInterrupt: | ||
74 | + break | ||
75 | + print('Top 3:', sent) | ||
76 | + | ||
77 | + sent = parser.parse_args().sentence | ||
78 | + toked = tok(sent) | ||
79 | + token_count=0 | ||
80 | + while token_count<100: | ||
81 | + try: | ||
82 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0) | ||
83 | + pred = model(input_ids)[0] | ||
84 | + gen = top_k(pred.squeeze()[-1], vocab, 5) | ||
85 | + if gen == '</s>': | ||
86 | + break | ||
87 | + sent += gen.replace('▁', ' ') | ||
88 | + toked = tok(sent) | ||
89 | + token_count+=1 | ||
90 | + except KeyboardInterrupt: | ||
91 | + break | ||
92 | + print('Top 5:', sent) | ||
93 | + | ||
94 | + sent = parser.parse_args().sentence | ||
95 | + toked = tok(sent) | ||
96 | + token_count=0 | ||
97 | + while token_count<100: | ||
98 | + try: | ||
99 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0) | ||
100 | + pred = model(input_ids)[0] | ||
101 | + gen = top_p(pred.squeeze()[-1], vocab,0.5) | ||
102 | + if gen == '</s>': | ||
103 | + break | ||
104 | + sent += gen.replace('▁', ' ') | ||
105 | + toked = tok(sent) | ||
106 | + token_count+=1 | ||
107 | + except KeyboardInterrupt: | ||
108 | + break | ||
109 | + print('Top p=0.5:', sent) | ||
110 | + | ||
111 | + sent = parser.parse_args().sentence | ||
112 | + toked = tok(sent) | ||
113 | + token_count=0 | ||
114 | + while token_count<100: | ||
115 | + try: | ||
116 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0) | ||
117 | + pred = model(input_ids)[0] | ||
118 | + gen = top_p(pred.squeeze()[-1], vocab,0.7) | ||
119 | + if gen == '</s>': | ||
120 | + break | ||
121 | + sent += gen.replace('▁', ' ') | ||
122 | + toked = tok(sent) | ||
123 | + token_count+=1 | ||
124 | + except KeyboardInterrupt: | ||
125 | + break | ||
126 | + print('Top p=0.7:', sent) | ||
127 | + | ||
128 | + sent = parser.parse_args().sentence | ||
129 | + toked = tok(sent) | ||
130 | + token_count=0 | ||
131 | + while token_count<100: | ||
132 | + try: | ||
133 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0) | ||
134 | + pred = model(input_ids)[0] | ||
135 | + gen = top_p(pred.squeeze()[-1], vocab) | ||
136 | + if gen == '</s>': | ||
137 | + break | ||
138 | + sent += gen.replace('▁', ' ') | ||
139 | + toked = tok(sent) | ||
140 | + token_count+=1 | ||
141 | + except KeyboardInterrupt: | ||
142 | + break | ||
143 | + print('Top p=0.9:', sent) | ||
144 | + | ||
145 | + sent = parser.parse_args().sentence | ||
146 | + toked = tok(sent) | ||
147 | + token_count=0 | ||
148 | + while token_count<100: | ||
149 | + try: | ||
150 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0) | ||
151 | + pred = model(input_ids)[0] | ||
152 | + gen = weighted_random(pred.squeeze()[-1], vocab) | ||
153 | + if gen == '</s>': | ||
154 | + break | ||
155 | + sent += gen.replace('▁', ' ') | ||
156 | + toked = tok(sent) | ||
157 | + token_count+=1 | ||
158 | + except KeyboardInterrupt: | ||
159 | + break | ||
160 | + print('Weighted random:', sent) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/ligtning_train_test.py
0 → 100644
1 | +import argparse | ||
2 | +import tqdm | ||
3 | +import gluonnlp as nlp | ||
4 | +import torch | ||
5 | +from gluonnlp.data import SentencepieceTokenizer | ||
6 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
7 | +from kogpt2.utils import get_tokenizer | ||
8 | +from pytorch_lightning import Trainer | ||
9 | +from pytorch_lightning.callbacks import ModelCheckpoint | ||
10 | +from pytorch_lightning.core.lightning import LightningModule | ||
11 | +from torch.utils.data import DataLoader, Dataset | ||
12 | +from transformers.optimization import AdamW, get_cosine_schedule_with_warmup | ||
13 | +from .utils.data_loader import ArticleDataset, ToTensor | ||
14 | + | ||
15 | +class KoGPT2Chat(LightningModule): | ||
16 | + def __init__(self, hparams, **kwargs): | ||
17 | + super(KoGPT2Chat, self).__init__() | ||
18 | + self.hparams = hparams | ||
19 | + self.tok_path = get_tokenizer() | ||
20 | + self.neg = -1e18 | ||
21 | + self.model, self.vocab = get_pytorch_kogpt2_model() | ||
22 | + self.loss_function = torch.nn.CrossEntropyLoss(reduction='none') | ||
23 | + | ||
24 | + @staticmethod | ||
25 | + def add_model_specific_args(parent_parser): | ||
26 | + # add model specific args | ||
27 | + parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) | ||
28 | + parser.add_argument('--max-len', | ||
29 | + type=int, | ||
30 | + default=32, | ||
31 | + help='max sentence length on input (default: 32)') | ||
32 | + | ||
33 | + parser.add_argument('--batch-size', | ||
34 | + type=int, | ||
35 | + default=96, | ||
36 | + help='batch size for training (default: 96)') | ||
37 | + parser.add_argument('--lr', | ||
38 | + type=float, | ||
39 | + default=5e-5, | ||
40 | + help='The initial learning rate') | ||
41 | + parser.add_argument('--warmup_ratio', | ||
42 | + type=float, | ||
43 | + default=0.1, | ||
44 | + help='warmup ratio') | ||
45 | + return parser | ||
46 | + | ||
47 | + def forward(self, inputs): | ||
48 | + # (batch, seq_len, hiddens) | ||
49 | + output, _ = self.kogpt2(inputs) | ||
50 | + return output | ||
51 | + | ||
52 | + def training_step(self, batch, batch_idx): | ||
53 | + token_ids, mask, label = batch | ||
54 | + out = self(token_ids) | ||
55 | + mask_3d = mask.unsqueeze(dim=2).repeat_interleave(repeats=out.shape[2], dim=2) | ||
56 | + mask_out = torch.where(mask_3d == 1, out, self.neg * torch.ones_like(out)) | ||
57 | + loss = self.loss_function(mask_out.transpose(2, 1), label) | ||
58 | + loss_avg = loss.sum() / mask.sum() | ||
59 | + tensorboard_logs = {'train_loss': loss_avg} | ||
60 | + return {'loss': loss_avg, 'log': tensorboard_logs} | ||
61 | + | ||
62 | + def configure_optimizers(self): | ||
63 | + # Prepare optimizer | ||
64 | + param_optimizer = list(self.named_parameters()) | ||
65 | + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] | ||
66 | + optimizer_grouped_parameters = [ | ||
67 | + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, | ||
68 | + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} | ||
69 | + ] | ||
70 | + optimizer = AdamW(optimizer_grouped_parameters, | ||
71 | + lr=self.hparams.lr, correct_bias=False) | ||
72 | + # warm up lr | ||
73 | + num_train_steps = len(self.train_dataloader()) * self.hparams.max_epochs | ||
74 | + num_warmup_steps = int(num_train_steps * self.hparams.warmup_ratio) | ||
75 | + scheduler = get_cosine_schedule_with_warmup( | ||
76 | + optimizer, | ||
77 | + num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps) | ||
78 | + lr_scheduler = {'scheduler': scheduler, 'name': 'cosine_schedule_with_warmup', | ||
79 | + 'monitor': 'loss', 'interval': 'step', | ||
80 | + 'frequency': 1} | ||
81 | + return [optimizer], [lr_scheduler] | ||
82 | + | ||
83 | + def _collate_fn(self, batch): | ||
84 | + data = [item[0] for item in batch] | ||
85 | + mask = [item[1] for item in batch] | ||
86 | + label = [item[2] for item in batch] | ||
87 | + return torch.LongTensor(data), torch.LongTensor(mask), torch.LongTensor(label) | ||
88 | + | ||
89 | + def train_dataloader(self): | ||
90 | + data = pd.read_csv('Chatbot_data/ChatbotData.csv') | ||
91 | + jhkhk self.train_set = ArticleDataset(data, self.tok_path, self.vocab, max_len=self.hparams.max_len) | ||
92 | + train_dataloader = DataLoader( | ||
93 | + self.train_set, batch_size=self.hparams.batch_size, num_workers=2, | ||
94 | + shuffle=True, collate_fn=self._collate_fn) | ||
95 | + return train_dataloader | ||
96 | + | ||
97 | +parser = KoGPT2Chat.add_model_specific_args(parser) | ||
98 | +parser = Trainer.add_argparse_args(parser) | ||
99 | +args = parser.parse_args() | ||
100 | + | ||
101 | +if __name__ == "__main__": | ||
102 | + if args.train: | ||
103 | + checkpoint_callback = ModelCheckpoint( | ||
104 | + filepath='model_chp/{epoch:02d}-{loss:.2f}', | ||
105 | + verbose=True, | ||
106 | + save_last=True, | ||
107 | + monitor='loss', | ||
108 | + mode='min', | ||
109 | + prefix='model_' | ||
110 | + ) | ||
111 | + # python train_torch.py --train --gpus 1 --max_epochs 3 | ||
112 | + model = KoGPT2Chat(args) | ||
113 | + model.train() | ||
114 | + trainer = Trainer.from_argparse_args( | ||
115 | + args, | ||
116 | + checkpoint_callback=checkpoint_callback, gradient_clip_val=1.0) | ||
117 | + trainer.fit(model, ) | ||
118 | + if args.chat: | ||
119 | + model = KoGPT2Chat.load_from_checkpoint(args.model_params) | ||
120 | + model.eval() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/paragraph_gen.py
0 → 100644
1 | +from random import choice, choices, randint | ||
2 | +import argparse | ||
3 | +import re | ||
4 | +import time | ||
5 | +import torch | ||
6 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
7 | +from gluonnlp.data import SentencepieceTokenizer | ||
8 | +from kogpt2.utils import get_tokenizer | ||
9 | + | ||
10 | +def greedy(predict): | ||
11 | + return (torch.argmax(predict, axis=-1).tolist()) | ||
12 | + | ||
13 | +def top_k(predict, k): | ||
14 | + # topk 중 랜덤으로 선택된 값을 반환. | ||
15 | + probs, indices = torch.topk(predict, k=k,dim=-1) | ||
16 | + return choice(indices.tolist()) | ||
17 | + | ||
18 | +def top_p(logits, threshold = 0.9): | ||
19 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
20 | + indices = sorted_indices.tolist() | ||
21 | + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1) | ||
22 | + cum_prob = 0 | ||
23 | + top_p_index = 0 | ||
24 | + # Top-p에 해당하는 index를 획득 | ||
25 | + for i, prob in enumerate(sorted_softmax_logits): | ||
26 | + if cum_prob>threshold: | ||
27 | + top_p_index = 0 if i==0 else i-1 | ||
28 | + break | ||
29 | + cum_prob+=prob | ||
30 | + rand_num = randint(0, top_p_index) # top-p 분포에서 랜덤 샘플링 | ||
31 | + return indices[rand_num] | ||
32 | + | ||
33 | +def weighted_random(logits): | ||
34 | + indices=torch.where(logits>=0)[0] #음수 고려 안 함 | ||
35 | + selected_logits=torch.index_select(logits,-1,indices) | ||
36 | + softmax_logits = torch.nn.functional.softmax(selected_logits, dim=-1) | ||
37 | + return choices(indices.tolist(),weights=softmax_logits)[0] | ||
38 | + | ||
39 | +def weighted_top_k(predict, k): | ||
40 | + probs, indices = torch.topk(predict, k=k,dim=-1) | ||
41 | + softmax_probs = torch.nn.functional.softmax(probs, dim=-1) | ||
42 | + return choices(indices.tolist(),weights=softmax_probs)[0] | ||
43 | + | ||
44 | +def weighted_top_p(logits, threshold = 0.9): | ||
45 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
46 | + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1) | ||
47 | + cum_prob = 0 | ||
48 | + last_cum_prob=0 | ||
49 | + top_p_bound = 0 | ||
50 | + # Top-p에 해당하는 index를 획득 | ||
51 | + for i, prob in enumerate(sorted_softmax_logits): | ||
52 | + if cum_prob>threshold: | ||
53 | + top_p_bound = i | ||
54 | + break | ||
55 | + last_cum_prob=cum_prob | ||
56 | + cum_prob+=prob | ||
57 | + return choices(sorted_indices[:top_p_bound].tolist(),weights=sorted_softmax_logits[:top_p_bound]/last_cum_prob)[0] | ||
58 | + | ||
59 | + | ||
60 | +if __name__ == "__main__": | ||
61 | + parser = argparse.ArgumentParser(description='KoGPT2 generation example') | ||
62 | + group=parser.add_mutually_exclusive_group() | ||
63 | + group.add_argument('-g','--greedy',action='store_const',const='greedy',help='Greedy sampling') | ||
64 | + group.add_argument('-k','--topk',type=int, choices=range(1,51), help='Top k sampling. 1<=K<=50', metavar='K') | ||
65 | + group.add_argument('-p','--topp',type=float, help='Top p sampling. 0<P<=1.0', metavar='P') | ||
66 | + parser.add_argument('-w','--weighted',action='store_true', help='Use weighted version of sampling.') | ||
67 | + parser.add_argument('-d','--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.") | ||
68 | + parser.add_argument('-c','--checkpoint',type=str , help='Model chekpoint path',metavar='PATH') | ||
69 | + parser.add_argument('-f','--full_sentence', action='store_true' , help='Treat last S as a full_sentence. (Do not append it.)') | ||
70 | + parser.add_argument('-l','--length', type=int, choices=range(1,21) , help='Set length of paragraph.', metavar='LENGTH', default=15) | ||
71 | + parser.add_argument('sentence', metavar='S', type=str, nargs='*', | ||
72 | + help='korean sentence to use as input.') | ||
73 | + args = parser.parse_args() | ||
74 | + print(args) | ||
75 | + model_cache_path='/code/model' if args.docker else 'model' | ||
76 | + save_path='/code/save' if args.docker else 'save' | ||
77 | + | ||
78 | + if args.greedy: | ||
79 | + sampling_name = "Weighted" if args.weighted else "Greedy" | ||
80 | + sampling=weighted_random if args.weighted else greedy | ||
81 | + elif args.topk is not None: | ||
82 | + sampling_name=f"Weighted Top k={args.topk}" if args.weighted else f"Top k={args.topk}" | ||
83 | + sampling= (lambda pred: weighted_top_k(pred,args.topk)) if args.weighted else (lambda pred: top_k(pred,args.topk)) | ||
84 | + elif args.topp is not None: | ||
85 | + sampling_name=f"Weighted Top p={args.topp}" if args.weighted else f"Top p={args.topp}" | ||
86 | + sampling= (lambda pred: weighted_top_p(pred,args.topp)) if args.weighted else (lambda pred: top_p(pred,args.topp)) | ||
87 | + else: #if args.weighted: | ||
88 | + sampling_name="Weighted" | ||
89 | + sampling=weighted_random | ||
90 | + | ||
91 | + ctx='cuda:0' if torch.cuda.is_available() else 'cpu' | ||
92 | + device=torch.device(ctx) | ||
93 | + tok_path = get_tokenizer(cachedir=model_cache_path) | ||
94 | + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path) | ||
95 | + tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0) | ||
96 | + if args.checkpoint: | ||
97 | + checkpoint = torch.load(args.checkpoint, map_location=device) | ||
98 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
99 | + epoch = checkpoint['epoch'] | ||
100 | + model.eval() | ||
101 | + | ||
102 | + toked=[] | ||
103 | + for sent in args.sentence: | ||
104 | + toked += (tok(sent)+[vocab.eos_token,vocab.bos_token]) | ||
105 | + else: | ||
106 | + if not args.full_sentence: | ||
107 | + toked=toked[:-2] | ||
108 | + token_count=0 | ||
109 | + sent_count=0 | ||
110 | + started=time.time() | ||
111 | + while token_count<1000: | ||
112 | + try: | ||
113 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0).to(device=device) | ||
114 | + pred = model(input_ids)[0] | ||
115 | + gen_id = sampling(pred.squeeze()[-1]) | ||
116 | + gen_token=vocab.to_tokens(gen_id) | ||
117 | + if gen_token == vocab.eos_token: | ||
118 | + sent_count+=1 | ||
119 | + print(sent_count, token_count) | ||
120 | + if sent_count>=args.length: | ||
121 | + break | ||
122 | + else: | ||
123 | + toked+=[vocab.eos_token,vocab.bos_token] | ||
124 | + token_count+=2 | ||
125 | + else: | ||
126 | + toked.append(gen_token) | ||
127 | + token_count+=1 | ||
128 | + except KeyboardInterrupt: | ||
129 | + break | ||
130 | + print(f'{sampling_name}:',re.sub('</s>', '\r\n',re.sub('(▁|<s>)',' ',''.join(toked)))) | ||
131 | + print("Time elapsed:", time.time()-started) | ||
132 | + |
code/paragraph_len.py
0 → 100644
1 | +from random import choice, choices, randint | ||
2 | +import argparse | ||
3 | +import torch | ||
4 | +from gluonnlp.data import SentencepieceTokenizer | ||
5 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
6 | +from kogpt2.utils import get_tokenizer | ||
7 | +from tqdm import trange | ||
8 | + | ||
9 | +def greedy(predict): | ||
10 | + return (torch.argmax(predict, axis=-1).tolist()) | ||
11 | + | ||
12 | +def top_k(predict, k): | ||
13 | + # topk 중 랜덤으로 선택된 값을 반환. | ||
14 | + probs, indices = torch.topk(predict, k=k,dim=-1) | ||
15 | + return choice(indices.tolist()) | ||
16 | + | ||
17 | +def top_p(logits, threshold = 0.9): | ||
18 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
19 | + indices = sorted_indices.tolist() | ||
20 | + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1) | ||
21 | + cum_prob = 0 | ||
22 | + top_p_index = 0 | ||
23 | + # Top-p에 해당하는 index를 획득 | ||
24 | + for i, prob in enumerate(sorted_softmax_logits): | ||
25 | + if cum_prob>threshold: | ||
26 | + top_p_index = 0 if i==0 else i-1 | ||
27 | + break | ||
28 | + cum_prob+=prob | ||
29 | + rand_num = randint(0, top_p_index) # top-p 분포에서 랜덤 샘플링 | ||
30 | + return indices[rand_num] | ||
31 | + | ||
32 | +def weighted_random(logits): | ||
33 | + indices=torch.where(logits>=0)[0] #음수 고려 안 함 | ||
34 | + selected_logits=torch.index_select(logits,-1,indices) | ||
35 | + softmax_logits = torch.nn.functional.softmax(selected_logits, dim=-1) | ||
36 | + return choices(indices.tolist(),weights=softmax_logits)[0] | ||
37 | + | ||
38 | +def weighted_top_k(predict, k): | ||
39 | + probs, indices = torch.topk(predict, k=k,dim=-1) | ||
40 | + softmax_probs = torch.nn.functional.softmax(probs, dim=-1) | ||
41 | + return choices(indices.tolist(),weights=softmax_probs)[0] | ||
42 | + | ||
43 | +def weighted_top_p(logits, threshold = 0.9): | ||
44 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
45 | + sorted_softmax_logits = torch.nn.functional.softmax(sorted_logits, dim=-1) | ||
46 | + cum_prob = 0 | ||
47 | + last_cum_prob=0 | ||
48 | + top_p_bound = 0 | ||
49 | + # Top-p에 해당하는 index를 획득 | ||
50 | + for i, prob in enumerate(sorted_softmax_logits): | ||
51 | + if cum_prob>threshold: | ||
52 | + top_p_bound = i | ||
53 | + break | ||
54 | + last_cum_prob=cum_prob | ||
55 | + cum_prob+=prob | ||
56 | + return choices(sorted_indices[:top_p_bound].tolist(),weights=sorted_softmax_logits[:top_p_bound]/last_cum_prob)[0] | ||
57 | + | ||
58 | + | ||
59 | +if __name__ == "__main__": | ||
60 | + parser = argparse.ArgumentParser(description='KoGPT2 generation example') | ||
61 | + group=parser.add_mutually_exclusive_group() | ||
62 | + group.add_argument('-g','--greedy',action='store_const',const='greedy',help='Greedy sampling') | ||
63 | + group.add_argument('-k','--topk',type=int, choices=range(1,51), help='Top k sampling. 1<=K<=50', metavar='K') | ||
64 | + group.add_argument('-p','--topp',type=float, help='Top p sampling. 0<P<=1.0', metavar='P') | ||
65 | + parser.add_argument('-w','--weighted',action='store_true', help='Use weighted version of sampling.') | ||
66 | + parser.add_argument('-d','--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.") | ||
67 | + parser.add_argument('-c','--checkpoint',type=str , help='Model chekpoint path',metavar='PATH') | ||
68 | + parser.add_argument('-f','--full_sentence', action='store_true' , help='Treat last S as a full_sentence. (Do not append it.)') | ||
69 | + parser.add_argument('-l','--length', type=int, choices=range(1,21) , help='Set length of paragraph.', metavar='LENGTH', default=15) | ||
70 | + parser.add_argument('sentence', metavar='S', type=str, nargs='*', | ||
71 | + help='korean sentence to use as input.') | ||
72 | + args = parser.parse_args() | ||
73 | + print(args) | ||
74 | + model_cache_path='/code/model' if args.docker else 'model' | ||
75 | + save_path='/code/save' if args.docker else 'save' | ||
76 | + | ||
77 | + if args.greedy: | ||
78 | + sampling_name = "Weighted" if args.weighted else "Greedy" | ||
79 | + sampling=weighted_random if args.weighted else greedy | ||
80 | + elif args.topk is not None: | ||
81 | + sampling_name=f"Weighted Top k={args.topk}" if args.weighted else f"Top k={args.topk}" | ||
82 | + sampling= (lambda pred: weighted_top_k(pred,args.topk)) if args.weighted else (lambda pred: top_k(pred,args.topk)) | ||
83 | + elif args.topp is not None: | ||
84 | + sampling_name=f"Weighted Top p={args.topp}" if args.weighted else f"Top p={args.topp}" | ||
85 | + sampling= (lambda pred: weighted_top_p(pred,args.topp)) if args.weighted else (lambda pred: top_p(pred,args.topp)) | ||
86 | + else: #if args.weighted: | ||
87 | + sampling_name="Weighted" | ||
88 | + sampling=weighted_random | ||
89 | + | ||
90 | + ctx='cuda:0' if torch.cuda.is_available() else 'cpu' | ||
91 | + device=torch.device(ctx) | ||
92 | + tok_path = get_tokenizer(cachedir=model_cache_path) | ||
93 | + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path) | ||
94 | + tok = SentencepieceTokenizer(tok_path, num_best=0, alpha=0) | ||
95 | + if args.checkpoint: | ||
96 | + checkpoint = torch.load(args.checkpoint, map_location=device) | ||
97 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
98 | + epoch = checkpoint['epoch'] | ||
99 | + model.eval() | ||
100 | + lenght_list=[] | ||
101 | + for i in trange(20): | ||
102 | + toked=[] | ||
103 | + for sent in args.sentence: | ||
104 | + toked += (tok(sent)+[vocab.eos_token,vocab.bos_token]) | ||
105 | + else: | ||
106 | + if not args.full_sentence: | ||
107 | + toked=toked[:-2] | ||
108 | + token_count=0 | ||
109 | + sent_count=0 | ||
110 | + while token_count<1000: | ||
111 | + try: | ||
112 | + input_ids = torch.tensor([vocab[vocab.bos_token],] + vocab[toked]).unsqueeze(0).to(device=device) | ||
113 | + pred = model(input_ids)[0] | ||
114 | + gen_id = sampling(pred.squeeze()[-1]) | ||
115 | + gen_token=vocab.to_tokens(gen_id) | ||
116 | + if gen_token == vocab.eos_token: | ||
117 | + sent_count+=1 | ||
118 | + lenght_list.append(f"{i},{sent_count}, {token_count}\n") | ||
119 | + if sent_count>=args.length: | ||
120 | + break | ||
121 | + else: | ||
122 | + toked+=[vocab.eos_token,vocab.bos_token] | ||
123 | + token_count+=2 | ||
124 | + else: | ||
125 | + toked.append(gen_token) | ||
126 | + token_count+=1 | ||
127 | + except KeyboardInterrupt: | ||
128 | + break | ||
129 | + with open('length.log','a') as log: | ||
130 | + log.write(f'#-*- {args.checkpoint} -*-\n') | ||
131 | + log.writelines(lenght_list) | ||
132 | + log.write('#-*- -*-\n') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/tokenize_stastics.py
0 → 100644
1 | +import torch | ||
2 | +from tqdm import tqdm | ||
3 | +from util.data_loader import ArticleDataset, ToTensor | ||
4 | +from torch.utils.data import DataLoader | ||
5 | +from gluonnlp.data import SentencepieceTokenizer | ||
6 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
7 | +from kogpt2.utils import get_tokenizer | ||
8 | + | ||
9 | +max_len=1024 | ||
10 | + | ||
11 | +ctx='cuda' if torch.cuda.is_available() else 'cpu' | ||
12 | +device=torch.device(ctx) | ||
13 | +tokenizer_path = get_tokenizer(cachedir='/code/model') | ||
14 | +model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model') | ||
15 | +tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0) | ||
16 | +transform=ToTensor(tokenizer,vocab,max_len=max_len) | ||
17 | +batch_size=64 | ||
18 | +trainset=DataLoader(ArticleDataset('/dataset', label='train', transform=transform, use_cache=False),batch_size=batch_size, num_workers=32,shuffle=True) | ||
19 | +count_dict=dict((idx,0)for idx in range(256,max_len,256)) | ||
20 | +for (data, original_len) in tqdm(trainset): | ||
21 | + original_len.to(device) | ||
22 | + for bound in count_dict: | ||
23 | + count_dict[bound]+= torch.sum(torch.where( original_len<=bound , torch.ones_like(original_len), torch.zeros_like(original_len))).item() | ||
24 | +for bound in count_dict: | ||
25 | + print(f"count[{bound}]: {count_dict[bound]}/{len(trainset.dataset)} ({100*count_dict[bound]/len(trainset.dataset):.1f}%)") | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/train_test.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | +import argparse | ||
3 | +import os | ||
4 | +import glob | ||
5 | +import time | ||
6 | + | ||
7 | +import gluonnlp as nlp | ||
8 | +import torch | ||
9 | +from torch.nn import DataParallel | ||
10 | +from torch.utils.data import DataLoader, Dataset | ||
11 | +from gluonnlp.data import SentencepieceTokenizer | ||
12 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
13 | +from kogpt2.utils import get_tokenizer | ||
14 | +from tqdm import tqdm | ||
15 | +from util.data_loader import ArticleDataset, ToTensor | ||
16 | + | ||
17 | +if __name__ == "__main__": | ||
18 | + parser=argparse.ArgumentParser(description='Train KoGPT2 with ArticleDataset.') | ||
19 | + parser.add_argument('--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.") | ||
20 | + parser.add_argument('--resume', choices=['default', 'cpu', 'cuda', 'cuda:0', 'cuda:1'], nargs='?', const='default', help="Load state file to device; then resume train.") | ||
21 | + parser.add_argument('--topic', nargs='+',choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], default=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']) | ||
22 | + parser.add_argument('--length', type=int, default=128, choices=[2**i for i in range(11)], help="token length for transform") | ||
23 | + parser.add_argument('--epoch', type=int, default=30, help="Train epoch") | ||
24 | + parser.add_argument('device', choices=['cpu', 'cuda', 'cuda:0', 'cuda:1']) | ||
25 | + args = parser.parse_args() | ||
26 | + print(args) | ||
27 | + model_cache_path='/code/model' if args.docker else 'model' | ||
28 | + dataset_path='/dataset' if args.docker else '../dataset' | ||
29 | + save_path='/code/save' if args.docker else 'save' | ||
30 | + | ||
31 | + ctx=args.device if torch.cuda.is_available() else 'cpu' | ||
32 | + print(ctx) | ||
33 | + device=torch.device(ctx) | ||
34 | + tokenizer_path = get_tokenizer(cachedir=model_cache_path) | ||
35 | + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path) | ||
36 | + tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0) | ||
37 | + num_workers=int(32*(128/args.length)) if args.length<1024 else 4 | ||
38 | + batch_size=int(64*(128/args.length)) if args.length<1024 else 4 | ||
39 | + padding_id=vocab[vocab.padding_token] | ||
40 | + | ||
41 | + topics=set(set(sorted(args.topic))) | ||
42 | + transform=ToTensor(tokenizer,vocab,args.length) | ||
43 | + print("Preparing dataloader...") | ||
44 | + trainset=DataLoader(ArticleDataset(dataset_path, topics=topics,label='train', transform=transform),batch_size=batch_size, num_workers=0,shuffle=True) | ||
45 | + validset=DataLoader(ArticleDataset(dataset_path,topics=topics,label='valid', transform=transform),batch_size=batch_size, num_workers=0) | ||
46 | + #testset=DataLoader(ArticleDataset(dataset_path,label='test', transform=transform),batch_size=128, num_workers=4) | ||
47 | + print("Prepared dataloader.") | ||
48 | + epoches=args.epoch | ||
49 | + checkpoint_epoch=0 | ||
50 | + learning_rate = 3e-5 | ||
51 | + criterion = torch.nn.CrossEntropyLoss() | ||
52 | + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||
53 | + | ||
54 | + | ||
55 | + | ||
56 | + if args.resume: | ||
57 | + save_ctx=ctx if args.resume=="default" else args.resume | ||
58 | + saves=glob.glob(f'{save_path}/KoGPT2_checkpoint_{save_ctx}_{topics}_{transform.max_len}_*.state') | ||
59 | + if len(saves)>0: | ||
60 | + last_save=max(saves,key=os.path.getmtime) | ||
61 | + checkpoint = torch.load(last_save, map_location=device) | ||
62 | + print(f"Loading save from {last_save}") | ||
63 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
64 | + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
65 | + checkpoint_epoch = checkpoint['epoch'] | ||
66 | + last_valid_loss = checkpoint['loss'] | ||
67 | + print("Loaded.") | ||
68 | + else: | ||
69 | + print("No save exists.") | ||
70 | + | ||
71 | + model.to(device) | ||
72 | + model.train() | ||
73 | + | ||
74 | + cached_trainset_path=f"{save_path}/train_{topics}_{transform.max_len}" | ||
75 | + cached_vaildset_path=f"{save_path}/valid_{topics}_{transform.max_len}" | ||
76 | + if os.path.isfile(cached_trainset_path+'.npy'): | ||
77 | + trainset.dataset.load_from_file(cached_trainset_path+'.npy') | ||
78 | + else: | ||
79 | + print("Caching trainset...") | ||
80 | + for temp in tqdm(trainset): | ||
81 | + pass | ||
82 | + trainset.dataset.set_use_cache(True, cached_trainset_path) | ||
83 | + if os.path.isfile(cached_vaildset_path+'.npy'): | ||
84 | + validset.dataset.load_from_file(cached_vaildset_path+'.npy') | ||
85 | + else: | ||
86 | + print("Caching validset...") | ||
87 | + for temp in tqdm(validset): | ||
88 | + pass | ||
89 | + validset.dataset.set_use_cache(True, cached_vaildset_path) | ||
90 | + print("Cached.") | ||
91 | + | ||
92 | + trainset.num_workers=num_workers | ||
93 | + validset.num_workers=num_workers | ||
94 | + | ||
95 | + last_valid_loss=float('infinity') | ||
96 | + overfit=-1 | ||
97 | + states=[] | ||
98 | + | ||
99 | + for epoch in tqdm(range(checkpoint_epoch+1,epoches)): | ||
100 | + try: | ||
101 | + train_loss_list=[] | ||
102 | + valid_loss_list=[] | ||
103 | + for data in tqdm(trainset): | ||
104 | + optimizer.zero_grad() | ||
105 | + data = data.to(ctx) | ||
106 | + label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100) | ||
107 | + mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data)) | ||
108 | + output=model(data, labels=label, attention_mask=mask) | ||
109 | + loss=output[0] | ||
110 | + loss.backward() | ||
111 | + optimizer.step() | ||
112 | + train_loss_list.append(loss.item()) | ||
113 | + del loss | ||
114 | + del output | ||
115 | + del label | ||
116 | + del mask | ||
117 | + del data | ||
118 | + with torch.no_grad(): | ||
119 | + for v_data in tqdm(validset): | ||
120 | + v_data = v_data.to(ctx) | ||
121 | + v_label = torch.where(v_data!=padding_id, v_data, torch.ones_like(v_data)*-100) | ||
122 | + v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data)) | ||
123 | + v_output=model(v_data,labels=v_label, attention_mask=v_mask) | ||
124 | + v_loss=v_output[0] | ||
125 | + valid_loss_list.append(v_loss.item()) | ||
126 | + del v_loss | ||
127 | + del v_output | ||
128 | + del v_mask | ||
129 | + del v_label | ||
130 | + del v_data | ||
131 | + valid_loss=sum(valid_loss_list)/len(valid_loss_list) | ||
132 | + train_loss=sum(train_loss_list)/len(train_loss_list) | ||
133 | + print(f"epoch: {epoch} train loss: {train_loss} valid loss: {valid_loss}") | ||
134 | + states.append((epoch,train_loss,valid_loss)) | ||
135 | + if valid_loss>last_valid_loss: | ||
136 | + overfit=epoch | ||
137 | + try: | ||
138 | + torch.save({ | ||
139 | + 'epoch': epoch, | ||
140 | + 'model_state_dict': model.state_dict(), | ||
141 | + 'optimizer_state_dict': optimizer.state_dict(), | ||
142 | + 'loss': train_loss | ||
143 | + }, f"{save_path}/KoGPT2_checkpoint_{ctx}_{topics}_{transform.max_len}_{epoch}.state") | ||
144 | + except Exception as e: | ||
145 | + print(e) | ||
146 | + last_valid_loss=valid_loss | ||
147 | + except KeyboardInterrupt: | ||
148 | + break | ||
149 | + log_path=f"{save_path}/{topics}_{transform.max_len}_{int(time.time())}.log" | ||
150 | + with open(log_path, 'w') as log: | ||
151 | + log.write(f"Overfit at: {overfit}\n") | ||
152 | + for state in states: | ||
153 | + log.write(f"epoch: {state[0]} train loss: {state[1]} valid loss: {state[2]}\n") | ||
154 | + print(f"Log written at: {log_path}") | ||
155 | + | ||
156 | + | ||
157 | + | ||
158 | + |
1 | import os | 1 | import os |
2 | +import numpy as np | ||
2 | import pyarrow as pa | 3 | import pyarrow as pa |
3 | import pyarrow.parquet as pq | 4 | import pyarrow.parquet as pq |
4 | import torch | 5 | import torch |
... | @@ -10,12 +11,15 @@ class ArticleDataset(Dataset): | ... | @@ -10,12 +11,15 @@ class ArticleDataset(Dataset): |
10 | 기사 학습을 위한 데이터셋 | 11 | 기사 학습을 위한 데이터셋 |
11 | dataset for learn articles | 12 | dataset for learn articles |
12 | """ | 13 | """ |
13 | - def __init__(self, dataset_path:str, topics:list=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], label:str='train'): | 14 | + def __init__(self, dataset_path:str, topics:set=set(['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학']), label:str='train', |
15 | + transform=None, use_cache=False): | ||
14 | """ | 16 | """ |
15 | Initializer | 17 | Initializer |
16 | :param dataset_path: path of parquet dataset | 18 | :param dataset_path: path of parquet dataset |
17 | - :param topic: if not None, only use specified topics; must be sublist of [경제, 문화, 미용_건강, 사회, 생활, 스포츠, 연예, 정치, IT_과학] | 19 | + :param topic: if not None, only use specified topics; must be subset of {경제, 문화, 미용_건강, 사회, 생활, 스포츠, 연예, 정치, IT_과학} |
18 | :param label: specify type of dataset; must be one of [train, test, valid] (default is train) | 20 | :param label: specify type of dataset; must be one of [train, test, valid] (default is train) |
21 | + :param transform: if not None, transforms data. (paragraph:stringScalar, topic:stringScalar)=>Tensor | ||
22 | + :param use_cache: if True, __getitem__ uses cache. Must be used after first epoch. | ||
19 | """ | 23 | """ |
20 | expanded_dataset_path = os.path.expanduser(dataset_path) | 24 | expanded_dataset_path = os.path.expanduser(dataset_path) |
21 | tables=[] | 25 | tables=[] |
... | @@ -23,26 +27,68 @@ class ArticleDataset(Dataset): | ... | @@ -23,26 +27,68 @@ class ArticleDataset(Dataset): |
23 | table=pq.read_table(f'{expanded_dataset_path}/topic={topic}/label={label}',columns=['paragraph']) | 27 | table=pq.read_table(f'{expanded_dataset_path}/topic={topic}/label={label}',columns=['paragraph']) |
24 | tables.append(table.append_column('topic',pa.array([topic]*len(table)))) | 28 | tables.append(table.append_column('topic',pa.array([topic]*len(table)))) |
25 | self.data=pa.concat_tables(tables) | 29 | self.data=pa.concat_tables(tables) |
30 | + self.transform=transform | ||
31 | + self.use_cache=use_cache | ||
32 | + self.cache=[None]*len(self.data) | ||
33 | + #if self.transform is not None: too slow | ||
34 | + # self.data=[ self.transform((p,t)) for p, t in zip(self.data['paragraph'],self.data['topic'])] | ||
26 | 35 | ||
27 | def __len__(self): | 36 | def __len__(self): |
28 | return len(self.data) | 37 | return len(self.data) |
29 | 38 | ||
30 | def __getitem__(self,index): | 39 | def __getitem__(self,index): |
31 | - return self.data['paragraph'][index], self.data['topic'][index] | 40 | + item=(self.data['paragraph'][index], self.data['topic'][index]) if self.transform is None \ |
41 | + else self.transform((self.data['paragraph'][index], self.data['topic'][index])) | ||
42 | + if self.use_cache and self.cache[index] is not None: | ||
43 | + return self.cache[index] | ||
44 | + else: | ||
45 | + self.cache[index]=item | ||
46 | + return item | ||
47 | + | ||
48 | + def load_from_file(self, cache_file_path:str): | ||
49 | + self.use_cache=True | ||
50 | + self.cache=torch.from_numpy(np.load(cache_file_path)) | ||
51 | + | ||
52 | + def set_use_cache(self, use_cache:bool, cache_file_path:str=None): | ||
53 | + self.use_cache=use_cache | ||
54 | + if use_cache: | ||
55 | + if isinstance(self.cache,torch.Tensor): | ||
56 | + if cache_file_path is not None: | ||
57 | + np.save(cache_file_path,self.cache.numpy()) | ||
58 | + else: | ||
59 | + print("Already fully cached.") | ||
60 | + return | ||
61 | + try: | ||
62 | + self.cache=torch.stack(self.cache) | ||
63 | + if cache_file_path is not None: | ||
64 | + np.save(cache_file_path,self.cache.numpy()) | ||
65 | + except RuntimeError: | ||
66 | + print("Not fully cached yet. Please run epoch with num_worker=0.") | ||
67 | + return | ||
68 | + else: | ||
69 | + self.cache=[] | ||
32 | 70 | ||
33 | class ToTensor(object): | 71 | class ToTensor(object): |
34 | """ | 72 | """ |
35 | Convert Article dataset paragraph to Tensor using tokenizer | 73 | Convert Article dataset paragraph to Tensor using tokenizer |
36 | """ | 74 | """ |
37 | - def __init__(self, tokenizer, vocab): | 75 | + def __init__(self, tokenizer, vocab, max_len=512): |
38 | self.tokenizer=tokenizer | 76 | self.tokenizer=tokenizer |
39 | self.vocab=vocab | 77 | self.vocab=vocab |
78 | + self.max_len=max_len | ||
40 | 79 | ||
41 | def __call__(self, sample): | 80 | def __call__(self, sample): |
42 | tokens=[] | 81 | tokens=[] |
43 | - for i, sentence in enumerate(sample[0]): | 82 | + paragraph=sample[0] |
83 | + topic=sample[1] | ||
84 | + for i, sentence in enumerate(paragraph): | ||
44 | if i==0: | 85 | if i==0: |
45 | - tokens+=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sample[1].as_py())+self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]] | 86 | + line=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(topic.as_py())+self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]] |
87 | + else: | ||
88 | + line=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]] | ||
89 | + if len(tokens)+len(line)<=self.max_len: # prevent sentence fragment | ||
90 | + tokens+=line | ||
46 | else: | 91 | else: |
47 | - tokens+=[self.vocab[self.vocab.bos_token]]+self.vocab[self.tokenizer(sentence.as_py())]+[self.vocab[self.vocab.eos_token]] | ||
48 | - return torch.Tensor(tokens) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
92 | + break | ||
93 | + tokens+=([self.vocab[self.vocab.padding_token]]*(self.max_len-len(tokens))) # indicate padding with -100 ref: https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel | ||
94 | + return torch.tensor(tokens,dtype=torch.long) | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
report/캡스톤 디자인 2 주간보고서-4.docx
0 → 100644
No preview for this file type
-
Please register or login to post a comment