graykode

(refactor) get out train.py from train folder

......@@ -2,6 +2,7 @@ whatthepatch
gitpython
matorage
transformers
packaging
psutil
sacrebleu
......
# Copyright 2020-present Tae Hwan Jung
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import pytorch_lightning as pl
from train.finetune import main, SummarizationModule
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)
\ No newline at end of file
......@@ -12,7 +12,7 @@ import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from train.lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
......@@ -260,16 +260,16 @@ class SummarizationModule(BaseTransformer):
def get_dataset(self, type_path) -> Seq2SeqDataset:
max_target_length = self.target_lens[type_path]
data_config = DataConfig(
endpoint=args.endpoint,
endpoint=self.hparams.endpoint,
access_key=os.environ["access_key"],
secret_key=os.environ["secret_key"],
region=args.region,
region=self.hparams.region,
dataset_name="commit-autosuggestions",
additional={
"mode": ("training" if type_path == "train" else "evaluation"),
"max_source_length": self.hparams.max_source_length,
"max_target_length": max_target_length,
"url": args.url,
"url": self.hparams.url,
},
attributes=[
("input_ids", "int32", (self.hparams.max_source_length,)),
......@@ -462,13 +462,3 @@ def main(args, model=None) -> SummarizationModule:
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
\ No newline at end of file
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)
......
......@@ -21,7 +21,7 @@ from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
)
from modeling_bart import BartForConditionalGeneration
from train.modeling_bart import BartForConditionalGeneration
from transformers.optimization import (
Adafactor,
......
......@@ -41,7 +41,7 @@ from transformers.modeling_outputs import (
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
)
from modeling_utils import PreTrainedModel
from train.modeling_utils import PreTrainedModel
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......
......@@ -39,7 +39,7 @@ from transformers.file_utils import (
is_torch_tpu_available,
replace_return_docstrings,
)
from generation_utils import GenerationMixin
from train.generation_utils import GenerationMixin
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......