Toggle navigation
Toggle navigation
This project
Loading...
Sign in
graykode
/
commit-autosuggestions
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
graykode
2020-09-09 23:39:11 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
ce932c6884ef6673302bcf864500b3e597d18968
ce932c68
1 parent
9f8021c1
(fixed) remove max_source_length, add diff_parse function
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
15 deletions
commit_suggester.py
commit_suggester.py
View file @
ce932c6
...
...
@@ -20,18 +20,25 @@ from transformers import AutoTokenizer
from
preprocess
import
diff_parse
,
truncate
from
train
import
BartForConditionalGeneration
def
get_length
(
chunks
):
cnt
=
0
for
chunk
in
chunks
:
cnt
+=
len
(
chunk
)
return
cnt
def
suggester
(
chunks
,
model
,
tokenizer
,
device
):
max_source_length
=
get_length
(
chunks
)
def
suggester
(
chunks
,
max_source_length
,
model
,
tokenizer
,
device
):
input_ids
,
attention_masks
,
patch_ids
=
zip
(
*
chunks
)
input_ids
=
torch
.
LongTensor
(
[
truncate
(
input_ids
,
max_source_length
,
value
=
0
)])
.
to
(
device
)
input_ids
=
torch
.
LongTensor
(
[
truncate
(
input_ids
,
max_source_length
,
value
=
0
)]
)
.
to
(
device
)
attention_masks
=
torch
.
LongTensor
(
[
truncate
(
attention_masks
,
max_source_length
,
value
=
1
)]
)
.
to
(
device
)
patch_ids
=
torch
.
LongTensor
(
[
truncate
(
patch_ids
,
max_source_length
,
value
=
0
)])
.
to
(
device
)
patch_ids
=
torch
.
LongTensor
(
[
truncate
(
patch_ids
,
max_source_length
,
value
=
0
)]
)
.
to
(
device
)
summaries
=
model
.
generate
(
input_ids
=
input_ids
,
patch_ids
=
patch_ids
,
attention_mask
=
attention_masks
...
...
@@ -59,9 +66,13 @@ def main(args):
staged_files
=
[
f
.
strip
()
for
f
in
staged_files
]
chunks
=
"
\n
"
.
join
(
staged_files
)
chunks
=
diff_parse
(
chunks
,
tokenizer
)
if
not
chunks
:
print
(
'There is no file in staged state.'
)
return
commit_message
=
suggester
(
chunks
,
max_source_length
=
args
.
max_source_length
,
model
=
model
,
tokenizer
=
tokenizer
,
device
=
device
,
...
...
@@ -89,13 +100,6 @@ if __name__ == "__main__":
type
=
str
,
help
=
"Pretrained tokenizer name or path if not the same as model_name"
,
)
parser
.
add_argument
(
"--max_source_length"
,
default
=
1024
,
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
args
=
parser
.
parse_args
()
main
(
args
)
...
...
Please
register
or
login
to post a comment