graykode

(add) patch ids embedding roberta model

1 +# coding=utf-8
2 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 +#
5 +# Licensed under the Apache License, Version 2.0 (the "License");
6 +# you may not use this file except in compliance with the License.
7 +# You may obtain a copy of the License at
8 +#
9 +# http://www.apache.org/licenses/LICENSE-2.0
10 +#
11 +# Unless required by applicable law or agreed to in writing, software
12 +# distributed under the License is distributed on an "AS IS" BASIS,
13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 +# See the License for the specific language governing permissions and
15 +# limitations under the License.
16 +"""PyTorch RoBERTa model. """
17 +
18 +import torch
19 +import torch.nn as nn
20 +
21 +from transformers.modeling_roberta import (
22 + create_position_ids_from_input_ids,
23 + RobertaPreTrainedModel,
24 + RobertaEncoder,
25 + RobertaPooler,
26 + BaseModelOutputWithPooling
27 +)
28 +
29 +class RobertaEmbeddings(nn.Module):
30 + """
31 + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
32 + """
33 +
34 + # Copied from transformers.modeling_bert.BertEmbeddings.__init__
35 + def __init__(self, config):
36 + super().__init__()
37 + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
38 + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
39 + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
40 + self.patch_type_embeddings = nn.Embedding(3, config.hidden_size)
41 +
42 + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
43 + # any TensorFlow checkpoint file
44 + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
45 + self.dropout = nn.Dropout(config.hidden_dropout_prob)
46 +
47 + # position_ids (1, len position emb) is contiguous in memory and exported when serialized
48 + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
49 +
50 + # End copy
51 + self.padding_idx = config.pad_token_id
52 + self.position_embeddings = nn.Embedding(
53 + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
54 + )
55 +
56 + def forward(self, input_ids=None, patch_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
57 + if position_ids is None:
58 + if input_ids is not None:
59 + # Create the position ids from the input token ids. Any padded tokens remain padded.
60 + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
61 + else:
62 + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
63 +
64 + # Copied from transformers.modeling_bert.BertEmbeddings.forward
65 + if input_ids is not None:
66 + input_shape = input_ids.size()
67 + else:
68 + input_shape = inputs_embeds.size()[:-1]
69 +
70 + seq_length = input_shape[1]
71 +
72 + if position_ids is None:
73 + position_ids = self.position_ids[:, :seq_length]
74 +
75 + if token_type_ids is None:
76 + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
77 +
78 + if inputs_embeds is None:
79 + inputs_embeds = self.word_embeddings(input_ids)
80 + position_embeddings = self.position_embeddings(position_ids)
81 + token_type_embeddings = self.token_type_embeddings(token_type_ids)
82 +
83 + embeddings = inputs_embeds + position_embeddings + token_type_embeddings
84 + if patch_ids is not None:
85 + patch_type_embeddings = self.patch_type_embeddings(patch_ids)
86 + embeddings += patch_type_embeddings
87 +
88 + embeddings = self.LayerNorm(embeddings)
89 + embeddings = self.dropout(embeddings)
90 + return embeddings
91 +
92 + def create_position_ids_from_inputs_embeds(self, inputs_embeds):
93 + """We are provided embeddings directly. We cannot infer which are padded so just generate
94 + sequential position ids.
95 +
96 + :param torch.Tensor inputs_embeds:
97 + :return torch.Tensor:
98 + """
99 + input_shape = inputs_embeds.size()[:-1]
100 + sequence_length = input_shape[1]
101 +
102 + position_ids = torch.arange(
103 + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
104 + )
105 + return position_ids.unsqueeze(0).expand(input_shape)
106 +
107 +class RobertaModel(RobertaPreTrainedModel):
108 + """
109 +
110 + The model can behave as an encoder (with only self-attention) as well
111 + as a decoder, in which case a layer of cross-attention is added between
112 + the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
113 + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
114 +
115 + To behave as an decoder the model needs to be initialized with the
116 + :obj:`is_decoder` argument of the configuration set to :obj:`True`.
117 + To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
118 + argument and :obj:`add_cross_attention` set to :obj:`True`; an
119 + :obj:`encoder_hidden_states` is then expected as an input to the forward pass.
120 +
121 + .. _`Attention is all you need`:
122 + https://arxiv.org/abs/1706.03762
123 +
124 + """
125 +
126 + authorized_missing_keys = [r"position_ids"]
127 +
128 + # Copied from transformers.modeling_bert.BertModel.__init__ with Bert->Roberta
129 + def __init__(self, config, add_pooling_layer=True):
130 + super().__init__(config)
131 + self.config = config
132 +
133 + self.embeddings = RobertaEmbeddings(config)
134 + self.encoder = RobertaEncoder(config)
135 +
136 + self.pooler = RobertaPooler(config) if add_pooling_layer else None
137 +
138 + self.init_weights()
139 +
140 + def get_input_embeddings(self):
141 + return self.embeddings.word_embeddings
142 +
143 + def set_input_embeddings(self, value):
144 + self.embeddings.word_embeddings = value
145 +
146 + def _prune_heads(self, heads_to_prune):
147 + """Prunes heads of the model.
148 + heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
149 + See base class PreTrainedModel
150 + """
151 + for layer, heads in heads_to_prune.items():
152 + self.encoder.layer[layer].attention.prune_heads(heads)
153 +
154 + # Copied from transformers.modeling_bert.BertModel.forward
155 + def forward(
156 + self,
157 + input_ids=None,
158 + patch_ids=None,
159 + attention_mask=None,
160 + token_type_ids=None,
161 + position_ids=None,
162 + head_mask=None,
163 + inputs_embeds=None,
164 + encoder_hidden_states=None,
165 + encoder_attention_mask=None,
166 + output_attentions=None,
167 + output_hidden_states=None,
168 + return_dict=None,
169 + ):
170 + r"""
171 + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
172 + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
173 + if the model is configured as a decoder.
174 + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
175 + Mask to avoid performing attention on the padding token indices of the encoder input. This mask
176 + is used in the cross-attention if the model is configured as a decoder.
177 + Mask values selected in ``[0, 1]``:
178 + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
179 + """
180 + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
181 + output_hidden_states = (
182 + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
183 + )
184 + return_dict = return_dict if return_dict is not None else self.config.use_return_dict
185 +
186 + if input_ids is not None and inputs_embeds is not None:
187 + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
188 + elif input_ids is not None:
189 + input_shape = input_ids.size()
190 + elif inputs_embeds is not None:
191 + input_shape = inputs_embeds.size()[:-1]
192 + else:
193 + raise ValueError("You have to specify either input_ids or inputs_embeds")
194 +
195 + device = input_ids.device if input_ids is not None else inputs_embeds.device
196 +
197 + if attention_mask is None:
198 + attention_mask = torch.ones(input_shape, device=device)
199 + if token_type_ids is None:
200 + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
201 +
202 + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
203 + # ourselves in which case we just need to make it broadcastable to all heads.
204 + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
205 +
206 + # If a 2D or 3D attention mask is provided for the cross-attention
207 + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
208 + if self.config.is_decoder and encoder_hidden_states is not None:
209 + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
210 + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
211 + if encoder_attention_mask is None:
212 + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
213 + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
214 + else:
215 + encoder_extended_attention_mask = None
216 +
217 + # Prepare head mask if needed
218 + # 1.0 in head_mask indicate we keep the head
219 + # attention_probs has shape bsz x n_heads x N x N
220 + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
221 + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
222 + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
223 +
224 + embedding_output = self.embeddings(
225 + input_ids=input_ids, patch_ids=patch_ids,
226 + position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
227 + )
228 + encoder_outputs = self.encoder(
229 + embedding_output,
230 + attention_mask=extended_attention_mask,
231 + head_mask=head_mask,
232 + encoder_hidden_states=encoder_hidden_states,
233 + encoder_attention_mask=encoder_extended_attention_mask,
234 + output_attentions=output_attentions,
235 + output_hidden_states=output_hidden_states,
236 + return_dict=return_dict,
237 + )
238 + sequence_output = encoder_outputs[0]
239 + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
240 +
241 + if not return_dict:
242 + return (sequence_output, pooled_output) + encoder_outputs[1:]
243 +
244 + return BaseModelOutputWithPooling(
245 + last_hidden_state=sequence_output,
246 + pooler_output=pooled_output,
247 + hidden_states=encoder_outputs.hidden_states,
248 + attentions=encoder_outputs.attentions,
249 + )
...\ No newline at end of file ...\ No newline at end of file
...@@ -51,8 +51,8 @@ class Seq2Seq(nn.Module): ...@@ -51,8 +51,8 @@ class Seq2Seq(nn.Module):
51 self._tie_or_clone_weights(self.lm_head, 51 self._tie_or_clone_weights(self.lm_head,
52 self.encoder.embeddings.word_embeddings) 52 self.encoder.embeddings.word_embeddings)
53 53
54 - def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None): 54 + def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,patch_ids=None,args=None):
55 - outputs = self.encoder(source_ids, attention_mask=source_mask) 55 + outputs = self.encoder(source_ids, attention_mask=source_mask, patch_ids=patch_ids)
56 encoder_output = outputs[0].permute([1,0,2]).contiguous() 56 encoder_output = outputs[0].permute([1,0,2]).contiguous()
57 if target_ids is not None: 57 if target_ids is not None:
58 attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) 58 attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]])
......
...@@ -35,10 +35,11 @@ from itertools import cycle ...@@ -35,10 +35,11 @@ from itertools import cycle
35 import torch.nn as nn 35 import torch.nn as nn
36 from model import Seq2Seq 36 from model import Seq2Seq
37 from tqdm import tqdm, trange 37 from tqdm import tqdm, trange
38 +from customized_roberta import RobertaModel
38 from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 39 from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
39 from torch.utils.data.distributed import DistributedSampler 40 from torch.utils.data.distributed import DistributedSampler
40 from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 41 from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
41 - RobertaConfig, RobertaModel, RobertaTokenizer) 42 + RobertaConfig, RobertaTokenizer)
42 MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} 43 MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}
43 44
44 logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 45 logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -50,11 +51,13 @@ class Example(object): ...@@ -50,11 +51,13 @@ class Example(object):
50 """A single training/test example.""" 51 """A single training/test example."""
51 def __init__(self, 52 def __init__(self,
52 idx, 53 idx,
53 - source, 54 + added,
55 + deleted,
54 target, 56 target,
55 ): 57 ):
56 self.idx = idx 58 self.idx = idx
57 - self.source = source 59 + self.added = added
60 + self.deleted = deleted
58 self.target = target 61 self.target = target
59 62
60 def read_examples(filename): 63 def read_examples(filename):
...@@ -66,15 +69,12 @@ def read_examples(filename): ...@@ -66,15 +69,12 @@ def read_examples(filename):
66 js=json.loads(line) 69 js=json.loads(line)
67 if 'idx' not in js: 70 if 'idx' not in js:
68 js['idx']=idx 71 js['idx']=idx
69 - code=' '.join(js['code_tokens']).replace('\n',' ')
70 - code=' '.join(code.strip().split())
71 - nl=' '.join(js['docstring_tokens']).replace('\n','')
72 - nl=' '.join(nl.strip().split())
73 examples.append( 72 examples.append(
74 Example( 73 Example(
75 idx = idx, 74 idx = idx,
76 - source=code, 75 + added=js['added'],
77 - target = nl, 76 + deleted=js['deleted'],
77 + target=js['msg'],
78 ) 78 )
79 ) 79 )
80 return examples 80 return examples
...@@ -88,6 +88,7 @@ class InputFeatures(object): ...@@ -88,6 +88,7 @@ class InputFeatures(object):
88 target_ids, 88 target_ids,
89 source_mask, 89 source_mask,
90 target_mask, 90 target_mask,
91 + patch_ids,
91 92
92 ): 93 ):
93 self.example_id = example_id 94 self.example_id = example_id
...@@ -95,6 +96,7 @@ class InputFeatures(object): ...@@ -95,6 +96,7 @@ class InputFeatures(object):
95 self.target_ids = target_ids 96 self.target_ids = target_ids
96 self.source_mask = source_mask 97 self.source_mask = source_mask
97 self.target_mask = target_mask 98 self.target_mask = target_mask
99 + self.patch_ids = patch_ids
98 100
99 101
100 102
...@@ -102,19 +104,26 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): ...@@ -102,19 +104,26 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None):
102 features = [] 104 features = []
103 for example_index, example in enumerate(examples): 105 for example_index, example in enumerate(examples):
104 #source 106 #source
105 - source_tokens = tokenizer.tokenize(example.source)[:args.max_source_length-2] 107 + added_tokens=[tokenizer.cls_token]+example.added+[tokenizer.sep_token]
106 - source_tokens =[tokenizer.cls_token]+source_tokens+[tokenizer.sep_token] 108 + deleted_tokens=example.deleted+[tokenizer.sep_token]
109 + source_tokens = added_tokens + deleted_tokens
110 + patch_ids = [1] * len(added_tokens) + [2] * len(deleted_tokens)
107 source_ids = tokenizer.convert_tokens_to_ids(source_tokens) 111 source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
108 source_mask = [1] * (len(source_tokens)) 112 source_mask = [1] * (len(source_tokens))
109 padding_length = args.max_source_length - len(source_ids) 113 padding_length = args.max_source_length - len(source_ids)
110 source_ids+=[tokenizer.pad_token_id]*padding_length 114 source_ids+=[tokenizer.pad_token_id]*padding_length
115 + patch_ids+=[0]*padding_length
111 source_mask+=[0]*padding_length 116 source_mask+=[0]*padding_length
112 117
118 + assert len(source_ids) == args.max_source_length
119 + assert len(source_mask) == args.max_source_length
120 + assert len(patch_ids) == args.max_source_length
121 +
113 #target 122 #target
114 if stage=="test": 123 if stage=="test":
115 target_tokens = tokenizer.tokenize("None") 124 target_tokens = tokenizer.tokenize("None")
116 else: 125 else:
117 - target_tokens = tokenizer.tokenize(example.target)[:args.max_target_length-2] 126 + target_tokens = (example.target)[:args.max_target_length-2]
118 target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token] 127 target_tokens = [tokenizer.cls_token]+target_tokens+[tokenizer.sep_token]
119 target_ids = tokenizer.convert_tokens_to_ids(target_tokens) 128 target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
120 target_mask = [1] *len(target_ids) 129 target_mask = [1] *len(target_ids)
...@@ -129,6 +138,7 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): ...@@ -129,6 +138,7 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None):
129 138
130 logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens])) 139 logger.info("source_tokens: {}".format([x.replace('\u0120','_') for x in source_tokens]))
131 logger.info("source_ids: {}".format(' '.join(map(str, source_ids)))) 140 logger.info("source_ids: {}".format(' '.join(map(str, source_ids))))
141 + logger.info("patch_ids: {}".format(' '.join(map(str, patch_ids))))
132 logger.info("source_mask: {}".format(' '.join(map(str, source_mask)))) 142 logger.info("source_mask: {}".format(' '.join(map(str, source_mask))))
133 143
134 logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens])) 144 logger.info("target_tokens: {}".format([x.replace('\u0120','_') for x in target_tokens]))
...@@ -142,6 +152,7 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None): ...@@ -142,6 +152,7 @@ def convert_examples_to_features(examples, tokenizer, args,stage=None):
142 target_ids, 152 target_ids,
143 source_mask, 153 source_mask,
144 target_mask, 154 target_mask,
155 + patch_ids,
145 ) 156 )
146 ) 157 )
147 return features 158 return features
...@@ -255,7 +266,7 @@ def main(): ...@@ -255,7 +266,7 @@ def main():
255 tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,do_lower_case=args.do_lower_case) 266 tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,do_lower_case=args.do_lower_case)
256 267
257 #budild model 268 #budild model
258 - encoder = model_class.from_pretrained(args.model_name_or_path,config=config) 269 + encoder = model_class(config=config)
259 decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) 270 decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
260 decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 271 decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
261 model=Seq2Seq(encoder=encoder,decoder=decoder,config=config, 272 model=Seq2Seq(encoder=encoder,decoder=decoder,config=config,
...@@ -263,7 +274,7 @@ def main(): ...@@ -263,7 +274,7 @@ def main():
263 sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id) 274 sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id)
264 if args.load_model_path is not None: 275 if args.load_model_path is not None:
265 logger.info("reload model from {}".format(args.load_model_path)) 276 logger.info("reload model from {}".format(args.load_model_path))
266 - model.load_state_dict(torch.load(args.load_model_path)) 277 + model.load_state_dict(torch.load(args.load_model_path), strict=False)
267 278
268 model.to(device) 279 model.to(device)
269 if args.local_rank != -1: 280 if args.local_rank != -1:
...@@ -289,7 +300,8 @@ def main(): ...@@ -289,7 +300,8 @@ def main():
289 all_source_mask = torch.tensor([f.source_mask for f in train_features], dtype=torch.long) 300 all_source_mask = torch.tensor([f.source_mask for f in train_features], dtype=torch.long)
290 all_target_ids = torch.tensor([f.target_ids for f in train_features], dtype=torch.long) 301 all_target_ids = torch.tensor([f.target_ids for f in train_features], dtype=torch.long)
291 all_target_mask = torch.tensor([f.target_mask for f in train_features], dtype=torch.long) 302 all_target_mask = torch.tensor([f.target_mask for f in train_features], dtype=torch.long)
292 - train_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask) 303 + all_patch_ids = torch.tensor([f.patch_ids for f in train_features], dtype=torch.long)
304 + train_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask,all_patch_ids)
293 305
294 if args.local_rank == -1: 306 if args.local_rank == -1:
295 train_sampler = RandomSampler(train_data) 307 train_sampler = RandomSampler(train_data)
...@@ -327,8 +339,9 @@ def main(): ...@@ -327,8 +339,9 @@ def main():
327 for step in bar: 339 for step in bar:
328 batch = next(train_dataloader) 340 batch = next(train_dataloader)
329 batch = tuple(t.to(device) for t in batch) 341 batch = tuple(t.to(device) for t in batch)
330 - source_ids,source_mask,target_ids,target_mask = batch 342 + source_ids,source_mask,target_ids,target_mask,patch_ids = batch
331 - loss,_,_ = model(source_ids=source_ids,source_mask=source_mask,target_ids=target_ids,target_mask=target_mask) 343 + loss,_,_ = model(source_ids=source_ids,source_mask=source_mask,
344 + target_ids=target_ids,target_mask=target_mask,patch_ids=patch_ids)
332 345
333 if args.n_gpu > 1: 346 if args.n_gpu > 1:
334 loss = loss.mean() # mean() to average on multi-gpu. 347 loss = loss.mean() # mean() to average on multi-gpu.
...@@ -363,7 +376,8 @@ def main(): ...@@ -363,7 +376,8 @@ def main():
363 all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) 376 all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long)
364 all_target_ids = torch.tensor([f.target_ids for f in eval_features], dtype=torch.long) 377 all_target_ids = torch.tensor([f.target_ids for f in eval_features], dtype=torch.long)
365 all_target_mask = torch.tensor([f.target_mask for f in eval_features], dtype=torch.long) 378 all_target_mask = torch.tensor([f.target_mask for f in eval_features], dtype=torch.long)
366 - eval_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask) 379 + all_patch_ids = torch.tensor([f.patch_ids for f in eval_features], dtype=torch.long)
380 + eval_data = TensorDataset(all_source_ids,all_source_mask,all_target_ids,all_target_mask,all_patch_ids)
367 dev_dataset['dev_loss']=eval_examples,eval_data 381 dev_dataset['dev_loss']=eval_examples,eval_data
368 eval_sampler = SequentialSampler(eval_data) 382 eval_sampler = SequentialSampler(eval_data)
369 eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 383 eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
...@@ -377,11 +391,11 @@ def main(): ...@@ -377,11 +391,11 @@ def main():
377 eval_loss,tokens_num = 0,0 391 eval_loss,tokens_num = 0,0
378 for batch in eval_dataloader: 392 for batch in eval_dataloader:
379 batch = tuple(t.to(device) for t in batch) 393 batch = tuple(t.to(device) for t in batch)
380 - source_ids,source_mask,target_ids,target_mask = batch 394 + source_ids,source_mask,target_ids,target_mask,patch_ids = batch
381 395
382 with torch.no_grad(): 396 with torch.no_grad():
383 _,loss,num = model(source_ids=source_ids,source_mask=source_mask, 397 _,loss,num = model(source_ids=source_ids,source_mask=source_mask,
384 - target_ids=target_ids,target_mask=target_mask) 398 + target_ids=target_ids,target_mask=target_mask,patch_ids=patch_ids)
385 eval_loss += loss.sum().item() 399 eval_loss += loss.sum().item()
386 tokens_num += num.sum().item() 400 tokens_num += num.sum().item()
387 #Pring loss of dev dataset 401 #Pring loss of dev dataset
...@@ -423,7 +437,8 @@ def main(): ...@@ -423,7 +437,8 @@ def main():
423 eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test') 437 eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test')
424 all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) 438 all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long)
425 all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) 439 all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long)
426 - eval_data = TensorDataset(all_source_ids,all_source_mask) 440 + all_patch_ids = torch.tensor([f.patch_ids for f in eval_features], dtype=torch.long)
441 + eval_data = TensorDataset(all_source_ids,all_source_mask,all_patch_ids)
427 dev_dataset['dev_bleu']=eval_examples,eval_data 442 dev_dataset['dev_bleu']=eval_examples,eval_data
428 443
429 444
...@@ -435,9 +450,9 @@ def main(): ...@@ -435,9 +450,9 @@ def main():
435 p=[] 450 p=[]
436 for batch in eval_dataloader: 451 for batch in eval_dataloader:
437 batch = tuple(t.to(device) for t in batch) 452 batch = tuple(t.to(device) for t in batch)
438 - source_ids,source_mask= batch 453 + source_ids,source_mask,patch_ids= batch
439 with torch.no_grad(): 454 with torch.no_grad():
440 - preds = model(source_ids=source_ids,source_mask=source_mask) 455 + preds = model(source_ids=source_ids,source_mask=source_mask,patch_ids=patch_ids)
441 for pred in preds: 456 for pred in preds:
442 t=pred[0].cpu().numpy() 457 t=pred[0].cpu().numpy()
443 t=list(t) 458 t=list(t)
...@@ -481,7 +496,8 @@ def main(): ...@@ -481,7 +496,8 @@ def main():
481 eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test') 496 eval_features = convert_examples_to_features(eval_examples, tokenizer, args,stage='test')
482 all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long) 497 all_source_ids = torch.tensor([f.source_ids for f in eval_features], dtype=torch.long)
483 all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long) 498 all_source_mask = torch.tensor([f.source_mask for f in eval_features], dtype=torch.long)
484 - eval_data = TensorDataset(all_source_ids,all_source_mask) 499 + all_patch_ids = torch.tensor([f.patch_ids for f in eval_features], dtype=torch.long)
500 + eval_data = TensorDataset(all_source_ids,all_source_mask,all_patch_ids)
485 501
486 # Calculate bleu 502 # Calculate bleu
487 eval_sampler = SequentialSampler(eval_data) 503 eval_sampler = SequentialSampler(eval_data)
...@@ -491,9 +507,9 @@ def main(): ...@@ -491,9 +507,9 @@ def main():
491 p=[] 507 p=[]
492 for batch in tqdm(eval_dataloader,total=len(eval_dataloader)): 508 for batch in tqdm(eval_dataloader,total=len(eval_dataloader)):
493 batch = tuple(t.to(device) for t in batch) 509 batch = tuple(t.to(device) for t in batch)
494 - source_ids,source_mask= batch 510 + source_ids,source_mask,patch_ids= batch
495 with torch.no_grad(): 511 with torch.no_grad():
496 - preds = model(source_ids=source_ids,source_mask=source_mask) 512 + preds = model(source_ids=source_ids,source_mask=source_mask,patch_ids=patch_ids)
497 for pred in preds: 513 for pred in preds:
498 t=pred[0].cpu().numpy() 514 t=pred[0].cpu().numpy()
499 t=list(t) 515 t=list(t)
......