train_test.py 7.12 KB
# -*- coding: utf-8 -*-
import argparse
import os
import glob
import time

import gluonnlp as nlp
import torch
from torch.nn import DataParallel
from torch.utils.data import DataLoader, Dataset
from gluonnlp.data import SentencepieceTokenizer
from kogpt2.pytorch_kogpt2 import get_pytorch_kogpt2_model
from kogpt2.utils import get_tokenizer
from tqdm import tqdm
from util.data_loader import ArticleDataset, ToTensor

if __name__ == "__main__":
    parser=argparse.ArgumentParser(description='Train KoGPT2 with ArticleDataset.')
    parser.add_argument('--docker', action='store_true', help="Train on docker. Sets model cache path:/code/model, dataset path:/dataset, save path:/code/save.")
    parser.add_argument('--resume', choices=['default', 'cpu', 'cuda', 'cuda:0', 'cuda:1'], nargs='?', const='default', help="Load state file to device; then resume train.")
    parser.add_argument('--topic', nargs='+',choices=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'], default=['경제', '문화', '미용_건강', '사회', '생활', '스포츠', '연예', '정치', 'IT_과학'])
    parser.add_argument('--length', type=int, default=128, choices=[2**i for i in range(11)], help="token length for transform")
    parser.add_argument('--epoch', type=int, default=30, help="Train epoch")
    parser.add_argument('device', choices=['cpu', 'cuda', 'cuda:0', 'cuda:1'])
    args = parser.parse_args()
    print(args)
    model_cache_path='/code/model' if args.docker else 'model'
    dataset_path='/dataset' if args.docker else '../dataset'
    save_path='/code/save' if args.docker else 'save'

    ctx=args.device if torch.cuda.is_available() else 'cpu'
    print(ctx)
    device=torch.device(ctx)
    tokenizer_path = get_tokenizer(cachedir=model_cache_path)
    model, vocab = get_pytorch_kogpt2_model(ctx=ctx,cachedir=model_cache_path)
    tokenizer = SentencepieceTokenizer(tokenizer_path,  num_best=0, alpha=0)
    num_workers=int(32*(128/args.length)) if args.length<1024 else 4
    batch_size=int(64*(128/args.length)) if args.length<1024 else 4
    padding_id=vocab[vocab.padding_token]

    topics=set(set(sorted(args.topic)))
    transform=ToTensor(tokenizer,vocab,args.length)
    print("Preparing dataloader...")
    trainset=DataLoader(ArticleDataset(dataset_path, topics=topics,label='train', transform=transform),batch_size=batch_size, num_workers=0,shuffle=True)
    validset=DataLoader(ArticleDataset(dataset_path,topics=topics,label='valid', transform=transform),batch_size=batch_size, num_workers=0)
    #testset=DataLoader(ArticleDataset(dataset_path,label='test', transform=transform),batch_size=128, num_workers=4)
    print("Prepared dataloader.")
    epoches=args.epoch
    checkpoint_epoch=0
    learning_rate = 3e-5
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    
    
    if args.resume:
        save_ctx=ctx if args.resume=="default" else args.resume
        saves=glob.glob(f'{save_path}/KoGPT2_checkpoint_{save_ctx}_{topics}_{transform.max_len}_*.state')
        if len(saves)>0:
            last_save=max(saves,key=os.path.getmtime)
            checkpoint = torch.load(last_save, map_location=device)
            print(f"Loading save from {last_save}")
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            checkpoint_epoch = checkpoint['epoch']
            last_valid_loss = checkpoint['loss']
            print("Loaded.")
        else:
            print("No save exists.")

    model.to(device)
    model.train()

    cached_trainset_path=f"{save_path}/train_{topics}_{transform.max_len}"
    cached_vaildset_path=f"{save_path}/valid_{topics}_{transform.max_len}"
    if os.path.isfile(cached_trainset_path+'.npy'):
        trainset.dataset.load_from_file(cached_trainset_path+'.npy')
    else:
        print("Caching trainset...")
        for temp in tqdm(trainset):
            pass
        trainset.dataset.set_use_cache(True, cached_trainset_path)
    if os.path.isfile(cached_vaildset_path+'.npy'):
        validset.dataset.load_from_file(cached_vaildset_path+'.npy')
    else:
        print("Caching validset...")
        for temp in tqdm(validset):
            pass
        validset.dataset.set_use_cache(True, cached_vaildset_path)
    print("Cached.")

    trainset.num_workers=num_workers
    validset.num_workers=num_workers
    
    last_valid_loss=float('infinity')
    overfit=-1
    states=[]
    
    for epoch in tqdm(range(checkpoint_epoch+1,epoches)):
        try:
            train_loss_list=[]
            valid_loss_list=[]
            for data in tqdm(trainset):
                optimizer.zero_grad()
                data = data.to(ctx)
                label = torch.where(data!=padding_id, data, torch.ones_like(data)*-100)
                mask = torch.where(data!=padding_id,torch.ones_like(data),torch.zeros_like(data))
                output=model(data, labels=label, attention_mask=mask)
                loss=output[0]
                loss.backward()
                optimizer.step()
                train_loss_list.append(loss.item())
                del loss
                del output
                del label
                del mask
                del data
            with torch.no_grad():
                for v_data in tqdm(validset):
                    v_data = v_data.to(ctx)
                    v_label = torch.where(v_data!=padding_id, v_data, torch.ones_like(v_data)*-100)
                    v_mask = torch.where(v_data!=padding_id,torch.ones_like(v_data),torch.zeros_like(v_data))
                    v_output=model(v_data,labels=v_label, attention_mask=v_mask)
                    v_loss=v_output[0]
                    valid_loss_list.append(v_loss.item())
                    del v_loss
                    del v_output
                    del v_mask
                    del v_label
                    del v_data
            valid_loss=sum(valid_loss_list)/len(valid_loss_list)
            train_loss=sum(train_loss_list)/len(train_loss_list)
            print(f"epoch: {epoch} train loss: {train_loss} valid loss: {valid_loss}")
            states.append((epoch,train_loss,valid_loss))
            if valid_loss>last_valid_loss:
                overfit=epoch
            try:
                torch.save({
                                'epoch': epoch,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'loss': train_loss
                            }, f"{save_path}/KoGPT2_checkpoint_{ctx}_{topics}_{transform.max_len}_{epoch}.state")
            except Exception as e:
                print(e)
            last_valid_loss=valid_loss
        except KeyboardInterrupt:
            break
    log_path=f"{save_path}/{topics}_{transform.max_len}_{int(time.time())}.log"
    with open(log_path, 'w') as log:
        log.write(f"Overfit at: {overfit}\n")
        for state in states:
            log.write(f"epoch: {state[0]} train loss: {state[1]} valid loss: {state[2]}\n")
    print(f"Log written at: {log_path}")