graykode

(fix) device cuda or cpu in beam search

......@@ -99,7 +99,7 @@ def create_app():
def tokenizer():
if request.method == 'POST':
payload = request.get_json()
tokens = args.tokenizer.tokenize(payload['line'])
tokens = args.tokenizer.tokenize(payload['code'])
return jsonify(tokens=tokens)
return app
......
......@@ -12,28 +12,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import requests
import argparse
import subprocess
import whatthepatch
def preprocessing(diff):
added_examples, diff_examples = [], []
isadded, isdeleted = False, False
for idx, example in enumerate(whatthepatch.parse_patch(diff)):
def tokenizing(code):
data = {"code": code }
res = requests.post(
'http://127.0.0.1:5000/tokenizer',
data=json.dumps(data),
headers=args.headers
)
return json.loads(res.text)["tokens"]
def preprocessing(diffs):
for idx, example in enumerate(whatthepatch.parse_patch(diffs)):
isadded, isdeleted = False, False
added, deleted = [], []
for change in example.changes:
if change.old == None and change.new != None:
added.extend(tokenizer.tokenize(change.line))
added.extend(tokenizing(change.line))
isadded = True
elif change.old != None and change.new == None:
deleted.extend(tokenizer.tokenize(change.line))
deleted.extend(tokenizing(change.line))
isdeleted = True
if isadded and isdeleted:
pass
else:
pass
if isadded and isdeleted:
data = {"idx": idx, "added" : added, "deleted" : deleted}
res = requests.post(
'http://127.0.0.1:5000/diff',
data=json.dumps(data),
headers=args.headers
)
print(json.loads(res.text))
else:
data = {"idx": idx, "added": added, "deleted": deleted}
res = requests.post(
'http://127.0.0.1:5000/added',
data=json.dumps(data),
headers=args.headers
)
print(json.loads(res.text))
def main():
proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE)
staged_files = proc.stdout.readlines()
staged_files = [f.decode("utf-8") for f in staged_files]
......@@ -42,4 +66,10 @@ def main():
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="")
parser.add_argument("--endpoint", type=str, default="http://127.0.0.1:5000/")
args = parser.parse_args()
args.headers = {'Content-Type': 'application/json; charset=utf-8'}
main()
\ No newline at end of file
......
......@@ -71,12 +71,15 @@ class Seq2Seq(nn.Module):
return outputs
else:
#Predict
preds=[]
zero=torch.cuda.LongTensor(1).fill_(0)
preds=[]
if source_ids.device.type == 'cuda':
zero=torch.cuda.LongTensor(1).fill_(0)
elif source_ids.device.type == 'cpu':
zero = torch.LongTensor(1).fill_(0)
for i in range(source_ids.shape[0]):
context=encoder_output[:,i:i+1]
context_mask=source_mask[i:i+1,:]
beam = Beam(self.beam_size,self.sos_id,self.eos_id)
beam = Beam(self.beam_size,self.sos_id,self.eos_id, device=source_ids.device.type)
input_ids=beam.getCurrentState()
context=context.repeat(1, self.beam_size,1)
context_mask=context_mask.repeat(self.beam_size,1)
......@@ -103,9 +106,12 @@ class Seq2Seq(nn.Module):
class Beam(object):
def __init__(self, size,sos,eos):
def __init__(self, size,sos,eos, device):
self.size = size
self.tt = torch.cuda
if device == 'cuda':
self.tt = torch.cuda
elif device == 'cpu':
self.tt = torch
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
# The backpointers at each time-step.
......