graykode

(add) customized bart model to modify patch_ids

...@@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer): ...@@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer):
188 t0 = time.time() 188 t0 = time.time()
189 generated_ids = self.model.generate( 189 generated_ids = self.model.generate(
190 batch[0].long(), 190 batch[0].long(),
191 + patch_ids=batch[2].long(),
191 attention_mask=batch[1].long(), 192 attention_mask=batch[1].long(),
192 - # patch_ids=batch[2].long(),
193 use_cache=True, 193 use_cache=True,
194 decoder_start_token_id=self.decoder_start_token_id, 194 decoder_start_token_id=self.decoder_start_token_id,
195 ) 195 )
......
1 +# coding=utf-8
2 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3 +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 +#
5 +# Licensed under the Apache License, Version 2.0 (the "License");
6 +# you may not use this file except in compliance with the License.
7 +# You may obtain a copy of the License at
8 +#
9 +# http://www.apache.org/licenses/LICENSE-2.0
10 +#
11 +# Unless required by applicable law or agreed to in writing, software
12 +# distributed under the License is distributed on an "AS IS" BASIS,
13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 +# See the License for the specific language governing permissions and
15 +# limitations under the License.
16 +
17 +from typing import Iterable, List, Optional, Tuple
18 +
19 +import torch
20 +from torch import Tensor
21 +from torch.nn import functional as F
22 +
23 +from transformers.file_utils import ModelOutput
24 +import logging
25 +
26 +logger = logging.getLogger(__name__) # pylint: disable=invalid-name
27 +logging.basicConfig(
28 + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
29 + datefmt="%m/%d/%Y %H:%M:%S",
30 + level=logging.INFO,
31 +)
32 +
33 +class GenerationMixin:
34 + """
35 + A class contraining all of the functions supporting generation, to be used as a mixin in
36 + :class:`~transfomers.PreTrainedModel`.
37 + """
38 +
39 + def prepare_inputs_for_generation(self, input_ids, **kwargs):
40 + """
41 + Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to prepare inputs in the
42 + generate method.
43 + """
44 + return {"input_ids": input_ids}
45 +
46 + def adjust_logits_during_generation(self, logits, **kwargs):
47 + """
48 + Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
49 + the generate method.
50 + """
51 + return logits
52 +
53 + def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
54 + """
55 + Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
56 + """
57 + for i in range(batch_size * num_beams):
58 + for previous_token in set(prev_output_tokens[i].tolist()):
59 + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
60 + if lprobs[i, previous_token] < 0:
61 + lprobs[i, previous_token] *= repetition_penalty
62 + else:
63 + lprobs[i, previous_token] /= repetition_penalty
64 +
65 + def postprocess_next_token_scores(
66 + self,
67 + scores,
68 + input_ids,
69 + no_repeat_ngram_size,
70 + bad_words_ids,
71 + cur_len,
72 + min_length,
73 + max_length,
74 + eos_token_id,
75 + repetition_penalty,
76 + batch_size,
77 + num_beams,
78 + ):
79 + # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
80 + if repetition_penalty != 1.0:
81 + self.enforce_repetition_penalty_(
82 + scores,
83 + batch_size,
84 + num_beams,
85 + input_ids,
86 + repetition_penalty,
87 + )
88 +
89 + # set eos token prob to zero if min_length is not reached
90 + if eos_token_id is not None and cur_len < min_length:
91 + scores[:, eos_token_id] = -float("inf")
92 +
93 + if no_repeat_ngram_size > 0:
94 + # calculate a list of banned tokens to prevent repetitively generating the same ngrams
95 + num_batch_hypotheses = batch_size * num_beams
96 + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
97 + banned_batch_tokens = calc_banned_ngram_tokens(
98 + input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
99 + )
100 + for i, banned_tokens in enumerate(banned_batch_tokens):
101 + scores[i, banned_tokens] = -float("inf")
102 +
103 + if bad_words_ids is not None:
104 + # Exclude EOS token (already processed)
105 + bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
106 + # calculate a list of banned tokens according to bad words
107 + banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
108 + # Modify the scores in place by setting the banned tokens logits to `-inf`
109 + set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
110 +
111 + return scores
112 +
113 + @torch.no_grad()
114 + def generate(
115 + self,
116 + input_ids: Optional[torch.LongTensor] = None,
117 + max_length: Optional[int] = None,
118 + min_length: Optional[int] = None,
119 + do_sample: Optional[bool] = None,
120 + early_stopping: Optional[bool] = None,
121 + num_beams: Optional[int] = None,
122 + temperature: Optional[float] = None,
123 + top_k: Optional[int] = None,
124 + top_p: Optional[float] = None,
125 + repetition_penalty: Optional[float] = None,
126 + bad_words_ids: Optional[Iterable[int]] = None,
127 + bos_token_id: Optional[int] = None,
128 + pad_token_id: Optional[int] = None,
129 + eos_token_id: Optional[int] = None,
130 + length_penalty: Optional[float] = None,
131 + no_repeat_ngram_size: Optional[int] = None,
132 + num_return_sequences: Optional[int] = None,
133 + attention_mask: Optional[torch.LongTensor] = None,
134 + decoder_start_token_id: Optional[int] = None,
135 + use_cache: Optional[bool] = None,
136 + **model_kwargs
137 + ) -> torch.LongTensor:
138 + r"""
139 + Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
140 + beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
141 +
142 + Adapted in part from `Facebook's XLM beam search code
143 + <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
144 +
145 + Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
146 + attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
147 + indicated are the default values of those config.
148 +
149 + Most of these parameters are explained in more detail in `this blog post
150 + <https://huggingface.co/blog/how-to-generate>`__.
151 +
152 + Parameters:
153 +
154 + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
155 + The sequence used as a prompt for the generation. If :obj:`None` the method initializes
156 + it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
157 + max_length (:obj:`int`, `optional`, defaults to 20):
158 + The maximum length of the sequence to be generated.
159 + min_length (:obj:`int`, `optional`, defaults to 10):
160 + The minimum length of the sequence to be generated.
161 + do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
162 + Whether or not to use sampling ; use greedy decoding otherwise.
163 + early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
164 + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
165 + num_beams (:obj:`int`, `optional`, defaults to 1):
166 + Number of beams for beam search. 1 means no beam search.
167 + temperature (:obj:`float`, `optional`, defaults tp 1.0):
168 + The value used to module the next token probabilities.
169 + top_k (:obj:`int`, `optional`, defaults to 50):
170 + The number of highest probability vocabulary tokens to keep for top-k-filtering.
171 + top_p (:obj:`float`, `optional`, defaults to 1.0):
172 + If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or
173 + higher are kept for generation.
174 + repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
175 + The parameter for repetition penalty. 1.0 means no penalty. See `this paper
176 + <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
177 + pad_token_id (:obj:`int`, `optional`):
178 + The id of the `padding` token.
179 + bos_token_id (:obj:`int`, `optional`):
180 + The id of the `beginning-of-sequence` token.
181 + eos_token_id (:obj:`int`, `optional`):
182 + The id of the `end-of-sequence` token.
183 + length_penalty (:obj:`float`, `optional`, defaults to 1.0):
184 + Exponential penalty to the length. 1.0 means no penalty.
185 +
186 + Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
187 + order to encourage the model to produce longer sequences.
188 + no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
189 + If set to int > 0, all ngrams of that size can only occur once.
190 + bad_words_ids(:obj:`List[int]`, `optional`):
191 + List of token ids that are not allowed to be generated. In order to get the tokens of the words that
192 + should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
193 + num_return_sequences(:obj:`int`, `optional`, defaults to 1):
194 + The number of independently computed returned sequences for each element in the batch.
195 + attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
196 + Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
197 + tokens that are not masked, and 0 for masked tokens.
198 +
199 + If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token.
200 +
201 + `What are attention masks? <../glossary.html#attention-mask>`__
202 + decoder_start_token_id (:obj:`int`, `optional`):
203 + If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
204 + use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
205 + Whether or not the model should use the past last key/values attentions (if applicable to the model) to
206 + speed up decoding.
207 + model_kwargs:
208 + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
209 +
210 + Return:
211 +
212 + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
213 + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
214 + shorter if all batches finished early due to the :obj:`eos_token_id`.
215 +
216 + Examples::
217 +
218 + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
219 + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
220 + outputs = model.generate(max_length=40) # do greedy decoding
221 + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
222 +
223 + tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
224 + model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
225 + input_context = 'The dog'
226 + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
227 + outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
228 + for i in range(3): # 3 output sequences were generated
229 + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
230 +
231 + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
232 + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
233 + input_context = 'The dog'
234 + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
235 + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
236 + for i in range(3): # 3 output sequences were generated
237 + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
238 +
239 + tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
240 + model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
241 + input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
242 + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
243 + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
244 + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
245 +
246 + tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
247 + model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
248 + input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
249 + bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
250 + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
251 + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
252 + """
253 +
254 + # We cannot generate if the model does not have a LM head
255 + if self.get_output_embeddings() is None:
256 + raise AttributeError(
257 + "You tried to generate sequences with a model that does not have a LM Head."
258 + "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
259 + )
260 +
261 + max_length = max_length if max_length is not None else self.config.max_length
262 + min_length = min_length if min_length is not None else self.config.min_length
263 + do_sample = do_sample if do_sample is not None else self.config.do_sample
264 + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
265 + use_cache = use_cache if use_cache is not None else self.config.use_cache
266 + num_beams = num_beams if num_beams is not None else self.config.num_beams
267 + temperature = temperature if temperature is not None else self.config.temperature
268 + top_k = top_k if top_k is not None else self.config.top_k
269 + top_p = top_p if top_p is not None else self.config.top_p
270 + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
271 + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
272 + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
273 + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
274 + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
275 + no_repeat_ngram_size = (
276 + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
277 + )
278 + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
279 + num_return_sequences = (
280 + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
281 + )
282 + decoder_start_token_id = (
283 + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
284 + )
285 +
286 + if input_ids is not None:
287 + batch_size = input_ids.shape[0] # overriden by the input batch_size
288 + else:
289 + batch_size = 1
290 +
291 + assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
292 + assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
293 + assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
294 + assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
295 + assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
296 + assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
297 + assert temperature > 0, "`temperature` should be strictly positive."
298 + assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
299 + assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
300 + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
301 + assert input_ids is not None or (
302 + isinstance(bos_token_id, int) and bos_token_id >= 0
303 + ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
304 + assert pad_token_id is None or (
305 + isinstance(pad_token_id, int) and (pad_token_id >= 0)
306 + ), "`pad_token_id` should be a positive integer."
307 + assert (eos_token_id is None) or (
308 + isinstance(eos_token_id, int) and (eos_token_id >= 0)
309 + ), "`eos_token_id` should be a positive integer."
310 + assert length_penalty > 0, "`length_penalty` should be strictly positive."
311 + assert (
312 + isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
313 + ), "`no_repeat_ngram_size` should be a positive integer."
314 + assert (
315 + isinstance(num_return_sequences, int) and num_return_sequences > 0
316 + ), "`num_return_sequences` should be a strictly positive integer."
317 + assert (
318 + bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
319 + ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
320 +
321 + if input_ids is None:
322 + assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
323 + "you should either supply a context to complete as `input_ids` input "
324 + "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
325 + )
326 + input_ids = torch.full(
327 + (batch_size, 1),
328 + bos_token_id,
329 + dtype=torch.long,
330 + device=next(self.parameters()).device,
331 + )
332 + else:
333 + assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
334 +
335 + # not allow to duplicate outputs when greedy decoding
336 + if do_sample is False:
337 + if num_beams == 1:
338 + # no_beam_search greedy generation conditions
339 + assert (
340 + num_return_sequences == 1
341 + ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
342 +
343 + else:
344 + # beam_search greedy generation conditions
345 + assert (
346 + num_beams >= num_return_sequences
347 + ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
348 +
349 + # create attention mask if necessary
350 + # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
351 + if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
352 + attention_mask = input_ids.ne(pad_token_id).long()
353 + elif attention_mask is None:
354 + attention_mask = input_ids.new_ones(input_ids.shape)
355 +
356 + # set pad_token_id to eos_token_id if not set. Important that this is done after
357 + # attention_mask is created
358 + if pad_token_id is None and eos_token_id is not None:
359 + logger.warning(
360 + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
361 + )
362 + pad_token_id = eos_token_id
363 +
364 + # current position and vocab size
365 + if hasattr(self.config, "vocab_size"):
366 + vocab_size = self.config.vocab_size
367 + elif (
368 + self.config.is_encoder_decoder
369 + and hasattr(self.config, "decoder")
370 + and hasattr(self.config.decoder, "vocab_size")
371 + ):
372 + vocab_size = self.config.decoder.vocab_size
373 +
374 + # set effective batch size and effective batch multiplier according to do_sample
375 + if do_sample:
376 + effective_batch_size = batch_size * num_return_sequences
377 + effective_batch_mult = num_return_sequences
378 + else:
379 + effective_batch_size = batch_size
380 + effective_batch_mult = 1
381 +
382 + if self.config.is_encoder_decoder:
383 + if decoder_start_token_id is None:
384 + # see if BOS token can be used for decoder_start_token_id
385 + if bos_token_id is not None:
386 + decoder_start_token_id = bos_token_id
387 + elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
388 + decoder_start_token_id = self.config.decoder.bos_token_id
389 + else:
390 + raise ValueError(
391 + "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
392 + )
393 +
394 + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
395 + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
396 +
397 + # get encoder and store encoder outputs
398 + encoder = self.get_encoder()
399 + encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
400 +
401 + # Expand input ids if num_beams > 1 or num_return_sequences > 1
402 + if num_return_sequences > 1 or num_beams > 1:
403 + input_ids_len = input_ids.shape[-1]
404 + input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
405 + attention_mask = attention_mask.unsqueeze(1).expand(
406 + batch_size, effective_batch_mult * num_beams, input_ids_len
407 + )
408 +
409 + input_ids = input_ids.contiguous().view(
410 + effective_batch_size * num_beams, input_ids_len
411 + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
412 + attention_mask = attention_mask.contiguous().view(
413 + effective_batch_size * num_beams, input_ids_len
414 + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
415 +
416 + if self.config.is_encoder_decoder:
417 + # create empty decoder_input_ids
418 + input_ids = torch.full(
419 + (effective_batch_size * num_beams, 1),
420 + decoder_start_token_id,
421 + dtype=torch.long,
422 + device=next(self.parameters()).device,
423 + )
424 + cur_len = 1
425 +
426 + assert (
427 + batch_size == encoder_outputs.last_hidden_state.shape[0]
428 + ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
429 +
430 + # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
431 + expanded_batch_idxs = (
432 + torch.arange(batch_size)
433 + .view(-1, 1)
434 + .repeat(1, num_beams * effective_batch_mult)
435 + .view(-1)
436 + .to(input_ids.device)
437 + )
438 +
439 + # expand encoder_outputs
440 + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
441 + 0, expanded_batch_idxs
442 + )
443 +
444 + # save encoder_outputs in `model_kwargs`
445 + model_kwargs["encoder_outputs"] = encoder_outputs
446 +
447 + else:
448 + cur_len = input_ids.shape[-1]
449 +
450 + assert (
451 + cur_len < max_length
452 + ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
453 +
454 + if num_beams > 1:
455 + output = self._generate_beam_search(
456 + input_ids,
457 + cur_len=cur_len,
458 + max_length=max_length,
459 + min_length=min_length,
460 + do_sample=do_sample,
461 + early_stopping=early_stopping,
462 + temperature=temperature,
463 + top_k=top_k,
464 + top_p=top_p,
465 + repetition_penalty=repetition_penalty,
466 + no_repeat_ngram_size=no_repeat_ngram_size,
467 + bad_words_ids=bad_words_ids,
468 + pad_token_id=pad_token_id,
469 + eos_token_id=eos_token_id,
470 + batch_size=effective_batch_size,
471 + num_return_sequences=num_return_sequences,
472 + length_penalty=length_penalty,
473 + num_beams=num_beams,
474 + vocab_size=vocab_size,
475 + attention_mask=attention_mask,
476 + use_cache=use_cache,
477 + model_kwargs=model_kwargs,
478 + )
479 + else:
480 + output = self._generate_no_beam_search(
481 + input_ids,
482 + cur_len=cur_len,
483 + max_length=max_length,
484 + min_length=min_length,
485 + do_sample=do_sample,
486 + temperature=temperature,
487 + top_k=top_k,
488 + top_p=top_p,
489 + repetition_penalty=repetition_penalty,
490 + no_repeat_ngram_size=no_repeat_ngram_size,
491 + bad_words_ids=bad_words_ids,
492 + pad_token_id=pad_token_id,
493 + eos_token_id=eos_token_id,
494 + batch_size=effective_batch_size,
495 + attention_mask=attention_mask,
496 + use_cache=use_cache,
497 + model_kwargs=model_kwargs,
498 + )
499 +
500 + return output
501 +
502 + def _generate_no_beam_search(
503 + self,
504 + input_ids,
505 + cur_len,
506 + max_length,
507 + min_length,
508 + do_sample,
509 + temperature,
510 + top_k,
511 + top_p,
512 + repetition_penalty,
513 + no_repeat_ngram_size,
514 + bad_words_ids,
515 + pad_token_id,
516 + eos_token_id,
517 + batch_size,
518 + attention_mask,
519 + use_cache,
520 + model_kwargs,
521 + ):
522 + """Generate sequences for each example without beam search (num_beams == 1).
523 + All returned sequence are generated independantly.
524 + """
525 + # length of generated sentences / unfinished sentences
526 + unfinished_sents = input_ids.new(batch_size).fill_(1)
527 + sent_lengths = input_ids.new(batch_size).fill_(max_length)
528 +
529 + past = None
530 + while cur_len < max_length:
531 + model_inputs = self.prepare_inputs_for_generation(
532 + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
533 + )
534 +
535 + outputs = self(**model_inputs, return_dict=True)
536 + next_token_logits = outputs.logits[:, -1, :]
537 +
538 + scores = self.postprocess_next_token_scores(
539 + scores=next_token_logits,
540 + input_ids=input_ids,
541 + no_repeat_ngram_size=no_repeat_ngram_size,
542 + bad_words_ids=bad_words_ids,
543 + cur_len=cur_len,
544 + min_length=min_length,
545 + max_length=max_length,
546 + eos_token_id=eos_token_id,
547 + repetition_penalty=repetition_penalty,
548 + batch_size=batch_size,
549 + num_beams=1,
550 + )
551 +
552 + # if model has past, then set the past variable to speed up decoding
553 + if "past_key_values" in outputs:
554 + past = outputs.past_key_values
555 + elif "mems" in outputs:
556 + past = outputs.mems
557 +
558 + if do_sample:
559 + # Temperature (higher temperature => more likely to sample low probability tokens)
560 + if temperature != 1.0:
561 + scores = scores / temperature
562 + # Top-p/top-k filtering
563 + next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
564 + # Sample
565 + probs = F.softmax(next_token_logscores, dim=-1)
566 + next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
567 + else:
568 + # Greedy decoding
569 + next_token = torch.argmax(next_token_logits, dim=-1)
570 +
571 + # update generations and finished sentences
572 + if eos_token_id is not None:
573 + # pad finished sentences if eos_token_id exist
574 + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
575 + else:
576 + tokens_to_add = next_token
577 +
578 + # add token and increase length by one
579 + input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
580 + cur_len = cur_len + 1
581 +
582 + if eos_token_id is not None:
583 + eos_in_sents = tokens_to_add == eos_token_id
584 + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
585 + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
586 + sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
587 + # unfinished_sents is set to zero if eos in sentence
588 + unfinished_sents.mul_((~eos_in_sents).long())
589 +
590 + # stop when there is a </s> in each sentence, or if we exceed the maximul length
591 + if unfinished_sents.max() == 0:
592 + break
593 +
594 + # extend attention_mask for new generated input if only decoder
595 + if self.config.is_encoder_decoder is False:
596 + attention_mask = torch.cat(
597 + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
598 + )
599 +
600 + return input_ids
601 +
602 + def _generate_beam_search(
603 + self,
604 + input_ids,
605 + cur_len,
606 + max_length,
607 + min_length,
608 + do_sample,
609 + early_stopping,
610 + temperature,
611 + top_k,
612 + top_p,
613 + repetition_penalty,
614 + no_repeat_ngram_size,
615 + bad_words_ids,
616 + pad_token_id,
617 + eos_token_id,
618 + batch_size,
619 + num_return_sequences,
620 + length_penalty,
621 + num_beams,
622 + vocab_size,
623 + attention_mask,
624 + use_cache,
625 + model_kwargs,
626 + ):
627 + """Generate sequences for each example with beam search."""
628 +
629 + # generated hypotheses
630 + generated_hyps = [
631 + BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
632 + for _ in range(batch_size)
633 + ]
634 +
635 + # scores for each sentence in the beam
636 + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
637 +
638 + # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
639 + if do_sample is False:
640 + beam_scores[:, 1:] = -1e9
641 + beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
642 +
643 + # cache compute states
644 + past = None
645 +
646 + # done sentences
647 + done = [False for _ in range(batch_size)]
648 +
649 + while cur_len < max_length:
650 + model_inputs = self.prepare_inputs_for_generation(
651 + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
652 + )
653 + outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
654 + next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
655 +
656 + # if model has past, then set the past variable to speed up decoding
657 + if "past_key_values" in outputs:
658 + past = outputs.past_key_values
659 + elif "mems" in outputs:
660 + past = outputs.mems
661 +
662 + if self.config.is_encoder_decoder and do_sample is False:
663 + # TODO (PVP) still a bit hacky here - there might be a better solution
664 + next_token_logits = self.adjust_logits_during_generation(
665 + next_token_logits, cur_len=cur_len, max_length=max_length
666 + )
667 +
668 + scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
669 +
670 + scores = self.postprocess_next_token_scores(
671 + scores=scores,
672 + input_ids=input_ids,
673 + no_repeat_ngram_size=no_repeat_ngram_size,
674 + bad_words_ids=bad_words_ids,
675 + cur_len=cur_len,
676 + min_length=min_length,
677 + max_length=max_length,
678 + eos_token_id=eos_token_id,
679 + repetition_penalty=repetition_penalty,
680 + batch_size=batch_size,
681 + num_beams=num_beams,
682 + )
683 +
684 + assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
685 + scores.shape, (batch_size * num_beams, vocab_size)
686 + )
687 +
688 + if do_sample:
689 + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
690 + # Temperature
691 + if temperature != 1.0:
692 + _scores = _scores / temperature
693 + # Top-p/top-k filtering
694 + _scores = top_k_top_p_filtering(
695 + _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
696 + ) # (batch_size * num_beams, vocab_size)
697 + # re-organize to group the beam together to sample from all beam_idxs
698 + _scores = _scores.contiguous().view(
699 + batch_size, num_beams * vocab_size
700 + ) # (batch_size, num_beams * vocab_size)
701 +
702 + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
703 + probs = F.softmax(_scores, dim=-1)
704 + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
705 + # Compute next scores
706 + next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
707 + # sort the sampled vector to make sure that the first num_beams samples are the best
708 + next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
709 + next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
710 +
711 + else:
712 + next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
713 +
714 + # re-organize to group the beam together (we are keeping top hypothesis accross beams)
715 + next_scores = next_scores.view(
716 + batch_size, num_beams * vocab_size
717 + ) # (batch_size, num_beams * vocab_size)
718 +
719 + next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
720 +
721 + assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
722 +
723 + # next batch beam content
724 + next_batch_beam = []
725 +
726 + # for each sentence
727 + for batch_idx in range(batch_size):
728 +
729 + # if we are done with this sentence, add a pad token
730 + if done[batch_idx]:
731 + assert (
732 + len(generated_hyps[batch_idx]) >= num_beams
733 + ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
734 + assert (
735 + eos_token_id is not None and pad_token_id is not None
736 + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
737 + next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
738 + continue
739 +
740 + # next sentence beam content, this will get added to next_batch_beam
741 + next_sent_beam = []
742 +
743 + # next tokens for this sentence
744 + for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
745 + zip(next_tokens[batch_idx], next_scores[batch_idx])
746 + ):
747 + # get beam and token IDs
748 + beam_id = beam_token_id // vocab_size
749 + token_id = beam_token_id % vocab_size
750 +
751 + effective_beam_id = batch_idx * num_beams + beam_id
752 + # add to generated hypotheses if end of sentence
753 + if (eos_token_id is not None) and (token_id.item() == eos_token_id):
754 + # if beam_token does not belong to top num_beams tokens, it should not be added
755 + is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
756 + if is_beam_token_worse_than_top_num_beams:
757 + continue
758 + generated_hyps[batch_idx].add(
759 + input_ids[effective_beam_id].clone(),
760 + beam_token_score.item(),
761 + )
762 + else:
763 + # add next predicted token since it is not eos_token
764 + next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
765 +
766 + # once the beam for next step is full, don't add more tokens to it.
767 + if len(next_sent_beam) == num_beams:
768 + break
769 +
770 + # Check if we are done so that we can save a pad step if all(done)
771 + done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
772 + next_scores[batch_idx].max().item(), cur_len
773 + )
774 +
775 + # update next beam content
776 + assert len(next_sent_beam) == num_beams, "Beam should always be full"
777 + next_batch_beam.extend(next_sent_beam)
778 + assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
779 +
780 + # stop when we are done with each sentence
781 + if all(done):
782 + break
783 +
784 + # sanity check / prepare next batch
785 + assert len(next_batch_beam) == batch_size * num_beams
786 + beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
787 + beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
788 + beam_idx = input_ids.new([x[2] for x in next_batch_beam])
789 +
790 + # re-order batch and update current length
791 + input_ids = input_ids[beam_idx, :]
792 + input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
793 + cur_len = cur_len + 1
794 +
795 + # re-order internal states
796 + if past is not None:
797 + past = self._reorder_cache(past, beam_idx)
798 +
799 + # extend attention_mask for new generated input if only decoder
800 + if self.config.is_encoder_decoder is False:
801 + attention_mask = torch.cat(
802 + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
803 + )
804 +
805 + # finalize all open beam hypotheses and add to generated hypotheses
806 + for batch_idx in range(batch_size):
807 + if done[batch_idx]:
808 + continue
809 +
810 + # test that beam scores match previously calculated scores if not eos and batch_idx not done
811 + if eos_token_id is not None and all(
812 + (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
813 + ):
814 + assert torch.all(
815 + next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
816 + ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
817 + next_scores[:, :num_beams][batch_idx],
818 + beam_scores.view(batch_size, num_beams)[batch_idx],
819 + )
820 +
821 + # need to add best num_beams hypotheses to generated hyps
822 + for beam_id in range(num_beams):
823 + effective_beam_id = batch_idx * num_beams + beam_id
824 + final_score = beam_scores[effective_beam_id].item()
825 + final_tokens = input_ids[effective_beam_id]
826 + generated_hyps[batch_idx].add(final_tokens, final_score)
827 +
828 + # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
829 + output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
830 + output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
831 +
832 + # select the best hypotheses
833 + sent_lengths = input_ids.new(output_batch_size)
834 + best = []
835 +
836 + # retrieve best hypotheses
837 + for i, hypotheses in enumerate(generated_hyps):
838 + sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
839 + for j in range(output_num_return_sequences_per_batch):
840 + effective_batch_idx = output_num_return_sequences_per_batch * i + j
841 + best_hyp = sorted_hyps.pop()[1]
842 + sent_lengths[effective_batch_idx] = len(best_hyp)
843 + best.append(best_hyp)
844 +
845 + # shorter batches are padded
846 + if sent_lengths.min().item() != sent_lengths.max().item():
847 + assert pad_token_id is not None, "`Pad_token_id` has to be defined"
848 + sent_max_len = min(sent_lengths.max().item() + 1, max_length)
849 + decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
850 +
851 + # fill with hypothesis and eos_token_id if necessary
852 + for i, hypo in enumerate(best):
853 + decoded[i, : sent_lengths[i]] = hypo
854 + if sent_lengths[i] < max_length:
855 + decoded[i, sent_lengths[i]] = eos_token_id
856 + else:
857 + # none of the hypotheses have an eos_token
858 + assert (len(hypo) == max_length for hypo in best)
859 + decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
860 +
861 + return decoded
862 +
863 + @staticmethod
864 + def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
865 + return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
866 +
867 +
868 +def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
869 + """Copied from fairseq for no_repeat_ngram in beam_search"""
870 + if cur_len + 1 < no_repeat_ngram_size:
871 + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
872 + return [[] for _ in range(num_hypos)]
873 + generated_ngrams = [{} for _ in range(num_hypos)]
874 + for idx in range(num_hypos):
875 + gen_tokens = prev_input_ids[idx].tolist()
876 + generated_ngram = generated_ngrams[idx]
877 + for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
878 + prev_ngram_tuple = tuple(ngram[:-1])
879 + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
880 +
881 + def _get_generated_ngrams(hypo_idx):
882 + # Before decoding the next token, prevent decoding of ngrams that have already appeared
883 + start_idx = cur_len + 1 - no_repeat_ngram_size
884 + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
885 + return generated_ngrams[hypo_idx].get(ngram_idx, [])
886 +
887 + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
888 + return banned_tokens
889 +
890 +
891 +def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
892 + banned_tokens = []
893 +
894 + def _tokens_match(prev_tokens, tokens):
895 + if len(tokens) == 0:
896 + # if bad word tokens is just one token always ban it
897 + return True
898 + if len(tokens) > len(prev_tokens):
899 + # if bad word tokens are longer than prev tokens they can't be equal
900 + return False
901 +
902 + if prev_tokens[-len(tokens) :] == tokens:
903 + # if tokens match
904 + return True
905 + else:
906 + return False
907 +
908 + for prev_input_ids_slice in prev_input_ids:
909 + banned_tokens_slice = []
910 +
911 + for banned_token_seq in bad_words_ids:
912 + assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
913 + bad_words_ids
914 + )
915 +
916 + if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
917 + # if tokens do not match continue
918 + continue
919 +
920 + banned_tokens_slice.append(banned_token_seq[-1])
921 +
922 + banned_tokens.append(banned_tokens_slice)
923 +
924 + return banned_tokens
925 +
926 +
927 +def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
928 + """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
929 + a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
930 + Args:
931 + scores: logits distribution of shape (batch size, vocabulary size)
932 + banned_tokens: list of list of tokens to ban of length (batch_size)
933 + """
934 + banned_mask_list = []
935 + for idx, batch_banned_tokens in enumerate(banned_tokens):
936 + for token in batch_banned_tokens:
937 + banned_mask_list.append([idx, token])
938 + if not banned_mask_list:
939 + return
940 + banned_mask = torch.LongTensor(banned_mask_list)
941 + indices = torch.ones(len(banned_mask))
942 + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
943 + # [ 0 1 1 ]
944 + # [ 0 0 0 ]
945 + # [ 1 0 0 ]
946 +
947 + banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
948 + scores.masked_fill_(banned_mask, -float("inf"))
949 +
950 +
951 +def top_k_top_p_filtering(
952 + logits: Tensor,
953 + top_k: int = 0,
954 + top_p: float = 1.0,
955 + filter_value: float = -float("Inf"),
956 + min_tokens_to_keep: int = 1,
957 +) -> Tensor:
958 + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
959 + Args:
960 + logits: logits distribution shape (batch size, vocabulary size)
961 + if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
962 + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
963 + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
964 + Make sure we keep at least min_tokens_to_keep per batch example in the output
965 + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
966 + """
967 + if top_k > 0:
968 + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
969 + # Remove all tokens with a probability less than the last token of the top-k
970 + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
971 + logits[indices_to_remove] = filter_value
972 +
973 + if top_p < 1.0:
974 + sorted_logits, sorted_indices = torch.sort(logits, descending=True)
975 + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
976 +
977 + # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
978 + sorted_indices_to_remove = cumulative_probs > top_p
979 + if min_tokens_to_keep > 1:
980 + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
981 + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
982 + # Shift the indices to the right to keep also the first token above the threshold
983 + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
984 + sorted_indices_to_remove[..., 0] = 0
985 +
986 + # scatter sorted tensors to original indexing
987 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
988 + logits[indices_to_remove] = filter_value
989 + return logits
990 +
991 +
992 +class BeamHypotheses(object):
993 + def __init__(self, num_beams, max_length, length_penalty, early_stopping):
994 + """
995 + Initialize n-best list of hypotheses.
996 + """
997 + self.max_length = max_length - 1 # ignoring bos_token
998 + self.length_penalty = length_penalty
999 + self.early_stopping = early_stopping
1000 + self.num_beams = num_beams
1001 + self.beams = []
1002 + self.worst_score = 1e9
1003 +
1004 + def __len__(self):
1005 + """
1006 + Number of hypotheses in the list.
1007 + """
1008 + return len(self.beams)
1009 +
1010 + def add(self, hyp, sum_logprobs):
1011 + """
1012 + Add a new hypothesis to the list.
1013 + """
1014 + score = sum_logprobs / len(hyp) ** self.length_penalty
1015 + if len(self) < self.num_beams or score > self.worst_score:
1016 + self.beams.append((score, hyp))
1017 + if len(self) > self.num_beams:
1018 + sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
1019 + del self.beams[sorted_scores[0][1]]
1020 + self.worst_score = sorted_scores[1][0]
1021 + else:
1022 + self.worst_score = min(score, self.worst_score)
1023 +
1024 + def is_done(self, best_sum_logprobs, cur_len):
1025 + """
1026 + If there are enough hypotheses and that none of the hypotheses being generated
1027 + can become better than the worst one in the heap, then we are done with this sentence.
1028 + """
1029 +
1030 + if len(self) < self.num_beams:
1031 + return False
1032 + elif self.early_stopping:
1033 + return True
1034 + else:
1035 + cur_score = best_sum_logprobs / cur_len ** self.length_penalty
1036 + ret = self.worst_score >= cur_score
1037 + return ret
...@@ -21,6 +21,8 @@ from transformers import ( ...@@ -21,6 +21,8 @@ from transformers import (
21 PretrainedConfig, 21 PretrainedConfig,
22 PreTrainedTokenizer, 22 PreTrainedTokenizer,
23 ) 23 )
24 +from modeling_bart import BartForConditionalGeneration
25 +
24 from transformers.optimization import ( 26 from transformers.optimization import (
25 Adafactor, 27 Adafactor,
26 get_cosine_schedule_with_warmup, 28 get_cosine_schedule_with_warmup,
...@@ -40,7 +42,7 @@ MODEL_MODES = { ...@@ -40,7 +42,7 @@ MODEL_MODES = {
40 "pretraining": AutoModelForPreTraining, 42 "pretraining": AutoModelForPreTraining,
41 "token-classification": AutoModelForTokenClassification, 43 "token-classification": AutoModelForTokenClassification,
42 "language-modeling": AutoModelWithLMHead, 44 "language-modeling": AutoModelWithLMHead,
43 - "summarization": AutoModelForSeq2SeqLM, 45 + "summarization": BartForConditionalGeneration,
44 "translation": AutoModelForSeq2SeqLM, 46 "translation": AutoModelForSeq2SeqLM,
45 } 47 }
46 48
......
1 +# coding=utf-8
2 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
3 +#
4 +# Licensed under the Apache License, Version 2.0 (the "License");
5 +# you may not use this file except in compliance with the License.
6 +# You may obtain a copy of the License at
7 +#
8 +# http://www.apache.org/licenses/LICENSE-2.0
9 +#
10 +# Unless required by applicable law or agreed to in writing, software
11 +# distributed under the License is distributed on an "AS IS" BASIS,
12 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 +# See the License for the specific language governing permissions and
14 +# limitations under the License.
15 +"""PyTorch BART model, ported from the fairseq repo."""
16 +import math
17 +import random
18 +import warnings
19 +from typing import Dict, List, Optional, Tuple
20 +
21 +import numpy as np
22 +import torch
23 +import torch.nn.functional as F
24 +from torch import Tensor, nn
25 +from torch.nn import CrossEntropyLoss
26 +
27 +from transformers.activations import ACT2FN
28 +from transformers.configuration_bart import BartConfig
29 +from transformers.file_utils import (
30 + add_code_sample_docstrings,
31 + add_end_docstrings,
32 + add_start_docstrings,
33 + add_start_docstrings_to_callable,
34 + replace_return_docstrings,
35 +)
36 +from transformers.modeling_outputs import (
37 + BaseModelOutput,
38 + BaseModelOutputWithPast,
39 + Seq2SeqLMOutput,
40 + Seq2SeqModelOutput,
41 + Seq2SeqQuestionAnsweringModelOutput,
42 + Seq2SeqSequenceClassifierOutput,
43 +)
44 +from modeling_utils import PreTrainedModel
45 +import logging
46 +
47 +logger = logging.getLogger(__name__) # pylint: disable=invalid-name
48 +logging.basicConfig(
49 + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
50 + datefmt="%m/%d/%Y %H:%M:%S",
51 + level=logging.INFO,
52 +)
53 +
54 +_CONFIG_FOR_DOC = "BartConfig"
55 +_TOKENIZER_FOR_DOC = "BartTokenizer"
56 +
57 +
58 +BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
59 + "facebook/bart-base",
60 + "facebook/bart-large",
61 + "facebook/bart-large-mnli",
62 + "facebook/bart-large-cnn",
63 + "facebook/bart-large-xsum",
64 + "facebook/mbart-large-en-ro",
65 +]
66 +# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
67 +
68 +
69 +BART_START_DOCSTRING = r"""
70 +
71 + This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
72 + refer to the PyTorch documentation for all matters related to general usage and behavior.
73 +
74 + Parameters:
75 + config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
76 + Initializing with a config file does not load the weights associated with the model, only the configuration.
77 + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
78 +
79 +"""
80 +BART_GENERATION_EXAMPLE = r"""
81 + Summarization example::
82 +
83 + from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
84 +
85 + # see ``examples/summarization/bart/run_eval.py`` for a longer example
86 + model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
87 + tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
88 +
89 + ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
90 + inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
91 +
92 + # Generate Summary
93 + summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
94 + print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
95 +
96 +"""
97 +
98 +BART_INPUTS_DOCSTRING = r"""
99 + Args:
100 + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
101 + Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
102 + Padding will be ignored by default should you provide it.
103 + Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
104 + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
105 + Mask to avoid performing attention on padding token indices in input_ids.
106 + Mask values selected in ``[0, 1]``:
107 + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
108 + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
109 + Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
110 + `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
111 + Used in the cross-attention of the decoder.
112 + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
113 + Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
114 + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
115 + Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
116 + If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
117 + See diagram 1 in the paper for more info on the default strategy
118 + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
119 + Contains pre-computed key and value hidden-states of the attention blocks.
120 + Can be used to speed up decoding.
121 + If ``past_key_values`` are used, the user can optionally input only the last
122 + ``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
123 + :obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
124 + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
125 + If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
126 + ``past_key_values``).
127 + output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
128 + If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
129 + output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
130 + If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
131 + return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
132 + If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
133 + plain tuple.
134 +"""
135 +
136 +
137 +def invert_mask(attention_mask):
138 + """Turns 1->0, 0->1, False->True, True-> False"""
139 + assert attention_mask.dim() == 2
140 + return attention_mask.eq(0)
141 +
142 +
143 +def _prepare_bart_decoder_inputs(
144 + config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
145 +):
146 + """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
147 + none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
148 + Note: this is not called during generation
149 + """
150 + pad_token_id = config.pad_token_id
151 + if decoder_input_ids is None:
152 + decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
153 + bsz, tgt_len = decoder_input_ids.size()
154 + if decoder_padding_mask is None:
155 + decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
156 + else:
157 + decoder_padding_mask = invert_mask(decoder_padding_mask)
158 + if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
159 + # never mask leading token, even if it is pad
160 + decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
161 + causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
162 + dtype=causal_mask_dtype, device=decoder_input_ids.device
163 + )
164 + return decoder_input_ids, decoder_padding_mask, causal_mask
165 +
166 +
167 +class PretrainedBartModel(PreTrainedModel):
168 + config_class = BartConfig
169 + base_model_prefix = "model"
170 +
171 + def _init_weights(self, module):
172 + std = self.config.init_std
173 + if isinstance(module, nn.Linear):
174 + module.weight.data.normal_(mean=0.0, std=std)
175 + if module.bias is not None:
176 + module.bias.data.zero_()
177 + elif isinstance(module, SinusoidalPositionalEmbedding):
178 + pass
179 + elif isinstance(module, nn.Embedding):
180 + module.weight.data.normal_(mean=0.0, std=std)
181 + if module.padding_idx is not None:
182 + module.weight.data[module.padding_idx].zero_()
183 +
184 + @property
185 + def dummy_inputs(self):
186 + pad_token = self.config.pad_token_id
187 + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
188 + dummy_inputs = {
189 + "attention_mask": input_ids.ne(pad_token),
190 + "input_ids": input_ids,
191 + }
192 + return dummy_inputs
193 +
194 +
195 +def _make_linear_from_emb(emb):
196 + vocab_size, emb_size = emb.weight.shape
197 + lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
198 + lin_layer.weight.data = emb.weight.data
199 + return lin_layer
200 +
201 +
202 +# Helper Functions, mostly for making masks
203 +def _check_shapes(shape_1, shape2):
204 + if shape_1 != shape2:
205 + raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
206 +
207 +
208 +def shift_tokens_right(input_ids, pad_token_id):
209 + """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
210 + prev_output_tokens = input_ids.clone()
211 + index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
212 + prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
213 + prev_output_tokens[:, 1:] = input_ids[:, :-1]
214 + return prev_output_tokens
215 +
216 +
217 +def make_padding_mask(input_ids, padding_idx=1):
218 + """True for pad tokens"""
219 + padding_mask = input_ids.eq(padding_idx)
220 + if not padding_mask.any():
221 + padding_mask = None
222 + return padding_mask
223 +
224 +
225 +# Helper Modules
226 +
227 +
228 +class EncoderLayer(nn.Module):
229 + def __init__(self, config: BartConfig):
230 + super().__init__()
231 + self.embed_dim = config.d_model
232 + self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
233 + self.normalize_before = config.normalize_before
234 + self.self_attn_layer_norm = LayerNorm(self.embed_dim)
235 + self.dropout = config.dropout
236 + self.activation_fn = ACT2FN[config.activation_function]
237 + self.activation_dropout = config.activation_dropout
238 + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
239 + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
240 + self.final_layer_norm = LayerNorm(self.embed_dim)
241 +
242 + def forward(self, x, encoder_padding_mask, output_attentions=False):
243 + """
244 + Args:
245 + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
246 + encoder_padding_mask (ByteTensor): binary ByteTensor of shape
247 + `(batch, src_len)` where padding elements are indicated by ``1``.
248 + for t_tgt, t_src is excluded (or masked out), =0 means it is
249 + included in attention
250 +
251 + Returns:
252 + encoded output of shape `(seq_len, batch, embed_dim)`
253 + """
254 + residual = x
255 + if self.normalize_before:
256 + x = self.self_attn_layer_norm(x)
257 + x, attn_weights = self.self_attn(
258 + query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
259 + )
260 + x = F.dropout(x, p=self.dropout, training=self.training)
261 + x = residual + x
262 + if not self.normalize_before:
263 + x = self.self_attn_layer_norm(x)
264 +
265 + residual = x
266 + if self.normalize_before:
267 + x = self.final_layer_norm(x)
268 + x = self.activation_fn(self.fc1(x))
269 + x = F.dropout(x, p=self.activation_dropout, training=self.training)
270 + x = self.fc2(x)
271 + x = F.dropout(x, p=self.dropout, training=self.training)
272 + x = residual + x
273 + if not self.normalize_before:
274 + x = self.final_layer_norm(x)
275 + return x, attn_weights
276 +
277 +
278 +class BartEncoder(nn.Module):
279 + """
280 + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
281 + is a :class:`EncoderLayer`.
282 +
283 + Args:
284 + config: BartConfig
285 + """
286 +
287 + def __init__(self, config: BartConfig, embed_tokens):
288 + super().__init__()
289 +
290 + self.dropout = config.dropout
291 + self.layerdrop = config.encoder_layerdrop
292 +
293 + embed_dim = embed_tokens.embedding_dim
294 + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
295 + self.padding_idx = embed_tokens.padding_idx
296 + self.max_source_positions = config.max_position_embeddings
297 +
298 + self.embed_tokens = embed_tokens
299 + if config.static_position_embeddings:
300 + self.embed_positions = SinusoidalPositionalEmbedding(
301 + config.max_position_embeddings, embed_dim, self.padding_idx
302 + )
303 + else:
304 + self.embed_positions = LearnedPositionalEmbedding(
305 + config.max_position_embeddings,
306 + embed_dim,
307 + self.padding_idx,
308 + config.extra_pos_embeddings,
309 + )
310 + self.embed_patches = nn.Embedding(3, config.d_model)
311 + self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
312 + self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
313 + # mbart has one extra layer_norm
314 + self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
315 +
316 + def forward(
317 + self, input_ids, patch_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
318 + ):
319 + """
320 + Args:
321 + input_ids (LongTensor): tokens in the source language of shape
322 + `(batch, src_len)`
323 + attention_mask (torch.LongTensor): indicating which indices are padding tokens.
324 + Returns:
325 + BaseModelOutput or Tuple comprised of:
326 + - **x** (Tensor): the last encoder layer's output of
327 + shape `(src_len, batch, embed_dim)`
328 + - **encoder_states** (tuple(torch.FloatTensor)): all intermediate
329 + hidden states of shape `(src_len, batch, embed_dim)`.
330 + Only populated if *output_hidden_states:* is True.
331 + - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
332 + During training might not be of length n_layers because of layer dropout.
333 + """
334 + # check attention mask and invert
335 + if attention_mask is not None:
336 + attention_mask = invert_mask(attention_mask)
337 +
338 + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
339 + embed_pos = self.embed_positions(input_ids)
340 + embed_patch = self.embed_patches(patch_ids)
341 + x = inputs_embeds + embed_pos + embed_patch
342 + x = self.layernorm_embedding(x)
343 + x = F.dropout(x, p=self.dropout, training=self.training)
344 +
345 + # B x T x C -> T x B x C
346 + x = x.transpose(0, 1)
347 +
348 + encoder_states = [] if output_hidden_states else None
349 + all_attentions = () if output_attentions else None
350 + for encoder_layer in self.layers:
351 + if output_hidden_states:
352 + encoder_states.append(x)
353 + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
354 + dropout_probability = random.uniform(0, 1)
355 + if self.training and (dropout_probability < self.layerdrop): # skip the layer
356 + attn = None
357 + else:
358 + x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
359 +
360 + if output_attentions:
361 + all_attentions = all_attentions + (attn,)
362 +
363 + if self.layer_norm:
364 + x = self.layer_norm(x)
365 + if output_hidden_states:
366 + encoder_states.append(x)
367 + # T x B x C -> B x T x C
368 + encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)
369 +
370 + # T x B x C -> B x T x C
371 + x = x.transpose(0, 1)
372 +
373 + if not return_dict:
374 + return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
375 + return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
376 +
377 +
378 +class DecoderLayer(nn.Module):
379 + def __init__(self, config: BartConfig):
380 + super().__init__()
381 + self.embed_dim = config.d_model
382 +
383 + self.self_attn = Attention(
384 + embed_dim=self.embed_dim,
385 + num_heads=config.decoder_attention_heads,
386 + dropout=config.attention_dropout,
387 + )
388 + self.dropout = config.dropout
389 + self.activation_fn = ACT2FN[config.activation_function]
390 + self.activation_dropout = config.activation_dropout
391 + self.normalize_before = config.normalize_before
392 +
393 + self.self_attn_layer_norm = LayerNorm(self.embed_dim)
394 + self.encoder_attn = Attention(
395 + self.embed_dim,
396 + config.decoder_attention_heads,
397 + dropout=config.attention_dropout,
398 + encoder_decoder_attention=True,
399 + )
400 + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
401 + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
402 + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
403 + self.final_layer_norm = LayerNorm(self.embed_dim)
404 +
405 + def forward(
406 + self,
407 + x,
408 + encoder_hidden_states,
409 + encoder_attn_mask=None,
410 + layer_state=None,
411 + causal_mask=None,
412 + decoder_padding_mask=None,
413 + output_attentions=False,
414 + ):
415 + residual = x
416 +
417 + if layer_state is None:
418 + layer_state = {}
419 + if self.normalize_before:
420 + x = self.self_attn_layer_norm(x)
421 + # Self Attention
422 +
423 + x, self_attn_weights = self.self_attn(
424 + query=x,
425 + key=x,
426 + layer_state=layer_state, # adds keys to layer state
427 + key_padding_mask=decoder_padding_mask,
428 + attn_mask=causal_mask,
429 + output_attentions=output_attentions,
430 + )
431 + x = F.dropout(x, p=self.dropout, training=self.training)
432 + x = residual + x
433 + if not self.normalize_before:
434 + x = self.self_attn_layer_norm(x)
435 +
436 + # Cross attention
437 + residual = x
438 + assert self.encoder_attn.cache_key != self.self_attn.cache_key
439 + if self.normalize_before:
440 + x = self.encoder_attn_layer_norm(x)
441 + x, _ = self.encoder_attn(
442 + query=x,
443 + key=encoder_hidden_states,
444 + key_padding_mask=encoder_attn_mask,
445 + layer_state=layer_state, # mutates layer state
446 + )
447 + x = F.dropout(x, p=self.dropout, training=self.training)
448 + x = residual + x
449 + if not self.normalize_before:
450 + x = self.encoder_attn_layer_norm(x)
451 +
452 + # Fully Connected
453 + residual = x
454 + if self.normalize_before:
455 + x = self.final_layer_norm(x)
456 + x = self.activation_fn(self.fc1(x))
457 + x = F.dropout(x, p=self.activation_dropout, training=self.training)
458 + x = self.fc2(x)
459 + x = F.dropout(x, p=self.dropout, training=self.training)
460 + x = residual + x
461 + if not self.normalize_before:
462 + x = self.final_layer_norm(x)
463 + return (
464 + x,
465 + self_attn_weights,
466 + layer_state,
467 + ) # just self_attn weights for now, following t5, layer_state = cache for decoding
468 +
469 +
470 +class BartDecoder(nn.Module):
471 + """
472 + Transformer decoder consisting of *config.decoder_layers* layers. Each layer
473 + is a :class:`DecoderLayer`.
474 + Args:
475 + config: BartConfig
476 + embed_tokens (torch.nn.Embedding): output embedding
477 + """
478 +
479 + def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
480 + super().__init__()
481 + self.dropout = config.dropout
482 + self.layerdrop = config.decoder_layerdrop
483 + self.padding_idx = embed_tokens.padding_idx
484 + self.max_target_positions = config.max_position_embeddings
485 + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
486 + self.embed_tokens = embed_tokens
487 + if config.static_position_embeddings:
488 + self.embed_positions = SinusoidalPositionalEmbedding(
489 + config.max_position_embeddings, config.d_model, config.pad_token_id
490 + )
491 + else:
492 + self.embed_positions = LearnedPositionalEmbedding(
493 + config.max_position_embeddings,
494 + config.d_model,
495 + self.padding_idx,
496 + config.extra_pos_embeddings,
497 + )
498 + self.layers = nn.ModuleList(
499 + [DecoderLayer(config) for _ in range(config.decoder_layers)]
500 + ) # type: List[DecoderLayer]
501 + self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
502 + self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
503 +
504 + def forward(
505 + self,
506 + input_ids,
507 + encoder_hidden_states,
508 + encoder_padding_mask,
509 + decoder_padding_mask,
510 + decoder_causal_mask,
511 + past_key_values=None,
512 + use_cache=False,
513 + output_attentions=False,
514 + output_hidden_states=False,
515 + return_dict=False,
516 + **unused,
517 + ):
518 + """
519 + Includes several features from "Jointly Learning to Align and
520 + Translate with Transformer Models" (Garg et al., EMNLP 2019).
521 +
522 + Args:
523 + input_ids (LongTensor): previous decoder outputs of shape
524 + `(batch, tgt_len)`, for teacher forcing
525 + encoder_hidden_states: output from the encoder, used for
526 + encoder-side attention
527 + encoder_padding_mask: for ignoring pad tokens
528 + past_key_values (dict or None): dictionary used for storing state during generation
529 +
530 + Returns:
531 + BaseModelOutputWithPast or tuple:
532 + - the decoder's features of shape `(batch, tgt_len, embed_dim)`
533 + - the cache
534 + - hidden states
535 + - attentions
536 + """
537 + if "decoder_cached_states" in unused:
538 + warnings.warn(
539 + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
540 + FutureWarning,
541 + )
542 + past_key_values = unused.pop("decoder_cached_states")
543 + if "decoder_past_key_values" in unused:
544 + warnings.warn(
545 + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
546 + FutureWarning,
547 + )
548 + past_key_values = unused.pop("decoder_past_key_values")
549 +
550 + # check attention mask and invert
551 + if encoder_padding_mask is not None:
552 + encoder_padding_mask = invert_mask(encoder_padding_mask)
553 +
554 + # embed positions
555 + positions = self.embed_positions(input_ids, use_cache=use_cache)
556 +
557 + if use_cache:
558 + input_ids = input_ids[:, -1:]
559 + positions = positions[:, -1:] # happens after we embed them
560 + # assert input_ids.ne(self.padding_idx).any()
561 +
562 + x = self.embed_tokens(input_ids) * self.embed_scale
563 + x += positions
564 + x = self.layernorm_embedding(x)
565 + x = F.dropout(x, p=self.dropout, training=self.training)
566 +
567 + # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
568 + x = x.transpose(0, 1)
569 + encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
570 +
571 + # decoder layers
572 + all_hidden_states = () if output_hidden_states else None
573 + all_self_attns = () if output_attentions else None
574 + next_decoder_cache = []
575 + for idx, decoder_layer in enumerate(self.layers):
576 + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
577 + if output_hidden_states:
578 + all_hidden_states += (x,)
579 + dropout_probability = random.uniform(0, 1)
580 + if self.training and (dropout_probability < self.layerdrop):
581 + continue
582 +
583 + layer_state = past_key_values[idx] if past_key_values is not None else None
584 +
585 + x, layer_self_attn, layer_past = decoder_layer(
586 + x,
587 + encoder_hidden_states,
588 + encoder_attn_mask=encoder_padding_mask,
589 + decoder_padding_mask=decoder_padding_mask,
590 + layer_state=layer_state,
591 + causal_mask=decoder_causal_mask,
592 + output_attentions=output_attentions,
593 + )
594 +
595 + if use_cache:
596 + next_decoder_cache.append(layer_past.copy())
597 +
598 + if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART)
599 + x = self.layer_norm(x)
600 + if output_attentions:
601 + all_self_attns += (layer_self_attn,)
602 +
603 + # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
604 + if output_hidden_states:
605 + all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
606 + x = x.transpose(0, 1)
607 + encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
608 +
609 + next_cache = next_decoder_cache if use_cache else None
610 +
611 + if not return_dict:
612 + return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
613 + return BaseModelOutputWithPast(
614 + last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
615 + )
616 +
617 +
618 +def _reorder_buffer(attn_cache, new_order):
619 + for k, input_buffer_k in attn_cache.items():
620 + if input_buffer_k is not None:
621 + attn_cache[k] = input_buffer_k.index_select(0, new_order)
622 + return attn_cache
623 +
624 +
625 +class Attention(nn.Module):
626 + """Multi-headed attention from 'Attention Is All You Need' paper"""
627 +
628 + def __init__(
629 + self,
630 + embed_dim,
631 + num_heads,
632 + dropout=0.0,
633 + bias=True,
634 + encoder_decoder_attention=False, # otherwise self_attention
635 + ):
636 + super().__init__()
637 + self.embed_dim = embed_dim
638 + self.num_heads = num_heads
639 + self.dropout = dropout
640 + self.head_dim = embed_dim // num_heads
641 + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
642 + self.scaling = self.head_dim ** -0.5
643 +
644 + self.encoder_decoder_attention = encoder_decoder_attention
645 + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
646 + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
647 + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
648 + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
649 + self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
650 +
651 + def _shape(self, tensor, seq_len, bsz):
652 + return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
653 +
654 + def forward(
655 + self,
656 + query,
657 + key: Optional[Tensor],
658 + key_padding_mask: Optional[Tensor] = None,
659 + layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
660 + attn_mask: Optional[Tensor] = None,
661 + output_attentions=False,
662 + ) -> Tuple[Tensor, Optional[Tensor]]:
663 + """Input shape: Time(SeqLen) x Batch x Channel"""
664 + static_kv: bool = self.encoder_decoder_attention
665 + tgt_len, bsz, embed_dim = query.size()
666 + assert embed_dim == self.embed_dim
667 + assert list(query.size()) == [tgt_len, bsz, embed_dim]
668 + # get here for encoder decoder cause of static_kv
669 + if layer_state is not None: # reuse k,v and encoder_padding_mask
670 + saved_state = layer_state.get(self.cache_key, {})
671 + if "prev_key" in saved_state and static_kv:
672 + # previous time steps are cached - no need to recompute key and value if they are static
673 + key = None
674 + else:
675 + saved_state = None
676 + layer_state = {}
677 +
678 + q = self.q_proj(query) * self.scaling
679 + if static_kv:
680 + if key is None:
681 + k = v = None
682 + else:
683 + k = self.k_proj(key)
684 + v = self.v_proj(key)
685 + else:
686 + k = self.k_proj(query)
687 + v = self.v_proj(query)
688 +
689 + q = self._shape(q, tgt_len, bsz)
690 + if k is not None:
691 + k = self._shape(k, -1, bsz)
692 + if v is not None:
693 + v = self._shape(v, -1, bsz)
694 +
695 + if saved_state is not None:
696 + k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
697 +
698 + # Update cache
699 + layer_state[self.cache_key] = {
700 + "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
701 + "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
702 + "prev_key_padding_mask": key_padding_mask if not static_kv else None,
703 + }
704 +
705 + assert k is not None
706 + src_len = k.size(1)
707 + attn_weights = torch.bmm(q, k.transpose(1, 2))
708 + assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
709 +
710 + if attn_mask is not None:
711 + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
712 + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
713 +
714 + # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
715 + if key_padding_mask is not None and key_padding_mask.dim() == 0:
716 + key_padding_mask = None
717 + assert key_padding_mask is None or key_padding_mask.size()[:2] == (
718 + bsz,
719 + src_len,
720 + )
721 +
722 + if key_padding_mask is not None: # don't attend to padding symbols
723 + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
724 + reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
725 + attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
726 + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
727 + attn_weights = F.softmax(attn_weights, dim=-1)
728 + attn_probs = F.dropout(
729 + attn_weights,
730 + p=self.dropout,
731 + training=self.training,
732 + )
733 +
734 + assert v is not None
735 + attn_output = torch.bmm(attn_probs, v)
736 + assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
737 + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
738 + attn_output = self.out_proj(attn_output)
739 + if output_attentions:
740 + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
741 + else:
742 + attn_weights = None
743 + return attn_output, attn_weights
744 +
745 + def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
746 + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
747 + if "prev_key" in saved_state:
748 + _prev_key = saved_state["prev_key"]
749 + assert _prev_key is not None
750 + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
751 + if static_kv:
752 + k = prev_key
753 + else:
754 + assert k is not None
755 + k = torch.cat([prev_key, k], dim=1)
756 + if "prev_value" in saved_state:
757 + _prev_value = saved_state["prev_value"]
758 + assert _prev_value is not None
759 + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
760 + if static_kv:
761 + v = prev_value
762 + else:
763 + assert v is not None
764 + v = torch.cat([prev_value, v], dim=1)
765 + assert k is not None and v is not None
766 + prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
767 + if prev_key_padding_mask is not None:
768 + if static_kv:
769 + new_key_padding_mask = prev_key_padding_mask
770 + else:
771 + new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
772 + else:
773 + new_key_padding_mask = key_padding_mask
774 + return k, v, new_key_padding_mask
775 +
776 +
777 +class BartClassificationHead(nn.Module):
778 + """Head for sentence-level classification tasks."""
779 +
780 + # This can trivially be shared with RobertaClassificationHead
781 +
782 + def __init__(
783 + self,
784 + input_dim,
785 + inner_dim,
786 + num_classes,
787 + pooler_dropout,
788 + ):
789 + super().__init__()
790 + self.dense = nn.Linear(input_dim, inner_dim)
791 + self.dropout = nn.Dropout(p=pooler_dropout)
792 + self.out_proj = nn.Linear(inner_dim, num_classes)
793 +
794 + def forward(self, x):
795 + x = self.dropout(x)
796 + x = self.dense(x)
797 + x = torch.tanh(x)
798 + x = self.dropout(x)
799 + x = self.out_proj(x)
800 + return x
801 +
802 +
803 +class LearnedPositionalEmbedding(nn.Embedding):
804 + """
805 + This module learns positional embeddings up to a fixed maximum size.
806 + Padding ids are ignored by either offsetting based on padding_idx
807 + or by setting padding_idx to None and ensuring that the appropriate
808 + position ids are passed to the forward function.
809 + """
810 +
811 + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
812 + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
813 + # and adjust num_embeddings appropriately. Other models dont have this hack
814 + self.offset = offset
815 + assert padding_idx is not None
816 + num_embeddings += offset
817 + super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
818 +
819 + def forward(self, input_ids, use_cache=False):
820 + """Input is expected to be of size [bsz x seqlen]."""
821 + bsz, seq_len = input_ids.shape[:2]
822 + if use_cache:
823 + positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
824 + else:
825 + # starts at 0, ends at 1-seq_len
826 + positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
827 + return super().forward(positions + self.offset)
828 +
829 +
830 +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
831 + if torch.cuda.is_available():
832 + try:
833 + from apex.normalization import FusedLayerNorm
834 +
835 + return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
836 + except ImportError:
837 + pass
838 + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
839 +
840 +
841 +def fill_with_neg_inf(t):
842 + """FP16-compatible function that fills a input_ids with -inf."""
843 + return t.float().fill_(float("-inf")).type_as(t)
844 +
845 +
846 +# Public API
847 +def _get_shape(t):
848 + return getattr(t, "shape", None)
849 +
850 +
851 +@add_start_docstrings(
852 + "The bare BART Model outputting raw hidden-states without any specific head on top.",
853 + BART_START_DOCSTRING,
854 +)
855 +class BartModel(PretrainedBartModel):
856 + def __init__(self, config: BartConfig):
857 + super().__init__(config)
858 +
859 + padding_idx, vocab_size = config.pad_token_id, config.vocab_size
860 + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
861 +
862 + self.encoder = BartEncoder(config, self.shared)
863 + self.decoder = BartDecoder(config, self.shared)
864 +
865 + self.init_weights()
866 +
867 + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
868 + @add_code_sample_docstrings(
869 + tokenizer_class=_TOKENIZER_FOR_DOC,
870 + checkpoint="facebook/bart-large",
871 + output_type=BaseModelOutputWithPast,
872 + config_class=_CONFIG_FOR_DOC,
873 + )
874 + def forward(
875 + self,
876 + input_ids,
877 + patch_ids=None,
878 + attention_mask=None,
879 + decoder_input_ids=None,
880 + encoder_outputs: Optional[Tuple] = None,
881 + decoder_attention_mask=None,
882 + past_key_values=None,
883 + use_cache=None,
884 + output_attentions=None,
885 + output_hidden_states=None,
886 + return_dict=None,
887 + **kwargs,
888 + ):
889 + if "decoder_past_key_values" in kwargs:
890 + warnings.warn(
891 + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
892 + FutureWarning,
893 + )
894 + past_key_values = kwargs.pop("decoder_past_key_values")
895 +
896 + if decoder_input_ids is None:
897 + use_cache = False
898 +
899 + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
900 + output_hidden_states = (
901 + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
902 + )
903 + use_cache = use_cache if use_cache is not None else self.config.use_cache
904 + return_dict = return_dict if return_dict is not None else self.config.use_return_dict
905 +
906 + # make masks if user doesn't supply
907 + if not use_cache:
908 + decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
909 + self.config,
910 + input_ids,
911 + decoder_input_ids=decoder_input_ids,
912 + decoder_padding_mask=decoder_attention_mask,
913 + causal_mask_dtype=self.shared.weight.dtype,
914 + )
915 + else:
916 + decoder_padding_mask, causal_mask = None, None
917 +
918 + assert decoder_input_ids is not None
919 +
920 + if encoder_outputs is None:
921 + encoder_outputs = self.encoder(
922 + input_ids=input_ids,
923 + patch_ids=patch_ids,
924 + attention_mask=attention_mask,
925 + output_attentions=output_attentions,
926 + output_hidden_states=output_hidden_states,
927 + return_dict=return_dict,
928 + )
929 + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False
930 + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
931 + encoder_outputs = BaseModelOutput(
932 + last_hidden_state=encoder_outputs[0],
933 + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
934 + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
935 + )
936 +
937 + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
938 + decoder_outputs = self.decoder(
939 + decoder_input_ids,
940 + encoder_outputs[0],
941 + attention_mask,
942 + decoder_padding_mask,
943 + decoder_causal_mask=causal_mask,
944 + past_key_values=past_key_values,
945 + use_cache=use_cache,
946 + output_attentions=output_attentions,
947 + output_hidden_states=output_hidden_states,
948 + return_dict=return_dict,
949 + )
950 +
951 + if not return_dict:
952 + return decoder_outputs + encoder_outputs
953 +
954 + return Seq2SeqModelOutput(
955 + last_hidden_state=decoder_outputs.last_hidden_state,
956 + past_key_values=decoder_outputs.past_key_values,
957 + decoder_hidden_states=decoder_outputs.hidden_states,
958 + decoder_attentions=decoder_outputs.attentions,
959 + encoder_last_hidden_state=encoder_outputs.last_hidden_state,
960 + encoder_hidden_states=encoder_outputs.hidden_states,
961 + encoder_attentions=encoder_outputs.attentions,
962 + )
963 +
964 + def get_input_embeddings(self):
965 + return self.shared
966 +
967 + def set_input_embeddings(self, value):
968 + self.shared = value
969 + self.encoder.embed_tokens = self.shared
970 + self.decoder.embed_tokens = self.shared
971 +
972 + def get_output_embeddings(self):
973 + return _make_linear_from_emb(self.shared) # make it on the fly
974 +
975 +
976 +@add_start_docstrings(
977 + "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
978 +)
979 +class BartForConditionalGeneration(PretrainedBartModel):
980 + base_model_prefix = "model"
981 + authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
982 +
983 + def __init__(self, config: BartConfig):
984 + super().__init__(config)
985 + base_model = BartModel(config)
986 + self.model = base_model
987 + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
988 +
989 + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
990 + old_num_tokens = self.model.shared.num_embeddings
991 + new_embeddings = super().resize_token_embeddings(new_num_tokens)
992 + self.model.shared = new_embeddings
993 + self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
994 + return new_embeddings
995 +
996 + def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
997 + if new_num_tokens <= old_num_tokens:
998 + new_bias = self.final_logits_bias[:, :new_num_tokens]
999 + else:
1000 + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1001 + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1002 + self.register_buffer("final_logits_bias", new_bias)
1003 +
1004 + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1005 + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1006 + @add_end_docstrings(BART_GENERATION_EXAMPLE)
1007 + def forward(
1008 + self,
1009 + input_ids,
1010 + patch_ids,
1011 + attention_mask=None,
1012 + encoder_outputs=None,
1013 + decoder_input_ids=None,
1014 + decoder_attention_mask=None,
1015 + past_key_values=None,
1016 + labels=None,
1017 + use_cache=None,
1018 + output_attentions=None,
1019 + output_hidden_states=None,
1020 + return_dict=None,
1021 + **unused,
1022 + ):
1023 + r"""
1024 + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1025 + Labels for computing the masked language modeling loss.
1026 + Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
1027 + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
1028 + with labels in ``[0, ..., config.vocab_size]``.
1029 +
1030 + Returns:
1031 +
1032 + Conditional generation example::
1033 +
1034 + # Mask filling only works for bart-large
1035 + from transformers import BartTokenizer, BartForConditionalGeneration
1036 + tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
1037 + TXT = "My friends are <mask> but they eat too many carbs."
1038 +
1039 + model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
1040 + input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
1041 + logits = model(input_ids).logits
1042 +
1043 + masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
1044 + probs = logits[0, masked_index].softmax(dim=0)
1045 + values, predictions = probs.topk(5)
1046 +
1047 + tokenizer.decode(predictions).split()
1048 + # ['good', 'great', 'all', 'really', 'very']
1049 + """
1050 + if "lm_labels" in unused:
1051 + warnings.warn(
1052 + "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1053 + FutureWarning,
1054 + )
1055 + labels = unused.pop("lm_labels")
1056 + if "decoder_cached_states" in unused:
1057 + warnings.warn(
1058 + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1059 + FutureWarning,
1060 + )
1061 + past_key_values = unused.pop("decoder_cached_states")
1062 + if "decoder_past_key_values" in unused:
1063 + warnings.warn(
1064 + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1065 + FutureWarning,
1066 + )
1067 + past_key_values = unused.pop("decoder_past_key_values")
1068 + return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1069 +
1070 + if labels is not None:
1071 + use_cache = False
1072 + if decoder_input_ids is None:
1073 + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
1074 +
1075 + outputs = self.model(
1076 + input_ids,
1077 + patch_ids=patch_ids,
1078 + attention_mask=attention_mask,
1079 + decoder_input_ids=decoder_input_ids,
1080 + encoder_outputs=encoder_outputs,
1081 + decoder_attention_mask=decoder_attention_mask,
1082 + past_key_values=past_key_values,
1083 + use_cache=use_cache,
1084 + output_attentions=output_attentions,
1085 + output_hidden_states=output_hidden_states,
1086 + return_dict=return_dict,
1087 + )
1088 + lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
1089 +
1090 + masked_lm_loss = None
1091 + if labels is not None:
1092 + loss_fct = CrossEntropyLoss()
1093 + # TODO(SS): do we need to ignore pad tokens in labels?
1094 + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1095 +
1096 + if not return_dict:
1097 + output = (lm_logits,) + outputs[1:]
1098 + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1099 +
1100 + return Seq2SeqLMOutput(
1101 + loss=masked_lm_loss,
1102 + logits=lm_logits,
1103 + past_key_values=outputs.past_key_values,
1104 + decoder_hidden_states=outputs.decoder_hidden_states,
1105 + decoder_attentions=outputs.decoder_attentions,
1106 + encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1107 + encoder_hidden_states=outputs.encoder_hidden_states,
1108 + encoder_attentions=outputs.encoder_attentions,
1109 + )
1110 +
1111 + def prepare_inputs_for_generation(
1112 + self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
1113 + ):
1114 + return {
1115 + "input_ids": None, # encoder_outputs is defined. input_ids not needed
1116 + "encoder_outputs": encoder_outputs,
1117 + "past_key_values": past,
1118 + "decoder_input_ids": decoder_input_ids,
1119 + "attention_mask": attention_mask,
1120 + "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1121 + }
1122 +
1123 + def adjust_logits_during_generation(self, logits, cur_len, max_length):
1124 + if cur_len == 1 and self.config.force_bos_token_to_be_generated:
1125 + self._force_token_ids_generation(logits, self.config.bos_token_id)
1126 + elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
1127 + self._force_token_ids_generation(logits, self.config.eos_token_id)
1128 + return logits
1129 +
1130 + def _force_token_ids_generation(self, scores, token_id) -> None:
1131 + """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
1132 + scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
1133 +
1134 + @staticmethod
1135 + def _reorder_cache(past, beam_idx):
1136 + reordered_past = []
1137 + for layer_past in past:
1138 + # get the correct batch idx from decoder layer's batch dim for cross and self-attn
1139 + layer_past_new = {
1140 + attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
1141 + }
1142 + reordered_past.append(layer_past_new)
1143 + return reordered_past
1144 +
1145 + def get_encoder(self):
1146 + return self.model.encoder
1147 +
1148 + def get_output_embeddings(self):
1149 + return _make_linear_from_emb(self.model.shared) # make it on the fly
1150 +
1151 +
1152 +@add_start_docstrings(
1153 + """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
1154 + BART_START_DOCSTRING,
1155 +)
1156 +class BartForSequenceClassification(PretrainedBartModel):
1157 + def __init__(self, config: BartConfig, **kwargs):
1158 + super().__init__(config, **kwargs)
1159 + self.model = BartModel(config)
1160 + self.classification_head = BartClassificationHead(
1161 + config.d_model,
1162 + config.d_model,
1163 + config.num_labels,
1164 + config.classif_dropout,
1165 + )
1166 + self.model._init_weights(self.classification_head.dense)
1167 + self.model._init_weights(self.classification_head.out_proj)
1168 +
1169 + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1170 + @add_code_sample_docstrings(
1171 + tokenizer_class=_TOKENIZER_FOR_DOC,
1172 + checkpoint="facebook/bart-large",
1173 + output_type=Seq2SeqSequenceClassifierOutput,
1174 + config_class=_CONFIG_FOR_DOC,
1175 + )
1176 + def forward(
1177 + self,
1178 + input_ids,
1179 + attention_mask=None,
1180 + encoder_outputs=None,
1181 + decoder_input_ids=None,
1182 + decoder_attention_mask=None,
1183 + labels=None,
1184 + use_cache=None,
1185 + output_attentions=None,
1186 + output_hidden_states=None,
1187 + return_dict=None,
1188 + ):
1189 + r"""
1190 + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1191 + Labels for computing the sequence classification/regression loss.
1192 + Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1193 + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1194 + """
1195 + return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1196 + if labels is not None:
1197 + use_cache = False
1198 +
1199 + outputs = self.model(
1200 + input_ids,
1201 + attention_mask=attention_mask,
1202 + decoder_input_ids=decoder_input_ids,
1203 + decoder_attention_mask=decoder_attention_mask,
1204 + encoder_outputs=encoder_outputs,
1205 + use_cache=use_cache,
1206 + output_attentions=output_attentions,
1207 + output_hidden_states=output_hidden_states,
1208 + return_dict=return_dict,
1209 + )
1210 + x = outputs[0] # last hidden state
1211 + eos_mask = input_ids.eq(self.config.eos_token_id)
1212 + if len(torch.unique(eos_mask.sum(1))) > 1:
1213 + raise ValueError("All examples must have the same number of <eos> tokens.")
1214 + sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
1215 + logits = self.classification_head(sentence_representation)
1216 +
1217 + loss = None
1218 + if labels is not None:
1219 + loss_fct = CrossEntropyLoss()
1220 + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1221 +
1222 + if not return_dict:
1223 + output = (logits,) + outputs[1:]
1224 + return ((loss,) + output) if loss is not None else output
1225 +
1226 + return Seq2SeqSequenceClassifierOutput(
1227 + loss=loss,
1228 + logits=logits,
1229 + past_key_values=outputs.past_key_values,
1230 + decoder_hidden_states=outputs.decoder_hidden_states,
1231 + decoder_attentions=outputs.decoder_attentions,
1232 + encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1233 + encoder_hidden_states=outputs.encoder_hidden_states,
1234 + encoder_attentions=outputs.encoder_attentions,
1235 + )
1236 +
1237 +
1238 +@add_start_docstrings(
1239 + """BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
1240 + the hidden-states output to compute `span start logits` and `span end logits`). """,
1241 + BART_START_DOCSTRING,
1242 +)
1243 +class BartForQuestionAnswering(PretrainedBartModel):
1244 + def __init__(self, config):
1245 + super().__init__(config)
1246 +
1247 + config.num_labels = 2
1248 + self.num_labels = config.num_labels
1249 +
1250 + self.model = BartModel(config)
1251 + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1252 +
1253 + self.model._init_weights(self.qa_outputs)
1254 +
1255 + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1256 + @add_code_sample_docstrings(
1257 + tokenizer_class=_TOKENIZER_FOR_DOC,
1258 + checkpoint="facebook/bart-large",
1259 + output_type=Seq2SeqQuestionAnsweringModelOutput,
1260 + config_class=_CONFIG_FOR_DOC,
1261 + )
1262 + def forward(
1263 + self,
1264 + input_ids,
1265 + attention_mask=None,
1266 + encoder_outputs=None,
1267 + decoder_input_ids=None,
1268 + decoder_attention_mask=None,
1269 + start_positions=None,
1270 + end_positions=None,
1271 + use_cache=None,
1272 + output_attentions=None,
1273 + output_hidden_states=None,
1274 + return_dict=None,
1275 + ):
1276 + r"""
1277 + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1278 + Labels for position (index) of the start of the labelled span for computing the token classification loss.
1279 + Positions are clamped to the length of the sequence (`sequence_length`).
1280 + Position outside of the sequence are not taken into account for computing the loss.
1281 + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1282 + Labels for position (index) of the end of the labelled span for computing the token classification loss.
1283 + Positions are clamped to the length of the sequence (`sequence_length`).
1284 + Position outside of the sequence are not taken into account for computing the loss.
1285 + """
1286 + return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1287 + if start_positions is not None and end_positions is not None:
1288 + use_cache = False
1289 +
1290 + outputs = self.model(
1291 + input_ids,
1292 + attention_mask=attention_mask,
1293 + decoder_input_ids=decoder_input_ids,
1294 + decoder_attention_mask=decoder_attention_mask,
1295 + encoder_outputs=encoder_outputs,
1296 + use_cache=use_cache,
1297 + output_attentions=output_attentions,
1298 + output_hidden_states=output_hidden_states,
1299 + return_dict=return_dict,
1300 + )
1301 +
1302 + sequence_output = outputs[0]
1303 +
1304 + logits = self.qa_outputs(sequence_output)
1305 + start_logits, end_logits = logits.split(1, dim=-1)
1306 + start_logits = start_logits.squeeze(-1)
1307 + end_logits = end_logits.squeeze(-1)
1308 +
1309 + total_loss = None
1310 + if start_positions is not None and end_positions is not None:
1311 + # If we are on multi-GPU, split add a dimension
1312 + if len(start_positions.size()) > 1:
1313 + start_positions = start_positions.squeeze(-1)
1314 + if len(end_positions.size()) > 1:
1315 + end_positions = end_positions.squeeze(-1)
1316 + # sometimes the start/end positions are outside our model inputs, we ignore these terms
1317 + ignored_index = start_logits.size(1)
1318 + start_positions.clamp_(0, ignored_index)
1319 + end_positions.clamp_(0, ignored_index)
1320 +
1321 + loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1322 + start_loss = loss_fct(start_logits, start_positions)
1323 + end_loss = loss_fct(end_logits, end_positions)
1324 + total_loss = (start_loss + end_loss) / 2
1325 +
1326 + if not return_dict:
1327 + output = (
1328 + start_logits,
1329 + end_logits,
1330 + ) + outputs[1:]
1331 + return ((total_loss,) + output) if total_loss is not None else output
1332 +
1333 + return Seq2SeqQuestionAnsweringModelOutput(
1334 + loss=total_loss,
1335 + start_logits=start_logits,
1336 + end_logits=end_logits,
1337 + past_key_values=outputs.past_key_values,
1338 + decoder_hidden_states=outputs.decoder_hidden_states,
1339 + decoder_attentions=outputs.decoder_attentions,
1340 + encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1341 + encoder_hidden_states=outputs.encoder_hidden_states,
1342 + encoder_attentions=outputs.encoder_attentions,
1343 + )
1344 +
1345 +
1346 +class SinusoidalPositionalEmbedding(nn.Embedding):
1347 + """This module produces sinusoidal positional embeddings of any length."""
1348 +
1349 + def __init__(self, num_positions, embedding_dim, padding_idx=None):
1350 + super().__init__(num_positions, embedding_dim)
1351 + if embedding_dim % 2 != 0:
1352 + raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
1353 + self.weight = self._init_weight(self.weight)
1354 +
1355 + @staticmethod
1356 + def _init_weight(out: nn.Parameter):
1357 + """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
1358 + The cos features are in the 2nd half of the vector. [dim // 2:]
1359 + """
1360 + n_pos, dim = out.shape
1361 + position_enc = np.array(
1362 + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
1363 + )
1364 + out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
1365 + out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
1366 + out.detach_()
1367 + out.requires_grad = False
1368 + return out
1369 +
1370 + @torch.no_grad()
1371 + def forward(self, input_ids, use_cache=False):
1372 + """Input is expected to be of size [bsz x seqlen]."""
1373 + bsz, seq_len = input_ids.shape[:2]
1374 + if use_cache:
1375 + positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
1376 + else:
1377 + # starts at 0, ends at 1-seq_len
1378 + positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
1379 + return super().forward(positions)
1 +# coding=utf-8
2 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3 +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 +#
5 +# Licensed under the Apache License, Version 2.0 (the "License");
6 +# you may not use this file except in compliance with the License.
7 +# You may obtain a copy of the License at
8 +#
9 +# http://www.apache.org/licenses/LICENSE-2.0
10 +#
11 +# Unless required by applicable law or agreed to in writing, software
12 +# distributed under the License is distributed on an "AS IS" BASIS,
13 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 +# See the License for the specific language governing permissions and
15 +# limitations under the License.
16 +
17 +import inspect
18 +import os
19 +import re
20 +from dataclasses import dataclass
21 +from typing import Callable, Dict, List, Optional, Set, Tuple, Union
22 +
23 +import torch
24 +from torch import Tensor, device, dtype, nn
25 +from torch.nn import CrossEntropyLoss
26 +from torch.nn import functional as F
27 +
28 +from transformers.activations import get_activation
29 +from transformers.configuration_utils import PretrainedConfig
30 +from transformers.file_utils import (
31 + DUMMY_INPUTS,
32 + TF2_WEIGHTS_NAME,
33 + TF_WEIGHTS_NAME,
34 + WEIGHTS_NAME,
35 + ModelOutput,
36 + cached_path,
37 + hf_bucket_url,
38 + is_remote_url,
39 + is_torch_tpu_available,
40 + replace_return_docstrings,
41 +)
42 +from generation_utils import GenerationMixin
43 +import logging
44 +
45 +logger = logging.getLogger(__name__) # pylint: disable=invalid-name
46 +logging.basicConfig(
47 + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
48 + datefmt="%m/%d/%Y %H:%M:%S",
49 + level=logging.INFO,
50 +)
51 +
52 +
53 +try:
54 + from torch.nn import Identity
55 +except ImportError:
56 + # Older PyTorch compatibility
57 + class Identity(nn.Module):
58 + r"""A placeholder identity operator that is argument-insensitive."""
59 +
60 + def __init__(self, *args, **kwargs):
61 + super().__init__()
62 +
63 + def forward(self, input):
64 + return input
65 +
66 +
67 +def find_pruneable_heads_and_indices(
68 + heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
69 +) -> Tuple[Set[int], torch.LongTensor]:
70 + """
71 + Finds the heads and their indices taking :obj:`already_pruned_heads` into account.
72 +
73 + Args:
74 + heads (:obj:`List[int]`): List of the indices of heads to prune.
75 + n_heads (:obj:`int`): The number of heads in the model.
76 + head_size (:obj:`int`): The size of each head.
77 + already_pruned_heads (:obj:`Set[int]`): A set of already pruned heads.
78 +
79 + Returns:
80 + :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
81 + """
82 + mask = torch.ones(n_heads, head_size)
83 + heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
84 + for head in heads:
85 + # Compute how many pruned heads are before the head and move the index accordingly
86 + head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
87 + mask[head] = 0
88 + mask = mask.view(-1).contiguous().eq(1)
89 + index: torch.LongTensor = torch.arange(len(mask))[mask].long()
90 + return heads, index
91 +
92 +
93 +class ModuleUtilsMixin:
94 + """
95 + A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
96 + """
97 +
98 + def num_parameters(self, only_trainable: bool = False) -> int:
99 + """
100 + Get the number of (optionally, trainable) parameters in the model.
101 +
102 + Args:
103 + only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
104 + Whether or not to return only the number of trainable parameters
105 +
106 + Returns:
107 + :obj:`int`: The number of parameters.
108 + """
109 + params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
110 + return sum(p.numel() for p in params)
111 +
112 + @staticmethod
113 + def _hook_rss_memory_pre_forward(module, *args, **kwargs):
114 + try:
115 + import psutil
116 + except (ImportError):
117 + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
118 +
119 + process = psutil.Process(os.getpid())
120 + mem = process.memory_info()
121 + module.mem_rss_pre_forward = mem.rss
122 + return None
123 +
124 + @staticmethod
125 + def _hook_rss_memory_post_forward(module, *args, **kwargs):
126 + try:
127 + import psutil
128 + except (ImportError):
129 + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
130 +
131 + process = psutil.Process(os.getpid())
132 + mem = process.memory_info()
133 + module.mem_rss_post_forward = mem.rss
134 + mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
135 + module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
136 + return None
137 +
138 + def add_memory_hooks(self):
139 + """
140 + Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
141 +
142 + Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to
143 + zero with :obj:`model.reset_memory_hooks_state()`.
144 + """
145 + for module in self.modules():
146 + module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
147 + module.register_forward_hook(self._hook_rss_memory_post_forward)
148 + self.reset_memory_hooks_state()
149 +
150 + def reset_memory_hooks_state(self):
151 + """
152 + Reset the :obj:`mem_rss_diff` attribute of each module (see
153 + :func:`~transformers.modeling_utils.ModuleUtilsMixin.add_memory_hooks`).
154 + """
155 + for module in self.modules():
156 + module.mem_rss_diff = 0
157 + module.mem_rss_post_forward = 0
158 + module.mem_rss_pre_forward = 0
159 +
160 + @property
161 + def device(self) -> device:
162 + """
163 + :obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
164 + device).
165 + """
166 + try:
167 + return next(self.parameters()).device
168 + except StopIteration:
169 + # For nn.DataParallel compatibility in PyTorch 1.5
170 +
171 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
172 + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
173 + return tuples
174 +
175 + gen = self._named_members(get_members_fn=find_tensor_attributes)
176 + first_tuple = next(gen)
177 + return first_tuple[1].device
178 +
179 + @property
180 + def dtype(self) -> dtype:
181 + """
182 + :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
183 + """
184 + try:
185 + return next(self.parameters()).dtype
186 + except StopIteration:
187 + # For nn.DataParallel compatibility in PyTorch 1.5
188 +
189 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
190 + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
191 + return tuples
192 +
193 + gen = self._named_members(get_members_fn=find_tensor_attributes)
194 + first_tuple = next(gen)
195 + return first_tuple[1].dtype
196 +
197 + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
198 + """
199 + Invert an attention mask (e.g., switches 0. and 1.).
200 +
201 + Args:
202 + encoder_attention_mask (:obj:`torch.Tensor`): An attention mask.
203 +
204 + Returns:
205 + :obj:`torch.Tensor`: The inverted attention mask.
206 + """
207 + if encoder_attention_mask.dim() == 3:
208 + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
209 + if encoder_attention_mask.dim() == 2:
210 + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
211 + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
212 + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
213 + # /transformer/transformer_layers.py#L270
214 + # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
215 + # encoder_extended_attention_mask.transpose(-1, -2))
216 + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
217 +
218 + if self.dtype == torch.float16:
219 + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
220 + elif self.dtype == torch.float32:
221 + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
222 + else:
223 + raise ValueError(
224 + "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
225 + self.dtype
226 + )
227 + )
228 +
229 + return encoder_extended_attention_mask
230 +
231 + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
232 + """
233 + Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
234 +
235 + Arguments:
236 + attention_mask (:obj:`torch.Tensor`):
237 + Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
238 + input_shape (:obj:`Tuple[int]`):
239 + The shape of the input to the model.
240 + device: (:obj:`torch.device`):
241 + The device of the input to the model.
242 +
243 + Returns:
244 + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
245 + """
246 + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
247 + # ourselves in which case we just need to make it broadcastable to all heads.
248 + if attention_mask.dim() == 3:
249 + extended_attention_mask = attention_mask[:, None, :, :]
250 + elif attention_mask.dim() == 2:
251 + # Provided a padding mask of dimensions [batch_size, seq_length]
252 + # - if the model is a decoder, apply a causal mask in addition to the padding mask
253 + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
254 + if self.config.is_decoder:
255 + batch_size, seq_length = input_shape
256 + seq_ids = torch.arange(seq_length, device=device)
257 + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
258 + # causal and attention masks must have same type with pytorch version < 1.3
259 + causal_mask = causal_mask.to(attention_mask.dtype)
260 + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
261 + else:
262 + extended_attention_mask = attention_mask[:, None, None, :]
263 + else:
264 + raise ValueError(
265 + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
266 + input_shape, attention_mask.shape
267 + )
268 + )
269 +
270 + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
271 + # masked positions, this operation will create a tensor which is 0.0 for
272 + # positions we want to attend and -10000.0 for masked positions.
273 + # Since we are adding it to the raw scores before the softmax, this is
274 + # effectively the same as removing these entirely.
275 + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
276 + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
277 + return extended_attention_mask
278 +
279 + def get_head_mask(
280 + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
281 + ) -> Tensor:
282 + """
283 + Prepare the head mask if needed.
284 +
285 + Args:
286 + head_mask (:obj:`torch.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`):
287 + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
288 + num_hidden_layers (:obj:`int`):
289 + The number of hidden layers in the model.
290 + is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`):
291 + Whether or not the attentions scores are computed by chunks or not.
292 +
293 + Returns:
294 + :obj:`torch.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length]`
295 + or list with :obj:`[None]` for each layer.
296 + """
297 + if head_mask is not None:
298 + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
299 + if is_attention_chunked is True:
300 + head_mask = head_mask.unsqueeze(-1)
301 + else:
302 + head_mask = [None] * num_hidden_layers
303 +
304 + return head_mask
305 +
306 + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
307 + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
308 + if head_mask.dim() == 1:
309 + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
310 + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
311 + elif head_mask.dim() == 2:
312 + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
313 + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
314 + head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
315 + return head_mask
316 +
317 +
318 +class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
319 + r"""
320 + Base class for all models.
321 +
322 + :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods
323 + for loading, downloading and saving models as well as a few methods common to all models to:
324 +
325 + * resize the input embeddings,
326 + * prune heads in the self-attention heads.
327 +
328 + Class attributes (overridden by derived classes):
329 + - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
330 + :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
331 + - **load_tf_weights** (:obj:`Callable`) -- A python `method` for loading a TensorFlow checkpoint in a
332 + PyTorch model, taking as arguments:
333 +
334 + - **model** (:class:`~transformers.PreTrainedModel`) -- An instance of the model on which to load the
335 + TensorFlow checkpoint.
336 + - **config** (:class:`~transformers.PreTrainedConfig`) -- An instance of the configuration associated
337 + to the model.
338 + - **path** (:obj:`str`) -- A path to the TensorFlow checkpoint.
339 +
340 + - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
341 + derived classes of the same architecture adding modules on top of the base model.
342 + - **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
343 + when loading the model (and avoid unnecessary warnings).
344 + """
345 + config_class = None
346 + base_model_prefix = ""
347 + authorized_missing_keys = None
348 +
349 + @property
350 + def dummy_inputs(self) -> Dict[str, torch.Tensor]:
351 + """
352 + :obj:`Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
353 + """
354 + return {"input_ids": torch.tensor(DUMMY_INPUTS)}
355 +
356 + def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
357 + super().__init__()
358 + if not isinstance(config, PretrainedConfig):
359 + raise ValueError(
360 + "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
361 + "To create a model from a pretrained model use "
362 + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
363 + self.__class__.__name__, self.__class__.__name__
364 + )
365 + )
366 + # Save config in model
367 + self.config = config
368 +
369 + @property
370 + def base_model(self) -> nn.Module:
371 + """
372 + :obj:`torch.nn.Module`: The main body of the model.
373 + """
374 + return getattr(self, self.base_model_prefix, self)
375 +
376 + def get_input_embeddings(self) -> nn.Module:
377 + """
378 + Returns the model's input embeddings.
379 +
380 + Returns:
381 + :obj:`nn.Module`: A torch module mapping vocabulary to hidden states.
382 + """
383 + base_model = getattr(self, self.base_model_prefix, self)
384 + if base_model is not self:
385 + return base_model.get_input_embeddings()
386 + else:
387 + raise NotImplementedError
388 +
389 + def set_input_embeddings(self, value: nn.Module):
390 + """
391 + Set model's input embeddings
392 +
393 + Args:
394 + value (:obj:`nn.Module`): A module mapping vocabulary to hidden states.
395 + """
396 + base_model = getattr(self, self.base_model_prefix, self)
397 + if base_model is not self:
398 + base_model.set_input_embeddings(value)
399 + else:
400 + raise NotImplementedError
401 +
402 + def get_output_embeddings(self) -> nn.Module:
403 + """
404 + Returns the model's output embeddings.
405 +
406 + Returns:
407 + :obj:`nn.Module`: A torch module mapping hidden states to vocabulary.
408 + """
409 + return None # Overwrite for models with output embeddings
410 +
411 + def tie_weights(self):
412 + """
413 + Tie the weights between the input embeddings and the output embeddings.
414 +
415 + If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
416 + the weights instead.
417 + """
418 + output_embeddings = self.get_output_embeddings()
419 + if output_embeddings is not None and self.config.tie_word_embeddings:
420 + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
421 +
422 + if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
423 + self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
424 +
425 + @staticmethod
426 + def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
427 + uninitialized_encoder_weights: List[str] = []
428 + assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal."
429 +
430 + def tie_encoder_to_decoder_recursively(
431 + decoder_pointer: nn.Module,
432 + encoder_pointer: nn.Module,
433 + module_name: str,
434 + uninitialized_encoder_weights: List[str],
435 + depth=0,
436 + ):
437 + assert isinstance(decoder_pointer, nn.Module) and isinstance(
438 + encoder_pointer, nn.Module
439 + ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
440 + if hasattr(decoder_pointer, "weight"):
441 + assert hasattr(encoder_pointer, "weight")
442 + encoder_pointer.weight = decoder_pointer.weight
443 + if hasattr(decoder_pointer, "bias"):
444 + assert hasattr(encoder_pointer, "bias")
445 + encoder_pointer.bias = decoder_pointer.bias
446 + return
447 +
448 + encoder_modules = encoder_pointer._modules
449 + decoder_modules = decoder_pointer._modules
450 + if len(decoder_modules) > 0:
451 + assert (
452 + len(encoder_modules) > 0
453 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
454 +
455 + all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
456 + encoder_layer_pos = 0
457 + for name, module in decoder_modules.items():
458 + if name.isdigit():
459 + encoder_name = str(int(name) + encoder_layer_pos)
460 + decoder_name = name
461 + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])):
462 + # this can happen if the name corresponds to the position in a list module list of layers
463 + # in this case the decoder has added a cross-attention that the encoder does not have
464 + # thus skip this step and substract one layer pos from encoder
465 + encoder_layer_pos -= 1
466 + continue
467 + elif name not in encoder_modules:
468 + continue
469 + elif depth > 500:
470 + raise ValueError(
471 + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
472 + )
473 + else:
474 + decoder_name = encoder_name = name
475 + tie_encoder_to_decoder_recursively(
476 + decoder_modules[decoder_name],
477 + encoder_modules[encoder_name],
478 + module_name + "/" + name,
479 + uninitialized_encoder_weights,
480 + depth=depth + 1,
481 + )
482 + all_encoder_weights.remove(module_name + "/" + encoder_name)
483 +
484 + uninitialized_encoder_weights += list(all_encoder_weights)
485 +
486 + # tie weights recursively
487 + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
488 + if len(uninitialized_encoder_weights) > 0:
489 + logger.warning(
490 + f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
491 + )
492 +
493 + def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
494 + """Tie or clone module weights depending of whether we are using TorchScript or not"""
495 + if self.config.torchscript:
496 + output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
497 + else:
498 + output_embeddings.weight = input_embeddings.weight
499 +
500 + if getattr(output_embeddings, "bias", None) is not None:
501 + output_embeddings.bias.data = torch.nn.functional.pad(
502 + output_embeddings.bias.data,
503 + (
504 + 0,
505 + output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
506 + ),
507 + "constant",
508 + 0,
509 + )
510 + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
511 + output_embeddings.out_features = input_embeddings.num_embeddings
512 +
513 + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
514 + """
515 + Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
516 +
517 + Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
518 +
519 + Arguments:
520 + new_num_tokens (:obj:`int`, `optional`):
521 + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
522 + vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
523 + just returns a pointer to the input tokens :obj:`torch.nn.Embedding` module of the model wihtout doing
524 + anything.
525 +
526 + Return:
527 + :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
528 + """
529 + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
530 + model_embeds = base_model._resize_token_embeddings(new_num_tokens)
531 + if new_num_tokens is None:
532 + return model_embeds
533 +
534 + # Update base model and current model config
535 + self.config.vocab_size = new_num_tokens
536 + base_model.vocab_size = new_num_tokens
537 +
538 + # Tie weights again if needed
539 + self.tie_weights()
540 +
541 + return model_embeds
542 +
543 + def _resize_token_embeddings(self, new_num_tokens):
544 + old_embeddings = self.get_input_embeddings()
545 + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
546 + self.set_input_embeddings(new_embeddings)
547 + return self.get_input_embeddings()
548 +
549 + def _get_resized_embeddings(
550 + self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
551 + ) -> torch.nn.Embedding:
552 + """
553 + Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
554 + initialized vectors at the end. Reducing the size will remove vectors from the end
555 +
556 + Args:
557 + old_embeddings (:obj:`torch.nn.Embedding`):
558 + Old embeddings to be resized.
559 + new_num_tokens (:obj:`int`, `optional`):
560 + New number of tokens in the embedding matrix.
561 +
562 + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
563 + vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
564 + :obj:`torch.nn.Embedding`` module of the model wihtout doing anything.
565 +
566 + Return:
567 + :obj:`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
568 + :obj:`new_num_tokens` is :obj:`None`
569 + """
570 + if new_num_tokens is None:
571 + return old_embeddings
572 +
573 + old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
574 + if old_num_tokens == new_num_tokens:
575 + return old_embeddings
576 +
577 + # Build new embeddings
578 + new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
579 + new_embeddings.to(old_embeddings.weight.device)
580 +
581 + # initialize all new embeddings (in particular added tokens)
582 + self._init_weights(new_embeddings)
583 +
584 + # Copy token embeddings from the previous weights
585 + num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
586 + new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
587 +
588 + return new_embeddings
589 +
590 + def init_weights(self):
591 + """
592 + Initializes and prunes weights if needed.
593 + """
594 + # Initialize weights
595 + self.apply(self._init_weights)
596 +
597 + # Prune heads if needed
598 + if self.config.pruned_heads:
599 + self.prune_heads(self.config.pruned_heads)
600 +
601 + # Tie weights if needed
602 + self.tie_weights()
603 +
604 + def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
605 + """
606 + Prunes heads of the base model.
607 +
608 + Arguments:
609 + heads_to_prune (:obj:`Dict[int, List[int]]`):
610 + Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
611 + of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
612 + prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
613 + """
614 + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
615 + for layer, heads in heads_to_prune.items():
616 + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
617 + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
618 +
619 + self.base_model._prune_heads(heads_to_prune)
620 +
621 + def save_pretrained(self, save_directory):
622 + """
623 + Save a model and its configuration file to a directory, so that it can be re-loaded using the
624 + `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
625 +
626 + Arguments:
627 + save_directory (:obj:`str`):
628 + Directory to which to save. Will be created if it doesn't exist.
629 + """
630 + if os.path.isfile(save_directory):
631 + logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
632 + return
633 + os.makedirs(save_directory, exist_ok=True)
634 +
635 + # Only save the model itself if we are using distributed training
636 + model_to_save = self.module if hasattr(self, "module") else self
637 +
638 + # Attach architecture to the config
639 + model_to_save.config.architectures = [model_to_save.__class__.__name__]
640 +
641 + # If we save using the predefined names, we can load using `from_pretrained`
642 + output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
643 +
644 + if getattr(self.config, "xla_device", False):
645 + import torch_xla.core.xla_model as xm
646 +
647 + if xm.is_master_ordinal():
648 + # Save configuration file
649 + model_to_save.config.save_pretrained(save_directory)
650 + # xm.save takes care of saving only from master
651 + xm.save(model_to_save.state_dict(), output_model_file)
652 + else:
653 + model_to_save.config.save_pretrained(save_directory)
654 + torch.save(model_to_save.state_dict(), output_model_file)
655 +
656 + logger.info("Model weights saved in {}".format(output_model_file))
657 +
658 + @classmethod
659 + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
660 + r"""
661 + Instantiate a pretrained pytorch model from a pre-trained model configuration.
662 +
663 + The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
664 + To train the model, you should first set it back in training mode with ``model.train()``.
665 +
666 + The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
667 + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
668 + task.
669 +
670 + The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
671 + weights are discarded.
672 +
673 + Parameters:
674 + pretrained_model_name_or_path (:obj:`str`, `optional`):
675 + Can be either:
676 +
677 + - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
678 + ``bert-base-uncased``.
679 + - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
680 + ``dbmdz/bert-base-german-cased``.
681 + - A path to a `directory` containing model weights saved using
682 + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
683 + - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
684 + this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
685 + as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
686 + a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
687 + - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
688 + arguments ``config`` and ``state_dict``).
689 + model_args (sequence of positional arguments, `optional`):
690 + All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
691 + config (:obj:`Union[PretrainedConfig, str]`, `optional`):
692 + Can be either:
693 +
694 + - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
695 + - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
696 +
697 + Configuration for the model to use instead of an automatically loaded configuation. Configuration can
698 + be automatically loaded when:
699 +
700 + - The model is a model provided by the library (loaded with the `shortcut name` string of a
701 + pretrained model).
702 + - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
703 + by suppling the save directory.
704 + - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
705 + configuration JSON file named `config.json` is found in the directory.
706 + state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`):
707 + A state dictionary to use instead of a state dictionary loaded from saved weights file.
708 +
709 + This option can be used if you want to create a model from a pretrained configuration but load your own
710 + weights. In this case though, you should check if using
711 + :func:`~transformers.PreTrainedModel.save_pretrained` and
712 + :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
713 + cache_dir (:obj:`str`, `optional`):
714 + Path to a directory in which a downloaded pretrained model configuration should be cached if the
715 + standard cache should not be used.
716 + from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
717 + Load the model weights from a TensorFlow checkpoint save file (see docstring of
718 + ``pretrained_model_name_or_path`` argument).
719 + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
720 + Whether or not to force the (re-)download of the model weights and configuration files, overriding the
721 + cached versions if they exist.
722 + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
723 + Whether or not to delete incompletely received files. Will attempt to resume the download if such a
724 + file exists.
725 + proxies (:obj:`Dict[str, str], `optional`):
726 + A dictionary of proxy servers to use by protocol or endpoint, e.g.,
727 + :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
728 + request.
729 + output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
730 + Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
731 + messages.
732 + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
733 + Whether or not to only look at local files (e.g., not try doanloading the model).
734 + use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
735 + Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
736 + our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
737 + kwargs (remaining dictionary of keyword arguments, `optional`):
738 + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
739 + :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
740 + automatically loaded:
741 +
742 + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
743 + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
744 + already been done)
745 + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
746 + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
747 + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
748 + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
749 + attribute will be passed to the underlying model's ``__init__`` function.
750 +
751 + Examples::
752 +
753 + from transformers import BertConfig, BertModel
754 + # Download model and configuration from S3 and cache.
755 + model = BertModel.from_pretrained('bert-base-uncased')
756 + # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
757 + model = BertModel.from_pretrained('./test/saved_model/')
758 + # Update configuration during loading.
759 + model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)
760 + assert model.config.output_attention == True
761 + # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
762 + config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
763 + model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
764 + """
765 + config = kwargs.pop("config", None)
766 + state_dict = kwargs.pop("state_dict", None)
767 + cache_dir = kwargs.pop("cache_dir", None)
768 + from_tf = kwargs.pop("from_tf", False)
769 + force_download = kwargs.pop("force_download", False)
770 + resume_download = kwargs.pop("resume_download", False)
771 + proxies = kwargs.pop("proxies", None)
772 + output_loading_info = kwargs.pop("output_loading_info", False)
773 + local_files_only = kwargs.pop("local_files_only", False)
774 + use_cdn = kwargs.pop("use_cdn", True)
775 +
776 + # Load config if we don't provide a configuration
777 + if not isinstance(config, PretrainedConfig):
778 + config_path = config if config is not None else pretrained_model_name_or_path
779 + config, model_kwargs = cls.config_class.from_pretrained(
780 + config_path,
781 + *model_args,
782 + cache_dir=cache_dir,
783 + return_unused_kwargs=True,
784 + force_download=force_download,
785 + resume_download=resume_download,
786 + proxies=proxies,
787 + local_files_only=local_files_only,
788 + **kwargs,
789 + )
790 + else:
791 + model_kwargs = kwargs
792 +
793 + # Load model
794 + if pretrained_model_name_or_path is not None:
795 + if os.path.isdir(pretrained_model_name_or_path):
796 + if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
797 + # Load from a TF 1.0 checkpoint
798 + archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
799 + elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
800 + # Load from a TF 2.0 checkpoint
801 + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
802 + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
803 + # Load from a PyTorch checkpoint
804 + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
805 + else:
806 + raise EnvironmentError(
807 + "Error no file named {} found in directory {} or `from_tf` set to False".format(
808 + [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
809 + pretrained_model_name_or_path,
810 + )
811 + )
812 + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
813 + archive_file = pretrained_model_name_or_path
814 + elif os.path.isfile(pretrained_model_name_or_path + ".index"):
815 + assert (
816 + from_tf
817 + ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
818 + pretrained_model_name_or_path + ".index"
819 + )
820 + archive_file = pretrained_model_name_or_path + ".index"
821 + else:
822 + archive_file = hf_bucket_url(
823 + pretrained_model_name_or_path,
824 + filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
825 + use_cdn=use_cdn,
826 + )
827 +
828 + try:
829 + # Load from URL or cache if already cached
830 + resolved_archive_file = cached_path(
831 + archive_file,
832 + cache_dir=cache_dir,
833 + force_download=force_download,
834 + proxies=proxies,
835 + resume_download=resume_download,
836 + local_files_only=local_files_only,
837 + )
838 + if resolved_archive_file is None:
839 + raise EnvironmentError
840 + except EnvironmentError:
841 + msg = (
842 + f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
843 + f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
844 + f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
845 + )
846 + raise EnvironmentError(msg)
847 +
848 + if resolved_archive_file == archive_file:
849 + logger.info("loading weights file {}".format(archive_file))
850 + else:
851 + logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
852 + else:
853 + resolved_archive_file = None
854 +
855 + # Instantiate model.
856 + model = cls(config, *model_args, **model_kwargs)
857 +
858 + if state_dict is None and not from_tf:
859 + try:
860 + state_dict = torch.load(resolved_archive_file, map_location="cpu")
861 + except Exception:
862 + raise OSError(
863 + "Unable to load weights from pytorch checkpoint file. "
864 + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
865 + )
866 +
867 + missing_keys = []
868 + unexpected_keys = []
869 + error_msgs = []
870 +
871 + if from_tf:
872 + if resolved_archive_file.endswith(".index"):
873 + # Load from a TensorFlow 1.X checkpoint - provided by original authors
874 + model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
875 + else:
876 + # Load from our TensorFlow 2.0 checkpoints
877 + try:
878 + from transformers import load_tf2_checkpoint_in_pytorch_model
879 +
880 + model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
881 + except ImportError:
882 + logger.error(
883 + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
884 + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
885 + )
886 + raise
887 + else:
888 + # Convert old format to new format if needed from a PyTorch state_dict
889 + old_keys = []
890 + new_keys = []
891 + for key in state_dict.keys():
892 + new_key = None
893 + if "gamma" in key:
894 + new_key = key.replace("gamma", "weight")
895 + if "beta" in key:
896 + new_key = key.replace("beta", "bias")
897 + if new_key:
898 + old_keys.append(key)
899 + new_keys.append(new_key)
900 + for old_key, new_key in zip(old_keys, new_keys):
901 + state_dict[new_key] = state_dict.pop(old_key)
902 +
903 + # copy state_dict so _load_from_state_dict can modify it
904 + metadata = getattr(state_dict, "_metadata", None)
905 + state_dict = state_dict.copy()
906 + if metadata is not None:
907 + state_dict._metadata = metadata
908 +
909 + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
910 + # so we need to apply the function recursively.
911 + def load(module: nn.Module, prefix=""):
912 + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
913 + module._load_from_state_dict(
914 + state_dict,
915 + prefix,
916 + local_metadata,
917 + True,
918 + missing_keys,
919 + unexpected_keys,
920 + error_msgs,
921 + )
922 + for name, child in module._modules.items():
923 + if child is not None:
924 + load(child, prefix + name + ".")
925 +
926 + # Make sure we are able to load base models as well as derived models (with heads)
927 + start_prefix = ""
928 + model_to_load = model
929 + has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
930 + if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
931 + start_prefix = cls.base_model_prefix + "."
932 + if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
933 + model_to_load = getattr(model, cls.base_model_prefix)
934 +
935 + load(model_to_load, prefix=start_prefix)
936 +
937 + if model.__class__.__name__ != model_to_load.__class__.__name__:
938 + base_model_state_dict = model_to_load.state_dict().keys()
939 + head_model_state_dict_without_base_prefix = [
940 + key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
941 + ]
942 + missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
943 +
944 + # Some models may have keys that are not in the state by design, removing them before needlessly warning
945 + # the user.
946 + if cls.authorized_missing_keys is not None:
947 + for pat in cls.authorized_missing_keys:
948 + missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
949 +
950 + if len(unexpected_keys) > 0:
951 + logger.warning(
952 + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
953 + f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
954 + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
955 + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
956 + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
957 + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
958 + )
959 + else:
960 + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
961 + if len(missing_keys) > 0:
962 + logger.warning(
963 + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
964 + f"and are newly initialized: {missing_keys}\n"
965 + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
966 + )
967 + else:
968 + logger.info(
969 + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
970 + f"If your task is similar to the task the model of the checkpoint was trained on, "
971 + f"you can already use {model.__class__.__name__} for predictions without further training."
972 + )
973 + if len(error_msgs) > 0:
974 + raise RuntimeError(
975 + "Error(s) in loading state_dict for {}:\n\t{}".format(
976 + model.__class__.__name__, "\n\t".join(error_msgs)
977 + )
978 + )
979 + # make sure token embedding weights are still tied if needed
980 + model.tie_weights()
981 +
982 + # Set model in evaluation mode to deactivate DropOut modules by default
983 + model.eval()
984 +
985 + if output_loading_info:
986 + loading_info = {
987 + "missing_keys": missing_keys,
988 + "unexpected_keys": unexpected_keys,
989 + "error_msgs": error_msgs,
990 + }
991 + return model, loading_info
992 +
993 + if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
994 + import torch_xla.core.xla_model as xm
995 +
996 + model = xm.send_cpu_data_to_device(model, xm.xla_device())
997 + model.to(xm.xla_device())
998 +
999 + return model
1000 +
1001 +
1002 +class Conv1D(nn.Module):
1003 + """
1004 + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
1005 +
1006 + Basically works like a linear layer but the weights are transposed.
1007 +
1008 + Args:
1009 + nf (:obj:`int`): The number of output features.
1010 + nx (:obj:`int`): The number of input features.
1011 + """
1012 +
1013 + def __init__(self, nf, nx):
1014 + super().__init__()
1015 + self.nf = nf
1016 + w = torch.empty(nx, nf)
1017 + nn.init.normal_(w, std=0.02)
1018 + self.weight = nn.Parameter(w)
1019 + self.bias = nn.Parameter(torch.zeros(nf))
1020 +
1021 + def forward(self, x):
1022 + size_out = x.size()[:-1] + (self.nf,)
1023 + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
1024 + x = x.view(*size_out)
1025 + return x
1026 +
1027 +
1028 +class PoolerStartLogits(nn.Module):
1029 + """
1030 + Compute SQuAD start logits from sequence hidden states.
1031 +
1032 + Args:
1033 + config (:class:`~transformers.PretrainedConfig`):
1034 + The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
1035 + """
1036 +
1037 + def __init__(self, config: PretrainedConfig):
1038 + super().__init__()
1039 + self.dense = nn.Linear(config.hidden_size, 1)
1040 +
1041 + def forward(
1042 + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
1043 + ) -> torch.FloatTensor:
1044 + """
1045 + Args:
1046 + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
1047 + The final hidden states of the model.
1048 + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
1049 + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1050 + 1.0 means token should be masked.
1051 +
1052 + Returns:
1053 + :obj:`torch.FloatTensor`: The start logits for SQuAD.
1054 + """
1055 + x = self.dense(hidden_states).squeeze(-1)
1056 +
1057 + if p_mask is not None:
1058 + if next(self.parameters()).dtype == torch.float16:
1059 + x = x * (1 - p_mask) - 65500 * p_mask
1060 + else:
1061 + x = x * (1 - p_mask) - 1e30 * p_mask
1062 +
1063 + return x
1064 +
1065 +
1066 +class PoolerEndLogits(nn.Module):
1067 + """
1068 + Compute SQuAD end logits from sequence hidden states.
1069 +
1070 + Args:
1071 + config (:class:`~transformers.PretrainedConfig`):
1072 + The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
1073 + :obj:`layer_norm_eps` to use.
1074 + """
1075 +
1076 + def __init__(self, config: PretrainedConfig):
1077 + super().__init__()
1078 + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
1079 + self.activation = nn.Tanh()
1080 + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1081 + self.dense_1 = nn.Linear(config.hidden_size, 1)
1082 +
1083 + def forward(
1084 + self,
1085 + hidden_states: torch.FloatTensor,
1086 + start_states: Optional[torch.FloatTensor] = None,
1087 + start_positions: Optional[torch.LongTensor] = None,
1088 + p_mask: Optional[torch.FloatTensor] = None,
1089 + ) -> torch.FloatTensor:
1090 + """
1091 + Args:
1092 + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
1093 + The final hidden states of the model.
1094 + start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
1095 + The hidden states of the first tokens for the labeled span.
1096 + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1097 + The position of the first token for the labeled span.
1098 + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
1099 + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1100 + 1.0 means token should be masked.
1101 +
1102 + .. note::
1103 +
1104 + One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
1105 + ``start_positions`` overrides ``start_states``.
1106 +
1107 + Returns:
1108 + :obj:`torch.FloatTensor`: The end logits for SQuAD.
1109 + """
1110 + assert (
1111 + start_states is not None or start_positions is not None
1112 + ), "One of start_states, start_positions should be not None"
1113 + if start_positions is not None:
1114 + slen, hsz = hidden_states.shape[-2:]
1115 + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
1116 + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
1117 + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
1118 +
1119 + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
1120 + x = self.activation(x)
1121 + x = self.LayerNorm(x)
1122 + x = self.dense_1(x).squeeze(-1)
1123 +
1124 + if p_mask is not None:
1125 + if next(self.parameters()).dtype == torch.float16:
1126 + x = x * (1 - p_mask) - 65500 * p_mask
1127 + else:
1128 + x = x * (1 - p_mask) - 1e30 * p_mask
1129 +
1130 + return x
1131 +
1132 +
1133 +class PoolerAnswerClass(nn.Module):
1134 + """
1135 + Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
1136 +
1137 + Args:
1138 + config (:class:`~transformers.PretrainedConfig`):
1139 + The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
1140 + """
1141 +
1142 + def __init__(self, config):
1143 + super().__init__()
1144 + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
1145 + self.activation = nn.Tanh()
1146 + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
1147 +
1148 + def forward(
1149 + self,
1150 + hidden_states: torch.FloatTensor,
1151 + start_states: Optional[torch.FloatTensor] = None,
1152 + start_positions: Optional[torch.LongTensor] = None,
1153 + cls_index: Optional[torch.LongTensor] = None,
1154 + ) -> torch.FloatTensor:
1155 + """
1156 + Args:
1157 + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
1158 + The final hidden states of the model.
1159 + start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
1160 + The hidden states of the first tokens for the labeled span.
1161 + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1162 + The position of the first token for the labeled span.
1163 + cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1164 + Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
1165 +
1166 + .. note::
1167 +
1168 + One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
1169 + ``start_positions`` overrides ``start_states``.
1170 +
1171 + Returns:
1172 + :obj:`torch.FloatTensor`: The SQuAD 2.0 answer class.
1173 + """
1174 + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
1175 + hsz = hidden_states.shape[-1]
1176 + assert (
1177 + start_states is not None or start_positions is not None
1178 + ), "One of start_states, start_positions should be not None"
1179 + if start_positions is not None:
1180 + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
1181 + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
1182 +
1183 + if cls_index is not None:
1184 + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
1185 + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
1186 + else:
1187 + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
1188 +
1189 + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
1190 + x = self.activation(x)
1191 + x = self.dense_1(x).squeeze(-1)
1192 +
1193 + return x
1194 +
1195 +
1196 +@dataclass
1197 +class SquadHeadOutput(ModelOutput):
1198 + """
1199 + Base class for outputs of question answering models using a :class:`~transformers.modeling_utils.SQuADHead`.
1200 +
1201 + Args:
1202 + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
1203 + Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
1204 + start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1205 + Log probabilities for the top config.start_n_top start token possibilities (beam-search).
1206 + start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1207 + Indices for the top config.start_n_top start token possibilities (beam-search).
1208 + end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1209 + Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
1210 + end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1211 + Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
1212 + cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1213 + Log probabilities for the ``is_impossible`` label of the answers.
1214 +
1215 + """
1216 +
1217 + loss: Optional[torch.FloatTensor] = None
1218 + start_top_log_probs: Optional[torch.FloatTensor] = None
1219 + start_top_index: Optional[torch.LongTensor] = None
1220 + end_top_log_probs: Optional[torch.FloatTensor] = None
1221 + end_top_index: Optional[torch.LongTensor] = None
1222 + cls_logits: Optional[torch.FloatTensor] = None
1223 +
1224 +
1225 +class SQuADHead(nn.Module):
1226 + r"""
1227 + A SQuAD head inspired by XLNet.
1228 +
1229 + Args:
1230 + config (:class:`~transformers.PretrainedConfig`):
1231 + The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
1232 + :obj:`layer_norm_eps` to use.
1233 + """
1234 +
1235 + def __init__(self, config):
1236 + super().__init__()
1237 + self.start_n_top = config.start_n_top
1238 + self.end_n_top = config.end_n_top
1239 +
1240 + self.start_logits = PoolerStartLogits(config)
1241 + self.end_logits = PoolerEndLogits(config)
1242 + self.answer_class = PoolerAnswerClass(config)
1243 +
1244 + @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
1245 + def forward(
1246 + self,
1247 + hidden_states: torch.FloatTensor,
1248 + start_positions: Optional[torch.LongTensor] = None,
1249 + end_positions: Optional[torch.LongTensor] = None,
1250 + cls_index: Optional[torch.LongTensor] = None,
1251 + is_impossible: Optional[torch.LongTensor] = None,
1252 + p_mask: Optional[torch.FloatTensor] = None,
1253 + return_dict: bool = False,
1254 + ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
1255 + """
1256 + Args:
1257 + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
1258 + Final hidden states of the model on the sequence tokens.
1259 + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1260 + Positions of the first token for the labeled span.
1261 + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1262 + Positions of the last token for the labeled span.
1263 + cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1264 + Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
1265 + is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1266 + Whether the question has a possible answer in the paragraph or not.
1267 + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
1268 + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1269 + 1.0 means token should be masked.
1270 + return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
1271 + Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
1272 +
1273 + Returns:
1274 + """
1275 + start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1276 +
1277 + if start_positions is not None and end_positions is not None:
1278 + # If we are on multi-GPU, let's remove the dimension added by batch splitting
1279 + for x in (start_positions, end_positions, cls_index, is_impossible):
1280 + if x is not None and x.dim() > 1:
1281 + x.squeeze_(-1)
1282 +
1283 + # during training, compute the end logits based on the ground truth of the start position
1284 + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1285 +
1286 + loss_fct = CrossEntropyLoss()
1287 + start_loss = loss_fct(start_logits, start_positions)
1288 + end_loss = loss_fct(end_logits, end_positions)
1289 + total_loss = (start_loss + end_loss) / 2
1290 +
1291 + if cls_index is not None and is_impossible is not None:
1292 + # Predict answerability from the representation of CLS and START
1293 + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1294 + loss_fct_cls = nn.BCEWithLogitsLoss()
1295 + cls_loss = loss_fct_cls(cls_logits, is_impossible)
1296 +
1297 + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1298 + total_loss += cls_loss * 0.5
1299 +
1300 + return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
1301 +
1302 + else:
1303 + # during inference, compute the end logits based on beam search
1304 + bsz, slen, hsz = hidden_states.size()
1305 + start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1306 +
1307 + start_top_log_probs, start_top_index = torch.topk(
1308 + start_log_probs, self.start_n_top, dim=-1
1309 + ) # shape (bsz, start_n_top)
1310 + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
1311 + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
1312 + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
1313 +
1314 + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1315 + start_states
1316 + ) # shape (bsz, slen, start_n_top, hsz)
1317 + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1318 + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1319 + end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
1320 +
1321 + end_top_log_probs, end_top_index = torch.topk(
1322 + end_log_probs, self.end_n_top, dim=1
1323 + ) # shape (bsz, end_n_top, start_n_top)
1324 + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1325 + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1326 +
1327 + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
1328 + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
1329 +
1330 + if not return_dict:
1331 + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
1332 + else:
1333 + return SquadHeadOutput(
1334 + start_top_log_probs=start_top_log_probs,
1335 + start_top_index=start_top_index,
1336 + end_top_log_probs=end_top_log_probs,
1337 + end_top_index=end_top_index,
1338 + cls_logits=cls_logits,
1339 + )
1340 +
1341 +
1342 +class SequenceSummary(nn.Module):
1343 + r"""
1344 + Compute a single vector summary of a sequence hidden states.
1345 +
1346 + Args:
1347 + config (:class:`~transformers.PretrainedConfig`):
1348 + The config used by the model. Relevant arguments in the config class of the model are (refer to the
1349 + actual config class of your model for the default values it uses):
1350 +
1351 + - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:
1352 +
1353 + - :obj:`"last"` -- Take the last token hidden state (like XLNet)
1354 + - :obj:`"first"` -- Take the first token hidden state (like Bert)
1355 + - :obj:`"mean"` -- Take the mean of all tokens hidden states
1356 + - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
1357 + - :obj:`"attn"` -- Not implemented now, use multi-head attention
1358 +
1359 + - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
1360 + - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
1361 + :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
1362 + - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
1363 + output, another string or :obj:`None` will add no activation.
1364 + - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
1365 + activation.
1366 + - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
1367 + activation.
1368 + """
1369 +
1370 + def __init__(self, config: PretrainedConfig):
1371 + super().__init__()
1372 +
1373 + self.summary_type = getattr(config, "summary_type", "last")
1374 + if self.summary_type == "attn":
1375 + # We should use a standard multi-head attention module with absolute positional embedding for that.
1376 + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
1377 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0
1378 + raise NotImplementedError
1379 +
1380 + self.summary = Identity()
1381 + if hasattr(config, "summary_use_proj") and config.summary_use_proj:
1382 + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
1383 + num_classes = config.num_labels
1384 + else:
1385 + num_classes = config.hidden_size
1386 + self.summary = nn.Linear(config.hidden_size, num_classes)
1387 +
1388 + activation_string = getattr(config, "summary_activation", None)
1389 + self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
1390 +
1391 + self.first_dropout = Identity()
1392 + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
1393 + self.first_dropout = nn.Dropout(config.summary_first_dropout)
1394 +
1395 + self.last_dropout = Identity()
1396 + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1397 + self.last_dropout = nn.Dropout(config.summary_last_dropout)
1398 +
1399 + def forward(
1400 + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
1401 + ) -> torch.FloatTensor:
1402 + """
1403 + Compute a single vector summary of a sequence hidden states.
1404 +
1405 + Args:
1406 + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
1407 + The hidden states of the last layer.
1408 + cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
1409 + Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
1410 + token.
1411 +
1412 + Returns:
1413 + :obj:`torch.FloatTensor`: The summary of the sequence hidden states.
1414 + """
1415 + if self.summary_type == "last":
1416 + output = hidden_states[:, -1]
1417 + elif self.summary_type == "first":
1418 + output = hidden_states[:, 0]
1419 + elif self.summary_type == "mean":
1420 + output = hidden_states.mean(dim=1)
1421 + elif self.summary_type == "cls_index":
1422 + if cls_index is None:
1423 + cls_index = torch.full_like(
1424 + hidden_states[..., :1, :],
1425 + hidden_states.shape[-2] - 1,
1426 + dtype=torch.long,
1427 + )
1428 + else:
1429 + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1430 + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
1431 + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1432 + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
1433 + elif self.summary_type == "attn":
1434 + raise NotImplementedError
1435 +
1436 + output = self.first_dropout(output)
1437 + output = self.summary(output)
1438 + output = self.activation(output)
1439 + output = self.last_dropout(output)
1440 +
1441 + return output
1442 +
1443 +
1444 +def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
1445 + """
1446 + Prune a linear layer to keep only entries in index.
1447 +
1448 + Used to remove heads.
1449 +
1450 + Args:
1451 + layer (:obj:`torch.nn.Linear`): The layer to prune.
1452 + index (:obj:`torch.LongTensor`): The indices to keep in the layer.
1453 + dim (:obj:`int`, `optional`, defaults to 0): The dimension on which to keep the indices.
1454 +
1455 + Returns:
1456 + :obj:`torch.nn.Linear`: The pruned layer as a new layer with :obj:`requires_grad=True`.
1457 + """
1458 + index = index.to(layer.weight.device)
1459 + W = layer.weight.index_select(dim, index).clone().detach()
1460 + if layer.bias is not None:
1461 + if dim == 1:
1462 + b = layer.bias.clone().detach()
1463 + else:
1464 + b = layer.bias[index].clone().detach()
1465 + new_size = list(layer.weight.size())
1466 + new_size[dim] = len(index)
1467 + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
1468 + new_layer.weight.requires_grad = False
1469 + new_layer.weight.copy_(W.contiguous())
1470 + new_layer.weight.requires_grad = True
1471 + if layer.bias is not None:
1472 + new_layer.bias.requires_grad = False
1473 + new_layer.bias.copy_(b.contiguous())
1474 + new_layer.bias.requires_grad = True
1475 + return new_layer
1476 +
1477 +
1478 +def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
1479 + """
1480 + Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
1481 + are transposed.
1482 +
1483 + Used to remove heads.
1484 +
1485 + Args:
1486 + layer (:class:`~transformers.modeling_utils.Conv1D`): The layer to prune.
1487 + index (:obj:`torch.LongTensor`): The indices to keep in the layer.
1488 + dim (:obj:`int`, `optional`, defaults to 1): The dimension on which to keep the indices.
1489 +
1490 + Returns:
1491 + :class:`~transformers.modeling_utils.Conv1D`: The pruned layer as a new layer with :obj:`requires_grad=True`.
1492 + """
1493 + index = index.to(layer.weight.device)
1494 + W = layer.weight.index_select(dim, index).clone().detach()
1495 + if dim == 0:
1496 + b = layer.bias.clone().detach()
1497 + else:
1498 + b = layer.bias[index].clone().detach()
1499 + new_size = list(layer.weight.size())
1500 + new_size[dim] = len(index)
1501 + new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
1502 + new_layer.weight.requires_grad = False
1503 + new_layer.weight.copy_(W.contiguous())
1504 + new_layer.weight.requires_grad = True
1505 + new_layer.bias.requires_grad = False
1506 + new_layer.bias.copy_(b.contiguous())
1507 + new_layer.bias.requires_grad = True
1508 + return new_layer
1509 +
1510 +
1511 +def prune_layer(
1512 + layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None
1513 +) -> Union[torch.nn.Linear, Conv1D]:
1514 + """
1515 + Prune a Conv1D or linear layer to keep only entries in index.
1516 +
1517 + Used to remove heads.
1518 +
1519 + Args:
1520 + layer (:obj:`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
1521 + index (:obj:`torch.LongTensor`): The indices to keep in the layer.
1522 + dim (:obj:`int`, `optional`): The dimension on which to keep the indices.
1523 +
1524 + Returns:
1525 + :obj:`torch.nn.Linear` or :class:`~transformers.modeling_utils.Conv1D`:
1526 + The pruned layer as a new layer with :obj:`requires_grad=True`.
1527 + """
1528 + if isinstance(layer, nn.Linear):
1529 + return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
1530 + elif isinstance(layer, Conv1D):
1531 + return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
1532 + else:
1533 + raise ValueError("Can't prune layer of class {}".format(layer.__class__))
1534 +
1535 +
1536 +def apply_chunking_to_forward(
1537 + forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
1538 +) -> torch.Tensor:
1539 + """
1540 + This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
1541 + dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory.
1542 +
1543 + If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as
1544 + directly applying :obj:`forward_fn` to :obj:`input_tensors`.
1545 +
1546 + Args:
1547 + forward_fn (:obj:`Callable[..., torch.Tensor]`):
1548 + The forward function of the model.
1549 + chunk_size (:obj:`int`):
1550 + The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
1551 + chunk_dim (:obj:`int`):
1552 + The dimension over which the :obj:`input_tensors` should be chunked.
1553 + input_tensors (:obj:`Tuple[torch.Tensor]`):
1554 + The input tensors of ``forward_fn`` which will be chunked.
1555 + Returns:
1556 + :obj:`torch.Tensor`: A tensor with the same shape as the :obj:`foward_fn` would have given if applied`.
1557 +
1558 +
1559 + Examples::
1560 +
1561 + # rename the usual forward() fn to forward_chunk()
1562 + def forward_chunk(self, hidden_states):
1563 + hidden_states = self.decoder(hidden_states)
1564 + return hidden_states
1565 +
1566 + # implement a chunked forward function
1567 + def forward(self, hidden_states):
1568 + return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
1569 + """
1570 +
1571 + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
1572 + tensor_shape = input_tensors[0].shape
1573 + assert all(
1574 + input_tensor.shape == tensor_shape for input_tensor in input_tensors
1575 + ), "All input tenors have to be of the same shape"
1576 +
1577 + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
1578 + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
1579 + assert num_args_in_forward_chunk_fn == len(
1580 + input_tensors
1581 + ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
1582 + num_args_in_forward_chunk_fn, len(input_tensors)
1583 + )
1584 +
1585 + if chunk_size > 0:
1586 + assert (
1587 + input_tensors[0].shape[chunk_dim] % chunk_size == 0
1588 + ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
1589 + input_tensors[0].shape[chunk_dim], chunk_size
1590 + )
1591 +
1592 + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1593 +
1594 + # chunk input tensor into tuples
1595 + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
1596 + # apply forward fn to every tuple
1597 + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
1598 + # concatenate output at same dimension
1599 + return torch.cat(output_chunks, dim=chunk_dim)
1600 +
1601 + return forward_fn(*input_tensors)