Showing
5 changed files
with
4 additions
and
2 deletions
... | @@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer): | ... | @@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer): |
188 | t0 = time.time() | 188 | t0 = time.time() |
189 | generated_ids = self.model.generate( | 189 | generated_ids = self.model.generate( |
190 | batch[0].long(), | 190 | batch[0].long(), |
191 | + patch_ids=batch[2].long(), | ||
191 | attention_mask=batch[1].long(), | 192 | attention_mask=batch[1].long(), |
192 | - # patch_ids=batch[2].long(), | ||
193 | use_cache=True, | 193 | use_cache=True, |
194 | decoder_start_token_id=self.decoder_start_token_id, | 194 | decoder_start_token_id=self.decoder_start_token_id, |
195 | ) | 195 | ) | ... | ... |
generation_utils.py
0 → 100644
This diff is collapsed. Click to expand it.
... | @@ -21,6 +21,8 @@ from transformers import ( | ... | @@ -21,6 +21,8 @@ from transformers import ( |
21 | PretrainedConfig, | 21 | PretrainedConfig, |
22 | PreTrainedTokenizer, | 22 | PreTrainedTokenizer, |
23 | ) | 23 | ) |
24 | +from modeling_bart import BartForConditionalGeneration | ||
25 | + | ||
24 | from transformers.optimization import ( | 26 | from transformers.optimization import ( |
25 | Adafactor, | 27 | Adafactor, |
26 | get_cosine_schedule_with_warmup, | 28 | get_cosine_schedule_with_warmup, |
... | @@ -40,7 +42,7 @@ MODEL_MODES = { | ... | @@ -40,7 +42,7 @@ MODEL_MODES = { |
40 | "pretraining": AutoModelForPreTraining, | 42 | "pretraining": AutoModelForPreTraining, |
41 | "token-classification": AutoModelForTokenClassification, | 43 | "token-classification": AutoModelForTokenClassification, |
42 | "language-modeling": AutoModelWithLMHead, | 44 | "language-modeling": AutoModelWithLMHead, |
43 | - "summarization": AutoModelForSeq2SeqLM, | 45 | + "summarization": BartForConditionalGeneration, |
44 | "translation": AutoModelForSeq2SeqLM, | 46 | "translation": AutoModelForSeq2SeqLM, |
45 | } | 47 | } |
46 | 48 | ... | ... |
modeling_bart.py
0 → 100644
This diff is collapsed. Click to expand it.
modeling_utils.py
0 → 100644
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment