Toggle navigation
Toggle navigation
This project
Loading...
Sign in
graykode
/
commit-autosuggestions
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
graykode
2020-11-06 00:23:25 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
a21e790d7fb778becc0e5123aebb23db69c0a6d9
a21e790d
1 parent
4d064d9b
(fix) device cuda or cpu in beam search
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
16 deletions
autocommit/app.py
autocommit/commit.py
autocommit/model/model.py
autocommit/app.py
View file @
a21e790
...
...
@@ -99,7 +99,7 @@ def create_app():
def
tokenizer
():
if
request
.
method
==
'POST'
:
payload
=
request
.
get_json
()
tokens
=
args
.
tokenizer
.
tokenize
(
payload
[
'
lin
e'
])
tokens
=
args
.
tokenizer
.
tokenize
(
payload
[
'
cod
e'
])
return
jsonify
(
tokens
=
tokens
)
return
app
...
...
autocommit/commit.py
View file @
a21e790
...
...
@@ -12,28 +12,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
requests
import
argparse
import
subprocess
import
whatthepatch
def
preprocessing
(
diff
):
added_examples
,
diff_examples
=
[],
[]
isadded
,
isdeleted
=
False
,
False
for
idx
,
example
in
enumerate
(
whatthepatch
.
parse_patch
(
diff
)):
def
tokenizing
(
code
):
data
=
{
"code"
:
code
}
res
=
requests
.
post
(
'http://127.0.0.1:5000/tokenizer'
,
data
=
json
.
dumps
(
data
),
headers
=
args
.
headers
)
return
json
.
loads
(
res
.
text
)[
"tokens"
]
def
preprocessing
(
diffs
):
for
idx
,
example
in
enumerate
(
whatthepatch
.
parse_patch
(
diffs
)):
isadded
,
isdeleted
=
False
,
False
added
,
deleted
=
[],
[]
for
change
in
example
.
changes
:
if
change
.
old
==
None
and
change
.
new
!=
None
:
added
.
extend
(
tokeniz
er
.
tokenize
(
change
.
line
))
added
.
extend
(
tokeniz
ing
(
change
.
line
))
isadded
=
True
elif
change
.
old
!=
None
and
change
.
new
==
None
:
deleted
.
extend
(
tokeniz
er
.
tokenize
(
change
.
line
))
deleted
.
extend
(
tokeniz
ing
(
change
.
line
))
isdeleted
=
True
if
isadded
and
isdeleted
:
pass
else
:
pass
if
isadded
and
isdeleted
:
data
=
{
"idx"
:
idx
,
"added"
:
added
,
"deleted"
:
deleted
}
res
=
requests
.
post
(
'http://127.0.0.1:5000/diff'
,
data
=
json
.
dumps
(
data
),
headers
=
args
.
headers
)
print
(
json
.
loads
(
res
.
text
))
else
:
data
=
{
"idx"
:
idx
,
"added"
:
added
,
"deleted"
:
deleted
}
res
=
requests
.
post
(
'http://127.0.0.1:5000/added'
,
data
=
json
.
dumps
(
data
),
headers
=
args
.
headers
)
print
(
json
.
loads
(
res
.
text
))
def
main
():
proc
=
subprocess
.
Popen
([
"git"
,
"diff"
,
"--cached"
],
stdout
=
subprocess
.
PIPE
)
staged_files
=
proc
.
stdout
.
readlines
()
staged_files
=
[
f
.
decode
(
"utf-8"
)
for
f
in
staged_files
]
...
...
@@ -42,4 +66,10 @@ def main():
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
""
)
parser
.
add_argument
(
"--endpoint"
,
type
=
str
,
default
=
"http://127.0.0.1:5000/"
)
args
=
parser
.
parse_args
()
args
.
headers
=
{
'Content-Type'
:
'application/json; charset=utf-8'
}
main
()
\ No newline at end of file
...
...
autocommit/model/model.py
View file @
a21e790
...
...
@@ -71,12 +71,15 @@ class Seq2Seq(nn.Module):
return
outputs
else
:
#Predict
preds
=
[]
zero
=
torch
.
cuda
.
LongTensor
(
1
)
.
fill_
(
0
)
preds
=
[]
if
source_ids
.
device
.
type
==
'cuda'
:
zero
=
torch
.
cuda
.
LongTensor
(
1
)
.
fill_
(
0
)
elif
source_ids
.
device
.
type
==
'cpu'
:
zero
=
torch
.
LongTensor
(
1
)
.
fill_
(
0
)
for
i
in
range
(
source_ids
.
shape
[
0
]):
context
=
encoder_output
[:,
i
:
i
+
1
]
context_mask
=
source_mask
[
i
:
i
+
1
,:]
beam
=
Beam
(
self
.
beam_size
,
self
.
sos_id
,
self
.
eos_id
)
beam
=
Beam
(
self
.
beam_size
,
self
.
sos_id
,
self
.
eos_id
,
device
=
source_ids
.
device
.
type
)
input_ids
=
beam
.
getCurrentState
()
context
=
context
.
repeat
(
1
,
self
.
beam_size
,
1
)
context_mask
=
context_mask
.
repeat
(
self
.
beam_size
,
1
)
...
...
@@ -103,9 +106,12 @@ class Seq2Seq(nn.Module):
class
Beam
(
object
):
def
__init__
(
self
,
size
,
sos
,
eos
):
def
__init__
(
self
,
size
,
sos
,
eos
,
device
):
self
.
size
=
size
self
.
tt
=
torch
.
cuda
if
device
==
'cuda'
:
self
.
tt
=
torch
.
cuda
elif
device
==
'cpu'
:
self
.
tt
=
torch
# The score for each translation on the beam.
self
.
scores
=
self
.
tt
.
FloatTensor
(
size
)
.
zero_
()
# The backpointers at each time-step.
...
...
Please
register
or
login
to post a comment