ligtning_train_test.py
5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import tqdm
import gluonnlp as nlp
import torch
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from kogpt2.utils import get_tokenizer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from torch.utils.data import DataLoader, Dataset
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
from .utils.data_loader import ArticleDataset, ToTensor
class KoGPT2Chat(LightningModule):
def __init__(self, hparams, **kwargs):
super(KoGPT2Chat, self).__init__()
self.hparams = hparams
self.tok_path = get_tokenizer()
self.neg = -1e18
self.model, self.vocab = get_pytorch_kogpt2_model()
self.loss_function = torch.nn.CrossEntropyLoss(reduction='none')
@staticmethod
def add_model_specific_args(parent_parser):
# add model specific args
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--max-len',
type=int,
default=32,
help='max sentence length on input (default: 32)')
parser.add_argument('--batch-size',
type=int,
default=96,
help='batch size for training (default: 96)')
parser.add_argument('--lr',
type=float,
default=5e-5,
help='The initial learning rate')
parser.add_argument('--warmup_ratio',
type=float,
default=0.1,
help='warmup ratio')
return parser
def forward(self, inputs):
# (batch, seq_len, hiddens)
output, _ = self.kogpt2(inputs)
return output
def training_step(self, batch, batch_idx):
token_ids, mask, label = batch
out = self(token_ids)
mask_3d = mask.unsqueeze(dim=2).repeat_interleave(repeats=out.shape[2], dim=2)
mask_out = torch.where(mask_3d == 1, out, self.neg * torch.ones_like(out))
loss = self.loss_function(mask_out.transpose(2, 1), label)
loss_avg = loss.sum() / mask.sum()
tensorboard_logs = {'train_loss': loss_avg}
return {'loss': loss_avg, 'log': tensorboard_logs}
def configure_optimizers(self):
# Prepare optimizer
param_optimizer = list(self.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters,
lr=self.hparams.lr, correct_bias=False)
# warm up lr
num_train_steps = len(self.train_dataloader()) * self.hparams.max_epochs
num_warmup_steps = int(num_train_steps * self.hparams.warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)
lr_scheduler = {'scheduler': scheduler, 'name': 'cosine_schedule_with_warmup',
'monitor': 'loss', 'interval': 'step',
'frequency': 1}
return [optimizer], [lr_scheduler]
def _collate_fn(self, batch):
data = [item[0] for item in batch]
mask = [item[1] for item in batch]
label = [item[2] for item in batch]
return torch.LongTensor(data), torch.LongTensor(mask), torch.LongTensor(label)
def train_dataloader(self):
data = pd.read_csv('Chatbot_data/ChatbotData.csv')
jhkhk self.train_set = ArticleDataset(data, self.tok_path, self.vocab, max_len=self.hparams.max_len)
train_dataloader = DataLoader(
self.train_set, batch_size=self.hparams.batch_size, num_workers=2,
shuffle=True, collate_fn=self._collate_fn)
return train_dataloader
parser = KoGPT2Chat.add_model_specific_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
if __name__ == "__main__":
if args.train:
checkpoint_callback = ModelCheckpoint(
filepath='model_chp/{epoch:02d}-{loss:.2f}',
verbose=True,
save_last=True,
monitor='loss',
mode='min',
prefix='model_'
)
# python train_torch.py --train --gpus 1 --max_epochs 3
model = KoGPT2Chat(args)
model.train()
trainer = Trainer.from_argparse_args(
args,
checkpoint_callback=checkpoint_callback, gradient_clip_val=1.0)
trainer.fit(model, )
if args.chat:
model = KoGPT2Chat.load_from_checkpoint(args.model_params)
model.eval()