Showing
20 changed files
with
410 additions
and
286 deletions
autocommit/__init__.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
autocommit/app.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import os | ||
16 | +import torch | ||
17 | +import argparse | ||
18 | +import whatthepatch | ||
19 | +from tqdm import tqdm | ||
20 | +import torch.nn as nn | ||
21 | +from torch.utils.data import TensorDataset, DataLoader, SequentialSampler | ||
22 | +from transformers import (RobertaConfig, RobertaTokenizer) | ||
23 | + | ||
24 | +from autocommit.model import Seq2Seq | ||
25 | +from autocommit.utils import (Example, convert_examples_to_features) | ||
26 | +from autocommit.model.diff_roberta import RobertaModel | ||
27 | + | ||
28 | +from flask import Flask, jsonify, request | ||
29 | + | ||
30 | +app = Flask(__name__) | ||
31 | + | ||
32 | +MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | ||
33 | + | ||
34 | +def get_model(model_class, config, tokenizer, mode): | ||
35 | + encoder = model_class(config=config) | ||
36 | + decoder_layer = nn.TransformerDecoderLayer( | ||
37 | + d_model=config.hidden_size, nhead=config.num_attention_heads | ||
38 | + ) | ||
39 | + decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) | ||
40 | + model = Seq2Seq(encoder=encoder, decoder=decoder, config=config, | ||
41 | + beam_size=args.beam_size, max_length=args.max_target_length, | ||
42 | + sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id) | ||
43 | + | ||
44 | + assert args.load_model_path | ||
45 | + assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin')) | ||
46 | + | ||
47 | + model.load_state_dict( | ||
48 | + torch.load( | ||
49 | + os.path.join(args.load_model_path, mode, 'pytorch_model.bin'), | ||
50 | + map_location=torch.device(args.device) | ||
51 | + ), | ||
52 | + strict=False | ||
53 | + ) | ||
54 | + return model | ||
55 | + | ||
56 | +def get_features(examples): | ||
57 | + features = convert_examples_to_features(examples, args.tokenizer, args, stage='test') | ||
58 | + all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) | ||
59 | + all_source_mask = torch.tensor([f.source_mask for f in features], dtype=torch.long) | ||
60 | + all_patch_ids = torch.tensor([f.patch_ids for f in features], dtype=torch.long) | ||
61 | + return TensorDataset(all_source_ids, all_source_mask, all_patch_ids) | ||
62 | + | ||
63 | +def create_app(): | ||
64 | + @app.route('/') | ||
65 | + def index(): | ||
66 | + return jsonify(hello="world") | ||
67 | + | ||
68 | + @app.route('/added', methods=['POST']) | ||
69 | + def added(): | ||
70 | + if request.method == 'POST': | ||
71 | + payload = request.get_json() | ||
72 | + example = [ | ||
73 | + Example( | ||
74 | + idx=payload['idx'], | ||
75 | + added=payload['added'], | ||
76 | + deleted=payload['deleted'], | ||
77 | + target=None | ||
78 | + ) | ||
79 | + ] | ||
80 | + message = inference(model=args.added_model, data=get_features(example)) | ||
81 | + return jsonify(idx=payload['idx'], message=message) | ||
82 | + | ||
83 | + @app.route('/diff', methods=['POST']) | ||
84 | + def diff(): | ||
85 | + if request.method == 'POST': | ||
86 | + payload = request.get_json() | ||
87 | + example = [ | ||
88 | + Example( | ||
89 | + idx=payload['idx'], | ||
90 | + added=payload['added'], | ||
91 | + deleted=payload['deleted'], | ||
92 | + target=None | ||
93 | + ) | ||
94 | + ] | ||
95 | + message = inference(model=args.diff_model, data=get_features(example)) | ||
96 | + return jsonify(idx=payload['idx'], message=message) | ||
97 | + | ||
98 | + @app.route('/tokenizer', methods=['POST']) | ||
99 | + def tokenizer(): | ||
100 | + if request.method == 'POST': | ||
101 | + payload = request.get_json() | ||
102 | + tokens = args.tokenizer.tokenize(payload['line']) | ||
103 | + return jsonify(tokens=tokens) | ||
104 | + | ||
105 | + return app | ||
106 | + | ||
107 | +def inference(model, data): | ||
108 | + # Calculate bleu | ||
109 | + eval_sampler = SequentialSampler(data) | ||
110 | + eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data)) | ||
111 | + | ||
112 | + model.eval() | ||
113 | + p=[] | ||
114 | + for batch in tqdm(eval_dataloader, total=len(eval_dataloader)): | ||
115 | + batch = tuple(t.to(args.device) for t in batch) | ||
116 | + source_ids, source_mask, patch_ids = batch | ||
117 | + with torch.no_grad(): | ||
118 | + preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids) | ||
119 | + for pred in preds: | ||
120 | + t = pred[0].cpu().numpy() | ||
121 | + t = list(t) | ||
122 | + if 0 in t: | ||
123 | + t = t[:t.index(0)] | ||
124 | + text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False) | ||
125 | + p.append(text) | ||
126 | + return p | ||
127 | + | ||
128 | +def main(args): | ||
129 | + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] | ||
130 | + config = config_class.from_pretrained(args.config_name) | ||
131 | + args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case) | ||
132 | + | ||
133 | + # budild model | ||
134 | + args.added_model =get_model(model_class=model_class, config=config, | ||
135 | + tokenizer=args.tokenizer, mode='added').to(args.device) | ||
136 | + args.diff_model = get_model(model_class=model_class, config=config, | ||
137 | + tokenizer=args.tokenizer, mode='diff').to(args.device) | ||
138 | + | ||
139 | + app = create_app() | ||
140 | + app.run(host=args.host, debug=True, port=args.port) | ||
141 | + | ||
142 | +if __name__ == '__main__': | ||
143 | + parser = argparse.ArgumentParser(description="") | ||
144 | + parser.add_argument("--load_model_path", default='weight', type=str, | ||
145 | + help="Path to trained model: Should contain the .bin files") | ||
146 | + | ||
147 | + parser.add_argument("--model_type", default='roberta', type=str, | ||
148 | + help="Model type: e.g. roberta") | ||
149 | + parser.add_argument("--config_name", default="microsoft/codebert-base", type=str, | ||
150 | + help="Pretrained config name or path if not the same as model_name") | ||
151 | + parser.add_argument("--tokenizer_name", type=str, | ||
152 | + default="microsoft/codebert-base", help="The name of tokenizer", ) | ||
153 | + parser.add_argument("--max_source_length", default=256, type=int, | ||
154 | + help="The maximum total source sequence length after tokenization. Sequences longer " | ||
155 | + "than this will be truncated, sequences shorter will be padded.") | ||
156 | + parser.add_argument("--max_target_length", default=128, type=int, | ||
157 | + help="The maximum total target sequence length after tokenization. Sequences longer " | ||
158 | + "than this will be truncated, sequences shorter will be padded.") | ||
159 | + parser.add_argument("--beam_size", default=10, type=int, | ||
160 | + help="beam size for beam search") | ||
161 | + parser.add_argument("--do_lower_case", action='store_true', | ||
162 | + help="Set this flag if you are using an uncased model.") | ||
163 | + parser.add_argument("--no_cuda", action='store_true', | ||
164 | + help="Avoid using CUDA when available") | ||
165 | + | ||
166 | + parser.add_argument("--host", type=str, default="0.0.0.0") | ||
167 | + parser.add_argument("--port", type=int, default=5000) | ||
168 | + | ||
169 | + args = parser.parse_args() | ||
170 | + | ||
171 | + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") | ||
172 | + | ||
173 | + main(args) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
autocommit/commit.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import subprocess | ||
16 | +import whatthepatch | ||
17 | + | ||
18 | +def preprocessing(diff): | ||
19 | + added_examples, diff_examples = [], [] | ||
20 | + isadded, isdeleted = False, False | ||
21 | + for idx, example in enumerate(whatthepatch.parse_patch(diff)): | ||
22 | + added, deleted = [], [] | ||
23 | + for change in example.changes: | ||
24 | + if change.old == None and change.new != None: | ||
25 | + added.extend(tokenizer.tokenize(change.line)) | ||
26 | + isadded = True | ||
27 | + elif change.old != None and change.new == None: | ||
28 | + deleted.extend(tokenizer.tokenize(change.line)) | ||
29 | + isdeleted = True | ||
30 | + | ||
31 | + if isadded and isdeleted: | ||
32 | + pass | ||
33 | + else: | ||
34 | + pass | ||
35 | + | ||
36 | +def main(): | ||
37 | + proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE) | ||
38 | + staged_files = proc.stdout.readlines() | ||
39 | + staged_files = [f.decode("utf-8") for f in staged_files] | ||
40 | + staged_files = [f.strip() for f in staged_files] | ||
41 | + diffs = "\n".join(staged_files) | ||
42 | + | ||
43 | + | ||
44 | +if __name__ == '__main__': | ||
45 | + main() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
autocommit/model/__init__.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +from autocommit.model.diff_roberta import RobertaModel | ||
16 | +from autocommit.model.model import Seq2Seq | ||
17 | + | ||
18 | +__all__ = [ | ||
19 | + 'RobertaModel', | ||
20 | + 'Seq2Seq' | ||
21 | +] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
autocommit/utils.py
0 → 100644
1 | +# Copyright 2020-present Tae Hwan Jung | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import logging | ||
16 | + | ||
17 | +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||
18 | + datefmt = '%m/%d/%Y %H:%M:%S', | ||
19 | + level = logging.INFO) | ||
20 | +logger = logging.getLogger(__name__) | ||
21 | + | ||
22 | +class Example(object): | ||
23 | + """A single training/test example.""" | ||
24 | + def __init__(self, | ||
25 | + idx, | ||
26 | + added, | ||
27 | + deleted, | ||
28 | + target, | ||
29 | + ): | ||
30 | + self.idx = idx | ||
31 | + self.added = added | ||
32 | + self.deleted = deleted | ||
33 | + self.target = target | ||
34 | + | ||
35 | +class InputFeatures(object): | ||
36 | + """A single training/test features for a example.""" | ||
37 | + def __init__(self, | ||
38 | + example_id, | ||
39 | + source_ids, | ||
40 | + target_ids, | ||
41 | + source_mask, | ||
42 | + target_mask, | ||
43 | + patch_ids, | ||
44 | + | ||
45 | + ): | ||
46 | + self.example_id = example_id | ||
47 | + self.source_ids = source_ids | ||
48 | + self.target_ids = target_ids | ||
49 | + self.source_mask = source_mask | ||
50 | + self.target_mask = target_mask | ||
51 | + self.patch_ids = patch_ids | ||
52 | + | ||
53 | +def convert_examples_to_features(examples, tokenizer, args, stage=None): | ||
54 | + features = [] | ||
55 | + for example_index, example in enumerate(examples): | ||
56 | + # source | ||
57 | + added_tokens = [tokenizer.cls_token] + example.added + [tokenizer.sep_token] | ||
58 | + deleted_tokens = example.deleted + [tokenizer.sep_token] | ||
59 | + source_tokens = added_tokens + deleted_tokens | ||
60 | + patch_ids = [1] * len(added_tokens) + [2] * len(deleted_tokens) | ||
61 | + source_ids = tokenizer.convert_tokens_to_ids(source_tokens) | ||
62 | + source_mask = [1] * (len(source_tokens)) | ||
63 | + padding_length = args.max_source_length - len(source_ids) | ||
64 | + source_ids += [tokenizer.pad_token_id] * padding_length | ||
65 | + patch_ids += [0] * padding_length | ||
66 | + source_mask += [0] * padding_length | ||
67 | + | ||
68 | + # target | ||
69 | + if stage == "test": | ||
70 | + target_tokens = tokenizer.tokenize("None") | ||
71 | + else: | ||
72 | + target_tokens = (example.target)[:args.max_target_length - 2] | ||
73 | + target_tokens = [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token] | ||
74 | + target_ids = tokenizer.convert_tokens_to_ids(target_tokens) | ||
75 | + target_mask = [1] * len(target_ids) | ||
76 | + padding_length = args.max_target_length - len(target_ids) | ||
77 | + target_ids += [tokenizer.pad_token_id] * padding_length | ||
78 | + target_mask += [0] * padding_length | ||
79 | + | ||
80 | + if example_index < 5: | ||
81 | + if stage == 'train': | ||
82 | + logger.info("*** Example ***") | ||
83 | + logger.info("idx: {}".format(example.idx)) | ||
84 | + | ||
85 | + logger.info("source_tokens: {}".format([x.replace('\u0120', '_') for x in source_tokens])) | ||
86 | + logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) | ||
87 | + logger.info("patch_ids: {}".format(' '.join(map(str, patch_ids)))) | ||
88 | + logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) | ||
89 | + | ||
90 | + logger.info("target_tokens: {}".format([x.replace('\u0120', '_') for x in target_tokens])) | ||
91 | + logger.info("target_ids: {}".format(' '.join(map(str, target_ids)))) | ||
92 | + logger.info("target_mask: {}".format(' '.join(map(str, target_mask)))) | ||
93 | + | ||
94 | + features.append( | ||
95 | + InputFeatures( | ||
96 | + example_index, | ||
97 | + source_ids, | ||
98 | + target_ids, | ||
99 | + source_mask, | ||
100 | + target_mask, | ||
101 | + patch_ids, | ||
102 | + ) | ||
103 | + ) | ||
104 | + | ||
105 | + return features |
autocommit/weight/added/.keep
0 → 100644
File mode changed
autocommit/weight/diff/.keep
0 → 100644
File mode changed
File moved
File moved
File moved
src/api.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -import os | ||
16 | -import torch | ||
17 | -import logging | ||
18 | -from tqdm import tqdm | ||
19 | -import torch.nn as nn | ||
20 | -from torch.utils.data import TensorDataset, DataLoader, SequentialSampler | ||
21 | -from transformers import (RobertaConfig, RobertaTokenizer) | ||
22 | - | ||
23 | -import argparse | ||
24 | -import whatthepatch | ||
25 | -from train.run import (Example, convert_examples_to_features) | ||
26 | -from train.model import Seq2Seq | ||
27 | -from train.customized_roberta import RobertaModel | ||
28 | - | ||
29 | -MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | ||
30 | - | ||
31 | -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||
32 | - datefmt = '%m/%d/%Y %H:%M:%S', | ||
33 | - level = logging.INFO) | ||
34 | -logger = logging.getLogger(__name__) | ||
35 | - | ||
36 | -def create_examples(diff, tokenizer): | ||
37 | - examples = [] | ||
38 | - for idx, example in enumerate(whatthepatch.parse_patch(diff)): | ||
39 | - added, deleted = [], [] | ||
40 | - for change in example.changes: | ||
41 | - if change.old == None and change.new != None: | ||
42 | - added.extend(tokenizer.tokenize(change.line)) | ||
43 | - elif change.old != None and change.new == None: | ||
44 | - deleted.extend(tokenizer.tokenize(change.line)) | ||
45 | - examples.append( | ||
46 | - Example( | ||
47 | - idx=idx, | ||
48 | - added=added, | ||
49 | - deleted=deleted, | ||
50 | - target=None | ||
51 | - ) | ||
52 | - ) | ||
53 | - | ||
54 | - return examples | ||
55 | - | ||
56 | -def main(args): | ||
57 | - | ||
58 | - config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] | ||
59 | - config = config_class.from_pretrained(args.config_name) | ||
60 | - tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case) | ||
61 | - | ||
62 | - # budild model | ||
63 | - encoder = model_class(config=config) | ||
64 | - decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) | ||
65 | - decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) | ||
66 | - model = Seq2Seq(encoder=encoder, decoder=decoder, config=config, | ||
67 | - beam_size=args.beam_size, max_length=args.max_target_length, | ||
68 | - sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id) | ||
69 | - if args.load_model_path is not None: | ||
70 | - logger.info("reload model from {}".format(args.load_model_path)) | ||
71 | - model.load_state_dict(torch.load(args.load_model_path), strict=False) | ||
72 | - | ||
73 | - model.to(args.device) | ||
74 | - with open("test.source", "r") as f: | ||
75 | - eval_examples = create_examples(f.read(), tokenizer) | ||
76 | - | ||
77 | - test_features = convert_examples_to_features(eval_examples, tokenizer, args, stage='test') | ||
78 | - all_source_ids = torch.tensor([f.source_ids for f in test_features], dtype=torch.long) | ||
79 | - all_source_mask = torch.tensor([f.source_mask for f in test_features], dtype=torch.long) | ||
80 | - all_patch_ids = torch.tensor([f.patch_ids for f in test_features], dtype=torch.long) | ||
81 | - test_data = TensorDataset(all_source_ids, all_source_mask, all_patch_ids) | ||
82 | - | ||
83 | - # Calculate bleu | ||
84 | - eval_sampler = SequentialSampler(test_data) | ||
85 | - eval_dataloader = DataLoader(test_data, sampler=eval_sampler, batch_size=len(test_data)) | ||
86 | - | ||
87 | - model.eval() | ||
88 | - for batch in tqdm(eval_dataloader, total=len(eval_dataloader)): | ||
89 | - batch = tuple(t.to(args.device) for t in batch) | ||
90 | - source_ids, source_mask, patch_ids = batch | ||
91 | - with torch.no_grad(): | ||
92 | - preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids) | ||
93 | - for pred in preds: | ||
94 | - t = pred[0].cpu().numpy() | ||
95 | - t = list(t) | ||
96 | - if 0 in t: | ||
97 | - t = t[:t.index(0)] | ||
98 | - text = tokenizer.decode(t, clean_up_tokenization_spaces=False) | ||
99 | - print(text) | ||
100 | - | ||
101 | - | ||
102 | -if __name__ == '__main__': | ||
103 | - parser = argparse.ArgumentParser(description="") | ||
104 | - parser.add_argument("--load_model_path", default=None, type=str, required=True, | ||
105 | - help="Path to trained model: Should contain the .bin files") | ||
106 | - | ||
107 | - parser.add_argument("--model_type", default='roberta', type=str, | ||
108 | - help="Model type: e.g. roberta") | ||
109 | - parser.add_argument("--config_name", default="microsoft/codebert-base", type=str, | ||
110 | - help="Pretrained config name or path if not the same as model_name") | ||
111 | - parser.add_argument("--tokenizer_name", type=str, | ||
112 | - default="microsoft/codebert-base", help="The name of tokenizer", ) | ||
113 | - parser.add_argument("--max_source_length", default=256, type=int, | ||
114 | - help="The maximum total source sequence length after tokenization. Sequences longer " | ||
115 | - "than this will be truncated, sequences shorter will be padded.") | ||
116 | - parser.add_argument("--max_target_length", default=128, type=int, | ||
117 | - help="The maximum total target sequence length after tokenization. Sequences longer " | ||
118 | - "than this will be truncated, sequences shorter will be padded.") | ||
119 | - parser.add_argument("--beam_size", default=10, type=int, | ||
120 | - help="beam size for beam search") | ||
121 | - parser.add_argument("--do_lower_case", action='store_true', | ||
122 | - help="Set this flag if you are using an uncased model.") | ||
123 | - parser.add_argument("--no_cuda", action='store_true', | ||
124 | - help="Avoid using CUDA when available") | ||
125 | - | ||
126 | - args = parser.parse_args() | ||
127 | - | ||
128 | - args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") | ||
129 | - | ||
130 | - main(args) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
src/test.source
deleted
100644 → 0
1 | -diff --git a/src/train/model.py b/src/train/model.py | ||
2 | -index 20e56b3..cab82e5 100644 | ||
3 | ---- a/src/train/model.py | ||
4 | -+++ b/src/train/model.py | ||
5 | -@@ -3,9 +3,7 @@ | ||
6 | - | ||
7 | - import torch | ||
8 | - import torch.nn as nn | ||
9 | --import torch | ||
10 | --from torch.autograd import Variable | ||
11 | --import copy | ||
12 | -+ | ||
13 | - class Seq2Seq(nn.Module): | ||
14 | - """ | ||
15 | - Build Seqence-to-Sequence. | ||
16 | -diff --git a/src/train/run.py b/src/train/run.py | ||
17 | -index 5961ad1..be98fec 100644 | ||
18 | ---- a/src/train/run.py | ||
19 | -+++ b/src/train/run.py | ||
20 | -@@ -22,7 +22,6 @@ using a masked language modeling (MLM) loss. | ||
21 | - from __future__ import absolute_import | ||
22 | - import os | ||
23 | - import sys | ||
24 | --import bleu | ||
25 | - import pickle | ||
26 | - import torch | ||
27 | - import json | ||
28 | -@@ -35,11 +34,14 @@ from itertools import cycle | ||
29 | - import torch.nn as nn | ||
30 | - from model import Seq2Seq | ||
31 | - from tqdm import tqdm, trange | ||
32 | --from customized_roberta import RobertaModel | ||
33 | - from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset | ||
34 | - from torch.utils.data.distributed import DistributedSampler | ||
35 | - from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, | ||
36 | - RobertaConfig, RobertaTokenizer) | ||
37 | -+ | ||
38 | -+import train.bleu as bleu | ||
39 | -+from train.customized_roberta import RobertaModel | ||
40 | -+ | ||
41 | - MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | ||
42 | - | ||
43 | - logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
test/add.diff
0 → 100644
1 | +diff --git a/codebert/code.py b/codebert/code.py | ||
2 | +new file mode 100644 | ||
3 | +index 0000000..b4bc953 | ||
4 | +--- /dev/null | ||
5 | ++++ b/codebert/code.py | ||
6 | +@@ -0,0 +1,21 @@ | ||
7 | ++def dailymotion_download(url, output_dir='.', merge=True, info_only=False, **kwargs): | ||
8 | ++ | ||
9 | ++ html = get_content(rebuilt_url(url)) | ||
10 | ++ info = json.loads(match1(html, r'qualities":({.+?}),"')) | ||
11 | ++ title = match1(html, r'"video_title"\s*:\s*"([^"]+)"') or \ | ||
12 | ++ match1(html, r'"title"\s*:\s*"([^"]+)"') | ||
13 | ++ title = unicodize(title) | ||
14 | ++ | ||
15 | ++ for quality in ['1080','720','480','380','240','144','auto']: | ||
16 | ++ try: | ||
17 | ++ real_url = info[quality][1]["url"] | ||
18 | ++ if real_url: | ||
19 | ++ break | ||
20 | ++ except KeyError: | ||
21 | ++ pass | ||
22 | ++ | ||
23 | ++ mime, ext, size = url_info(real_url) | ||
24 | ++ | ||
25 | ++ print_info(site_info, title, mime, size) | ||
26 | ++ if not info_only: | ||
27 | ++ download_urls([real_url], title, ext, size, output_dir=output_dir, merge=merge) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
test/fixed.diff
0 → 100644
1 | +diff --git a/src/train/model.py b/src/train/model.py | ||
2 | +index 20e56b3..cab82e5 100644 | ||
3 | +--- a/src/train/model.py | ||
4 | ++++ b/src/train/model.py | ||
5 | +@@ -3,9 +3,7 @@ | ||
6 | + | ||
7 | + import torch | ||
8 | + import torch.nn as nn | ||
9 | +-import torch | ||
10 | +-from torch.autograd import Variable | ||
11 | +-import copy | ||
12 | ++ | ||
13 | + class Seq2Seq(nn.Module): | ||
14 | + """ | ||
15 | + Build Seqence-to-Sequence. | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -21,8 +21,6 @@ using a masked language modeling (MLM) loss. | ... | @@ -21,8 +21,6 @@ using a masked language modeling (MLM) loss. |
21 | 21 | ||
22 | from __future__ import absolute_import | 22 | from __future__ import absolute_import |
23 | import os | 23 | import os |
24 | -import sys | ||
25 | -import pickle | ||
26 | import torch | 24 | import torch |
27 | import json | 25 | import json |
28 | import random | 26 | import random |
... | @@ -30,17 +28,17 @@ import logging | ... | @@ -30,17 +28,17 @@ import logging |
30 | import argparse | 28 | import argparse |
31 | import numpy as np | 29 | import numpy as np |
32 | from io import open | 30 | from io import open |
33 | -from itertools import cycle | 31 | +from tqdm import tqdm |
34 | import torch.nn as nn | 32 | import torch.nn as nn |
35 | -from model import Seq2Seq | 33 | +from itertools import cycle |
36 | -from tqdm import tqdm, trange | 34 | + |
37 | -from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset | 35 | +from torch.utils.data import (DataLoader, SequentialSampler, RandomSampler, TensorDataset) |
38 | from torch.utils.data.distributed import DistributedSampler | 36 | from torch.utils.data.distributed import DistributedSampler |
39 | -from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, | 37 | +from transformers import (AdamW, get_linear_schedule_with_warmup, RobertaConfig, RobertaTokenizer) |
40 | - RobertaConfig, RobertaTokenizer) | ||
41 | 38 | ||
42 | -import train.bleu as bleu | 39 | +import bleu |
43 | -from train.customized_roberta import RobertaModel | 40 | +from autocommit.model import Seq2Seq, RobertaModel |
41 | +from autocommit.utils import (convert_examples_to_features, Example) | ||
44 | 42 | ||
45 | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | 43 | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} |
46 | 44 | ||
... | @@ -49,19 +47,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa | ... | @@ -49,19 +47,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa |
49 | level = logging.INFO) | 47 | level = logging.INFO) |
50 | logger = logging.getLogger(__name__) | 48 | logger = logging.getLogger(__name__) |
51 | 49 | ||
52 | -class Example(object): | ||
53 | - """A single training/test example.""" | ||
54 | - def __init__(self, | ||
55 | - idx, | ||
56 | - added, | ||
57 | - deleted, | ||
58 | - target, | ||
59 | - ): | ||
60 | - self.idx = idx | ||
61 | - self.added = added | ||
62 | - self.deleted = deleted | ||
63 | - self.target = target | ||
64 | - | ||
65 | def read_examples(filename): | 50 | def read_examples(filename): |
66 | """Read examples from filename.""" | 51 | """Read examples from filename.""" |
67 | examples=[] | 52 | examples=[] |
... | @@ -82,85 +67,6 @@ def read_examples(filename): | ... | @@ -82,85 +67,6 @@ def read_examples(filename): |
82 | return examples | 67 | return examples |
83 | 68 | ||
84 | 69 | ||
85 | -class InputFeatures(object): | ||
86 | - """A single training/test features for a example.""" | ||
87 | - def __init__(self, | ||
88 | - example_id, | ||
89 | - source_ids, | ||
90 | - target_ids, | ||
91 | - source_mask, | ||
92 | - target_mask, | ||
93 | - patch_ids, | ||
94 | - | ||
95 | - ): | ||
96 | - self.example_id = example_id | ||
97 | - self.source_ids = source_ids | ||
98 | - self.target_ids = target_ids | ||
99 | - self.source_mask = source_mask | ||
100 | - self.target_mask = target_mask | ||
101 | - self.patch_ids = patch_ids | ||
102 | - | ||
103 | - | ||
104 | - | ||
105 | -def convert_examples_to_features(examples, tokenizer, args,stage=None): | ||
106 | - features = [] | ||
107 | - for example_index, example in enumerate(examples): | ||
108 | - #source | ||
109 | - added_tokens=[tokenizer.cls_token]+example.added+[tokenizer.sep_token] | ||
110 | - deleted_tokens=example.deleted+[tokenizer.sep_token] | ||
111 | - source_tokens = added_tokens + deleted_tokens | ||
112 | - patch_ids = [1] * len(added_tokens) + [2] * len(deleted_tokens) | ||
113 | - source_ids = tokenizer.convert_tokens_to_ids(source_tokens) | ||
114 | - source_mask = [1] * (len(source_tokens)) | ||
115 | - padding_length = args.max_source_length - len(source_ids) | ||
116 | - source_ids+=[tokenizer.pad_token_id]*padding_length | ||
117 | - patch_ids+=[0]*padding_length | ||
118 | - source_mask+=[0]*padding_length | ||
119 | - | ||
120 | - assert len(source_ids) == args.max_source_length | ||
121 | - assert len(source_mask) == args.max_source_length | ||
122 | - assert len(patch_ids) == args.max_source_length | ||
123 | - | ||
124 | - #target | ||
125 | - if stage=="test": | ||
126 | - target_tokens = tokenizer.tokenize("None") | ||
127 | - else: | ||
128 | - target_tokens = (example.target)[:args.max_target_length-2] | ||
129 | - target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] | ||
130 | - target_ids = tokenizer.convert_tokens_to_ids(target_tokens) | ||
131 | - target_mask = [1] *len(target_ids) | ||
132 | - padding_length = args.max_target_length - len(target_ids) | ||
133 | - target_ids+=[tokenizer.pad_token_id]*padding_length | ||
134 | - target_mask+=[0]*padding_length | ||
135 | - | ||
136 | - if example_index < 5: | ||
137 | - if stage=='train': | ||
138 | - logger.info("*** Example ***") | ||
139 | - logger.info("idx: {}".format(example.idx)) | ||
140 | - | ||
141 | - logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) | ||
142 | - logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) | ||
143 | - logger.info("patch_ids: {}".format(' '.join(map(str, patch_ids)))) | ||
144 | - logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) | ||
145 | - | ||
146 | - logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) | ||
147 | - logger.info("target_ids: {}".format(' '.join(map(str, target_ids)))) | ||
148 | - logger.info("target_mask: {}".format(' '.join(map(str, target_mask)))) | ||
149 | - | ||
150 | - features.append( | ||
151 | - InputFeatures( | ||
152 | - example_index, | ||
153 | - source_ids, | ||
154 | - target_ids, | ||
155 | - source_mask, | ||
156 | - target_mask, | ||
157 | - patch_ids, | ||
158 | - ) | ||
159 | - ) | ||
160 | - return features | ||
161 | - | ||
162 | - | ||
163 | - | ||
164 | def set_seed(args): | 70 | def set_seed(args): |
165 | """set random seed.""" | 71 | """set random seed.""" |
166 | random.seed(args.seed) | 72 | random.seed(args.seed) |
... | @@ -471,7 +377,7 @@ def main(): | ... | @@ -471,7 +377,7 @@ def main(): |
471 | f1.write(str(gold.idx)+'\t'+' '.join(gold.target)+'\n') | 377 | f1.write(str(gold.idx)+'\t'+' '.join(gold.target)+'\n') |
472 | 378 | ||
473 | (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold")) | 379 | (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold")) |
474 | - dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) | 380 | + dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) |
475 | logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) | 381 | logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) |
476 | logger.info(" "+"*"*20) | 382 | logger.info(" "+"*"*20) |
477 | if dev_bleu>best_bleu: | 383 | if dev_bleu>best_bleu: |
... | @@ -528,7 +434,7 @@ def main(): | ... | @@ -528,7 +434,7 @@ def main(): |
528 | f1.write(str(gold.idx)+'\t'+' '.join(gold.target)+'\n') | 434 | f1.write(str(gold.idx)+'\t'+' '.join(gold.target)+'\n') |
529 | 435 | ||
530 | (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "test_{}.gold".format(idx))) | 436 | (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "test_{}.gold".format(idx))) |
531 | - dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) | 437 | + dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) |
532 | logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) | 438 | logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) |
533 | logger.info(" "+"*"*20) | 439 | logger.info(" "+"*"*20) |
534 | 440 | ... | ... |
-
Please register or login to post a comment