graykode

(refactor) get out train.py from train folder

...@@ -2,6 +2,7 @@ whatthepatch ...@@ -2,6 +2,7 @@ whatthepatch
2 gitpython 2 gitpython
3 matorage 3 matorage
4 transformers 4 transformers
5 +packaging
5 6
6 psutil 7 psutil
7 sacrebleu 8 sacrebleu
......
1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import os
16 +import argparse
17 +import pytorch_lightning as pl
18 +from train.finetune import main, SummarizationModule
19 +
20 +if __name__ == "__main__":
21 + parser = argparse.ArgumentParser()
22 + parser = pl.Trainer.add_argparse_args(parser)
23 + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
24 +
25 + args = parser.parse_args()
26 +
27 + main(args)
...\ No newline at end of file ...\ No newline at end of file
...@@ -12,7 +12,7 @@ import pytorch_lightning as pl ...@@ -12,7 +12,7 @@ import pytorch_lightning as pl
12 import torch 12 import torch
13 from torch.utils.data import DataLoader 13 from torch.utils.data import DataLoader
14 14
15 -from lightning_base import BaseTransformer, add_generic_args, generic_train 15 +from train.lightning_base import BaseTransformer, add_generic_args, generic_train
16 from transformers import MBartTokenizer, T5ForConditionalGeneration 16 from transformers import MBartTokenizer, T5ForConditionalGeneration
17 from transformers.modeling_bart import shift_tokens_right 17 from transformers.modeling_bart import shift_tokens_right
18 18
...@@ -260,16 +260,16 @@ class SummarizationModule(BaseTransformer): ...@@ -260,16 +260,16 @@ class SummarizationModule(BaseTransformer):
260 def get_dataset(self, type_path) -> Seq2SeqDataset: 260 def get_dataset(self, type_path) -> Seq2SeqDataset:
261 max_target_length = self.target_lens[type_path] 261 max_target_length = self.target_lens[type_path]
262 data_config = DataConfig( 262 data_config = DataConfig(
263 - endpoint=args.endpoint, 263 + endpoint=self.hparams.endpoint,
264 access_key=os.environ["access_key"], 264 access_key=os.environ["access_key"],
265 secret_key=os.environ["secret_key"], 265 secret_key=os.environ["secret_key"],
266 - region=args.region, 266 + region=self.hparams.region,
267 dataset_name="commit-autosuggestions", 267 dataset_name="commit-autosuggestions",
268 additional={ 268 additional={
269 "mode": ("training" if type_path == "train" else "evaluation"), 269 "mode": ("training" if type_path == "train" else "evaluation"),
270 "max_source_length": self.hparams.max_source_length, 270 "max_source_length": self.hparams.max_source_length,
271 "max_target_length": max_target_length, 271 "max_target_length": max_target_length,
272 - "url": args.url, 272 + "url": self.hparams.url,
273 }, 273 },
274 attributes=[ 274 attributes=[
275 ("input_ids", "int32", (self.hparams.max_source_length,)), 275 ("input_ids", "int32", (self.hparams.max_source_length,)),
...@@ -462,13 +462,3 @@ def main(args, model=None) -> SummarizationModule: ...@@ -462,13 +462,3 @@ def main(args, model=None) -> SummarizationModule:
462 # test() without a model tests using the best checkpoint automatically 462 # test() without a model tests using the best checkpoint automatically
463 trainer.test() 463 trainer.test()
464 return model 464 return model
...\ No newline at end of file ...\ No newline at end of file
465 -
466 -
467 -if __name__ == "__main__":
468 - parser = argparse.ArgumentParser()
469 - parser = pl.Trainer.add_argparse_args(parser)
470 - parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
471 -
472 - args = parser.parse_args()
473 -
474 - main(args)
......
...@@ -21,7 +21,7 @@ from transformers import ( ...@@ -21,7 +21,7 @@ from transformers import (
21 PretrainedConfig, 21 PretrainedConfig,
22 PreTrainedTokenizer, 22 PreTrainedTokenizer,
23 ) 23 )
24 -from modeling_bart import BartForConditionalGeneration 24 +from train.modeling_bart import BartForConditionalGeneration
25 25
26 from transformers.optimization import ( 26 from transformers.optimization import (
27 Adafactor, 27 Adafactor,
......
...@@ -41,7 +41,7 @@ from transformers.modeling_outputs import ( ...@@ -41,7 +41,7 @@ from transformers.modeling_outputs import (
41 Seq2SeqQuestionAnsweringModelOutput, 41 Seq2SeqQuestionAnsweringModelOutput,
42 Seq2SeqSequenceClassifierOutput, 42 Seq2SeqSequenceClassifierOutput,
43 ) 43 )
44 -from modeling_utils import PreTrainedModel 44 +from train.modeling_utils import PreTrainedModel
45 import logging 45 import logging
46 46
47 logger = logging.getLogger(__name__) # pylint: disable=invalid-name 47 logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......
...@@ -39,7 +39,7 @@ from transformers.file_utils import ( ...@@ -39,7 +39,7 @@ from transformers.file_utils import (
39 is_torch_tpu_available, 39 is_torch_tpu_available,
40 replace_return_docstrings, 40 replace_return_docstrings,
41 ) 41 )
42 -from generation_utils import GenerationMixin 42 +from train.generation_utils import GenerationMixin
43 import logging 43 import logging
44 44
45 logger = logging.getLogger(__name__) # pylint: disable=invalid-name 45 logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......