Showing
3 changed files
with
293 additions
and
28 deletions
code2nl/customized_roberta.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | ||
3 | +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
4 | +# | ||
5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | +# you may not use this file except in compliance with the License. | ||
7 | +# You may obtain a copy of the License at | ||
8 | +# | ||
9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | +# | ||
11 | +# Unless required by applicable law or agreed to in writing, software | ||
12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | +# See the License for the specific language governing permissions and | ||
15 | +# limitations under the License. | ||
16 | +"""PyTorch RoBERTa model. """ | ||
17 | + | ||
18 | +import torch | ||
19 | +import torch.nn as nn | ||
20 | + | ||
21 | +from transformers.modeling_roberta import ( | ||
22 | + create_position_ids_from_input_ids, | ||
23 | + RobertaPreTrainedModel, | ||
24 | + RobertaEncoder, | ||
25 | + RobertaPooler, | ||
26 | + BaseModelOutputWithPooling | ||
27 | +) | ||
28 | + | ||
29 | +class RobertaEmbeddings(nn.Module): | ||
30 | + """ | ||
31 | + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. | ||
32 | + """ | ||
33 | + | ||
34 | + # Copied from transformers.modeling_bert.BertEmbeddings.__init__ | ||
35 | + def __init__(self, config): | ||
36 | + super().__init__() | ||
37 | + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) | ||
38 | + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) | ||
39 | + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) | ||
40 | + self.patch_type_embeddings = nn.Embedding(3, config.hidden_size) | ||
41 | + | ||
42 | + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | ||
43 | + # any TensorFlow checkpoint file | ||
44 | + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
45 | + self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||
46 | + | ||
47 | + # position_ids (1, len position emb) is contiguous in memory and exported when serialized | ||
48 | + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) | ||
49 | + | ||
50 | + # End copy | ||
51 | + self.padding_idx = config.pad_token_id | ||
52 | + self.position_embeddings = nn.Embedding( | ||
53 | + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx | ||
54 | + ) | ||
55 | + | ||
56 | + def forward(self, input_ids=None, patch_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): | ||
57 | + if position_ids is None: | ||
58 | + if input_ids is not None: | ||
59 | + # Create the position ids from the input token ids. Any padded tokens remain padded. | ||
60 | + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device) | ||
61 | + else: | ||
62 | + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) | ||
63 | + | ||
64 | + # Copied from transformers.modeling_bert.BertEmbeddings.forward | ||
65 | + if input_ids is not None: | ||
66 | + input_shape = input_ids.size() | ||
67 | + else: | ||
68 | + input_shape = inputs_embeds.size()[:-1] | ||
69 | + | ||
70 | + seq_length = input_shape[1] | ||
71 | + | ||
72 | + if position_ids is None: | ||
73 | + position_ids = self.position_ids[:, :seq_length] | ||
74 | + | ||
75 | + if token_type_ids is None: | ||
76 | + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) | ||
77 | + | ||
78 | + if inputs_embeds is None: | ||
79 | + inputs_embeds = self.word_embeddings(input_ids) | ||
80 | + position_embeddings = self.position_embeddings(position_ids) | ||
81 | + token_type_embeddings = self.token_type_embeddings(token_type_ids) | ||
82 | + | ||
83 | + embeddings = inputs_embeds + position_embeddings + token_type_embeddings | ||
84 | + if patch_ids is not None: | ||
85 | + patch_type_embeddings = self.patch_type_embeddings(patch_ids) | ||
86 | + embeddings += patch_type_embeddings | ||
87 | + | ||
88 | + embeddings = self.LayerNorm(embeddings) | ||
89 | + embeddings = self.dropout(embeddings) | ||
90 | + return embeddings | ||
91 | + | ||
92 | + def create_position_ids_from_inputs_embeds(self, inputs_embeds): | ||
93 | + """We are provided embeddings directly. We cannot infer which are padded so just generate | ||
94 | + sequential position ids. | ||
95 | + | ||
96 | + :param torch.Tensor inputs_embeds: | ||
97 | + :return torch.Tensor: | ||
98 | + """ | ||
99 | + input_shape = inputs_embeds.size()[:-1] | ||
100 | + sequence_length = input_shape[1] | ||
101 | + | ||
102 | + position_ids = torch.arange( | ||
103 | + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device | ||
104 | + ) | ||
105 | + return position_ids.unsqueeze(0).expand(input_shape) | ||
106 | + | ||
107 | +class RobertaModel(RobertaPreTrainedModel): | ||
108 | + """ | ||
109 | + | ||
110 | + The model can behave as an encoder (with only self-attention) as well | ||
111 | + as a decoder, in which case a layer of cross-attention is added between | ||
112 | + the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, | ||
113 | + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. | ||
114 | + | ||
115 | + To behave as an decoder the model needs to be initialized with the | ||
116 | + :obj:`is_decoder` argument of the configuration set to :obj:`True`. | ||
117 | + To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` | ||
118 | + argument and :obj:`add_cross_attention` set to :obj:`True`; an | ||
119 | + :obj:`encoder_hidden_states` is then expected as an input to the forward pass. | ||
120 | + | ||
121 | + .. _`Attention is all you need`: | ||
122 | + https://arxiv.org/abs/1706.03762 | ||
123 | + | ||
124 | + """ | ||
125 | + | ||
126 | + authorized_missing_keys = [r"position_ids"] | ||
127 | + | ||
128 | + # Copied from transformers.modeling_bert.BertModel.__init__ with Bert->Roberta | ||
129 | + def __init__(self, config, add_pooling_layer=True): | ||
130 | + super().__init__(config) | ||
131 | + self.config = config | ||
132 | + | ||
133 | + self.embeddings = RobertaEmbeddings(config) | ||
134 | + self.encoder = RobertaEncoder(config) | ||
135 | + | ||
136 | + self.pooler = RobertaPooler(config) if add_pooling_layer else None | ||
137 | + | ||
138 | + self.init_weights() | ||
139 | + | ||
140 | + def get_input_embeddings(self): | ||
141 | + return self.embeddings.word_embeddings | ||
142 | + | ||
143 | + def set_input_embeddings(self, value): | ||
144 | + self.embeddings.word_embeddings = value | ||
145 | + | ||
146 | + def _prune_heads(self, heads_to_prune): | ||
147 | + """Prunes heads of the model. | ||
148 | + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} | ||
149 | + See base class PreTrainedModel | ||
150 | + """ | ||
151 | + for layer, heads in heads_to_prune.items(): | ||
152 | + self.encoder.layer[layer].attention.prune_heads(heads) | ||
153 | + | ||
154 | + # Copied from transformers.modeling_bert.BertModel.forward | ||
155 | + def forward( | ||
156 | + self, | ||
157 | + input_ids=None, | ||
158 | + patch_ids=None, | ||
159 | + attention_mask=None, | ||
160 | + token_type_ids=None, | ||
161 | + position_ids=None, | ||
162 | + head_mask=None, | ||
163 | + inputs_embeds=None, | ||
164 | + encoder_hidden_states=None, | ||
165 | + encoder_attention_mask=None, | ||
166 | + output_attentions=None, | ||
167 | + output_hidden_states=None, | ||
168 | + return_dict=None, | ||
169 | + ): | ||
170 | + r""" | ||
171 | + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | ||
172 | + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention | ||
173 | + if the model is configured as a decoder. | ||
174 | + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | ||
175 | + Mask to avoid performing attention on the padding token indices of the encoder input. This mask | ||
176 | + is used in the cross-attention if the model is configured as a decoder. | ||
177 | + Mask values selected in ``[0, 1]``: | ||
178 | + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. | ||
179 | + """ | ||
180 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
181 | + output_hidden_states = ( | ||
182 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
183 | + ) | ||
184 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
185 | + | ||
186 | + if input_ids is not None and inputs_embeds is not None: | ||
187 | + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | ||
188 | + elif input_ids is not None: | ||
189 | + input_shape = input_ids.size() | ||
190 | + elif inputs_embeds is not None: | ||
191 | + input_shape = inputs_embeds.size()[:-1] | ||
192 | + else: | ||
193 | + raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
194 | + | ||
195 | + device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
196 | + | ||
197 | + if attention_mask is None: | ||
198 | + attention_mask = torch.ones(input_shape, device=device) | ||
199 | + if token_type_ids is None: | ||
200 | + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) | ||
201 | + | ||
202 | + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
203 | + # ourselves in which case we just need to make it broadcastable to all heads. | ||
204 | + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) | ||
205 | + | ||
206 | + # If a 2D or 3D attention mask is provided for the cross-attention | ||
207 | + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||
208 | + if self.config.is_decoder and encoder_hidden_states is not None: | ||
209 | + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() | ||
210 | + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | ||
211 | + if encoder_attention_mask is None: | ||
212 | + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | ||
213 | + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) | ||
214 | + else: | ||
215 | + encoder_extended_attention_mask = None | ||
216 | + | ||
217 | + # Prepare head mask if needed | ||
218 | + # 1.0 in head_mask indicate we keep the head | ||
219 | + # attention_probs has shape bsz x n_heads x N x N | ||
220 | + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | ||
221 | + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | ||
222 | + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | ||
223 | + | ||
224 | + embedding_output = self.embeddings( | ||
225 | + input_ids=input_ids, patch_ids=patch_ids, | ||
226 | + position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds | ||
227 | + ) | ||
228 | + encoder_outputs = self.encoder( | ||
229 | + embedding_output, | ||
230 | + attention_mask=extended_attention_mask, | ||
231 | + head_mask=head_mask, | ||
232 | + encoder_hidden_states=encoder_hidden_states, | ||
233 | + encoder_attention_mask=encoder_extended_attention_mask, | ||
234 | + output_attentions=output_attentions, | ||
235 | + output_hidden_states=output_hidden_states, | ||
236 | + return_dict=return_dict, | ||
237 | + ) | ||
238 | + sequence_output = encoder_outputs[0] | ||
239 | + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None | ||
240 | + | ||
241 | + if not return_dict: | ||
242 | + return (sequence_output, pooled_output) + encoder_outputs[1:] | ||
243 | + | ||
244 | + return BaseModelOutputWithPooling( | ||
245 | + last_hidden_state=sequence_output, | ||
246 | + pooler_output=pooled_output, | ||
247 | + hidden_states=encoder_outputs.hidden_states, | ||
248 | + attentions=encoder_outputs.attentions, | ||
249 | + ) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
... | @@ -51,8 +51,8 @@ class Seq2Seq(nn.Module): | ... | @@ -51,8 +51,8 @@ class Seq2Seq(nn.Module): |
51 | self._tie_or_clone_weights(self.lm_head, | 51 | self._tie_or_clone_weights(self.lm_head, |
52 | self.encoder.embeddings.word_embeddings) | 52 | self.encoder.embeddings.word_embeddings) |
53 | 53 | ||
54 | - def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None): | 54 | + def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,patch_ids=None,args=None): |
55 | - outputs = self.encoder(source_ids, attention_mask=source_mask) | 55 | + outputs = self.encoder(source_ids, attention_mask=source_mask, patch_ids=patch_ids) |
56 | encoder_output = outputs[0].permute([1,0,2]).contiguous() | 56 | encoder_output = outputs[0].permute([1,0,2]).contiguous() |
57 | if target_ids is not None: | 57 | if target_ids is not None: |
58 | attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) | 58 | attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) | ... | ... |
... | @@ -35,10 +35,11 @@ from itertools import cycle | ... | @@ -35,10 +35,11 @@ from itertools import cycle |
35 | import torch.nn as nn | 35 | import torch.nn as nn |
36 | from model import Seq2Seq | 36 | from model import Seq2Seq |
37 | from tqdm import tqdm, trange | 37 | from tqdm import tqdm, trange |
38 | +from customized_roberta import RobertaModel | ||
38 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset | 39 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset |
39 | from torch.utils.data.distributed import DistributedSampler | 40 | from torch.utils.data.distributed import DistributedSampler |
40 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, | 41 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, |
41 | - RobertaConfig, RobertaModel, RobertaTokenizer) | 42 | + RobertaConfig, RobertaTokenizer) |
42 | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | 43 | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} |
43 | 44 | ||
44 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | 45 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
... | @@ -50,11 +51,13 @@ class Example(object): | ... | @@ -50,11 +51,13 @@ class Example(object): |
50 | """A single training/test example.""" | 51 | """A single training/test example.""" |
51 | def __init__(self, | 52 | def __init__(self, |
52 | idx, | 53 | idx, |
53 | - source, | 54 | + added, |
55 | + deleted, | ||
54 | target, | 56 | target, |
55 | ): | 57 | ): |
56 | self.idx = idx | 58 | self.idx = idx |
57 | - self.source = source | 59 | + self.added = added |
60 | + self.deleted = deleted | ||
58 | self.target = target | 61 | self.target = target |
59 | 62 | ||
60 | def read_examples(filename): | 63 | def read_examples(filename): |
... | @@ -66,15 +69,12 @@ def read_examples(filename): | ... | @@ -66,15 +69,12 @@ def read_examples(filename): |
66 | js=json.loads(line) | 69 | js=json.loads(line) |
67 | if 'idx' not in js: | 70 | if 'idx' not in js: |
68 | js['idx']=idx | 71 | js['idx']=idx |
69 | - code=' '.join(js['code_tokens']).replace('\n',' ') | ||
70 | - code=' '.join(code.strip().split()) | ||
71 | - nl=' '.join(js['docstring_tokens']).replace('\n','') | ||
72 | - nl=' '.join(nl.strip().split()) | ||
73 | examples.append( | 72 | examples.append( |
74 | Example( | 73 | Example( |
75 | idx = idx, | 74 | idx = idx, |
76 | - source=code, | 75 | + added=js['added'], |
77 | - target = nl, | 76 | + deleted=js['deleted'], |
77 | + target=js['msg'], | ||
78 | ) | 78 | ) |
79 | ) | 79 | ) |
80 | return examples | 80 | return examples |
... | @@ -88,6 +88,7 @@ class InputFeatures(object): | ... | @@ -88,6 +88,7 @@ class InputFeatures(object): |
88 | target_ids, | 88 | target_ids, |
89 | source_mask, | 89 | source_mask, |
90 | target_mask, | 90 | target_mask, |
91 | + patch_ids, | ||
91 | 92 | ||
92 | ): | 93 | ): |
93 | self.example_id = example_id | 94 | self.example_id = example_id |
... | @@ -95,6 +96,7 @@ class InputFeatures(object): | ... | @@ -95,6 +96,7 @@ class InputFeatures(object): |
95 | self.target_ids = target_ids | 96 | self.target_ids = target_ids |
96 | self.source_mask = source_mask | 97 | self.source_mask = source_mask |
97 | self.target_mask = target_mask | 98 | self.target_mask = target_mask |
99 | + self.patch_ids = patch_ids | ||
98 | 100 | ||
99 | 101 | ||
100 | 102 | ||
... | @@ -102,19 +104,26 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): | ... | @@ -102,19 +104,26 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): |
102 | features = [] | 104 | features = [] |
103 | for example_index, example in enumerate(examples): | 105 | for example_index, example in enumerate(examples): |
104 | #source | 106 | #source |
105 | - source_tokens = tokenizer.tokenize(example.source)[:args.max_source_length-2] | 107 | + added_tokens=[tokenizer.cls_token]+example.added+[tokenizer.sep_token] |
106 | - source_tokens =[tokenizer.cls_token]+source_tokens+[tokenizer.sep_token] | 108 | + deleted_tokens=example.deleted+[tokenizer.sep_token] |
109 | + source_tokens = added_tokens + deleted_tokens | ||
110 | + patch_ids = [1] * len(added_tokens) + [2] * len(deleted_tokens) | ||
107 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) | 111 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) |
108 | source_mask = [1] * (len(source_tokens)) | 112 | source_mask = [1] * (len(source_tokens)) |
109 | padding_length = args.max_source_length - len(source_ids) | 113 | padding_length = args.max_source_length - len(source_ids) |
110 | source_ids+=[tokenizer.pad_token_id]*padding_length | 114 | source_ids+=[tokenizer.pad_token_id]*padding_length |
115 | + patch_ids+=[0]*padding_length | ||
111 | source_mask+=[0]*padding_length | 116 | source_mask+=[0]*padding_length |
112 | 117 | ||
118 | + assert len(source_ids) == args.max_source_length | ||
119 | + assert len(source_mask) == args.max_source_length | ||
120 | + assert len(patch_ids) == args.max_source_length | ||
121 | + | ||
113 | #target | 122 | #target |
114 | if stage=="test": | 123 | if stage=="test": |
115 | target_tokens = tokenizer.tokenize("None") | 124 | target_tokens = tokenizer.tokenize("None") |
116 | else: | 125 | else: |
117 | - target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2] | 126 | + target_tokens = (example.target)[:args.max_target_length-2] |
118 | target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] | 127 | target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] |
119 | target_ids = tokenizer.convert_tokens_to_ids(target_tokens) | 128 | target_ids = tokenizer.convert_tokens_to_ids(target_tokens) |
120 | target_mask = [1] *len(target_ids) | 129 | target_mask = [1] *len(target_ids) |
... | @@ -129,6 +138,7 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): | ... | @@ -129,6 +138,7 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): |
129 | 138 | ||
130 | logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) | 139 | logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) |
131 | logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) | 140 | logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) |
141 | + logger.info("patch_ids: {}".format(' '.join(map(str, patch_ids)))) | ||
132 | logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) | 142 | logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) |
133 | 143 | ||
134 | logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) | 144 | logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) |
... | @@ -142,6 +152,7 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): | ... | @@ -142,6 +152,7 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): |
142 | target_ids, | 152 | target_ids, |
143 | source_mask, | 153 | source_mask, |
144 | target_mask, | 154 | target_mask, |
155 | + patch_ids, | ||
145 | ) | 156 | ) |
146 | ) | 157 | ) |
147 | return features | 158 | return features |
... | @@ -255,7 +266,7 @@ def main(): | ... | @@ -255,7 +266,7 @@ def main(): |
255 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,do_lower_case=args.do_lower_case) | 266 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,do_lower_case=args.do_lower_case) |
256 | 267 | ||
257 | #budild model | 268 | #budild model |
258 | - encoder = model_class.from_pretrained(args.model_name_or_path,config=config) | 269 | + encoder = model_class(config=config) |
259 | decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) | 270 | decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) |
260 | decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) | 271 | decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) |
261 | model=Seq2Seq(encoder=encoder,decoder=decoder,config=config, | 272 | model=Seq2Seq(encoder=encoder,decoder=decoder,config=config, |
... | @@ -263,7 +274,7 @@ def main(): | ... | @@ -263,7 +274,7 @@ def main(): |
263 | sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id) | 274 | sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id) |
264 | if args.load_model_path is not None: | 275 | if args.load_model_path is not None: |
265 | logger.info("reload model from {}".format(args.load_model_path)) | 276 | logger.info("reload model from {}".format(args.load_model_path)) |
266 | - model.load_state_dict(torch.load(args.load_model_path)) | 277 | + model.load_state_dict(torch.load(args.load_model_path), strict=False) |
267 | 278 | ||
268 | model.to(device) | 279 | model.to(device) |
269 | if args.local_rank != -1: | 280 | if args.local_rank != -1: |
... | @@ -289,7 +300,8 @@ def main(): | ... | @@ -289,7 +300,8 @@ def main(): |
289 | all_source_mask = torch.tensor([f.source_mask for f in train_features], dtype=torch.long) | 300 | all_source_mask = torch.tensor([f.source_mask for f in train_features], dtype=torch.long) |
290 | all_target_ids = torch.tensor([f.target_ids for f in train_features], dtype=torch.long) | 301 | all_target_ids = torch.tensor([f.target_ids for f in train_features], dtype=torch.long) |
291 | all_target_mask = torch.tensor([f.target_mask for f in train_features], dtype=torch.long) | 302 | all_target_mask = torch.tensor([f.target_mask for f in train_features], dtype=torch.long) |
292 | - train_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask) | 303 | + all_patch_ids = torch.tensor([f.patch_ids for f in train_features], dtype=torch.long) |
304 | + train_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask,all_patch_ids) | ||
293 | 305 | ||
294 | if args.local_rank == -1: | 306 | if args.local_rank == -1: |
295 | train_sampler = RandomSampler(train_data) | 307 | train_sampler = RandomSampler(train_data) |
... | @@ -327,8 +339,9 @@ def main(): | ... | @@ -327,8 +339,9 @@ def main(): |
327 | for step in bar: | 339 | for step in bar: |
328 | batch = next(train_dataloader) | 340 | batch = next(train_dataloader) |
329 | batch = tuple(t.to(device) for t in batch) | 341 | batch = tuple(t.to(device) for t in batch) |
330 | - source_ids,source_mask,target_ids,target_mask = batch | 342 | + source_ids,source_mask,target_ids,target_mask,patch_ids = batch |
331 | - loss,_,_ = model(source_ids=source_ids,source_mask=source_mask,target_ids=target_ids,target_mask=target_mask) | 343 | + loss,_,_ = model(source_ids=source_ids,source_mask=source_mask, |
344 | + target_ids=target_ids,target_mask=target_mask,patch_ids=patch_ids) | ||
332 | 345 | ||
333 | if args.n_gpu > 1: | 346 | if args.n_gpu > 1: |
334 | loss = loss.mean() # mean() to average on multi-gpu. | 347 | loss = loss.mean() # mean() to average on multi-gpu. |
... | @@ -363,7 +376,8 @@ def main(): | ... | @@ -363,7 +376,8 @@ def main(): |
363 | all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) | 376 | all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) |
364 | all_target_ids = torch.tensor([f.target_ids for f in eval_features], dtype=torch.long) | 377 | all_target_ids = torch.tensor([f.target_ids for f in eval_features], dtype=torch.long) |
365 | all_target_mask = torch.tensor([f.target_mask for f in eval_features], dtype=torch.long) | 378 | all_target_mask = torch.tensor([f.target_mask for f in eval_features], dtype=torch.long) |
366 | - eval_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask) | 379 | + all_patch_ids = torch.tensor([f.patch_ids for f in eval_features], dtype=torch.long) |
380 | + eval_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask,all_patch_ids) | ||
367 | dev_dataset['dev_loss']=eval_examples,eval_data | 381 | dev_dataset['dev_loss']=eval_examples,eval_data |
368 | eval_sampler = SequentialSampler(eval_data) | 382 | eval_sampler = SequentialSampler(eval_data) |
369 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) | 383 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) |
... | @@ -377,11 +391,11 @@ def main(): | ... | @@ -377,11 +391,11 @@ def main(): |
377 | eval_loss,tokens_num = 0,0 | 391 | eval_loss,tokens_num = 0,0 |
378 | for batch in eval_dataloader: | 392 | for batch in eval_dataloader: |
379 | batch = tuple(t.to(device) for t in batch) | 393 | batch = tuple(t.to(device) for t in batch) |
380 | - source_ids,source_mask,target_ids,target_mask = batch | 394 | + source_ids,source_mask,target_ids,target_mask,patch_ids = batch |
381 | 395 | ||
382 | with torch.no_grad(): | 396 | with torch.no_grad(): |
383 | _,loss,num = model(source_ids=source_ids,source_mask=source_mask, | 397 | _,loss,num = model(source_ids=source_ids,source_mask=source_mask, |
384 | - target_ids=target_ids,target_mask=target_mask) | 398 | + target_ids=target_ids,target_mask=target_mask,patch_ids=patch_ids) |
385 | eval_loss += loss.sum().item() | 399 | eval_loss += loss.sum().item() |
386 | tokens_num += num.sum().item() | 400 | tokens_num += num.sum().item() |
387 | #Pring loss of dev dataset | 401 | #Pring loss of dev dataset |
... | @@ -423,7 +437,8 @@ def main(): | ... | @@ -423,7 +437,8 @@ def main(): |
423 | eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test') | 437 | eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test') |
424 | all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) | 438 | all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) |
425 | all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) | 439 | all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) |
426 | - eval_data = TensorDataset(all_source_ids,all_source_mask) | 440 | + all_patch_ids = torch.tensor([f.patch_ids for f in eval_features], dtype=torch.long) |
441 | + eval_data = TensorDataset(all_source_ids,all_source_mask,all_patch_ids) | ||
427 | dev_dataset['dev_bleu']=eval_examples,eval_data | 442 | dev_dataset['dev_bleu']=eval_examples,eval_data |
428 | 443 | ||
429 | 444 | ||
... | @@ -435,9 +450,9 @@ def main(): | ... | @@ -435,9 +450,9 @@ def main(): |
435 | p=[] | 450 | p=[] |
436 | for batch in eval_dataloader: | 451 | for batch in eval_dataloader: |
437 | batch = tuple(t.to(device) for t in batch) | 452 | batch = tuple(t.to(device) for t in batch) |
438 | - source_ids,source_mask= batch | 453 | + source_ids,source_mask,patch_ids= batch |
439 | with torch.no_grad(): | 454 | with torch.no_grad(): |
440 | - preds = model(source_ids=source_ids,source_mask=source_mask) | 455 | + preds = model(source_ids=source_ids,source_mask=source_mask,patch_ids=patch_ids) |
441 | for pred in preds: | 456 | for pred in preds: |
442 | t=pred[0].cpu().numpy() | 457 | t=pred[0].cpu().numpy() |
443 | t=list(t) | 458 | t=list(t) |
... | @@ -481,7 +496,8 @@ def main(): | ... | @@ -481,7 +496,8 @@ def main(): |
481 | eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test') | 496 | eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test') |
482 | all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) | 497 | all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) |
483 | all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) | 498 | all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) |
484 | - eval_data = TensorDataset(all_source_ids,all_source_mask) | 499 | + all_patch_ids = torch.tensor([f.patch_ids for f in eval_features], dtype=torch.long) |
500 | + eval_data = TensorDataset(all_source_ids,all_source_mask,all_patch_ids) | ||
485 | 501 | ||
486 | # Calculate bleu | 502 | # Calculate bleu |
487 | eval_sampler = SequentialSampler(eval_data) | 503 | eval_sampler = SequentialSampler(eval_data) |
... | @@ -491,9 +507,9 @@ def main(): | ... | @@ -491,9 +507,9 @@ def main(): |
491 | p=[] | 507 | p=[] |
492 | for batch in tqdm(eval_dataloader,total=len(eval_dataloader)): | 508 | for batch in tqdm(eval_dataloader,total=len(eval_dataloader)): |
493 | batch = tuple(t.to(device) for t in batch) | 509 | batch = tuple(t.to(device) for t in batch) |
494 | - source_ids,source_mask= batch | 510 | + source_ids,source_mask,patch_ids= batch |
495 | with torch.no_grad(): | 511 | with torch.no_grad(): |
496 | - preds = model(source_ids=source_ids,source_mask=source_mask) | 512 | + preds = model(source_ids=source_ids,source_mask=source_mask,patch_ids=patch_ids) |
497 | for pred in preds: | 513 | for pred in preds: |
498 | t=pred[0].cpu().numpy() | 514 | t=pred[0].cpu().numpy() |
499 | t=list(t) | 515 | t=list(t) | ... | ... |
-
Please register or login to post a comment