graykode

(fix) device cuda or cpu in beam search

...@@ -99,7 +99,7 @@ def create_app(): ...@@ -99,7 +99,7 @@ def create_app():
99 def tokenizer(): 99 def tokenizer():
100 if request.method == 'POST': 100 if request.method == 'POST':
101 payload = request.get_json() 101 payload = request.get_json()
102 - tokens = args.tokenizer.tokenize(payload['line']) 102 + tokens = args.tokenizer.tokenize(payload['code'])
103 return jsonify(tokens=tokens) 103 return jsonify(tokens=tokens)
104 104
105 return app 105 return app
......
...@@ -12,28 +12,52 @@ ...@@ -12,28 +12,52 @@
12 # See the License for the specific language governing permissions and 12 # See the License for the specific language governing permissions and
13 # limitations under the License. 13 # limitations under the License.
14 14
15 +import json
16 +import requests
17 +import argparse
15 import subprocess 18 import subprocess
16 import whatthepatch 19 import whatthepatch
17 20
18 -def preprocessing(diff): 21 +def tokenizing(code):
19 - added_examples, diff_examples = [], [] 22 + data = {"code": code }
20 - isadded, isdeleted = False, False 23 + res = requests.post(
21 - for idx, example in enumerate(whatthepatch.parse_patch(diff)): 24 + 'http://127.0.0.1:5000/tokenizer',
25 + data=json.dumps(data),
26 + headers=args.headers
27 + )
28 + return json.loads(res.text)["tokens"]
29 +
30 +def preprocessing(diffs):
31 + for idx, example in enumerate(whatthepatch.parse_patch(diffs)):
32 + isadded, isdeleted = False, False
22 added, deleted = [], [] 33 added, deleted = [], []
23 for change in example.changes: 34 for change in example.changes:
24 if change.old == None and change.new != None: 35 if change.old == None and change.new != None:
25 - added.extend(tokenizer.tokenize(change.line)) 36 + added.extend(tokenizing(change.line))
26 isadded = True 37 isadded = True
27 elif change.old != None and change.new == None: 38 elif change.old != None and change.new == None:
28 - deleted.extend(tokenizer.tokenize(change.line)) 39 + deleted.extend(tokenizing(change.line))
29 isdeleted = True 40 isdeleted = True
30 41
31 - if isadded and isdeleted: 42 + if isadded and isdeleted:
32 - pass 43 + data = {"idx": idx, "added" : added, "deleted" : deleted}
33 - else: 44 + res = requests.post(
34 - pass 45 + 'http://127.0.0.1:5000/diff',
46 + data=json.dumps(data),
47 + headers=args.headers
48 + )
49 + print(json.loads(res.text))
50 + else:
51 + data = {"idx": idx, "added": added, "deleted": deleted}
52 + res = requests.post(
53 + 'http://127.0.0.1:5000/added',
54 + data=json.dumps(data),
55 + headers=args.headers
56 + )
57 + print(json.loads(res.text))
35 58
36 def main(): 59 def main():
60 +
37 proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE) 61 proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE)
38 staged_files = proc.stdout.readlines() 62 staged_files = proc.stdout.readlines()
39 staged_files = [f.decode("utf-8") for f in staged_files] 63 staged_files = [f.decode("utf-8") for f in staged_files]
...@@ -42,4 +66,10 @@ def main(): ...@@ -42,4 +66,10 @@ def main():
42 66
43 67
44 if __name__ == '__main__': 68 if __name__ == '__main__':
69 + parser = argparse.ArgumentParser(description="")
70 + parser.add_argument("--endpoint", type=str, default="http://127.0.0.1:5000/")
71 + args = parser.parse_args()
72 +
73 + args.headers = {'Content-Type': 'application/json; charset=utf-8'}
74 +
45 main() 75 main()
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -71,12 +71,15 @@ class Seq2Seq(nn.Module): ...@@ -71,12 +71,15 @@ class Seq2Seq(nn.Module):
71 return outputs 71 return outputs
72 else: 72 else:
73 #Predict 73 #Predict
74 - preds=[] 74 + preds=[]
75 - zero=torch.cuda.LongTensor(1).fill_(0) 75 + if source_ids.device.type == 'cuda':
76 + zero=torch.cuda.LongTensor(1).fill_(0)
77 + elif source_ids.device.type == 'cpu':
78 + zero = torch.LongTensor(1).fill_(0)
76 for i in range(source_ids.shape[0]): 79 for i in range(source_ids.shape[0]):
77 context=encoder_output[:,i:i+1] 80 context=encoder_output[:,i:i+1]
78 context_mask=source_mask[i:i+1,:] 81 context_mask=source_mask[i:i+1,:]
79 - beam = Beam(self.beam_size,self.sos_id,self.eos_id) 82 + beam = Beam(self.beam_size,self.sos_id,self.eos_id, device=source_ids.device.type)
80 input_ids=beam.getCurrentState() 83 input_ids=beam.getCurrentState()
81 context=context.repeat(1, self.beam_size,1) 84 context=context.repeat(1, self.beam_size,1)
82 context_mask=context_mask.repeat(self.beam_size,1) 85 context_mask=context_mask.repeat(self.beam_size,1)
...@@ -103,9 +106,12 @@ class Seq2Seq(nn.Module): ...@@ -103,9 +106,12 @@ class Seq2Seq(nn.Module):
103 106
104 107
105 class Beam(object): 108 class Beam(object):
106 - def __init__(self, size,sos,eos): 109 + def __init__(self, size,sos,eos, device):
107 self.size = size 110 self.size = size
108 - self.tt = torch.cuda 111 + if device == 'cuda':
112 + self.tt = torch.cuda
113 + elif device == 'cpu':
114 + self.tt = torch
109 # The score for each translation on the beam. 115 # The score for each translation on the beam.
110 self.scores = self.tt.FloatTensor(size).zero_() 116 self.scores = self.tt.FloatTensor(size).zero_()
111 # The backpointers at each time-step. 117 # The backpointers at each time-step.
......