Showing
6 changed files
with
35 additions
and
17 deletions
train.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import os | ||
16 | +import argparse | ||
17 | +import pytorch_lightning as pl | ||
18 | +from train.finetune import main, SummarizationModule | ||
19 | + | ||
20 | +if __name__ == "__main__": | ||
21 | + parser = argparse.ArgumentParser() | ||
22 | + parser = pl.Trainer.add_argparse_args(parser) | ||
23 | + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) | ||
24 | + | ||
25 | + args = parser.parse_args() | ||
26 | + | ||
27 | + main(args) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -12,7 +12,7 @@ import pytorch_lightning as pl | ... | @@ -12,7 +12,7 @@ import pytorch_lightning as pl |
12 | import torch | 12 | import torch |
13 | from torch.utils.data import DataLoader | 13 | from torch.utils.data import DataLoader |
14 | 14 | ||
15 | -from lightning_base import BaseTransformer, add_generic_args, generic_train | 15 | +from train.lightning_base import BaseTransformer, add_generic_args, generic_train |
16 | from transformers import MBartTokenizer, T5ForConditionalGeneration | 16 | from transformers import MBartTokenizer, T5ForConditionalGeneration |
17 | from transformers.modeling_bart import shift_tokens_right | 17 | from transformers.modeling_bart import shift_tokens_right |
18 | 18 | ||
... | @@ -260,16 +260,16 @@ class SummarizationModule(BaseTransformer): | ... | @@ -260,16 +260,16 @@ class SummarizationModule(BaseTransformer): |
260 | def get_dataset(self, type_path) -> Seq2SeqDataset: | 260 | def get_dataset(self, type_path) -> Seq2SeqDataset: |
261 | max_target_length = self.target_lens[type_path] | 261 | max_target_length = self.target_lens[type_path] |
262 | data_config = DataConfig( | 262 | data_config = DataConfig( |
263 | - endpoint=args.endpoint, | 263 | + endpoint=self.hparams.endpoint, |
264 | access_key=os.environ["access_key"], | 264 | access_key=os.environ["access_key"], |
265 | secret_key=os.environ["secret_key"], | 265 | secret_key=os.environ["secret_key"], |
266 | - region=args.region, | 266 | + region=self.hparams.region, |
267 | dataset_name="commit-autosuggestions", | 267 | dataset_name="commit-autosuggestions", |
268 | additional={ | 268 | additional={ |
269 | "mode": ("training" if type_path == "train" else "evaluation"), | 269 | "mode": ("training" if type_path == "train" else "evaluation"), |
270 | "max_source_length": self.hparams.max_source_length, | 270 | "max_source_length": self.hparams.max_source_length, |
271 | "max_target_length": max_target_length, | 271 | "max_target_length": max_target_length, |
272 | - "url": args.url, | 272 | + "url": self.hparams.url, |
273 | }, | 273 | }, |
274 | attributes=[ | 274 | attributes=[ |
275 | ("input_ids", "int32", (self.hparams.max_source_length,)), | 275 | ("input_ids", "int32", (self.hparams.max_source_length,)), |
... | @@ -462,13 +462,3 @@ def main(args, model=None) -> SummarizationModule: | ... | @@ -462,13 +462,3 @@ def main(args, model=None) -> SummarizationModule: |
462 | # test() without a model tests using the best checkpoint automatically | 462 | # test() without a model tests using the best checkpoint automatically |
463 | trainer.test() | 463 | trainer.test() |
464 | return model | 464 | return model |
... | \ No newline at end of file | ... | \ No newline at end of file |
465 | - | ||
466 | - | ||
467 | -if __name__ == "__main__": | ||
468 | - parser = argparse.ArgumentParser() | ||
469 | - parser = pl.Trainer.add_argparse_args(parser) | ||
470 | - parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) | ||
471 | - | ||
472 | - args = parser.parse_args() | ||
473 | - | ||
474 | - main(args) | ... | ... |
... | @@ -21,7 +21,7 @@ from transformers import ( | ... | @@ -21,7 +21,7 @@ from transformers import ( |
21 | PretrainedConfig, | 21 | PretrainedConfig, |
22 | PreTrainedTokenizer, | 22 | PreTrainedTokenizer, |
23 | ) | 23 | ) |
24 | -from modeling_bart import BartForConditionalGeneration | 24 | +from train.modeling_bart import BartForConditionalGeneration |
25 | 25 | ||
26 | from transformers.optimization import ( | 26 | from transformers.optimization import ( |
27 | Adafactor, | 27 | Adafactor, | ... | ... |
... | @@ -41,7 +41,7 @@ from transformers.modeling_outputs import ( | ... | @@ -41,7 +41,7 @@ from transformers.modeling_outputs import ( |
41 | Seq2SeqQuestionAnsweringModelOutput, | 41 | Seq2SeqQuestionAnsweringModelOutput, |
42 | Seq2SeqSequenceClassifierOutput, | 42 | Seq2SeqSequenceClassifierOutput, |
43 | ) | 43 | ) |
44 | -from modeling_utils import PreTrainedModel | 44 | +from train.modeling_utils import PreTrainedModel |
45 | import logging | 45 | import logging |
46 | 46 | ||
47 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | 47 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ... | ... |
... | @@ -39,7 +39,7 @@ from transformers.file_utils import ( | ... | @@ -39,7 +39,7 @@ from transformers.file_utils import ( |
39 | is_torch_tpu_available, | 39 | is_torch_tpu_available, |
40 | replace_return_docstrings, | 40 | replace_return_docstrings, |
41 | ) | 41 | ) |
42 | -from generation_utils import GenerationMixin | 42 | +from train.generation_utils import GenerationMixin |
43 | import logging | 43 | import logging |
44 | 44 | ||
45 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | 45 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ... | ... |
-
Please register or login to post a comment