graykode

(add) customized bart model to modify patch_ids

......@@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer):
t0 = time.time()
generated_ids = self.model.generate(
batch[0].long(),
patch_ids=batch[2].long(),
attention_mask=batch[1].long(),
# patch_ids=batch[2].long(),
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
)
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Optional, Tuple
import torch
from torch import Tensor
from torch.nn import functional as F
from transformers.file_utils import ModelOutput
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
class GenerationMixin:
"""
A class contraining all of the functions supporting generation, to be used as a mixin in
:class:`~transfomers.PreTrainedModel`.
"""
def prepare_inputs_for_generation(self, input_ids, **kwargs):
"""
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to prepare inputs in the
generate method.
"""
return {"input_ids": input_ids}
def adjust_logits_during_generation(self, logits, **kwargs):
"""
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method.
"""
return logits
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
"""
for i in range(batch_size * num_beams):
for previous_token in set(prev_output_tokens[i].tolist()):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[i, previous_token] < 0:
lprobs[i, previous_token] *= repetition_penalty
else:
lprobs[i, previous_token] /= repetition_penalty
def postprocess_next_token_scores(
self,
scores,
input_ids,
no_repeat_ngram_size,
bad_words_ids,
cur_len,
min_length,
max_length,
eos_token_id,
repetition_penalty,
batch_size,
num_beams,
):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(
scores,
batch_size,
num_beams,
input_ids,
repetition_penalty,
)
# set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length:
scores[:, eos_token_id] = -float("inf")
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
num_batch_hypotheses = batch_size * num_beams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_batch_tokens = calc_banned_ngram_tokens(
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
if bad_words_ids is not None:
# Exclude EOS token (already processed)
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
# Modify the scores in place by setting the banned tokens logits to `-inf`
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
return scores
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_kwargs
) -> torch.LongTensor:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Adapted in part from `Facebook's XLM beam search code
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
indicated are the default values of those config.
Most of these parameters are explained in more detail in `this blog post
<https://huggingface.co/blog/how-to-generate>`__.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes
it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
min_length (:obj:`int`, `optional`, defaults to 10):
The minimum length of the sequence to be generated.
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use sampling ; use greedy decoding otherwise.
early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
num_beams (:obj:`int`, `optional`, defaults to 1):
Number of beams for beam search. 1 means no beam search.
temperature (:obj:`float`, `optional`, defaults tp 1.0):
The value used to module the next token probabilities.
top_k (:obj:`int`, `optional`, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (:obj:`float`, `optional`, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or
higher are kept for generation.
repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
bos_token_id (:obj:`int`, `optional`):
The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty.
Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
order to encourage the model to produce longer sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(:obj:`List[int]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
tokens that are not masked, and 0 for masked tokens.
If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
Examples::
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40) # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
"""
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
)
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
else:
batch_size = 1
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert input_ids is not None or (
isinstance(bos_token_id, int) and bos_token_id >= 0
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
), "`no_repeat_ngram_size` should be a positive integer."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictly positive integer."
assert (
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = torch.full(
(batch_size, 1),
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False:
if num_beams == 1:
# no_beam_search greedy generation conditions
assert (
num_return_sequences == 1
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else:
# beam_search greedy generation conditions
assert (
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created
if pad_token_id is None and eos_token_id is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
)
pad_token_id = eos_token_id
# current position and vocab size
if hasattr(self.config, "vocab_size"):
vocab_size = self.config.vocab_size
elif (
self.config.is_encoder_decoder
and hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "vocab_size")
):
vocab_size = self.config.decoder.vocab_size
# set effective batch size and effective batch multiplier according to do_sample
if do_sample:
effective_batch_size = batch_size * num_return_sequences
effective_batch_mult = num_return_sequences
else:
effective_batch_size = batch_size
effective_batch_mult = 1
if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
# see if BOS token can be used for decoder_start_token_id
if bos_token_id is not None:
decoder_start_token_id = bos_token_id
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
decoder_start_token_id = self.config.decoder.bos_token_id
else:
raise ValueError(
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
)
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
)
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
# create empty decoder_input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
decoder_start_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
cur_len = 1
assert (
batch_size == encoder_outputs.last_hidden_state.shape[0]
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = (
torch.arange(batch_size)
.view(-1, 1)
.repeat(1, num_beams * effective_batch_mult)
.view(-1)
.to(input_ids.device)
)
# expand encoder_outputs
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_batch_idxs
)
# save encoder_outputs in `model_kwargs`
model_kwargs["encoder_outputs"] = encoder_outputs
else:
cur_len = input_ids.shape[-1]
assert (
cur_len < max_length
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
attention_mask=attention_mask,
use_cache=use_cache,
model_kwargs=model_kwargs,
)
else:
output = self._generate_no_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
attention_mask=attention_mask,
use_cache=use_cache,
model_kwargs=model_kwargs,
)
return output
def _generate_no_beam_search(
self,
input_ids,
cur_len,
max_length,
min_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
attention_mask,
use_cache,
model_kwargs,
):
"""Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
scores = self.postprocess_next_token_scores(
scores=next_token_logits,
input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
cur_len=cur_len,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=1,
)
# if model has past, then set the past variable to speed up decoding
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
# Sample
probs = F.softmax(next_token_logscores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1)
# update generations and finished sentences
if eos_token_id is not None:
# pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else:
tokens_to_add = next_token
# add token and increase length by one
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long())
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
break
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return input_ids
def _generate_beam_search(
self,
input_ids,
cur_len,
max_length,
min_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
attention_mask,
use_cache,
model_kwargs,
):
"""Generate sequences for each example with beam search."""
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
scores = self.postprocess_next_token_scores(
scores=scores,
input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
cur_len=cur_len,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=num_beams,
)
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size)
)
if do_sample:
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Temperature
if temperature != 1.0:
_scores = _scores / temperature
# Top-p/top-k filtering
_scores = top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together to sample from all beam_idxs
_scores = _scores.contiguous().view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
probs = F.softmax(_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
# Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
else:
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = next_scores.view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
# next batch beam content
next_batch_beam = []
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence, add a pad token
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content, this will get added to next_batch_beam
next_sent_beam = []
# next tokens for this sentence
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and token IDs
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(),
beam_token_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# once the beam for next step is full, don't add more tokens to it.
if len(next_sent_beam) == num_beams:
break
# Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
)
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
# stop when we are done with each sentence
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and update current length
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
cur_len = cur_len + 1
# re-order internal states
if past is not None:
past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx in range(batch_size):
if done[batch_idx]:
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_id is not None and all(
(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
):
assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx],
beam_scores.view(batch_size, num_beams)[batch_idx],
)
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(num_beams):
effective_beam_id = batch_idx * num_beams + beam_id
final_score = beam_scores[effective_beam_id].item()
final_tokens = input_ids[effective_beam_id]
generated_hyps[batch_idx].add(final_tokens, final_score)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses
sent_lengths = input_ids.new(output_batch_size)
best = []
# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps):
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
for j in range(output_num_return_sequences_per_batch):
effective_batch_idx = output_num_return_sequences_per_batch * i + j
best_hyp = sorted_hyps.pop()[1]
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
# shorter batches are padded
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
else:
# none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
return decoded
@staticmethod
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - no_repeat_ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_tokens):
# if bad word tokens are longer than prev tokens they can't be equal
return False
if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in bad_words_ids:
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
bad_words_ids
)
if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
"""Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
if not banned_mask_list:
return
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
scores.masked_fill_(banned_mask, -float("inf"))
def top_k_top_p_filtering(
logits: Tensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> Tensor:
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
......@@ -21,6 +21,8 @@ from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
)
from modeling_bart import BartForConditionalGeneration
from transformers.optimization import (
Adafactor,
get_cosine_schedule_with_warmup,
......@@ -40,7 +42,7 @@ MODEL_MODES = {
"pretraining": AutoModelForPreTraining,
"token-classification": AutoModelForTokenClassification,
"language-modeling": AutoModelWithLMHead,
"summarization": AutoModelForSeq2SeqLM,
"summarization": BartForConditionalGeneration,
"translation": AutoModelForSeq2SeqLM,
}
......
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import math
import random
import warnings
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.configuration_bart import BartConfig
from transformers.file_utils import (
add_code_sample_docstrings,
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
)
from modeling_utils import PreTrainedModel
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
_CONFIG_FOR_DOC = "BartConfig"
_TOKENIZER_FOR_DOC = "BartTokenizer"
BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/bart-base",
"facebook/bart-large",
"facebook/bart-large-mnli",
"facebook/bart-large-cnn",
"facebook/bart-large-xsum",
"facebook/mbart-large-en-ro",
]
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
BART_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matters related to general usage and behavior.
Parameters:
config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
BART_GENERATION_EXAMPLE = r"""
Summarization example::
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# see ``examples/summarization/bart/run_eval.py`` for a longer example
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
"""
BART_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
Padding will be ignored by default should you provide it.
Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices in input_ids.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If ``past_key_values`` are used, the user can optionally input only the last
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
``past_key_values``).
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
plain tuple.
"""
def invert_mask(attention_mask):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2
return attention_mask.eq(0)
def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
):
"""Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
Note: this is not called during generation
"""
pad_token_id = config.pad_token_id
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
bsz, tgt_len = decoder_input_ids.size()
if decoder_padding_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
# never mask leading token, even if it is pad
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
)
return decoder_input_ids, decoder_padding_mask, causal_mask
class PretrainedBartModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, SinusoidalPositionalEmbedding):
pass
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
def _make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
# Helper Functions, mostly for making masks
def _check_shapes(shape_1, shape2):
if shape_1 != shape2:
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
def shift_tokens_right(input_ids, pad_token_id):
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
prev_output_tokens = input_ids.clone()
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
return prev_output_tokens
def make_padding_mask(input_ids, padding_idx=1):
"""True for pad tokens"""
padding_mask = input_ids.eq(padding_idx)
if not padding_mask.any():
padding_mask = None
return padding_mask
# Helper Modules
class EncoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(self, x, encoder_padding_mask, output_attentions=False):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, attn_weights
class BartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
is a :class:`EncoderLayer`.
Args:
config: BartConfig
"""
def __init__(self, config: BartConfig, embed_tokens):
super().__init__()
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens
if config.static_position_embeddings:
self.embed_positions = SinusoidalPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx
)
else:
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
self.padding_idx,
config.extra_pos_embeddings,
)
self.embed_patches = nn.Embedding(3, config.d_model)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
def forward(
self, input_ids, patch_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
):
"""
Args:
input_ids (LongTensor): tokens in the source language of shape
`(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
Returns:
BaseModelOutput or Tuple comprised of:
- **x** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *output_hidden_states:* is True.
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
if attention_mask is not None:
attention_mask = invert_mask(attention_mask)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_ids)
embed_patch = self.embed_patches(patch_ids)
x = inputs_embeds + embed_pos + embed_patch
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_states = [] if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
else:
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
if output_attentions:
all_attentions = all_attentions + (attn,)
if self.layer_norm:
x = self.layer_norm(x)
if output_hidden_states:
encoder_states.append(x)
# T x B x C -> B x T x C
encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if not return_dict:
return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
class DecoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Attention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = Attention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(
self,
x,
encoder_hidden_states,
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
decoder_padding_mask=None,
output_attentions=False,
):
residual = x
if layer_state is None:
layer_state = {}
if self.normalize_before:
x = self.self_attn_layer_norm(x)
# Self Attention
x, self_attn_weights = self.self_attn(
query=x,
key=x,
layer_state=layer_state, # adds keys to layer state
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# Cross attention
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
# Fully Connected
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return (
x,
self_attn_weights,
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
class BartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer
is a :class:`DecoderLayer`.
Args:
config: BartConfig
embed_tokens (torch.nn.Embedding): output embedding
"""
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
super().__init__()
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
if config.static_position_embeddings:
self.embed_positions = SinusoidalPositionalEmbedding(
config.max_position_embeddings, config.d_model, config.pad_token_id
)
else:
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
def forward(
self,
input_ids,
encoder_hidden_states,
encoder_padding_mask,
decoder_padding_mask,
decoder_causal_mask,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
**unused,
):
"""
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
input_ids (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_hidden_states: output from the encoder, used for
encoder-side attention
encoder_padding_mask: for ignoring pad tokens
past_key_values (dict or None): dictionary used for storing state during generation
Returns:
BaseModelOutputWithPast or tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- the cache
- hidden states
- attentions
"""
if "decoder_cached_states" in unused:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = unused.pop("decoder_cached_states")
if "decoder_past_key_values" in unused:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = unused.pop("decoder_past_key_values")
# check attention mask and invert
if encoder_padding_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask)
# embed positions
positions = self.embed_positions(input_ids, use_cache=use_cache)
if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
# assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids) * self.embed_scale
x += positions
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (x,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = past_key_values[idx] if past_key_values is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask,
decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state,
causal_mask=decoder_causal_mask,
output_attentions=output_attentions,
)
if use_cache:
next_decoder_cache.append(layer_past.copy())
if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
if output_attentions:
all_self_attns += (layer_self_attn,)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if output_hidden_states:
all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
)
def _reorder_buffer(attn_cache, new_order):
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = input_buffer_k.index_select(0, new_order)
return attn_cache
class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
encoder_decoder_attention=False, # otherwise self_attention
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.encoder_decoder_attention = encoder_decoder_attention
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
def _shape(self, tensor, seq_len, bsz):
return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
def forward(
self,
query,
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None,
output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv: bool = self.encoder_decoder_attention
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv
if layer_state is not None: # reuse k,v and encoder_padding_mask
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state and static_kv:
# previous time steps are cached - no need to recompute key and value if they are static
key = None
else:
saved_state = None
layer_state = {}
q = self.q_proj(query) * self.scaling
if static_kv:
if key is None:
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
k = self.k_proj(query)
v = self.v_proj(query)
q = self._shape(q, tgt_len, bsz)
if k is not None:
k = self._shape(k, -1, bsz)
if v is not None:
v = self._shape(v, -1, bsz)
if saved_state is not None:
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
# Update cache
layer_state[self.cache_key] = {
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
}
assert k is not None
src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (
bsz,
src_len,
)
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(
attn_weights,
p=self.dropout,
training=self.training,
)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
if output_attentions:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
else:
attn_weights = None
return attn_output, attn_weights
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
assert k is not None and v is not None
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
if prev_key_padding_mask is not None:
if static_kv:
new_key_padding_mask = prev_key_padding_mask
else:
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
else:
new_key_padding_mask = key_padding_mask
return k, v, new_key_padding_mask
class BartClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
# This can trivially be shared with RobertaClassificationHead
def __init__(
self,
input_dim,
inner_dim,
num_classes,
pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, x):
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
assert padding_idx is not None
num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
else:
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions + self.offset)
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a input_ids with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
# Public API
def _get_shape(t):
return getattr(t, "shape", None)
@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
class BartModel(PretrainedBartModel):
def __init__(self, config: BartConfig):
super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
self.encoder = BartEncoder(config, self.shared)
self.decoder = BartDecoder(config, self.shared)
self.init_weights()
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids,
patch_ids=None,
attention_mask=None,
decoder_input_ids=None,
encoder_outputs: Optional[Tuple] = None,
decoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
if "decoder_past_key_values" in kwargs:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_values")
if decoder_input_ids is None:
use_cache = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# make masks if user doesn't supply
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
decoder_input_ids=decoder_input_ids,
decoder_padding_mask=decoder_attention_mask,
causal_mask_dtype=self.shared.weight.dtype,
)
else:
decoder_padding_mask, causal_mask = None, None
assert decoder_input_ids is not None
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
patch_ids=patch_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids,
encoder_outputs[0],
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def get_output_embeddings(self):
return _make_linear_from_emb(self.shared) # make it on the fly
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model"
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
def __init__(self, config: BartConfig):
super().__init__(config)
base_model = BartModel(config)
self.model = base_model
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
old_num_tokens = self.model.shared.num_embeddings
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self.model.shared = new_embeddings
self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(BART_GENERATION_EXAMPLE)
def forward(
self,
input_ids,
patch_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
past_key_values=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**unused,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
with labels in ``[0, ..., config.vocab_size]``.
Returns:
Conditional generation example::
# Mask filling only works for bart-large
from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
TXT = "My friends are <mask> but they eat too many carbs."
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
logits = model(input_ids).logits
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
tokenizer.decode(predictions).split()
# ['good', 'great', 'all', 'really', 'very']
"""
if "lm_labels" in unused:
warnings.warn(
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = unused.pop("lm_labels")
if "decoder_cached_states" in unused:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = unused.pop("decoder_cached_states")
if "decoder_past_key_values" in unused:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = unused.pop("decoder_past_key_values")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
outputs = self.model(
input_ids,
patch_ids=patch_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels?
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_ids_generation(logits, self.config.bos_token_id)
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
def _force_token_ids_generation(self, scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
return reordered_past
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
BART_START_DOCSTRING,
)
class BartForSequenceClassification(PretrainedBartModel):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
self.classification_head = BartClassificationHead(
config.d_model,
config.d_model,
config.num_labels,
config.classif_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
logits = self.classification_head(sentence_representation)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"""BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
BART_START_DOCSTRING,
)
class BartForQuestionAnswering(PretrainedBartModel):
def __init__(self, config):
super().__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
self.model = BartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
start_positions=None,
end_positions=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if start_positions is not None and end_positions is not None:
use_cache = False
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (
start_logits,
end_logits,
) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return Seq2SeqQuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
class SinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions, embedding_dim, padding_idx=None):
super().__init__(num_positions, embedding_dim)
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter):
"""Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
The cos features are in the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False
return out
@torch.no_grad()
def forward(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
else:
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import re
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import Tensor, device, dtype, nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers.activations import get_activation
from transformers.configuration_utils import PretrainedConfig
from transformers.file_utils import (
DUMMY_INPUTS,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_NAME,
ModelOutput,
cached_path,
hf_bucket_url,
is_remote_url,
is_torch_tpu_available,
replace_return_docstrings,
)
from generation_utils import GenerationMixin
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
try:
from torch.nn import Identity
except ImportError:
# Older PyTorch compatibility
class Identity(nn.Module):
r"""A placeholder identity operator that is argument-insensitive."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, input):
return input
def find_pruneable_heads_and_indices(
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:
"""
Finds the heads and their indices taking :obj:`already_pruned_heads` into account.
Args:
heads (:obj:`List[int]`): List of the indices of heads to prune.
n_heads (:obj:`int`): The number of heads in the model.
head_size (:obj:`int`): The size of each head.
already_pruned_heads (:obj:`Set[int]`): A set of already pruned heads.
Returns:
:obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
"""
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
return heads, index
class ModuleUtilsMixin:
"""
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
"""
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get the number of (optionally, trainable) parameters in the model.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
Returns:
:obj:`int`: The number of parameters.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try:
import psutil
except (ImportError):
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_pre_forward = mem.rss
return None
@staticmethod
def _hook_rss_memory_post_forward(module, *args, **kwargs):
try:
import psutil
except (ImportError):
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_post_forward = mem.rss
mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
return None
def add_memory_hooks(self):
"""
Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to
zero with :obj:`model.reset_memory_hooks_state()`.
"""
for module in self.modules():
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
module.register_forward_hook(self._hook_rss_memory_post_forward)
self.reset_memory_hooks_state()
def reset_memory_hooks_state(self):
"""
Reset the :obj:`mem_rss_diff` attribute of each module (see
:func:`~transformers.modeling_utils.ModuleUtilsMixin.add_memory_hooks`).
"""
for module in self.modules():
module.mem_rss_diff = 0
module.mem_rss_post_forward = 0
module.mem_rss_pre_forward = 0
@property
def device(self) -> device:
"""
:obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
try:
return next(self.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
@property
def dtype(self) -> dtype:
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (:obj:`torch.Tensor`): An attention mask.
Returns:
:obj:`torch.Tensor`: The inverted attention mask.
"""
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype == torch.float32:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
raise ValueError(
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
self.dtype
)
)
return encoder_extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def get_head_mask(
self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
) -> Tensor:
"""
Prepare the head mask if needed.
Args:
head_mask (:obj:`torch.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (:obj:`int`):
The number of hidden layers in the model.
is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`):
Whether or not the attentions scores are computed by chunks or not.
Returns:
:obj:`torch.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length]`
or list with :obj:`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
return head_mask
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
return head_mask
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
r"""
Base class for all models.
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods
for loading, downloading and saving models as well as a few methods common to all models to:
* resize the input embeddings,
* prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **load_tf_weights** (:obj:`Callable`) -- A python `method` for loading a TensorFlow checkpoint in a
PyTorch model, taking as arguments:
- **model** (:class:`~transformers.PreTrainedModel`) -- An instance of the model on which to load the
TensorFlow checkpoint.
- **config** (:class:`~transformers.PreTrainedConfig`) -- An instance of the configuration associated
to the model.
- **path** (:obj:`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
when loading the model (and avoid unnecessary warnings).
"""
config_class = None
base_model_prefix = ""
authorized_missing_keys = None
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
:obj:`Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
"""
return {"input_ids": torch.tensor(DUMMY_INPUTS)}
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
)
)
# Save config in model
self.config = config
@property
def base_model(self) -> nn.Module:
"""
:obj:`torch.nn.Module`: The main body of the model.
"""
return getattr(self, self.base_model_prefix, self)
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.
Returns:
:obj:`nn.Module`: A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
def set_input_embeddings(self, value: nn.Module):
"""
Set model's input embeddings
Args:
value (:obj:`nn.Module`): A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self) -> nn.Module:
"""
Returns the model's output embeddings.
Returns:
:obj:`nn.Module`: A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
the weights instead.
"""
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
uninitialized_encoder_weights: List[str] = []
assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal."
def tie_encoder_to_decoder_recursively(
decoder_pointer: nn.Module,
encoder_pointer: nn.Module,
module_name: str,
uninitialized_encoder_weights: List[str],
depth=0,
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
if hasattr(decoder_pointer, "bias"):
assert hasattr(encoder_pointer, "bias")
encoder_pointer.bias = decoder_pointer.bias
return
encoder_modules = encoder_pointer._modules
decoder_modules = decoder_pointer._modules
if len(decoder_modules) > 0:
assert (
len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
encoder_layer_pos = 0
for name, module in decoder_modules.items():
if name.isdigit():
encoder_name = str(int(name) + encoder_layer_pos)
decoder_name = name
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])):
# this can happen if the name corresponds to the position in a list module list of layers
# in this case the decoder has added a cross-attention that the encoder does not have
# thus skip this step and substract one layer pos from encoder
encoder_layer_pos -= 1
continue
elif name not in encoder_modules:
continue
elif depth > 500:
raise ValueError(
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
)
else:
decoder_name = encoder_name = name
tie_encoder_to_decoder_recursively(
decoder_modules[decoder_name],
encoder_modules[encoder_name],
module_name + "/" + name,
uninitialized_encoder_weights,
depth=depth + 1,
)
all_encoder_weights.remove(module_name + "/" + encoder_name)
uninitialized_encoder_weights += list(all_encoder_weights)
# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
if len(uninitialized_encoder_weights) > 0:
logger.warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
)
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
if self.config.torchscript:
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
else:
output_embeddings.weight = input_embeddings.weight
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(
0,
output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
),
"constant",
0,
)
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
"""
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
Arguments:
new_num_tokens (:obj:`int`, `optional`):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
just returns a pointer to the input tokens :obj:`torch.nn.Embedding` module of the model wihtout doing
anything.
Return:
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
# Update base model and current model config
self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
# Tie weights again if needed
self.tie_weights()
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()
def _get_resized_embeddings(
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
) -> torch.nn.Embedding:
"""
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
initialized vectors at the end. Reducing the size will remove vectors from the end
Args:
old_embeddings (:obj:`torch.nn.Embedding`):
Old embeddings to be resized.
new_num_tokens (:obj:`int`, `optional`):
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
:obj:`torch.nn.Embedding`` module of the model wihtout doing anything.
Return:
:obj:`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
:obj:`new_num_tokens` is :obj:`None`
"""
if new_num_tokens is None:
return old_embeddings
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
return new_embeddings
def init_weights(self):
"""
Initializes and prunes weights if needed.
"""
# Initialize weights
self.apply(self._init_weights)
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
# Tie weights if needed
self.tie_weights()
def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
"""
Prunes heads of the base model.
Arguments:
heads_to_prune (:obj:`Dict[int, List[int]]`):
Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items():
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
Arguments:
save_directory (:obj:`str`):
Directory to which to save. Will be created if it doesn't exist.
"""
if os.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
return
os.makedirs(save_directory, exist_ok=True)
# Only save the model itself if we are using distributed training
model_to_save = self.module if hasattr(self, "module") else self
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
if getattr(self.config, "xla_device", False):
import torch_xla.core.xla_model as xm
if xm.is_master_ordinal():
# Save configuration file
model_to_save.config.save_pretrained(save_directory)
# xm.save takes care of saving only from master
xm.save(model_to_save.state_dict(), output_model_file)
else:
model_to_save.config.save_pretrained(save_directory)
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Model weights saved in {}".format(output_model_file))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
To train the model, you should first set it back in training mode with ``model.train()``.
The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (:obj:`str`, `optional`):
Can be either:
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
arguments ``config`` and ``state_dict``).
model_args (sequence of positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str]`, `optional`):
Can be either:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model).
- The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory.
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own
weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir (:obj:`str`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g.,
:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
Examples::
from transformers import BertConfig, BertModel
# Download model and configuration from S3 and cache.
model = BertModel.from_pretrained('bert-base-uncased')
# Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = BertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading.
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
from_tf = kwargs.pop("from_tf", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
**kwargs,
)
else:
model_kwargs = kwargs
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
# Load from a TF 1.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
raise EnvironmentError(
"Error no file named {} found in directory {} or `from_tf` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
pretrained_model_name_or_path,
)
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
assert (
from_tf
), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
pretrained_model_name_or_path + ".index"
)
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
use_cdn=use_cdn,
)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
)
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
else:
resolved_archive_file = None
# Instantiate model.
model = cls(config, *model_args, **model_kwargs)
if state_dict is None and not from_tf:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
"Unable to load weights from pytorch checkpoint file. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
missing_keys = []
unexpected_keys = []
error_msgs = []
if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
else:
# Load from our TensorFlow 2.0 checkpoints
try:
from transformers import load_tf2_checkpoint_in_pytorch_model
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
else:
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if model.__class__.__name__ != model_to_load.__class__.__name__:
base_model_state_dict = model_to_load.state_dict().keys()
head_model_state_dict_without_base_prefix = [
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
]
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
# make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"error_msgs": error_msgs,
}
return model, loading_info
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
import torch_xla.core.xla_model as xm
model = xm.send_cpu_data_to_device(model, xm.xla_device())
model.to(xm.xla_device())
return model
class Conv1D(nn.Module):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (:obj:`int`): The number of output features.
nx (:obj:`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
class PoolerStartLogits(nn.Module):
"""
Compute SQuAD start logits from sequence hidden states.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, 1)
def forward(
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
"""
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
Returns:
:obj:`torch.FloatTensor`: The start logits for SQuAD.
"""
x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None:
if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
class PoolerEndLogits(nn.Module):
"""
Compute SQuAD end logits from sequence hidden states.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
:obj:`layer_norm_eps` to use.
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dense_1 = nn.Linear(config.hidden_size, 1)
def forward(
self,
hidden_states: torch.FloatTensor,
start_states: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
The hidden states of the first tokens for the labeled span.
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
The position of the first token for the labeled span.
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
.. note::
One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
``start_positions`` overrides ``start_states``.
Returns:
:obj:`torch.FloatTensor`: The end logits for SQuAD.
"""
assert (
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
x = self.activation(x)
x = self.LayerNorm(x)
x = self.dense_1(x).squeeze(-1)
if p_mask is not None:
if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
class PoolerAnswerClass(nn.Module):
"""
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
"""
def __init__(self, config):
super().__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
def forward(
self,
hidden_states: torch.FloatTensor,
start_states: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
cls_index: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
"""
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
The hidden states of the first tokens for the labeled span.
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
The position of the first token for the labeled span.
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
.. note::
One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
``start_positions`` overrides ``start_states``.
Returns:
:obj:`torch.FloatTensor`: The SQuAD 2.0 answer class.
"""
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
hsz = hidden_states.shape[-1]
assert (
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
if cls_index is not None:
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
else:
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
x = self.activation(x)
x = self.dense_1(x).squeeze(-1)
return x
@dataclass
class SquadHeadOutput(ModelOutput):
"""
Base class for outputs of question answering models using a :class:`~transformers.modeling_utils.SQuADHead`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top config.start_n_top start token possibilities (beam-search).
end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
"""
loss: Optional[torch.FloatTensor] = None
start_top_log_probs: Optional[torch.FloatTensor] = None
start_top_index: Optional[torch.LongTensor] = None
end_top_log_probs: Optional[torch.FloatTensor] = None
end_top_index: Optional[torch.LongTensor] = None
cls_logits: Optional[torch.FloatTensor] = None
class SQuADHead(nn.Module):
r"""
A SQuAD head inspired by XLNet.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
:obj:`layer_norm_eps` to use.
"""
def __init__(self, config):
super().__init__()
self.start_n_top = config.start_n_top
self.end_n_top = config.end_n_top
self.start_logits = PoolerStartLogits(config)
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
def forward(
self,
hidden_states: torch.FloatTensor,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
cls_index: Optional[torch.LongTensor] = None,
is_impossible: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
return_dict: bool = False,
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
"""
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
Final hidden states of the model on the sequence tokens.
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Positions of the first token for the labeled span.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Positions of the last token for the labeled span.
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Whether the question has a possible answer in the paragraph or not.
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
Returns:
"""
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for x in (start_positions, end_positions, cls_index, is_impossible):
if x is not None and x.dim() > 1:
x.squeeze_(-1)
# during training, compute the end logits based on the ground truth of the start position
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if cls_index is not None and is_impossible is not None:
# Predict answerability from the representation of CLS and START
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1
) # shape (bsz, start_n_top)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
start_states
) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(
end_log_probs, self.end_n_top, dim=1
) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
if not return_dict:
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
else:
return SquadHeadOutput(
start_top_log_probs=start_top_log_probs,
start_top_index=start_top_index,
end_top_log_probs=end_top_log_probs,
end_top_index=end_top_index,
cls_logits=cls_logits,
)
class SequenceSummary(nn.Module):
r"""
Compute a single vector summary of a sequence hidden states.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model. Relevant arguments in the config class of the model are (refer to the
actual config class of your model for the default values it uses):
- **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:
- :obj:`"last"` -- Take the last token hidden state (like XLNet)
- :obj:`"first"` -- Take the first token hidden state (like Bert)
- :obj:`"mean"` -- Take the mean of all tokens hidden states
- :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- :obj:`"attn"` -- Not implemented now, use multi-head attention
- **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
- **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
:obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
- **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
output, another string or :obj:`None` will add no activation.
- **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
activation.
- **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
activation.
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.summary_type = getattr(config, "summary_type", "last")
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.summary = Identity()
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
activation_string = getattr(config, "summary_activation", None)
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.last_dropout = Identity()
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
) -> torch.FloatTensor:
"""
Compute a single vector summary of a sequence hidden states.
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
token.
Returns:
:obj:`torch.FloatTensor`: The summary of the sequence hidden states.
"""
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == "mean":
output = hidden_states.mean(dim=1)
elif self.summary_type == "cls_index":
if cls_index is None:
cls_index = torch.full_like(
hidden_states[..., :1, :],
hidden_states.shape[-2] - 1,
dtype=torch.long,
)
else:
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError
output = self.first_dropout(output)
output = self.summary(output)
output = self.activation(output)
output = self.last_dropout(output)
return output
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
"""
Prune a linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (:obj:`torch.nn.Linear`): The layer to prune.
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
dim (:obj:`int`, `optional`, defaults to 0): The dimension on which to keep the indices.
Returns:
:obj:`torch.nn.Linear`: The pruned layer as a new layer with :obj:`requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
"""
Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
are transposed.
Used to remove heads.
Args:
layer (:class:`~transformers.modeling_utils.Conv1D`): The layer to prune.
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
dim (:obj:`int`, `optional`, defaults to 1): The dimension on which to keep the indices.
Returns:
:class:`~transformers.modeling_utils.Conv1D`: The pruned layer as a new layer with :obj:`requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if dim == 0:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def prune_layer(
layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
) -> Union[torch.nn.Linear, Conv1D]:
"""
Prune a Conv1D or linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (:obj:`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
index (:obj:`torch.LongTensor`): The indices to keep in the layer.
dim (:obj:`int`, `optional`): The dimension on which to keep the indices.
Returns:
:obj:`torch.nn.Linear` or :class:`~transformers.modeling_utils.Conv1D`:
The pruned layer as a new layer with :obj:`requires_grad=True`.
"""
if isinstance(layer, nn.Linear):
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
elif isinstance(layer, Conv1D):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
def apply_chunking_to_forward(
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> torch.Tensor:
"""
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory.
If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as
directly applying :obj:`forward_fn` to :obj:`input_tensors`.
Args:
forward_fn (:obj:`Callable[..., torch.Tensor]`):
The forward function of the model.
chunk_size (:obj:`int`):
The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (:obj:`int`):
The dimension over which the :obj:`input_tensors` should be chunked.
input_tensors (:obj:`Tuple[torch.Tensor]`):
The input tensors of ``forward_fn`` which will be chunked.
Returns:
:obj:`torch.Tensor`: A tensor with the same shape as the :obj:`foward_fn` would have given if applied`.
Examples::
# rename the usual forward() fn to forward_chunk()
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
"""
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
tensor_shape = input_tensors[0].shape
assert all(
input_tensor.shape == tensor_shape for input_tensor in input_tensors
), "All input tenors have to be of the same shape"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
assert num_args_in_forward_chunk_fn == len(
input_tensors
), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
num_args_in_forward_chunk_fn, len(input_tensors)
)
if chunk_size > 0:
assert (
input_tensors[0].shape[chunk_dim] % chunk_size == 0
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
input_tensors[0].shape[chunk_dim], chunk_size
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
# concatenate output at same dimension
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)