graykode

(add) patch ids embedding roberta model

This diff is collapsed. Click to expand it.
......@@ -51,8 +51,8 @@ class Seq2Seq(nn.Module):
self._tie_or_clone_weights(self.lm_head,
self.encoder.embeddings.word_embeddings)
def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,args=None):
outputs = self.encoder(source_ids, attention_mask=source_mask)
def forward(self, source_ids=None,source_mask=None,target_ids=None,target_mask=None,patch_ids=None,args=None):
outputs = self.encoder(source_ids, attention_mask=source_mask, patch_ids=patch_ids)
encoder_output = outputs[0].permute([1,0,2]).contiguous()
if target_ids is not None:
attn_mask=-1e4 *(1-self.bias[:target_ids.shape[1],:target_ids.shape[1]])
......
This diff is collapsed. Click to expand it.