Showing
3 changed files
with
116 additions
and
0 deletions
report/캡스톤 디자인 2 주간보고서-3.docx
0 → 100644
No preview for this file type
train.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +import argparse | ||
| 3 | +import os | ||
| 4 | +import glob | ||
| 5 | + | ||
| 6 | +import gluonnlp as nlp | ||
| 7 | +import torch | ||
| 8 | +from torch.utils.data import DataLoader, Dataset | ||
| 9 | +from gluonnlp.data import SentencepieceTokenizer | ||
| 10 | +from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model | ||
| 11 | +from kogpt2.utils import get_tokenizer | ||
| 12 | +from tqdm import tqdm | ||
| 13 | +from util.data_loader import ArticleDataset, ToTensor | ||
| 14 | + | ||
| 15 | +if __name__ == "__main__": | ||
| 16 | + ctx='cuda' if torch.cuda.is_available() else 'cpu' | ||
| 17 | + device=torch.device(ctx) | ||
| 18 | + tokenizer_path = get_tokenizer(cachedir='/code/model') | ||
| 19 | + model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir='/code/model') | ||
| 20 | + tokenizer = SentencepieceTokenizer(tokenizer_path, num_best=0, alpha=0) | ||
| 21 | + num_workers=0 | ||
| 22 | + padding_id=vocab[vocab.padding_token] | ||
| 23 | + | ||
| 24 | + transform=ToTensor(tokenizer,vocab) | ||
| 25 | + print("Preparing dataloader...") | ||
| 26 | + trainset=DataLoader(ArticleDataset('/dataset',label='train', transform=transform),batch_size=64, num_workers=0,shuffle=True) | ||
| 27 | + validset=DataLoader(ArticleDataset('/dataset',label='valid', transform=transform),batch_size=64, num_workers=0) | ||
| 28 | + #testset=DataLoader(ArticleDataset('/dataset',label='test', transform=transform),batch_size=128, num_workers=4) | ||
| 29 | + print("Prepared dataloader.") | ||
| 30 | + epoches=200 | ||
| 31 | + checkpoint_epoch=0 | ||
| 32 | + learning_rate = 3e-5 | ||
| 33 | + criterion = torch.nn.CrossEntropyLoss() | ||
| 34 | + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||
| 35 | + | ||
| 36 | + save_path='/model/save' | ||
| 37 | + saves=glob.glob(save_path+'*.state') | ||
| 38 | + if len(saves)>0: | ||
| 39 | + last_save=max(saves,key=os.path.getmtime) | ||
| 40 | + checkpoint = torch.load(last_save) | ||
| 41 | + print(f"Loading save from {last_save}") | ||
| 42 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
| 43 | + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
| 44 | + checkpoint_epoch = checkpoint['epoch'] | ||
| 45 | + loss = checkpoint['loss'] | ||
| 46 | + else: | ||
| 47 | + print("No save exists.") | ||
| 48 | + | ||
| 49 | + | ||
| 50 | + model.to(device) | ||
| 51 | + model.train() | ||
| 52 | + | ||
| 53 | + last_valid_loss=float('infinity') | ||
| 54 | + for epoch in tqdm(range(checkpoint_epoch,epoches)): | ||
| 55 | + train_loss_list=[] | ||
| 56 | + valid_loss_list=[] | ||
| 57 | + for data in tqdm(trainset): | ||
| 58 | + optimizer.zero_grad() | ||
| 59 | + data = data.to(ctx) | ||
| 60 | + label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100) | ||
| 61 | + mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data)) | ||
| 62 | + output=model(data, labels=label, attention_mask=mask) | ||
| 63 | + loss, logits=output[0], output[1] | ||
| 64 | + #loss = loss.to(ctx) | ||
| 65 | + loss.backward() | ||
| 66 | + optimizer.step() | ||
| 67 | + train_loss_list.append(loss.item()) | ||
| 68 | + with torch.no_grad(): | ||
| 69 | + | ||
| 70 | + for v_data in tqdm(validset): | ||
| 71 | + v_data = v_data.to(ctx) | ||
| 72 | + v_label = torch.where(data!=padding_id, v_data, torch.ones_like(v_data)*-100) | ||
| 73 | + v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data)) | ||
| 74 | + v_output=model(v_data,labels=v_label, attention_mask=v_mask) | ||
| 75 | + v_loss, v_logits=v_output[0], v_output[1] | ||
| 76 | + valid_loss_list.append(v_loss.item()) | ||
| 77 | + valid_loss=sum(valid_loss_list)/len(valid_loss_list) | ||
| 78 | + print(f"epoch: {epoch} train loss: {sum(train_loss_list)/len(train_loss_list)} valid loss: {valid_loss}") | ||
| 79 | + if valid_loss>last_valid_loss or (epoch%10==9): | ||
| 80 | + try: | ||
| 81 | + torch.save({ | ||
| 82 | + 'epoch': epoch, | ||
| 83 | + 'train_no': i, | ||
| 84 | + 'model_state_dict': model.state_dict(), | ||
| 85 | + 'optimizer_state_dict': optimizer.state_dict(), | ||
| 86 | + 'loss': loss | ||
| 87 | + }, f"{save_path}KoGPT2_checkpoint_{ctx}{i}.state") | ||
| 88 | + except Exception as e: | ||
| 89 | + print(e) | ||
| 90 | + last_valid_loss=valid_loss | ||
| 91 | + if epoch==checkpoint_epoch: # Must run entire epoch first with num_worker=0 to fully cache dataset. | ||
| 92 | + trainset.dataset.set_use_cache(True) | ||
| 93 | + trainset.num_workers=num_workers | ||
| 94 | + validset.dataset.set_use_cache(True) | ||
| 95 | + validset.num_workers=num_workers | ||
| 96 | + | ||
| 97 | + | ||
| 98 | + |
통계
0 → 100644
| 1 | +과학 | ||
| 2 | +100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3361/3361 [00:16<00:00, 197.95it/s] | ||
| 3 | +count[256]: 30451/215067 (%) | ||
| 4 | +count[512]: 137611/215067 (%) | ||
| 5 | +count[768]: 185856/215067 (%) | ||
| 6 | +count[1024]: 205300/215067 (%) --더 이상은 모델 한계로 불가능 | ||
| 7 | +count[1280]: 211386/215067 (%) | ||
| 8 | +count[1536]: 213877/215067 (%) | ||
| 9 | +count[1792]: 214932/215067 (%) | ||
| 10 | +전체 | ||
| 11 | +100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53825/53825 [04:19<00:00, 207.39it/s] | ||
| 12 | +count[256]: 421097/3444755 (12.2%) | ||
| 13 | +count[512]: 2110517/3444755 (61.2%) | ||
| 14 | +count[768]: 2927091/3444755 (84.9%) | ||
| 15 | +count[1024]: 3242747/3444755 (94.1%) --더 이상은 모델 한계로 불가능 | ||
| 16 | +count[1280]: 3355523/3444755 (97.4%) | ||
| 17 | +count[1536]: 3410390/3444755 (99.0%) | ||
| 18 | +count[1792]: 3437609/3444755 (99.7%) |
-
Mentioned in commit ee62c0cc
-
Please register or login to post a comment