Showing
5 changed files
with
4021 additions
and
2 deletions
... | @@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer): | ... | @@ -188,8 +188,8 @@ class SummarizationModule(BaseTransformer): |
188 | t0 = time.time() | 188 | t0 = time.time() |
189 | generated_ids = self.model.generate( | 189 | generated_ids = self.model.generate( |
190 | batch[0].long(), | 190 | batch[0].long(), |
191 | + patch_ids=batch[2].long(), | ||
191 | attention_mask=batch[1].long(), | 192 | attention_mask=batch[1].long(), |
192 | - # patch_ids=batch[2].long(), | ||
193 | use_cache=True, | 193 | use_cache=True, |
194 | decoder_start_token_id=self.decoder_start_token_id, | 194 | decoder_start_token_id=self.decoder_start_token_id, |
195 | ) | 195 | ) | ... | ... |
generation_utils.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. | ||
3 | +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
4 | +# | ||
5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | +# you may not use this file except in compliance with the License. | ||
7 | +# You may obtain a copy of the License at | ||
8 | +# | ||
9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | +# | ||
11 | +# Unless required by applicable law or agreed to in writing, software | ||
12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | +# See the License for the specific language governing permissions and | ||
15 | +# limitations under the License. | ||
16 | + | ||
17 | +from typing import Iterable, List, Optional, Tuple | ||
18 | + | ||
19 | +import torch | ||
20 | +from torch import Tensor | ||
21 | +from torch.nn import functional as F | ||
22 | + | ||
23 | +from transformers.file_utils import ModelOutput | ||
24 | +import logging | ||
25 | + | ||
26 | +logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
27 | +logging.basicConfig( | ||
28 | + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
29 | + datefmt="%m/%d/%Y %H:%M:%S", | ||
30 | + level=logging.INFO, | ||
31 | +) | ||
32 | + | ||
33 | +class GenerationMixin: | ||
34 | + """ | ||
35 | + A class contraining all of the functions supporting generation, to be used as a mixin in | ||
36 | + :class:`~transfomers.PreTrainedModel`. | ||
37 | + """ | ||
38 | + | ||
39 | + def prepare_inputs_for_generation(self, input_ids, **kwargs): | ||
40 | + """ | ||
41 | + Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to prepare inputs in the | ||
42 | + generate method. | ||
43 | + """ | ||
44 | + return {"input_ids": input_ids} | ||
45 | + | ||
46 | + def adjust_logits_during_generation(self, logits, **kwargs): | ||
47 | + """ | ||
48 | + Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in | ||
49 | + the generate method. | ||
50 | + """ | ||
51 | + return logits | ||
52 | + | ||
53 | + def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): | ||
54 | + """ | ||
55 | + Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__). | ||
56 | + """ | ||
57 | + for i in range(batch_size * num_beams): | ||
58 | + for previous_token in set(prev_output_tokens[i].tolist()): | ||
59 | + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability | ||
60 | + if lprobs[i, previous_token] < 0: | ||
61 | + lprobs[i, previous_token] *= repetition_penalty | ||
62 | + else: | ||
63 | + lprobs[i, previous_token] /= repetition_penalty | ||
64 | + | ||
65 | + def postprocess_next_token_scores( | ||
66 | + self, | ||
67 | + scores, | ||
68 | + input_ids, | ||
69 | + no_repeat_ngram_size, | ||
70 | + bad_words_ids, | ||
71 | + cur_len, | ||
72 | + min_length, | ||
73 | + max_length, | ||
74 | + eos_token_id, | ||
75 | + repetition_penalty, | ||
76 | + batch_size, | ||
77 | + num_beams, | ||
78 | + ): | ||
79 | + # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) | ||
80 | + if repetition_penalty != 1.0: | ||
81 | + self.enforce_repetition_penalty_( | ||
82 | + scores, | ||
83 | + batch_size, | ||
84 | + num_beams, | ||
85 | + input_ids, | ||
86 | + repetition_penalty, | ||
87 | + ) | ||
88 | + | ||
89 | + # set eos token prob to zero if min_length is not reached | ||
90 | + if eos_token_id is not None and cur_len < min_length: | ||
91 | + scores[:, eos_token_id] = -float("inf") | ||
92 | + | ||
93 | + if no_repeat_ngram_size > 0: | ||
94 | + # calculate a list of banned tokens to prevent repetitively generating the same ngrams | ||
95 | + num_batch_hypotheses = batch_size * num_beams | ||
96 | + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 | ||
97 | + banned_batch_tokens = calc_banned_ngram_tokens( | ||
98 | + input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len | ||
99 | + ) | ||
100 | + for i, banned_tokens in enumerate(banned_batch_tokens): | ||
101 | + scores[i, banned_tokens] = -float("inf") | ||
102 | + | ||
103 | + if bad_words_ids is not None: | ||
104 | + # Exclude EOS token (already processed) | ||
105 | + bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) | ||
106 | + # calculate a list of banned tokens according to bad words | ||
107 | + banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) | ||
108 | + # Modify the scores in place by setting the banned tokens logits to `-inf` | ||
109 | + set_scores_to_inf_for_banned_tokens(scores, banned_tokens) | ||
110 | + | ||
111 | + return scores | ||
112 | + | ||
113 | + @torch.no_grad() | ||
114 | + def generate( | ||
115 | + self, | ||
116 | + input_ids: Optional[torch.LongTensor] = None, | ||
117 | + max_length: Optional[int] = None, | ||
118 | + min_length: Optional[int] = None, | ||
119 | + do_sample: Optional[bool] = None, | ||
120 | + early_stopping: Optional[bool] = None, | ||
121 | + num_beams: Optional[int] = None, | ||
122 | + temperature: Optional[float] = None, | ||
123 | + top_k: Optional[int] = None, | ||
124 | + top_p: Optional[float] = None, | ||
125 | + repetition_penalty: Optional[float] = None, | ||
126 | + bad_words_ids: Optional[Iterable[int]] = None, | ||
127 | + bos_token_id: Optional[int] = None, | ||
128 | + pad_token_id: Optional[int] = None, | ||
129 | + eos_token_id: Optional[int] = None, | ||
130 | + length_penalty: Optional[float] = None, | ||
131 | + no_repeat_ngram_size: Optional[int] = None, | ||
132 | + num_return_sequences: Optional[int] = None, | ||
133 | + attention_mask: Optional[torch.LongTensor] = None, | ||
134 | + decoder_start_token_id: Optional[int] = None, | ||
135 | + use_cache: Optional[bool] = None, | ||
136 | + **model_kwargs | ||
137 | + ) -> torch.LongTensor: | ||
138 | + r""" | ||
139 | + Generates sequences for models with a language modeling head. The method currently supports greedy decoding, | ||
140 | + beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. | ||
141 | + | ||
142 | + Adapted in part from `Facebook's XLM beam search code | ||
143 | + <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. | ||
144 | + | ||
145 | + Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the | ||
146 | + attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values | ||
147 | + indicated are the default values of those config. | ||
148 | + | ||
149 | + Most of these parameters are explained in more detail in `this blog post | ||
150 | + <https://huggingface.co/blog/how-to-generate>`__. | ||
151 | + | ||
152 | + Parameters: | ||
153 | + | ||
154 | + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | ||
155 | + The sequence used as a prompt for the generation. If :obj:`None` the method initializes | ||
156 | + it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. | ||
157 | + max_length (:obj:`int`, `optional`, defaults to 20): | ||
158 | + The maximum length of the sequence to be generated. | ||
159 | + min_length (:obj:`int`, `optional`, defaults to 10): | ||
160 | + The minimum length of the sequence to be generated. | ||
161 | + do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
162 | + Whether or not to use sampling ; use greedy decoding otherwise. | ||
163 | + early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
164 | + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. | ||
165 | + num_beams (:obj:`int`, `optional`, defaults to 1): | ||
166 | + Number of beams for beam search. 1 means no beam search. | ||
167 | + temperature (:obj:`float`, `optional`, defaults tp 1.0): | ||
168 | + The value used to module the next token probabilities. | ||
169 | + top_k (:obj:`int`, `optional`, defaults to 50): | ||
170 | + The number of highest probability vocabulary tokens to keep for top-k-filtering. | ||
171 | + top_p (:obj:`float`, `optional`, defaults to 1.0): | ||
172 | + If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or | ||
173 | + higher are kept for generation. | ||
174 | + repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): | ||
175 | + The parameter for repetition penalty. 1.0 means no penalty. See `this paper | ||
176 | + <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. | ||
177 | + pad_token_id (:obj:`int`, `optional`): | ||
178 | + The id of the `padding` token. | ||
179 | + bos_token_id (:obj:`int`, `optional`): | ||
180 | + The id of the `beginning-of-sequence` token. | ||
181 | + eos_token_id (:obj:`int`, `optional`): | ||
182 | + The id of the `end-of-sequence` token. | ||
183 | + length_penalty (:obj:`float`, `optional`, defaults to 1.0): | ||
184 | + Exponential penalty to the length. 1.0 means no penalty. | ||
185 | + | ||
186 | + Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in | ||
187 | + order to encourage the model to produce longer sequences. | ||
188 | + no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): | ||
189 | + If set to int > 0, all ngrams of that size can only occur once. | ||
190 | + bad_words_ids(:obj:`List[int]`, `optional`): | ||
191 | + List of token ids that are not allowed to be generated. In order to get the tokens of the words that | ||
192 | + should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. | ||
193 | + num_return_sequences(:obj:`int`, `optional`, defaults to 1): | ||
194 | + The number of independently computed returned sequences for each element in the batch. | ||
195 | + attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | ||
196 | + Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for | ||
197 | + tokens that are not masked, and 0 for masked tokens. | ||
198 | + | ||
199 | + If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. | ||
200 | + | ||
201 | + `What are attention masks? <../glossary.html#attention-mask>`__ | ||
202 | + decoder_start_token_id (:obj:`int`, `optional`): | ||
203 | + If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. | ||
204 | + use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
205 | + Whether or not the model should use the past last key/values attentions (if applicable to the model) to | ||
206 | + speed up decoding. | ||
207 | + model_kwargs: | ||
208 | + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. | ||
209 | + | ||
210 | + Return: | ||
211 | + | ||
212 | + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: | ||
213 | + The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | ||
214 | + shorter if all batches finished early due to the :obj:`eos_token_id`. | ||
215 | + | ||
216 | + Examples:: | ||
217 | + | ||
218 | + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer | ||
219 | + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. | ||
220 | + outputs = model.generate(max_length=40) # do greedy decoding | ||
221 | + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) | ||
222 | + | ||
223 | + tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer | ||
224 | + model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. | ||
225 | + input_context = 'The dog' | ||
226 | + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context | ||
227 | + outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' | ||
228 | + for i in range(3): # 3 output sequences were generated | ||
229 | + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) | ||
230 | + | ||
231 | + tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer | ||
232 | + model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. | ||
233 | + input_context = 'The dog' | ||
234 | + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context | ||
235 | + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling | ||
236 | + for i in range(3): # 3 output sequences were generated | ||
237 | + print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) | ||
238 | + | ||
239 | + tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer | ||
240 | + model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. | ||
241 | + input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl | ||
242 | + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context | ||
243 | + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences | ||
244 | + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) | ||
245 | + | ||
246 | + tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer | ||
247 | + model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. | ||
248 | + input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl | ||
249 | + bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] | ||
250 | + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context | ||
251 | + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated | ||
252 | + """ | ||
253 | + | ||
254 | + # We cannot generate if the model does not have a LM head | ||
255 | + if self.get_output_embeddings() is None: | ||
256 | + raise AttributeError( | ||
257 | + "You tried to generate sequences with a model that does not have a LM Head." | ||
258 | + "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" | ||
259 | + ) | ||
260 | + | ||
261 | + max_length = max_length if max_length is not None else self.config.max_length | ||
262 | + min_length = min_length if min_length is not None else self.config.min_length | ||
263 | + do_sample = do_sample if do_sample is not None else self.config.do_sample | ||
264 | + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping | ||
265 | + use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
266 | + num_beams = num_beams if num_beams is not None else self.config.num_beams | ||
267 | + temperature = temperature if temperature is not None else self.config.temperature | ||
268 | + top_k = top_k if top_k is not None else self.config.top_k | ||
269 | + top_p = top_p if top_p is not None else self.config.top_p | ||
270 | + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty | ||
271 | + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id | ||
272 | + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | ||
273 | + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | ||
274 | + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty | ||
275 | + no_repeat_ngram_size = ( | ||
276 | + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size | ||
277 | + ) | ||
278 | + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids | ||
279 | + num_return_sequences = ( | ||
280 | + num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences | ||
281 | + ) | ||
282 | + decoder_start_token_id = ( | ||
283 | + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id | ||
284 | + ) | ||
285 | + | ||
286 | + if input_ids is not None: | ||
287 | + batch_size = input_ids.shape[0] # overriden by the input batch_size | ||
288 | + else: | ||
289 | + batch_size = 1 | ||
290 | + | ||
291 | + assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." | ||
292 | + assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." | ||
293 | + assert isinstance(do_sample, bool), "`do_sample` should be a boolean." | ||
294 | + assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." | ||
295 | + assert isinstance(use_cache, bool), "`use_cache` should be a boolean." | ||
296 | + assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." | ||
297 | + assert temperature > 0, "`temperature` should be strictly positive." | ||
298 | + assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." | ||
299 | + assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." | ||
300 | + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." | ||
301 | + assert input_ids is not None or ( | ||
302 | + isinstance(bos_token_id, int) and bos_token_id >= 0 | ||
303 | + ), "If input_ids is not defined, `bos_token_id` should be a positive integer." | ||
304 | + assert pad_token_id is None or ( | ||
305 | + isinstance(pad_token_id, int) and (pad_token_id >= 0) | ||
306 | + ), "`pad_token_id` should be a positive integer." | ||
307 | + assert (eos_token_id is None) or ( | ||
308 | + isinstance(eos_token_id, int) and (eos_token_id >= 0) | ||
309 | + ), "`eos_token_id` should be a positive integer." | ||
310 | + assert length_penalty > 0, "`length_penalty` should be strictly positive." | ||
311 | + assert ( | ||
312 | + isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 | ||
313 | + ), "`no_repeat_ngram_size` should be a positive integer." | ||
314 | + assert ( | ||
315 | + isinstance(num_return_sequences, int) and num_return_sequences > 0 | ||
316 | + ), "`num_return_sequences` should be a strictly positive integer." | ||
317 | + assert ( | ||
318 | + bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) | ||
319 | + ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" | ||
320 | + | ||
321 | + if input_ids is None: | ||
322 | + assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( | ||
323 | + "you should either supply a context to complete as `input_ids` input " | ||
324 | + "or a `bos_token_id` (integer >= 0) as a first token to start the generation." | ||
325 | + ) | ||
326 | + input_ids = torch.full( | ||
327 | + (batch_size, 1), | ||
328 | + bos_token_id, | ||
329 | + dtype=torch.long, | ||
330 | + device=next(self.parameters()).device, | ||
331 | + ) | ||
332 | + else: | ||
333 | + assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." | ||
334 | + | ||
335 | + # not allow to duplicate outputs when greedy decoding | ||
336 | + if do_sample is False: | ||
337 | + if num_beams == 1: | ||
338 | + # no_beam_search greedy generation conditions | ||
339 | + assert ( | ||
340 | + num_return_sequences == 1 | ||
341 | + ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" | ||
342 | + | ||
343 | + else: | ||
344 | + # beam_search greedy generation conditions | ||
345 | + assert ( | ||
346 | + num_beams >= num_return_sequences | ||
347 | + ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" | ||
348 | + | ||
349 | + # create attention mask if necessary | ||
350 | + # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 | ||
351 | + if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): | ||
352 | + attention_mask = input_ids.ne(pad_token_id).long() | ||
353 | + elif attention_mask is None: | ||
354 | + attention_mask = input_ids.new_ones(input_ids.shape) | ||
355 | + | ||
356 | + # set pad_token_id to eos_token_id if not set. Important that this is done after | ||
357 | + # attention_mask is created | ||
358 | + if pad_token_id is None and eos_token_id is not None: | ||
359 | + logger.warning( | ||
360 | + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) | ||
361 | + ) | ||
362 | + pad_token_id = eos_token_id | ||
363 | + | ||
364 | + # current position and vocab size | ||
365 | + if hasattr(self.config, "vocab_size"): | ||
366 | + vocab_size = self.config.vocab_size | ||
367 | + elif ( | ||
368 | + self.config.is_encoder_decoder | ||
369 | + and hasattr(self.config, "decoder") | ||
370 | + and hasattr(self.config.decoder, "vocab_size") | ||
371 | + ): | ||
372 | + vocab_size = self.config.decoder.vocab_size | ||
373 | + | ||
374 | + # set effective batch size and effective batch multiplier according to do_sample | ||
375 | + if do_sample: | ||
376 | + effective_batch_size = batch_size * num_return_sequences | ||
377 | + effective_batch_mult = num_return_sequences | ||
378 | + else: | ||
379 | + effective_batch_size = batch_size | ||
380 | + effective_batch_mult = 1 | ||
381 | + | ||
382 | + if self.config.is_encoder_decoder: | ||
383 | + if decoder_start_token_id is None: | ||
384 | + # see if BOS token can be used for decoder_start_token_id | ||
385 | + if bos_token_id is not None: | ||
386 | + decoder_start_token_id = bos_token_id | ||
387 | + elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"): | ||
388 | + decoder_start_token_id = self.config.decoder.bos_token_id | ||
389 | + else: | ||
390 | + raise ValueError( | ||
391 | + "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" | ||
392 | + ) | ||
393 | + | ||
394 | + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) | ||
395 | + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) | ||
396 | + | ||
397 | + # get encoder and store encoder outputs | ||
398 | + encoder = self.get_encoder() | ||
399 | + encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) | ||
400 | + | ||
401 | + # Expand input ids if num_beams > 1 or num_return_sequences > 1 | ||
402 | + if num_return_sequences > 1 or num_beams > 1: | ||
403 | + input_ids_len = input_ids.shape[-1] | ||
404 | + input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) | ||
405 | + attention_mask = attention_mask.unsqueeze(1).expand( | ||
406 | + batch_size, effective_batch_mult * num_beams, input_ids_len | ||
407 | + ) | ||
408 | + | ||
409 | + input_ids = input_ids.contiguous().view( | ||
410 | + effective_batch_size * num_beams, input_ids_len | ||
411 | + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
412 | + attention_mask = attention_mask.contiguous().view( | ||
413 | + effective_batch_size * num_beams, input_ids_len | ||
414 | + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
415 | + | ||
416 | + if self.config.is_encoder_decoder: | ||
417 | + # create empty decoder_input_ids | ||
418 | + input_ids = torch.full( | ||
419 | + (effective_batch_size * num_beams, 1), | ||
420 | + decoder_start_token_id, | ||
421 | + dtype=torch.long, | ||
422 | + device=next(self.parameters()).device, | ||
423 | + ) | ||
424 | + cur_len = 1 | ||
425 | + | ||
426 | + assert ( | ||
427 | + batch_size == encoder_outputs.last_hidden_state.shape[0] | ||
428 | + ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " | ||
429 | + | ||
430 | + # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) | ||
431 | + expanded_batch_idxs = ( | ||
432 | + torch.arange(batch_size) | ||
433 | + .view(-1, 1) | ||
434 | + .repeat(1, num_beams * effective_batch_mult) | ||
435 | + .view(-1) | ||
436 | + .to(input_ids.device) | ||
437 | + ) | ||
438 | + | ||
439 | + # expand encoder_outputs | ||
440 | + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( | ||
441 | + 0, expanded_batch_idxs | ||
442 | + ) | ||
443 | + | ||
444 | + # save encoder_outputs in `model_kwargs` | ||
445 | + model_kwargs["encoder_outputs"] = encoder_outputs | ||
446 | + | ||
447 | + else: | ||
448 | + cur_len = input_ids.shape[-1] | ||
449 | + | ||
450 | + assert ( | ||
451 | + cur_len < max_length | ||
452 | + ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" | ||
453 | + | ||
454 | + if num_beams > 1: | ||
455 | + output = self._generate_beam_search( | ||
456 | + input_ids, | ||
457 | + cur_len=cur_len, | ||
458 | + max_length=max_length, | ||
459 | + min_length=min_length, | ||
460 | + do_sample=do_sample, | ||
461 | + early_stopping=early_stopping, | ||
462 | + temperature=temperature, | ||
463 | + top_k=top_k, | ||
464 | + top_p=top_p, | ||
465 | + repetition_penalty=repetition_penalty, | ||
466 | + no_repeat_ngram_size=no_repeat_ngram_size, | ||
467 | + bad_words_ids=bad_words_ids, | ||
468 | + pad_token_id=pad_token_id, | ||
469 | + eos_token_id=eos_token_id, | ||
470 | + batch_size=effective_batch_size, | ||
471 | + num_return_sequences=num_return_sequences, | ||
472 | + length_penalty=length_penalty, | ||
473 | + num_beams=num_beams, | ||
474 | + vocab_size=vocab_size, | ||
475 | + attention_mask=attention_mask, | ||
476 | + use_cache=use_cache, | ||
477 | + model_kwargs=model_kwargs, | ||
478 | + ) | ||
479 | + else: | ||
480 | + output = self._generate_no_beam_search( | ||
481 | + input_ids, | ||
482 | + cur_len=cur_len, | ||
483 | + max_length=max_length, | ||
484 | + min_length=min_length, | ||
485 | + do_sample=do_sample, | ||
486 | + temperature=temperature, | ||
487 | + top_k=top_k, | ||
488 | + top_p=top_p, | ||
489 | + repetition_penalty=repetition_penalty, | ||
490 | + no_repeat_ngram_size=no_repeat_ngram_size, | ||
491 | + bad_words_ids=bad_words_ids, | ||
492 | + pad_token_id=pad_token_id, | ||
493 | + eos_token_id=eos_token_id, | ||
494 | + batch_size=effective_batch_size, | ||
495 | + attention_mask=attention_mask, | ||
496 | + use_cache=use_cache, | ||
497 | + model_kwargs=model_kwargs, | ||
498 | + ) | ||
499 | + | ||
500 | + return output | ||
501 | + | ||
502 | + def _generate_no_beam_search( | ||
503 | + self, | ||
504 | + input_ids, | ||
505 | + cur_len, | ||
506 | + max_length, | ||
507 | + min_length, | ||
508 | + do_sample, | ||
509 | + temperature, | ||
510 | + top_k, | ||
511 | + top_p, | ||
512 | + repetition_penalty, | ||
513 | + no_repeat_ngram_size, | ||
514 | + bad_words_ids, | ||
515 | + pad_token_id, | ||
516 | + eos_token_id, | ||
517 | + batch_size, | ||
518 | + attention_mask, | ||
519 | + use_cache, | ||
520 | + model_kwargs, | ||
521 | + ): | ||
522 | + """Generate sequences for each example without beam search (num_beams == 1). | ||
523 | + All returned sequence are generated independantly. | ||
524 | + """ | ||
525 | + # length of generated sentences / unfinished sentences | ||
526 | + unfinished_sents = input_ids.new(batch_size).fill_(1) | ||
527 | + sent_lengths = input_ids.new(batch_size).fill_(max_length) | ||
528 | + | ||
529 | + past = None | ||
530 | + while cur_len < max_length: | ||
531 | + model_inputs = self.prepare_inputs_for_generation( | ||
532 | + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs | ||
533 | + ) | ||
534 | + | ||
535 | + outputs = self(**model_inputs, return_dict=True) | ||
536 | + next_token_logits = outputs.logits[:, -1, :] | ||
537 | + | ||
538 | + scores = self.postprocess_next_token_scores( | ||
539 | + scores=next_token_logits, | ||
540 | + input_ids=input_ids, | ||
541 | + no_repeat_ngram_size=no_repeat_ngram_size, | ||
542 | + bad_words_ids=bad_words_ids, | ||
543 | + cur_len=cur_len, | ||
544 | + min_length=min_length, | ||
545 | + max_length=max_length, | ||
546 | + eos_token_id=eos_token_id, | ||
547 | + repetition_penalty=repetition_penalty, | ||
548 | + batch_size=batch_size, | ||
549 | + num_beams=1, | ||
550 | + ) | ||
551 | + | ||
552 | + # if model has past, then set the past variable to speed up decoding | ||
553 | + if "past_key_values" in outputs: | ||
554 | + past = outputs.past_key_values | ||
555 | + elif "mems" in outputs: | ||
556 | + past = outputs.mems | ||
557 | + | ||
558 | + if do_sample: | ||
559 | + # Temperature (higher temperature => more likely to sample low probability tokens) | ||
560 | + if temperature != 1.0: | ||
561 | + scores = scores / temperature | ||
562 | + # Top-p/top-k filtering | ||
563 | + next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) | ||
564 | + # Sample | ||
565 | + probs = F.softmax(next_token_logscores, dim=-1) | ||
566 | + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) | ||
567 | + else: | ||
568 | + # Greedy decoding | ||
569 | + next_token = torch.argmax(next_token_logits, dim=-1) | ||
570 | + | ||
571 | + # update generations and finished sentences | ||
572 | + if eos_token_id is not None: | ||
573 | + # pad finished sentences if eos_token_id exist | ||
574 | + tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) | ||
575 | + else: | ||
576 | + tokens_to_add = next_token | ||
577 | + | ||
578 | + # add token and increase length by one | ||
579 | + input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) | ||
580 | + cur_len = cur_len + 1 | ||
581 | + | ||
582 | + if eos_token_id is not None: | ||
583 | + eos_in_sents = tokens_to_add == eos_token_id | ||
584 | + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length | ||
585 | + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() | ||
586 | + sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) | ||
587 | + # unfinished_sents is set to zero if eos in sentence | ||
588 | + unfinished_sents.mul_((~eos_in_sents).long()) | ||
589 | + | ||
590 | + # stop when there is a </s> in each sentence, or if we exceed the maximul length | ||
591 | + if unfinished_sents.max() == 0: | ||
592 | + break | ||
593 | + | ||
594 | + # extend attention_mask for new generated input if only decoder | ||
595 | + if self.config.is_encoder_decoder is False: | ||
596 | + attention_mask = torch.cat( | ||
597 | + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | ||
598 | + ) | ||
599 | + | ||
600 | + return input_ids | ||
601 | + | ||
602 | + def _generate_beam_search( | ||
603 | + self, | ||
604 | + input_ids, | ||
605 | + cur_len, | ||
606 | + max_length, | ||
607 | + min_length, | ||
608 | + do_sample, | ||
609 | + early_stopping, | ||
610 | + temperature, | ||
611 | + top_k, | ||
612 | + top_p, | ||
613 | + repetition_penalty, | ||
614 | + no_repeat_ngram_size, | ||
615 | + bad_words_ids, | ||
616 | + pad_token_id, | ||
617 | + eos_token_id, | ||
618 | + batch_size, | ||
619 | + num_return_sequences, | ||
620 | + length_penalty, | ||
621 | + num_beams, | ||
622 | + vocab_size, | ||
623 | + attention_mask, | ||
624 | + use_cache, | ||
625 | + model_kwargs, | ||
626 | + ): | ||
627 | + """Generate sequences for each example with beam search.""" | ||
628 | + | ||
629 | + # generated hypotheses | ||
630 | + generated_hyps = [ | ||
631 | + BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) | ||
632 | + for _ in range(batch_size) | ||
633 | + ] | ||
634 | + | ||
635 | + # scores for each sentence in the beam | ||
636 | + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | ||
637 | + | ||
638 | + # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times | ||
639 | + if do_sample is False: | ||
640 | + beam_scores[:, 1:] = -1e9 | ||
641 | + beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) | ||
642 | + | ||
643 | + # cache compute states | ||
644 | + past = None | ||
645 | + | ||
646 | + # done sentences | ||
647 | + done = [False for _ in range(batch_size)] | ||
648 | + | ||
649 | + while cur_len < max_length: | ||
650 | + model_inputs = self.prepare_inputs_for_generation( | ||
651 | + input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs | ||
652 | + ) | ||
653 | + outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size) | ||
654 | + next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size) | ||
655 | + | ||
656 | + # if model has past, then set the past variable to speed up decoding | ||
657 | + if "past_key_values" in outputs: | ||
658 | + past = outputs.past_key_values | ||
659 | + elif "mems" in outputs: | ||
660 | + past = outputs.mems | ||
661 | + | ||
662 | + if self.config.is_encoder_decoder and do_sample is False: | ||
663 | + # TODO (PVP) still a bit hacky here - there might be a better solution | ||
664 | + next_token_logits = self.adjust_logits_during_generation( | ||
665 | + next_token_logits, cur_len=cur_len, max_length=max_length | ||
666 | + ) | ||
667 | + | ||
668 | + scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) | ||
669 | + | ||
670 | + scores = self.postprocess_next_token_scores( | ||
671 | + scores=scores, | ||
672 | + input_ids=input_ids, | ||
673 | + no_repeat_ngram_size=no_repeat_ngram_size, | ||
674 | + bad_words_ids=bad_words_ids, | ||
675 | + cur_len=cur_len, | ||
676 | + min_length=min_length, | ||
677 | + max_length=max_length, | ||
678 | + eos_token_id=eos_token_id, | ||
679 | + repetition_penalty=repetition_penalty, | ||
680 | + batch_size=batch_size, | ||
681 | + num_beams=num_beams, | ||
682 | + ) | ||
683 | + | ||
684 | + assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( | ||
685 | + scores.shape, (batch_size * num_beams, vocab_size) | ||
686 | + ) | ||
687 | + | ||
688 | + if do_sample: | ||
689 | + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) | ||
690 | + # Temperature | ||
691 | + if temperature != 1.0: | ||
692 | + _scores = _scores / temperature | ||
693 | + # Top-p/top-k filtering | ||
694 | + _scores = top_k_top_p_filtering( | ||
695 | + _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 | ||
696 | + ) # (batch_size * num_beams, vocab_size) | ||
697 | + # re-organize to group the beam together to sample from all beam_idxs | ||
698 | + _scores = _scores.contiguous().view( | ||
699 | + batch_size, num_beams * vocab_size | ||
700 | + ) # (batch_size, num_beams * vocab_size) | ||
701 | + | ||
702 | + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) | ||
703 | + probs = F.softmax(_scores, dim=-1) | ||
704 | + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2) | ||
705 | + # Compute next scores | ||
706 | + next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) | ||
707 | + # sort the sampled vector to make sure that the first num_beams samples are the best | ||
708 | + next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) | ||
709 | + next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) | ||
710 | + | ||
711 | + else: | ||
712 | + next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) | ||
713 | + | ||
714 | + # re-organize to group the beam together (we are keeping top hypothesis accross beams) | ||
715 | + next_scores = next_scores.view( | ||
716 | + batch_size, num_beams * vocab_size | ||
717 | + ) # (batch_size, num_beams * vocab_size) | ||
718 | + | ||
719 | + next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True) | ||
720 | + | ||
721 | + assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) | ||
722 | + | ||
723 | + # next batch beam content | ||
724 | + next_batch_beam = [] | ||
725 | + | ||
726 | + # for each sentence | ||
727 | + for batch_idx in range(batch_size): | ||
728 | + | ||
729 | + # if we are done with this sentence, add a pad token | ||
730 | + if done[batch_idx]: | ||
731 | + assert ( | ||
732 | + len(generated_hyps[batch_idx]) >= num_beams | ||
733 | + ), "Batch can only be done if at least {} beams have been generated".format(num_beams) | ||
734 | + assert ( | ||
735 | + eos_token_id is not None and pad_token_id is not None | ||
736 | + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" | ||
737 | + next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch | ||
738 | + continue | ||
739 | + | ||
740 | + # next sentence beam content, this will get added to next_batch_beam | ||
741 | + next_sent_beam = [] | ||
742 | + | ||
743 | + # next tokens for this sentence | ||
744 | + for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( | ||
745 | + zip(next_tokens[batch_idx], next_scores[batch_idx]) | ||
746 | + ): | ||
747 | + # get beam and token IDs | ||
748 | + beam_id = beam_token_id // vocab_size | ||
749 | + token_id = beam_token_id % vocab_size | ||
750 | + | ||
751 | + effective_beam_id = batch_idx * num_beams + beam_id | ||
752 | + # add to generated hypotheses if end of sentence | ||
753 | + if (eos_token_id is not None) and (token_id.item() == eos_token_id): | ||
754 | + # if beam_token does not belong to top num_beams tokens, it should not be added | ||
755 | + is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams | ||
756 | + if is_beam_token_worse_than_top_num_beams: | ||
757 | + continue | ||
758 | + generated_hyps[batch_idx].add( | ||
759 | + input_ids[effective_beam_id].clone(), | ||
760 | + beam_token_score.item(), | ||
761 | + ) | ||
762 | + else: | ||
763 | + # add next predicted token since it is not eos_token | ||
764 | + next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) | ||
765 | + | ||
766 | + # once the beam for next step is full, don't add more tokens to it. | ||
767 | + if len(next_sent_beam) == num_beams: | ||
768 | + break | ||
769 | + | ||
770 | + # Check if we are done so that we can save a pad step if all(done) | ||
771 | + done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( | ||
772 | + next_scores[batch_idx].max().item(), cur_len | ||
773 | + ) | ||
774 | + | ||
775 | + # update next beam content | ||
776 | + assert len(next_sent_beam) == num_beams, "Beam should always be full" | ||
777 | + next_batch_beam.extend(next_sent_beam) | ||
778 | + assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step" | ||
779 | + | ||
780 | + # stop when we are done with each sentence | ||
781 | + if all(done): | ||
782 | + break | ||
783 | + | ||
784 | + # sanity check / prepare next batch | ||
785 | + assert len(next_batch_beam) == batch_size * num_beams | ||
786 | + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) | ||
787 | + beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) | ||
788 | + beam_idx = input_ids.new([x[2] for x in next_batch_beam]) | ||
789 | + | ||
790 | + # re-order batch and update current length | ||
791 | + input_ids = input_ids[beam_idx, :] | ||
792 | + input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) | ||
793 | + cur_len = cur_len + 1 | ||
794 | + | ||
795 | + # re-order internal states | ||
796 | + if past is not None: | ||
797 | + past = self._reorder_cache(past, beam_idx) | ||
798 | + | ||
799 | + # extend attention_mask for new generated input if only decoder | ||
800 | + if self.config.is_encoder_decoder is False: | ||
801 | + attention_mask = torch.cat( | ||
802 | + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | ||
803 | + ) | ||
804 | + | ||
805 | + # finalize all open beam hypotheses and add to generated hypotheses | ||
806 | + for batch_idx in range(batch_size): | ||
807 | + if done[batch_idx]: | ||
808 | + continue | ||
809 | + | ||
810 | + # test that beam scores match previously calculated scores if not eos and batch_idx not done | ||
811 | + if eos_token_id is not None and all( | ||
812 | + (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx] | ||
813 | + ): | ||
814 | + assert torch.all( | ||
815 | + next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] | ||
816 | + ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( | ||
817 | + next_scores[:, :num_beams][batch_idx], | ||
818 | + beam_scores.view(batch_size, num_beams)[batch_idx], | ||
819 | + ) | ||
820 | + | ||
821 | + # need to add best num_beams hypotheses to generated hyps | ||
822 | + for beam_id in range(num_beams): | ||
823 | + effective_beam_id = batch_idx * num_beams + beam_id | ||
824 | + final_score = beam_scores[effective_beam_id].item() | ||
825 | + final_tokens = input_ids[effective_beam_id] | ||
826 | + generated_hyps[batch_idx].add(final_tokens, final_score) | ||
827 | + | ||
828 | + # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch | ||
829 | + output_batch_size = batch_size if do_sample else batch_size * num_return_sequences | ||
830 | + output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences | ||
831 | + | ||
832 | + # select the best hypotheses | ||
833 | + sent_lengths = input_ids.new(output_batch_size) | ||
834 | + best = [] | ||
835 | + | ||
836 | + # retrieve best hypotheses | ||
837 | + for i, hypotheses in enumerate(generated_hyps): | ||
838 | + sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) | ||
839 | + for j in range(output_num_return_sequences_per_batch): | ||
840 | + effective_batch_idx = output_num_return_sequences_per_batch * i + j | ||
841 | + best_hyp = sorted_hyps.pop()[1] | ||
842 | + sent_lengths[effective_batch_idx] = len(best_hyp) | ||
843 | + best.append(best_hyp) | ||
844 | + | ||
845 | + # shorter batches are padded | ||
846 | + if sent_lengths.min().item() != sent_lengths.max().item(): | ||
847 | + assert pad_token_id is not None, "`Pad_token_id` has to be defined" | ||
848 | + sent_max_len = min(sent_lengths.max().item() + 1, max_length) | ||
849 | + decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) | ||
850 | + | ||
851 | + # fill with hypothesis and eos_token_id if necessary | ||
852 | + for i, hypo in enumerate(best): | ||
853 | + decoded[i, : sent_lengths[i]] = hypo | ||
854 | + if sent_lengths[i] < max_length: | ||
855 | + decoded[i, sent_lengths[i]] = eos_token_id | ||
856 | + else: | ||
857 | + # none of the hypotheses have an eos_token | ||
858 | + assert (len(hypo) == max_length for hypo in best) | ||
859 | + decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) | ||
860 | + | ||
861 | + return decoded | ||
862 | + | ||
863 | + @staticmethod | ||
864 | + def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]: | ||
865 | + return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) | ||
866 | + | ||
867 | + | ||
868 | +def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: | ||
869 | + """Copied from fairseq for no_repeat_ngram in beam_search""" | ||
870 | + if cur_len + 1 < no_repeat_ngram_size: | ||
871 | + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | ||
872 | + return [[] for _ in range(num_hypos)] | ||
873 | + generated_ngrams = [{} for _ in range(num_hypos)] | ||
874 | + for idx in range(num_hypos): | ||
875 | + gen_tokens = prev_input_ids[idx].tolist() | ||
876 | + generated_ngram = generated_ngrams[idx] | ||
877 | + for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): | ||
878 | + prev_ngram_tuple = tuple(ngram[:-1]) | ||
879 | + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] | ||
880 | + | ||
881 | + def _get_generated_ngrams(hypo_idx): | ||
882 | + # Before decoding the next token, prevent decoding of ngrams that have already appeared | ||
883 | + start_idx = cur_len + 1 - no_repeat_ngram_size | ||
884 | + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) | ||
885 | + return generated_ngrams[hypo_idx].get(ngram_idx, []) | ||
886 | + | ||
887 | + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] | ||
888 | + return banned_tokens | ||
889 | + | ||
890 | + | ||
891 | +def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: | ||
892 | + banned_tokens = [] | ||
893 | + | ||
894 | + def _tokens_match(prev_tokens, tokens): | ||
895 | + if len(tokens) == 0: | ||
896 | + # if bad word tokens is just one token always ban it | ||
897 | + return True | ||
898 | + if len(tokens) > len(prev_tokens): | ||
899 | + # if bad word tokens are longer than prev tokens they can't be equal | ||
900 | + return False | ||
901 | + | ||
902 | + if prev_tokens[-len(tokens) :] == tokens: | ||
903 | + # if tokens match | ||
904 | + return True | ||
905 | + else: | ||
906 | + return False | ||
907 | + | ||
908 | + for prev_input_ids_slice in prev_input_ids: | ||
909 | + banned_tokens_slice = [] | ||
910 | + | ||
911 | + for banned_token_seq in bad_words_ids: | ||
912 | + assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( | ||
913 | + bad_words_ids | ||
914 | + ) | ||
915 | + | ||
916 | + if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: | ||
917 | + # if tokens do not match continue | ||
918 | + continue | ||
919 | + | ||
920 | + banned_tokens_slice.append(banned_token_seq[-1]) | ||
921 | + | ||
922 | + banned_tokens.append(banned_tokens_slice) | ||
923 | + | ||
924 | + return banned_tokens | ||
925 | + | ||
926 | + | ||
927 | +def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: | ||
928 | + """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be | ||
929 | + a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] | ||
930 | + Args: | ||
931 | + scores: logits distribution of shape (batch size, vocabulary size) | ||
932 | + banned_tokens: list of list of tokens to ban of length (batch_size) | ||
933 | + """ | ||
934 | + banned_mask_list = [] | ||
935 | + for idx, batch_banned_tokens in enumerate(banned_tokens): | ||
936 | + for token in batch_banned_tokens: | ||
937 | + banned_mask_list.append([idx, token]) | ||
938 | + if not banned_mask_list: | ||
939 | + return | ||
940 | + banned_mask = torch.LongTensor(banned_mask_list) | ||
941 | + indices = torch.ones(len(banned_mask)) | ||
942 | + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: | ||
943 | + # [ 0 1 1 ] | ||
944 | + # [ 0 0 0 ] | ||
945 | + # [ 1 0 0 ] | ||
946 | + | ||
947 | + banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() | ||
948 | + scores.masked_fill_(banned_mask, -float("inf")) | ||
949 | + | ||
950 | + | ||
951 | +def top_k_top_p_filtering( | ||
952 | + logits: Tensor, | ||
953 | + top_k: int = 0, | ||
954 | + top_p: float = 1.0, | ||
955 | + filter_value: float = -float("Inf"), | ||
956 | + min_tokens_to_keep: int = 1, | ||
957 | +) -> Tensor: | ||
958 | + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | ||
959 | + Args: | ||
960 | + logits: logits distribution shape (batch size, vocabulary size) | ||
961 | + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). | ||
962 | + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | ||
963 | + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | ||
964 | + Make sure we keep at least min_tokens_to_keep per batch example in the output | ||
965 | + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | ||
966 | + """ | ||
967 | + if top_k > 0: | ||
968 | + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check | ||
969 | + # Remove all tokens with a probability less than the last token of the top-k | ||
970 | + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | ||
971 | + logits[indices_to_remove] = filter_value | ||
972 | + | ||
973 | + if top_p < 1.0: | ||
974 | + sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
975 | + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||
976 | + | ||
977 | + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) | ||
978 | + sorted_indices_to_remove = cumulative_probs > top_p | ||
979 | + if min_tokens_to_keep > 1: | ||
980 | + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | ||
981 | + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | ||
982 | + # Shift the indices to the right to keep also the first token above the threshold | ||
983 | + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||
984 | + sorted_indices_to_remove[..., 0] = 0 | ||
985 | + | ||
986 | + # scatter sorted tensors to original indexing | ||
987 | + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | ||
988 | + logits[indices_to_remove] = filter_value | ||
989 | + return logits | ||
990 | + | ||
991 | + | ||
992 | +class BeamHypotheses(object): | ||
993 | + def __init__(self, num_beams, max_length, length_penalty, early_stopping): | ||
994 | + """ | ||
995 | + Initialize n-best list of hypotheses. | ||
996 | + """ | ||
997 | + self.max_length = max_length - 1 # ignoring bos_token | ||
998 | + self.length_penalty = length_penalty | ||
999 | + self.early_stopping = early_stopping | ||
1000 | + self.num_beams = num_beams | ||
1001 | + self.beams = [] | ||
1002 | + self.worst_score = 1e9 | ||
1003 | + | ||
1004 | + def __len__(self): | ||
1005 | + """ | ||
1006 | + Number of hypotheses in the list. | ||
1007 | + """ | ||
1008 | + return len(self.beams) | ||
1009 | + | ||
1010 | + def add(self, hyp, sum_logprobs): | ||
1011 | + """ | ||
1012 | + Add a new hypothesis to the list. | ||
1013 | + """ | ||
1014 | + score = sum_logprobs / len(hyp) ** self.length_penalty | ||
1015 | + if len(self) < self.num_beams or score > self.worst_score: | ||
1016 | + self.beams.append((score, hyp)) | ||
1017 | + if len(self) > self.num_beams: | ||
1018 | + sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) | ||
1019 | + del self.beams[sorted_scores[0][1]] | ||
1020 | + self.worst_score = sorted_scores[1][0] | ||
1021 | + else: | ||
1022 | + self.worst_score = min(score, self.worst_score) | ||
1023 | + | ||
1024 | + def is_done(self, best_sum_logprobs, cur_len): | ||
1025 | + """ | ||
1026 | + If there are enough hypotheses and that none of the hypotheses being generated | ||
1027 | + can become better than the worst one in the heap, then we are done with this sentence. | ||
1028 | + """ | ||
1029 | + | ||
1030 | + if len(self) < self.num_beams: | ||
1031 | + return False | ||
1032 | + elif self.early_stopping: | ||
1033 | + return True | ||
1034 | + else: | ||
1035 | + cur_score = best_sum_logprobs / cur_len ** self.length_penalty | ||
1036 | + ret = self.worst_score >= cur_score | ||
1037 | + return ret |
... | @@ -21,6 +21,8 @@ from transformers import ( | ... | @@ -21,6 +21,8 @@ from transformers import ( |
21 | PretrainedConfig, | 21 | PretrainedConfig, |
22 | PreTrainedTokenizer, | 22 | PreTrainedTokenizer, |
23 | ) | 23 | ) |
24 | +from modeling_bart import BartForConditionalGeneration | ||
25 | + | ||
24 | from transformers.optimization import ( | 26 | from transformers.optimization import ( |
25 | Adafactor, | 27 | Adafactor, |
26 | get_cosine_schedule_with_warmup, | 28 | get_cosine_schedule_with_warmup, |
... | @@ -40,7 +42,7 @@ MODEL_MODES = { | ... | @@ -40,7 +42,7 @@ MODEL_MODES = { |
40 | "pretraining": AutoModelForPreTraining, | 42 | "pretraining": AutoModelForPreTraining, |
41 | "token-classification": AutoModelForTokenClassification, | 43 | "token-classification": AutoModelForTokenClassification, |
42 | "language-modeling": AutoModelWithLMHead, | 44 | "language-modeling": AutoModelWithLMHead, |
43 | - "summarization": AutoModelForSeq2SeqLM, | 45 | + "summarization": BartForConditionalGeneration, |
44 | "translation": AutoModelForSeq2SeqLM, | 46 | "translation": AutoModelForSeq2SeqLM, |
45 | } | 47 | } |
46 | 48 | ... | ... |
modeling_bart.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. | ||
3 | +# | ||
4 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | +# you may not use this file except in compliance with the License. | ||
6 | +# You may obtain a copy of the License at | ||
7 | +# | ||
8 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | +# | ||
10 | +# Unless required by applicable law or agreed to in writing, software | ||
11 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | +# See the License for the specific language governing permissions and | ||
14 | +# limitations under the License. | ||
15 | +"""PyTorch BART model, ported from the fairseq repo.""" | ||
16 | +import math | ||
17 | +import random | ||
18 | +import warnings | ||
19 | +from typing import Dict, List, Optional, Tuple | ||
20 | + | ||
21 | +import numpy as np | ||
22 | +import torch | ||
23 | +import torch.nn.functional as F | ||
24 | +from torch import Tensor, nn | ||
25 | +from torch.nn import CrossEntropyLoss | ||
26 | + | ||
27 | +from transformers.activations import ACT2FN | ||
28 | +from transformers.configuration_bart import BartConfig | ||
29 | +from transformers.file_utils import ( | ||
30 | + add_code_sample_docstrings, | ||
31 | + add_end_docstrings, | ||
32 | + add_start_docstrings, | ||
33 | + add_start_docstrings_to_callable, | ||
34 | + replace_return_docstrings, | ||
35 | +) | ||
36 | +from transformers.modeling_outputs import ( | ||
37 | + BaseModelOutput, | ||
38 | + BaseModelOutputWithPast, | ||
39 | + Seq2SeqLMOutput, | ||
40 | + Seq2SeqModelOutput, | ||
41 | + Seq2SeqQuestionAnsweringModelOutput, | ||
42 | + Seq2SeqSequenceClassifierOutput, | ||
43 | +) | ||
44 | +from modeling_utils import PreTrainedModel | ||
45 | +import logging | ||
46 | + | ||
47 | +logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
48 | +logging.basicConfig( | ||
49 | + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
50 | + datefmt="%m/%d/%Y %H:%M:%S", | ||
51 | + level=logging.INFO, | ||
52 | +) | ||
53 | + | ||
54 | +_CONFIG_FOR_DOC = "BartConfig" | ||
55 | +_TOKENIZER_FOR_DOC = "BartTokenizer" | ||
56 | + | ||
57 | + | ||
58 | +BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ | ||
59 | + "facebook/bart-base", | ||
60 | + "facebook/bart-large", | ||
61 | + "facebook/bart-large-mnli", | ||
62 | + "facebook/bart-large-cnn", | ||
63 | + "facebook/bart-large-xsum", | ||
64 | + "facebook/mbart-large-en-ro", | ||
65 | +] | ||
66 | +# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart | ||
67 | + | ||
68 | + | ||
69 | +BART_START_DOCSTRING = r""" | ||
70 | + | ||
71 | + This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and | ||
72 | + refer to the PyTorch documentation for all matters related to general usage and behavior. | ||
73 | + | ||
74 | + Parameters: | ||
75 | + config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model. | ||
76 | + Initializing with a config file does not load the weights associated with the model, only the configuration. | ||
77 | + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. | ||
78 | + | ||
79 | +""" | ||
80 | +BART_GENERATION_EXAMPLE = r""" | ||
81 | + Summarization example:: | ||
82 | + | ||
83 | + from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig | ||
84 | + | ||
85 | + # see ``examples/summarization/bart/run_eval.py`` for a longer example | ||
86 | + model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') | ||
87 | + tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') | ||
88 | + | ||
89 | + ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." | ||
90 | + inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') | ||
91 | + | ||
92 | + # Generate Summary | ||
93 | + summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) | ||
94 | + print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) | ||
95 | + | ||
96 | +""" | ||
97 | + | ||
98 | +BART_INPUTS_DOCSTRING = r""" | ||
99 | + Args: | ||
100 | + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | ||
101 | + Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them. | ||
102 | + Padding will be ignored by default should you provide it. | ||
103 | + Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`. | ||
104 | + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
105 | + Mask to avoid performing attention on padding token indices in input_ids. | ||
106 | + Mask values selected in ``[0, 1]``: | ||
107 | + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. | ||
108 | + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): | ||
109 | + Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) | ||
110 | + `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. | ||
111 | + Used in the cross-attention of the decoder. | ||
112 | + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): | ||
113 | + Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper. | ||
114 | + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): | ||
115 | + Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. | ||
116 | + If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. | ||
117 | + See diagram 1 in the paper for more info on the default strategy | ||
118 | + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | ||
119 | + Contains pre-computed key and value hidden-states of the attention blocks. | ||
120 | + Can be used to speed up decoding. | ||
121 | + If ``past_key_values`` are used, the user can optionally input only the last | ||
122 | + ``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape | ||
123 | + :obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. | ||
124 | + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
125 | + If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see | ||
126 | + ``past_key_values``). | ||
127 | + output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): | ||
128 | + If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. | ||
129 | + output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): | ||
130 | + If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail. | ||
131 | + return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`): | ||
132 | + If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a | ||
133 | + plain tuple. | ||
134 | +""" | ||
135 | + | ||
136 | + | ||
137 | +def invert_mask(attention_mask): | ||
138 | + """Turns 1->0, 0->1, False->True, True-> False""" | ||
139 | + assert attention_mask.dim() == 2 | ||
140 | + return attention_mask.eq(0) | ||
141 | + | ||
142 | + | ||
143 | +def _prepare_bart_decoder_inputs( | ||
144 | + config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 | ||
145 | +): | ||
146 | + """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if | ||
147 | + none are provided. This mimics the default behavior in fairseq. To override it pass in masks. | ||
148 | + Note: this is not called during generation | ||
149 | + """ | ||
150 | + pad_token_id = config.pad_token_id | ||
151 | + if decoder_input_ids is None: | ||
152 | + decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) | ||
153 | + bsz, tgt_len = decoder_input_ids.size() | ||
154 | + if decoder_padding_mask is None: | ||
155 | + decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) | ||
156 | + else: | ||
157 | + decoder_padding_mask = invert_mask(decoder_padding_mask) | ||
158 | + if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1: | ||
159 | + # never mask leading token, even if it is pad | ||
160 | + decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1] | ||
161 | + causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( | ||
162 | + dtype=causal_mask_dtype, device=decoder_input_ids.device | ||
163 | + ) | ||
164 | + return decoder_input_ids, decoder_padding_mask, causal_mask | ||
165 | + | ||
166 | + | ||
167 | +class PretrainedBartModel(PreTrainedModel): | ||
168 | + config_class = BartConfig | ||
169 | + base_model_prefix = "model" | ||
170 | + | ||
171 | + def _init_weights(self, module): | ||
172 | + std = self.config.init_std | ||
173 | + if isinstance(module, nn.Linear): | ||
174 | + module.weight.data.normal_(mean=0.0, std=std) | ||
175 | + if module.bias is not None: | ||
176 | + module.bias.data.zero_() | ||
177 | + elif isinstance(module, SinusoidalPositionalEmbedding): | ||
178 | + pass | ||
179 | + elif isinstance(module, nn.Embedding): | ||
180 | + module.weight.data.normal_(mean=0.0, std=std) | ||
181 | + if module.padding_idx is not None: | ||
182 | + module.weight.data[module.padding_idx].zero_() | ||
183 | + | ||
184 | + @property | ||
185 | + def dummy_inputs(self): | ||
186 | + pad_token = self.config.pad_token_id | ||
187 | + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) | ||
188 | + dummy_inputs = { | ||
189 | + "attention_mask": input_ids.ne(pad_token), | ||
190 | + "input_ids": input_ids, | ||
191 | + } | ||
192 | + return dummy_inputs | ||
193 | + | ||
194 | + | ||
195 | +def _make_linear_from_emb(emb): | ||
196 | + vocab_size, emb_size = emb.weight.shape | ||
197 | + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) | ||
198 | + lin_layer.weight.data = emb.weight.data | ||
199 | + return lin_layer | ||
200 | + | ||
201 | + | ||
202 | +# Helper Functions, mostly for making masks | ||
203 | +def _check_shapes(shape_1, shape2): | ||
204 | + if shape_1 != shape2: | ||
205 | + raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) | ||
206 | + | ||
207 | + | ||
208 | +def shift_tokens_right(input_ids, pad_token_id): | ||
209 | + """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).""" | ||
210 | + prev_output_tokens = input_ids.clone() | ||
211 | + index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) | ||
212 | + prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() | ||
213 | + prev_output_tokens[:, 1:] = input_ids[:, :-1] | ||
214 | + return prev_output_tokens | ||
215 | + | ||
216 | + | ||
217 | +def make_padding_mask(input_ids, padding_idx=1): | ||
218 | + """True for pad tokens""" | ||
219 | + padding_mask = input_ids.eq(padding_idx) | ||
220 | + if not padding_mask.any(): | ||
221 | + padding_mask = None | ||
222 | + return padding_mask | ||
223 | + | ||
224 | + | ||
225 | +# Helper Modules | ||
226 | + | ||
227 | + | ||
228 | +class EncoderLayer(nn.Module): | ||
229 | + def __init__(self, config: BartConfig): | ||
230 | + super().__init__() | ||
231 | + self.embed_dim = config.d_model | ||
232 | + self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout) | ||
233 | + self.normalize_before = config.normalize_before | ||
234 | + self.self_attn_layer_norm = LayerNorm(self.embed_dim) | ||
235 | + self.dropout = config.dropout | ||
236 | + self.activation_fn = ACT2FN[config.activation_function] | ||
237 | + self.activation_dropout = config.activation_dropout | ||
238 | + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) | ||
239 | + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) | ||
240 | + self.final_layer_norm = LayerNorm(self.embed_dim) | ||
241 | + | ||
242 | + def forward(self, x, encoder_padding_mask, output_attentions=False): | ||
243 | + """ | ||
244 | + Args: | ||
245 | + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | ||
246 | + encoder_padding_mask (ByteTensor): binary ByteTensor of shape | ||
247 | + `(batch, src_len)` where padding elements are indicated by ``1``. | ||
248 | + for t_tgt, t_src is excluded (or masked out), =0 means it is | ||
249 | + included in attention | ||
250 | + | ||
251 | + Returns: | ||
252 | + encoded output of shape `(seq_len, batch, embed_dim)` | ||
253 | + """ | ||
254 | + residual = x | ||
255 | + if self.normalize_before: | ||
256 | + x = self.self_attn_layer_norm(x) | ||
257 | + x, attn_weights = self.self_attn( | ||
258 | + query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions | ||
259 | + ) | ||
260 | + x = F.dropout(x, p=self.dropout, training=self.training) | ||
261 | + x = residual + x | ||
262 | + if not self.normalize_before: | ||
263 | + x = self.self_attn_layer_norm(x) | ||
264 | + | ||
265 | + residual = x | ||
266 | + if self.normalize_before: | ||
267 | + x = self.final_layer_norm(x) | ||
268 | + x = self.activation_fn(self.fc1(x)) | ||
269 | + x = F.dropout(x, p=self.activation_dropout, training=self.training) | ||
270 | + x = self.fc2(x) | ||
271 | + x = F.dropout(x, p=self.dropout, training=self.training) | ||
272 | + x = residual + x | ||
273 | + if not self.normalize_before: | ||
274 | + x = self.final_layer_norm(x) | ||
275 | + return x, attn_weights | ||
276 | + | ||
277 | + | ||
278 | +class BartEncoder(nn.Module): | ||
279 | + """ | ||
280 | + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer | ||
281 | + is a :class:`EncoderLayer`. | ||
282 | + | ||
283 | + Args: | ||
284 | + config: BartConfig | ||
285 | + """ | ||
286 | + | ||
287 | + def __init__(self, config: BartConfig, embed_tokens): | ||
288 | + super().__init__() | ||
289 | + | ||
290 | + self.dropout = config.dropout | ||
291 | + self.layerdrop = config.encoder_layerdrop | ||
292 | + | ||
293 | + embed_dim = embed_tokens.embedding_dim | ||
294 | + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 | ||
295 | + self.padding_idx = embed_tokens.padding_idx | ||
296 | + self.max_source_positions = config.max_position_embeddings | ||
297 | + | ||
298 | + self.embed_tokens = embed_tokens | ||
299 | + if config.static_position_embeddings: | ||
300 | + self.embed_positions = SinusoidalPositionalEmbedding( | ||
301 | + config.max_position_embeddings, embed_dim, self.padding_idx | ||
302 | + ) | ||
303 | + else: | ||
304 | + self.embed_positions = LearnedPositionalEmbedding( | ||
305 | + config.max_position_embeddings, | ||
306 | + embed_dim, | ||
307 | + self.padding_idx, | ||
308 | + config.extra_pos_embeddings, | ||
309 | + ) | ||
310 | + self.embed_patches = nn.Embedding(3, config.d_model) | ||
311 | + self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) | ||
312 | + self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() | ||
313 | + # mbart has one extra layer_norm | ||
314 | + self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None | ||
315 | + | ||
316 | + def forward( | ||
317 | + self, input_ids, patch_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False | ||
318 | + ): | ||
319 | + """ | ||
320 | + Args: | ||
321 | + input_ids (LongTensor): tokens in the source language of shape | ||
322 | + `(batch, src_len)` | ||
323 | + attention_mask (torch.LongTensor): indicating which indices are padding tokens. | ||
324 | + Returns: | ||
325 | + BaseModelOutput or Tuple comprised of: | ||
326 | + - **x** (Tensor): the last encoder layer's output of | ||
327 | + shape `(src_len, batch, embed_dim)` | ||
328 | + - **encoder_states** (tuple(torch.FloatTensor)): all intermediate | ||
329 | + hidden states of shape `(src_len, batch, embed_dim)`. | ||
330 | + Only populated if *output_hidden_states:* is True. | ||
331 | + - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer. | ||
332 | + During training might not be of length n_layers because of layer dropout. | ||
333 | + """ | ||
334 | + # check attention mask and invert | ||
335 | + if attention_mask is not None: | ||
336 | + attention_mask = invert_mask(attention_mask) | ||
337 | + | ||
338 | + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale | ||
339 | + embed_pos = self.embed_positions(input_ids) | ||
340 | + embed_patch = self.embed_patches(patch_ids) | ||
341 | + x = inputs_embeds + embed_pos + embed_patch | ||
342 | + x = self.layernorm_embedding(x) | ||
343 | + x = F.dropout(x, p=self.dropout, training=self.training) | ||
344 | + | ||
345 | + # B x T x C -> T x B x C | ||
346 | + x = x.transpose(0, 1) | ||
347 | + | ||
348 | + encoder_states = [] if output_hidden_states else None | ||
349 | + all_attentions = () if output_attentions else None | ||
350 | + for encoder_layer in self.layers: | ||
351 | + if output_hidden_states: | ||
352 | + encoder_states.append(x) | ||
353 | + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | ||
354 | + dropout_probability = random.uniform(0, 1) | ||
355 | + if self.training and (dropout_probability < self.layerdrop): # skip the layer | ||
356 | + attn = None | ||
357 | + else: | ||
358 | + x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) | ||
359 | + | ||
360 | + if output_attentions: | ||
361 | + all_attentions = all_attentions + (attn,) | ||
362 | + | ||
363 | + if self.layer_norm: | ||
364 | + x = self.layer_norm(x) | ||
365 | + if output_hidden_states: | ||
366 | + encoder_states.append(x) | ||
367 | + # T x B x C -> B x T x C | ||
368 | + encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states) | ||
369 | + | ||
370 | + # T x B x C -> B x T x C | ||
371 | + x = x.transpose(0, 1) | ||
372 | + | ||
373 | + if not return_dict: | ||
374 | + return tuple(v for v in [x, encoder_states, all_attentions] if v is not None) | ||
375 | + return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions) | ||
376 | + | ||
377 | + | ||
378 | +class DecoderLayer(nn.Module): | ||
379 | + def __init__(self, config: BartConfig): | ||
380 | + super().__init__() | ||
381 | + self.embed_dim = config.d_model | ||
382 | + | ||
383 | + self.self_attn = Attention( | ||
384 | + embed_dim=self.embed_dim, | ||
385 | + num_heads=config.decoder_attention_heads, | ||
386 | + dropout=config.attention_dropout, | ||
387 | + ) | ||
388 | + self.dropout = config.dropout | ||
389 | + self.activation_fn = ACT2FN[config.activation_function] | ||
390 | + self.activation_dropout = config.activation_dropout | ||
391 | + self.normalize_before = config.normalize_before | ||
392 | + | ||
393 | + self.self_attn_layer_norm = LayerNorm(self.embed_dim) | ||
394 | + self.encoder_attn = Attention( | ||
395 | + self.embed_dim, | ||
396 | + config.decoder_attention_heads, | ||
397 | + dropout=config.attention_dropout, | ||
398 | + encoder_decoder_attention=True, | ||
399 | + ) | ||
400 | + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) | ||
401 | + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) | ||
402 | + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) | ||
403 | + self.final_layer_norm = LayerNorm(self.embed_dim) | ||
404 | + | ||
405 | + def forward( | ||
406 | + self, | ||
407 | + x, | ||
408 | + encoder_hidden_states, | ||
409 | + encoder_attn_mask=None, | ||
410 | + layer_state=None, | ||
411 | + causal_mask=None, | ||
412 | + decoder_padding_mask=None, | ||
413 | + output_attentions=False, | ||
414 | + ): | ||
415 | + residual = x | ||
416 | + | ||
417 | + if layer_state is None: | ||
418 | + layer_state = {} | ||
419 | + if self.normalize_before: | ||
420 | + x = self.self_attn_layer_norm(x) | ||
421 | + # Self Attention | ||
422 | + | ||
423 | + x, self_attn_weights = self.self_attn( | ||
424 | + query=x, | ||
425 | + key=x, | ||
426 | + layer_state=layer_state, # adds keys to layer state | ||
427 | + key_padding_mask=decoder_padding_mask, | ||
428 | + attn_mask=causal_mask, | ||
429 | + output_attentions=output_attentions, | ||
430 | + ) | ||
431 | + x = F.dropout(x, p=self.dropout, training=self.training) | ||
432 | + x = residual + x | ||
433 | + if not self.normalize_before: | ||
434 | + x = self.self_attn_layer_norm(x) | ||
435 | + | ||
436 | + # Cross attention | ||
437 | + residual = x | ||
438 | + assert self.encoder_attn.cache_key != self.self_attn.cache_key | ||
439 | + if self.normalize_before: | ||
440 | + x = self.encoder_attn_layer_norm(x) | ||
441 | + x, _ = self.encoder_attn( | ||
442 | + query=x, | ||
443 | + key=encoder_hidden_states, | ||
444 | + key_padding_mask=encoder_attn_mask, | ||
445 | + layer_state=layer_state, # mutates layer state | ||
446 | + ) | ||
447 | + x = F.dropout(x, p=self.dropout, training=self.training) | ||
448 | + x = residual + x | ||
449 | + if not self.normalize_before: | ||
450 | + x = self.encoder_attn_layer_norm(x) | ||
451 | + | ||
452 | + # Fully Connected | ||
453 | + residual = x | ||
454 | + if self.normalize_before: | ||
455 | + x = self.final_layer_norm(x) | ||
456 | + x = self.activation_fn(self.fc1(x)) | ||
457 | + x = F.dropout(x, p=self.activation_dropout, training=self.training) | ||
458 | + x = self.fc2(x) | ||
459 | + x = F.dropout(x, p=self.dropout, training=self.training) | ||
460 | + x = residual + x | ||
461 | + if not self.normalize_before: | ||
462 | + x = self.final_layer_norm(x) | ||
463 | + return ( | ||
464 | + x, | ||
465 | + self_attn_weights, | ||
466 | + layer_state, | ||
467 | + ) # just self_attn weights for now, following t5, layer_state = cache for decoding | ||
468 | + | ||
469 | + | ||
470 | +class BartDecoder(nn.Module): | ||
471 | + """ | ||
472 | + Transformer decoder consisting of *config.decoder_layers* layers. Each layer | ||
473 | + is a :class:`DecoderLayer`. | ||
474 | + Args: | ||
475 | + config: BartConfig | ||
476 | + embed_tokens (torch.nn.Embedding): output embedding | ||
477 | + """ | ||
478 | + | ||
479 | + def __init__(self, config: BartConfig, embed_tokens: nn.Embedding): | ||
480 | + super().__init__() | ||
481 | + self.dropout = config.dropout | ||
482 | + self.layerdrop = config.decoder_layerdrop | ||
483 | + self.padding_idx = embed_tokens.padding_idx | ||
484 | + self.max_target_positions = config.max_position_embeddings | ||
485 | + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 | ||
486 | + self.embed_tokens = embed_tokens | ||
487 | + if config.static_position_embeddings: | ||
488 | + self.embed_positions = SinusoidalPositionalEmbedding( | ||
489 | + config.max_position_embeddings, config.d_model, config.pad_token_id | ||
490 | + ) | ||
491 | + else: | ||
492 | + self.embed_positions = LearnedPositionalEmbedding( | ||
493 | + config.max_position_embeddings, | ||
494 | + config.d_model, | ||
495 | + self.padding_idx, | ||
496 | + config.extra_pos_embeddings, | ||
497 | + ) | ||
498 | + self.layers = nn.ModuleList( | ||
499 | + [DecoderLayer(config) for _ in range(config.decoder_layers)] | ||
500 | + ) # type: List[DecoderLayer] | ||
501 | + self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() | ||
502 | + self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None | ||
503 | + | ||
504 | + def forward( | ||
505 | + self, | ||
506 | + input_ids, | ||
507 | + encoder_hidden_states, | ||
508 | + encoder_padding_mask, | ||
509 | + decoder_padding_mask, | ||
510 | + decoder_causal_mask, | ||
511 | + past_key_values=None, | ||
512 | + use_cache=False, | ||
513 | + output_attentions=False, | ||
514 | + output_hidden_states=False, | ||
515 | + return_dict=False, | ||
516 | + **unused, | ||
517 | + ): | ||
518 | + """ | ||
519 | + Includes several features from "Jointly Learning to Align and | ||
520 | + Translate with Transformer Models" (Garg et al., EMNLP 2019). | ||
521 | + | ||
522 | + Args: | ||
523 | + input_ids (LongTensor): previous decoder outputs of shape | ||
524 | + `(batch, tgt_len)`, for teacher forcing | ||
525 | + encoder_hidden_states: output from the encoder, used for | ||
526 | + encoder-side attention | ||
527 | + encoder_padding_mask: for ignoring pad tokens | ||
528 | + past_key_values (dict or None): dictionary used for storing state during generation | ||
529 | + | ||
530 | + Returns: | ||
531 | + BaseModelOutputWithPast or tuple: | ||
532 | + - the decoder's features of shape `(batch, tgt_len, embed_dim)` | ||
533 | + - the cache | ||
534 | + - hidden states | ||
535 | + - attentions | ||
536 | + """ | ||
537 | + if "decoder_cached_states" in unused: | ||
538 | + warnings.warn( | ||
539 | + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
540 | + FutureWarning, | ||
541 | + ) | ||
542 | + past_key_values = unused.pop("decoder_cached_states") | ||
543 | + if "decoder_past_key_values" in unused: | ||
544 | + warnings.warn( | ||
545 | + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
546 | + FutureWarning, | ||
547 | + ) | ||
548 | + past_key_values = unused.pop("decoder_past_key_values") | ||
549 | + | ||
550 | + # check attention mask and invert | ||
551 | + if encoder_padding_mask is not None: | ||
552 | + encoder_padding_mask = invert_mask(encoder_padding_mask) | ||
553 | + | ||
554 | + # embed positions | ||
555 | + positions = self.embed_positions(input_ids, use_cache=use_cache) | ||
556 | + | ||
557 | + if use_cache: | ||
558 | + input_ids = input_ids[:, -1:] | ||
559 | + positions = positions[:, -1:] # happens after we embed them | ||
560 | + # assert input_ids.ne(self.padding_idx).any() | ||
561 | + | ||
562 | + x = self.embed_tokens(input_ids) * self.embed_scale | ||
563 | + x += positions | ||
564 | + x = self.layernorm_embedding(x) | ||
565 | + x = F.dropout(x, p=self.dropout, training=self.training) | ||
566 | + | ||
567 | + # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) | ||
568 | + x = x.transpose(0, 1) | ||
569 | + encoder_hidden_states = encoder_hidden_states.transpose(0, 1) | ||
570 | + | ||
571 | + # decoder layers | ||
572 | + all_hidden_states = () if output_hidden_states else None | ||
573 | + all_self_attns = () if output_attentions else None | ||
574 | + next_decoder_cache = [] | ||
575 | + for idx, decoder_layer in enumerate(self.layers): | ||
576 | + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | ||
577 | + if output_hidden_states: | ||
578 | + all_hidden_states += (x,) | ||
579 | + dropout_probability = random.uniform(0, 1) | ||
580 | + if self.training and (dropout_probability < self.layerdrop): | ||
581 | + continue | ||
582 | + | ||
583 | + layer_state = past_key_values[idx] if past_key_values is not None else None | ||
584 | + | ||
585 | + x, layer_self_attn, layer_past = decoder_layer( | ||
586 | + x, | ||
587 | + encoder_hidden_states, | ||
588 | + encoder_attn_mask=encoder_padding_mask, | ||
589 | + decoder_padding_mask=decoder_padding_mask, | ||
590 | + layer_state=layer_state, | ||
591 | + causal_mask=decoder_causal_mask, | ||
592 | + output_attentions=output_attentions, | ||
593 | + ) | ||
594 | + | ||
595 | + if use_cache: | ||
596 | + next_decoder_cache.append(layer_past.copy()) | ||
597 | + | ||
598 | + if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART) | ||
599 | + x = self.layer_norm(x) | ||
600 | + if output_attentions: | ||
601 | + all_self_attns += (layer_self_attn,) | ||
602 | + | ||
603 | + # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) | ||
604 | + if output_hidden_states: | ||
605 | + all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states) | ||
606 | + x = x.transpose(0, 1) | ||
607 | + encoder_hidden_states = encoder_hidden_states.transpose(0, 1) | ||
608 | + | ||
609 | + next_cache = next_decoder_cache if use_cache else None | ||
610 | + | ||
611 | + if not return_dict: | ||
612 | + return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) | ||
613 | + return BaseModelOutputWithPast( | ||
614 | + last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns | ||
615 | + ) | ||
616 | + | ||
617 | + | ||
618 | +def _reorder_buffer(attn_cache, new_order): | ||
619 | + for k, input_buffer_k in attn_cache.items(): | ||
620 | + if input_buffer_k is not None: | ||
621 | + attn_cache[k] = input_buffer_k.index_select(0, new_order) | ||
622 | + return attn_cache | ||
623 | + | ||
624 | + | ||
625 | +class Attention(nn.Module): | ||
626 | + """Multi-headed attention from 'Attention Is All You Need' paper""" | ||
627 | + | ||
628 | + def __init__( | ||
629 | + self, | ||
630 | + embed_dim, | ||
631 | + num_heads, | ||
632 | + dropout=0.0, | ||
633 | + bias=True, | ||
634 | + encoder_decoder_attention=False, # otherwise self_attention | ||
635 | + ): | ||
636 | + super().__init__() | ||
637 | + self.embed_dim = embed_dim | ||
638 | + self.num_heads = num_heads | ||
639 | + self.dropout = dropout | ||
640 | + self.head_dim = embed_dim // num_heads | ||
641 | + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | ||
642 | + self.scaling = self.head_dim ** -0.5 | ||
643 | + | ||
644 | + self.encoder_decoder_attention = encoder_decoder_attention | ||
645 | + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
646 | + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
647 | + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
648 | + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
649 | + self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" | ||
650 | + | ||
651 | + def _shape(self, tensor, seq_len, bsz): | ||
652 | + return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) | ||
653 | + | ||
654 | + def forward( | ||
655 | + self, | ||
656 | + query, | ||
657 | + key: Optional[Tensor], | ||
658 | + key_padding_mask: Optional[Tensor] = None, | ||
659 | + layer_state: Optional[Dict[str, Optional[Tensor]]] = None, | ||
660 | + attn_mask: Optional[Tensor] = None, | ||
661 | + output_attentions=False, | ||
662 | + ) -> Tuple[Tensor, Optional[Tensor]]: | ||
663 | + """Input shape: Time(SeqLen) x Batch x Channel""" | ||
664 | + static_kv: bool = self.encoder_decoder_attention | ||
665 | + tgt_len, bsz, embed_dim = query.size() | ||
666 | + assert embed_dim == self.embed_dim | ||
667 | + assert list(query.size()) == [tgt_len, bsz, embed_dim] | ||
668 | + # get here for encoder decoder cause of static_kv | ||
669 | + if layer_state is not None: # reuse k,v and encoder_padding_mask | ||
670 | + saved_state = layer_state.get(self.cache_key, {}) | ||
671 | + if "prev_key" in saved_state and static_kv: | ||
672 | + # previous time steps are cached - no need to recompute key and value if they are static | ||
673 | + key = None | ||
674 | + else: | ||
675 | + saved_state = None | ||
676 | + layer_state = {} | ||
677 | + | ||
678 | + q = self.q_proj(query) * self.scaling | ||
679 | + if static_kv: | ||
680 | + if key is None: | ||
681 | + k = v = None | ||
682 | + else: | ||
683 | + k = self.k_proj(key) | ||
684 | + v = self.v_proj(key) | ||
685 | + else: | ||
686 | + k = self.k_proj(query) | ||
687 | + v = self.v_proj(query) | ||
688 | + | ||
689 | + q = self._shape(q, tgt_len, bsz) | ||
690 | + if k is not None: | ||
691 | + k = self._shape(k, -1, bsz) | ||
692 | + if v is not None: | ||
693 | + v = self._shape(v, -1, bsz) | ||
694 | + | ||
695 | + if saved_state is not None: | ||
696 | + k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) | ||
697 | + | ||
698 | + # Update cache | ||
699 | + layer_state[self.cache_key] = { | ||
700 | + "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), | ||
701 | + "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), | ||
702 | + "prev_key_padding_mask": key_padding_mask if not static_kv else None, | ||
703 | + } | ||
704 | + | ||
705 | + assert k is not None | ||
706 | + src_len = k.size(1) | ||
707 | + attn_weights = torch.bmm(q, k.transpose(1, 2)) | ||
708 | + assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) | ||
709 | + | ||
710 | + if attn_mask is not None: | ||
711 | + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask | ||
712 | + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | ||
713 | + | ||
714 | + # This is part of a workaround to get around fork/join parallelism not supporting Optional types. | ||
715 | + if key_padding_mask is not None and key_padding_mask.dim() == 0: | ||
716 | + key_padding_mask = None | ||
717 | + assert key_padding_mask is None or key_padding_mask.size()[:2] == ( | ||
718 | + bsz, | ||
719 | + src_len, | ||
720 | + ) | ||
721 | + | ||
722 | + if key_padding_mask is not None: # don't attend to padding symbols | ||
723 | + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | ||
724 | + reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2) | ||
725 | + attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) | ||
726 | + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | ||
727 | + attn_weights = F.softmax(attn_weights, dim=-1) | ||
728 | + attn_probs = F.dropout( | ||
729 | + attn_weights, | ||
730 | + p=self.dropout, | ||
731 | + training=self.training, | ||
732 | + ) | ||
733 | + | ||
734 | + assert v is not None | ||
735 | + attn_output = torch.bmm(attn_probs, v) | ||
736 | + assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) | ||
737 | + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | ||
738 | + attn_output = self.out_proj(attn_output) | ||
739 | + if output_attentions: | ||
740 | + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | ||
741 | + else: | ||
742 | + attn_weights = None | ||
743 | + return attn_output, attn_weights | ||
744 | + | ||
745 | + def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): | ||
746 | + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) | ||
747 | + if "prev_key" in saved_state: | ||
748 | + _prev_key = saved_state["prev_key"] | ||
749 | + assert _prev_key is not None | ||
750 | + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) | ||
751 | + if static_kv: | ||
752 | + k = prev_key | ||
753 | + else: | ||
754 | + assert k is not None | ||
755 | + k = torch.cat([prev_key, k], dim=1) | ||
756 | + if "prev_value" in saved_state: | ||
757 | + _prev_value = saved_state["prev_value"] | ||
758 | + assert _prev_value is not None | ||
759 | + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) | ||
760 | + if static_kv: | ||
761 | + v = prev_value | ||
762 | + else: | ||
763 | + assert v is not None | ||
764 | + v = torch.cat([prev_value, v], dim=1) | ||
765 | + assert k is not None and v is not None | ||
766 | + prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None) | ||
767 | + if prev_key_padding_mask is not None: | ||
768 | + if static_kv: | ||
769 | + new_key_padding_mask = prev_key_padding_mask | ||
770 | + else: | ||
771 | + new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1) | ||
772 | + else: | ||
773 | + new_key_padding_mask = key_padding_mask | ||
774 | + return k, v, new_key_padding_mask | ||
775 | + | ||
776 | + | ||
777 | +class BartClassificationHead(nn.Module): | ||
778 | + """Head for sentence-level classification tasks.""" | ||
779 | + | ||
780 | + # This can trivially be shared with RobertaClassificationHead | ||
781 | + | ||
782 | + def __init__( | ||
783 | + self, | ||
784 | + input_dim, | ||
785 | + inner_dim, | ||
786 | + num_classes, | ||
787 | + pooler_dropout, | ||
788 | + ): | ||
789 | + super().__init__() | ||
790 | + self.dense = nn.Linear(input_dim, inner_dim) | ||
791 | + self.dropout = nn.Dropout(p=pooler_dropout) | ||
792 | + self.out_proj = nn.Linear(inner_dim, num_classes) | ||
793 | + | ||
794 | + def forward(self, x): | ||
795 | + x = self.dropout(x) | ||
796 | + x = self.dense(x) | ||
797 | + x = torch.tanh(x) | ||
798 | + x = self.dropout(x) | ||
799 | + x = self.out_proj(x) | ||
800 | + return x | ||
801 | + | ||
802 | + | ||
803 | +class LearnedPositionalEmbedding(nn.Embedding): | ||
804 | + """ | ||
805 | + This module learns positional embeddings up to a fixed maximum size. | ||
806 | + Padding ids are ignored by either offsetting based on padding_idx | ||
807 | + or by setting padding_idx to None and ensuring that the appropriate | ||
808 | + position ids are passed to the forward function. | ||
809 | + """ | ||
810 | + | ||
811 | + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset): | ||
812 | + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 | ||
813 | + # and adjust num_embeddings appropriately. Other models dont have this hack | ||
814 | + self.offset = offset | ||
815 | + assert padding_idx is not None | ||
816 | + num_embeddings += offset | ||
817 | + super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) | ||
818 | + | ||
819 | + def forward(self, input_ids, use_cache=False): | ||
820 | + """Input is expected to be of size [bsz x seqlen].""" | ||
821 | + bsz, seq_len = input_ids.shape[:2] | ||
822 | + if use_cache: | ||
823 | + positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing | ||
824 | + else: | ||
825 | + # starts at 0, ends at 1-seq_len | ||
826 | + positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) | ||
827 | + return super().forward(positions + self.offset) | ||
828 | + | ||
829 | + | ||
830 | +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): | ||
831 | + if torch.cuda.is_available(): | ||
832 | + try: | ||
833 | + from apex.normalization import FusedLayerNorm | ||
834 | + | ||
835 | + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) | ||
836 | + except ImportError: | ||
837 | + pass | ||
838 | + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) | ||
839 | + | ||
840 | + | ||
841 | +def fill_with_neg_inf(t): | ||
842 | + """FP16-compatible function that fills a input_ids with -inf.""" | ||
843 | + return t.float().fill_(float("-inf")).type_as(t) | ||
844 | + | ||
845 | + | ||
846 | +# Public API | ||
847 | +def _get_shape(t): | ||
848 | + return getattr(t, "shape", None) | ||
849 | + | ||
850 | + | ||
851 | +@add_start_docstrings( | ||
852 | + "The bare BART Model outputting raw hidden-states without any specific head on top.", | ||
853 | + BART_START_DOCSTRING, | ||
854 | +) | ||
855 | +class BartModel(PretrainedBartModel): | ||
856 | + def __init__(self, config: BartConfig): | ||
857 | + super().__init__(config) | ||
858 | + | ||
859 | + padding_idx, vocab_size = config.pad_token_id, config.vocab_size | ||
860 | + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) | ||
861 | + | ||
862 | + self.encoder = BartEncoder(config, self.shared) | ||
863 | + self.decoder = BartDecoder(config, self.shared) | ||
864 | + | ||
865 | + self.init_weights() | ||
866 | + | ||
867 | + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
868 | + @add_code_sample_docstrings( | ||
869 | + tokenizer_class=_TOKENIZER_FOR_DOC, | ||
870 | + checkpoint="facebook/bart-large", | ||
871 | + output_type=BaseModelOutputWithPast, | ||
872 | + config_class=_CONFIG_FOR_DOC, | ||
873 | + ) | ||
874 | + def forward( | ||
875 | + self, | ||
876 | + input_ids, | ||
877 | + patch_ids=None, | ||
878 | + attention_mask=None, | ||
879 | + decoder_input_ids=None, | ||
880 | + encoder_outputs: Optional[Tuple] = None, | ||
881 | + decoder_attention_mask=None, | ||
882 | + past_key_values=None, | ||
883 | + use_cache=None, | ||
884 | + output_attentions=None, | ||
885 | + output_hidden_states=None, | ||
886 | + return_dict=None, | ||
887 | + **kwargs, | ||
888 | + ): | ||
889 | + if "decoder_past_key_values" in kwargs: | ||
890 | + warnings.warn( | ||
891 | + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
892 | + FutureWarning, | ||
893 | + ) | ||
894 | + past_key_values = kwargs.pop("decoder_past_key_values") | ||
895 | + | ||
896 | + if decoder_input_ids is None: | ||
897 | + use_cache = False | ||
898 | + | ||
899 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
900 | + output_hidden_states = ( | ||
901 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
902 | + ) | ||
903 | + use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
904 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
905 | + | ||
906 | + # make masks if user doesn't supply | ||
907 | + if not use_cache: | ||
908 | + decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( | ||
909 | + self.config, | ||
910 | + input_ids, | ||
911 | + decoder_input_ids=decoder_input_ids, | ||
912 | + decoder_padding_mask=decoder_attention_mask, | ||
913 | + causal_mask_dtype=self.shared.weight.dtype, | ||
914 | + ) | ||
915 | + else: | ||
916 | + decoder_padding_mask, causal_mask = None, None | ||
917 | + | ||
918 | + assert decoder_input_ids is not None | ||
919 | + | ||
920 | + if encoder_outputs is None: | ||
921 | + encoder_outputs = self.encoder( | ||
922 | + input_ids=input_ids, | ||
923 | + patch_ids=patch_ids, | ||
924 | + attention_mask=attention_mask, | ||
925 | + output_attentions=output_attentions, | ||
926 | + output_hidden_states=output_hidden_states, | ||
927 | + return_dict=return_dict, | ||
928 | + ) | ||
929 | + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False | ||
930 | + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | ||
931 | + encoder_outputs = BaseModelOutput( | ||
932 | + last_hidden_state=encoder_outputs[0], | ||
933 | + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, | ||
934 | + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | ||
935 | + ) | ||
936 | + | ||
937 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
938 | + decoder_outputs = self.decoder( | ||
939 | + decoder_input_ids, | ||
940 | + encoder_outputs[0], | ||
941 | + attention_mask, | ||
942 | + decoder_padding_mask, | ||
943 | + decoder_causal_mask=causal_mask, | ||
944 | + past_key_values=past_key_values, | ||
945 | + use_cache=use_cache, | ||
946 | + output_attentions=output_attentions, | ||
947 | + output_hidden_states=output_hidden_states, | ||
948 | + return_dict=return_dict, | ||
949 | + ) | ||
950 | + | ||
951 | + if not return_dict: | ||
952 | + return decoder_outputs + encoder_outputs | ||
953 | + | ||
954 | + return Seq2SeqModelOutput( | ||
955 | + last_hidden_state=decoder_outputs.last_hidden_state, | ||
956 | + past_key_values=decoder_outputs.past_key_values, | ||
957 | + decoder_hidden_states=decoder_outputs.hidden_states, | ||
958 | + decoder_attentions=decoder_outputs.attentions, | ||
959 | + encoder_last_hidden_state=encoder_outputs.last_hidden_state, | ||
960 | + encoder_hidden_states=encoder_outputs.hidden_states, | ||
961 | + encoder_attentions=encoder_outputs.attentions, | ||
962 | + ) | ||
963 | + | ||
964 | + def get_input_embeddings(self): | ||
965 | + return self.shared | ||
966 | + | ||
967 | + def set_input_embeddings(self, value): | ||
968 | + self.shared = value | ||
969 | + self.encoder.embed_tokens = self.shared | ||
970 | + self.decoder.embed_tokens = self.shared | ||
971 | + | ||
972 | + def get_output_embeddings(self): | ||
973 | + return _make_linear_from_emb(self.shared) # make it on the fly | ||
974 | + | ||
975 | + | ||
976 | +@add_start_docstrings( | ||
977 | + "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING | ||
978 | +) | ||
979 | +class BartForConditionalGeneration(PretrainedBartModel): | ||
980 | + base_model_prefix = "model" | ||
981 | + authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] | ||
982 | + | ||
983 | + def __init__(self, config: BartConfig): | ||
984 | + super().__init__(config) | ||
985 | + base_model = BartModel(config) | ||
986 | + self.model = base_model | ||
987 | + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) | ||
988 | + | ||
989 | + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: | ||
990 | + old_num_tokens = self.model.shared.num_embeddings | ||
991 | + new_embeddings = super().resize_token_embeddings(new_num_tokens) | ||
992 | + self.model.shared = new_embeddings | ||
993 | + self._resize_final_logits_bias(new_num_tokens, old_num_tokens) | ||
994 | + return new_embeddings | ||
995 | + | ||
996 | + def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: | ||
997 | + if new_num_tokens <= old_num_tokens: | ||
998 | + new_bias = self.final_logits_bias[:, :new_num_tokens] | ||
999 | + else: | ||
1000 | + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) | ||
1001 | + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) | ||
1002 | + self.register_buffer("final_logits_bias", new_bias) | ||
1003 | + | ||
1004 | + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
1005 | + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) | ||
1006 | + @add_end_docstrings(BART_GENERATION_EXAMPLE) | ||
1007 | + def forward( | ||
1008 | + self, | ||
1009 | + input_ids, | ||
1010 | + patch_ids, | ||
1011 | + attention_mask=None, | ||
1012 | + encoder_outputs=None, | ||
1013 | + decoder_input_ids=None, | ||
1014 | + decoder_attention_mask=None, | ||
1015 | + past_key_values=None, | ||
1016 | + labels=None, | ||
1017 | + use_cache=None, | ||
1018 | + output_attentions=None, | ||
1019 | + output_hidden_states=None, | ||
1020 | + return_dict=None, | ||
1021 | + **unused, | ||
1022 | + ): | ||
1023 | + r""" | ||
1024 | + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
1025 | + Labels for computing the masked language modeling loss. | ||
1026 | + Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). | ||
1027 | + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens | ||
1028 | + with labels in ``[0, ..., config.vocab_size]``. | ||
1029 | + | ||
1030 | + Returns: | ||
1031 | + | ||
1032 | + Conditional generation example:: | ||
1033 | + | ||
1034 | + # Mask filling only works for bart-large | ||
1035 | + from transformers import BartTokenizer, BartForConditionalGeneration | ||
1036 | + tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | ||
1037 | + TXT = "My friends are <mask> but they eat too many carbs." | ||
1038 | + | ||
1039 | + model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') | ||
1040 | + input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] | ||
1041 | + logits = model(input_ids).logits | ||
1042 | + | ||
1043 | + masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | ||
1044 | + probs = logits[0, masked_index].softmax(dim=0) | ||
1045 | + values, predictions = probs.topk(5) | ||
1046 | + | ||
1047 | + tokenizer.decode(predictions).split() | ||
1048 | + # ['good', 'great', 'all', 'really', 'very'] | ||
1049 | + """ | ||
1050 | + if "lm_labels" in unused: | ||
1051 | + warnings.warn( | ||
1052 | + "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", | ||
1053 | + FutureWarning, | ||
1054 | + ) | ||
1055 | + labels = unused.pop("lm_labels") | ||
1056 | + if "decoder_cached_states" in unused: | ||
1057 | + warnings.warn( | ||
1058 | + "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
1059 | + FutureWarning, | ||
1060 | + ) | ||
1061 | + past_key_values = unused.pop("decoder_cached_states") | ||
1062 | + if "decoder_past_key_values" in unused: | ||
1063 | + warnings.warn( | ||
1064 | + "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
1065 | + FutureWarning, | ||
1066 | + ) | ||
1067 | + past_key_values = unused.pop("decoder_past_key_values") | ||
1068 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
1069 | + | ||
1070 | + if labels is not None: | ||
1071 | + use_cache = False | ||
1072 | + if decoder_input_ids is None: | ||
1073 | + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) | ||
1074 | + | ||
1075 | + outputs = self.model( | ||
1076 | + input_ids, | ||
1077 | + patch_ids=patch_ids, | ||
1078 | + attention_mask=attention_mask, | ||
1079 | + decoder_input_ids=decoder_input_ids, | ||
1080 | + encoder_outputs=encoder_outputs, | ||
1081 | + decoder_attention_mask=decoder_attention_mask, | ||
1082 | + past_key_values=past_key_values, | ||
1083 | + use_cache=use_cache, | ||
1084 | + output_attentions=output_attentions, | ||
1085 | + output_hidden_states=output_hidden_states, | ||
1086 | + return_dict=return_dict, | ||
1087 | + ) | ||
1088 | + lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) | ||
1089 | + | ||
1090 | + masked_lm_loss = None | ||
1091 | + if labels is not None: | ||
1092 | + loss_fct = CrossEntropyLoss() | ||
1093 | + # TODO(SS): do we need to ignore pad tokens in labels? | ||
1094 | + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) | ||
1095 | + | ||
1096 | + if not return_dict: | ||
1097 | + output = (lm_logits,) + outputs[1:] | ||
1098 | + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | ||
1099 | + | ||
1100 | + return Seq2SeqLMOutput( | ||
1101 | + loss=masked_lm_loss, | ||
1102 | + logits=lm_logits, | ||
1103 | + past_key_values=outputs.past_key_values, | ||
1104 | + decoder_hidden_states=outputs.decoder_hidden_states, | ||
1105 | + decoder_attentions=outputs.decoder_attentions, | ||
1106 | + encoder_last_hidden_state=outputs.encoder_last_hidden_state, | ||
1107 | + encoder_hidden_states=outputs.encoder_hidden_states, | ||
1108 | + encoder_attentions=outputs.encoder_attentions, | ||
1109 | + ) | ||
1110 | + | ||
1111 | + def prepare_inputs_for_generation( | ||
1112 | + self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs | ||
1113 | + ): | ||
1114 | + return { | ||
1115 | + "input_ids": None, # encoder_outputs is defined. input_ids not needed | ||
1116 | + "encoder_outputs": encoder_outputs, | ||
1117 | + "past_key_values": past, | ||
1118 | + "decoder_input_ids": decoder_input_ids, | ||
1119 | + "attention_mask": attention_mask, | ||
1120 | + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) | ||
1121 | + } | ||
1122 | + | ||
1123 | + def adjust_logits_during_generation(self, logits, cur_len, max_length): | ||
1124 | + if cur_len == 1 and self.config.force_bos_token_to_be_generated: | ||
1125 | + self._force_token_ids_generation(logits, self.config.bos_token_id) | ||
1126 | + elif cur_len == max_length - 1 and self.config.eos_token_id is not None: | ||
1127 | + self._force_token_ids_generation(logits, self.config.eos_token_id) | ||
1128 | + return logits | ||
1129 | + | ||
1130 | + def _force_token_ids_generation(self, scores, token_id) -> None: | ||
1131 | + """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" | ||
1132 | + scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf") | ||
1133 | + | ||
1134 | + @staticmethod | ||
1135 | + def _reorder_cache(past, beam_idx): | ||
1136 | + reordered_past = [] | ||
1137 | + for layer_past in past: | ||
1138 | + # get the correct batch idx from decoder layer's batch dim for cross and self-attn | ||
1139 | + layer_past_new = { | ||
1140 | + attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() | ||
1141 | + } | ||
1142 | + reordered_past.append(layer_past_new) | ||
1143 | + return reordered_past | ||
1144 | + | ||
1145 | + def get_encoder(self): | ||
1146 | + return self.model.encoder | ||
1147 | + | ||
1148 | + def get_output_embeddings(self): | ||
1149 | + return _make_linear_from_emb(self.model.shared) # make it on the fly | ||
1150 | + | ||
1151 | + | ||
1152 | +@add_start_docstrings( | ||
1153 | + """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, | ||
1154 | + BART_START_DOCSTRING, | ||
1155 | +) | ||
1156 | +class BartForSequenceClassification(PretrainedBartModel): | ||
1157 | + def __init__(self, config: BartConfig, **kwargs): | ||
1158 | + super().__init__(config, **kwargs) | ||
1159 | + self.model = BartModel(config) | ||
1160 | + self.classification_head = BartClassificationHead( | ||
1161 | + config.d_model, | ||
1162 | + config.d_model, | ||
1163 | + config.num_labels, | ||
1164 | + config.classif_dropout, | ||
1165 | + ) | ||
1166 | + self.model._init_weights(self.classification_head.dense) | ||
1167 | + self.model._init_weights(self.classification_head.out_proj) | ||
1168 | + | ||
1169 | + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
1170 | + @add_code_sample_docstrings( | ||
1171 | + tokenizer_class=_TOKENIZER_FOR_DOC, | ||
1172 | + checkpoint="facebook/bart-large", | ||
1173 | + output_type=Seq2SeqSequenceClassifierOutput, | ||
1174 | + config_class=_CONFIG_FOR_DOC, | ||
1175 | + ) | ||
1176 | + def forward( | ||
1177 | + self, | ||
1178 | + input_ids, | ||
1179 | + attention_mask=None, | ||
1180 | + encoder_outputs=None, | ||
1181 | + decoder_input_ids=None, | ||
1182 | + decoder_attention_mask=None, | ||
1183 | + labels=None, | ||
1184 | + use_cache=None, | ||
1185 | + output_attentions=None, | ||
1186 | + output_hidden_states=None, | ||
1187 | + return_dict=None, | ||
1188 | + ): | ||
1189 | + r""" | ||
1190 | + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | ||
1191 | + Labels for computing the sequence classification/regression loss. | ||
1192 | + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. | ||
1193 | + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
1194 | + """ | ||
1195 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
1196 | + if labels is not None: | ||
1197 | + use_cache = False | ||
1198 | + | ||
1199 | + outputs = self.model( | ||
1200 | + input_ids, | ||
1201 | + attention_mask=attention_mask, | ||
1202 | + decoder_input_ids=decoder_input_ids, | ||
1203 | + decoder_attention_mask=decoder_attention_mask, | ||
1204 | + encoder_outputs=encoder_outputs, | ||
1205 | + use_cache=use_cache, | ||
1206 | + output_attentions=output_attentions, | ||
1207 | + output_hidden_states=output_hidden_states, | ||
1208 | + return_dict=return_dict, | ||
1209 | + ) | ||
1210 | + x = outputs[0] # last hidden state | ||
1211 | + eos_mask = input_ids.eq(self.config.eos_token_id) | ||
1212 | + if len(torch.unique(eos_mask.sum(1))) > 1: | ||
1213 | + raise ValueError("All examples must have the same number of <eos> tokens.") | ||
1214 | + sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] | ||
1215 | + logits = self.classification_head(sentence_representation) | ||
1216 | + | ||
1217 | + loss = None | ||
1218 | + if labels is not None: | ||
1219 | + loss_fct = CrossEntropyLoss() | ||
1220 | + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) | ||
1221 | + | ||
1222 | + if not return_dict: | ||
1223 | + output = (logits,) + outputs[1:] | ||
1224 | + return ((loss,) + output) if loss is not None else output | ||
1225 | + | ||
1226 | + return Seq2SeqSequenceClassifierOutput( | ||
1227 | + loss=loss, | ||
1228 | + logits=logits, | ||
1229 | + past_key_values=outputs.past_key_values, | ||
1230 | + decoder_hidden_states=outputs.decoder_hidden_states, | ||
1231 | + decoder_attentions=outputs.decoder_attentions, | ||
1232 | + encoder_last_hidden_state=outputs.encoder_last_hidden_state, | ||
1233 | + encoder_hidden_states=outputs.encoder_hidden_states, | ||
1234 | + encoder_attentions=outputs.encoder_attentions, | ||
1235 | + ) | ||
1236 | + | ||
1237 | + | ||
1238 | +@add_start_docstrings( | ||
1239 | + """BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of | ||
1240 | + the hidden-states output to compute `span start logits` and `span end logits`). """, | ||
1241 | + BART_START_DOCSTRING, | ||
1242 | +) | ||
1243 | +class BartForQuestionAnswering(PretrainedBartModel): | ||
1244 | + def __init__(self, config): | ||
1245 | + super().__init__(config) | ||
1246 | + | ||
1247 | + config.num_labels = 2 | ||
1248 | + self.num_labels = config.num_labels | ||
1249 | + | ||
1250 | + self.model = BartModel(config) | ||
1251 | + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) | ||
1252 | + | ||
1253 | + self.model._init_weights(self.qa_outputs) | ||
1254 | + | ||
1255 | + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
1256 | + @add_code_sample_docstrings( | ||
1257 | + tokenizer_class=_TOKENIZER_FOR_DOC, | ||
1258 | + checkpoint="facebook/bart-large", | ||
1259 | + output_type=Seq2SeqQuestionAnsweringModelOutput, | ||
1260 | + config_class=_CONFIG_FOR_DOC, | ||
1261 | + ) | ||
1262 | + def forward( | ||
1263 | + self, | ||
1264 | + input_ids, | ||
1265 | + attention_mask=None, | ||
1266 | + encoder_outputs=None, | ||
1267 | + decoder_input_ids=None, | ||
1268 | + decoder_attention_mask=None, | ||
1269 | + start_positions=None, | ||
1270 | + end_positions=None, | ||
1271 | + use_cache=None, | ||
1272 | + output_attentions=None, | ||
1273 | + output_hidden_states=None, | ||
1274 | + return_dict=None, | ||
1275 | + ): | ||
1276 | + r""" | ||
1277 | + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | ||
1278 | + Labels for position (index) of the start of the labelled span for computing the token classification loss. | ||
1279 | + Positions are clamped to the length of the sequence (`sequence_length`). | ||
1280 | + Position outside of the sequence are not taken into account for computing the loss. | ||
1281 | + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | ||
1282 | + Labels for position (index) of the end of the labelled span for computing the token classification loss. | ||
1283 | + Positions are clamped to the length of the sequence (`sequence_length`). | ||
1284 | + Position outside of the sequence are not taken into account for computing the loss. | ||
1285 | + """ | ||
1286 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
1287 | + if start_positions is not None and end_positions is not None: | ||
1288 | + use_cache = False | ||
1289 | + | ||
1290 | + outputs = self.model( | ||
1291 | + input_ids, | ||
1292 | + attention_mask=attention_mask, | ||
1293 | + decoder_input_ids=decoder_input_ids, | ||
1294 | + decoder_attention_mask=decoder_attention_mask, | ||
1295 | + encoder_outputs=encoder_outputs, | ||
1296 | + use_cache=use_cache, | ||
1297 | + output_attentions=output_attentions, | ||
1298 | + output_hidden_states=output_hidden_states, | ||
1299 | + return_dict=return_dict, | ||
1300 | + ) | ||
1301 | + | ||
1302 | + sequence_output = outputs[0] | ||
1303 | + | ||
1304 | + logits = self.qa_outputs(sequence_output) | ||
1305 | + start_logits, end_logits = logits.split(1, dim=-1) | ||
1306 | + start_logits = start_logits.squeeze(-1) | ||
1307 | + end_logits = end_logits.squeeze(-1) | ||
1308 | + | ||
1309 | + total_loss = None | ||
1310 | + if start_positions is not None and end_positions is not None: | ||
1311 | + # If we are on multi-GPU, split add a dimension | ||
1312 | + if len(start_positions.size()) > 1: | ||
1313 | + start_positions = start_positions.squeeze(-1) | ||
1314 | + if len(end_positions.size()) > 1: | ||
1315 | + end_positions = end_positions.squeeze(-1) | ||
1316 | + # sometimes the start/end positions are outside our model inputs, we ignore these terms | ||
1317 | + ignored_index = start_logits.size(1) | ||
1318 | + start_positions.clamp_(0, ignored_index) | ||
1319 | + end_positions.clamp_(0, ignored_index) | ||
1320 | + | ||
1321 | + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | ||
1322 | + start_loss = loss_fct(start_logits, start_positions) | ||
1323 | + end_loss = loss_fct(end_logits, end_positions) | ||
1324 | + total_loss = (start_loss + end_loss) / 2 | ||
1325 | + | ||
1326 | + if not return_dict: | ||
1327 | + output = ( | ||
1328 | + start_logits, | ||
1329 | + end_logits, | ||
1330 | + ) + outputs[1:] | ||
1331 | + return ((total_loss,) + output) if total_loss is not None else output | ||
1332 | + | ||
1333 | + return Seq2SeqQuestionAnsweringModelOutput( | ||
1334 | + loss=total_loss, | ||
1335 | + start_logits=start_logits, | ||
1336 | + end_logits=end_logits, | ||
1337 | + past_key_values=outputs.past_key_values, | ||
1338 | + decoder_hidden_states=outputs.decoder_hidden_states, | ||
1339 | + decoder_attentions=outputs.decoder_attentions, | ||
1340 | + encoder_last_hidden_state=outputs.encoder_last_hidden_state, | ||
1341 | + encoder_hidden_states=outputs.encoder_hidden_states, | ||
1342 | + encoder_attentions=outputs.encoder_attentions, | ||
1343 | + ) | ||
1344 | + | ||
1345 | + | ||
1346 | +class SinusoidalPositionalEmbedding(nn.Embedding): | ||
1347 | + """This module produces sinusoidal positional embeddings of any length.""" | ||
1348 | + | ||
1349 | + def __init__(self, num_positions, embedding_dim, padding_idx=None): | ||
1350 | + super().__init__(num_positions, embedding_dim) | ||
1351 | + if embedding_dim % 2 != 0: | ||
1352 | + raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") | ||
1353 | + self.weight = self._init_weight(self.weight) | ||
1354 | + | ||
1355 | + @staticmethod | ||
1356 | + def _init_weight(out: nn.Parameter): | ||
1357 | + """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. | ||
1358 | + The cos features are in the 2nd half of the vector. [dim // 2:] | ||
1359 | + """ | ||
1360 | + n_pos, dim = out.shape | ||
1361 | + position_enc = np.array( | ||
1362 | + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] | ||
1363 | + ) | ||
1364 | + out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos | ||
1365 | + out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) | ||
1366 | + out.detach_() | ||
1367 | + out.requires_grad = False | ||
1368 | + return out | ||
1369 | + | ||
1370 | + @torch.no_grad() | ||
1371 | + def forward(self, input_ids, use_cache=False): | ||
1372 | + """Input is expected to be of size [bsz x seqlen].""" | ||
1373 | + bsz, seq_len = input_ids.shape[:2] | ||
1374 | + if use_cache: | ||
1375 | + positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing | ||
1376 | + else: | ||
1377 | + # starts at 0, ends at 1-seq_len | ||
1378 | + positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device) | ||
1379 | + return super().forward(positions) |
modeling_utils.py
0 → 100644
1 | +# coding=utf-8 | ||
2 | +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. | ||
3 | +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
4 | +# | ||
5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | +# you may not use this file except in compliance with the License. | ||
7 | +# You may obtain a copy of the License at | ||
8 | +# | ||
9 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | +# | ||
11 | +# Unless required by applicable law or agreed to in writing, software | ||
12 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | +# See the License for the specific language governing permissions and | ||
15 | +# limitations under the License. | ||
16 | + | ||
17 | +import inspect | ||
18 | +import os | ||
19 | +import re | ||
20 | +from dataclasses import dataclass | ||
21 | +from typing import Callable, Dict, List, Optional, Set, Tuple, Union | ||
22 | + | ||
23 | +import torch | ||
24 | +from torch import Tensor, device, dtype, nn | ||
25 | +from torch.nn import CrossEntropyLoss | ||
26 | +from torch.nn import functional as F | ||
27 | + | ||
28 | +from transformers.activations import get_activation | ||
29 | +from transformers.configuration_utils import PretrainedConfig | ||
30 | +from transformers.file_utils import ( | ||
31 | + DUMMY_INPUTS, | ||
32 | + TF2_WEIGHTS_NAME, | ||
33 | + TF_WEIGHTS_NAME, | ||
34 | + WEIGHTS_NAME, | ||
35 | + ModelOutput, | ||
36 | + cached_path, | ||
37 | + hf_bucket_url, | ||
38 | + is_remote_url, | ||
39 | + is_torch_tpu_available, | ||
40 | + replace_return_docstrings, | ||
41 | +) | ||
42 | +from generation_utils import GenerationMixin | ||
43 | +import logging | ||
44 | + | ||
45 | +logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
46 | +logging.basicConfig( | ||
47 | + format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
48 | + datefmt="%m/%d/%Y %H:%M:%S", | ||
49 | + level=logging.INFO, | ||
50 | +) | ||
51 | + | ||
52 | + | ||
53 | +try: | ||
54 | + from torch.nn import Identity | ||
55 | +except ImportError: | ||
56 | + # Older PyTorch compatibility | ||
57 | + class Identity(nn.Module): | ||
58 | + r"""A placeholder identity operator that is argument-insensitive.""" | ||
59 | + | ||
60 | + def __init__(self, *args, **kwargs): | ||
61 | + super().__init__() | ||
62 | + | ||
63 | + def forward(self, input): | ||
64 | + return input | ||
65 | + | ||
66 | + | ||
67 | +def find_pruneable_heads_and_indices( | ||
68 | + heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int] | ||
69 | +) -> Tuple[Set[int], torch.LongTensor]: | ||
70 | + """ | ||
71 | + Finds the heads and their indices taking :obj:`already_pruned_heads` into account. | ||
72 | + | ||
73 | + Args: | ||
74 | + heads (:obj:`List[int]`): List of the indices of heads to prune. | ||
75 | + n_heads (:obj:`int`): The number of heads in the model. | ||
76 | + head_size (:obj:`int`): The size of each head. | ||
77 | + already_pruned_heads (:obj:`Set[int]`): A set of already pruned heads. | ||
78 | + | ||
79 | + Returns: | ||
80 | + :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. | ||
81 | + """ | ||
82 | + mask = torch.ones(n_heads, head_size) | ||
83 | + heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads | ||
84 | + for head in heads: | ||
85 | + # Compute how many pruned heads are before the head and move the index accordingly | ||
86 | + head = head - sum(1 if h < head else 0 for h in already_pruned_heads) | ||
87 | + mask[head] = 0 | ||
88 | + mask = mask.view(-1).contiguous().eq(1) | ||
89 | + index: torch.LongTensor = torch.arange(len(mask))[mask].long() | ||
90 | + return heads, index | ||
91 | + | ||
92 | + | ||
93 | +class ModuleUtilsMixin: | ||
94 | + """ | ||
95 | + A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin. | ||
96 | + """ | ||
97 | + | ||
98 | + def num_parameters(self, only_trainable: bool = False) -> int: | ||
99 | + """ | ||
100 | + Get the number of (optionally, trainable) parameters in the model. | ||
101 | + | ||
102 | + Args: | ||
103 | + only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
104 | + Whether or not to return only the number of trainable parameters | ||
105 | + | ||
106 | + Returns: | ||
107 | + :obj:`int`: The number of parameters. | ||
108 | + """ | ||
109 | + params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters() | ||
110 | + return sum(p.numel() for p in params) | ||
111 | + | ||
112 | + @staticmethod | ||
113 | + def _hook_rss_memory_pre_forward(module, *args, **kwargs): | ||
114 | + try: | ||
115 | + import psutil | ||
116 | + except (ImportError): | ||
117 | + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") | ||
118 | + | ||
119 | + process = psutil.Process(os.getpid()) | ||
120 | + mem = process.memory_info() | ||
121 | + module.mem_rss_pre_forward = mem.rss | ||
122 | + return None | ||
123 | + | ||
124 | + @staticmethod | ||
125 | + def _hook_rss_memory_post_forward(module, *args, **kwargs): | ||
126 | + try: | ||
127 | + import psutil | ||
128 | + except (ImportError): | ||
129 | + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") | ||
130 | + | ||
131 | + process = psutil.Process(os.getpid()) | ||
132 | + mem = process.memory_info() | ||
133 | + module.mem_rss_post_forward = mem.rss | ||
134 | + mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward | ||
135 | + module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0) | ||
136 | + return None | ||
137 | + | ||
138 | + def add_memory_hooks(self): | ||
139 | + """ | ||
140 | + Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. | ||
141 | + | ||
142 | + Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to | ||
143 | + zero with :obj:`model.reset_memory_hooks_state()`. | ||
144 | + """ | ||
145 | + for module in self.modules(): | ||
146 | + module.register_forward_pre_hook(self._hook_rss_memory_pre_forward) | ||
147 | + module.register_forward_hook(self._hook_rss_memory_post_forward) | ||
148 | + self.reset_memory_hooks_state() | ||
149 | + | ||
150 | + def reset_memory_hooks_state(self): | ||
151 | + """ | ||
152 | + Reset the :obj:`mem_rss_diff` attribute of each module (see | ||
153 | + :func:`~transformers.modeling_utils.ModuleUtilsMixin.add_memory_hooks`). | ||
154 | + """ | ||
155 | + for module in self.modules(): | ||
156 | + module.mem_rss_diff = 0 | ||
157 | + module.mem_rss_post_forward = 0 | ||
158 | + module.mem_rss_pre_forward = 0 | ||
159 | + | ||
160 | + @property | ||
161 | + def device(self) -> device: | ||
162 | + """ | ||
163 | + :obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same | ||
164 | + device). | ||
165 | + """ | ||
166 | + try: | ||
167 | + return next(self.parameters()).device | ||
168 | + except StopIteration: | ||
169 | + # For nn.DataParallel compatibility in PyTorch 1.5 | ||
170 | + | ||
171 | + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | ||
172 | + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | ||
173 | + return tuples | ||
174 | + | ||
175 | + gen = self._named_members(get_members_fn=find_tensor_attributes) | ||
176 | + first_tuple = next(gen) | ||
177 | + return first_tuple[1].device | ||
178 | + | ||
179 | + @property | ||
180 | + def dtype(self) -> dtype: | ||
181 | + """ | ||
182 | + :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | ||
183 | + """ | ||
184 | + try: | ||
185 | + return next(self.parameters()).dtype | ||
186 | + except StopIteration: | ||
187 | + # For nn.DataParallel compatibility in PyTorch 1.5 | ||
188 | + | ||
189 | + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | ||
190 | + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | ||
191 | + return tuples | ||
192 | + | ||
193 | + gen = self._named_members(get_members_fn=find_tensor_attributes) | ||
194 | + first_tuple = next(gen) | ||
195 | + return first_tuple[1].dtype | ||
196 | + | ||
197 | + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: | ||
198 | + """ | ||
199 | + Invert an attention mask (e.g., switches 0. and 1.). | ||
200 | + | ||
201 | + Args: | ||
202 | + encoder_attention_mask (:obj:`torch.Tensor`): An attention mask. | ||
203 | + | ||
204 | + Returns: | ||
205 | + :obj:`torch.Tensor`: The inverted attention mask. | ||
206 | + """ | ||
207 | + if encoder_attention_mask.dim() == 3: | ||
208 | + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] | ||
209 | + if encoder_attention_mask.dim() == 2: | ||
210 | + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] | ||
211 | + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition | ||
212 | + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow | ||
213 | + # /transformer/transformer_layers.py#L270 | ||
214 | + # encoder_extended_attention_mask = (encoder_extended_attention_mask == | ||
215 | + # encoder_extended_attention_mask.transpose(-1, -2)) | ||
216 | + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility | ||
217 | + | ||
218 | + if self.dtype == torch.float16: | ||
219 | + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 | ||
220 | + elif self.dtype == torch.float32: | ||
221 | + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 | ||
222 | + else: | ||
223 | + raise ValueError( | ||
224 | + "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format( | ||
225 | + self.dtype | ||
226 | + ) | ||
227 | + ) | ||
228 | + | ||
229 | + return encoder_extended_attention_mask | ||
230 | + | ||
231 | + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: | ||
232 | + """ | ||
233 | + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. | ||
234 | + | ||
235 | + Arguments: | ||
236 | + attention_mask (:obj:`torch.Tensor`): | ||
237 | + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. | ||
238 | + input_shape (:obj:`Tuple[int]`): | ||
239 | + The shape of the input to the model. | ||
240 | + device: (:obj:`torch.device`): | ||
241 | + The device of the input to the model. | ||
242 | + | ||
243 | + Returns: | ||
244 | + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. | ||
245 | + """ | ||
246 | + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
247 | + # ourselves in which case we just need to make it broadcastable to all heads. | ||
248 | + if attention_mask.dim() == 3: | ||
249 | + extended_attention_mask = attention_mask[:, None, :, :] | ||
250 | + elif attention_mask.dim() == 2: | ||
251 | + # Provided a padding mask of dimensions [batch_size, seq_length] | ||
252 | + # - if the model is a decoder, apply a causal mask in addition to the padding mask | ||
253 | + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||
254 | + if self.config.is_decoder: | ||
255 | + batch_size, seq_length = input_shape | ||
256 | + seq_ids = torch.arange(seq_length, device=device) | ||
257 | + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] | ||
258 | + # causal and attention masks must have same type with pytorch version < 1.3 | ||
259 | + causal_mask = causal_mask.to(attention_mask.dtype) | ||
260 | + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] | ||
261 | + else: | ||
262 | + extended_attention_mask = attention_mask[:, None, None, :] | ||
263 | + else: | ||
264 | + raise ValueError( | ||
265 | + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( | ||
266 | + input_shape, attention_mask.shape | ||
267 | + ) | ||
268 | + ) | ||
269 | + | ||
270 | + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for | ||
271 | + # masked positions, this operation will create a tensor which is 0.0 for | ||
272 | + # positions we want to attend and -10000.0 for masked positions. | ||
273 | + # Since we are adding it to the raw scores before the softmax, this is | ||
274 | + # effectively the same as removing these entirely. | ||
275 | + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility | ||
276 | + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||
277 | + return extended_attention_mask | ||
278 | + | ||
279 | + def get_head_mask( | ||
280 | + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False | ||
281 | + ) -> Tensor: | ||
282 | + """ | ||
283 | + Prepare the head mask if needed. | ||
284 | + | ||
285 | + Args: | ||
286 | + head_mask (:obj:`torch.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`): | ||
287 | + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). | ||
288 | + num_hidden_layers (:obj:`int`): | ||
289 | + The number of hidden layers in the model. | ||
290 | + is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`): | ||
291 | + Whether or not the attentions scores are computed by chunks or not. | ||
292 | + | ||
293 | + Returns: | ||
294 | + :obj:`torch.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length]` | ||
295 | + or list with :obj:`[None]` for each layer. | ||
296 | + """ | ||
297 | + if head_mask is not None: | ||
298 | + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) | ||
299 | + if is_attention_chunked is True: | ||
300 | + head_mask = head_mask.unsqueeze(-1) | ||
301 | + else: | ||
302 | + head_mask = [None] * num_hidden_layers | ||
303 | + | ||
304 | + return head_mask | ||
305 | + | ||
306 | + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): | ||
307 | + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" | ||
308 | + if head_mask.dim() == 1: | ||
309 | + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) | ||
310 | + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) | ||
311 | + elif head_mask.dim() == 2: | ||
312 | + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer | ||
313 | + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" | ||
314 | + head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility | ||
315 | + return head_mask | ||
316 | + | ||
317 | + | ||
318 | +class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ||
319 | + r""" | ||
320 | + Base class for all models. | ||
321 | + | ||
322 | + :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods | ||
323 | + for loading, downloading and saving models as well as a few methods common to all models to: | ||
324 | + | ||
325 | + * resize the input embeddings, | ||
326 | + * prune heads in the self-attention heads. | ||
327 | + | ||
328 | + Class attributes (overridden by derived classes): | ||
329 | + - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of | ||
330 | + :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. | ||
331 | + - **load_tf_weights** (:obj:`Callable`) -- A python `method` for loading a TensorFlow checkpoint in a | ||
332 | + PyTorch model, taking as arguments: | ||
333 | + | ||
334 | + - **model** (:class:`~transformers.PreTrainedModel`) -- An instance of the model on which to load the | ||
335 | + TensorFlow checkpoint. | ||
336 | + - **config** (:class:`~transformers.PreTrainedConfig`) -- An instance of the configuration associated | ||
337 | + to the model. | ||
338 | + - **path** (:obj:`str`) -- A path to the TensorFlow checkpoint. | ||
339 | + | ||
340 | + - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in | ||
341 | + derived classes of the same architecture adding modules on top of the base model. | ||
342 | + - **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore | ||
343 | + when loading the model (and avoid unnecessary warnings). | ||
344 | + """ | ||
345 | + config_class = None | ||
346 | + base_model_prefix = "" | ||
347 | + authorized_missing_keys = None | ||
348 | + | ||
349 | + @property | ||
350 | + def dummy_inputs(self) -> Dict[str, torch.Tensor]: | ||
351 | + """ | ||
352 | + :obj:`Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network. | ||
353 | + """ | ||
354 | + return {"input_ids": torch.tensor(DUMMY_INPUTS)} | ||
355 | + | ||
356 | + def __init__(self, config: PretrainedConfig, *inputs, **kwargs): | ||
357 | + super().__init__() | ||
358 | + if not isinstance(config, PretrainedConfig): | ||
359 | + raise ValueError( | ||
360 | + "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " | ||
361 | + "To create a model from a pretrained model use " | ||
362 | + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | ||
363 | + self.__class__.__name__, self.__class__.__name__ | ||
364 | + ) | ||
365 | + ) | ||
366 | + # Save config in model | ||
367 | + self.config = config | ||
368 | + | ||
369 | + @property | ||
370 | + def base_model(self) -> nn.Module: | ||
371 | + """ | ||
372 | + :obj:`torch.nn.Module`: The main body of the model. | ||
373 | + """ | ||
374 | + return getattr(self, self.base_model_prefix, self) | ||
375 | + | ||
376 | + def get_input_embeddings(self) -> nn.Module: | ||
377 | + """ | ||
378 | + Returns the model's input embeddings. | ||
379 | + | ||
380 | + Returns: | ||
381 | + :obj:`nn.Module`: A torch module mapping vocabulary to hidden states. | ||
382 | + """ | ||
383 | + base_model = getattr(self, self.base_model_prefix, self) | ||
384 | + if base_model is not self: | ||
385 | + return base_model.get_input_embeddings() | ||
386 | + else: | ||
387 | + raise NotImplementedError | ||
388 | + | ||
389 | + def set_input_embeddings(self, value: nn.Module): | ||
390 | + """ | ||
391 | + Set model's input embeddings | ||
392 | + | ||
393 | + Args: | ||
394 | + value (:obj:`nn.Module`): A module mapping vocabulary to hidden states. | ||
395 | + """ | ||
396 | + base_model = getattr(self, self.base_model_prefix, self) | ||
397 | + if base_model is not self: | ||
398 | + base_model.set_input_embeddings(value) | ||
399 | + else: | ||
400 | + raise NotImplementedError | ||
401 | + | ||
402 | + def get_output_embeddings(self) -> nn.Module: | ||
403 | + """ | ||
404 | + Returns the model's output embeddings. | ||
405 | + | ||
406 | + Returns: | ||
407 | + :obj:`nn.Module`: A torch module mapping hidden states to vocabulary. | ||
408 | + """ | ||
409 | + return None # Overwrite for models with output embeddings | ||
410 | + | ||
411 | + def tie_weights(self): | ||
412 | + """ | ||
413 | + Tie the weights between the input embeddings and the output embeddings. | ||
414 | + | ||
415 | + If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning | ||
416 | + the weights instead. | ||
417 | + """ | ||
418 | + output_embeddings = self.get_output_embeddings() | ||
419 | + if output_embeddings is not None and self.config.tie_word_embeddings: | ||
420 | + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) | ||
421 | + | ||
422 | + if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: | ||
423 | + self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) | ||
424 | + | ||
425 | + @staticmethod | ||
426 | + def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): | ||
427 | + uninitialized_encoder_weights: List[str] = [] | ||
428 | + assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal." | ||
429 | + | ||
430 | + def tie_encoder_to_decoder_recursively( | ||
431 | + decoder_pointer: nn.Module, | ||
432 | + encoder_pointer: nn.Module, | ||
433 | + module_name: str, | ||
434 | + uninitialized_encoder_weights: List[str], | ||
435 | + depth=0, | ||
436 | + ): | ||
437 | + assert isinstance(decoder_pointer, nn.Module) and isinstance( | ||
438 | + encoder_pointer, nn.Module | ||
439 | + ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" | ||
440 | + if hasattr(decoder_pointer, "weight"): | ||
441 | + assert hasattr(encoder_pointer, "weight") | ||
442 | + encoder_pointer.weight = decoder_pointer.weight | ||
443 | + if hasattr(decoder_pointer, "bias"): | ||
444 | + assert hasattr(encoder_pointer, "bias") | ||
445 | + encoder_pointer.bias = decoder_pointer.bias | ||
446 | + return | ||
447 | + | ||
448 | + encoder_modules = encoder_pointer._modules | ||
449 | + decoder_modules = decoder_pointer._modules | ||
450 | + if len(decoder_modules) > 0: | ||
451 | + assert ( | ||
452 | + len(encoder_modules) > 0 | ||
453 | + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" | ||
454 | + | ||
455 | + all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) | ||
456 | + encoder_layer_pos = 0 | ||
457 | + for name, module in decoder_modules.items(): | ||
458 | + if name.isdigit(): | ||
459 | + encoder_name = str(int(name) + encoder_layer_pos) | ||
460 | + decoder_name = name | ||
461 | + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])): | ||
462 | + # this can happen if the name corresponds to the position in a list module list of layers | ||
463 | + # in this case the decoder has added a cross-attention that the encoder does not have | ||
464 | + # thus skip this step and substract one layer pos from encoder | ||
465 | + encoder_layer_pos -= 1 | ||
466 | + continue | ||
467 | + elif name not in encoder_modules: | ||
468 | + continue | ||
469 | + elif depth > 500: | ||
470 | + raise ValueError( | ||
471 | + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." | ||
472 | + ) | ||
473 | + else: | ||
474 | + decoder_name = encoder_name = name | ||
475 | + tie_encoder_to_decoder_recursively( | ||
476 | + decoder_modules[decoder_name], | ||
477 | + encoder_modules[encoder_name], | ||
478 | + module_name + "/" + name, | ||
479 | + uninitialized_encoder_weights, | ||
480 | + depth=depth + 1, | ||
481 | + ) | ||
482 | + all_encoder_weights.remove(module_name + "/" + encoder_name) | ||
483 | + | ||
484 | + uninitialized_encoder_weights += list(all_encoder_weights) | ||
485 | + | ||
486 | + # tie weights recursively | ||
487 | + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | ||
488 | + if len(uninitialized_encoder_weights) > 0: | ||
489 | + logger.warning( | ||
490 | + f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | ||
491 | + ) | ||
492 | + | ||
493 | + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): | ||
494 | + """Tie or clone module weights depending of whether we are using TorchScript or not""" | ||
495 | + if self.config.torchscript: | ||
496 | + output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) | ||
497 | + else: | ||
498 | + output_embeddings.weight = input_embeddings.weight | ||
499 | + | ||
500 | + if getattr(output_embeddings, "bias", None) is not None: | ||
501 | + output_embeddings.bias.data = torch.nn.functional.pad( | ||
502 | + output_embeddings.bias.data, | ||
503 | + ( | ||
504 | + 0, | ||
505 | + output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], | ||
506 | + ), | ||
507 | + "constant", | ||
508 | + 0, | ||
509 | + ) | ||
510 | + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): | ||
511 | + output_embeddings.out_features = input_embeddings.num_embeddings | ||
512 | + | ||
513 | + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding: | ||
514 | + """ | ||
515 | + Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`. | ||
516 | + | ||
517 | + Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method. | ||
518 | + | ||
519 | + Arguments: | ||
520 | + new_num_tokens (:obj:`int`, `optional`): | ||
521 | + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized | ||
522 | + vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`, | ||
523 | + just returns a pointer to the input tokens :obj:`torch.nn.Embedding` module of the model wihtout doing | ||
524 | + anything. | ||
525 | + | ||
526 | + Return: | ||
527 | + :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. | ||
528 | + """ | ||
529 | + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed | ||
530 | + model_embeds = base_model._resize_token_embeddings(new_num_tokens) | ||
531 | + if new_num_tokens is None: | ||
532 | + return model_embeds | ||
533 | + | ||
534 | + # Update base model and current model config | ||
535 | + self.config.vocab_size = new_num_tokens | ||
536 | + base_model.vocab_size = new_num_tokens | ||
537 | + | ||
538 | + # Tie weights again if needed | ||
539 | + self.tie_weights() | ||
540 | + | ||
541 | + return model_embeds | ||
542 | + | ||
543 | + def _resize_token_embeddings(self, new_num_tokens): | ||
544 | + old_embeddings = self.get_input_embeddings() | ||
545 | + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) | ||
546 | + self.set_input_embeddings(new_embeddings) | ||
547 | + return self.get_input_embeddings() | ||
548 | + | ||
549 | + def _get_resized_embeddings( | ||
550 | + self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None | ||
551 | + ) -> torch.nn.Embedding: | ||
552 | + """ | ||
553 | + Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly | ||
554 | + initialized vectors at the end. Reducing the size will remove vectors from the end | ||
555 | + | ||
556 | + Args: | ||
557 | + old_embeddings (:obj:`torch.nn.Embedding`): | ||
558 | + Old embeddings to be resized. | ||
559 | + new_num_tokens (:obj:`int`, `optional`): | ||
560 | + New number of tokens in the embedding matrix. | ||
561 | + | ||
562 | + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove | ||
563 | + vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens | ||
564 | + :obj:`torch.nn.Embedding`` module of the model wihtout doing anything. | ||
565 | + | ||
566 | + Return: | ||
567 | + :obj:`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if | ||
568 | + :obj:`new_num_tokens` is :obj:`None` | ||
569 | + """ | ||
570 | + if new_num_tokens is None: | ||
571 | + return old_embeddings | ||
572 | + | ||
573 | + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() | ||
574 | + if old_num_tokens == new_num_tokens: | ||
575 | + return old_embeddings | ||
576 | + | ||
577 | + # Build new embeddings | ||
578 | + new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) | ||
579 | + new_embeddings.to(old_embeddings.weight.device) | ||
580 | + | ||
581 | + # initialize all new embeddings (in particular added tokens) | ||
582 | + self._init_weights(new_embeddings) | ||
583 | + | ||
584 | + # Copy token embeddings from the previous weights | ||
585 | + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) | ||
586 | + new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] | ||
587 | + | ||
588 | + return new_embeddings | ||
589 | + | ||
590 | + def init_weights(self): | ||
591 | + """ | ||
592 | + Initializes and prunes weights if needed. | ||
593 | + """ | ||
594 | + # Initialize weights | ||
595 | + self.apply(self._init_weights) | ||
596 | + | ||
597 | + # Prune heads if needed | ||
598 | + if self.config.pruned_heads: | ||
599 | + self.prune_heads(self.config.pruned_heads) | ||
600 | + | ||
601 | + # Tie weights if needed | ||
602 | + self.tie_weights() | ||
603 | + | ||
604 | + def prune_heads(self, heads_to_prune: Dict[int, List[int]]): | ||
605 | + """ | ||
606 | + Prunes heads of the base model. | ||
607 | + | ||
608 | + Arguments: | ||
609 | + heads_to_prune (:obj:`Dict[int, List[int]]`): | ||
610 | + Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list | ||
611 | + of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will | ||
612 | + prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. | ||
613 | + """ | ||
614 | + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads | ||
615 | + for layer, heads in heads_to_prune.items(): | ||
616 | + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) | ||
617 | + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON | ||
618 | + | ||
619 | + self.base_model._prune_heads(heads_to_prune) | ||
620 | + | ||
621 | + def save_pretrained(self, save_directory): | ||
622 | + """ | ||
623 | + Save a model and its configuration file to a directory, so that it can be re-loaded using the | ||
624 | + `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. | ||
625 | + | ||
626 | + Arguments: | ||
627 | + save_directory (:obj:`str`): | ||
628 | + Directory to which to save. Will be created if it doesn't exist. | ||
629 | + """ | ||
630 | + if os.path.isfile(save_directory): | ||
631 | + logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) | ||
632 | + return | ||
633 | + os.makedirs(save_directory, exist_ok=True) | ||
634 | + | ||
635 | + # Only save the model itself if we are using distributed training | ||
636 | + model_to_save = self.module if hasattr(self, "module") else self | ||
637 | + | ||
638 | + # Attach architecture to the config | ||
639 | + model_to_save.config.architectures = [model_to_save.__class__.__name__] | ||
640 | + | ||
641 | + # If we save using the predefined names, we can load using `from_pretrained` | ||
642 | + output_model_file = os.path.join(save_directory, WEIGHTS_NAME) | ||
643 | + | ||
644 | + if getattr(self.config, "xla_device", False): | ||
645 | + import torch_xla.core.xla_model as xm | ||
646 | + | ||
647 | + if xm.is_master_ordinal(): | ||
648 | + # Save configuration file | ||
649 | + model_to_save.config.save_pretrained(save_directory) | ||
650 | + # xm.save takes care of saving only from master | ||
651 | + xm.save(model_to_save.state_dict(), output_model_file) | ||
652 | + else: | ||
653 | + model_to_save.config.save_pretrained(save_directory) | ||
654 | + torch.save(model_to_save.state_dict(), output_model_file) | ||
655 | + | ||
656 | + logger.info("Model weights saved in {}".format(output_model_file)) | ||
657 | + | ||
658 | + @classmethod | ||
659 | + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||
660 | + r""" | ||
661 | + Instantiate a pretrained pytorch model from a pre-trained model configuration. | ||
662 | + | ||
663 | + The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). | ||
664 | + To train the model, you should first set it back in training mode with ``model.train()``. | ||
665 | + | ||
666 | + The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come | ||
667 | + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning | ||
668 | + task. | ||
669 | + | ||
670 | + The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those | ||
671 | + weights are discarded. | ||
672 | + | ||
673 | + Parameters: | ||
674 | + pretrained_model_name_or_path (:obj:`str`, `optional`): | ||
675 | + Can be either: | ||
676 | + | ||
677 | + - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g., | ||
678 | + ``bert-base-uncased``. | ||
679 | + - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g., | ||
680 | + ``dbmdz/bert-base-german-cased``. | ||
681 | + - A path to a `directory` containing model weights saved using | ||
682 | + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | ||
683 | + - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In | ||
684 | + this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided | ||
685 | + as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in | ||
686 | + a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. | ||
687 | + - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword | ||
688 | + arguments ``config`` and ``state_dict``). | ||
689 | + model_args (sequence of positional arguments, `optional`): | ||
690 | + All remaning positional arguments will be passed to the underlying model's ``__init__`` method. | ||
691 | + config (:obj:`Union[PretrainedConfig, str]`, `optional`): | ||
692 | + Can be either: | ||
693 | + | ||
694 | + - an instance of a class derived from :class:`~transformers.PretrainedConfig`, | ||
695 | + - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. | ||
696 | + | ||
697 | + Configuration for the model to use instead of an automatically loaded configuation. Configuration can | ||
698 | + be automatically loaded when: | ||
699 | + | ||
700 | + - The model is a model provided by the library (loaded with the `shortcut name` string of a | ||
701 | + pretrained model). | ||
702 | + - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded | ||
703 | + by suppling the save directory. | ||
704 | + - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a | ||
705 | + configuration JSON file named `config.json` is found in the directory. | ||
706 | + state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`): | ||
707 | + A state dictionary to use instead of a state dictionary loaded from saved weights file. | ||
708 | + | ||
709 | + This option can be used if you want to create a model from a pretrained configuration but load your own | ||
710 | + weights. In this case though, you should check if using | ||
711 | + :func:`~transformers.PreTrainedModel.save_pretrained` and | ||
712 | + :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. | ||
713 | + cache_dir (:obj:`str`, `optional`): | ||
714 | + Path to a directory in which a downloaded pretrained model configuration should be cached if the | ||
715 | + standard cache should not be used. | ||
716 | + from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
717 | + Load the model weights from a TensorFlow checkpoint save file (see docstring of | ||
718 | + ``pretrained_model_name_or_path`` argument). | ||
719 | + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
720 | + Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||
721 | + cached versions if they exist. | ||
722 | + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
723 | + Whether or not to delete incompletely received files. Will attempt to resume the download if such a | ||
724 | + file exists. | ||
725 | + proxies (:obj:`Dict[str, str], `optional`): | ||
726 | + A dictionary of proxy servers to use by protocol or endpoint, e.g., | ||
727 | + :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each | ||
728 | + request. | ||
729 | + output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
730 | + Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error | ||
731 | + messages. | ||
732 | + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
733 | + Whether or not to only look at local files (e.g., not try doanloading the model). | ||
734 | + use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
735 | + Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on | ||
736 | + our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. | ||
737 | + kwargs (remaining dictionary of keyword arguments, `optional`): | ||
738 | + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | ||
739 | + :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or | ||
740 | + automatically loaded: | ||
741 | + | ||
742 | + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the | ||
743 | + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have | ||
744 | + already been done) | ||
745 | + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class | ||
746 | + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of | ||
747 | + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute | ||
748 | + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration | ||
749 | + attribute will be passed to the underlying model's ``__init__`` function. | ||
750 | + | ||
751 | + Examples:: | ||
752 | + | ||
753 | + from transformers import BertConfig, BertModel | ||
754 | + # Download model and configuration from S3 and cache. | ||
755 | + model = BertModel.from_pretrained('bert-base-uncased') | ||
756 | + # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). | ||
757 | + model = BertModel.from_pretrained('./test/saved_model/') | ||
758 | + # Update configuration during loading. | ||
759 | + model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) | ||
760 | + assert model.config.output_attention == True | ||
761 | + # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). | ||
762 | + config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') | ||
763 | + model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) | ||
764 | + """ | ||
765 | + config = kwargs.pop("config", None) | ||
766 | + state_dict = kwargs.pop("state_dict", None) | ||
767 | + cache_dir = kwargs.pop("cache_dir", None) | ||
768 | + from_tf = kwargs.pop("from_tf", False) | ||
769 | + force_download = kwargs.pop("force_download", False) | ||
770 | + resume_download = kwargs.pop("resume_download", False) | ||
771 | + proxies = kwargs.pop("proxies", None) | ||
772 | + output_loading_info = kwargs.pop("output_loading_info", False) | ||
773 | + local_files_only = kwargs.pop("local_files_only", False) | ||
774 | + use_cdn = kwargs.pop("use_cdn", True) | ||
775 | + | ||
776 | + # Load config if we don't provide a configuration | ||
777 | + if not isinstance(config, PretrainedConfig): | ||
778 | + config_path = config if config is not None else pretrained_model_name_or_path | ||
779 | + config, model_kwargs = cls.config_class.from_pretrained( | ||
780 | + config_path, | ||
781 | + *model_args, | ||
782 | + cache_dir=cache_dir, | ||
783 | + return_unused_kwargs=True, | ||
784 | + force_download=force_download, | ||
785 | + resume_download=resume_download, | ||
786 | + proxies=proxies, | ||
787 | + local_files_only=local_files_only, | ||
788 | + **kwargs, | ||
789 | + ) | ||
790 | + else: | ||
791 | + model_kwargs = kwargs | ||
792 | + | ||
793 | + # Load model | ||
794 | + if pretrained_model_name_or_path is not None: | ||
795 | + if os.path.isdir(pretrained_model_name_or_path): | ||
796 | + if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): | ||
797 | + # Load from a TF 1.0 checkpoint | ||
798 | + archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") | ||
799 | + elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): | ||
800 | + # Load from a TF 2.0 checkpoint | ||
801 | + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) | ||
802 | + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): | ||
803 | + # Load from a PyTorch checkpoint | ||
804 | + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | ||
805 | + else: | ||
806 | + raise EnvironmentError( | ||
807 | + "Error no file named {} found in directory {} or `from_tf` set to False".format( | ||
808 | + [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], | ||
809 | + pretrained_model_name_or_path, | ||
810 | + ) | ||
811 | + ) | ||
812 | + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): | ||
813 | + archive_file = pretrained_model_name_or_path | ||
814 | + elif os.path.isfile(pretrained_model_name_or_path + ".index"): | ||
815 | + assert ( | ||
816 | + from_tf | ||
817 | + ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( | ||
818 | + pretrained_model_name_or_path + ".index" | ||
819 | + ) | ||
820 | + archive_file = pretrained_model_name_or_path + ".index" | ||
821 | + else: | ||
822 | + archive_file = hf_bucket_url( | ||
823 | + pretrained_model_name_or_path, | ||
824 | + filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), | ||
825 | + use_cdn=use_cdn, | ||
826 | + ) | ||
827 | + | ||
828 | + try: | ||
829 | + # Load from URL or cache if already cached | ||
830 | + resolved_archive_file = cached_path( | ||
831 | + archive_file, | ||
832 | + cache_dir=cache_dir, | ||
833 | + force_download=force_download, | ||
834 | + proxies=proxies, | ||
835 | + resume_download=resume_download, | ||
836 | + local_files_only=local_files_only, | ||
837 | + ) | ||
838 | + if resolved_archive_file is None: | ||
839 | + raise EnvironmentError | ||
840 | + except EnvironmentError: | ||
841 | + msg = ( | ||
842 | + f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" | ||
843 | + f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" | ||
844 | + f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n" | ||
845 | + ) | ||
846 | + raise EnvironmentError(msg) | ||
847 | + | ||
848 | + if resolved_archive_file == archive_file: | ||
849 | + logger.info("loading weights file {}".format(archive_file)) | ||
850 | + else: | ||
851 | + logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) | ||
852 | + else: | ||
853 | + resolved_archive_file = None | ||
854 | + | ||
855 | + # Instantiate model. | ||
856 | + model = cls(config, *model_args, **model_kwargs) | ||
857 | + | ||
858 | + if state_dict is None and not from_tf: | ||
859 | + try: | ||
860 | + state_dict = torch.load(resolved_archive_file, map_location="cpu") | ||
861 | + except Exception: | ||
862 | + raise OSError( | ||
863 | + "Unable to load weights from pytorch checkpoint file. " | ||
864 | + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " | ||
865 | + ) | ||
866 | + | ||
867 | + missing_keys = [] | ||
868 | + unexpected_keys = [] | ||
869 | + error_msgs = [] | ||
870 | + | ||
871 | + if from_tf: | ||
872 | + if resolved_archive_file.endswith(".index"): | ||
873 | + # Load from a TensorFlow 1.X checkpoint - provided by original authors | ||
874 | + model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' | ||
875 | + else: | ||
876 | + # Load from our TensorFlow 2.0 checkpoints | ||
877 | + try: | ||
878 | + from transformers import load_tf2_checkpoint_in_pytorch_model | ||
879 | + | ||
880 | + model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True) | ||
881 | + except ImportError: | ||
882 | + logger.error( | ||
883 | + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " | ||
884 | + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." | ||
885 | + ) | ||
886 | + raise | ||
887 | + else: | ||
888 | + # Convert old format to new format if needed from a PyTorch state_dict | ||
889 | + old_keys = [] | ||
890 | + new_keys = [] | ||
891 | + for key in state_dict.keys(): | ||
892 | + new_key = None | ||
893 | + if "gamma" in key: | ||
894 | + new_key = key.replace("gamma", "weight") | ||
895 | + if "beta" in key: | ||
896 | + new_key = key.replace("beta", "bias") | ||
897 | + if new_key: | ||
898 | + old_keys.append(key) | ||
899 | + new_keys.append(new_key) | ||
900 | + for old_key, new_key in zip(old_keys, new_keys): | ||
901 | + state_dict[new_key] = state_dict.pop(old_key) | ||
902 | + | ||
903 | + # copy state_dict so _load_from_state_dict can modify it | ||
904 | + metadata = getattr(state_dict, "_metadata", None) | ||
905 | + state_dict = state_dict.copy() | ||
906 | + if metadata is not None: | ||
907 | + state_dict._metadata = metadata | ||
908 | + | ||
909 | + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | ||
910 | + # so we need to apply the function recursively. | ||
911 | + def load(module: nn.Module, prefix=""): | ||
912 | + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | ||
913 | + module._load_from_state_dict( | ||
914 | + state_dict, | ||
915 | + prefix, | ||
916 | + local_metadata, | ||
917 | + True, | ||
918 | + missing_keys, | ||
919 | + unexpected_keys, | ||
920 | + error_msgs, | ||
921 | + ) | ||
922 | + for name, child in module._modules.items(): | ||
923 | + if child is not None: | ||
924 | + load(child, prefix + name + ".") | ||
925 | + | ||
926 | + # Make sure we are able to load base models as well as derived models (with heads) | ||
927 | + start_prefix = "" | ||
928 | + model_to_load = model | ||
929 | + has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()) | ||
930 | + if not hasattr(model, cls.base_model_prefix) and has_prefix_module: | ||
931 | + start_prefix = cls.base_model_prefix + "." | ||
932 | + if hasattr(model, cls.base_model_prefix) and not has_prefix_module: | ||
933 | + model_to_load = getattr(model, cls.base_model_prefix) | ||
934 | + | ||
935 | + load(model_to_load, prefix=start_prefix) | ||
936 | + | ||
937 | + if model.__class__.__name__ != model_to_load.__class__.__name__: | ||
938 | + base_model_state_dict = model_to_load.state_dict().keys() | ||
939 | + head_model_state_dict_without_base_prefix = [ | ||
940 | + key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() | ||
941 | + ] | ||
942 | + missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) | ||
943 | + | ||
944 | + # Some models may have keys that are not in the state by design, removing them before needlessly warning | ||
945 | + # the user. | ||
946 | + if cls.authorized_missing_keys is not None: | ||
947 | + for pat in cls.authorized_missing_keys: | ||
948 | + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] | ||
949 | + | ||
950 | + if len(unexpected_keys) > 0: | ||
951 | + logger.warning( | ||
952 | + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | ||
953 | + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | ||
954 | + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | ||
955 | + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" | ||
956 | + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " | ||
957 | + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | ||
958 | + ) | ||
959 | + else: | ||
960 | + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | ||
961 | + if len(missing_keys) > 0: | ||
962 | + logger.warning( | ||
963 | + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | ||
964 | + f"and are newly initialized: {missing_keys}\n" | ||
965 | + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | ||
966 | + ) | ||
967 | + else: | ||
968 | + logger.info( | ||
969 | + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" | ||
970 | + f"If your task is similar to the task the model of the checkpoint was trained on, " | ||
971 | + f"you can already use {model.__class__.__name__} for predictions without further training." | ||
972 | + ) | ||
973 | + if len(error_msgs) > 0: | ||
974 | + raise RuntimeError( | ||
975 | + "Error(s) in loading state_dict for {}:\n\t{}".format( | ||
976 | + model.__class__.__name__, "\n\t".join(error_msgs) | ||
977 | + ) | ||
978 | + ) | ||
979 | + # make sure token embedding weights are still tied if needed | ||
980 | + model.tie_weights() | ||
981 | + | ||
982 | + # Set model in evaluation mode to deactivate DropOut modules by default | ||
983 | + model.eval() | ||
984 | + | ||
985 | + if output_loading_info: | ||
986 | + loading_info = { | ||
987 | + "missing_keys": missing_keys, | ||
988 | + "unexpected_keys": unexpected_keys, | ||
989 | + "error_msgs": error_msgs, | ||
990 | + } | ||
991 | + return model, loading_info | ||
992 | + | ||
993 | + if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available(): | ||
994 | + import torch_xla.core.xla_model as xm | ||
995 | + | ||
996 | + model = xm.send_cpu_data_to_device(model, xm.xla_device()) | ||
997 | + model.to(xm.xla_device()) | ||
998 | + | ||
999 | + return model | ||
1000 | + | ||
1001 | + | ||
1002 | +class Conv1D(nn.Module): | ||
1003 | + """ | ||
1004 | + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). | ||
1005 | + | ||
1006 | + Basically works like a linear layer but the weights are transposed. | ||
1007 | + | ||
1008 | + Args: | ||
1009 | + nf (:obj:`int`): The number of output features. | ||
1010 | + nx (:obj:`int`): The number of input features. | ||
1011 | + """ | ||
1012 | + | ||
1013 | + def __init__(self, nf, nx): | ||
1014 | + super().__init__() | ||
1015 | + self.nf = nf | ||
1016 | + w = torch.empty(nx, nf) | ||
1017 | + nn.init.normal_(w, std=0.02) | ||
1018 | + self.weight = nn.Parameter(w) | ||
1019 | + self.bias = nn.Parameter(torch.zeros(nf)) | ||
1020 | + | ||
1021 | + def forward(self, x): | ||
1022 | + size_out = x.size()[:-1] + (self.nf,) | ||
1023 | + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) | ||
1024 | + x = x.view(*size_out) | ||
1025 | + return x | ||
1026 | + | ||
1027 | + | ||
1028 | +class PoolerStartLogits(nn.Module): | ||
1029 | + """ | ||
1030 | + Compute SQuAD start logits from sequence hidden states. | ||
1031 | + | ||
1032 | + Args: | ||
1033 | + config (:class:`~transformers.PretrainedConfig`): | ||
1034 | + The config used by the model, will be used to grab the :obj:`hidden_size` of the model. | ||
1035 | + """ | ||
1036 | + | ||
1037 | + def __init__(self, config: PretrainedConfig): | ||
1038 | + super().__init__() | ||
1039 | + self.dense = nn.Linear(config.hidden_size, 1) | ||
1040 | + | ||
1041 | + def forward( | ||
1042 | + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None | ||
1043 | + ) -> torch.FloatTensor: | ||
1044 | + """ | ||
1045 | + Args: | ||
1046 | + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): | ||
1047 | + The final hidden states of the model. | ||
1048 | + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): | ||
1049 | + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). | ||
1050 | + 1.0 means token should be masked. | ||
1051 | + | ||
1052 | + Returns: | ||
1053 | + :obj:`torch.FloatTensor`: The start logits for SQuAD. | ||
1054 | + """ | ||
1055 | + x = self.dense(hidden_states).squeeze(-1) | ||
1056 | + | ||
1057 | + if p_mask is not None: | ||
1058 | + if next(self.parameters()).dtype == torch.float16: | ||
1059 | + x = x * (1 - p_mask) - 65500 * p_mask | ||
1060 | + else: | ||
1061 | + x = x * (1 - p_mask) - 1e30 * p_mask | ||
1062 | + | ||
1063 | + return x | ||
1064 | + | ||
1065 | + | ||
1066 | +class PoolerEndLogits(nn.Module): | ||
1067 | + """ | ||
1068 | + Compute SQuAD end logits from sequence hidden states. | ||
1069 | + | ||
1070 | + Args: | ||
1071 | + config (:class:`~transformers.PretrainedConfig`): | ||
1072 | + The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the | ||
1073 | + :obj:`layer_norm_eps` to use. | ||
1074 | + """ | ||
1075 | + | ||
1076 | + def __init__(self, config: PretrainedConfig): | ||
1077 | + super().__init__() | ||
1078 | + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) | ||
1079 | + self.activation = nn.Tanh() | ||
1080 | + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
1081 | + self.dense_1 = nn.Linear(config.hidden_size, 1) | ||
1082 | + | ||
1083 | + def forward( | ||
1084 | + self, | ||
1085 | + hidden_states: torch.FloatTensor, | ||
1086 | + start_states: Optional[torch.FloatTensor] = None, | ||
1087 | + start_positions: Optional[torch.LongTensor] = None, | ||
1088 | + p_mask: Optional[torch.FloatTensor] = None, | ||
1089 | + ) -> torch.FloatTensor: | ||
1090 | + """ | ||
1091 | + Args: | ||
1092 | + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): | ||
1093 | + The final hidden states of the model. | ||
1094 | + start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`): | ||
1095 | + The hidden states of the first tokens for the labeled span. | ||
1096 | + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1097 | + The position of the first token for the labeled span. | ||
1098 | + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): | ||
1099 | + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). | ||
1100 | + 1.0 means token should be masked. | ||
1101 | + | ||
1102 | + .. note:: | ||
1103 | + | ||
1104 | + One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set, | ||
1105 | + ``start_positions`` overrides ``start_states``. | ||
1106 | + | ||
1107 | + Returns: | ||
1108 | + :obj:`torch.FloatTensor`: The end logits for SQuAD. | ||
1109 | + """ | ||
1110 | + assert ( | ||
1111 | + start_states is not None or start_positions is not None | ||
1112 | + ), "One of start_states, start_positions should be not None" | ||
1113 | + if start_positions is not None: | ||
1114 | + slen, hsz = hidden_states.shape[-2:] | ||
1115 | + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) | ||
1116 | + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) | ||
1117 | + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) | ||
1118 | + | ||
1119 | + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) | ||
1120 | + x = self.activation(x) | ||
1121 | + x = self.LayerNorm(x) | ||
1122 | + x = self.dense_1(x).squeeze(-1) | ||
1123 | + | ||
1124 | + if p_mask is not None: | ||
1125 | + if next(self.parameters()).dtype == torch.float16: | ||
1126 | + x = x * (1 - p_mask) - 65500 * p_mask | ||
1127 | + else: | ||
1128 | + x = x * (1 - p_mask) - 1e30 * p_mask | ||
1129 | + | ||
1130 | + return x | ||
1131 | + | ||
1132 | + | ||
1133 | +class PoolerAnswerClass(nn.Module): | ||
1134 | + """ | ||
1135 | + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. | ||
1136 | + | ||
1137 | + Args: | ||
1138 | + config (:class:`~transformers.PretrainedConfig`): | ||
1139 | + The config used by the model, will be used to grab the :obj:`hidden_size` of the model. | ||
1140 | + """ | ||
1141 | + | ||
1142 | + def __init__(self, config): | ||
1143 | + super().__init__() | ||
1144 | + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) | ||
1145 | + self.activation = nn.Tanh() | ||
1146 | + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) | ||
1147 | + | ||
1148 | + def forward( | ||
1149 | + self, | ||
1150 | + hidden_states: torch.FloatTensor, | ||
1151 | + start_states: Optional[torch.FloatTensor] = None, | ||
1152 | + start_positions: Optional[torch.LongTensor] = None, | ||
1153 | + cls_index: Optional[torch.LongTensor] = None, | ||
1154 | + ) -> torch.FloatTensor: | ||
1155 | + """ | ||
1156 | + Args: | ||
1157 | + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): | ||
1158 | + The final hidden states of the model. | ||
1159 | + start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`): | ||
1160 | + The hidden states of the first tokens for the labeled span. | ||
1161 | + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1162 | + The position of the first token for the labeled span. | ||
1163 | + cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1164 | + Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token. | ||
1165 | + | ||
1166 | + .. note:: | ||
1167 | + | ||
1168 | + One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set, | ||
1169 | + ``start_positions`` overrides ``start_states``. | ||
1170 | + | ||
1171 | + Returns: | ||
1172 | + :obj:`torch.FloatTensor`: The SQuAD 2.0 answer class. | ||
1173 | + """ | ||
1174 | + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. | ||
1175 | + hsz = hidden_states.shape[-1] | ||
1176 | + assert ( | ||
1177 | + start_states is not None or start_positions is not None | ||
1178 | + ), "One of start_states, start_positions should be not None" | ||
1179 | + if start_positions is not None: | ||
1180 | + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) | ||
1181 | + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) | ||
1182 | + | ||
1183 | + if cls_index is not None: | ||
1184 | + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) | ||
1185 | + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) | ||
1186 | + else: | ||
1187 | + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) | ||
1188 | + | ||
1189 | + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) | ||
1190 | + x = self.activation(x) | ||
1191 | + x = self.dense_1(x).squeeze(-1) | ||
1192 | + | ||
1193 | + return x | ||
1194 | + | ||
1195 | + | ||
1196 | +@dataclass | ||
1197 | +class SquadHeadOutput(ModelOutput): | ||
1198 | + """ | ||
1199 | + Base class for outputs of question answering models using a :class:`~transformers.modeling_utils.SQuADHead`. | ||
1200 | + | ||
1201 | + Args: | ||
1202 | + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided): | ||
1203 | + Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. | ||
1204 | + start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1205 | + Log probabilities for the top config.start_n_top start token possibilities (beam-search). | ||
1206 | + start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1207 | + Indices for the top config.start_n_top start token possibilities (beam-search). | ||
1208 | + end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1209 | + Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). | ||
1210 | + end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1211 | + Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). | ||
1212 | + cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1213 | + Log probabilities for the ``is_impossible`` label of the answers. | ||
1214 | + | ||
1215 | + """ | ||
1216 | + | ||
1217 | + loss: Optional[torch.FloatTensor] = None | ||
1218 | + start_top_log_probs: Optional[torch.FloatTensor] = None | ||
1219 | + start_top_index: Optional[torch.LongTensor] = None | ||
1220 | + end_top_log_probs: Optional[torch.FloatTensor] = None | ||
1221 | + end_top_index: Optional[torch.LongTensor] = None | ||
1222 | + cls_logits: Optional[torch.FloatTensor] = None | ||
1223 | + | ||
1224 | + | ||
1225 | +class SQuADHead(nn.Module): | ||
1226 | + r""" | ||
1227 | + A SQuAD head inspired by XLNet. | ||
1228 | + | ||
1229 | + Args: | ||
1230 | + config (:class:`~transformers.PretrainedConfig`): | ||
1231 | + The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the | ||
1232 | + :obj:`layer_norm_eps` to use. | ||
1233 | + """ | ||
1234 | + | ||
1235 | + def __init__(self, config): | ||
1236 | + super().__init__() | ||
1237 | + self.start_n_top = config.start_n_top | ||
1238 | + self.end_n_top = config.end_n_top | ||
1239 | + | ||
1240 | + self.start_logits = PoolerStartLogits(config) | ||
1241 | + self.end_logits = PoolerEndLogits(config) | ||
1242 | + self.answer_class = PoolerAnswerClass(config) | ||
1243 | + | ||
1244 | + @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) | ||
1245 | + def forward( | ||
1246 | + self, | ||
1247 | + hidden_states: torch.FloatTensor, | ||
1248 | + start_positions: Optional[torch.LongTensor] = None, | ||
1249 | + end_positions: Optional[torch.LongTensor] = None, | ||
1250 | + cls_index: Optional[torch.LongTensor] = None, | ||
1251 | + is_impossible: Optional[torch.LongTensor] = None, | ||
1252 | + p_mask: Optional[torch.FloatTensor] = None, | ||
1253 | + return_dict: bool = False, | ||
1254 | + ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]: | ||
1255 | + """ | ||
1256 | + Args: | ||
1257 | + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): | ||
1258 | + Final hidden states of the model on the sequence tokens. | ||
1259 | + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1260 | + Positions of the first token for the labeled span. | ||
1261 | + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1262 | + Positions of the last token for the labeled span. | ||
1263 | + cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1264 | + Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token. | ||
1265 | + is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1266 | + Whether the question has a possible answer in the paragraph or not. | ||
1267 | + p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): | ||
1268 | + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). | ||
1269 | + 1.0 means token should be masked. | ||
1270 | + return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
1271 | + Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple. | ||
1272 | + | ||
1273 | + Returns: | ||
1274 | + """ | ||
1275 | + start_logits = self.start_logits(hidden_states, p_mask=p_mask) | ||
1276 | + | ||
1277 | + if start_positions is not None and end_positions is not None: | ||
1278 | + # If we are on multi-GPU, let's remove the dimension added by batch splitting | ||
1279 | + for x in (start_positions, end_positions, cls_index, is_impossible): | ||
1280 | + if x is not None and x.dim() > 1: | ||
1281 | + x.squeeze_(-1) | ||
1282 | + | ||
1283 | + # during training, compute the end logits based on the ground truth of the start position | ||
1284 | + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) | ||
1285 | + | ||
1286 | + loss_fct = CrossEntropyLoss() | ||
1287 | + start_loss = loss_fct(start_logits, start_positions) | ||
1288 | + end_loss = loss_fct(end_logits, end_positions) | ||
1289 | + total_loss = (start_loss + end_loss) / 2 | ||
1290 | + | ||
1291 | + if cls_index is not None and is_impossible is not None: | ||
1292 | + # Predict answerability from the representation of CLS and START | ||
1293 | + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) | ||
1294 | + loss_fct_cls = nn.BCEWithLogitsLoss() | ||
1295 | + cls_loss = loss_fct_cls(cls_logits, is_impossible) | ||
1296 | + | ||
1297 | + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss | ||
1298 | + total_loss += cls_loss * 0.5 | ||
1299 | + | ||
1300 | + return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) | ||
1301 | + | ||
1302 | + else: | ||
1303 | + # during inference, compute the end logits based on beam search | ||
1304 | + bsz, slen, hsz = hidden_states.size() | ||
1305 | + start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) | ||
1306 | + | ||
1307 | + start_top_log_probs, start_top_index = torch.topk( | ||
1308 | + start_log_probs, self.start_n_top, dim=-1 | ||
1309 | + ) # shape (bsz, start_n_top) | ||
1310 | + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) | ||
1311 | + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) | ||
1312 | + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) | ||
1313 | + | ||
1314 | + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( | ||
1315 | + start_states | ||
1316 | + ) # shape (bsz, slen, start_n_top, hsz) | ||
1317 | + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None | ||
1318 | + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) | ||
1319 | + end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) | ||
1320 | + | ||
1321 | + end_top_log_probs, end_top_index = torch.topk( | ||
1322 | + end_log_probs, self.end_n_top, dim=1 | ||
1323 | + ) # shape (bsz, end_n_top, start_n_top) | ||
1324 | + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) | ||
1325 | + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) | ||
1326 | + | ||
1327 | + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) | ||
1328 | + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) | ||
1329 | + | ||
1330 | + if not return_dict: | ||
1331 | + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) | ||
1332 | + else: | ||
1333 | + return SquadHeadOutput( | ||
1334 | + start_top_log_probs=start_top_log_probs, | ||
1335 | + start_top_index=start_top_index, | ||
1336 | + end_top_log_probs=end_top_log_probs, | ||
1337 | + end_top_index=end_top_index, | ||
1338 | + cls_logits=cls_logits, | ||
1339 | + ) | ||
1340 | + | ||
1341 | + | ||
1342 | +class SequenceSummary(nn.Module): | ||
1343 | + r""" | ||
1344 | + Compute a single vector summary of a sequence hidden states. | ||
1345 | + | ||
1346 | + Args: | ||
1347 | + config (:class:`~transformers.PretrainedConfig`): | ||
1348 | + The config used by the model. Relevant arguments in the config class of the model are (refer to the | ||
1349 | + actual config class of your model for the default values it uses): | ||
1350 | + | ||
1351 | + - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are: | ||
1352 | + | ||
1353 | + - :obj:`"last"` -- Take the last token hidden state (like XLNet) | ||
1354 | + - :obj:`"first"` -- Take the first token hidden state (like Bert) | ||
1355 | + - :obj:`"mean"` -- Take the mean of all tokens hidden states | ||
1356 | + - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) | ||
1357 | + - :obj:`"attn"` -- Not implemented now, use multi-head attention | ||
1358 | + | ||
1359 | + - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction. | ||
1360 | + - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to | ||
1361 | + :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`). | ||
1362 | + - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the | ||
1363 | + output, another string or :obj:`None` will add no activation. | ||
1364 | + - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and | ||
1365 | + activation. | ||
1366 | + - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and | ||
1367 | + activation. | ||
1368 | + """ | ||
1369 | + | ||
1370 | + def __init__(self, config: PretrainedConfig): | ||
1371 | + super().__init__() | ||
1372 | + | ||
1373 | + self.summary_type = getattr(config, "summary_type", "last") | ||
1374 | + if self.summary_type == "attn": | ||
1375 | + # We should use a standard multi-head attention module with absolute positional embedding for that. | ||
1376 | + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 | ||
1377 | + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 | ||
1378 | + raise NotImplementedError | ||
1379 | + | ||
1380 | + self.summary = Identity() | ||
1381 | + if hasattr(config, "summary_use_proj") and config.summary_use_proj: | ||
1382 | + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: | ||
1383 | + num_classes = config.num_labels | ||
1384 | + else: | ||
1385 | + num_classes = config.hidden_size | ||
1386 | + self.summary = nn.Linear(config.hidden_size, num_classes) | ||
1387 | + | ||
1388 | + activation_string = getattr(config, "summary_activation", None) | ||
1389 | + self.activation: Callable = get_activation(activation_string) if activation_string else Identity() | ||
1390 | + | ||
1391 | + self.first_dropout = Identity() | ||
1392 | + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: | ||
1393 | + self.first_dropout = nn.Dropout(config.summary_first_dropout) | ||
1394 | + | ||
1395 | + self.last_dropout = Identity() | ||
1396 | + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: | ||
1397 | + self.last_dropout = nn.Dropout(config.summary_last_dropout) | ||
1398 | + | ||
1399 | + def forward( | ||
1400 | + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None | ||
1401 | + ) -> torch.FloatTensor: | ||
1402 | + """ | ||
1403 | + Compute a single vector summary of a sequence hidden states. | ||
1404 | + | ||
1405 | + Args: | ||
1406 | + hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`): | ||
1407 | + The hidden states of the last layer. | ||
1408 | + cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`): | ||
1409 | + Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification | ||
1410 | + token. | ||
1411 | + | ||
1412 | + Returns: | ||
1413 | + :obj:`torch.FloatTensor`: The summary of the sequence hidden states. | ||
1414 | + """ | ||
1415 | + if self.summary_type == "last": | ||
1416 | + output = hidden_states[:, -1] | ||
1417 | + elif self.summary_type == "first": | ||
1418 | + output = hidden_states[:, 0] | ||
1419 | + elif self.summary_type == "mean": | ||
1420 | + output = hidden_states.mean(dim=1) | ||
1421 | + elif self.summary_type == "cls_index": | ||
1422 | + if cls_index is None: | ||
1423 | + cls_index = torch.full_like( | ||
1424 | + hidden_states[..., :1, :], | ||
1425 | + hidden_states.shape[-2] - 1, | ||
1426 | + dtype=torch.long, | ||
1427 | + ) | ||
1428 | + else: | ||
1429 | + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) | ||
1430 | + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) | ||
1431 | + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states | ||
1432 | + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) | ||
1433 | + elif self.summary_type == "attn": | ||
1434 | + raise NotImplementedError | ||
1435 | + | ||
1436 | + output = self.first_dropout(output) | ||
1437 | + output = self.summary(output) | ||
1438 | + output = self.activation(output) | ||
1439 | + output = self.last_dropout(output) | ||
1440 | + | ||
1441 | + return output | ||
1442 | + | ||
1443 | + | ||
1444 | +def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear: | ||
1445 | + """ | ||
1446 | + Prune a linear layer to keep only entries in index. | ||
1447 | + | ||
1448 | + Used to remove heads. | ||
1449 | + | ||
1450 | + Args: | ||
1451 | + layer (:obj:`torch.nn.Linear`): The layer to prune. | ||
1452 | + index (:obj:`torch.LongTensor`): The indices to keep in the layer. | ||
1453 | + dim (:obj:`int`, `optional`, defaults to 0): The dimension on which to keep the indices. | ||
1454 | + | ||
1455 | + Returns: | ||
1456 | + :obj:`torch.nn.Linear`: The pruned layer as a new layer with :obj:`requires_grad=True`. | ||
1457 | + """ | ||
1458 | + index = index.to(layer.weight.device) | ||
1459 | + W = layer.weight.index_select(dim, index).clone().detach() | ||
1460 | + if layer.bias is not None: | ||
1461 | + if dim == 1: | ||
1462 | + b = layer.bias.clone().detach() | ||
1463 | + else: | ||
1464 | + b = layer.bias[index].clone().detach() | ||
1465 | + new_size = list(layer.weight.size()) | ||
1466 | + new_size[dim] = len(index) | ||
1467 | + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) | ||
1468 | + new_layer.weight.requires_grad = False | ||
1469 | + new_layer.weight.copy_(W.contiguous()) | ||
1470 | + new_layer.weight.requires_grad = True | ||
1471 | + if layer.bias is not None: | ||
1472 | + new_layer.bias.requires_grad = False | ||
1473 | + new_layer.bias.copy_(b.contiguous()) | ||
1474 | + new_layer.bias.requires_grad = True | ||
1475 | + return new_layer | ||
1476 | + | ||
1477 | + | ||
1478 | +def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D: | ||
1479 | + """ | ||
1480 | + Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights | ||
1481 | + are transposed. | ||
1482 | + | ||
1483 | + Used to remove heads. | ||
1484 | + | ||
1485 | + Args: | ||
1486 | + layer (:class:`~transformers.modeling_utils.Conv1D`): The layer to prune. | ||
1487 | + index (:obj:`torch.LongTensor`): The indices to keep in the layer. | ||
1488 | + dim (:obj:`int`, `optional`, defaults to 1): The dimension on which to keep the indices. | ||
1489 | + | ||
1490 | + Returns: | ||
1491 | + :class:`~transformers.modeling_utils.Conv1D`: The pruned layer as a new layer with :obj:`requires_grad=True`. | ||
1492 | + """ | ||
1493 | + index = index.to(layer.weight.device) | ||
1494 | + W = layer.weight.index_select(dim, index).clone().detach() | ||
1495 | + if dim == 0: | ||
1496 | + b = layer.bias.clone().detach() | ||
1497 | + else: | ||
1498 | + b = layer.bias[index].clone().detach() | ||
1499 | + new_size = list(layer.weight.size()) | ||
1500 | + new_size[dim] = len(index) | ||
1501 | + new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) | ||
1502 | + new_layer.weight.requires_grad = False | ||
1503 | + new_layer.weight.copy_(W.contiguous()) | ||
1504 | + new_layer.weight.requires_grad = True | ||
1505 | + new_layer.bias.requires_grad = False | ||
1506 | + new_layer.bias.copy_(b.contiguous()) | ||
1507 | + new_layer.bias.requires_grad = True | ||
1508 | + return new_layer | ||
1509 | + | ||
1510 | + | ||
1511 | +def prune_layer( | ||
1512 | + layer: Union[torch.nn.Linear, Conv1D], index: torch.LongTensor, dim: Optional[int] = None | ||
1513 | +) -> Union[torch.nn.Linear, Conv1D]: | ||
1514 | + """ | ||
1515 | + Prune a Conv1D or linear layer to keep only entries in index. | ||
1516 | + | ||
1517 | + Used to remove heads. | ||
1518 | + | ||
1519 | + Args: | ||
1520 | + layer (:obj:`Union[torch.nn.Linear, Conv1D]`): The layer to prune. | ||
1521 | + index (:obj:`torch.LongTensor`): The indices to keep in the layer. | ||
1522 | + dim (:obj:`int`, `optional`): The dimension on which to keep the indices. | ||
1523 | + | ||
1524 | + Returns: | ||
1525 | + :obj:`torch.nn.Linear` or :class:`~transformers.modeling_utils.Conv1D`: | ||
1526 | + The pruned layer as a new layer with :obj:`requires_grad=True`. | ||
1527 | + """ | ||
1528 | + if isinstance(layer, nn.Linear): | ||
1529 | + return prune_linear_layer(layer, index, dim=0 if dim is None else dim) | ||
1530 | + elif isinstance(layer, Conv1D): | ||
1531 | + return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) | ||
1532 | + else: | ||
1533 | + raise ValueError("Can't prune layer of class {}".format(layer.__class__)) | ||
1534 | + | ||
1535 | + | ||
1536 | +def apply_chunking_to_forward( | ||
1537 | + forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors | ||
1538 | +) -> torch.Tensor: | ||
1539 | + """ | ||
1540 | + This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the | ||
1541 | + dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory. | ||
1542 | + | ||
1543 | + If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as | ||
1544 | + directly applying :obj:`forward_fn` to :obj:`input_tensors`. | ||
1545 | + | ||
1546 | + Args: | ||
1547 | + forward_fn (:obj:`Callable[..., torch.Tensor]`): | ||
1548 | + The forward function of the model. | ||
1549 | + chunk_size (:obj:`int`): | ||
1550 | + The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`. | ||
1551 | + chunk_dim (:obj:`int`): | ||
1552 | + The dimension over which the :obj:`input_tensors` should be chunked. | ||
1553 | + input_tensors (:obj:`Tuple[torch.Tensor]`): | ||
1554 | + The input tensors of ``forward_fn`` which will be chunked. | ||
1555 | + Returns: | ||
1556 | + :obj:`torch.Tensor`: A tensor with the same shape as the :obj:`foward_fn` would have given if applied`. | ||
1557 | + | ||
1558 | + | ||
1559 | + Examples:: | ||
1560 | + | ||
1561 | + # rename the usual forward() fn to forward_chunk() | ||
1562 | + def forward_chunk(self, hidden_states): | ||
1563 | + hidden_states = self.decoder(hidden_states) | ||
1564 | + return hidden_states | ||
1565 | + | ||
1566 | + # implement a chunked forward function | ||
1567 | + def forward(self, hidden_states): | ||
1568 | + return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) | ||
1569 | + """ | ||
1570 | + | ||
1571 | + assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors) | ||
1572 | + tensor_shape = input_tensors[0].shape | ||
1573 | + assert all( | ||
1574 | + input_tensor.shape == tensor_shape for input_tensor in input_tensors | ||
1575 | + ), "All input tenors have to be of the same shape" | ||
1576 | + | ||
1577 | + # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability | ||
1578 | + num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) | ||
1579 | + assert num_args_in_forward_chunk_fn == len( | ||
1580 | + input_tensors | ||
1581 | + ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format( | ||
1582 | + num_args_in_forward_chunk_fn, len(input_tensors) | ||
1583 | + ) | ||
1584 | + | ||
1585 | + if chunk_size > 0: | ||
1586 | + assert ( | ||
1587 | + input_tensors[0].shape[chunk_dim] % chunk_size == 0 | ||
1588 | + ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( | ||
1589 | + input_tensors[0].shape[chunk_dim], chunk_size | ||
1590 | + ) | ||
1591 | + | ||
1592 | + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size | ||
1593 | + | ||
1594 | + # chunk input tensor into tuples | ||
1595 | + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) | ||
1596 | + # apply forward fn to every tuple | ||
1597 | + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) | ||
1598 | + # concatenate output at same dimension | ||
1599 | + return torch.cat(output_chunks, dim=chunk_dim) | ||
1600 | + | ||
1601 | + return forward_fn(*input_tensors) |
-
Please register or login to post a comment