graykode

reinit

...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
2 __pycache__/ 2 __pycache__/
3 *.py[cod] 3 *.py[cod]
4 *$py.class 4 *$py.class
5 +*.bin
5 6
6 # C extensions 7 # C extensions
7 *.so 8 *.so
......
1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
...\ No newline at end of file ...\ No newline at end of file
1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import os
16 +import torch
17 +import argparse
18 +import whatthepatch
19 +from tqdm import tqdm
20 +import torch.nn as nn
21 +from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
22 +from transformers import (RobertaConfig, RobertaTokenizer)
23 +
24 +from autocommit.model import Seq2Seq
25 +from autocommit.utils import (Example, convert_examples_to_features)
26 +from autocommit.model.diff_roberta import RobertaModel
27 +
28 +from flask import Flask, jsonify, request
29 +
30 +app = Flask(__name__)
31 +
32 +MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
33 +
34 +def get_model(model_class, config, tokenizer, mode):
35 + encoder = model_class(config=config)
36 + decoder_layer = nn.TransformerDecoderLayer(
37 + d_model=config.hidden_size, nhead=config.num_attention_heads
38 + )
39 + decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
40 + model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,
41 + beam_size=args.beam_size, max_length=args.max_target_length,
42 + sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
43 +
44 + assert args.load_model_path
45 + assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))
46 +
47 + model.load_state_dict(
48 + torch.load(
49 + os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),
50 + map_location=torch.device(args.device)
51 + ),
52 + strict=False
53 + )
54 + return model
55 +
56 +def get_features(examples):
57 + features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')
58 + all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
59 + all_source_mask = torch.tensor([f.source_mask for f in features], dtype=torch.long)
60 + all_patch_ids = torch.tensor([f.patch_ids for f in features], dtype=torch.long)
61 + return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)
62 +
63 +def create_app():
64 + @app.route('/')
65 + def index():
66 + return jsonify(hello="world")
67 +
68 + @app.route('/added', methods=['POST'])
69 + def added():
70 + if request.method == 'POST':
71 + payload = request.get_json()
72 + example = [
73 + Example(
74 + idx=payload['idx'],
75 + added=payload['added'],
76 + deleted=payload['deleted'],
77 + target=None
78 + )
79 + ]
80 + message = inference(model=args.added_model, data=get_features(example))
81 + return jsonify(idx=payload['idx'], message=message)
82 +
83 + @app.route('/diff', methods=['POST'])
84 + def diff():
85 + if request.method == 'POST':
86 + payload = request.get_json()
87 + example = [
88 + Example(
89 + idx=payload['idx'],
90 + added=payload['added'],
91 + deleted=payload['deleted'],
92 + target=None
93 + )
94 + ]
95 + message = inference(model=args.diff_model, data=get_features(example))
96 + return jsonify(idx=payload['idx'], message=message)
97 +
98 + @app.route('/tokenizer', methods=['POST'])
99 + def tokenizer():
100 + if request.method == 'POST':
101 + payload = request.get_json()
102 + tokens = args.tokenizer.tokenize(payload['line'])
103 + return jsonify(tokens=tokens)
104 +
105 + return app
106 +
107 +def inference(model, data):
108 + # Calculate bleu
109 + eval_sampler = SequentialSampler(data)
110 + eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))
111 +
112 + model.eval()
113 + p=[]
114 + for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):
115 + batch = tuple(t.to(args.device) for t in batch)
116 + source_ids, source_mask, patch_ids = batch
117 + with torch.no_grad():
118 + preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)
119 + for pred in preds:
120 + t = pred[0].cpu().numpy()
121 + t = list(t)
122 + if 0 in t:
123 + t = t[:t.index(0)]
124 + text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)
125 + p.append(text)
126 + return p
127 +
128 +def main(args):
129 + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
130 + config = config_class.from_pretrained(args.config_name)
131 + args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
132 +
133 + # budild model
134 + args.added_model =get_model(model_class=model_class, config=config,
135 + tokenizer=args.tokenizer, mode='added').to(args.device)
136 + args.diff_model = get_model(model_class=model_class, config=config,
137 + tokenizer=args.tokenizer, mode='diff').to(args.device)
138 +
139 + app = create_app()
140 + app.run(host=args.host, debug=True, port=args.port)
141 +
142 +if __name__ == '__main__':
143 + parser = argparse.ArgumentParser(description="")
144 + parser.add_argument("--load_model_path", default='weight', type=str,
145 + help="Path to trained model: Should contain the .bin files")
146 +
147 + parser.add_argument("--model_type", default='roberta', type=str,
148 + help="Model type: e.g. roberta")
149 + parser.add_argument("--config_name", default="microsoft/codebert-base", type=str,
150 + help="Pretrained config name or path if not the same as model_name")
151 + parser.add_argument("--tokenizer_name", type=str,
152 + default="microsoft/codebert-base", help="The name of tokenizer", )
153 + parser.add_argument("--max_source_length", default=256, type=int,
154 + help="The maximum total source sequence length after tokenization. Sequences longer "
155 + "than this will be truncated, sequences shorter will be padded.")
156 + parser.add_argument("--max_target_length", default=128, type=int,
157 + help="The maximum total target sequence length after tokenization. Sequences longer "
158 + "than this will be truncated, sequences shorter will be padded.")
159 + parser.add_argument("--beam_size", default=10, type=int,
160 + help="beam size for beam search")
161 + parser.add_argument("--do_lower_case", action='store_true',
162 + help="Set this flag if you are using an uncased model.")
163 + parser.add_argument("--no_cuda", action='store_true',
164 + help="Avoid using CUDA when available")
165 +
166 + parser.add_argument("--host", type=str, default="0.0.0.0")
167 + parser.add_argument("--port", type=int, default=5000)
168 +
169 + args = parser.parse_args()
170 +
171 + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
172 +
173 + main(args)
...\ No newline at end of file ...\ No newline at end of file
1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import subprocess
16 +import whatthepatch
17 +
18 +def preprocessing(diff):
19 + added_examples, diff_examples = [], []
20 + isadded, isdeleted = False, False
21 + for idx, example in enumerate(whatthepatch.parse_patch(diff)):
22 + added, deleted = [], []
23 + for change in example.changes:
24 + if change.old == None and change.new != None:
25 + added.extend(tokenizer.tokenize(change.line))
26 + isadded = True
27 + elif change.old != None and change.new == None:
28 + deleted.extend(tokenizer.tokenize(change.line))
29 + isdeleted = True
30 +
31 + if isadded and isdeleted:
32 + pass
33 + else:
34 + pass
35 +
36 +def main():
37 + proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE)
38 + staged_files = proc.stdout.readlines()
39 + staged_files = [f.decode("utf-8") for f in staged_files]
40 + staged_files = [f.strip() for f in staged_files]
41 + diffs = "\n".join(staged_files)
42 +
43 +
44 +if __name__ == '__main__':
45 + main()
...\ No newline at end of file ...\ No newline at end of file
1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +from autocommit.model.diff_roberta import RobertaModel
16 +from autocommit.model.model import Seq2Seq
17 +
18 +__all__ = [
19 + 'RobertaModel',
20 + 'Seq2Seq'
21 +]
...\ No newline at end of file ...\ No newline at end of file
1 +# Copyright 2020-present Tae Hwan Jung
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import logging
16 +
17 +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
18 + datefmt = '%m/%d/%Y %H:%M:%S',
19 + level = logging.INFO)
20 +logger = logging.getLogger(__name__)
21 +
22 +class Example(object):
23 + """A single training/test example."""
24 + def __init__(self,
25 + idx,
26 + added,
27 + deleted,
28 + target,
29 + ):
30 + self.idx = idx
31 + self.added = added
32 + self.deleted = deleted
33 + self.target = target
34 +
35 +class InputFeatures(object):
36 + """A single training/test features for a example."""
37 + def __init__(self,
38 + example_id,
39 + source_ids,
40 + target_ids,
41 + source_mask,
42 + target_mask,
43 + patch_ids,
44 +
45 + ):
46 + self.example_id = example_id
47 + self.source_ids = source_ids
48 + self.target_ids = target_ids
49 + self.source_mask = source_mask
50 + self.target_mask = target_mask
51 + self.patch_ids = patch_ids
52 +
53 +def convert_examples_to_features(examples, tokenizer, args, stage=None):
54 + features = []
55 + for example_index, example in enumerate(examples):
56 + # source
57 + added_tokens = [tokenizer.cls_token] + example.added + [tokenizer.sep_token]
58 + deleted_tokens = example.deleted + [tokenizer.sep_token]
59 + source_tokens = added_tokens + deleted_tokens
60 + patch_ids = [1] * len(added_tokens) + [2] * len(deleted_tokens)
61 + source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
62 + source_mask = [1] * (len(source_tokens))
63 + padding_length = args.max_source_length - len(source_ids)
64 + source_ids += [tokenizer.pad_token_id] * padding_length
65 + patch_ids += [0] * padding_length
66 + source_mask += [0] * padding_length
67 +
68 + # target
69 + if stage == "test":
70 + target_tokens = tokenizer.tokenize("None")
71 + else:
72 + target_tokens = (example.target)[:args.max_target_length - 2]
73 + target_tokens = [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
74 + target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
75 + target_mask = [1] * len(target_ids)
76 + padding_length = args.max_target_length - len(target_ids)
77 + target_ids += [tokenizer.pad_token_id] * padding_length
78 + target_mask += [0] * padding_length
79 +
80 + if example_index < 5:
81 + if stage == 'train':
82 + logger.info("*** Example ***")
83 + logger.info("idx: {}".format(example.idx))
84 +
85 + logger.info("source_tokens: {}".format([x.replace('\u0120', '_') for x in source_tokens]))
86 + logger.info("source_ids: {}".format(' '.join(map(str, source_ids))))
87 + logger.info("patch_ids: {}".format(' '.join(map(str, patch_ids))))
88 + logger.info("source_mask: {}".format(' '.join(map(str, source_mask))))
89 +
90 + logger.info("target_tokens: {}".format([x.replace('\u0120', '_') for x in target_tokens]))
91 + logger.info("target_ids: {}".format(' '.join(map(str, target_ids))))
92 + logger.info("target_mask: {}".format(' '.join(map(str, target_mask))))
93 +
94 + features.append(
95 + InputFeatures(
96 + example_index,
97 + source_ids,
98 + target_ids,
99 + source_mask,
100 + target_mask,
101 + patch_ids,
102 + )
103 + )
104 +
105 + return features
1 whatthepatch 1 whatthepatch
2 gitpython 2 gitpython
3 -matorage
4 -transformers
5 packaging 3 packaging
6 -
7 -psutil
8 -sacrebleu
9 -pyarrow>=0.16.0
10 -rouge-score
11 -pytorch-lightning==0.8.5
12 -pytest
...\ No newline at end of file ...\ No newline at end of file
......
1 -# Copyright 2020-present Tae Hwan Jung
2 -#
3 -# Licensed under the Apache License, Version 2.0 (the "License");
4 -# you may not use this file except in compliance with the License.
5 -# You may obtain a copy of the License at
6 -#
7 -# http://www.apache.org/licenses/LICENSE-2.0
8 -#
9 -# Unless required by applicable law or agreed to in writing, software
10 -# distributed under the License is distributed on an "AS IS" BASIS,
11 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -# See the License for the specific language governing permissions and
13 -# limitations under the License.
14 -
15 -import os
16 -import torch
17 -import logging
18 -from tqdm import tqdm
19 -import torch.nn as nn
20 -from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
21 -from transformers import (RobertaConfig, RobertaTokenizer)
22 -
23 -import argparse
24 -import whatthepatch
25 -from train.run import (Example, convert_examples_to_features)
26 -from train.model import Seq2Seq
27 -from train.customized_roberta import RobertaModel
28 -
29 -MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
30 -
31 -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
32 - datefmt = '%m/%d/%Y %H:%M:%S',
33 - level = logging.INFO)
34 -logger = logging.getLogger(__name__)
35 -
36 -def create_examples(diff, tokenizer):
37 - examples = []
38 - for idx, example in enumerate(whatthepatch.parse_patch(diff)):
39 - added, deleted = [], []
40 - for change in example.changes:
41 - if change.old == None and change.new != None:
42 - added.extend(tokenizer.tokenize(change.line))
43 - elif change.old != None and change.new == None:
44 - deleted.extend(tokenizer.tokenize(change.line))
45 - examples.append(
46 - Example(
47 - idx=idx,
48 - added=added,
49 - deleted=deleted,
50 - target=None
51 - )
52 - )
53 -
54 - return examples
55 -
56 -def main(args):
57 -
58 - config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
59 - config = config_class.from_pretrained(args.config_name)
60 - tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
61 -
62 - # budild model
63 - encoder = model_class(config=config)
64 - decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
65 - decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
66 - model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,
67 - beam_size=args.beam_size, max_length=args.max_target_length,
68 - sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)
69 - if args.load_model_path is not None:
70 - logger.info("reload model from {}".format(args.load_model_path))
71 - model.load_state_dict(torch.load(args.load_model_path), strict=False)
72 -
73 - model.to(args.device)
74 - with open("test.source", "r") as f:
75 - eval_examples = create_examples(f.read(), tokenizer)
76 -
77 - test_features = convert_examples_to_features(eval_examples, tokenizer, args, stage='test')
78 - all_source_ids = torch.tensor([f.source_ids for f in test_features], dtype=torch.long)
79 - all_source_mask = torch.tensor([f.source_mask for f in test_features], dtype=torch.long)
80 - all_patch_ids = torch.tensor([f.patch_ids for f in test_features], dtype=torch.long)
81 - test_data = TensorDataset(all_source_ids, all_source_mask, all_patch_ids)
82 -
83 - # Calculate bleu
84 - eval_sampler = SequentialSampler(test_data)
85 - eval_dataloader = DataLoader(test_data, sampler=eval_sampler, batch_size=len(test_data))
86 -
87 - model.eval()
88 - for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):
89 - batch = tuple(t.to(args.device) for t in batch)
90 - source_ids, source_mask, patch_ids = batch
91 - with torch.no_grad():
92 - preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)
93 - for pred in preds:
94 - t = pred[0].cpu().numpy()
95 - t = list(t)
96 - if 0 in t:
97 - t = t[:t.index(0)]
98 - text = tokenizer.decode(t, clean_up_tokenization_spaces=False)
99 - print(text)
100 -
101 -
102 -if __name__ == '__main__':
103 - parser = argparse.ArgumentParser(description="")
104 - parser.add_argument("--load_model_path", default=None, type=str, required=True,
105 - help="Path to trained model: Should contain the .bin files")
106 -
107 - parser.add_argument("--model_type", default='roberta', type=str,
108 - help="Model type: e.g. roberta")
109 - parser.add_argument("--config_name", default="microsoft/codebert-base", type=str,
110 - help="Pretrained config name or path if not the same as model_name")
111 - parser.add_argument("--tokenizer_name", type=str,
112 - default="microsoft/codebert-base", help="The name of tokenizer", )
113 - parser.add_argument("--max_source_length", default=256, type=int,
114 - help="The maximum total source sequence length after tokenization. Sequences longer "
115 - "than this will be truncated, sequences shorter will be padded.")
116 - parser.add_argument("--max_target_length", default=128, type=int,
117 - help="The maximum total target sequence length after tokenization. Sequences longer "
118 - "than this will be truncated, sequences shorter will be padded.")
119 - parser.add_argument("--beam_size", default=10, type=int,
120 - help="beam size for beam search")
121 - parser.add_argument("--do_lower_case", action='store_true',
122 - help="Set this flag if you are using an uncased model.")
123 - parser.add_argument("--no_cuda", action='store_true',
124 - help="Avoid using CUDA when available")
125 -
126 - args = parser.parse_args()
127 -
128 - args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
129 -
130 - main(args)
...\ No newline at end of file ...\ No newline at end of file
1 -diff --git a/src/train/model.py b/src/train/model.py
2 -index 20e56b3..cab82e5 100644
3 ---- a/src/train/model.py
4 -+++ b/src/train/model.py
5 -@@ -3,9 +3,7 @@
6 -
7 - import torch
8 - import torch.nn as nn
9 --import torch
10 --from torch.autograd import Variable
11 --import copy
12 -+
13 - class Seq2Seq(nn.Module):
14 - """
15 - Build Seqence-to-Sequence.
16 -diff --git a/src/train/run.py b/src/train/run.py
17 -index 5961ad1..be98fec 100644
18 ---- a/src/train/run.py
19 -+++ b/src/train/run.py
20 -@@ -22,7 +22,6 @@ using a masked language modeling (MLM) loss.
21 - from __future__ import absolute_import
22 - import os
23 - import sys
24 --import bleu
25 - import pickle
26 - import torch
27 - import json
28 -@@ -35,11 +34,14 @@ from itertools import cycle
29 - import torch.nn as nn
30 - from model import Seq2Seq
31 - from tqdm import tqdm, trange
32 --from customized_roberta import RobertaModel
33 - from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
34 - from torch.utils.data.distributed import DistributedSampler
35 - from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
36 - RobertaConfig, RobertaTokenizer)
37 -+
38 -+import train.bleu as bleu
39 -+from train.customized_roberta import RobertaModel
40 -+
41 - MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
42 -
43 - logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
1 +diff --git a/codebert/code.py b/codebert/code.py
2 +new file mode 100644
3 +index 0000000..b4bc953
4 +--- /dev/null
5 ++++ b/codebert/code.py
6 +@@ -0,0 +1,21 @@
7 ++def dailymotion_download(url, output_dir='.', merge=True, info_only=False, **kwargs):
8 ++
9 ++ html = get_content(rebuilt_url(url))
10 ++ info = json.loads(match1(html, r'qualities":({.+?}),"'))
11 ++ title = match1(html, r'"video_title"\s*:\s*"([^"]+)"') or \
12 ++ match1(html, r'"title"\s*:\s*"([^"]+)"')
13 ++ title = unicodize(title)
14 ++
15 ++ for quality in ['1080','720','480','380','240','144','auto']:
16 ++ try:
17 ++ real_url = info[quality][1]["url"]
18 ++ if real_url:
19 ++ break
20 ++ except KeyError:
21 ++ pass
22 ++
23 ++ mime, ext, size = url_info(real_url)
24 ++
25 ++ print_info(site_info, title, mime, size)
26 ++ if not info_only:
27 ++ download_urls([real_url], title, ext, size, output_dir=output_dir, merge=merge)
...\ No newline at end of file ...\ No newline at end of file
1 +diff --git a/src/train/model.py b/src/train/model.py
2 +index 20e56b3..cab82e5 100644
3 +--- a/src/train/model.py
4 ++++ b/src/train/model.py
5 +@@ -3,9 +3,7 @@
6 +
7 + import torch
8 + import torch.nn as nn
9 +-import torch
10 +-from torch.autograd import Variable
11 +-import copy
12 ++
13 + class Seq2Seq(nn.Module):
14 + """
15 + Build Seqence-to-Sequence.
...\ No newline at end of file ...\ No newline at end of file
...@@ -21,8 +21,6 @@ using a masked language modeling (MLM) loss. ...@@ -21,8 +21,6 @@ using a masked language modeling (MLM) loss.
21 21
22 from __future__ import absolute_import 22 from __future__ import absolute_import
23 import os 23 import os
24 -import sys
25 -import pickle
26 import torch 24 import torch
27 import json 25 import json
28 import random 26 import random
...@@ -30,17 +28,17 @@ import logging ...@@ -30,17 +28,17 @@ import logging
30 import argparse 28 import argparse
31 import numpy as np 29 import numpy as np
32 from io import open 30 from io import open
33 -from itertools import cycle 31 +from tqdm import tqdm
34 import torch.nn as nn 32 import torch.nn as nn
35 -from model import Seq2Seq 33 +from itertools import cycle
36 -from tqdm import tqdm, trange 34 +
37 -from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 35 +from torch.utils.data import (DataLoader, SequentialSampler, RandomSampler, TensorDataset)
38 from torch.utils.data.distributed import DistributedSampler 36 from torch.utils.data.distributed import DistributedSampler
39 -from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 37 +from transformers import (AdamW, get_linear_schedule_with_warmup, RobertaConfig, RobertaTokenizer)
40 - RobertaConfig, RobertaTokenizer)
41 38
42 -import train.bleu as bleu 39 +import bleu
43 -from train.customized_roberta import RobertaModel 40 +from autocommit.model import Seq2Seq, RobertaModel
41 +from autocommit.utils import (convert_examples_to_features, Example)
44 42
45 MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} 43 MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
46 44
...@@ -49,19 +47,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa ...@@ -49,19 +47,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
49 level = logging.INFO) 47 level = logging.INFO)
50 logger = logging.getLogger(__name__) 48 logger = logging.getLogger(__name__)
51 49
52 -class Example(object):
53 - """A single training/test example."""
54 - def __init__(self,
55 - idx,
56 - added,
57 - deleted,
58 - target,
59 - ):
60 - self.idx = idx
61 - self.added = added
62 - self.deleted = deleted
63 - self.target = target
64 -
65 def read_examples(filename): 50 def read_examples(filename):
66 """Read examples from filename.""" 51 """Read examples from filename."""
67 examples=[] 52 examples=[]
...@@ -82,85 +67,6 @@ def read_examples(filename): ...@@ -82,85 +67,6 @@ def read_examples(filename):
82 return examples 67 return examples
83 68
84 69
85 -class InputFeatures(object):
86 - """A single training/test features for a example."""
87 - def __init__(self,
88 - example_id,
89 - source_ids,
90 - target_ids,
91 - source_mask,
92 - target_mask,
93 - patch_ids,
94 -
95 - ):
96 - self.example_id = example_id
97 - self.source_ids = source_ids
98 - self.target_ids = target_ids
99 - self.source_mask = source_mask
100 - self.target_mask = target_mask
101 - self.patch_ids = patch_ids
102 -
103 -
104 -
105 -def convert_examples_to_features(examples, tokenizer, args,stage=None):
106 - features = []
107 - for example_index, example in enumerate(examples):
108 - #source
109 - added_tokens=[tokenizer.cls_token]+example.added+[tokenizer.sep_token]
110 - deleted_tokens=example.deleted+[tokenizer.sep_token]
111 - source_tokens = added_tokens + deleted_tokens
112 - patch_ids = [1] * len(added_tokens) + [2] * len(deleted_tokens)
113 - source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
114 - source_mask = [1] * (len(source_tokens))
115 - padding_length = args.max_source_length - len(source_ids)
116 - source_ids+=[tokenizer.pad_token_id]*padding_length
117 - patch_ids+=[0]*padding_length
118 - source_mask+=[0]*padding_length
119 -
120 - assert len(source_ids) == args.max_source_length
121 - assert len(source_mask) == args.max_source_length
122 - assert len(patch_ids) == args.max_source_length
123 -
124 - #target
125 - if stage=="test":
126 - target_tokens = tokenizer.tokenize("None")
127 - else:
128 - target_tokens = (example.target)[:args.max_target_length-2]
129 - target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token]
130 - target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
131 - target_mask = [1] *len(target_ids)
132 - padding_length = args.max_target_length - len(target_ids)
133 - target_ids+=[tokenizer.pad_token_id]*padding_length
134 - target_mask+=[0]*padding_length
135 -
136 - if example_index < 5:
137 - if stage=='train':
138 - logger.info("*** Example ***")
139 - logger.info("idx: {}".format(example.idx))
140 -
141 - logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens]))
142 - logger.info("source_ids: {}".format(' '.join(map(str, source_ids))))
143 - logger.info("patch_ids: {}".format(' '.join(map(str, patch_ids))))
144 - logger.info("source_mask: {}".format(' '.join(map(str, source_mask))))
145 -
146 - logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens]))
147 - logger.info("target_ids: {}".format(' '.join(map(str, target_ids))))
148 - logger.info("target_mask: {}".format(' '.join(map(str, target_mask))))
149 -
150 - features.append(
151 - InputFeatures(
152 - example_index,
153 - source_ids,
154 - target_ids,
155 - source_mask,
156 - target_mask,
157 - patch_ids,
158 - )
159 - )
160 - return features
161 -
162 -
163 -
164 def set_seed(args): 70 def set_seed(args):
165 """set random seed.""" 71 """set random seed."""
166 random.seed(args.seed) 72 random.seed(args.seed)
...@@ -471,7 +377,7 @@ def main(): ...@@ -471,7 +377,7 @@ def main():
471 f1.write(str(gold.idx)+'\t'+' '.join(gold.target)+'\n') 377 f1.write(str(gold.idx)+'\t'+' '.join(gold.target)+'\n')
472 378
473 (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold")) 379 (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold"))
474 - dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) 380 + dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0], 2)
475 logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) 381 logger.info(" %s = %s "%("bleu-4",str(dev_bleu)))
476 logger.info(" "+"*"*20) 382 logger.info(" "+"*"*20)
477 if dev_bleu>best_bleu: 383 if dev_bleu>best_bleu:
...@@ -528,7 +434,7 @@ def main(): ...@@ -528,7 +434,7 @@ def main():
528 f1.write(str(gold.idx)+'\t'+' '.join(gold.target)+'\n') 434 f1.write(str(gold.idx)+'\t'+' '.join(gold.target)+'\n')
529 435
530 (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "test_{}.gold".format(idx))) 436 (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "test_{}.gold".format(idx)))
531 - dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) 437 + dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0], 2)
532 logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) 438 logger.info(" %s = %s "%("bleu-4",str(dev_bleu)))
533 logger.info(" "+"*"*20) 439 logger.info(" "+"*"*20)
534 440
......