graykode

(refactor) get out train.py from train folder

...@@ -2,6 +2,7 @@ whatthepatch ...@@ -2,6 +2,7 @@ whatthepatch
2 gitpython 2 gitpython
3 matorage 3 matorage
4 transformers 4 transformers
5 +packaging
5 6
6 psutil 7 psutil
7 sacrebleu 8 sacrebleu
......
1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import os
16 +import argparse
17 +import pytorch_lightning as pl
18 +from train.finetune import main, SummarizationModule
19 +
20 +if __name__ == "__main__":
21 + parser = argparse.ArgumentParser()
22 + parser = pl.Trainer.add_argparse_args(parser)
23 + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
24 +
25 + args = parser.parse_args()
26 +
27 + main(args)
...\ No newline at end of file ...\ No newline at end of file
...@@ -12,7 +12,7 @@ import pytorch_lightning as pl ...@@ -12,7 +12,7 @@ import pytorch_lightning as pl
12 import torch 12 import torch
13 from torch.utils.data import DataLoader 13 from torch.utils.data import DataLoader
14 14
15 -from lightning_base import BaseTransformer, add_generic_args, generic_train 15 +from train.lightning_base import BaseTransformer, add_generic_args, generic_train
16 from transformers import MBartTokenizer, T5ForConditionalGeneration 16 from transformers import MBartTokenizer, T5ForConditionalGeneration
17 from transformers.modeling_bart import shift_tokens_right 17 from transformers.modeling_bart import shift_tokens_right
18 18
...@@ -260,16 +260,16 @@ class SummarizationModule(BaseTransformer): ...@@ -260,16 +260,16 @@ class SummarizationModule(BaseTransformer):
260 def get_dataset(self, type_path) -> Seq2SeqDataset: 260 def get_dataset(self, type_path) -> Seq2SeqDataset:
261 max_target_length = self.target_lens[type_path] 261 max_target_length = self.target_lens[type_path]
262 data_config = DataConfig( 262 data_config = DataConfig(
263 - endpoint=args.endpoint, 263 + endpoint=self.hparams.endpoint,
264 access_key=os.environ["access_key"], 264 access_key=os.environ["access_key"],
265 secret_key=os.environ["secret_key"], 265 secret_key=os.environ["secret_key"],
266 - region=args.region, 266 + region=self.hparams.region,
267 dataset_name="commit-autosuggestions", 267 dataset_name="commit-autosuggestions",
268 additional={ 268 additional={
269 "mode": ("training" if type_path == "train" else "evaluation"), 269 "mode": ("training" if type_path == "train" else "evaluation"),
270 "max_source_length": self.hparams.max_source_length, 270 "max_source_length": self.hparams.max_source_length,
271 "max_target_length": max_target_length, 271 "max_target_length": max_target_length,
272 - "url": args.url, 272 + "url": self.hparams.url,
273 }, 273 },
274 attributes=[ 274 attributes=[
275 ("input_ids", "int32", (self.hparams.max_source_length,)), 275 ("input_ids", "int32", (self.hparams.max_source_length,)),
...@@ -461,14 +461,4 @@ def main(args, model=None) -> SummarizationModule: ...@@ -461,14 +461,4 @@ def main(args, model=None) -> SummarizationModule:
461 461
462 # test() without a model tests using the best checkpoint automatically 462 # test() without a model tests using the best checkpoint automatically
463 trainer.test() 463 trainer.test()
464 - return model 464 + return model
465 -
466 -
467 -if __name__ == "__main__":
468 - parser = argparse.ArgumentParser()
469 - parser = pl.Trainer.add_argparse_args(parser)
470 - parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
471 -
472 - args = parser.parse_args()
473 -
474 - main(args)
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -21,7 +21,7 @@ from transformers import ( ...@@ -21,7 +21,7 @@ from transformers import (
21 PretrainedConfig, 21 PretrainedConfig,
22 PreTrainedTokenizer, 22 PreTrainedTokenizer,
23 ) 23 )
24 -from modeling_bart import BartForConditionalGeneration 24 +from train.modeling_bart import BartForConditionalGeneration
25 25
26 from transformers.optimization import ( 26 from transformers.optimization import (
27 Adafactor, 27 Adafactor,
......
...@@ -41,7 +41,7 @@ from transformers.modeling_outputs import ( ...@@ -41,7 +41,7 @@ from transformers.modeling_outputs import (
41 Seq2SeqQuestionAnsweringModelOutput, 41 Seq2SeqQuestionAnsweringModelOutput,
42 Seq2SeqSequenceClassifierOutput, 42 Seq2SeqSequenceClassifierOutput,
43 ) 43 )
44 -from modeling_utils import PreTrainedModel 44 +from train.modeling_utils import PreTrainedModel
45 import logging 45 import logging
46 46
47 logger = logging.getLogger(__name__) # pylint: disable=invalid-name 47 logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......
...@@ -39,7 +39,7 @@ from transformers.file_utils import ( ...@@ -39,7 +39,7 @@ from transformers.file_utils import (
39 is_torch_tpu_available, 39 is_torch_tpu_available,
40 replace_return_docstrings, 40 replace_return_docstrings,
41 ) 41 )
42 -from generation_utils import GenerationMixin 42 +from train.generation_utils import GenerationMixin
43 import logging 43 import logging
44 44
45 logger = logging.getLogger(__name__) # pylint: disable=invalid-name 45 logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......