train_test.py
7.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# -*- coding: utf-8 -*-
import argparse
import os
import glob
import time
import gluonnlp as nlp
import torch
from torch.nn import DataParallel
from torch.utils.data import DataLoader, Dataset
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from kogpt2.utils import get_tokenizer
from tqdm import tqdm
from util.data_loader import ArticleDataset, ToTensor
if __name__ == "__main__":
parser=argparse.ArgumentParser(description='Train KoGPT2 with ArticleDataset.')
parser.add_argument('--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
parser.add_argument('--resume', choices=['default', 'cpu', 'cuda', 'cuda:0', 'cuda:1'], nargs='?', const='default', help="Load state file to device; then resume train.")
parser.add_argument('--topic', nargs='+',choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], default=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'])
parser.add_argument('--length', type=int, default=128, choices=[2**i for i in range(11)], help="token length for transform")
parser.add_argument('--epoch', type=int, default=30, help="Train epoch")
parser.add_argument('device', choices=['cpu', 'cuda', 'cuda:0', 'cuda:1'])
args = parser.parse_args()
print(args)
model_cache_path='/code/model' if args.docker else 'model'
dataset_path='/dataset' if args.docker else '../dataset'
save_path='/code/save' if args.docker else 'save'
ctx=args.device if torch.cuda.is_available() else 'cpu'
print(ctx)
device=torch.device(ctx)
tokenizer_path = get_tokenizer(cachedir=model_cache_path)
model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0)
num_workers=int(32*(128/args.length)) if args.length<1024 else 4
batch_size=int(64*(128/args.length)) if args.length<1024 else 4
padding_id=vocab[vocab.padding_token]
topics=set(set(sorted(args.topic)))
transform=ToTensor(tokenizer,vocab,args.length)
print("Preparing dataloader...")
trainset=DataLoader(ArticleDataset(dataset_path, topics=topics,label='train', transform=transform),batch_size=batch_size, num_workers=0,shuffle=True)
validset=DataLoader(ArticleDataset(dataset_path,topics=topics,label='valid', transform=transform),batch_size=batch_size, num_workers=0)
#testset=DataLoader(ArticleDataset(dataset_path,label='test', transform=transform),batch_size=128, num_workers=4)
print("Prepared dataloader.")
epoches=args.epoch
checkpoint_epoch=0
learning_rate = 3e-5
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
if args.resume:
save_ctx=ctx if args.resume=="default" else args.resume
saves=glob.glob(f'{save_path}/KoGPT2_checkpoint_{save_ctx}_{topics}_{transform.max_len}_*.state')
if len(saves)>0:
last_save=max(saves,key=os.path.getmtime)
checkpoint = torch.load(last_save, map_location=device)
print(f"Loading save from {last_save}")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint_epoch = checkpoint['epoch']
last_valid_loss = checkpoint['loss']
print("Loaded.")
else:
print("No save exists.")
model.to(device)
model.train()
cached_trainset_path=f"{save_path}/train_{topics}_{transform.max_len}"
cached_vaildset_path=f"{save_path}/valid_{topics}_{transform.max_len}"
if os.path.isfile(cached_trainset_path+'.npy'):
trainset.dataset.load_from_file(cached_trainset_path+'.npy')
else:
print("Caching trainset...")
for temp in tqdm(trainset):
pass
trainset.dataset.set_use_cache(True, cached_trainset_path)
if os.path.isfile(cached_vaildset_path+'.npy'):
validset.dataset.load_from_file(cached_vaildset_path+'.npy')
else:
print("Caching validset...")
for temp in tqdm(validset):
pass
validset.dataset.set_use_cache(True, cached_vaildset_path)
print("Cached.")
trainset.num_workers=num_workers
validset.num_workers=num_workers
last_valid_loss=float('infinity')
overfit=-1
states=[]
for epoch in tqdm(range(checkpoint_epoch+1,epoches)):
try:
train_loss_list=[]
valid_loss_list=[]
for data in tqdm(trainset):
optimizer.zero_grad()
data = data.to(ctx)
label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
output=model(data, labels=label, attention_mask=mask)
loss=output[0]
loss.backward()
optimizer.step()
train_loss_list.append(loss.item())
del loss
del output
del label
del mask
del data
with torch.no_grad():
for v_data in tqdm(validset):
v_data = v_data.to(ctx)
v_label = torch.where(v_data!=padding_id, v_data, torch.ones_like(v_data)*-100)
v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data))
v_output=model(v_data,labels=v_label, attention_mask=v_mask)
v_loss=v_output[0]
valid_loss_list.append(v_loss.item())
del v_loss
del v_output
del v_mask
del v_label
del v_data
valid_loss=sum(valid_loss_list)/len(valid_loss_list)
train_loss=sum(train_loss_list)/len(train_loss_list)
print(f"epoch: {epoch} train loss: {train_loss} valid loss: {valid_loss}")
states.append((epoch,train_loss,valid_loss))
if valid_loss>last_valid_loss:
overfit=epoch
try:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss
}, f"{save_path}/KoGPT2_checkpoint_{ctx}_{topics}_{transform.max_len}_{epoch}.state")
except Exception as e:
print(e)
last_valid_loss=valid_loss
except KeyboardInterrupt:
break
log_path=f"{save_path}/{topics}_{transform.max_len}_{int(time.time())}.log"
with open(log_path, 'w') as log:
log.write(f"Overfit at: {overfit}\n")
for state in states:
log.write(f"epoch: {state[0]} train loss: {state[1]} valid loss: {state[2]}\n")
print(f"Log written at: {log_path}")