graykode

(add) patch ids embedding roberta model

This diff is collapsed. Click to expand it.
...@@ -51,8 +51,8 @@ class Seq2Seq(nn.Module): ...@@ -51,8 +51,8 @@ class Seq2Seq(nn.Module):
51 self._tie_or_clone_weights(self.lm_head, 51 self._tie_or_clone_weights(self.lm_head,
52 self.encoder.embeddings.word_embeddings) 52 self.encoder.embeddings.word_embeddings)
53 53
54 - def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None): 54 + def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,patch_ids=None,args=None):
55 - outputs = self.encoder(source_ids, attention_mask=source_mask) 55 + outputs = self.encoder(source_ids, attention_mask=source_mask, patch_ids=patch_ids)
56 encoder_output = outputs[0].permute([1,0,2]).contiguous() 56 encoder_output = outputs[0].permute([1,0,2]).contiguous()
57 if target_ids is not None: 57 if target_ids is not None:
58 attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]]) 58 attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]])
......
This diff is collapsed. Click to expand it.