Toggle navigation
Toggle navigation
This project
Loading...
Sign in
graykode
/
commit-autosuggestions
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
graykode
2020-11-06 00:45:02 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
16c23c2e3795462876d50b96c4d3713f3a27a3f3
16c23c2e
1 parent
a21e790d
(refactor) print message in api
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
9 deletions
autocommit/app.py
autocommit/commit.py
autocommit/app.py
View file @
16c23c2
...
...
@@ -15,7 +15,6 @@
import
os
import
torch
import
argparse
import
whatthepatch
from
tqdm
import
tqdm
import
torch.nn
as
nn
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
SequentialSampler
...
...
@@ -47,7 +46,7 @@ def get_model(model_class, config, tokenizer, mode):
model
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
load_model_path
,
mode
,
'pytorch_model.bin'
),
map_location
=
torch
.
device
(
args
.
device
)
map_location
=
torch
.
device
(
'cpu'
)
),
strict
=
False
)
...
...
@@ -55,9 +54,15 @@ def get_model(model_class, config, tokenizer, mode):
def
get_features
(
examples
):
features
=
convert_examples_to_features
(
examples
,
args
.
tokenizer
,
args
,
stage
=
'test'
)
all_source_ids
=
torch
.
tensor
([
f
.
source_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_source_mask
=
torch
.
tensor
([
f
.
source_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_patch_ids
=
torch
.
tensor
([
f
.
patch_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_source_ids
=
torch
.
tensor
(
[
f
.
source_ids
[:
args
.
max_source_length
]
for
f
in
features
],
dtype
=
torch
.
long
)
all_source_mask
=
torch
.
tensor
(
[
f
.
source_mask
[:
args
.
max_source_length
]
for
f
in
features
],
dtype
=
torch
.
long
)
all_patch_ids
=
torch
.
tensor
(
[
f
.
patch_ids
[:
args
.
max_source_length
]
for
f
in
features
],
dtype
=
torch
.
long
)
return
TensorDataset
(
all_source_ids
,
all_source_mask
,
all_patch_ids
)
def
create_app
():
...
...
@@ -150,7 +155,7 @@ if __name__ == '__main__':
help
=
"Pretrained config name or path if not the same as model_name"
)
parser
.
add_argument
(
"--tokenizer_name"
,
type
=
str
,
default
=
"microsoft/codebert-base"
,
help
=
"The name of tokenizer"
,
)
parser
.
add_argument
(
"--max_source_length"
,
default
=
256
,
type
=
int
,
parser
.
add_argument
(
"--max_source_length"
,
default
=
512
,
type
=
int
,
help
=
"The maximum total source sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
parser
.
add_argument
(
"--max_target_length"
,
default
=
128
,
type
=
int
,
...
...
autocommit/commit.py
View file @
16c23c2
...
...
@@ -27,8 +27,12 @@ def tokenizing(code):
)
return
json
.
loads
(
res
.
text
)[
"tokens"
]
def
preprocessing
(
diffs
):
def
autocommit
(
diffs
):
commit_message
=
[]
for
idx
,
example
in
enumerate
(
whatthepatch
.
parse_patch
(
diffs
)):
if
not
example
.
changes
:
continue
isadded
,
isdeleted
=
False
,
False
added
,
deleted
=
[],
[]
for
change
in
example
.
changes
:
...
...
@@ -46,7 +50,7 @@ def preprocessing(diffs):
data
=
json
.
dumps
(
data
),
headers
=
args
.
headers
)
print
(
json
.
loads
(
res
.
text
))
commit_message
.
append
(
json
.
loads
(
res
.
text
))
else
:
data
=
{
"idx"
:
idx
,
"added"
:
added
,
"deleted"
:
deleted
}
res
=
requests
.
post
(
...
...
@@ -54,7 +58,8 @@ def preprocessing(diffs):
data
=
json
.
dumps
(
data
),
headers
=
args
.
headers
)
print
(
json
.
loads
(
res
.
text
))
commit_message
.
append
(
json
.
loads
(
res
.
text
))
return
commit_message
def
main
():
...
...
@@ -64,6 +69,8 @@ def main():
staged_files
=
[
f
.
strip
()
for
f
in
staged_files
]
diffs
=
"
\n
"
.
join
(
staged_files
)
message
=
autocommit
(
diffs
=
diffs
)
print
(
message
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
""
)
...
...
Please
register
or
login
to post a comment