graykode

(add) customized bart model to modify patch_ids

...@@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer): ...@@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer):
188 t0 = time.time() 188 t0 = time.time()
189 generated_ids = self.model.generate( 189 generated_ids = self.model.generate(
190 batch[0].long(), 190 batch[0].long(),
191 + patch_ids=batch[2].long(),
191 attention_mask=batch[1].long(), 192 attention_mask=batch[1].long(),
192 - # patch_ids=batch[2].long(),
193 use_cache=True, 193 use_cache=True,
194 decoder_start_token_id=self.decoder_start_token_id, 194 decoder_start_token_id=self.decoder_start_token_id,
195 ) 195 )
......
This diff is collapsed. Click to expand it.
...@@ -21,6 +21,8 @@ from transformers import ( ...@@ -21,6 +21,8 @@ from transformers import (
21 PretrainedConfig, 21 PretrainedConfig,
22 PreTrainedTokenizer, 22 PreTrainedTokenizer,
23 ) 23 )
24 +from modeling_bart import BartForConditionalGeneration
25 +
24 from transformers.optimization import ( 26 from transformers.optimization import (
25 Adafactor, 27 Adafactor,
26 get_cosine_schedule_with_warmup, 28 get_cosine_schedule_with_warmup,
...@@ -40,7 +42,7 @@ MODEL_MODES = { ...@@ -40,7 +42,7 @@ MODEL_MODES = {
40 "pretraining": AutoModelForPreTraining, 42 "pretraining": AutoModelForPreTraining,
41 "token-classification": AutoModelForTokenClassification, 43 "token-classification": AutoModelForTokenClassification,
42 "language-modeling": AutoModelWithLMHead, 44 "language-modeling": AutoModelWithLMHead,
43 - "summarization": AutoModelForSeq2SeqLM, 45 + "summarization": BartForConditionalGeneration,
44 "translation": AutoModelForSeq2SeqLM, 46 "translation": AutoModelForSeq2SeqLM,
45 } 47 }
46 48
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.