Showing
3 changed files
with
52 additions
and
16 deletions
... | @@ -99,7 +99,7 @@ def create_app(): | ... | @@ -99,7 +99,7 @@ def create_app(): |
99 | def tokenizer(): | 99 | def tokenizer(): |
100 | if request.method == 'POST': | 100 | if request.method == 'POST': |
101 | payload = request.get_json() | 101 | payload = request.get_json() |
102 | - tokens = args.tokenizer.tokenize(payload['line']) | 102 | + tokens = args.tokenizer.tokenize(payload['code']) |
103 | return jsonify(tokens=tokens) | 103 | return jsonify(tokens=tokens) |
104 | 104 | ||
105 | return app | 105 | return app | ... | ... |
... | @@ -12,28 +12,52 @@ | ... | @@ -12,28 +12,52 @@ |
12 | # See the License for the specific language governing permissions and | 12 | # See the License for the specific language governing permissions and |
13 | # limitations under the License. | 13 | # limitations under the License. |
14 | 14 | ||
15 | +import json | ||
16 | +import requests | ||
17 | +import argparse | ||
15 | import subprocess | 18 | import subprocess |
16 | import whatthepatch | 19 | import whatthepatch |
17 | 20 | ||
18 | -def preprocessing(diff): | 21 | +def tokenizing(code): |
19 | - added_examples, diff_examples = [], [] | 22 | + data = {"code": code } |
20 | - isadded, isdeleted = False, False | 23 | + res = requests.post( |
21 | - for idx, example in enumerate(whatthepatch.parse_patch(diff)): | 24 | + 'http://127.0.0.1:5000/tokenizer', |
25 | + data=json.dumps(data), | ||
26 | + headers=args.headers | ||
27 | + ) | ||
28 | + return json.loads(res.text)["tokens"] | ||
29 | + | ||
30 | +def preprocessing(diffs): | ||
31 | + for idx, example in enumerate(whatthepatch.parse_patch(diffs)): | ||
32 | + isadded, isdeleted = False, False | ||
22 | added, deleted = [], [] | 33 | added, deleted = [], [] |
23 | for change in example.changes: | 34 | for change in example.changes: |
24 | if change.old == None and change.new != None: | 35 | if change.old == None and change.new != None: |
25 | - added.extend(tokenizer.tokenize(change.line)) | 36 | + added.extend(tokenizing(change.line)) |
26 | isadded = True | 37 | isadded = True |
27 | elif change.old != None and change.new == None: | 38 | elif change.old != None and change.new == None: |
28 | - deleted.extend(tokenizer.tokenize(change.line)) | 39 | + deleted.extend(tokenizing(change.line)) |
29 | isdeleted = True | 40 | isdeleted = True |
30 | 41 | ||
31 | - if isadded and isdeleted: | 42 | + if isadded and isdeleted: |
32 | - pass | 43 | + data = {"idx": idx, "added" : added, "deleted" : deleted} |
33 | - else: | 44 | + res = requests.post( |
34 | - pass | 45 | + 'http://127.0.0.1:5000/diff', |
46 | + data=json.dumps(data), | ||
47 | + headers=args.headers | ||
48 | + ) | ||
49 | + print(json.loads(res.text)) | ||
50 | + else: | ||
51 | + data = {"idx": idx, "added": added, "deleted": deleted} | ||
52 | + res = requests.post( | ||
53 | + 'http://127.0.0.1:5000/added', | ||
54 | + data=json.dumps(data), | ||
55 | + headers=args.headers | ||
56 | + ) | ||
57 | + print(json.loads(res.text)) | ||
35 | 58 | ||
36 | def main(): | 59 | def main(): |
60 | + | ||
37 | proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE) | 61 | proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE) |
38 | staged_files = proc.stdout.readlines() | 62 | staged_files = proc.stdout.readlines() |
39 | staged_files = [f.decode("utf-8") for f in staged_files] | 63 | staged_files = [f.decode("utf-8") for f in staged_files] |
... | @@ -42,4 +66,10 @@ def main(): | ... | @@ -42,4 +66,10 @@ def main(): |
42 | 66 | ||
43 | 67 | ||
44 | if __name__ == '__main__': | 68 | if __name__ == '__main__': |
69 | + parser = argparse.ArgumentParser(description="") | ||
70 | + parser.add_argument("--endpoint", type=str, default="http://127.0.0.1:5000/") | ||
71 | + args = parser.parse_args() | ||
72 | + | ||
73 | + args.headers = {'Content-Type': 'application/json; charset=utf-8'} | ||
74 | + | ||
45 | main() | 75 | main() |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -71,12 +71,15 @@ class Seq2Seq(nn.Module): | ... | @@ -71,12 +71,15 @@ class Seq2Seq(nn.Module): |
71 | return outputs | 71 | return outputs |
72 | else: | 72 | else: |
73 | #Predict | 73 | #Predict |
74 | - preds=[] | 74 | + preds=[] |
75 | - zero=torch.cuda.LongTensor(1).fill_(0) | 75 | + if source_ids.device.type == 'cuda': |
76 | + zero=torch.cuda.LongTensor(1).fill_(0) | ||
77 | + elif source_ids.device.type == 'cpu': | ||
78 | + zero = torch.LongTensor(1).fill_(0) | ||
76 | for i in range(source_ids.shape[0]): | 79 | for i in range(source_ids.shape[0]): |
77 | context=encoder_output[:,i:i+1] | 80 | context=encoder_output[:,i:i+1] |
78 | context_mask=source_mask[i:i+1,:] | 81 | context_mask=source_mask[i:i+1,:] |
79 | - beam = Beam(self.beam_size,self.sos_id,self.eos_id) | 82 | + beam = Beam(self.beam_size,self.sos_id,self.eos_id, device=source_ids.device.type) |
80 | input_ids=beam.getCurrentState() | 83 | input_ids=beam.getCurrentState() |
81 | context=context.repeat(1, self.beam_size,1) | 84 | context=context.repeat(1, self.beam_size,1) |
82 | context_mask=context_mask.repeat(self.beam_size,1) | 85 | context_mask=context_mask.repeat(self.beam_size,1) |
... | @@ -103,9 +106,12 @@ class Seq2Seq(nn.Module): | ... | @@ -103,9 +106,12 @@ class Seq2Seq(nn.Module): |
103 | 106 | ||
104 | 107 | ||
105 | class Beam(object): | 108 | class Beam(object): |
106 | - def __init__(self, size,sos,eos): | 109 | + def __init__(self, size,sos,eos, device): |
107 | self.size = size | 110 | self.size = size |
108 | - self.tt = torch.cuda | 111 | + if device == 'cuda': |
112 | + self.tt = torch.cuda | ||
113 | + elif device == 'cpu': | ||
114 | + self.tt = torch | ||
109 | # The score for each translation on the beam. | 115 | # The score for each translation on the beam. |
110 | self.scores = self.tt.FloatTensor(size).zero_() | 116 | self.scores = self.tt.FloatTensor(size).zero_() |
111 | # The backpointers at each time-step. | 117 | # The backpointers at each time-step. | ... | ... |
-
Please register or login to post a comment