Toggle navigation
Toggle navigation
This project
Loading...
Sign in
graykode
/
commit-autosuggestions
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
graykode
2020-09-08 11:10:54 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
ab8c5e08f4ebde15a9ad35e309ca039da4ed2004
ab8c5e08
1 parent
3d9cba15
(add) matorage runnable, edit cpyton to pandas in gitignore
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
66 deletions
.gitignore
finetune.py
lightning_base.py
.gitignore
View file @
ab8c5e0
...
...
@@ -137,5 +137,5 @@ dmypy.json
# Cython debug symbols
cython_debug/
cpython
pandas
.idea/
\ No newline at end of file
...
...
finetune.py
View file @
ab8c5e0
...
...
@@ -16,6 +16,9 @@ from lightning_base import BaseTransformer, add_generic_args, generic_train
from
transformers
import
MBartTokenizer
,
T5ForConditionalGeneration
from
transformers.modeling_bart
import
shift_tokens_right
from
matorage
import
DataConfig
from
matorage.torch
import
Dataset
try
:
from
.callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
...
...
@@ -75,18 +78,6 @@ class SummarizationModule(BaseTransformer):
self
.
step_count
=
0
self
.
metrics
=
defaultdict
(
list
)
self
.
dataset_kwargs
:
dict
=
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
prefix
=
self
.
model
.
config
.
prefix
or
""
,
)
n_observations_per_split
=
{
"train"
:
self
.
hparams
.
n_train
,
"val"
:
self
.
hparams
.
n_val
,
"test"
:
self
.
hparams
.
n_test
,
}
self
.
n_obs
=
{
k
:
v
if
v
>=
0
else
None
for
k
,
v
in
n_observations_per_split
.
items
()}
self
.
target_lens
=
{
"train"
:
self
.
hparams
.
max_target_length
,
"val"
:
self
.
hparams
.
val_max_target_length
,
...
...
@@ -107,9 +98,7 @@ class SummarizationModule(BaseTransformer):
if
self
.
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
self
.
model
.
config
.
decoder_start_token_id
=
self
.
decoder_start_token_id
self
.
dataset_class
=
(
Seq2SeqDataset
if
hasattr
(
self
.
tokenizer
,
"prepare_seq2seq_batch"
)
else
LegacySeq2SeqDataset
)
self
.
eval_beams
=
self
.
model
.
config
.
num_beams
if
self
.
hparams
.
eval_beams
is
None
else
self
.
hparams
.
eval_beams
assert
self
.
eval_beams
>=
1
,
f
"got self.eval_beams={self.eval_beams}. Need an integer > 1"
self
.
val_metric
=
self
.
default_val_metric
if
self
.
hparams
.
val_metric
is
None
else
self
.
hparams
.
val_metric
...
...
@@ -137,8 +126,8 @@ class SummarizationModule(BaseTransformer):
def
_step
(
self
,
batch
:
dict
)
->
Tuple
:
pad_token_id
=
self
.
tokenizer
.
pad_token_id
src_ids
,
src_mask
=
batch
[
"input_ids"
],
batch
[
"attention_mask"
]
tgt_ids
=
batch
[
"labels"
]
src_ids
,
src_mask
,
src_patch
=
batch
[
0
]
.
long
(),
batch
[
1
]
.
long
(),
batch
[
2
]
.
long
()
tgt_ids
=
batch
[
3
]
.
long
()
if
isinstance
(
self
.
model
,
T5ForConditionalGeneration
):
decoder_input_ids
=
self
.
model
.
_shift_right
(
tgt_ids
)
else
:
...
...
@@ -168,7 +157,7 @@ class SummarizationModule(BaseTransformer):
logs
=
{
name
:
loss
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
# tokens per batch
logs
[
"tpb"
]
=
batch
[
"input_ids"
]
.
ne
(
self
.
pad
)
.
sum
()
+
batch
[
"labels"
]
.
ne
(
self
.
pad
)
.
sum
()
logs
[
"tpb"
]
=
batch
[
0
]
.
long
()
.
ne
(
self
.
pad
)
.
sum
()
+
batch
[
3
]
.
long
()
.
ne
(
self
.
pad
)
.
sum
()
return
{
"loss"
:
loss_tensors
[
0
],
"log"
:
logs
}
def
validation_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
...
...
@@ -198,14 +187,15 @@ class SummarizationModule(BaseTransformer):
def
_generative_step
(
self
,
batch
:
dict
)
->
dict
:
t0
=
time
.
time
()
generated_ids
=
self
.
model
.
generate
(
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
],
batch
[
0
]
.
long
(),
attention_mask
=
batch
[
1
]
.
long
(),
# patch_ids=batch[2].long(),
use_cache
=
True
,
decoder_start_token_id
=
self
.
decoder_start_token_id
,
)
gen_time
=
(
time
.
time
()
-
t0
)
/
batch
[
"input_ids"
]
.
shape
[
0
]
gen_time
=
(
time
.
time
()
-
t0
)
/
batch
[
0
]
.
shape
[
0
]
preds
:
List
[
str
]
=
self
.
ids_to_clean_text
(
generated_ids
)
target
:
List
[
str
]
=
self
.
ids_to_clean_text
(
batch
[
"labels"
])
target
:
List
[
str
]
=
self
.
ids_to_clean_text
(
batch
[
3
])
loss_tensors
=
self
.
_step
(
batch
)
base_metrics
=
{
name
:
loss
for
name
,
loss
in
zip
(
self
.
loss_names
,
loss_tensors
)}
rouge
:
Dict
=
self
.
calc_generative_metrics
(
preds
,
target
)
...
...
@@ -220,29 +210,34 @@ class SummarizationModule(BaseTransformer):
return
self
.
validation_epoch_end
(
outputs
,
prefix
=
"test"
)
def
get_dataset
(
self
,
type_path
)
->
Seq2SeqDataset
:
n_obs
=
self
.
n_obs
[
type_path
]
max_target_length
=
self
.
target_lens
[
type_path
]
dataset
=
self
.
dataset_class
(
self
.
tokenizer
,
type_path
=
type_path
,
n_obs
=
n_obs
,
max_target_length
=
max_target_length
,
**
self
.
dataset_kwargs
,
data_config
=
DataConfig
(
endpoint
=
args
.
matorage_dir
,
access_key
=
os
.
environ
[
'access_key'
],
secret_key
=
os
.
environ
[
'secret_key'
],
dataset_name
=
'commit-autosuggestions'
,
additional
=
{
"mode"
:
(
"training"
if
type_path
==
"train"
else
"evaluation"
),
"max_source_length"
:
self
.
hparams
.
max_source_length
,
"max_target_length"
:
max_target_length
,
"url"
:
args
.
url
,
},
attributes
=
[
(
'input_ids'
,
'int32'
,
(
self
.
hparams
.
max_source_length
,)),
(
'attention_masks'
,
'int32'
,
(
self
.
hparams
.
max_source_length
,)),
(
'patch_ids'
,
'int32'
,
(
self
.
hparams
.
max_source_length
,)),
(
'targets'
,
'int32'
,
(
max_target_length
,))
]
)
return
dataset
return
Dataset
(
config
=
data_config
,
clear
=
True
)
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
,
shuffle
:
bool
=
False
)
->
DataLoader
:
dataset
=
self
.
get_dataset
(
type_path
)
sampler
=
None
if
self
.
hparams
.
sortish_sampler
and
type_path
==
"train"
:
assert
self
.
hparams
.
gpus
<=
1
# TODO: assert earlier
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
)
shuffle
=
False
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
collate_fn
=
dataset
.
collate_fn
,
shuffle
=
shuffle
,
num_workers
=
self
.
num_workers
,
sampler
=
sampler
,
...
...
@@ -264,6 +259,18 @@ class SummarizationModule(BaseTransformer):
BaseTransformer
.
add_model_specific_args
(
parser
,
root_dir
)
add_generic_args
(
parser
,
root_dir
)
parser
.
add_argument
(
"--url"
,
type
=
str
,
required
=
True
,
help
=
"github url"
)
parser
.
add_argument
(
"--matorage_dir"
,
type
=
str
,
required
=
True
,
help
=
'matorage saved directory.'
)
parser
.
add_argument
(
"--max_source_length"
,
default
=
1024
,
type
=
int
,
...
...
@@ -341,29 +348,8 @@ def main(args, model=None) -> SummarizationModule:
else
:
model
:
SummarizationModule
=
TranslationModule
(
args
)
dataset
=
Path
(
args
.
data_dir
)
.
name
if
(
args
.
logger_name
==
"default"
or
args
.
fast_dev_run
or
str
(
args
.
output_dir
)
.
startswith
(
"/tmp"
)
or
str
(
args
.
output_dir
)
.
startswith
(
"/var"
)
):
logger
=
True
# don't pollute wandb logs unnecessarily
elif
args
.
logger_name
==
"wandb"
:
from
pytorch_lightning.loggers
import
WandbLogger
project
=
os
.
environ
.
get
(
"WANDB_PROJECT"
,
dataset
)
logger
=
WandbLogger
(
name
=
model
.
output_dir
.
name
,
project
=
project
)
elif
args
.
logger_name
==
"wandb_shared"
:
from
pytorch_lightning.loggers
import
WandbLogger
logger
=
WandbLogger
(
name
=
model
.
output_dir
.
name
,
project
=
f
"hf_{dataset}"
)
if
args
.
early_stopping_patience
>=
0
:
es_callback
=
get_early_stopping_callback
(
model
.
val_metric
,
args
.
early_stopping_patience
)
else
:
es_callback
=
False
logger
=
True
es_callback
=
False
trainer
:
pl
.
Trainer
=
generic_train
(
model
,
args
,
...
...
lightning_base.py
View file @
ab8c5e0
...
...
@@ -323,13 +323,6 @@ def add_generic_args(parser, root_dir) -> None:
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the training files for the CoNLL-2003 NER task."
,
)
def
generic_train
(
...
...
Please
register
or
login
to post a comment