Showing
3 changed files
with
2 additions
and
2 deletions
code2nl/customized_roberta.py
0 → 100644
This diff is collapsed. Click to expand it.
... | @@ -51,8 +51,8 @@ class Seq2Seq(nn.Module): | ... | @@ -51,8 +51,8 @@ class Seq2Seq(nn.Module): |
51 | self._tie_or_clone_weights(self.lm_head, | 51 | self._tie_or_clone_weights(self.lm_head, |
52 | self.encoder.embeddings.word_embeddings) | 52 | self.encoder.embeddings.word_embeddings) |
53 | 53 | ||
54 | - def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None): | 54 | + def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,patch_ids=None,args=None): |
55 | - outputs = self.encoder(source_ids, attention_mask=source_mask) | 55 | + outputs = self.encoder(source_ids, attention_mask=source_mask, patch_ids=patch_ids) |
56 | encoder_output = outputs[0].permute([1,0,2]).contiguous() | 56 | encoder_output = outputs[0].permute([1,0,2]).contiguous() |
57 | if target_ids is not None: | 57 | if target_ids is not None: |
58 | attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) | 58 | attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) | ... | ... |
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment