graykode

(add) customized bart model to modify patch_ids

......@@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer):
t0 = time.time()
generated_ids = self.model.generate(
batch[0].long(),
patch_ids=batch[2].long(),
attention_mask=batch[1].long(),
# patch_ids=batch[2].long(),
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
)
......
This diff is collapsed. Click to expand it.
......@@ -21,6 +21,8 @@ from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
)
from modeling_bart import BartForConditionalGeneration
from transformers.optimization import (
Adafactor,
get_cosine_schedule_with_warmup,
......@@ -40,7 +42,7 @@ MODEL_MODES = {
"pretraining": AutoModelForPreTraining,
"token-classification": AutoModelForTokenClassification,
"language-modeling": AutoModelWithLMHead,
"summarization": AutoModelForSeq2SeqLM,
"summarization": BartForConditionalGeneration,
"translation": AutoModelForSeq2SeqLM,
}
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.