Showing
13 changed files
with
0 additions
and
6984 deletions
commit_suggester.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -import torch | ||
16 | -import argparse | ||
17 | -import subprocess | ||
18 | -from transformers import AutoTokenizer | ||
19 | - | ||
20 | -from preprocess import diff_parse, truncate | ||
21 | -from train import BartForConditionalGeneration | ||
22 | - | ||
23 | -def get_length(chunks): | ||
24 | - cnt = 0 | ||
25 | - for chunk in chunks: | ||
26 | - cnt += len(chunk) | ||
27 | - return cnt | ||
28 | - | ||
29 | -def suggester(chunks, model, tokenizer, device): | ||
30 | - max_source_length = get_length(chunks) | ||
31 | - | ||
32 | - input_ids, attention_masks, patch_ids = zip(*chunks) | ||
33 | - input_ids = torch.LongTensor( | ||
34 | - [truncate(input_ids, max_source_length, value=0)] | ||
35 | - ).to(device) | ||
36 | - attention_masks = torch.LongTensor( | ||
37 | - [truncate(attention_masks, max_source_length, value=1)] | ||
38 | - ).to(device) | ||
39 | - patch_ids = torch.LongTensor( | ||
40 | - [truncate(patch_ids, max_source_length, value=0)] | ||
41 | - ).to(device) | ||
42 | - | ||
43 | - summaries = model.generate( | ||
44 | - input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks | ||
45 | - ) | ||
46 | - return tokenizer.batch_decode( | ||
47 | - summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
48 | - ) | ||
49 | - | ||
50 | - | ||
51 | -def main(args): | ||
52 | - device = torch.device( | ||
53 | - "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" | ||
54 | - ) | ||
55 | - model = BartForConditionalGeneration.from_pretrained(args.output_dir).to(device) | ||
56 | - | ||
57 | - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) | ||
58 | - | ||
59 | - if args.unittest: | ||
60 | - with open("test.source", "r") as f: | ||
61 | - chunks = diff_parse(f.read(), tokenizer) | ||
62 | - else: | ||
63 | - proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE) | ||
64 | - staged_files = proc.stdout.readlines() | ||
65 | - staged_files = [f.decode("utf-8") for f in staged_files] | ||
66 | - staged_files = [f.strip() for f in staged_files] | ||
67 | - chunks = "\n".join(staged_files) | ||
68 | - | ||
69 | - chunks = diff_parse(chunks, tokenizer) | ||
70 | - if not chunks: | ||
71 | - print('There is no file in staged state.') | ||
72 | - return | ||
73 | - | ||
74 | - commit_message = suggester( | ||
75 | - chunks, | ||
76 | - model=model, | ||
77 | - tokenizer=tokenizer, | ||
78 | - device=device, | ||
79 | - ) | ||
80 | - print(commit_message) | ||
81 | - | ||
82 | - | ||
83 | -if __name__ == "__main__": | ||
84 | - parser = argparse.ArgumentParser(description="Code to collect commits on github") | ||
85 | - parser.add_argument( | ||
86 | - "--no_cuda", action="store_true", help="Whether not to use CUDA when available" | ||
87 | - ) | ||
88 | - parser.add_argument( | ||
89 | - "--unittest", action="store_true", help="Unittest with an one batch git diff" | ||
90 | - ) | ||
91 | - parser.add_argument( | ||
92 | - "--output_dir", | ||
93 | - type=str, | ||
94 | - required=True, | ||
95 | - help="The output directory where the model predictions and checkpoints will be written.", | ||
96 | - ) | ||
97 | - parser.add_argument( | ||
98 | - "--tokenizer_name", | ||
99 | - default="sshleifer/distilbart-xsum-6-6", | ||
100 | - type=str, | ||
101 | - help="Pretrained tokenizer name or path if not the same as model_name", | ||
102 | - ) | ||
103 | - args = parser.parse_args() | ||
104 | - | ||
105 | - main(args) |
preprocess/__init__.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -from .gitcommit import diff_parse, truncate | ||
16 | - | ||
17 | -__all__ = [ | ||
18 | - "diff_parse", | ||
19 | - "truncate", | ||
20 | -] |
preprocess/gitcommit.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -import os | ||
16 | -import re | ||
17 | -import enum | ||
18 | -import random | ||
19 | -import logging | ||
20 | -import tempfile | ||
21 | -import argparse | ||
22 | -import numpy as np | ||
23 | -from tqdm import * | ||
24 | -import whatthepatch | ||
25 | -from git import Repo | ||
26 | -from functools import partial | ||
27 | -from multiprocessing.pool import Pool | ||
28 | -from transformers import AutoTokenizer | ||
29 | - | ||
30 | -from matorage import * | ||
31 | - | ||
32 | -logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
33 | -logging.basicConfig( | ||
34 | - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
35 | - datefmt="%m/%d/%Y %H:%M:%S", | ||
36 | - level=logging.INFO, | ||
37 | -) | ||
38 | - | ||
39 | - | ||
40 | -class PATCH(enum.Enum): | ||
41 | - PLUS = 1 | ||
42 | - MINUS = 2 | ||
43 | - | ||
44 | - | ||
45 | -def truncate(tuple, max_length, value=0): | ||
46 | - ls = [] | ||
47 | - for t in tuple: | ||
48 | - if isinstance(t, int): | ||
49 | - t = [t] | ||
50 | - ls.extend(t) | ||
51 | - ls = ls[: max_length - 1] | ||
52 | - ls.insert(0, value) | ||
53 | - if len(ls) < max_length: | ||
54 | - ls.extend([0] * (max_length - len(ls))) | ||
55 | - assert len(ls) == max_length | ||
56 | - return ls | ||
57 | - | ||
58 | - | ||
59 | -def encode_line(tokenizer, line, patch): | ||
60 | - line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip() | ||
61 | - tokens = tokenizer.tokenize(line) | ||
62 | - tokens = tokenizer.convert_tokens_to_ids(tokens) | ||
63 | - return (tokens, [1] * len(tokens), len(tokens) * [patch.value]) | ||
64 | - | ||
65 | - | ||
66 | -def diff_parse(diff, tokenizer): | ||
67 | - chunks = [] | ||
68 | - for diff in whatthepatch.parse_patch(diff): | ||
69 | - if diff.header.old_path != diff.header.new_path: | ||
70 | - chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS)) | ||
71 | - chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS)) | ||
72 | - if not diff.changes: | ||
73 | - continue | ||
74 | - for change in diff.changes: | ||
75 | - if change.old == None and change.new != None: | ||
76 | - chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS)) | ||
77 | - elif change.old != None and change.new == None: | ||
78 | - chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS)) | ||
79 | - return chunks | ||
80 | - | ||
81 | - | ||
82 | -def sha_parse(sha, tokenizer, max_length=1024): | ||
83 | - | ||
84 | - chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer) | ||
85 | - if not chunks: | ||
86 | - return None | ||
87 | - | ||
88 | - input_ids, attention_masks, patch_ids = zip(*chunks) | ||
89 | - input_ids = truncate(input_ids, max_length, value=0) | ||
90 | - attention_masks = truncate(attention_masks, max_length, value=1) | ||
91 | - patch_ids = truncate(patch_ids, max_length, value=0) | ||
92 | - | ||
93 | - return (input_ids, attention_masks, patch_ids) | ||
94 | - | ||
95 | - | ||
96 | -def message_parse(msg, tokenizer, max_length=56): | ||
97 | - msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg) | ||
98 | - | ||
99 | - msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip() | ||
100 | - msg = tokenizer.tokenize(msg) | ||
101 | - msg = tokenizer.convert_tokens_to_ids(msg) | ||
102 | - msg = truncate(msg, max_length, value=0) | ||
103 | - | ||
104 | - return msg | ||
105 | - | ||
106 | - | ||
107 | -def jobs(sha_msgs, args, data_config, train=True): | ||
108 | - | ||
109 | - input_ids, attention_masks, patch_ids, targets = [], [], [], [] | ||
110 | - data_saver = DataSaver(config=data_config) | ||
111 | - | ||
112 | - for sha_msg in sha_msgs: | ||
113 | - sha, msg = sha_msg | ||
114 | - | ||
115 | - source = sha_parse( | ||
116 | - sha, tokenizer=args.tokenizer, max_length=args.max_source_length | ||
117 | - ) | ||
118 | - if not source: | ||
119 | - continue | ||
120 | - input_id, attention_mask, patch_id = source | ||
121 | - target = message_parse( | ||
122 | - msg, | ||
123 | - tokenizer=args.tokenizer, | ||
124 | - max_length=( | ||
125 | - args.max_target_length if train else args.val_max_target_length | ||
126 | - ), | ||
127 | - ) | ||
128 | - | ||
129 | - input_ids.append(input_id) | ||
130 | - attention_masks.append(attention_mask) | ||
131 | - patch_ids.append(patch_id) | ||
132 | - targets.append(target) | ||
133 | - | ||
134 | - data_saver( | ||
135 | - { | ||
136 | - "input_ids": np.asarray(input_ids), | ||
137 | - "attention_masks": np.asarray(attention_masks), | ||
138 | - "patch_ids": np.asarray(patch_ids), | ||
139 | - "targets": np.asarray(targets), | ||
140 | - } | ||
141 | - ) | ||
142 | - data_saver.disconnect() | ||
143 | - | ||
144 | - | ||
145 | -def start(chunked_sha_msgs, train=True): | ||
146 | - | ||
147 | - logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) | ||
148 | - | ||
149 | - max_target_length = args.max_target_length if train else args.val_max_target_length | ||
150 | - | ||
151 | - data_config = DataConfig( | ||
152 | - endpoint=args.endpoint, | ||
153 | - access_key=os.environ["access_key"], | ||
154 | - secret_key=os.environ["secret_key"], | ||
155 | - region=args.region, | ||
156 | - dataset_name="commit-autosuggestions", | ||
157 | - additional={ | ||
158 | - "mode": ("training" if train else "evaluation"), | ||
159 | - "max_source_length": args.max_source_length, | ||
160 | - "max_target_length": max_target_length, | ||
161 | - "url": args.url, | ||
162 | - }, | ||
163 | - attributes=[ | ||
164 | - ("input_ids", "int32", (args.max_source_length,)), | ||
165 | - ("attention_masks", "int32", (args.max_source_length,)), | ||
166 | - ("patch_ids", "int32", (args.max_source_length,)), | ||
167 | - ("targets", "int32", (max_target_length,)), | ||
168 | - ], | ||
169 | - ) | ||
170 | - | ||
171 | - func = partial(jobs, args=args, data_config=data_config, train=train) | ||
172 | - with Pool(processes=args.num_workers) as pool: | ||
173 | - with tqdm(total=len(chunked_sha_msgs)) as pbar: | ||
174 | - for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): | ||
175 | - pbar.update() | ||
176 | - | ||
177 | - | ||
178 | -def main(args): | ||
179 | - if "access_key" not in os.environ or "secret_key" not in os.environ: | ||
180 | - raise OSError("access_key or secret_key are not found.") | ||
181 | - | ||
182 | - sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] | ||
183 | - random.shuffle(sha_msgs) | ||
184 | - chunked_sha_msgs = [ | ||
185 | - sha_msgs[x : x + args.matorage_batch] | ||
186 | - for x in range(0, len(sha_msgs), args.matorage_batch) | ||
187 | - ] | ||
188 | - | ||
189 | - barrier = int(len(chunked_sha_msgs) * (1 - args.p_val)) | ||
190 | - if args.do_train: | ||
191 | - start(chunked_sha_msgs[:barrier], train=True) | ||
192 | - if args.do_predict: | ||
193 | - start(chunked_sha_msgs[barrier:], train=False) | ||
194 | - | ||
195 | - | ||
196 | -if __name__ == "__main__": | ||
197 | - parser = argparse.ArgumentParser(description="Code to collect commits on github") | ||
198 | - parser.add_argument("--url", type=str, required=True, help="github url") | ||
199 | - parser.add_argument( | ||
200 | - "--endpoint", | ||
201 | - type=str, | ||
202 | - required=True, | ||
203 | - help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", | ||
204 | - ) | ||
205 | - parser.add_argument( | ||
206 | - "--region", | ||
207 | - type=str, | ||
208 | - default=None, | ||
209 | - help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", | ||
210 | - ) | ||
211 | - parser.add_argument( | ||
212 | - "--tokenizer_name", | ||
213 | - default="sshleifer/distilbart-xsum-6-6", | ||
214 | - type=str, | ||
215 | - help="Pretrained tokenizer name or path if not the same as model_name", | ||
216 | - ) | ||
217 | - parser.add_argument( | ||
218 | - "--matorage_batch", | ||
219 | - default=1024, | ||
220 | - type=int, | ||
221 | - help="The smallest batch size stored atomically in matorage.", | ||
222 | - ) | ||
223 | - parser.add_argument( | ||
224 | - "--num_workers", default=4, type=int, help="number of process", | ||
225 | - ) | ||
226 | - parser.add_argument( | ||
227 | - "--max_source_length", | ||
228 | - default=1024, | ||
229 | - type=int, | ||
230 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
231 | - "than this will be truncated, sequences shorter will be padded.", | ||
232 | - ) | ||
233 | - parser.add_argument( | ||
234 | - "--max_target_length", | ||
235 | - default=56, | ||
236 | - type=int, | ||
237 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
238 | - "than this will be truncated, sequences shorter will be padded.", | ||
239 | - ) | ||
240 | - parser.add_argument( | ||
241 | - "--val_max_target_length", | ||
242 | - default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. | ||
243 | - type=int, | ||
244 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
245 | - "than this will be truncated, sequences shorter will be padded.", | ||
246 | - ) | ||
247 | - parser.add_argument( | ||
248 | - "--p_val", type=float, default=0.25, help="percent of validation dataset" | ||
249 | - ) | ||
250 | - parser.add_argument("--do_train", action="store_true", default=False) | ||
251 | - parser.add_argument("--do_predict", action="store_true", default=False) | ||
252 | - args = parser.parse_args() | ||
253 | - | ||
254 | - args.local_path = args.url.split("/")[-1] | ||
255 | - logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}") | ||
256 | - repo = ( | ||
257 | - Repo(args.local_path) | ||
258 | - if os.path.exists(args.local_path) | ||
259 | - else Repo.clone_from(args.url, to_path=args.local_path, branch="master") | ||
260 | - ) | ||
261 | - args.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) | ||
262 | - | ||
263 | - main(args) |
test.source
deleted
100644 → 0
1 | -commit b5a5268dabb2a4dea1c3c543a1ddff501b87a447 | ||
2 | -Author: jbrockmendel <jbrockmendel@gmail.com> | ||
3 | -Date: Tue Sep 8 18:33:41 2020 -0700 | ||
4 | - | ||
5 | - STY: De-privatize imported names (#36235) | ||
6 | - | ||
7 | -diff --git a/pandas/_libs/interval.pyx b/pandas/_libs/interval.pyx | ||
8 | -index 931ad8326..f8bcbcfb1 100644 | ||
9 | ---- a/pandas/_libs/interval.pyx | ||
10 | -+++ b/pandas/_libs/interval.pyx | ||
11 | -@@ -46,7 +46,7 @@ from pandas._libs.tslibs.util cimport ( | ||
12 | - is_timedelta64_object, | ||
13 | - ) | ||
14 | - | ||
15 | --_VALID_CLOSED = frozenset(['left', 'right', 'both', 'neither']) | ||
16 | -+VALID_CLOSED = frozenset(['left', 'right', 'both', 'neither']) | ||
17 | - | ||
18 | - | ||
19 | - cdef class IntervalMixin: | ||
20 | -@@ -318,7 +318,7 @@ cdef class Interval(IntervalMixin): | ||
21 | - self._validate_endpoint(left) | ||
22 | - self._validate_endpoint(right) | ||
23 | - | ||
24 | -- if closed not in _VALID_CLOSED: | ||
25 | -+ if closed not in VALID_CLOSED: | ||
26 | - raise ValueError(f"invalid option for 'closed': {closed}") | ||
27 | - if not left <= right: | ||
28 | - raise ValueError("left side of interval must be <= right side") | ||
29 | -diff --git a/pandas/core/arrays/_arrow_utils.py b/pandas/core/arrays/_arrow_utils.py | ||
30 | -index 4a33e0e84..c89f5554d 100644 | ||
31 | ---- a/pandas/core/arrays/_arrow_utils.py | ||
32 | -+++ b/pandas/core/arrays/_arrow_utils.py | ||
33 | -@@ -4,7 +4,7 @@ import json | ||
34 | - import numpy as np | ||
35 | - import pyarrow | ||
36 | - | ||
37 | --from pandas.core.arrays.interval import _VALID_CLOSED | ||
38 | -+from pandas.core.arrays.interval import VALID_CLOSED | ||
39 | - | ||
40 | - _pyarrow_version_ge_015 = LooseVersion(pyarrow.__version__) >= LooseVersion("0.15") | ||
41 | - | ||
42 | -@@ -83,7 +83,7 @@ if _pyarrow_version_ge_015: | ||
43 | - def __init__(self, subtype, closed): | ||
44 | - # attributes need to be set first before calling | ||
45 | - # super init (as that calls serialize) | ||
46 | -- assert closed in _VALID_CLOSED | ||
47 | -+ assert closed in VALID_CLOSED | ||
48 | - self._closed = closed | ||
49 | - if not isinstance(subtype, pyarrow.DataType): | ||
50 | - subtype = pyarrow.type_for_alias(str(subtype)) | ||
51 | -diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py | ||
52 | -index d76e0fd62..1dbd3cfc6 100644 | ||
53 | ---- a/pandas/core/arrays/interval.py | ||
54 | -+++ b/pandas/core/arrays/interval.py | ||
55 | -@@ -5,7 +5,12 @@ import numpy as np | ||
56 | - | ||
57 | - from pandas._config import get_option | ||
58 | - | ||
59 | --from pandas._libs.interval import Interval, IntervalMixin, intervals_to_interval_bounds | ||
60 | -+from pandas._libs.interval import ( | ||
61 | -+ VALID_CLOSED, | ||
62 | -+ Interval, | ||
63 | -+ IntervalMixin, | ||
64 | -+ intervals_to_interval_bounds, | ||
65 | -+) | ||
66 | - from pandas.compat.numpy import function as nv | ||
67 | - from pandas.util._decorators import Appender | ||
68 | - | ||
69 | -@@ -42,7 +47,6 @@ from pandas.core.construction import array | ||
70 | - from pandas.core.indexers import check_array_indexer | ||
71 | - from pandas.core.indexes.base import ensure_index | ||
72 | - | ||
73 | --_VALID_CLOSED = {"left", "right", "both", "neither"} | ||
74 | - _interval_shared_docs = {} | ||
75 | - | ||
76 | - _shared_docs_kwargs = dict( | ||
77 | -@@ -475,7 +479,7 @@ class IntervalArray(IntervalMixin, ExtensionArray): | ||
78 | - * left and right have the same missing values | ||
79 | - * left is always below right | ||
80 | - """ | ||
81 | -- if self.closed not in _VALID_CLOSED: | ||
82 | -+ if self.closed not in VALID_CLOSED: | ||
83 | - msg = f"invalid option for 'closed': {self.closed}" | ||
84 | - raise ValueError(msg) | ||
85 | - if len(self.left) != len(self.right): | ||
86 | -@@ -1012,7 +1016,7 @@ class IntervalArray(IntervalMixin, ExtensionArray): | ||
87 | - ) | ||
88 | - ) | ||
89 | - def set_closed(self, closed): | ||
90 | -- if closed not in _VALID_CLOSED: | ||
91 | -+ if closed not in VALID_CLOSED: | ||
92 | - msg = f"invalid option for 'closed': {closed}" | ||
93 | - raise ValueError(msg) | ||
94 | - | ||
95 | -diff --git a/pandas/core/arrays/sparse/__init__.py b/pandas/core/arrays/sparse/__init__.py | ||
96 | -index e928db499..e9ff4b7d4 100644 | ||
97 | ---- a/pandas/core/arrays/sparse/__init__.py | ||
98 | -+++ b/pandas/core/arrays/sparse/__init__.py | ||
99 | -@@ -5,6 +5,6 @@ from pandas.core.arrays.sparse.array import ( | ||
100 | - BlockIndex, | ||
101 | - IntIndex, | ||
102 | - SparseArray, | ||
103 | -- _make_index, | ||
104 | -+ make_sparse_index, | ||
105 | - ) | ||
106 | - from pandas.core.arrays.sparse.dtype import SparseDtype | ||
107 | -diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py | ||
108 | -index 47c960dc9..853f7bb0b 100644 | ||
109 | ---- a/pandas/core/arrays/sparse/array.py | ||
110 | -+++ b/pandas/core/arrays/sparse/array.py | ||
111 | -@@ -1556,7 +1556,7 @@ def make_sparse(arr: np.ndarray, kind="block", fill_value=None, dtype=None, copy | ||
112 | - else: | ||
113 | - indices = mask.nonzero()[0].astype(np.int32) | ||
114 | - | ||
115 | -- index = _make_index(length, indices, kind) | ||
116 | -+ index = make_sparse_index(length, indices, kind) | ||
117 | - sparsified_values = arr[mask] | ||
118 | - if dtype is not None: | ||
119 | - sparsified_values = astype_nansafe(sparsified_values, dtype=dtype) | ||
120 | -@@ -1564,7 +1564,7 @@ def make_sparse(arr: np.ndarray, kind="block", fill_value=None, dtype=None, copy | ||
121 | - return sparsified_values, index, fill_value | ||
122 | - | ||
123 | - | ||
124 | --def _make_index(length, indices, kind): | ||
125 | -+def make_sparse_index(length, indices, kind): | ||
126 | - | ||
127 | - if kind == "block" or isinstance(kind, BlockIndex): | ||
128 | - locs, lens = splib.get_blocks(indices) | ||
129 | -diff --git a/pandas/core/computation/engines.py b/pandas/core/computation/engines.py | ||
130 | -index 0cdc0f530..77a378369 100644 | ||
131 | ---- a/pandas/core/computation/engines.py | ||
132 | -+++ b/pandas/core/computation/engines.py | ||
133 | -@@ -130,7 +130,7 @@ class PythonEngine(AbstractEngine): | ||
134 | - pass | ||
135 | - | ||
136 | - | ||
137 | --_engines: Dict[str, Type[AbstractEngine]] = { | ||
138 | -+ENGINES: Dict[str, Type[AbstractEngine]] = { | ||
139 | - "numexpr": NumExprEngine, | ||
140 | - "python": PythonEngine, | ||
141 | - } | ||
142 | -diff --git a/pandas/core/computation/eval.py b/pandas/core/computation/eval.py | ||
143 | -index f6a793514..630606b4d 100644 | ||
144 | ---- a/pandas/core/computation/eval.py | ||
145 | -+++ b/pandas/core/computation/eval.py | ||
146 | -@@ -9,8 +9,8 @@ import warnings | ||
147 | - from pandas._libs.lib import no_default | ||
148 | - from pandas.util._validators import validate_bool_kwarg | ||
149 | - | ||
150 | --from pandas.core.computation.engines import _engines | ||
151 | --from pandas.core.computation.expr import Expr, _parsers | ||
152 | -+from pandas.core.computation.engines import ENGINES | ||
153 | -+from pandas.core.computation.expr import PARSERS, Expr | ||
154 | - from pandas.core.computation.parsing import tokenize_string | ||
155 | - from pandas.core.computation.scope import ensure_scope | ||
156 | - | ||
157 | -@@ -43,8 +43,8 @@ def _check_engine(engine: Optional[str]) -> str: | ||
158 | - if engine is None: | ||
159 | - engine = "numexpr" if NUMEXPR_INSTALLED else "python" | ||
160 | - | ||
161 | -- if engine not in _engines: | ||
162 | -- valid_engines = list(_engines.keys()) | ||
163 | -+ if engine not in ENGINES: | ||
164 | -+ valid_engines = list(ENGINES.keys()) | ||
165 | - raise KeyError( | ||
166 | - f"Invalid engine '{engine}' passed, valid engines are {valid_engines}" | ||
167 | - ) | ||
168 | -@@ -75,9 +75,9 @@ def _check_parser(parser: str): | ||
169 | - KeyError | ||
170 | - * If an invalid parser is passed | ||
171 | - """ | ||
172 | -- if parser not in _parsers: | ||
173 | -+ if parser not in PARSERS: | ||
174 | - raise KeyError( | ||
175 | -- f"Invalid parser '{parser}' passed, valid parsers are {_parsers.keys()}" | ||
176 | -+ f"Invalid parser '{parser}' passed, valid parsers are {PARSERS.keys()}" | ||
177 | - ) | ||
178 | - | ||
179 | - | ||
180 | -@@ -341,7 +341,7 @@ def eval( | ||
181 | - parsed_expr = Expr(expr, engine=engine, parser=parser, env=env) | ||
182 | - | ||
183 | - # construct the engine and evaluate the parsed expression | ||
184 | -- eng = _engines[engine] | ||
185 | -+ eng = ENGINES[engine] | ||
186 | - eng_inst = eng(parsed_expr) | ||
187 | - ret = eng_inst.evaluate() | ||
188 | - | ||
189 | -diff --git a/pandas/core/computation/expr.py b/pandas/core/computation/expr.py | ||
190 | -index 8cff6abc0..f5897277d 100644 | ||
191 | ---- a/pandas/core/computation/expr.py | ||
192 | -+++ b/pandas/core/computation/expr.py | ||
193 | -@@ -782,7 +782,7 @@ class Expr: | ||
194 | - self.env = env or Scope(level=level + 1) | ||
195 | - self.engine = engine | ||
196 | - self.parser = parser | ||
197 | -- self._visitor = _parsers[parser](self.env, self.engine, self.parser) | ||
198 | -+ self._visitor = PARSERS[parser](self.env, self.engine, self.parser) | ||
199 | - self.terms = self.parse() | ||
200 | - | ||
201 | - @property | ||
202 | -@@ -814,4 +814,4 @@ class Expr: | ||
203 | - return frozenset(term.name for term in com.flatten(self.terms)) | ||
204 | - | ||
205 | - | ||
206 | --_parsers = {"python": PythonExprVisitor, "pandas": PandasExprVisitor} | ||
207 | -+PARSERS = {"python": PythonExprVisitor, "pandas": PandasExprVisitor} | ||
208 | -diff --git a/pandas/core/config_init.py b/pandas/core/config_init.py | ||
209 | -index 0c23f1b4b..bfe20551c 100644 | ||
210 | ---- a/pandas/core/config_init.py | ||
211 | -+++ b/pandas/core/config_init.py | ||
212 | -@@ -314,9 +314,9 @@ pc_latex_multirow = """ | ||
213 | - | ||
214 | - | ||
215 | - def table_schema_cb(key): | ||
216 | -- from pandas.io.formats.printing import _enable_data_resource_formatter | ||
217 | -+ from pandas.io.formats.printing import enable_data_resource_formatter | ||
218 | - | ||
219 | -- _enable_data_resource_formatter(cf.get_option(key)) | ||
220 | -+ enable_data_resource_formatter(cf.get_option(key)) | ||
221 | - | ||
222 | - | ||
223 | - def is_terminal() -> bool: | ||
224 | -diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py | ||
225 | -index 72003eab2..e870187fc 100644 | ||
226 | ---- a/pandas/core/groupby/generic.py | ||
227 | -+++ b/pandas/core/groupby/generic.py | ||
228 | -@@ -70,9 +70,9 @@ from pandas.core.groupby.groupby import ( | ||
229 | - GroupBy, | ||
230 | - _agg_template, | ||
231 | - _apply_docs, | ||
232 | -- _group_selection_context, | ||
233 | - _transform_template, | ||
234 | - get_groupby, | ||
235 | -+ group_selection_context, | ||
236 | - ) | ||
237 | - from pandas.core.groupby.numba_ import generate_numba_func, split_for_numba | ||
238 | - from pandas.core.indexes.api import Index, MultiIndex, all_indexes_same | ||
239 | -@@ -230,7 +230,7 @@ class SeriesGroupBy(GroupBy[Series]): | ||
240 | - raise NotImplementedError( | ||
241 | - "Numba engine can only be used with a single function." | ||
242 | - ) | ||
243 | -- with _group_selection_context(self): | ||
244 | -+ with group_selection_context(self): | ||
245 | - data = self._selected_obj | ||
246 | - result, index = self._aggregate_with_numba( | ||
247 | - data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs | ||
248 | -@@ -685,7 +685,7 @@ class SeriesGroupBy(GroupBy[Series]): | ||
249 | - self, normalize=False, sort=True, ascending=False, bins=None, dropna=True | ||
250 | - ): | ||
251 | - | ||
252 | -- from pandas.core.reshape.merge import _get_join_indexers | ||
253 | -+ from pandas.core.reshape.merge import get_join_indexers | ||
254 | - from pandas.core.reshape.tile import cut | ||
255 | - | ||
256 | - if bins is not None and not np.iterable(bins): | ||
257 | -@@ -787,7 +787,7 @@ class SeriesGroupBy(GroupBy[Series]): | ||
258 | - | ||
259 | - right = [diff.cumsum() - 1, codes[-1]] | ||
260 | - | ||
261 | -- _, idx = _get_join_indexers(left, right, sort=False, how="left") | ||
262 | -+ _, idx = get_join_indexers(left, right, sort=False, how="left") | ||
263 | - out = np.where(idx != -1, out[idx], 0) | ||
264 | - | ||
265 | - if sort: | ||
266 | -@@ -942,7 +942,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]): | ||
267 | - raise NotImplementedError( | ||
268 | - "Numba engine can only be used with a single function." | ||
269 | - ) | ||
270 | -- with _group_selection_context(self): | ||
271 | -+ with group_selection_context(self): | ||
272 | - data = self._selected_obj | ||
273 | - result, index = self._aggregate_with_numba( | ||
274 | - data, func, *args, engine_kwargs=engine_kwargs, **kwargs | ||
275 | -diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py | ||
276 | -index 6ef2e6703..1e3e56f4f 100644 | ||
277 | ---- a/pandas/core/groupby/groupby.py | ||
278 | -+++ b/pandas/core/groupby/groupby.py | ||
279 | -@@ -459,9 +459,9 @@ class GroupByPlot(PandasObject): | ||
280 | - | ||
281 | - | ||
282 | - @contextmanager | ||
283 | --def _group_selection_context(groupby: "_GroupBy"): | ||
284 | -+def group_selection_context(groupby: "_GroupBy"): | ||
285 | - """ | ||
286 | -- Set / reset the _group_selection_context. | ||
287 | -+ Set / reset the group_selection_context. | ||
288 | - """ | ||
289 | - groupby._set_group_selection() | ||
290 | - try: | ||
291 | -@@ -737,7 +737,7 @@ b 2""", | ||
292 | - def _make_wrapper(self, name: str) -> Callable: | ||
293 | - assert name in self._apply_allowlist | ||
294 | - | ||
295 | -- with _group_selection_context(self): | ||
296 | -+ with group_selection_context(self): | ||
297 | - # need to setup the selection | ||
298 | - # as are not passed directly but in the grouper | ||
299 | - f = getattr(self._obj_with_exclusions, name) | ||
300 | -@@ -868,7 +868,7 @@ b 2""", | ||
301 | - # fails on *some* columns, e.g. a numeric operation | ||
302 | - # on a string grouper column | ||
303 | - | ||
304 | -- with _group_selection_context(self): | ||
305 | -+ with group_selection_context(self): | ||
306 | - return self._python_apply_general(f, self._selected_obj) | ||
307 | - | ||
308 | - return result | ||
309 | -@@ -994,7 +994,7 @@ b 2""", | ||
310 | - alias: str, | ||
311 | - npfunc: Callable, | ||
312 | - ): | ||
313 | -- with _group_selection_context(self): | ||
314 | -+ with group_selection_context(self): | ||
315 | - # try a cython aggregation if we can | ||
316 | - try: | ||
317 | - return self._cython_agg_general( | ||
318 | -@@ -1499,7 +1499,7 @@ class GroupBy(_GroupBy[FrameOrSeries]): | ||
319 | - ) | ||
320 | - else: | ||
321 | - func = lambda x: x.var(ddof=ddof) | ||
322 | -- with _group_selection_context(self): | ||
323 | -+ with group_selection_context(self): | ||
324 | - return self._python_agg_general(func) | ||
325 | - | ||
326 | - @Substitution(name="groupby") | ||
327 | -@@ -1658,7 +1658,7 @@ class GroupBy(_GroupBy[FrameOrSeries]): | ||
328 | - | ||
329 | - @doc(DataFrame.describe) | ||
330 | - def describe(self, **kwargs): | ||
331 | -- with _group_selection_context(self): | ||
332 | -+ with group_selection_context(self): | ||
333 | - result = self.apply(lambda x: x.describe(**kwargs)) | ||
334 | - if self.axis == 1: | ||
335 | - return result.T | ||
336 | -@@ -1963,7 +1963,7 @@ class GroupBy(_GroupBy[FrameOrSeries]): | ||
337 | - nth_values = list(set(n)) | ||
338 | - | ||
339 | - nth_array = np.array(nth_values, dtype=np.intp) | ||
340 | -- with _group_selection_context(self): | ||
341 | -+ with group_selection_context(self): | ||
342 | - | ||
343 | - mask_left = np.in1d(self._cumcount_array(), nth_array) | ||
344 | - mask_right = np.in1d( | ||
345 | -@@ -2226,7 +2226,7 @@ class GroupBy(_GroupBy[FrameOrSeries]): | ||
346 | - 5 0 | ||
347 | - dtype: int64 | ||
348 | - """ | ||
349 | -- with _group_selection_context(self): | ||
350 | -+ with group_selection_context(self): | ||
351 | - index = self._selected_obj.index | ||
352 | - result = self._obj_1d_constructor(self.grouper.group_info[0], index) | ||
353 | - if not ascending: | ||
354 | -@@ -2287,7 +2287,7 @@ class GroupBy(_GroupBy[FrameOrSeries]): | ||
355 | - 5 0 | ||
356 | - dtype: int64 | ||
357 | - """ | ||
358 | -- with _group_selection_context(self): | ||
359 | -+ with group_selection_context(self): | ||
360 | - index = self._selected_obj.index | ||
361 | - cumcounts = self._cumcount_array(ascending=ascending) | ||
362 | - return self._obj_1d_constructor(cumcounts, index) | ||
363 | -diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py | ||
364 | -index 526dae7e2..8014b16d0 100644 | ||
365 | ---- a/pandas/core/indexes/base.py | ||
366 | -+++ b/pandas/core/indexes/base.py | ||
367 | -@@ -3660,7 +3660,7 @@ class Index(IndexOpsMixin, PandasObject): | ||
368 | - return result | ||
369 | - | ||
370 | - def _join_non_unique(self, other, how="left", return_indexers=False): | ||
371 | -- from pandas.core.reshape.merge import _get_join_indexers | ||
372 | -+ from pandas.core.reshape.merge import get_join_indexers | ||
373 | - | ||
374 | - # We only get here if dtypes match | ||
375 | - assert self.dtype == other.dtype | ||
376 | -@@ -3668,7 +3668,7 @@ class Index(IndexOpsMixin, PandasObject): | ||
377 | - lvalues = self._get_engine_target() | ||
378 | - rvalues = other._get_engine_target() | ||
379 | - | ||
380 | -- left_idx, right_idx = _get_join_indexers( | ||
381 | -+ left_idx, right_idx = get_join_indexers( | ||
382 | - [lvalues], [rvalues], how=how, sort=True | ||
383 | - ) | ||
384 | - | ||
385 | -diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py | ||
386 | -index 3f72577c9..154f41bf0 100644 | ||
387 | ---- a/pandas/core/indexes/interval.py | ||
388 | -+++ b/pandas/core/indexes/interval.py | ||
389 | -@@ -59,7 +59,6 @@ from pandas.core.ops import get_op_result_name | ||
390 | - if TYPE_CHECKING: | ||
391 | - from pandas import CategoricalIndex # noqa:F401 | ||
392 | - | ||
393 | --_VALID_CLOSED = {"left", "right", "both", "neither"} | ||
394 | - _index_doc_kwargs = dict(ibase._index_doc_kwargs) | ||
395 | - | ||
396 | - _index_doc_kwargs.update( | ||
397 | -diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py | ||
398 | -index 030dec369..9f19ea9ae 100644 | ||
399 | ---- a/pandas/core/reshape/merge.py | ||
400 | -+++ b/pandas/core/reshape/merge.py | ||
401 | -@@ -859,7 +859,7 @@ class _MergeOperation: | ||
402 | - | ||
403 | - def _get_join_indexers(self): | ||
404 | - """ return the join indexers """ | ||
405 | -- return _get_join_indexers( | ||
406 | -+ return get_join_indexers( | ||
407 | - self.left_join_keys, self.right_join_keys, sort=self.sort, how=self.how | ||
408 | - ) | ||
409 | - | ||
410 | -@@ -1298,7 +1298,7 @@ class _MergeOperation: | ||
411 | - raise ValueError("Not a valid argument for validate") | ||
412 | - | ||
413 | - | ||
414 | --def _get_join_indexers( | ||
415 | -+def get_join_indexers( | ||
416 | - left_keys, right_keys, sort: bool = False, how: str = "inner", **kwargs | ||
417 | - ): | ||
418 | - """ | ||
419 | -diff --git a/pandas/io/formats/printing.py b/pandas/io/formats/printing.py | ||
420 | -index edc6fbfff..0d2ca83f1 100644 | ||
421 | ---- a/pandas/io/formats/printing.py | ||
422 | -+++ b/pandas/io/formats/printing.py | ||
423 | -@@ -243,7 +243,7 @@ def pprint_thing_encoded( | ||
424 | - return value.encode(encoding, errors) | ||
425 | - | ||
426 | - | ||
427 | --def _enable_data_resource_formatter(enable: bool) -> None: | ||
428 | -+def enable_data_resource_formatter(enable: bool) -> None: | ||
429 | - if "IPython" not in sys.modules: | ||
430 | - # definitely not in IPython | ||
431 | - return | ||
432 | -diff --git a/pandas/tests/arrays/sparse/test_libsparse.py b/pandas/tests/arrays/sparse/test_libsparse.py | ||
433 | -index a2f861d37..2d6e657de 100644 | ||
434 | ---- a/pandas/tests/arrays/sparse/test_libsparse.py | ||
435 | -+++ b/pandas/tests/arrays/sparse/test_libsparse.py | ||
436 | -@@ -8,7 +8,7 @@ import pandas.util._test_decorators as td | ||
437 | - | ||
438 | - from pandas import Series | ||
439 | - import pandas._testing as tm | ||
440 | --from pandas.core.arrays.sparse import BlockIndex, IntIndex, _make_index | ||
441 | -+from pandas.core.arrays.sparse import BlockIndex, IntIndex, make_sparse_index | ||
442 | - | ||
443 | - TEST_LENGTH = 20 | ||
444 | - | ||
445 | -@@ -273,41 +273,43 @@ class TestSparseIndexIntersect: | ||
446 | - | ||
447 | - class TestSparseIndexCommon: | ||
448 | - def test_int_internal(self): | ||
449 | -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind="integer") | ||
450 | -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind="integer") | ||
451 | - assert isinstance(idx, IntIndex) | ||
452 | - assert idx.npoints == 2 | ||
453 | - tm.assert_numpy_array_equal(idx.indices, np.array([2, 3], dtype=np.int32)) | ||
454 | - | ||
455 | -- idx = _make_index(4, np.array([], dtype=np.int32), kind="integer") | ||
456 | -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind="integer") | ||
457 | - assert isinstance(idx, IntIndex) | ||
458 | - assert idx.npoints == 0 | ||
459 | - tm.assert_numpy_array_equal(idx.indices, np.array([], dtype=np.int32)) | ||
460 | - | ||
461 | -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="integer") | ||
462 | -+ idx = make_sparse_index( | ||
463 | -+ 4, np.array([0, 1, 2, 3], dtype=np.int32), kind="integer" | ||
464 | -+ ) | ||
465 | - assert isinstance(idx, IntIndex) | ||
466 | - assert idx.npoints == 4 | ||
467 | - tm.assert_numpy_array_equal(idx.indices, np.array([0, 1, 2, 3], dtype=np.int32)) | ||
468 | - | ||
469 | - def test_block_internal(self): | ||
470 | -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind="block") | ||
471 | -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind="block") | ||
472 | - assert isinstance(idx, BlockIndex) | ||
473 | - assert idx.npoints == 2 | ||
474 | - tm.assert_numpy_array_equal(idx.blocs, np.array([2], dtype=np.int32)) | ||
475 | - tm.assert_numpy_array_equal(idx.blengths, np.array([2], dtype=np.int32)) | ||
476 | - | ||
477 | -- idx = _make_index(4, np.array([], dtype=np.int32), kind="block") | ||
478 | -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind="block") | ||
479 | - assert isinstance(idx, BlockIndex) | ||
480 | - assert idx.npoints == 0 | ||
481 | - tm.assert_numpy_array_equal(idx.blocs, np.array([], dtype=np.int32)) | ||
482 | - tm.assert_numpy_array_equal(idx.blengths, np.array([], dtype=np.int32)) | ||
483 | - | ||
484 | -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="block") | ||
485 | -+ idx = make_sparse_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="block") | ||
486 | - assert isinstance(idx, BlockIndex) | ||
487 | - assert idx.npoints == 4 | ||
488 | - tm.assert_numpy_array_equal(idx.blocs, np.array([0], dtype=np.int32)) | ||
489 | - tm.assert_numpy_array_equal(idx.blengths, np.array([4], dtype=np.int32)) | ||
490 | - | ||
491 | -- idx = _make_index(4, np.array([0, 2, 3], dtype=np.int32), kind="block") | ||
492 | -+ idx = make_sparse_index(4, np.array([0, 2, 3], dtype=np.int32), kind="block") | ||
493 | - assert isinstance(idx, BlockIndex) | ||
494 | - assert idx.npoints == 3 | ||
495 | - tm.assert_numpy_array_equal(idx.blocs, np.array([0, 2], dtype=np.int32)) | ||
496 | -@@ -315,7 +317,7 @@ class TestSparseIndexCommon: | ||
497 | - | ||
498 | - def test_lookup(self): | ||
499 | - for kind in ["integer", "block"]: | ||
500 | -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind=kind) | ||
501 | -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind=kind) | ||
502 | - assert idx.lookup(-1) == -1 | ||
503 | - assert idx.lookup(0) == -1 | ||
504 | - assert idx.lookup(1) == -1 | ||
505 | -@@ -323,12 +325,14 @@ class TestSparseIndexCommon: | ||
506 | - assert idx.lookup(3) == 1 | ||
507 | - assert idx.lookup(4) == -1 | ||
508 | - | ||
509 | -- idx = _make_index(4, np.array([], dtype=np.int32), kind=kind) | ||
510 | -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind=kind) | ||
511 | - | ||
512 | - for i in range(-1, 5): | ||
513 | - assert idx.lookup(i) == -1 | ||
514 | - | ||
515 | -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind=kind) | ||
516 | -+ idx = make_sparse_index( | ||
517 | -+ 4, np.array([0, 1, 2, 3], dtype=np.int32), kind=kind | ||
518 | -+ ) | ||
519 | - assert idx.lookup(-1) == -1 | ||
520 | - assert idx.lookup(0) == 0 | ||
521 | - assert idx.lookup(1) == 1 | ||
522 | -@@ -336,7 +340,7 @@ class TestSparseIndexCommon: | ||
523 | - assert idx.lookup(3) == 3 | ||
524 | - assert idx.lookup(4) == -1 | ||
525 | - | ||
526 | -- idx = _make_index(4, np.array([0, 2, 3], dtype=np.int32), kind=kind) | ||
527 | -+ idx = make_sparse_index(4, np.array([0, 2, 3], dtype=np.int32), kind=kind) | ||
528 | - assert idx.lookup(-1) == -1 | ||
529 | - assert idx.lookup(0) == 0 | ||
530 | - assert idx.lookup(1) == -1 | ||
531 | -@@ -346,7 +350,7 @@ class TestSparseIndexCommon: | ||
532 | - | ||
533 | - def test_lookup_array(self): | ||
534 | - for kind in ["integer", "block"]: | ||
535 | -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind=kind) | ||
536 | -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind=kind) | ||
537 | - | ||
538 | - res = idx.lookup_array(np.array([-1, 0, 2], dtype=np.int32)) | ||
539 | - exp = np.array([-1, -1, 0], dtype=np.int32) | ||
540 | -@@ -356,11 +360,13 @@ class TestSparseIndexCommon: | ||
541 | - exp = np.array([-1, 0, -1, 1], dtype=np.int32) | ||
542 | - tm.assert_numpy_array_equal(res, exp) | ||
543 | - | ||
544 | -- idx = _make_index(4, np.array([], dtype=np.int32), kind=kind) | ||
545 | -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind=kind) | ||
546 | - res = idx.lookup_array(np.array([-1, 0, 2, 4], dtype=np.int32)) | ||
547 | - exp = np.array([-1, -1, -1, -1], dtype=np.int32) | ||
548 | - | ||
549 | -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind=kind) | ||
550 | -+ idx = make_sparse_index( | ||
551 | -+ 4, np.array([0, 1, 2, 3], dtype=np.int32), kind=kind | ||
552 | -+ ) | ||
553 | - res = idx.lookup_array(np.array([-1, 0, 2], dtype=np.int32)) | ||
554 | - exp = np.array([-1, 0, 2], dtype=np.int32) | ||
555 | - tm.assert_numpy_array_equal(res, exp) | ||
556 | -@@ -369,7 +375,7 @@ class TestSparseIndexCommon: | ||
557 | - exp = np.array([-1, 2, 1, 3], dtype=np.int32) | ||
558 | - tm.assert_numpy_array_equal(res, exp) | ||
559 | - | ||
560 | -- idx = _make_index(4, np.array([0, 2, 3], dtype=np.int32), kind=kind) | ||
561 | -+ idx = make_sparse_index(4, np.array([0, 2, 3], dtype=np.int32), kind=kind) | ||
562 | - res = idx.lookup_array(np.array([2, 1, 3, 0], dtype=np.int32)) | ||
563 | - exp = np.array([1, -1, 2, 0], dtype=np.int32) | ||
564 | - tm.assert_numpy_array_equal(res, exp) | ||
565 | -@@ -402,25 +408,25 @@ class TestSparseIndexCommon: | ||
566 | - | ||
567 | - class TestBlockIndex: | ||
568 | - def test_block_internal(self): | ||
569 | -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind="block") | ||
570 | -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind="block") | ||
571 | - assert isinstance(idx, BlockIndex) | ||
572 | - assert idx.npoints == 2 | ||
573 | - tm.assert_numpy_array_equal(idx.blocs, np.array([2], dtype=np.int32)) | ||
574 | - tm.assert_numpy_array_equal(idx.blengths, np.array([2], dtype=np.int32)) | ||
575 | - | ||
576 | -- idx = _make_index(4, np.array([], dtype=np.int32), kind="block") | ||
577 | -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind="block") | ||
578 | - assert isinstance(idx, BlockIndex) | ||
579 | - assert idx.npoints == 0 | ||
580 | - tm.assert_numpy_array_equal(idx.blocs, np.array([], dtype=np.int32)) | ||
581 | - tm.assert_numpy_array_equal(idx.blengths, np.array([], dtype=np.int32)) | ||
582 | - | ||
583 | -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="block") | ||
584 | -+ idx = make_sparse_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="block") | ||
585 | - assert isinstance(idx, BlockIndex) | ||
586 | - assert idx.npoints == 4 | ||
587 | - tm.assert_numpy_array_equal(idx.blocs, np.array([0], dtype=np.int32)) | ||
588 | - tm.assert_numpy_array_equal(idx.blengths, np.array([4], dtype=np.int32)) | ||
589 | - | ||
590 | -- idx = _make_index(4, np.array([0, 2, 3], dtype=np.int32), kind="block") | ||
591 | -+ idx = make_sparse_index(4, np.array([0, 2, 3], dtype=np.int32), kind="block") | ||
592 | - assert isinstance(idx, BlockIndex) | ||
593 | - assert idx.npoints == 3 | ||
594 | - tm.assert_numpy_array_equal(idx.blocs, np.array([0, 2], dtype=np.int32)) | ||
595 | -@@ -428,7 +434,7 @@ class TestBlockIndex: | ||
596 | - | ||
597 | - def test_make_block_boundary(self): | ||
598 | - for i in [5, 10, 100, 101]: | ||
599 | -- idx = _make_index(i, np.arange(0, i, 2, dtype=np.int32), kind="block") | ||
600 | -+ idx = make_sparse_index(i, np.arange(0, i, 2, dtype=np.int32), kind="block") | ||
601 | - | ||
602 | - exp = np.arange(0, i, 2, dtype=np.int32) | ||
603 | - tm.assert_numpy_array_equal(idx.blocs, exp) | ||
604 | -@@ -514,17 +520,19 @@ class TestIntIndex: | ||
605 | - IntIndex(length=5, indices=[1, 3, 3]) | ||
606 | - | ||
607 | - def test_int_internal(self): | ||
608 | -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind="integer") | ||
609 | -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind="integer") | ||
610 | - assert isinstance(idx, IntIndex) | ||
611 | - assert idx.npoints == 2 | ||
612 | - tm.assert_numpy_array_equal(idx.indices, np.array([2, 3], dtype=np.int32)) | ||
613 | - | ||
614 | -- idx = _make_index(4, np.array([], dtype=np.int32), kind="integer") | ||
615 | -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind="integer") | ||
616 | - assert isinstance(idx, IntIndex) | ||
617 | - assert idx.npoints == 0 | ||
618 | - tm.assert_numpy_array_equal(idx.indices, np.array([], dtype=np.int32)) | ||
619 | - | ||
620 | -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="integer") | ||
621 | -+ idx = make_sparse_index( | ||
622 | -+ 4, np.array([0, 1, 2, 3], dtype=np.int32), kind="integer" | ||
623 | -+ ) | ||
624 | - assert isinstance(idx, IntIndex) | ||
625 | - assert idx.npoints == 4 | ||
626 | - tm.assert_numpy_array_equal(idx.indices, np.array([0, 1, 2, 3], dtype=np.int32)) | ||
627 | -diff --git a/pandas/tests/computation/test_compat.py b/pandas/tests/computation/test_compat.py | ||
628 | -index ead102f53..9fc3ed480 100644 | ||
629 | ---- a/pandas/tests/computation/test_compat.py | ||
630 | -+++ b/pandas/tests/computation/test_compat.py | ||
631 | -@@ -5,7 +5,7 @@ import pytest | ||
632 | - from pandas.compat._optional import VERSIONS | ||
633 | - | ||
634 | - import pandas as pd | ||
635 | --from pandas.core.computation.engines import _engines | ||
636 | -+from pandas.core.computation.engines import ENGINES | ||
637 | - import pandas.core.computation.expr as expr | ||
638 | - | ||
639 | - | ||
640 | -@@ -26,8 +26,8 @@ def test_compat(): | ||
641 | - pytest.skip("not testing numexpr version compat") | ||
642 | - | ||
643 | - | ||
644 | --@pytest.mark.parametrize("engine", _engines) | ||
645 | --@pytest.mark.parametrize("parser", expr._parsers) | ||
646 | -+@pytest.mark.parametrize("engine", ENGINES) | ||
647 | -+@pytest.mark.parametrize("parser", expr.PARSERS) | ||
648 | - def test_invalid_numexpr_version(engine, parser): | ||
649 | - def testit(): | ||
650 | - a, b = 1, 2 # noqa | ||
651 | -diff --git a/pandas/tests/computation/test_eval.py b/pandas/tests/computation/test_eval.py | ||
652 | -index 72dc04e68..cca64a6bf 100644 | ||
653 | ---- a/pandas/tests/computation/test_eval.py | ||
654 | -+++ b/pandas/tests/computation/test_eval.py | ||
655 | -@@ -19,7 +19,7 @@ from pandas import DataFrame, Series, compat, date_range | ||
656 | - import pandas._testing as tm | ||
657 | - from pandas.core.computation import pytables | ||
658 | - from pandas.core.computation.check import NUMEXPR_VERSION | ||
659 | --from pandas.core.computation.engines import NumExprClobberingError, _engines | ||
660 | -+from pandas.core.computation.engines import ENGINES, NumExprClobberingError | ||
661 | - import pandas.core.computation.expr as expr | ||
662 | - from pandas.core.computation.expr import ( | ||
663 | - BaseExprVisitor, | ||
664 | -@@ -46,14 +46,14 @@ from pandas.core.computation.ops import ( | ||
665 | - f"installed->{NUMEXPR_INSTALLED}", | ||
666 | - ), | ||
667 | - ) | ||
668 | -- for engine in _engines | ||
669 | -+ for engine in ENGINES | ||
670 | - ) | ||
671 | - ) # noqa | ||
672 | - def engine(request): | ||
673 | - return request.param | ||
674 | - | ||
675 | - | ||
676 | --@pytest.fixture(params=expr._parsers) | ||
677 | -+@pytest.fixture(params=expr.PARSERS) | ||
678 | - def parser(request): | ||
679 | - return request.param | ||
680 | - | ||
681 | -@@ -77,7 +77,7 @@ def unary_fns_for_ne(): | ||
682 | - | ||
683 | - | ||
684 | - def engine_has_neg_frac(engine): | ||
685 | -- return _engines[engine].has_neg_frac | ||
686 | -+ return ENGINES[engine].has_neg_frac | ||
687 | - | ||
688 | - | ||
689 | - def _eval_single_bin(lhs, cmp1, rhs, engine): | ||
690 | -@@ -168,7 +168,7 @@ class TestEvalNumexprPandas: | ||
691 | - def setup_method(self, method): | ||
692 | - self.setup_ops() | ||
693 | - self.setup_data() | ||
694 | -- self.current_engines = (engine for engine in _engines if engine != self.engine) | ||
695 | -+ self.current_engines = (engine for engine in ENGINES if engine != self.engine) | ||
696 | - | ||
697 | - def teardown_method(self, method): | ||
698 | - del self.lhses, self.rhses, self.scalar_rhses, self.scalar_lhses | ||
699 | -@@ -1921,7 +1921,7 @@ _parsers: Dict[str, Type[BaseExprVisitor]] = { | ||
700 | - } | ||
701 | - | ||
702 | - | ||
703 | --@pytest.mark.parametrize("engine", _engines) | ||
704 | -+@pytest.mark.parametrize("engine", ENGINES) | ||
705 | - @pytest.mark.parametrize("parser", _parsers) | ||
706 | - def test_disallowed_nodes(engine, parser): | ||
707 | - VisitorClass = _parsers[parser] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
train.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -import os | ||
16 | -import argparse | ||
17 | -import pytorch_lightning as pl | ||
18 | -from train.finetune import main, SummarizationModule | ||
19 | - | ||
20 | -if __name__ == "__main__": | ||
21 | - parser = argparse.ArgumentParser() | ||
22 | - parser = pl.Trainer.add_argparse_args(parser) | ||
23 | - parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) | ||
24 | - | ||
25 | - args = parser.parse_args() | ||
26 | - | ||
27 | - main(args) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
train/__init__.py
deleted
100644 → 0
1 | -# Copyright 2020-present Tae Hwan Jung | ||
2 | -# | ||
3 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | -# you may not use this file except in compliance with the License. | ||
5 | -# You may obtain a copy of the License at | ||
6 | -# | ||
7 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | -# | ||
9 | -# Unless required by applicable law or agreed to in writing, software | ||
10 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | -# See the License for the specific language governing permissions and | ||
13 | -# limitations under the License. | ||
14 | - | ||
15 | -from train.modeling_bart import BartForConditionalGeneration | ||
16 | - | ||
17 | -__all__ = ["BartForConditionalGeneration"] |
train/callbacks.py
deleted
100644 → 0
1 | -import logging | ||
2 | -import os | ||
3 | -from pathlib import Path | ||
4 | - | ||
5 | -import numpy as np | ||
6 | -import pytorch_lightning as pl | ||
7 | -import torch | ||
8 | -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | ||
9 | -from pytorch_lightning.utilities import rank_zero_only | ||
10 | - | ||
11 | - | ||
12 | -def count_trainable_parameters(model): | ||
13 | - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||
14 | - params = sum([np.prod(p.size()) for p in model_parameters]) | ||
15 | - return params | ||
16 | - | ||
17 | - | ||
18 | -logger = logging.getLogger(__name__) | ||
19 | - | ||
20 | - | ||
21 | -class Seq2SeqLoggingCallback(pl.Callback): | ||
22 | - def on_batch_end(self, trainer, pl_module): | ||
23 | - lrs = { | ||
24 | - f"lr_group_{i}": param["lr"] | ||
25 | - for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups) | ||
26 | - } | ||
27 | - pl_module.logger.log_metrics(lrs) | ||
28 | - | ||
29 | - @rank_zero_only | ||
30 | - def _write_logs( | ||
31 | - self, | ||
32 | - trainer: pl.Trainer, | ||
33 | - pl_module: pl.LightningModule, | ||
34 | - type_path: str, | ||
35 | - save_generations=True, | ||
36 | - ) -> None: | ||
37 | - logger.info( | ||
38 | - f"***** {type_path} results at step {trainer.global_step:05d} *****" | ||
39 | - ) | ||
40 | - metrics = trainer.callback_metrics | ||
41 | - trainer.logger.log_metrics( | ||
42 | - { | ||
43 | - k: v | ||
44 | - for k, v in metrics.items() | ||
45 | - if k not in ["log", "progress_bar", "preds"] | ||
46 | - } | ||
47 | - ) | ||
48 | - # Log results | ||
49 | - od = Path(pl_module.hparams.output_dir) | ||
50 | - if type_path == "test": | ||
51 | - results_file = od / "test_results.txt" | ||
52 | - generations_file = od / "test_generations.txt" | ||
53 | - else: | ||
54 | - # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json | ||
55 | - # If people want this it will be easy enough to add back. | ||
56 | - results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" | ||
57 | - generations_file = ( | ||
58 | - od / f"{type_path}_generations/{trainer.global_step:05d}.txt" | ||
59 | - ) | ||
60 | - results_file.parent.mkdir(exist_ok=True) | ||
61 | - generations_file.parent.mkdir(exist_ok=True) | ||
62 | - with open(results_file, "a+") as writer: | ||
63 | - for key in sorted(metrics): | ||
64 | - if key in ["log", "progress_bar", "preds"]: | ||
65 | - continue | ||
66 | - val = metrics[key] | ||
67 | - if isinstance(val, torch.Tensor): | ||
68 | - val = val.item() | ||
69 | - msg = f"{key}: {val:.6f}\n" | ||
70 | - writer.write(msg) | ||
71 | - | ||
72 | - if not save_generations: | ||
73 | - return | ||
74 | - | ||
75 | - if "preds" in metrics: | ||
76 | - content = "\n".join(metrics["preds"]) | ||
77 | - generations_file.open("w+").write(content) | ||
78 | - | ||
79 | - @rank_zero_only | ||
80 | - def on_train_start(self, trainer, pl_module): | ||
81 | - try: | ||
82 | - npars = pl_module.model.model.num_parameters() | ||
83 | - except AttributeError: | ||
84 | - npars = pl_module.model.num_parameters() | ||
85 | - | ||
86 | - n_trainable_pars = count_trainable_parameters(pl_module) | ||
87 | - # mp stands for million parameters | ||
88 | - trainer.logger.log_metrics( | ||
89 | - {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6} | ||
90 | - ) | ||
91 | - | ||
92 | - @rank_zero_only | ||
93 | - def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
94 | - return self._write_logs(trainer, pl_module, "test") | ||
95 | - | ||
96 | - | ||
97 | -def get_checkpoint_callback(output_dir, metric): | ||
98 | - """Saves the best model by validation ROUGE2 score.""" | ||
99 | - if metric == "rouge2": | ||
100 | - exp = "{val_avg_rouge2:.4f}-{step_count}" | ||
101 | - elif metric == "bleu": | ||
102 | - exp = "{val_avg_bleu:.4f}-{step_count}" | ||
103 | - else: | ||
104 | - raise NotImplementedError( | ||
105 | - f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." | ||
106 | - ) | ||
107 | - | ||
108 | - checkpoint_callback = ModelCheckpoint( | ||
109 | - filepath=os.path.join(output_dir, exp), | ||
110 | - monitor=f"val_{metric}", | ||
111 | - mode="max", | ||
112 | - save_top_k=1, | ||
113 | - period=0, # maybe save a checkpoint every time val is run, not just end of epoch. | ||
114 | - ) | ||
115 | - return checkpoint_callback | ||
116 | - | ||
117 | - | ||
118 | -def get_early_stopping_callback(metric, patience): | ||
119 | - return EarlyStopping( | ||
120 | - monitor=f"val_{metric}", mode="max", patience=patience, verbose=True, | ||
121 | - ) |
train/finetune.py
deleted
100644 → 0
1 | -import argparse | ||
2 | -import glob | ||
3 | -import logging | ||
4 | -import os | ||
5 | -import time | ||
6 | -from collections import defaultdict | ||
7 | -from pathlib import Path | ||
8 | -from typing import Dict, List, Tuple | ||
9 | - | ||
10 | -import numpy as np | ||
11 | -import pytorch_lightning as pl | ||
12 | -import torch | ||
13 | -from torch.utils.data import DataLoader | ||
14 | - | ||
15 | -from train.lightning_base import BaseTransformer, add_generic_args, generic_train | ||
16 | -from transformers import MBartTokenizer, T5ForConditionalGeneration | ||
17 | -from transformers.modeling_bart import shift_tokens_right | ||
18 | - | ||
19 | -from matorage import DataConfig | ||
20 | -from matorage.torch import Dataset | ||
21 | - | ||
22 | - | ||
23 | -try: | ||
24 | - from .callbacks import ( | ||
25 | - Seq2SeqLoggingCallback, | ||
26 | - get_checkpoint_callback, | ||
27 | - get_early_stopping_callback, | ||
28 | - ) | ||
29 | - from .utils import ( | ||
30 | - ROUGE_KEYS, | ||
31 | - LegacySeq2SeqDataset, | ||
32 | - Seq2SeqDataset, | ||
33 | - assert_all_frozen, | ||
34 | - calculate_bleu, | ||
35 | - calculate_rouge, | ||
36 | - flatten_list, | ||
37 | - freeze_params, | ||
38 | - get_git_info, | ||
39 | - label_smoothed_nll_loss, | ||
40 | - lmap, | ||
41 | - pickle_save, | ||
42 | - save_git_info, | ||
43 | - save_json, | ||
44 | - use_task_specific_params, | ||
45 | - ) | ||
46 | -except ImportError: | ||
47 | - from callbacks import ( | ||
48 | - Seq2SeqLoggingCallback, | ||
49 | - get_checkpoint_callback, | ||
50 | - get_early_stopping_callback, | ||
51 | - ) | ||
52 | - from utils import ( | ||
53 | - ROUGE_KEYS, | ||
54 | - LegacySeq2SeqDataset, | ||
55 | - Seq2SeqDataset, | ||
56 | - assert_all_frozen, | ||
57 | - calculate_bleu, | ||
58 | - calculate_rouge, | ||
59 | - flatten_list, | ||
60 | - freeze_params, | ||
61 | - get_git_info, | ||
62 | - label_smoothed_nll_loss, | ||
63 | - lmap, | ||
64 | - pickle_save, | ||
65 | - save_git_info, | ||
66 | - save_json, | ||
67 | - use_task_specific_params, | ||
68 | - ) | ||
69 | - | ||
70 | -logger = logging.getLogger(__name__) | ||
71 | - | ||
72 | - | ||
73 | -class SummarizationModule(BaseTransformer): | ||
74 | - mode = "summarization" | ||
75 | - loss_names = ["loss"] | ||
76 | - metric_names = ROUGE_KEYS | ||
77 | - default_val_metric = "rouge2" | ||
78 | - | ||
79 | - def __init__(self, hparams, **kwargs): | ||
80 | - super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) | ||
81 | - use_task_specific_params(self.model, "summarization") | ||
82 | - save_git_info(self.hparams.output_dir) | ||
83 | - self.metrics_save_path = Path(self.output_dir) / "metrics.json" | ||
84 | - self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" | ||
85 | - pickle_save(self.hparams, self.hparams_save_path) | ||
86 | - self.step_count = 0 | ||
87 | - self.metrics = defaultdict(list) | ||
88 | - | ||
89 | - self.target_lens = { | ||
90 | - "train": self.hparams.max_target_length, | ||
91 | - "val": self.hparams.val_max_target_length, | ||
92 | - "test": self.hparams.test_max_target_length, | ||
93 | - } | ||
94 | - assert ( | ||
95 | - self.target_lens["train"] <= self.target_lens["val"] | ||
96 | - ), f"target_lens: {self.target_lens}" | ||
97 | - assert ( | ||
98 | - self.target_lens["train"] <= self.target_lens["test"] | ||
99 | - ), f"target_lens: {self.target_lens}" | ||
100 | - | ||
101 | - if self.hparams.freeze_embeds: | ||
102 | - self.freeze_embeds() | ||
103 | - if self.hparams.freeze_encoder: | ||
104 | - freeze_params(self.model.get_encoder()) | ||
105 | - assert_all_frozen(self.model.get_encoder()) | ||
106 | - | ||
107 | - self.hparams.git_sha = get_git_info()["repo_sha"] | ||
108 | - self.num_workers = hparams.num_workers | ||
109 | - self.decoder_start_token_id = None # default to config | ||
110 | - if self.model.config.decoder_start_token_id is None and isinstance( | ||
111 | - self.tokenizer, MBartTokenizer | ||
112 | - ): | ||
113 | - self.decoder_start_token_id = self.tokenizer.lang_code_to_id[ | ||
114 | - hparams.tgt_lang | ||
115 | - ] | ||
116 | - self.model.config.decoder_start_token_id = self.decoder_start_token_id | ||
117 | - | ||
118 | - self.eval_beams = ( | ||
119 | - self.model.config.num_beams | ||
120 | - if self.hparams.eval_beams is None | ||
121 | - else self.hparams.eval_beams | ||
122 | - ) | ||
123 | - assert ( | ||
124 | - self.eval_beams >= 1 | ||
125 | - ), f"got self.eval_beams={self.eval_beams}. Need an integer > 1" | ||
126 | - self.val_metric = ( | ||
127 | - self.default_val_metric | ||
128 | - if self.hparams.val_metric is None | ||
129 | - else self.hparams.val_metric | ||
130 | - ) | ||
131 | - | ||
132 | - def freeze_embeds(self): | ||
133 | - """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" | ||
134 | - try: | ||
135 | - freeze_params(self.model.model.shared) | ||
136 | - for d in [self.model.model.encoder, self.model.model.decoder]: | ||
137 | - freeze_params(d.embed_positions) | ||
138 | - freeze_params(d.embed_tokens) | ||
139 | - except AttributeError: | ||
140 | - freeze_params(self.model.shared) | ||
141 | - for d in [self.model.encoder, self.model.decoder]: | ||
142 | - freeze_params(d.embed_tokens) | ||
143 | - | ||
144 | - def forward(self, input_ids, patch_ids, **kwargs): | ||
145 | - return self.model(input_ids, patch_ids, **kwargs) | ||
146 | - | ||
147 | - def ids_to_clean_text(self, generated_ids: List[int]): | ||
148 | - gen_text = self.tokenizer.batch_decode( | ||
149 | - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | ||
150 | - ) | ||
151 | - return lmap(str.strip, gen_text) | ||
152 | - | ||
153 | - def _step(self, batch: dict) -> Tuple: | ||
154 | - pad_token_id = self.tokenizer.pad_token_id | ||
155 | - src_ids, src_mask, src_patch = batch[0].long(), batch[1].long(), batch[2].long() | ||
156 | - tgt_ids = batch[3].long() | ||
157 | - if isinstance(self.model, T5ForConditionalGeneration): | ||
158 | - decoder_input_ids = self.model._shift_right(tgt_ids) | ||
159 | - else: | ||
160 | - decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) | ||
161 | - | ||
162 | - outputs = self( | ||
163 | - src_ids, | ||
164 | - src_patch, | ||
165 | - attention_mask=src_mask, | ||
166 | - decoder_input_ids=decoder_input_ids, | ||
167 | - use_cache=False, | ||
168 | - ) | ||
169 | - lm_logits = outputs[0] | ||
170 | - if self.hparams.label_smoothing == 0: | ||
171 | - # Same behavior as modeling_bart.py, besides ignoring pad_token_id | ||
172 | - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) | ||
173 | - | ||
174 | - assert lm_logits.shape[-1] == self.model.config.vocab_size | ||
175 | - loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) | ||
176 | - else: | ||
177 | - lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) | ||
178 | - loss, nll_loss = label_smoothed_nll_loss( | ||
179 | - lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id | ||
180 | - ) | ||
181 | - return (loss,) | ||
182 | - | ||
183 | - @property | ||
184 | - def pad(self) -> int: | ||
185 | - return self.tokenizer.pad_token_id | ||
186 | - | ||
187 | - def training_step(self, batch, batch_idx) -> Dict: | ||
188 | - loss_tensors = self._step(batch) | ||
189 | - | ||
190 | - logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} | ||
191 | - # tokens per batch | ||
192 | - logs["tpb"] = ( | ||
193 | - batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum() | ||
194 | - ) | ||
195 | - return {"loss": loss_tensors[0], "log": logs} | ||
196 | - | ||
197 | - def validation_step(self, batch, batch_idx) -> Dict: | ||
198 | - return self._generative_step(batch) | ||
199 | - | ||
200 | - def validation_epoch_end(self, outputs, prefix="val") -> Dict: | ||
201 | - self.step_count += 1 | ||
202 | - losses = { | ||
203 | - k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names | ||
204 | - } | ||
205 | - loss = losses["loss"] | ||
206 | - rouges = { | ||
207 | - k: np.array([x[k] for x in outputs]).mean() | ||
208 | - for k in self.metric_names + ["gen_time", "gen_len"] | ||
209 | - } | ||
210 | - rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as( | ||
211 | - loss | ||
212 | - ) | ||
213 | - rouges.update({k: v.item() for k, v in losses.items()}) | ||
214 | - losses.update(rouges) | ||
215 | - metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} | ||
216 | - metrics["step_count"] = self.step_count | ||
217 | - self.save_metrics(metrics, prefix) # writes to self.metrics_save_path | ||
218 | - preds = flatten_list([x["preds"] for x in outputs]) | ||
219 | - return { | ||
220 | - "log": metrics, | ||
221 | - "preds": preds, | ||
222 | - f"{prefix}_loss": loss, | ||
223 | - f"{prefix}_{self.val_metric}": rouge_tensor, | ||
224 | - } | ||
225 | - | ||
226 | - def save_metrics(self, latest_metrics, type_path) -> None: | ||
227 | - self.metrics[type_path].append(latest_metrics) | ||
228 | - save_json(self.metrics, self.metrics_save_path) | ||
229 | - | ||
230 | - def calc_generative_metrics(self, preds, target) -> Dict: | ||
231 | - return calculate_rouge(preds, target) | ||
232 | - | ||
233 | - def _generative_step(self, batch: dict) -> dict: | ||
234 | - t0 = time.time() | ||
235 | - generated_ids = self.model.generate( | ||
236 | - batch[0].long(), | ||
237 | - patch_ids=batch[2].long(), | ||
238 | - attention_mask=batch[1].long(), | ||
239 | - use_cache=True, | ||
240 | - decoder_start_token_id=self.decoder_start_token_id, | ||
241 | - ) | ||
242 | - gen_time = (time.time() - t0) / batch[0].shape[0] | ||
243 | - preds: List[str] = self.ids_to_clean_text(generated_ids) | ||
244 | - target: List[str] = self.ids_to_clean_text(batch[3]) | ||
245 | - loss_tensors = self._step(batch) | ||
246 | - base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} | ||
247 | - rouge: Dict = self.calc_generative_metrics(preds, target) | ||
248 | - summ_len = np.mean(lmap(len, generated_ids)) | ||
249 | - base_metrics.update( | ||
250 | - gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge | ||
251 | - ) | ||
252 | - return base_metrics | ||
253 | - | ||
254 | - def test_step(self, batch, batch_idx): | ||
255 | - return self._generative_step(batch) | ||
256 | - | ||
257 | - def test_epoch_end(self, outputs): | ||
258 | - return self.validation_epoch_end(outputs, prefix="test") | ||
259 | - | ||
260 | - def get_dataset(self, type_path) -> Seq2SeqDataset: | ||
261 | - max_target_length = self.target_lens[type_path] | ||
262 | - data_config = DataConfig( | ||
263 | - endpoint=self.hparams.endpoint, | ||
264 | - access_key=os.environ["access_key"], | ||
265 | - secret_key=os.environ["secret_key"], | ||
266 | - region=self.hparams.region, | ||
267 | - dataset_name="commit-autosuggestions", | ||
268 | - additional={ | ||
269 | - "mode": ("training" if type_path == "train" else "evaluation"), | ||
270 | - "max_source_length": self.hparams.max_source_length, | ||
271 | - "max_target_length": max_target_length, | ||
272 | - "url": self.hparams.url, | ||
273 | - }, | ||
274 | - attributes=[ | ||
275 | - ("input_ids", "int32", (self.hparams.max_source_length,)), | ||
276 | - ("attention_masks", "int32", (self.hparams.max_source_length,)), | ||
277 | - ("patch_ids", "int32", (self.hparams.max_source_length,)), | ||
278 | - ("targets", "int32", (max_target_length,)), | ||
279 | - ], | ||
280 | - ) | ||
281 | - return Dataset(config=data_config, clear=True) | ||
282 | - | ||
283 | - def get_dataloader( | ||
284 | - self, type_path: str, batch_size: int, shuffle: bool = False | ||
285 | - ) -> DataLoader: | ||
286 | - dataset = self.get_dataset(type_path) | ||
287 | - sampler = None | ||
288 | - | ||
289 | - dataloader = DataLoader( | ||
290 | - dataset, | ||
291 | - batch_size=batch_size, | ||
292 | - shuffle=shuffle, | ||
293 | - num_workers=self.num_workers, | ||
294 | - sampler=sampler, | ||
295 | - ) | ||
296 | - return dataloader | ||
297 | - | ||
298 | - def train_dataloader(self) -> DataLoader: | ||
299 | - dataloader = self.get_dataloader( | ||
300 | - "train", batch_size=self.hparams.train_batch_size, shuffle=True | ||
301 | - ) | ||
302 | - return dataloader | ||
303 | - | ||
304 | - def val_dataloader(self) -> DataLoader: | ||
305 | - return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) | ||
306 | - | ||
307 | - def test_dataloader(self) -> DataLoader: | ||
308 | - return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) | ||
309 | - | ||
310 | - @staticmethod | ||
311 | - def add_model_specific_args(parser, root_dir): | ||
312 | - BaseTransformer.add_model_specific_args(parser, root_dir) | ||
313 | - add_generic_args(parser, root_dir) | ||
314 | - parser.add_argument("--url", type=str, required=True, help="github url") | ||
315 | - parser.add_argument( | ||
316 | - "--endpoint", | ||
317 | - type=str, | ||
318 | - required=True, | ||
319 | - help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", | ||
320 | - ) | ||
321 | - parser.add_argument( | ||
322 | - "--region", | ||
323 | - type=str, | ||
324 | - default=None, | ||
325 | - help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html", | ||
326 | - ) | ||
327 | - parser.add_argument( | ||
328 | - "--max_source_length", | ||
329 | - default=1024, | ||
330 | - type=int, | ||
331 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
332 | - "than this will be truncated, sequences shorter will be padded.", | ||
333 | - ) | ||
334 | - parser.add_argument( | ||
335 | - "--max_target_length", | ||
336 | - default=56, | ||
337 | - type=int, | ||
338 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
339 | - "than this will be truncated, sequences shorter will be padded.", | ||
340 | - ) | ||
341 | - parser.add_argument( | ||
342 | - "--val_max_target_length", | ||
343 | - default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. | ||
344 | - type=int, | ||
345 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
346 | - "than this will be truncated, sequences shorter will be padded.", | ||
347 | - ) | ||
348 | - parser.add_argument( | ||
349 | - "--test_max_target_length", | ||
350 | - default=142, | ||
351 | - type=int, | ||
352 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
353 | - "than this will be truncated, sequences shorter will be padded.", | ||
354 | - ) | ||
355 | - parser.add_argument("--freeze_encoder", action="store_true") | ||
356 | - parser.add_argument("--freeze_embeds", action="store_true") | ||
357 | - parser.add_argument("--sortish_sampler", action="store_true", default=False) | ||
358 | - parser.add_argument( | ||
359 | - "--logger_name", | ||
360 | - type=str, | ||
361 | - choices=["default", "wandb", "wandb_shared"], | ||
362 | - default="default", | ||
363 | - ) | ||
364 | - parser.add_argument( | ||
365 | - "--n_train", | ||
366 | - type=int, | ||
367 | - default=-1, | ||
368 | - required=False, | ||
369 | - help="# examples. -1 means use all.", | ||
370 | - ) | ||
371 | - parser.add_argument( | ||
372 | - "--n_val", | ||
373 | - type=int, | ||
374 | - default=500, | ||
375 | - required=False, | ||
376 | - help="# examples. -1 means use all.", | ||
377 | - ) | ||
378 | - parser.add_argument( | ||
379 | - "--n_test", | ||
380 | - type=int, | ||
381 | - default=-1, | ||
382 | - required=False, | ||
383 | - help="# examples. -1 means use all.", | ||
384 | - ) | ||
385 | - parser.add_argument( | ||
386 | - "--task", | ||
387 | - type=str, | ||
388 | - default="summarization", | ||
389 | - required=False, | ||
390 | - help="# examples. -1 means use all.", | ||
391 | - ) | ||
392 | - parser.add_argument( | ||
393 | - "--label_smoothing", type=float, default=0.0, required=False | ||
394 | - ) | ||
395 | - parser.add_argument("--src_lang", type=str, default="", required=False) | ||
396 | - parser.add_argument("--tgt_lang", type=str, default="", required=False) | ||
397 | - parser.add_argument("--eval_beams", type=int, default=None, required=False) | ||
398 | - parser.add_argument("--val_metric", type=str, default=None, required=False) | ||
399 | - parser.add_argument( | ||
400 | - "--early_stopping_patience", | ||
401 | - type=int, | ||
402 | - default=-1, | ||
403 | - required=False, | ||
404 | - help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.", | ||
405 | - ) | ||
406 | - return parser | ||
407 | - | ||
408 | - | ||
409 | -class TranslationModule(SummarizationModule): | ||
410 | - mode = "translation" | ||
411 | - loss_names = ["loss"] | ||
412 | - metric_names = ["bleu"] | ||
413 | - default_val_metric = "bleu" | ||
414 | - | ||
415 | - def __init__(self, hparams, **kwargs): | ||
416 | - super().__init__(hparams, **kwargs) | ||
417 | - self.dataset_kwargs["src_lang"] = hparams.src_lang | ||
418 | - self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang | ||
419 | - | ||
420 | - def calc_generative_metrics(self, preds, target) -> dict: | ||
421 | - return calculate_bleu(preds, target) | ||
422 | - | ||
423 | - | ||
424 | -def main(args, model=None) -> SummarizationModule: | ||
425 | - Path(args.output_dir).mkdir(exist_ok=True) | ||
426 | - if len(os.listdir(args.output_dir)) > 3 and args.do_train: | ||
427 | - raise ValueError( | ||
428 | - "Output directory ({}) already exists and is not empty.".format( | ||
429 | - args.output_dir | ||
430 | - ) | ||
431 | - ) | ||
432 | - if model is None: | ||
433 | - if args.task == "summarization": | ||
434 | - model: SummarizationModule = SummarizationModule(args) | ||
435 | - else: | ||
436 | - model: SummarizationModule = TranslationModule(args) | ||
437 | - | ||
438 | - logger = True | ||
439 | - es_callback = False | ||
440 | - trainer: pl.Trainer = generic_train( | ||
441 | - model, | ||
442 | - args, | ||
443 | - logging_callback=Seq2SeqLoggingCallback(), | ||
444 | - checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), | ||
445 | - early_stopping_callback=es_callback, | ||
446 | - logger=logger, | ||
447 | - # TODO: early stopping callback seems messed up | ||
448 | - ) | ||
449 | - pickle_save(model.hparams, model.output_dir / "hparams.pkl") | ||
450 | - if not args.do_predict: | ||
451 | - return model | ||
452 | - | ||
453 | - model.hparams.test_checkpoint = "" | ||
454 | - checkpoints = list( | ||
455 | - sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)) | ||
456 | - ) | ||
457 | - if checkpoints: | ||
458 | - model.hparams.test_checkpoint = checkpoints[-1] | ||
459 | - trainer.resume_from_checkpoint = checkpoints[-1] | ||
460 | - trainer.logger.log_hyperparams(model.hparams) | ||
461 | - | ||
462 | - # test() without a model tests using the best checkpoint automatically | ||
463 | - trainer.test() | ||
464 | - return model | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
train/generation_utils.py
deleted
100644 → 0
1 | -# coding=utf-8 | ||
2 | -# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. | ||
3 | -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
4 | -# | ||
5 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | -# you may not use this file except in compliance with the License. | ||
7 | -# You may obtain a copy of the License at | ||
8 | -# | ||
9 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | -# | ||
11 | -# Unless required by applicable law or agreed to in writing, software | ||
12 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | -# See the License for the specific language governing permissions and | ||
15 | -# limitations under the License. | ||
16 | - | ||
17 | -from typing import Iterable, List, Optional, Tuple | ||
18 | - | ||
19 | -import torch | ||
20 | -from torch import Tensor | ||
21 | -from torch.nn import functional as F | ||
22 | - | ||
23 | -from transformers.file_utils import ModelOutput | ||
24 | -import logging | ||
25 | - | ||
26 | -logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
27 | -logging.basicConfig( | ||
28 | - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
29 | - datefmt="%m/%d/%Y %H:%M:%S", | ||
30 | - level=logging.INFO, | ||
31 | -) | ||
32 | - | ||
33 | - | ||
34 | -class GenerationMixin: | ||
35 | - """ | ||
36 | - A class contraining all of the functions supporting generation, to be used as a mixin in | ||
37 | - :class:`~transfomers.PreTrainedModel`. | ||
38 | - """ | ||
39 | - | ||
40 | - def prepare_inputs_for_generation(self, input_ids, **kwargs): | ||
41 | - """ | ||
42 | - Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to prepare inputs in the | ||
43 | - generate method. | ||
44 | - """ | ||
45 | - return {"input_ids": input_ids} | ||
46 | - | ||
47 | - def adjust_logits_during_generation(self, logits, **kwargs): | ||
48 | - """ | ||
49 | - Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in | ||
50 | - the generate method. | ||
51 | - """ | ||
52 | - return logits | ||
53 | - | ||
54 | - def enforce_repetition_penalty_( | ||
55 | - self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty | ||
56 | - ): | ||
57 | - """ | ||
58 | - Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__). | ||
59 | - """ | ||
60 | - for i in range(batch_size * num_beams): | ||
61 | - for previous_token in set(prev_output_tokens[i].tolist()): | ||
62 | - # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability | ||
63 | - if lprobs[i, previous_token] < 0: | ||
64 | - lprobs[i, previous_token] *= repetition_penalty | ||
65 | - else: | ||
66 | - lprobs[i, previous_token] /= repetition_penalty | ||
67 | - | ||
68 | - def postprocess_next_token_scores( | ||
69 | - self, | ||
70 | - scores, | ||
71 | - input_ids, | ||
72 | - no_repeat_ngram_size, | ||
73 | - bad_words_ids, | ||
74 | - cur_len, | ||
75 | - min_length, | ||
76 | - max_length, | ||
77 | - eos_token_id, | ||
78 | - repetition_penalty, | ||
79 | - batch_size, | ||
80 | - num_beams, | ||
81 | - ): | ||
82 | - # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) | ||
83 | - if repetition_penalty != 1.0: | ||
84 | - self.enforce_repetition_penalty_( | ||
85 | - scores, batch_size, num_beams, input_ids, repetition_penalty, | ||
86 | - ) | ||
87 | - | ||
88 | - # set eos token prob to zero if min_length is not reached | ||
89 | - if eos_token_id is not None and cur_len < min_length: | ||
90 | - scores[:, eos_token_id] = -float("inf") | ||
91 | - | ||
92 | - if no_repeat_ngram_size > 0: | ||
93 | - # calculate a list of banned tokens to prevent repetitively generating the same ngrams | ||
94 | - num_batch_hypotheses = batch_size * num_beams | ||
95 | - # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 | ||
96 | - banned_batch_tokens = calc_banned_ngram_tokens( | ||
97 | - input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len | ||
98 | - ) | ||
99 | - for i, banned_tokens in enumerate(banned_batch_tokens): | ||
100 | - scores[i, banned_tokens] = -float("inf") | ||
101 | - | ||
102 | - if bad_words_ids is not None: | ||
103 | - # Exclude EOS token (already processed) | ||
104 | - bad_words_ids = list( | ||
105 | - filter( | ||
106 | - lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids | ||
107 | - ) | ||
108 | - ) | ||
109 | - # calculate a list of banned tokens according to bad words | ||
110 | - banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) | ||
111 | - # Modify the scores in place by setting the banned tokens logits to `-inf` | ||
112 | - set_scores_to_inf_for_banned_tokens(scores, banned_tokens) | ||
113 | - | ||
114 | - return scores | ||
115 | - | ||
116 | - @torch.no_grad() | ||
117 | - def generate( | ||
118 | - self, | ||
119 | - input_ids: Optional[torch.LongTensor] = None, | ||
120 | - patch_ids: Optional[torch.LongTensor] = None, | ||
121 | - max_length: Optional[int] = None, | ||
122 | - min_length: Optional[int] = None, | ||
123 | - do_sample: Optional[bool] = None, | ||
124 | - early_stopping: Optional[bool] = None, | ||
125 | - num_beams: Optional[int] = None, | ||
126 | - temperature: Optional[float] = None, | ||
127 | - top_k: Optional[int] = None, | ||
128 | - top_p: Optional[float] = None, | ||
129 | - repetition_penalty: Optional[float] = None, | ||
130 | - bad_words_ids: Optional[Iterable[int]] = None, | ||
131 | - bos_token_id: Optional[int] = None, | ||
132 | - pad_token_id: Optional[int] = None, | ||
133 | - eos_token_id: Optional[int] = None, | ||
134 | - length_penalty: Optional[float] = None, | ||
135 | - no_repeat_ngram_size: Optional[int] = None, | ||
136 | - num_return_sequences: Optional[int] = None, | ||
137 | - attention_mask: Optional[torch.LongTensor] = None, | ||
138 | - decoder_start_token_id: Optional[int] = None, | ||
139 | - use_cache: Optional[bool] = None, | ||
140 | - **model_kwargs, | ||
141 | - ) -> torch.LongTensor: | ||
142 | - r""" | ||
143 | - Generates sequences for models with a language modeling head. The method currently supports greedy decoding, | ||
144 | - beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. | ||
145 | - | ||
146 | - Adapted in part from `Facebook's XLM beam search code | ||
147 | - <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. | ||
148 | - | ||
149 | - Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the | ||
150 | - attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values | ||
151 | - indicated are the default values of those config. | ||
152 | - | ||
153 | - Most of these parameters are explained in more detail in `this blog post | ||
154 | - <https://huggingface.co/blog/how-to-generate>`__. | ||
155 | - | ||
156 | - Parameters: | ||
157 | - | ||
158 | - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | ||
159 | - The sequence used as a prompt for the generation. If :obj:`None` the method initializes | ||
160 | - it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. | ||
161 | - max_length (:obj:`int`, `optional`, defaults to 20): | ||
162 | - The maximum length of the sequence to be generated. | ||
163 | - min_length (:obj:`int`, `optional`, defaults to 10): | ||
164 | - The minimum length of the sequence to be generated. | ||
165 | - do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
166 | - Whether or not to use sampling ; use greedy decoding otherwise. | ||
167 | - early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
168 | - Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. | ||
169 | - num_beams (:obj:`int`, `optional`, defaults to 1): | ||
170 | - Number of beams for beam search. 1 means no beam search. | ||
171 | - temperature (:obj:`float`, `optional`, defaults tp 1.0): | ||
172 | - The value used to module the next token probabilities. | ||
173 | - top_k (:obj:`int`, `optional`, defaults to 50): | ||
174 | - The number of highest probability vocabulary tokens to keep for top-k-filtering. | ||
175 | - top_p (:obj:`float`, `optional`, defaults to 1.0): | ||
176 | - If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or | ||
177 | - higher are kept for generation. | ||
178 | - repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): | ||
179 | - The parameter for repetition penalty. 1.0 means no penalty. See `this paper | ||
180 | - <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. | ||
181 | - pad_token_id (:obj:`int`, `optional`): | ||
182 | - The id of the `padding` token. | ||
183 | - bos_token_id (:obj:`int`, `optional`): | ||
184 | - The id of the `beginning-of-sequence` token. | ||
185 | - eos_token_id (:obj:`int`, `optional`): | ||
186 | - The id of the `end-of-sequence` token. | ||
187 | - length_penalty (:obj:`float`, `optional`, defaults to 1.0): | ||
188 | - Exponential penalty to the length. 1.0 means no penalty. | ||
189 | - | ||
190 | - Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in | ||
191 | - order to encourage the model to produce longer sequences. | ||
192 | - no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): | ||
193 | - If set to int > 0, all ngrams of that size can only occur once. | ||
194 | - bad_words_ids(:obj:`List[int]`, `optional`): | ||
195 | - List of token ids that are not allowed to be generated. In order to get the tokens of the words that | ||
196 | - should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. | ||
197 | - num_return_sequences(:obj:`int`, `optional`, defaults to 1): | ||
198 | - The number of independently computed returned sequences for each element in the batch. | ||
199 | - attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): | ||
200 | - Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for | ||
201 | - tokens that are not masked, and 0 for masked tokens. | ||
202 | - | ||
203 | - If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. | ||
204 | - | ||
205 | - `What are attention masks? <../glossary.html#attention-mask>`__ | ||
206 | - decoder_start_token_id (:obj:`int`, `optional`): | ||
207 | - If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. | ||
208 | - use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
209 | - Whether or not the model should use the past last key/values attentions (if applicable to the model) to | ||
210 | - speed up decoding. | ||
211 | - model_kwargs: | ||
212 | - Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. | ||
213 | - | ||
214 | - Return: | ||
215 | - | ||
216 | - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: | ||
217 | - The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or | ||
218 | - shorter if all batches finished early due to the :obj:`eos_token_id`. | ||
219 | - | ||
220 | - Examples:: | ||
221 | - | ||
222 | - tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer | ||
223 | - model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. | ||
224 | - outputs = model.generate(max_length=40) # do greedy decoding | ||
225 | - print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) | ||
226 | - | ||
227 | - tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer | ||
228 | - model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. | ||
229 | - input_context = 'The dog' | ||
230 | - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context | ||
231 | - outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' | ||
232 | - for i in range(3): # 3 output sequences were generated | ||
233 | - print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) | ||
234 | - | ||
235 | - tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer | ||
236 | - model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. | ||
237 | - input_context = 'The dog' | ||
238 | - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context | ||
239 | - outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling | ||
240 | - for i in range(3): # 3 output sequences were generated | ||
241 | - print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) | ||
242 | - | ||
243 | - tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer | ||
244 | - model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. | ||
245 | - input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl | ||
246 | - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context | ||
247 | - outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences | ||
248 | - print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) | ||
249 | - | ||
250 | - tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer | ||
251 | - model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. | ||
252 | - input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl | ||
253 | - bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] | ||
254 | - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context | ||
255 | - outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated | ||
256 | - """ | ||
257 | - | ||
258 | - # We cannot generate if the model does not have a LM head | ||
259 | - if self.get_output_embeddings() is None: | ||
260 | - raise AttributeError( | ||
261 | - "You tried to generate sequences with a model that does not have a LM Head." | ||
262 | - "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" | ||
263 | - ) | ||
264 | - | ||
265 | - max_length = max_length if max_length is not None else self.config.max_length | ||
266 | - min_length = min_length if min_length is not None else self.config.min_length | ||
267 | - do_sample = do_sample if do_sample is not None else self.config.do_sample | ||
268 | - early_stopping = ( | ||
269 | - early_stopping if early_stopping is not None else self.config.early_stopping | ||
270 | - ) | ||
271 | - use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
272 | - num_beams = num_beams if num_beams is not None else self.config.num_beams | ||
273 | - temperature = ( | ||
274 | - temperature if temperature is not None else self.config.temperature | ||
275 | - ) | ||
276 | - top_k = top_k if top_k is not None else self.config.top_k | ||
277 | - top_p = top_p if top_p is not None else self.config.top_p | ||
278 | - repetition_penalty = ( | ||
279 | - repetition_penalty | ||
280 | - if repetition_penalty is not None | ||
281 | - else self.config.repetition_penalty | ||
282 | - ) | ||
283 | - bos_token_id = ( | ||
284 | - bos_token_id if bos_token_id is not None else self.config.bos_token_id | ||
285 | - ) | ||
286 | - pad_token_id = ( | ||
287 | - pad_token_id if pad_token_id is not None else self.config.pad_token_id | ||
288 | - ) | ||
289 | - eos_token_id = ( | ||
290 | - eos_token_id if eos_token_id is not None else self.config.eos_token_id | ||
291 | - ) | ||
292 | - length_penalty = ( | ||
293 | - length_penalty if length_penalty is not None else self.config.length_penalty | ||
294 | - ) | ||
295 | - no_repeat_ngram_size = ( | ||
296 | - no_repeat_ngram_size | ||
297 | - if no_repeat_ngram_size is not None | ||
298 | - else self.config.no_repeat_ngram_size | ||
299 | - ) | ||
300 | - bad_words_ids = ( | ||
301 | - bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids | ||
302 | - ) | ||
303 | - num_return_sequences = ( | ||
304 | - num_return_sequences | ||
305 | - if num_return_sequences is not None | ||
306 | - else self.config.num_return_sequences | ||
307 | - ) | ||
308 | - decoder_start_token_id = ( | ||
309 | - decoder_start_token_id | ||
310 | - if decoder_start_token_id is not None | ||
311 | - else self.config.decoder_start_token_id | ||
312 | - ) | ||
313 | - | ||
314 | - if input_ids is not None: | ||
315 | - batch_size = input_ids.shape[0] # overriden by the input batch_size | ||
316 | - else: | ||
317 | - batch_size = 1 | ||
318 | - | ||
319 | - assert ( | ||
320 | - isinstance(max_length, int) and max_length > 0 | ||
321 | - ), "`max_length` should be a strictly positive integer." | ||
322 | - assert ( | ||
323 | - isinstance(min_length, int) and min_length >= 0 | ||
324 | - ), "`min_length` should be a positive integer." | ||
325 | - assert isinstance(do_sample, bool), "`do_sample` should be a boolean." | ||
326 | - assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." | ||
327 | - assert isinstance(use_cache, bool), "`use_cache` should be a boolean." | ||
328 | - assert ( | ||
329 | - isinstance(num_beams, int) and num_beams > 0 | ||
330 | - ), "`num_beams` should be a strictly positive integer." | ||
331 | - assert temperature > 0, "`temperature` should be strictly positive." | ||
332 | - assert ( | ||
333 | - isinstance(top_k, int) and top_k >= 0 | ||
334 | - ), "`top_k` should be a positive integer." | ||
335 | - assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." | ||
336 | - assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." | ||
337 | - assert input_ids is not None or ( | ||
338 | - isinstance(bos_token_id, int) and bos_token_id >= 0 | ||
339 | - ), "If input_ids is not defined, `bos_token_id` should be a positive integer." | ||
340 | - assert pad_token_id is None or ( | ||
341 | - isinstance(pad_token_id, int) and (pad_token_id >= 0) | ||
342 | - ), "`pad_token_id` should be a positive integer." | ||
343 | - assert (eos_token_id is None) or ( | ||
344 | - isinstance(eos_token_id, int) and (eos_token_id >= 0) | ||
345 | - ), "`eos_token_id` should be a positive integer." | ||
346 | - assert length_penalty > 0, "`length_penalty` should be strictly positive." | ||
347 | - assert ( | ||
348 | - isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 | ||
349 | - ), "`no_repeat_ngram_size` should be a positive integer." | ||
350 | - assert ( | ||
351 | - isinstance(num_return_sequences, int) and num_return_sequences > 0 | ||
352 | - ), "`num_return_sequences` should be a strictly positive integer." | ||
353 | - assert ( | ||
354 | - bad_words_ids is None | ||
355 | - or isinstance(bad_words_ids, list) | ||
356 | - and isinstance(bad_words_ids[0], list) | ||
357 | - ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" | ||
358 | - | ||
359 | - if input_ids is None: | ||
360 | - assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( | ||
361 | - "you should either supply a context to complete as `input_ids` input " | ||
362 | - "or a `bos_token_id` (integer >= 0) as a first token to start the generation." | ||
363 | - ) | ||
364 | - input_ids = torch.full( | ||
365 | - (batch_size, 1), | ||
366 | - bos_token_id, | ||
367 | - dtype=torch.long, | ||
368 | - device=next(self.parameters()).device, | ||
369 | - ) | ||
370 | - else: | ||
371 | - assert ( | ||
372 | - input_ids.dim() == 2 | ||
373 | - ), "Input prompt should be of shape (batch_size, sequence length)." | ||
374 | - | ||
375 | - # not allow to duplicate outputs when greedy decoding | ||
376 | - if do_sample is False: | ||
377 | - if num_beams == 1: | ||
378 | - # no_beam_search greedy generation conditions | ||
379 | - assert ( | ||
380 | - num_return_sequences == 1 | ||
381 | - ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" | ||
382 | - | ||
383 | - else: | ||
384 | - # beam_search greedy generation conditions | ||
385 | - assert ( | ||
386 | - num_beams >= num_return_sequences | ||
387 | - ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" | ||
388 | - | ||
389 | - # create attention mask if necessary | ||
390 | - # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 | ||
391 | - if ( | ||
392 | - (attention_mask is None) | ||
393 | - and (pad_token_id is not None) | ||
394 | - and (pad_token_id in input_ids) | ||
395 | - ): | ||
396 | - attention_mask = input_ids.ne(pad_token_id).long() | ||
397 | - elif attention_mask is None: | ||
398 | - attention_mask = input_ids.new_ones(input_ids.shape) | ||
399 | - | ||
400 | - # set pad_token_id to eos_token_id if not set. Important that this is done after | ||
401 | - # attention_mask is created | ||
402 | - if pad_token_id is None and eos_token_id is not None: | ||
403 | - logger.warning( | ||
404 | - "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format( | ||
405 | - eos_token_id | ||
406 | - ) | ||
407 | - ) | ||
408 | - pad_token_id = eos_token_id | ||
409 | - | ||
410 | - # current position and vocab size | ||
411 | - if hasattr(self.config, "vocab_size"): | ||
412 | - vocab_size = self.config.vocab_size | ||
413 | - elif ( | ||
414 | - self.config.is_encoder_decoder | ||
415 | - and hasattr(self.config, "decoder") | ||
416 | - and hasattr(self.config.decoder, "vocab_size") | ||
417 | - ): | ||
418 | - vocab_size = self.config.decoder.vocab_size | ||
419 | - | ||
420 | - # set effective batch size and effective batch multiplier according to do_sample | ||
421 | - if do_sample: | ||
422 | - effective_batch_size = batch_size * num_return_sequences | ||
423 | - effective_batch_mult = num_return_sequences | ||
424 | - else: | ||
425 | - effective_batch_size = batch_size | ||
426 | - effective_batch_mult = 1 | ||
427 | - | ||
428 | - if self.config.is_encoder_decoder: | ||
429 | - if decoder_start_token_id is None: | ||
430 | - # see if BOS token can be used for decoder_start_token_id | ||
431 | - if bos_token_id is not None: | ||
432 | - decoder_start_token_id = bos_token_id | ||
433 | - elif hasattr(self.config, "decoder") and hasattr( | ||
434 | - self.config.decoder, "bos_token_id" | ||
435 | - ): | ||
436 | - decoder_start_token_id = self.config.decoder.bos_token_id | ||
437 | - else: | ||
438 | - raise ValueError( | ||
439 | - "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" | ||
440 | - ) | ||
441 | - | ||
442 | - assert hasattr( | ||
443 | - self, "get_encoder" | ||
444 | - ), "{} should have a 'get_encoder' function defined".format(self) | ||
445 | - assert callable(self.get_encoder), "{} should be a method".format( | ||
446 | - self.get_encoder | ||
447 | - ) | ||
448 | - | ||
449 | - # get encoder and store encoder outputs | ||
450 | - encoder = self.get_encoder() | ||
451 | - encoder_outputs: ModelOutput = encoder( | ||
452 | - input_ids, patch_ids, attention_mask=attention_mask, return_dict=True | ||
453 | - ) | ||
454 | - | ||
455 | - # Expand input ids if num_beams > 1 or num_return_sequences > 1 | ||
456 | - if num_return_sequences > 1 or num_beams > 1: | ||
457 | - input_ids_len = input_ids.shape[-1] | ||
458 | - input_ids = input_ids.unsqueeze(1).expand( | ||
459 | - batch_size, effective_batch_mult * num_beams, input_ids_len | ||
460 | - ) | ||
461 | - patch_ids = patch_ids.unsqueeze(1).expand( | ||
462 | - batch_size, effective_batch_mult * num_beams, input_ids_len | ||
463 | - ) | ||
464 | - attention_mask = attention_mask.unsqueeze(1).expand( | ||
465 | - batch_size, effective_batch_mult * num_beams, input_ids_len | ||
466 | - ) | ||
467 | - | ||
468 | - input_ids = input_ids.contiguous().view( | ||
469 | - effective_batch_size * num_beams, input_ids_len | ||
470 | - ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
471 | - patch_ids = patch_ids.contiguous().view( | ||
472 | - effective_batch_size * num_beams, input_ids_len | ||
473 | - ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
474 | - attention_mask = attention_mask.contiguous().view( | ||
475 | - effective_batch_size * num_beams, input_ids_len | ||
476 | - ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
477 | - | ||
478 | - if self.config.is_encoder_decoder: | ||
479 | - # create empty decoder_input_ids | ||
480 | - input_ids = torch.full( | ||
481 | - (effective_batch_size * num_beams, 1), | ||
482 | - decoder_start_token_id, | ||
483 | - dtype=torch.long, | ||
484 | - device=next(self.parameters()).device, | ||
485 | - ) | ||
486 | - cur_len = 1 | ||
487 | - | ||
488 | - assert ( | ||
489 | - batch_size == encoder_outputs.last_hidden_state.shape[0] | ||
490 | - ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " | ||
491 | - | ||
492 | - # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) | ||
493 | - expanded_batch_idxs = ( | ||
494 | - torch.arange(batch_size) | ||
495 | - .view(-1, 1) | ||
496 | - .repeat(1, num_beams * effective_batch_mult) | ||
497 | - .view(-1) | ||
498 | - .to(input_ids.device) | ||
499 | - ) | ||
500 | - | ||
501 | - # expand encoder_outputs | ||
502 | - encoder_outputs[ | ||
503 | - "last_hidden_state" | ||
504 | - ] = encoder_outputs.last_hidden_state.index_select(0, expanded_batch_idxs) | ||
505 | - | ||
506 | - # save encoder_outputs in `model_kwargs` | ||
507 | - model_kwargs["encoder_outputs"] = encoder_outputs | ||
508 | - | ||
509 | - else: | ||
510 | - cur_len = input_ids.shape[-1] | ||
511 | - | ||
512 | - assert ( | ||
513 | - cur_len < max_length | ||
514 | - ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" | ||
515 | - | ||
516 | - if num_beams > 1: | ||
517 | - output = self._generate_beam_search( | ||
518 | - input_ids, | ||
519 | - cur_len=cur_len, | ||
520 | - max_length=max_length, | ||
521 | - min_length=min_length, | ||
522 | - do_sample=do_sample, | ||
523 | - early_stopping=early_stopping, | ||
524 | - temperature=temperature, | ||
525 | - top_k=top_k, | ||
526 | - top_p=top_p, | ||
527 | - repetition_penalty=repetition_penalty, | ||
528 | - no_repeat_ngram_size=no_repeat_ngram_size, | ||
529 | - bad_words_ids=bad_words_ids, | ||
530 | - pad_token_id=pad_token_id, | ||
531 | - eos_token_id=eos_token_id, | ||
532 | - batch_size=effective_batch_size, | ||
533 | - num_return_sequences=num_return_sequences, | ||
534 | - length_penalty=length_penalty, | ||
535 | - num_beams=num_beams, | ||
536 | - vocab_size=vocab_size, | ||
537 | - attention_mask=attention_mask, | ||
538 | - use_cache=use_cache, | ||
539 | - model_kwargs=model_kwargs, | ||
540 | - ) | ||
541 | - else: | ||
542 | - output = self._generate_no_beam_search( | ||
543 | - input_ids, | ||
544 | - cur_len=cur_len, | ||
545 | - max_length=max_length, | ||
546 | - min_length=min_length, | ||
547 | - do_sample=do_sample, | ||
548 | - temperature=temperature, | ||
549 | - top_k=top_k, | ||
550 | - top_p=top_p, | ||
551 | - repetition_penalty=repetition_penalty, | ||
552 | - no_repeat_ngram_size=no_repeat_ngram_size, | ||
553 | - bad_words_ids=bad_words_ids, | ||
554 | - pad_token_id=pad_token_id, | ||
555 | - eos_token_id=eos_token_id, | ||
556 | - batch_size=effective_batch_size, | ||
557 | - attention_mask=attention_mask, | ||
558 | - use_cache=use_cache, | ||
559 | - model_kwargs=model_kwargs, | ||
560 | - ) | ||
561 | - | ||
562 | - return output | ||
563 | - | ||
564 | - def _generate_no_beam_search( | ||
565 | - self, | ||
566 | - input_ids, | ||
567 | - cur_len, | ||
568 | - max_length, | ||
569 | - min_length, | ||
570 | - do_sample, | ||
571 | - temperature, | ||
572 | - top_k, | ||
573 | - top_p, | ||
574 | - repetition_penalty, | ||
575 | - no_repeat_ngram_size, | ||
576 | - bad_words_ids, | ||
577 | - pad_token_id, | ||
578 | - eos_token_id, | ||
579 | - batch_size, | ||
580 | - attention_mask, | ||
581 | - use_cache, | ||
582 | - model_kwargs, | ||
583 | - ): | ||
584 | - """Generate sequences for each example without beam search (num_beams == 1). | ||
585 | - All returned sequence are generated independantly. | ||
586 | - """ | ||
587 | - # length of generated sentences / unfinished sentences | ||
588 | - unfinished_sents = input_ids.new(batch_size).fill_(1) | ||
589 | - sent_lengths = input_ids.new(batch_size).fill_(max_length) | ||
590 | - | ||
591 | - past = None | ||
592 | - while cur_len < max_length: | ||
593 | - model_inputs = self.prepare_inputs_for_generation( | ||
594 | - input_ids, | ||
595 | - past=past, | ||
596 | - attention_mask=attention_mask, | ||
597 | - use_cache=use_cache, | ||
598 | - **model_kwargs, | ||
599 | - ) | ||
600 | - | ||
601 | - outputs = self(**model_inputs, return_dict=True) | ||
602 | - next_token_logits = outputs.logits[:, -1, :] | ||
603 | - | ||
604 | - scores = self.postprocess_next_token_scores( | ||
605 | - scores=next_token_logits, | ||
606 | - input_ids=input_ids, | ||
607 | - no_repeat_ngram_size=no_repeat_ngram_size, | ||
608 | - bad_words_ids=bad_words_ids, | ||
609 | - cur_len=cur_len, | ||
610 | - min_length=min_length, | ||
611 | - max_length=max_length, | ||
612 | - eos_token_id=eos_token_id, | ||
613 | - repetition_penalty=repetition_penalty, | ||
614 | - batch_size=batch_size, | ||
615 | - num_beams=1, | ||
616 | - ) | ||
617 | - | ||
618 | - # if model has past, then set the past variable to speed up decoding | ||
619 | - if "past_key_values" in outputs: | ||
620 | - past = outputs.past_key_values | ||
621 | - elif "mems" in outputs: | ||
622 | - past = outputs.mems | ||
623 | - | ||
624 | - if do_sample: | ||
625 | - # Temperature (higher temperature => more likely to sample low probability tokens) | ||
626 | - if temperature != 1.0: | ||
627 | - scores = scores / temperature | ||
628 | - # Top-p/top-k filtering | ||
629 | - next_token_logscores = top_k_top_p_filtering( | ||
630 | - scores, top_k=top_k, top_p=top_p | ||
631 | - ) | ||
632 | - # Sample | ||
633 | - probs = F.softmax(next_token_logscores, dim=-1) | ||
634 | - next_token = torch.multinomial(probs, num_samples=1).squeeze(1) | ||
635 | - else: | ||
636 | - # Greedy decoding | ||
637 | - next_token = torch.argmax(next_token_logits, dim=-1) | ||
638 | - | ||
639 | - # update generations and finished sentences | ||
640 | - if eos_token_id is not None: | ||
641 | - # pad finished sentences if eos_token_id exist | ||
642 | - tokens_to_add = next_token * unfinished_sents + (pad_token_id) * ( | ||
643 | - 1 - unfinished_sents | ||
644 | - ) | ||
645 | - else: | ||
646 | - tokens_to_add = next_token | ||
647 | - | ||
648 | - # add token and increase length by one | ||
649 | - input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) | ||
650 | - cur_len = cur_len + 1 | ||
651 | - | ||
652 | - if eos_token_id is not None: | ||
653 | - eos_in_sents = tokens_to_add == eos_token_id | ||
654 | - # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length | ||
655 | - is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul( | ||
656 | - eos_in_sents.long() | ||
657 | - ).bool() | ||
658 | - sent_lengths.masked_fill_( | ||
659 | - is_sents_unfinished_and_token_to_add_is_eos, cur_len | ||
660 | - ) | ||
661 | - # unfinished_sents is set to zero if eos in sentence | ||
662 | - unfinished_sents.mul_((~eos_in_sents).long()) | ||
663 | - | ||
664 | - # stop when there is a </s> in each sentence, or if we exceed the maximul length | ||
665 | - if unfinished_sents.max() == 0: | ||
666 | - break | ||
667 | - | ||
668 | - # extend attention_mask for new generated input if only decoder | ||
669 | - if self.config.is_encoder_decoder is False: | ||
670 | - attention_mask = torch.cat( | ||
671 | - [ | ||
672 | - attention_mask, | ||
673 | - attention_mask.new_ones((attention_mask.shape[0], 1)), | ||
674 | - ], | ||
675 | - dim=-1, | ||
676 | - ) | ||
677 | - | ||
678 | - return input_ids | ||
679 | - | ||
680 | - def _generate_beam_search( | ||
681 | - self, | ||
682 | - input_ids, | ||
683 | - cur_len, | ||
684 | - max_length, | ||
685 | - min_length, | ||
686 | - do_sample, | ||
687 | - early_stopping, | ||
688 | - temperature, | ||
689 | - top_k, | ||
690 | - top_p, | ||
691 | - repetition_penalty, | ||
692 | - no_repeat_ngram_size, | ||
693 | - bad_words_ids, | ||
694 | - pad_token_id, | ||
695 | - eos_token_id, | ||
696 | - batch_size, | ||
697 | - num_return_sequences, | ||
698 | - length_penalty, | ||
699 | - num_beams, | ||
700 | - vocab_size, | ||
701 | - attention_mask, | ||
702 | - use_cache, | ||
703 | - model_kwargs, | ||
704 | - ): | ||
705 | - """Generate sequences for each example with beam search.""" | ||
706 | - | ||
707 | - # generated hypotheses | ||
708 | - generated_hyps = [ | ||
709 | - BeamHypotheses( | ||
710 | - num_beams, max_length, length_penalty, early_stopping=early_stopping | ||
711 | - ) | ||
712 | - for _ in range(batch_size) | ||
713 | - ] | ||
714 | - | ||
715 | - # scores for each sentence in the beam | ||
716 | - beam_scores = torch.zeros( | ||
717 | - (batch_size, num_beams), dtype=torch.float, device=input_ids.device | ||
718 | - ) | ||
719 | - | ||
720 | - # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times | ||
721 | - if do_sample is False: | ||
722 | - beam_scores[:, 1:] = -1e9 | ||
723 | - beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) | ||
724 | - | ||
725 | - # cache compute states | ||
726 | - past = None | ||
727 | - | ||
728 | - # done sentences | ||
729 | - done = [False for _ in range(batch_size)] | ||
730 | - | ||
731 | - while cur_len < max_length: | ||
732 | - model_inputs = self.prepare_inputs_for_generation( | ||
733 | - input_ids, | ||
734 | - past=past, | ||
735 | - attention_mask=attention_mask, | ||
736 | - use_cache=use_cache, | ||
737 | - **model_kwargs, | ||
738 | - ) | ||
739 | - outputs = self( | ||
740 | - **model_inputs, return_dict=True | ||
741 | - ) # (batch_size * num_beams, cur_len, vocab_size) | ||
742 | - next_token_logits = outputs.logits[ | ||
743 | - :, -1, : | ||
744 | - ] # (batch_size * num_beams, vocab_size) | ||
745 | - | ||
746 | - # if model has past, then set the past variable to speed up decoding | ||
747 | - if "past_key_values" in outputs: | ||
748 | - past = outputs.past_key_values | ||
749 | - elif "mems" in outputs: | ||
750 | - past = outputs.mems | ||
751 | - | ||
752 | - if self.config.is_encoder_decoder and do_sample is False: | ||
753 | - # TODO (PVP) still a bit hacky here - there might be a better solution | ||
754 | - next_token_logits = self.adjust_logits_during_generation( | ||
755 | - next_token_logits, cur_len=cur_len, max_length=max_length | ||
756 | - ) | ||
757 | - | ||
758 | - scores = F.log_softmax( | ||
759 | - next_token_logits, dim=-1 | ||
760 | - ) # (batch_size * num_beams, vocab_size) | ||
761 | - | ||
762 | - scores = self.postprocess_next_token_scores( | ||
763 | - scores=scores, | ||
764 | - input_ids=input_ids, | ||
765 | - no_repeat_ngram_size=no_repeat_ngram_size, | ||
766 | - bad_words_ids=bad_words_ids, | ||
767 | - cur_len=cur_len, | ||
768 | - min_length=min_length, | ||
769 | - max_length=max_length, | ||
770 | - eos_token_id=eos_token_id, | ||
771 | - repetition_penalty=repetition_penalty, | ||
772 | - batch_size=batch_size, | ||
773 | - num_beams=num_beams, | ||
774 | - ) | ||
775 | - | ||
776 | - assert scores.shape == ( | ||
777 | - batch_size * num_beams, | ||
778 | - vocab_size, | ||
779 | - ), "Shapes of scores: {} != {}".format( | ||
780 | - scores.shape, (batch_size * num_beams, vocab_size) | ||
781 | - ) | ||
782 | - | ||
783 | - if do_sample: | ||
784 | - _scores = scores + beam_scores[:, None].expand_as( | ||
785 | - scores | ||
786 | - ) # (batch_size * num_beams, vocab_size) | ||
787 | - # Temperature | ||
788 | - if temperature != 1.0: | ||
789 | - _scores = _scores / temperature | ||
790 | - # Top-p/top-k filtering | ||
791 | - _scores = top_k_top_p_filtering( | ||
792 | - _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 | ||
793 | - ) # (batch_size * num_beams, vocab_size) | ||
794 | - # re-organize to group the beam together to sample from all beam_idxs | ||
795 | - _scores = _scores.contiguous().view( | ||
796 | - batch_size, num_beams * vocab_size | ||
797 | - ) # (batch_size, num_beams * vocab_size) | ||
798 | - | ||
799 | - # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) | ||
800 | - probs = F.softmax(_scores, dim=-1) | ||
801 | - next_tokens = torch.multinomial( | ||
802 | - probs, num_samples=2 * num_beams | ||
803 | - ) # (batch_size, num_beams * 2) | ||
804 | - # Compute next scores | ||
805 | - next_scores = torch.gather( | ||
806 | - _scores, -1, next_tokens | ||
807 | - ) # (batch_size, num_beams * 2) | ||
808 | - # sort the sampled vector to make sure that the first num_beams samples are the best | ||
809 | - next_scores, next_scores_indices = torch.sort( | ||
810 | - next_scores, descending=True, dim=1 | ||
811 | - ) | ||
812 | - next_tokens = torch.gather( | ||
813 | - next_tokens, -1, next_scores_indices | ||
814 | - ) # (batch_size, num_beams * 2) | ||
815 | - | ||
816 | - else: | ||
817 | - next_scores = scores + beam_scores[:, None].expand_as( | ||
818 | - scores | ||
819 | - ) # (batch_size * num_beams, vocab_size) | ||
820 | - | ||
821 | - # re-organize to group the beam together (we are keeping top hypothesis accross beams) | ||
822 | - next_scores = next_scores.view( | ||
823 | - batch_size, num_beams * vocab_size | ||
824 | - ) # (batch_size, num_beams * vocab_size) | ||
825 | - | ||
826 | - next_scores, next_tokens = torch.topk( | ||
827 | - next_scores, 2 * num_beams, dim=1, largest=True, sorted=True | ||
828 | - ) | ||
829 | - | ||
830 | - assert ( | ||
831 | - next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) | ||
832 | - ) | ||
833 | - | ||
834 | - # next batch beam content | ||
835 | - next_batch_beam = [] | ||
836 | - | ||
837 | - # for each sentence | ||
838 | - for batch_idx in range(batch_size): | ||
839 | - | ||
840 | - # if we are done with this sentence, add a pad token | ||
841 | - if done[batch_idx]: | ||
842 | - assert ( | ||
843 | - len(generated_hyps[batch_idx]) >= num_beams | ||
844 | - ), "Batch can only be done if at least {} beams have been generated".format( | ||
845 | - num_beams | ||
846 | - ) | ||
847 | - assert ( | ||
848 | - eos_token_id is not None and pad_token_id is not None | ||
849 | - ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" | ||
850 | - next_batch_beam.extend( | ||
851 | - [(0, pad_token_id, 0)] * num_beams | ||
852 | - ) # pad the batch | ||
853 | - continue | ||
854 | - | ||
855 | - # next sentence beam content, this will get added to next_batch_beam | ||
856 | - next_sent_beam = [] | ||
857 | - | ||
858 | - # next tokens for this sentence | ||
859 | - for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( | ||
860 | - zip(next_tokens[batch_idx], next_scores[batch_idx]) | ||
861 | - ): | ||
862 | - # get beam and token IDs | ||
863 | - beam_id = beam_token_id // vocab_size | ||
864 | - token_id = beam_token_id % vocab_size | ||
865 | - | ||
866 | - effective_beam_id = batch_idx * num_beams + beam_id | ||
867 | - # add to generated hypotheses if end of sentence | ||
868 | - if (eos_token_id is not None) and (token_id.item() == eos_token_id): | ||
869 | - # if beam_token does not belong to top num_beams tokens, it should not be added | ||
870 | - is_beam_token_worse_than_top_num_beams = ( | ||
871 | - beam_token_rank >= num_beams | ||
872 | - ) | ||
873 | - if is_beam_token_worse_than_top_num_beams: | ||
874 | - continue | ||
875 | - generated_hyps[batch_idx].add( | ||
876 | - input_ids[effective_beam_id].clone(), | ||
877 | - beam_token_score.item(), | ||
878 | - ) | ||
879 | - else: | ||
880 | - # add next predicted token since it is not eos_token | ||
881 | - next_sent_beam.append( | ||
882 | - (beam_token_score, token_id, effective_beam_id) | ||
883 | - ) | ||
884 | - | ||
885 | - # once the beam for next step is full, don't add more tokens to it. | ||
886 | - if len(next_sent_beam) == num_beams: | ||
887 | - break | ||
888 | - | ||
889 | - # Check if we are done so that we can save a pad step if all(done) | ||
890 | - done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( | ||
891 | - next_scores[batch_idx].max().item(), cur_len | ||
892 | - ) | ||
893 | - | ||
894 | - # update next beam content | ||
895 | - assert len(next_sent_beam) == num_beams, "Beam should always be full" | ||
896 | - next_batch_beam.extend(next_sent_beam) | ||
897 | - assert len(next_batch_beam) == num_beams * ( | ||
898 | - batch_idx + 1 | ||
899 | - ), "We should have added num_beams each step" | ||
900 | - | ||
901 | - # stop when we are done with each sentence | ||
902 | - if all(done): | ||
903 | - break | ||
904 | - | ||
905 | - # sanity check / prepare next batch | ||
906 | - assert len(next_batch_beam) == batch_size * num_beams | ||
907 | - beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) | ||
908 | - beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) | ||
909 | - beam_idx = input_ids.new([x[2] for x in next_batch_beam]) | ||
910 | - | ||
911 | - # re-order batch and update current length | ||
912 | - input_ids = input_ids[beam_idx, :] | ||
913 | - input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) | ||
914 | - cur_len = cur_len + 1 | ||
915 | - | ||
916 | - # re-order internal states | ||
917 | - if past is not None: | ||
918 | - past = self._reorder_cache(past, beam_idx) | ||
919 | - | ||
920 | - # extend attention_mask for new generated input if only decoder | ||
921 | - if self.config.is_encoder_decoder is False: | ||
922 | - attention_mask = torch.cat( | ||
923 | - [ | ||
924 | - attention_mask, | ||
925 | - attention_mask.new_ones((attention_mask.shape[0], 1)), | ||
926 | - ], | ||
927 | - dim=-1, | ||
928 | - ) | ||
929 | - | ||
930 | - # finalize all open beam hypotheses and add to generated hypotheses | ||
931 | - for batch_idx in range(batch_size): | ||
932 | - if done[batch_idx]: | ||
933 | - continue | ||
934 | - | ||
935 | - # test that beam scores match previously calculated scores if not eos and batch_idx not done | ||
936 | - if eos_token_id is not None and all( | ||
937 | - (token_id % vocab_size).item() != eos_token_id | ||
938 | - for token_id in next_tokens[batch_idx] | ||
939 | - ): | ||
940 | - assert torch.all( | ||
941 | - next_scores[batch_idx, :num_beams] | ||
942 | - == beam_scores.view(batch_size, num_beams)[batch_idx] | ||
943 | - ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( | ||
944 | - next_scores[:, :num_beams][batch_idx], | ||
945 | - beam_scores.view(batch_size, num_beams)[batch_idx], | ||
946 | - ) | ||
947 | - | ||
948 | - # need to add best num_beams hypotheses to generated hyps | ||
949 | - for beam_id in range(num_beams): | ||
950 | - effective_beam_id = batch_idx * num_beams + beam_id | ||
951 | - final_score = beam_scores[effective_beam_id].item() | ||
952 | - final_tokens = input_ids[effective_beam_id] | ||
953 | - generated_hyps[batch_idx].add(final_tokens, final_score) | ||
954 | - | ||
955 | - # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch | ||
956 | - output_batch_size = ( | ||
957 | - batch_size if do_sample else batch_size * num_return_sequences | ||
958 | - ) | ||
959 | - output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences | ||
960 | - | ||
961 | - # select the best hypotheses | ||
962 | - sent_lengths = input_ids.new(output_batch_size) | ||
963 | - best = [] | ||
964 | - | ||
965 | - # retrieve best hypotheses | ||
966 | - for i, hypotheses in enumerate(generated_hyps): | ||
967 | - sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) | ||
968 | - for j in range(output_num_return_sequences_per_batch): | ||
969 | - effective_batch_idx = output_num_return_sequences_per_batch * i + j | ||
970 | - best_hyp = sorted_hyps.pop()[1] | ||
971 | - sent_lengths[effective_batch_idx] = len(best_hyp) | ||
972 | - best.append(best_hyp) | ||
973 | - | ||
974 | - # shorter batches are padded | ||
975 | - if sent_lengths.min().item() != sent_lengths.max().item(): | ||
976 | - assert pad_token_id is not None, "`Pad_token_id` has to be defined" | ||
977 | - sent_max_len = min(sent_lengths.max().item() + 1, max_length) | ||
978 | - decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) | ||
979 | - | ||
980 | - # fill with hypothesis and eos_token_id if necessary | ||
981 | - for i, hypo in enumerate(best): | ||
982 | - decoded[i, : sent_lengths[i]] = hypo | ||
983 | - if sent_lengths[i] < max_length: | ||
984 | - decoded[i, sent_lengths[i]] = eos_token_id | ||
985 | - else: | ||
986 | - # none of the hypotheses have an eos_token | ||
987 | - assert (len(hypo) == max_length for hypo in best) | ||
988 | - decoded = ( | ||
989 | - torch.stack(best).type(torch.long).to(next(self.parameters()).device) | ||
990 | - ) | ||
991 | - | ||
992 | - return decoded | ||
993 | - | ||
994 | - @staticmethod | ||
995 | - def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]: | ||
996 | - return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) | ||
997 | - | ||
998 | - | ||
999 | -def calc_banned_ngram_tokens( | ||
1000 | - prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int | ||
1001 | -) -> None: | ||
1002 | - """Copied from fairseq for no_repeat_ngram in beam_search""" | ||
1003 | - if cur_len + 1 < no_repeat_ngram_size: | ||
1004 | - # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | ||
1005 | - return [[] for _ in range(num_hypos)] | ||
1006 | - generated_ngrams = [{} for _ in range(num_hypos)] | ||
1007 | - for idx in range(num_hypos): | ||
1008 | - gen_tokens = prev_input_ids[idx].tolist() | ||
1009 | - generated_ngram = generated_ngrams[idx] | ||
1010 | - for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): | ||
1011 | - prev_ngram_tuple = tuple(ngram[:-1]) | ||
1012 | - generated_ngram[prev_ngram_tuple] = generated_ngram.get( | ||
1013 | - prev_ngram_tuple, [] | ||
1014 | - ) + [ngram[-1]] | ||
1015 | - | ||
1016 | - def _get_generated_ngrams(hypo_idx): | ||
1017 | - # Before decoding the next token, prevent decoding of ngrams that have already appeared | ||
1018 | - start_idx = cur_len + 1 - no_repeat_ngram_size | ||
1019 | - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) | ||
1020 | - return generated_ngrams[hypo_idx].get(ngram_idx, []) | ||
1021 | - | ||
1022 | - banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] | ||
1023 | - return banned_tokens | ||
1024 | - | ||
1025 | - | ||
1026 | -def calc_banned_bad_words_ids( | ||
1027 | - prev_input_ids: Iterable[int], bad_words_ids: Iterable[int] | ||
1028 | -) -> Iterable[int]: | ||
1029 | - banned_tokens = [] | ||
1030 | - | ||
1031 | - def _tokens_match(prev_tokens, tokens): | ||
1032 | - if len(tokens) == 0: | ||
1033 | - # if bad word tokens is just one token always ban it | ||
1034 | - return True | ||
1035 | - if len(tokens) > len(prev_tokens): | ||
1036 | - # if bad word tokens are longer than prev tokens they can't be equal | ||
1037 | - return False | ||
1038 | - | ||
1039 | - if prev_tokens[-len(tokens) :] == tokens: | ||
1040 | - # if tokens match | ||
1041 | - return True | ||
1042 | - else: | ||
1043 | - return False | ||
1044 | - | ||
1045 | - for prev_input_ids_slice in prev_input_ids: | ||
1046 | - banned_tokens_slice = [] | ||
1047 | - | ||
1048 | - for banned_token_seq in bad_words_ids: | ||
1049 | - assert ( | ||
1050 | - len(banned_token_seq) > 0 | ||
1051 | - ), "Banned words token sequences {} cannot have an empty list".format( | ||
1052 | - bad_words_ids | ||
1053 | - ) | ||
1054 | - | ||
1055 | - if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: | ||
1056 | - # if tokens do not match continue | ||
1057 | - continue | ||
1058 | - | ||
1059 | - banned_tokens_slice.append(banned_token_seq[-1]) | ||
1060 | - | ||
1061 | - banned_tokens.append(banned_tokens_slice) | ||
1062 | - | ||
1063 | - return banned_tokens | ||
1064 | - | ||
1065 | - | ||
1066 | -def set_scores_to_inf_for_banned_tokens( | ||
1067 | - scores: torch.Tensor, banned_tokens: List[List[int]] | ||
1068 | -) -> None: | ||
1069 | - """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be | ||
1070 | - a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] | ||
1071 | - Args: | ||
1072 | - scores: logits distribution of shape (batch size, vocabulary size) | ||
1073 | - banned_tokens: list of list of tokens to ban of length (batch_size) | ||
1074 | - """ | ||
1075 | - banned_mask_list = [] | ||
1076 | - for idx, batch_banned_tokens in enumerate(banned_tokens): | ||
1077 | - for token in batch_banned_tokens: | ||
1078 | - banned_mask_list.append([idx, token]) | ||
1079 | - if not banned_mask_list: | ||
1080 | - return | ||
1081 | - banned_mask = torch.LongTensor(banned_mask_list) | ||
1082 | - indices = torch.ones(len(banned_mask)) | ||
1083 | - # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: | ||
1084 | - # [ 0 1 1 ] | ||
1085 | - # [ 0 0 0 ] | ||
1086 | - # [ 1 0 0 ] | ||
1087 | - | ||
1088 | - banned_mask = ( | ||
1089 | - torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()) | ||
1090 | - .to(scores.device) | ||
1091 | - .to_dense() | ||
1092 | - .bool() | ||
1093 | - ) | ||
1094 | - scores.masked_fill_(banned_mask, -float("inf")) | ||
1095 | - | ||
1096 | - | ||
1097 | -def top_k_top_p_filtering( | ||
1098 | - logits: Tensor, | ||
1099 | - top_k: int = 0, | ||
1100 | - top_p: float = 1.0, | ||
1101 | - filter_value: float = -float("Inf"), | ||
1102 | - min_tokens_to_keep: int = 1, | ||
1103 | -) -> Tensor: | ||
1104 | - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | ||
1105 | - Args: | ||
1106 | - logits: logits distribution shape (batch size, vocabulary size) | ||
1107 | - if top_k > 0: keep only top k tokens with highest probability (top-k filtering). | ||
1108 | - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | ||
1109 | - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | ||
1110 | - Make sure we keep at least min_tokens_to_keep per batch example in the output | ||
1111 | - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | ||
1112 | - """ | ||
1113 | - if top_k > 0: | ||
1114 | - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check | ||
1115 | - # Remove all tokens with a probability less than the last token of the top-k | ||
1116 | - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | ||
1117 | - logits[indices_to_remove] = filter_value | ||
1118 | - | ||
1119 | - if top_p < 1.0: | ||
1120 | - sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
1121 | - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||
1122 | - | ||
1123 | - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) | ||
1124 | - sorted_indices_to_remove = cumulative_probs > top_p | ||
1125 | - if min_tokens_to_keep > 1: | ||
1126 | - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | ||
1127 | - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | ||
1128 | - # Shift the indices to the right to keep also the first token above the threshold | ||
1129 | - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||
1130 | - sorted_indices_to_remove[..., 0] = 0 | ||
1131 | - | ||
1132 | - # scatter sorted tensors to original indexing | ||
1133 | - indices_to_remove = sorted_indices_to_remove.scatter( | ||
1134 | - 1, sorted_indices, sorted_indices_to_remove | ||
1135 | - ) | ||
1136 | - logits[indices_to_remove] = filter_value | ||
1137 | - return logits | ||
1138 | - | ||
1139 | - | ||
1140 | -class BeamHypotheses(object): | ||
1141 | - def __init__(self, num_beams, max_length, length_penalty, early_stopping): | ||
1142 | - """ | ||
1143 | - Initialize n-best list of hypotheses. | ||
1144 | - """ | ||
1145 | - self.max_length = max_length - 1 # ignoring bos_token | ||
1146 | - self.length_penalty = length_penalty | ||
1147 | - self.early_stopping = early_stopping | ||
1148 | - self.num_beams = num_beams | ||
1149 | - self.beams = [] | ||
1150 | - self.worst_score = 1e9 | ||
1151 | - | ||
1152 | - def __len__(self): | ||
1153 | - """ | ||
1154 | - Number of hypotheses in the list. | ||
1155 | - """ | ||
1156 | - return len(self.beams) | ||
1157 | - | ||
1158 | - def add(self, hyp, sum_logprobs): | ||
1159 | - """ | ||
1160 | - Add a new hypothesis to the list. | ||
1161 | - """ | ||
1162 | - score = sum_logprobs / len(hyp) ** self.length_penalty | ||
1163 | - if len(self) < self.num_beams or score > self.worst_score: | ||
1164 | - self.beams.append((score, hyp)) | ||
1165 | - if len(self) > self.num_beams: | ||
1166 | - sorted_scores = sorted( | ||
1167 | - [(s, idx) for idx, (s, _) in enumerate(self.beams)] | ||
1168 | - ) | ||
1169 | - del self.beams[sorted_scores[0][1]] | ||
1170 | - self.worst_score = sorted_scores[1][0] | ||
1171 | - else: | ||
1172 | - self.worst_score = min(score, self.worst_score) | ||
1173 | - | ||
1174 | - def is_done(self, best_sum_logprobs, cur_len): | ||
1175 | - """ | ||
1176 | - If there are enough hypotheses and that none of the hypotheses being generated | ||
1177 | - can become better than the worst one in the heap, then we are done with this sentence. | ||
1178 | - """ | ||
1179 | - | ||
1180 | - if len(self) < self.num_beams: | ||
1181 | - return False | ||
1182 | - elif self.early_stopping: | ||
1183 | - return True | ||
1184 | - else: | ||
1185 | - cur_score = best_sum_logprobs / cur_len ** self.length_penalty | ||
1186 | - ret = self.worst_score >= cur_score | ||
1187 | - return ret |
train/lightning_base.py
deleted
100644 → 0
1 | -import argparse | ||
2 | -import logging | ||
3 | -import os | ||
4 | -from pathlib import Path | ||
5 | -from typing import Any, Dict | ||
6 | - | ||
7 | -import pytorch_lightning as pl | ||
8 | -from pytorch_lightning.utilities import rank_zero_info | ||
9 | - | ||
10 | -from transformers import ( | ||
11 | - AdamW, | ||
12 | - AutoConfig, | ||
13 | - AutoModel, | ||
14 | - AutoModelForPreTraining, | ||
15 | - AutoModelForQuestionAnswering, | ||
16 | - AutoModelForSeq2SeqLM, | ||
17 | - AutoModelForSequenceClassification, | ||
18 | - AutoModelForTokenClassification, | ||
19 | - AutoModelWithLMHead, | ||
20 | - AutoTokenizer, | ||
21 | - PretrainedConfig, | ||
22 | - PreTrainedTokenizer, | ||
23 | -) | ||
24 | -from train.modeling_bart import BartForConditionalGeneration | ||
25 | - | ||
26 | -from transformers.optimization import ( | ||
27 | - Adafactor, | ||
28 | - get_cosine_schedule_with_warmup, | ||
29 | - get_cosine_with_hard_restarts_schedule_with_warmup, | ||
30 | - get_linear_schedule_with_warmup, | ||
31 | - get_polynomial_decay_schedule_with_warmup, | ||
32 | -) | ||
33 | - | ||
34 | - | ||
35 | -logger = logging.getLogger(__name__) | ||
36 | - | ||
37 | - | ||
38 | -MODEL_MODES = { | ||
39 | - "base": AutoModel, | ||
40 | - "sequence-classification": AutoModelForSequenceClassification, | ||
41 | - "question-answering": AutoModelForQuestionAnswering, | ||
42 | - "pretraining": AutoModelForPreTraining, | ||
43 | - "token-classification": AutoModelForTokenClassification, | ||
44 | - "language-modeling": AutoModelWithLMHead, | ||
45 | - "summarization": BartForConditionalGeneration, | ||
46 | - "translation": AutoModelForSeq2SeqLM, | ||
47 | -} | ||
48 | - | ||
49 | - | ||
50 | -# update this and the import above to support new schedulers from transformers.optimization | ||
51 | -arg_to_scheduler = { | ||
52 | - "linear": get_linear_schedule_with_warmup, | ||
53 | - "cosine": get_cosine_schedule_with_warmup, | ||
54 | - "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, | ||
55 | - "polynomial": get_polynomial_decay_schedule_with_warmup, | ||
56 | - # '': get_constant_schedule, # not supported for now | ||
57 | - # '': get_constant_schedule_with_warmup, # not supported for now | ||
58 | -} | ||
59 | -arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) | ||
60 | -arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" | ||
61 | - | ||
62 | - | ||
63 | -class BaseTransformer(pl.LightningModule): | ||
64 | - def __init__( | ||
65 | - self, | ||
66 | - hparams: argparse.Namespace, | ||
67 | - num_labels=None, | ||
68 | - mode="base", | ||
69 | - config=None, | ||
70 | - tokenizer=None, | ||
71 | - model=None, | ||
72 | - **config_kwargs, | ||
73 | - ): | ||
74 | - """Initialize a model, tokenizer and config.""" | ||
75 | - super().__init__() | ||
76 | - # TODO: move to self.save_hyperparameters() | ||
77 | - # self.save_hyperparameters() | ||
78 | - # can also expand arguments into trainer signature for easier reading | ||
79 | - | ||
80 | - self.save_hyperparameters(hparams) | ||
81 | - self.step_count = 0 | ||
82 | - self.output_dir = Path(self.hparams.output_dir) | ||
83 | - cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None | ||
84 | - if config is None: | ||
85 | - self.config = AutoConfig.from_pretrained( | ||
86 | - self.hparams.config_name | ||
87 | - if self.hparams.config_name | ||
88 | - else self.hparams.model_name_or_path, | ||
89 | - **({"num_labels": num_labels} if num_labels is not None else {}), | ||
90 | - cache_dir=cache_dir, | ||
91 | - **config_kwargs, | ||
92 | - ) | ||
93 | - else: | ||
94 | - self.config: PretrainedConfig = config | ||
95 | - | ||
96 | - extra_model_params = ( | ||
97 | - "encoder_layerdrop", | ||
98 | - "decoder_layerdrop", | ||
99 | - "dropout", | ||
100 | - "attention_dropout", | ||
101 | - ) | ||
102 | - for p in extra_model_params: | ||
103 | - if getattr(self.hparams, p, None): | ||
104 | - assert hasattr( | ||
105 | - self.config, p | ||
106 | - ), f"model config doesn't have a `{p}` attribute" | ||
107 | - setattr(self.config, p, getattr(self.hparams, p)) | ||
108 | - | ||
109 | - if tokenizer is None: | ||
110 | - self.tokenizer = AutoTokenizer.from_pretrained( | ||
111 | - self.hparams.tokenizer_name | ||
112 | - if self.hparams.tokenizer_name | ||
113 | - else self.hparams.model_name_or_path, | ||
114 | - cache_dir=cache_dir, | ||
115 | - ) | ||
116 | - else: | ||
117 | - self.tokenizer: PreTrainedTokenizer = tokenizer | ||
118 | - self.model_type = MODEL_MODES[mode] | ||
119 | - if model is None: | ||
120 | - self.model = self.model_type.from_pretrained( | ||
121 | - self.hparams.model_name_or_path, | ||
122 | - from_tf=bool(".ckpt" in self.hparams.model_name_or_path), | ||
123 | - config=self.config, | ||
124 | - cache_dir=cache_dir, | ||
125 | - ) | ||
126 | - else: | ||
127 | - self.model = model | ||
128 | - self.model.resize_token_embeddings(len(tokenizer)) | ||
129 | - | ||
130 | - def load_hf_checkpoint(self, *args, **kwargs): | ||
131 | - self.model = self.model_type.from_pretrained(*args, **kwargs) | ||
132 | - | ||
133 | - def get_lr_scheduler(self): | ||
134 | - get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] | ||
135 | - scheduler = get_schedule_func( | ||
136 | - self.opt, | ||
137 | - num_warmup_steps=self.hparams.warmup_steps, | ||
138 | - num_training_steps=self.total_steps, | ||
139 | - ) | ||
140 | - scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} | ||
141 | - return scheduler | ||
142 | - | ||
143 | - def configure_optimizers(self): | ||
144 | - """Prepare optimizer and schedule (linear warmup and decay)""" | ||
145 | - model = self.model | ||
146 | - no_decay = ["bias", "LayerNorm.weight"] | ||
147 | - optimizer_grouped_parameters = [ | ||
148 | - { | ||
149 | - "params": [ | ||
150 | - p | ||
151 | - for n, p in model.named_parameters() | ||
152 | - if not any(nd in n for nd in no_decay) | ||
153 | - ], | ||
154 | - "weight_decay": self.hparams.weight_decay, | ||
155 | - }, | ||
156 | - { | ||
157 | - "params": [ | ||
158 | - p | ||
159 | - for n, p in model.named_parameters() | ||
160 | - if any(nd in n for nd in no_decay) | ||
161 | - ], | ||
162 | - "weight_decay": 0.0, | ||
163 | - }, | ||
164 | - ] | ||
165 | - if self.hparams.adafactor: | ||
166 | - optimizer = Adafactor( | ||
167 | - optimizer_grouped_parameters, | ||
168 | - lr=self.hparams.learning_rate, | ||
169 | - scale_parameter=False, | ||
170 | - relative_step=False, | ||
171 | - ) | ||
172 | - | ||
173 | - else: | ||
174 | - optimizer = AdamW( | ||
175 | - optimizer_grouped_parameters, | ||
176 | - lr=self.hparams.learning_rate, | ||
177 | - eps=self.hparams.adam_epsilon, | ||
178 | - ) | ||
179 | - self.opt = optimizer | ||
180 | - | ||
181 | - scheduler = self.get_lr_scheduler() | ||
182 | - | ||
183 | - return [optimizer], [scheduler] | ||
184 | - | ||
185 | - def test_step(self, batch, batch_nb): | ||
186 | - return self.validation_step(batch, batch_nb) | ||
187 | - | ||
188 | - def test_epoch_end(self, outputs): | ||
189 | - return self.validation_end(outputs) | ||
190 | - | ||
191 | - @property | ||
192 | - def total_steps(self) -> int: | ||
193 | - """The number of total training steps that will be run. Used for lr scheduler purposes.""" | ||
194 | - num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores | ||
195 | - effective_batch_size = ( | ||
196 | - self.hparams.train_batch_size | ||
197 | - * self.hparams.accumulate_grad_batches | ||
198 | - * num_devices | ||
199 | - ) | ||
200 | - dataset_size = len(self.train_loader.dataset) | ||
201 | - return (dataset_size / effective_batch_size) * self.hparams.max_epochs | ||
202 | - | ||
203 | - def setup(self, mode): | ||
204 | - if mode == "fit": | ||
205 | - self.train_loader = self.get_dataloader( | ||
206 | - "train", self.hparams.train_batch_size, shuffle=True | ||
207 | - ) | ||
208 | - | ||
209 | - def get_dataloader(self, type_path, batch_size, shuffle=False): | ||
210 | - raise NotImplementedError("You must implement this for your task") | ||
211 | - | ||
212 | - def train_dataloader(self): | ||
213 | - return self.train_loader | ||
214 | - | ||
215 | - def val_dataloader(self): | ||
216 | - return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False) | ||
217 | - | ||
218 | - def test_dataloader(self): | ||
219 | - return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False) | ||
220 | - | ||
221 | - def _feature_file(self, mode): | ||
222 | - return os.path.join( | ||
223 | - self.hparams.data_dir, | ||
224 | - "cached_{}_{}_{}".format( | ||
225 | - mode, | ||
226 | - list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), | ||
227 | - str(self.hparams.max_seq_length), | ||
228 | - ), | ||
229 | - ) | ||
230 | - | ||
231 | - @pl.utilities.rank_zero_only | ||
232 | - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | ||
233 | - save_path = self.output_dir.joinpath("best_tfmr") | ||
234 | - self.model.config.save_step = self.step_count | ||
235 | - self.model.save_pretrained(save_path) | ||
236 | - self.tokenizer.save_pretrained(save_path) | ||
237 | - | ||
238 | - @staticmethod | ||
239 | - def add_model_specific_args(parser, root_dir): | ||
240 | - parser.add_argument( | ||
241 | - "--model_name_or_path", | ||
242 | - default=None, | ||
243 | - type=str, | ||
244 | - required=True, | ||
245 | - help="Path to pretrained model or model identifier from huggingface.co/models", | ||
246 | - ) | ||
247 | - parser.add_argument( | ||
248 | - "--config_name", | ||
249 | - default="", | ||
250 | - type=str, | ||
251 | - help="Pretrained config name or path if not the same as model_name", | ||
252 | - ) | ||
253 | - parser.add_argument( | ||
254 | - "--tokenizer_name", | ||
255 | - default=None, | ||
256 | - type=str, | ||
257 | - help="Pretrained tokenizer name or path if not the same as model_name", | ||
258 | - ) | ||
259 | - parser.add_argument( | ||
260 | - "--cache_dir", | ||
261 | - default="", | ||
262 | - type=str, | ||
263 | - help="Where do you want to store the pre-trained models downloaded from s3", | ||
264 | - ) | ||
265 | - parser.add_argument( | ||
266 | - "--encoder_layerdrop", | ||
267 | - type=float, | ||
268 | - help="Encoder layer dropout probability (Optional). Goes into model.config", | ||
269 | - ) | ||
270 | - parser.add_argument( | ||
271 | - "--decoder_layerdrop", | ||
272 | - type=float, | ||
273 | - help="Decoder layer dropout probability (Optional). Goes into model.config", | ||
274 | - ) | ||
275 | - parser.add_argument( | ||
276 | - "--dropout", | ||
277 | - type=float, | ||
278 | - help="Dropout probability (Optional). Goes into model.config", | ||
279 | - ) | ||
280 | - parser.add_argument( | ||
281 | - "--attention_dropout", | ||
282 | - type=float, | ||
283 | - help="Attention dropout probability (Optional). Goes into model.config", | ||
284 | - ) | ||
285 | - parser.add_argument( | ||
286 | - "--learning_rate", | ||
287 | - default=5e-5, | ||
288 | - type=float, | ||
289 | - help="The initial learning rate for Adam.", | ||
290 | - ) | ||
291 | - parser.add_argument( | ||
292 | - "--lr_scheduler", | ||
293 | - default="linear", | ||
294 | - choices=arg_to_scheduler_choices, | ||
295 | - metavar=arg_to_scheduler_metavar, | ||
296 | - type=str, | ||
297 | - help="Learning rate scheduler", | ||
298 | - ) | ||
299 | - parser.add_argument( | ||
300 | - "--weight_decay", | ||
301 | - default=0.0, | ||
302 | - type=float, | ||
303 | - help="Weight decay if we apply some.", | ||
304 | - ) | ||
305 | - parser.add_argument( | ||
306 | - "--adam_epsilon", | ||
307 | - default=1e-8, | ||
308 | - type=float, | ||
309 | - help="Epsilon for Adam optimizer.", | ||
310 | - ) | ||
311 | - parser.add_argument( | ||
312 | - "--warmup_steps", | ||
313 | - default=0, | ||
314 | - type=int, | ||
315 | - help="Linear warmup over warmup_steps.", | ||
316 | - ) | ||
317 | - parser.add_argument( | ||
318 | - "--num_workers", default=4, type=int, help="kwarg passed to DataLoader" | ||
319 | - ) | ||
320 | - parser.add_argument( | ||
321 | - "--num_train_epochs", dest="max_epochs", default=3, type=int | ||
322 | - ) | ||
323 | - parser.add_argument("--train_batch_size", default=32, type=int) | ||
324 | - parser.add_argument("--eval_batch_size", default=32, type=int) | ||
325 | - parser.add_argument("--adafactor", action="store_true") | ||
326 | - | ||
327 | - | ||
328 | -class LoggingCallback(pl.Callback): | ||
329 | - def on_batch_end(self, trainer, pl_module): | ||
330 | - lr_scheduler = trainer.lr_schedulers[0]["scheduler"] | ||
331 | - lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} | ||
332 | - pl_module.logger.log_metrics(lrs) | ||
333 | - | ||
334 | - def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
335 | - rank_zero_info("***** Validation results *****") | ||
336 | - metrics = trainer.callback_metrics | ||
337 | - # Log results | ||
338 | - for key in sorted(metrics): | ||
339 | - if key not in ["log", "progress_bar"]: | ||
340 | - rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) | ||
341 | - | ||
342 | - def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): | ||
343 | - rank_zero_info("***** Test results *****") | ||
344 | - metrics = trainer.callback_metrics | ||
345 | - # Log and save results to file | ||
346 | - output_test_results_file = os.path.join( | ||
347 | - pl_module.hparams.output_dir, "test_results.txt" | ||
348 | - ) | ||
349 | - with open(output_test_results_file, "w") as writer: | ||
350 | - for key in sorted(metrics): | ||
351 | - if key not in ["log", "progress_bar"]: | ||
352 | - rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) | ||
353 | - writer.write("{} = {}\n".format(key, str(metrics[key]))) | ||
354 | - | ||
355 | - | ||
356 | -def add_generic_args(parser, root_dir) -> None: | ||
357 | - # TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser) | ||
358 | - parser.add_argument( | ||
359 | - "--output_dir", | ||
360 | - default=None, | ||
361 | - type=str, | ||
362 | - required=True, | ||
363 | - help="The output directory where the model predictions and checkpoints will be written.", | ||
364 | - ) | ||
365 | - parser.add_argument( | ||
366 | - "--fp16", | ||
367 | - action="store_true", | ||
368 | - help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", | ||
369 | - ) | ||
370 | - | ||
371 | - parser.add_argument( | ||
372 | - "--fp16_opt_level", | ||
373 | - type=str, | ||
374 | - default="O2", | ||
375 | - help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." | ||
376 | - "See details at https://nvidia.github.io/apex/amp.html", | ||
377 | - ) | ||
378 | - parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) | ||
379 | - parser.add_argument( | ||
380 | - "--max_grad_norm", | ||
381 | - dest="gradient_clip_val", | ||
382 | - default=1.0, | ||
383 | - type=float, | ||
384 | - help="Max gradient norm", | ||
385 | - ) | ||
386 | - parser.add_argument( | ||
387 | - "--do_train", action="store_true", help="Whether to run training." | ||
388 | - ) | ||
389 | - parser.add_argument( | ||
390 | - "--do_predict", | ||
391 | - action="store_true", | ||
392 | - help="Whether to run predictions on the test set.", | ||
393 | - ) | ||
394 | - parser.add_argument( | ||
395 | - "--gradient_accumulation_steps", | ||
396 | - dest="accumulate_grad_batches", | ||
397 | - type=int, | ||
398 | - default=1, | ||
399 | - help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
400 | - ) | ||
401 | - parser.add_argument( | ||
402 | - "--seed", type=int, default=42, help="random seed for initialization" | ||
403 | - ) | ||
404 | - | ||
405 | - | ||
406 | -def generic_train( | ||
407 | - model: BaseTransformer, | ||
408 | - args: argparse.Namespace, | ||
409 | - early_stopping_callback=False, | ||
410 | - logger=True, # can pass WandbLogger() here | ||
411 | - extra_callbacks=[], | ||
412 | - checkpoint_callback=None, | ||
413 | - logging_callback=None, | ||
414 | - **extra_train_kwargs, | ||
415 | -): | ||
416 | - pl.seed_everything(args.seed) | ||
417 | - | ||
418 | - # init model | ||
419 | - odir = Path(model.hparams.output_dir) | ||
420 | - odir.mkdir(exist_ok=True) | ||
421 | - | ||
422 | - # add custom checkpoints | ||
423 | - if checkpoint_callback is None: | ||
424 | - checkpoint_callback = pl.callbacks.ModelCheckpoint( | ||
425 | - filepath=args.output_dir, | ||
426 | - prefix="checkpoint", | ||
427 | - monitor="val_loss", | ||
428 | - mode="min", | ||
429 | - save_top_k=1, | ||
430 | - ) | ||
431 | - if logging_callback is None: | ||
432 | - logging_callback = LoggingCallback() | ||
433 | - | ||
434 | - train_params = {} | ||
435 | - | ||
436 | - # TODO: remove with PyTorch 1.6 since pl uses native amp | ||
437 | - if args.fp16: | ||
438 | - train_params["precision"] = 16 | ||
439 | - train_params["amp_level"] = args.fp16_opt_level | ||
440 | - | ||
441 | - if args.gpus > 1: | ||
442 | - train_params["distributed_backend"] = "ddp" | ||
443 | - | ||
444 | - trainer = pl.Trainer.from_argparse_args( | ||
445 | - args, | ||
446 | - weights_summary=None, | ||
447 | - callbacks=[logging_callback] + extra_callbacks, | ||
448 | - logger=logger, | ||
449 | - checkpoint_callback=checkpoint_callback, | ||
450 | - early_stop_callback=early_stopping_callback, | ||
451 | - **train_params, | ||
452 | - ) | ||
453 | - | ||
454 | - if args.do_train: | ||
455 | - trainer.fit(model) | ||
456 | - | ||
457 | - return trainer |
train/modeling_bart.py
deleted
100644 → 0
1 | -# coding=utf-8 | ||
2 | -# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. | ||
3 | -# | ||
4 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | -# you may not use this file except in compliance with the License. | ||
6 | -# You may obtain a copy of the License at | ||
7 | -# | ||
8 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | -# | ||
10 | -# Unless required by applicable law or agreed to in writing, software | ||
11 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | -# See the License for the specific language governing permissions and | ||
14 | -# limitations under the License. | ||
15 | -"""PyTorch BART model, ported from the fairseq repo.""" | ||
16 | -import math | ||
17 | -import random | ||
18 | -import warnings | ||
19 | -from typing import Dict, List, Optional, Tuple | ||
20 | - | ||
21 | -import numpy as np | ||
22 | -import torch | ||
23 | -import torch.nn.functional as F | ||
24 | -from torch import Tensor, nn | ||
25 | -from torch.nn import CrossEntropyLoss | ||
26 | - | ||
27 | -from transformers.activations import ACT2FN | ||
28 | -from transformers.configuration_bart import BartConfig | ||
29 | -from transformers.file_utils import ( | ||
30 | - add_code_sample_docstrings, | ||
31 | - add_end_docstrings, | ||
32 | - add_start_docstrings, | ||
33 | - add_start_docstrings_to_callable, | ||
34 | - replace_return_docstrings, | ||
35 | -) | ||
36 | -from transformers.modeling_outputs import ( | ||
37 | - BaseModelOutput, | ||
38 | - BaseModelOutputWithPast, | ||
39 | - Seq2SeqLMOutput, | ||
40 | - Seq2SeqModelOutput, | ||
41 | - Seq2SeqQuestionAnsweringModelOutput, | ||
42 | - Seq2SeqSequenceClassifierOutput, | ||
43 | -) | ||
44 | -from train.modeling_utils import PreTrainedModel | ||
45 | -import logging | ||
46 | - | ||
47 | -logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
48 | -logging.basicConfig( | ||
49 | - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
50 | - datefmt="%m/%d/%Y %H:%M:%S", | ||
51 | - level=logging.INFO, | ||
52 | -) | ||
53 | - | ||
54 | -_CONFIG_FOR_DOC = "BartConfig" | ||
55 | -_TOKENIZER_FOR_DOC = "BartTokenizer" | ||
56 | - | ||
57 | - | ||
58 | -BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ | ||
59 | - "facebook/bart-base", | ||
60 | - "facebook/bart-large", | ||
61 | - "facebook/bart-large-mnli", | ||
62 | - "facebook/bart-large-cnn", | ||
63 | - "facebook/bart-large-xsum", | ||
64 | - "facebook/mbart-large-en-ro", | ||
65 | -] | ||
66 | -# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart | ||
67 | - | ||
68 | - | ||
69 | -BART_START_DOCSTRING = r""" | ||
70 | - | ||
71 | - This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and | ||
72 | - refer to the PyTorch documentation for all matters related to general usage and behavior. | ||
73 | - | ||
74 | - Parameters: | ||
75 | - config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model. | ||
76 | - Initializing with a config file does not load the weights associated with the model, only the configuration. | ||
77 | - Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. | ||
78 | - | ||
79 | -""" | ||
80 | -BART_GENERATION_EXAMPLE = r""" | ||
81 | - Summarization example:: | ||
82 | - | ||
83 | - from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig | ||
84 | - | ||
85 | - # see ``examples/summarization/bart/run_eval.py`` for a longer example | ||
86 | - model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') | ||
87 | - tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') | ||
88 | - | ||
89 | - ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." | ||
90 | - inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') | ||
91 | - | ||
92 | - # Generate Summary | ||
93 | - summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) | ||
94 | - print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) | ||
95 | - | ||
96 | -""" | ||
97 | - | ||
98 | -BART_INPUTS_DOCSTRING = r""" | ||
99 | - Args: | ||
100 | - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | ||
101 | - Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them. | ||
102 | - Padding will be ignored by default should you provide it. | ||
103 | - Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`. | ||
104 | - attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
105 | - Mask to avoid performing attention on padding token indices in input_ids. | ||
106 | - Mask values selected in ``[0, 1]``: | ||
107 | - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. | ||
108 | - encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): | ||
109 | - Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) | ||
110 | - `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. | ||
111 | - Used in the cross-attention of the decoder. | ||
112 | - decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): | ||
113 | - Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper. | ||
114 | - decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): | ||
115 | - Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. | ||
116 | - If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. | ||
117 | - See diagram 1 in the paper for more info on the default strategy | ||
118 | - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | ||
119 | - Contains pre-computed key and value hidden-states of the attention blocks. | ||
120 | - Can be used to speed up decoding. | ||
121 | - If ``past_key_values`` are used, the user can optionally input only the last | ||
122 | - ``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape | ||
123 | - :obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. | ||
124 | - use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
125 | - If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see | ||
126 | - ``past_key_values``). | ||
127 | - output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): | ||
128 | - If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. | ||
129 | - output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): | ||
130 | - If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail. | ||
131 | - return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`): | ||
132 | - If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a | ||
133 | - plain tuple. | ||
134 | -""" | ||
135 | - | ||
136 | - | ||
137 | -def invert_mask(attention_mask): | ||
138 | - """Turns 1->0, 0->1, False->True, True-> False""" | ||
139 | - assert attention_mask.dim() == 2 | ||
140 | - return attention_mask.eq(0) | ||
141 | - | ||
142 | - | ||
143 | -def _prepare_bart_decoder_inputs( | ||
144 | - config, | ||
145 | - input_ids, | ||
146 | - decoder_input_ids=None, | ||
147 | - decoder_padding_mask=None, | ||
148 | - causal_mask_dtype=torch.float32, | ||
149 | -): | ||
150 | - """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if | ||
151 | - none are provided. This mimics the default behavior in fairseq. To override it pass in masks. | ||
152 | - Note: this is not called during generation | ||
153 | - """ | ||
154 | - pad_token_id = config.pad_token_id | ||
155 | - if decoder_input_ids is None: | ||
156 | - decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) | ||
157 | - bsz, tgt_len = decoder_input_ids.size() | ||
158 | - if decoder_padding_mask is None: | ||
159 | - decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) | ||
160 | - else: | ||
161 | - decoder_padding_mask = invert_mask(decoder_padding_mask) | ||
162 | - if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1: | ||
163 | - # never mask leading token, even if it is pad | ||
164 | - decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1] | ||
165 | - causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( | ||
166 | - dtype=causal_mask_dtype, device=decoder_input_ids.device | ||
167 | - ) | ||
168 | - return decoder_input_ids, decoder_padding_mask, causal_mask | ||
169 | - | ||
170 | - | ||
171 | -class PretrainedBartModel(PreTrainedModel): | ||
172 | - config_class = BartConfig | ||
173 | - base_model_prefix = "model" | ||
174 | - | ||
175 | - def _init_weights(self, module): | ||
176 | - std = self.config.init_std | ||
177 | - if isinstance(module, nn.Linear): | ||
178 | - module.weight.data.normal_(mean=0.0, std=std) | ||
179 | - if module.bias is not None: | ||
180 | - module.bias.data.zero_() | ||
181 | - elif isinstance(module, SinusoidalPositionalEmbedding): | ||
182 | - pass | ||
183 | - elif isinstance(module, nn.Embedding): | ||
184 | - module.weight.data.normal_(mean=0.0, std=std) | ||
185 | - if module.padding_idx is not None: | ||
186 | - module.weight.data[module.padding_idx].zero_() | ||
187 | - | ||
188 | - @property | ||
189 | - def dummy_inputs(self): | ||
190 | - pad_token = self.config.pad_token_id | ||
191 | - input_ids = torch.tensor( | ||
192 | - [[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device | ||
193 | - ) | ||
194 | - dummy_inputs = { | ||
195 | - "attention_mask": input_ids.ne(pad_token), | ||
196 | - "input_ids": input_ids, | ||
197 | - } | ||
198 | - return dummy_inputs | ||
199 | - | ||
200 | - | ||
201 | -def _make_linear_from_emb(emb): | ||
202 | - vocab_size, emb_size = emb.weight.shape | ||
203 | - lin_layer = nn.Linear(vocab_size, emb_size, bias=False) | ||
204 | - lin_layer.weight.data = emb.weight.data | ||
205 | - return lin_layer | ||
206 | - | ||
207 | - | ||
208 | -# Helper Functions, mostly for making masks | ||
209 | -def _check_shapes(shape_1, shape2): | ||
210 | - if shape_1 != shape2: | ||
211 | - raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) | ||
212 | - | ||
213 | - | ||
214 | -def shift_tokens_right(input_ids, pad_token_id): | ||
215 | - """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).""" | ||
216 | - prev_output_tokens = input_ids.clone() | ||
217 | - index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) | ||
218 | - prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() | ||
219 | - prev_output_tokens[:, 1:] = input_ids[:, :-1] | ||
220 | - return prev_output_tokens | ||
221 | - | ||
222 | - | ||
223 | -def make_padding_mask(input_ids, padding_idx=1): | ||
224 | - """True for pad tokens""" | ||
225 | - padding_mask = input_ids.eq(padding_idx) | ||
226 | - if not padding_mask.any(): | ||
227 | - padding_mask = None | ||
228 | - return padding_mask | ||
229 | - | ||
230 | - | ||
231 | -# Helper Modules | ||
232 | - | ||
233 | - | ||
234 | -class EncoderLayer(nn.Module): | ||
235 | - def __init__(self, config: BartConfig): | ||
236 | - super().__init__() | ||
237 | - self.embed_dim = config.d_model | ||
238 | - self.self_attn = Attention( | ||
239 | - self.embed_dim, | ||
240 | - config.encoder_attention_heads, | ||
241 | - dropout=config.attention_dropout, | ||
242 | - ) | ||
243 | - self.normalize_before = config.normalize_before | ||
244 | - self.self_attn_layer_norm = LayerNorm(self.embed_dim) | ||
245 | - self.dropout = config.dropout | ||
246 | - self.activation_fn = ACT2FN[config.activation_function] | ||
247 | - self.activation_dropout = config.activation_dropout | ||
248 | - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) | ||
249 | - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) | ||
250 | - self.final_layer_norm = LayerNorm(self.embed_dim) | ||
251 | - | ||
252 | - def forward(self, x, encoder_padding_mask, output_attentions=False): | ||
253 | - """ | ||
254 | - Args: | ||
255 | - x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | ||
256 | - encoder_padding_mask (ByteTensor): binary ByteTensor of shape | ||
257 | - `(batch, src_len)` where padding elements are indicated by ``1``. | ||
258 | - for t_tgt, t_src is excluded (or masked out), =0 means it is | ||
259 | - included in attention | ||
260 | - | ||
261 | - Returns: | ||
262 | - encoded output of shape `(seq_len, batch, embed_dim)` | ||
263 | - """ | ||
264 | - residual = x | ||
265 | - if self.normalize_before: | ||
266 | - x = self.self_attn_layer_norm(x) | ||
267 | - x, attn_weights = self.self_attn( | ||
268 | - query=x, | ||
269 | - key=x, | ||
270 | - key_padding_mask=encoder_padding_mask, | ||
271 | - output_attentions=output_attentions, | ||
272 | - ) | ||
273 | - x = F.dropout(x, p=self.dropout, training=self.training) | ||
274 | - x = residual + x | ||
275 | - if not self.normalize_before: | ||
276 | - x = self.self_attn_layer_norm(x) | ||
277 | - | ||
278 | - residual = x | ||
279 | - if self.normalize_before: | ||
280 | - x = self.final_layer_norm(x) | ||
281 | - x = self.activation_fn(self.fc1(x)) | ||
282 | - x = F.dropout(x, p=self.activation_dropout, training=self.training) | ||
283 | - x = self.fc2(x) | ||
284 | - x = F.dropout(x, p=self.dropout, training=self.training) | ||
285 | - x = residual + x | ||
286 | - if not self.normalize_before: | ||
287 | - x = self.final_layer_norm(x) | ||
288 | - return x, attn_weights | ||
289 | - | ||
290 | - | ||
291 | -class BartEncoder(nn.Module): | ||
292 | - """ | ||
293 | - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer | ||
294 | - is a :class:`EncoderLayer`. | ||
295 | - | ||
296 | - Args: | ||
297 | - config: BartConfig | ||
298 | - """ | ||
299 | - | ||
300 | - def __init__(self, config: BartConfig, embed_tokens): | ||
301 | - super().__init__() | ||
302 | - | ||
303 | - self.dropout = config.dropout | ||
304 | - self.layerdrop = config.encoder_layerdrop | ||
305 | - | ||
306 | - embed_dim = embed_tokens.embedding_dim | ||
307 | - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 | ||
308 | - self.padding_idx = embed_tokens.padding_idx | ||
309 | - self.max_source_positions = config.max_position_embeddings | ||
310 | - | ||
311 | - self.embed_tokens = embed_tokens | ||
312 | - if config.static_position_embeddings: | ||
313 | - self.embed_positions = SinusoidalPositionalEmbedding( | ||
314 | - config.max_position_embeddings, embed_dim, self.padding_idx | ||
315 | - ) | ||
316 | - else: | ||
317 | - self.embed_positions = LearnedPositionalEmbedding( | ||
318 | - config.max_position_embeddings, | ||
319 | - embed_dim, | ||
320 | - self.padding_idx, | ||
321 | - config.extra_pos_embeddings, | ||
322 | - ) | ||
323 | - self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0) | ||
324 | - self.layers = nn.ModuleList( | ||
325 | - [EncoderLayer(config) for _ in range(config.encoder_layers)] | ||
326 | - ) | ||
327 | - self.layernorm_embedding = ( | ||
328 | - LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() | ||
329 | - ) | ||
330 | - # mbart has one extra layer_norm | ||
331 | - self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None | ||
332 | - | ||
333 | - def forward( | ||
334 | - self, | ||
335 | - input_ids, | ||
336 | - patch_ids, | ||
337 | - attention_mask=None, | ||
338 | - output_attentions=False, | ||
339 | - output_hidden_states=False, | ||
340 | - return_dict=False, | ||
341 | - ): | ||
342 | - """ | ||
343 | - Args: | ||
344 | - input_ids (LongTensor): tokens in the source language of shape | ||
345 | - `(batch, src_len)` | ||
346 | - attention_mask (torch.LongTensor): indicating which indices are padding tokens. | ||
347 | - Returns: | ||
348 | - BaseModelOutput or Tuple comprised of: | ||
349 | - - **x** (Tensor): the last encoder layer's output of | ||
350 | - shape `(src_len, batch, embed_dim)` | ||
351 | - - **encoder_states** (tuple(torch.FloatTensor)): all intermediate | ||
352 | - hidden states of shape `(src_len, batch, embed_dim)`. | ||
353 | - Only populated if *output_hidden_states:* is True. | ||
354 | - - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer. | ||
355 | - During training might not be of length n_layers because of layer dropout. | ||
356 | - """ | ||
357 | - # check attention mask and invert | ||
358 | - if attention_mask is not None: | ||
359 | - attention_mask = invert_mask(attention_mask) | ||
360 | - | ||
361 | - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale | ||
362 | - embed_pos = self.embed_positions(input_ids) | ||
363 | - embed_patch = self.embed_patches(patch_ids) | ||
364 | - x = inputs_embeds + embed_pos + embed_patch | ||
365 | - x = self.layernorm_embedding(x) | ||
366 | - x = F.dropout(x, p=self.dropout, training=self.training) | ||
367 | - | ||
368 | - # B x T x C -> T x B x C | ||
369 | - x = x.transpose(0, 1) | ||
370 | - | ||
371 | - encoder_states = [] if output_hidden_states else None | ||
372 | - all_attentions = () if output_attentions else None | ||
373 | - for encoder_layer in self.layers: | ||
374 | - if output_hidden_states: | ||
375 | - encoder_states.append(x) | ||
376 | - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | ||
377 | - dropout_probability = random.uniform(0, 1) | ||
378 | - if self.training and ( | ||
379 | - dropout_probability < self.layerdrop | ||
380 | - ): # skip the layer | ||
381 | - attn = None | ||
382 | - else: | ||
383 | - x, attn = encoder_layer( | ||
384 | - x, attention_mask, output_attentions=output_attentions | ||
385 | - ) | ||
386 | - | ||
387 | - if output_attentions: | ||
388 | - all_attentions = all_attentions + (attn,) | ||
389 | - | ||
390 | - if self.layer_norm: | ||
391 | - x = self.layer_norm(x) | ||
392 | - if output_hidden_states: | ||
393 | - encoder_states.append(x) | ||
394 | - # T x B x C -> B x T x C | ||
395 | - encoder_states = tuple( | ||
396 | - hidden_state.transpose(0, 1) for hidden_state in encoder_states | ||
397 | - ) | ||
398 | - | ||
399 | - # T x B x C -> B x T x C | ||
400 | - x = x.transpose(0, 1) | ||
401 | - | ||
402 | - if not return_dict: | ||
403 | - return tuple( | ||
404 | - v for v in [x, encoder_states, all_attentions] if v is not None | ||
405 | - ) | ||
406 | - return BaseModelOutput( | ||
407 | - last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions | ||
408 | - ) | ||
409 | - | ||
410 | - | ||
411 | -class DecoderLayer(nn.Module): | ||
412 | - def __init__(self, config: BartConfig): | ||
413 | - super().__init__() | ||
414 | - self.embed_dim = config.d_model | ||
415 | - | ||
416 | - self.self_attn = Attention( | ||
417 | - embed_dim=self.embed_dim, | ||
418 | - num_heads=config.decoder_attention_heads, | ||
419 | - dropout=config.attention_dropout, | ||
420 | - ) | ||
421 | - self.dropout = config.dropout | ||
422 | - self.activation_fn = ACT2FN[config.activation_function] | ||
423 | - self.activation_dropout = config.activation_dropout | ||
424 | - self.normalize_before = config.normalize_before | ||
425 | - | ||
426 | - self.self_attn_layer_norm = LayerNorm(self.embed_dim) | ||
427 | - self.encoder_attn = Attention( | ||
428 | - self.embed_dim, | ||
429 | - config.decoder_attention_heads, | ||
430 | - dropout=config.attention_dropout, | ||
431 | - encoder_decoder_attention=True, | ||
432 | - ) | ||
433 | - self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) | ||
434 | - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) | ||
435 | - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) | ||
436 | - self.final_layer_norm = LayerNorm(self.embed_dim) | ||
437 | - | ||
438 | - def forward( | ||
439 | - self, | ||
440 | - x, | ||
441 | - encoder_hidden_states, | ||
442 | - encoder_attn_mask=None, | ||
443 | - layer_state=None, | ||
444 | - causal_mask=None, | ||
445 | - decoder_padding_mask=None, | ||
446 | - output_attentions=False, | ||
447 | - ): | ||
448 | - residual = x | ||
449 | - | ||
450 | - if layer_state is None: | ||
451 | - layer_state = {} | ||
452 | - if self.normalize_before: | ||
453 | - x = self.self_attn_layer_norm(x) | ||
454 | - # Self Attention | ||
455 | - | ||
456 | - x, self_attn_weights = self.self_attn( | ||
457 | - query=x, | ||
458 | - key=x, | ||
459 | - layer_state=layer_state, # adds keys to layer state | ||
460 | - key_padding_mask=decoder_padding_mask, | ||
461 | - attn_mask=causal_mask, | ||
462 | - output_attentions=output_attentions, | ||
463 | - ) | ||
464 | - x = F.dropout(x, p=self.dropout, training=self.training) | ||
465 | - x = residual + x | ||
466 | - if not self.normalize_before: | ||
467 | - x = self.self_attn_layer_norm(x) | ||
468 | - | ||
469 | - # Cross attention | ||
470 | - residual = x | ||
471 | - assert self.encoder_attn.cache_key != self.self_attn.cache_key | ||
472 | - if self.normalize_before: | ||
473 | - x = self.encoder_attn_layer_norm(x) | ||
474 | - x, _ = self.encoder_attn( | ||
475 | - query=x, | ||
476 | - key=encoder_hidden_states, | ||
477 | - key_padding_mask=encoder_attn_mask, | ||
478 | - layer_state=layer_state, # mutates layer state | ||
479 | - ) | ||
480 | - x = F.dropout(x, p=self.dropout, training=self.training) | ||
481 | - x = residual + x | ||
482 | - if not self.normalize_before: | ||
483 | - x = self.encoder_attn_layer_norm(x) | ||
484 | - | ||
485 | - # Fully Connected | ||
486 | - residual = x | ||
487 | - if self.normalize_before: | ||
488 | - x = self.final_layer_norm(x) | ||
489 | - x = self.activation_fn(self.fc1(x)) | ||
490 | - x = F.dropout(x, p=self.activation_dropout, training=self.training) | ||
491 | - x = self.fc2(x) | ||
492 | - x = F.dropout(x, p=self.dropout, training=self.training) | ||
493 | - x = residual + x | ||
494 | - if not self.normalize_before: | ||
495 | - x = self.final_layer_norm(x) | ||
496 | - return ( | ||
497 | - x, | ||
498 | - self_attn_weights, | ||
499 | - layer_state, | ||
500 | - ) # just self_attn weights for now, following t5, layer_state = cache for decoding | ||
501 | - | ||
502 | - | ||
503 | -class BartDecoder(nn.Module): | ||
504 | - """ | ||
505 | - Transformer decoder consisting of *config.decoder_layers* layers. Each layer | ||
506 | - is a :class:`DecoderLayer`. | ||
507 | - Args: | ||
508 | - config: BartConfig | ||
509 | - embed_tokens (torch.nn.Embedding): output embedding | ||
510 | - """ | ||
511 | - | ||
512 | - def __init__(self, config: BartConfig, embed_tokens: nn.Embedding): | ||
513 | - super().__init__() | ||
514 | - self.dropout = config.dropout | ||
515 | - self.layerdrop = config.decoder_layerdrop | ||
516 | - self.padding_idx = embed_tokens.padding_idx | ||
517 | - self.max_target_positions = config.max_position_embeddings | ||
518 | - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 | ||
519 | - self.embed_tokens = embed_tokens | ||
520 | - if config.static_position_embeddings: | ||
521 | - self.embed_positions = SinusoidalPositionalEmbedding( | ||
522 | - config.max_position_embeddings, config.d_model, config.pad_token_id | ||
523 | - ) | ||
524 | - else: | ||
525 | - self.embed_positions = LearnedPositionalEmbedding( | ||
526 | - config.max_position_embeddings, | ||
527 | - config.d_model, | ||
528 | - self.padding_idx, | ||
529 | - config.extra_pos_embeddings, | ||
530 | - ) | ||
531 | - self.layers = nn.ModuleList( | ||
532 | - [DecoderLayer(config) for _ in range(config.decoder_layers)] | ||
533 | - ) # type: List[DecoderLayer] | ||
534 | - self.layernorm_embedding = ( | ||
535 | - LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() | ||
536 | - ) | ||
537 | - self.layer_norm = ( | ||
538 | - LayerNorm(config.d_model) if config.add_final_layer_norm else None | ||
539 | - ) | ||
540 | - | ||
541 | - def forward( | ||
542 | - self, | ||
543 | - input_ids, | ||
544 | - encoder_hidden_states, | ||
545 | - encoder_padding_mask, | ||
546 | - decoder_padding_mask, | ||
547 | - decoder_causal_mask, | ||
548 | - past_key_values=None, | ||
549 | - use_cache=False, | ||
550 | - output_attentions=False, | ||
551 | - output_hidden_states=False, | ||
552 | - return_dict=False, | ||
553 | - **unused, | ||
554 | - ): | ||
555 | - """ | ||
556 | - Includes several features from "Jointly Learning to Align and | ||
557 | - Translate with Transformer Models" (Garg et al., EMNLP 2019). | ||
558 | - | ||
559 | - Args: | ||
560 | - input_ids (LongTensor): previous decoder outputs of shape | ||
561 | - `(batch, tgt_len)`, for teacher forcing | ||
562 | - encoder_hidden_states: output from the encoder, used for | ||
563 | - encoder-side attention | ||
564 | - encoder_padding_mask: for ignoring pad tokens | ||
565 | - past_key_values (dict or None): dictionary used for storing state during generation | ||
566 | - | ||
567 | - Returns: | ||
568 | - BaseModelOutputWithPast or tuple: | ||
569 | - - the decoder's features of shape `(batch, tgt_len, embed_dim)` | ||
570 | - - the cache | ||
571 | - - hidden states | ||
572 | - - attentions | ||
573 | - """ | ||
574 | - if "decoder_cached_states" in unused: | ||
575 | - warnings.warn( | ||
576 | - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
577 | - FutureWarning, | ||
578 | - ) | ||
579 | - past_key_values = unused.pop("decoder_cached_states") | ||
580 | - if "decoder_past_key_values" in unused: | ||
581 | - warnings.warn( | ||
582 | - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
583 | - FutureWarning, | ||
584 | - ) | ||
585 | - past_key_values = unused.pop("decoder_past_key_values") | ||
586 | - | ||
587 | - # check attention mask and invert | ||
588 | - if encoder_padding_mask is not None: | ||
589 | - encoder_padding_mask = invert_mask(encoder_padding_mask) | ||
590 | - | ||
591 | - # embed positions | ||
592 | - positions = self.embed_positions(input_ids, use_cache=use_cache) | ||
593 | - | ||
594 | - if use_cache: | ||
595 | - input_ids = input_ids[:, -1:] | ||
596 | - positions = positions[:, -1:] # happens after we embed them | ||
597 | - # assert input_ids.ne(self.padding_idx).any() | ||
598 | - | ||
599 | - x = self.embed_tokens(input_ids) * self.embed_scale | ||
600 | - x += positions | ||
601 | - x = self.layernorm_embedding(x) | ||
602 | - x = F.dropout(x, p=self.dropout, training=self.training) | ||
603 | - | ||
604 | - # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) | ||
605 | - x = x.transpose(0, 1) | ||
606 | - encoder_hidden_states = encoder_hidden_states.transpose(0, 1) | ||
607 | - | ||
608 | - # decoder layers | ||
609 | - all_hidden_states = () if output_hidden_states else None | ||
610 | - all_self_attns = () if output_attentions else None | ||
611 | - next_decoder_cache = [] | ||
612 | - for idx, decoder_layer in enumerate(self.layers): | ||
613 | - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | ||
614 | - if output_hidden_states: | ||
615 | - all_hidden_states += (x,) | ||
616 | - dropout_probability = random.uniform(0, 1) | ||
617 | - if self.training and (dropout_probability < self.layerdrop): | ||
618 | - continue | ||
619 | - | ||
620 | - layer_state = past_key_values[idx] if past_key_values is not None else None | ||
621 | - | ||
622 | - x, layer_self_attn, layer_past = decoder_layer( | ||
623 | - x, | ||
624 | - encoder_hidden_states, | ||
625 | - encoder_attn_mask=encoder_padding_mask, | ||
626 | - decoder_padding_mask=decoder_padding_mask, | ||
627 | - layer_state=layer_state, | ||
628 | - causal_mask=decoder_causal_mask, | ||
629 | - output_attentions=output_attentions, | ||
630 | - ) | ||
631 | - | ||
632 | - if use_cache: | ||
633 | - next_decoder_cache.append(layer_past.copy()) | ||
634 | - | ||
635 | - if self.layer_norm and ( | ||
636 | - idx == len(self.layers) - 1 | ||
637 | - ): # if config.add_final_layer_norm (mBART) | ||
638 | - x = self.layer_norm(x) | ||
639 | - if output_attentions: | ||
640 | - all_self_attns += (layer_self_attn,) | ||
641 | - | ||
642 | - # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) | ||
643 | - if output_hidden_states: | ||
644 | - all_hidden_states = tuple( | ||
645 | - hidden_state.transpose(0, 1) for hidden_state in all_hidden_states | ||
646 | - ) | ||
647 | - x = x.transpose(0, 1) | ||
648 | - encoder_hidden_states = encoder_hidden_states.transpose(0, 1) | ||
649 | - | ||
650 | - next_cache = next_decoder_cache if use_cache else None | ||
651 | - | ||
652 | - if not return_dict: | ||
653 | - return tuple( | ||
654 | - v | ||
655 | - for v in [x, next_cache, all_hidden_states, all_self_attns] | ||
656 | - if v is not None | ||
657 | - ) | ||
658 | - return BaseModelOutputWithPast( | ||
659 | - last_hidden_state=x, | ||
660 | - past_key_values=next_cache, | ||
661 | - hidden_states=all_hidden_states, | ||
662 | - attentions=all_self_attns, | ||
663 | - ) | ||
664 | - | ||
665 | - | ||
666 | -def _reorder_buffer(attn_cache, new_order): | ||
667 | - for k, input_buffer_k in attn_cache.items(): | ||
668 | - if input_buffer_k is not None: | ||
669 | - attn_cache[k] = input_buffer_k.index_select(0, new_order) | ||
670 | - return attn_cache | ||
671 | - | ||
672 | - | ||
673 | -class Attention(nn.Module): | ||
674 | - """Multi-headed attention from 'Attention Is All You Need' paper""" | ||
675 | - | ||
676 | - def __init__( | ||
677 | - self, | ||
678 | - embed_dim, | ||
679 | - num_heads, | ||
680 | - dropout=0.0, | ||
681 | - bias=True, | ||
682 | - encoder_decoder_attention=False, # otherwise self_attention | ||
683 | - ): | ||
684 | - super().__init__() | ||
685 | - self.embed_dim = embed_dim | ||
686 | - self.num_heads = num_heads | ||
687 | - self.dropout = dropout | ||
688 | - self.head_dim = embed_dim // num_heads | ||
689 | - assert ( | ||
690 | - self.head_dim * num_heads == self.embed_dim | ||
691 | - ), "embed_dim must be divisible by num_heads" | ||
692 | - self.scaling = self.head_dim ** -0.5 | ||
693 | - | ||
694 | - self.encoder_decoder_attention = encoder_decoder_attention | ||
695 | - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
696 | - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
697 | - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
698 | - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
699 | - self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" | ||
700 | - | ||
701 | - def _shape(self, tensor, seq_len, bsz): | ||
702 | - return ( | ||
703 | - tensor.contiguous() | ||
704 | - .view(seq_len, bsz * self.num_heads, self.head_dim) | ||
705 | - .transpose(0, 1) | ||
706 | - ) | ||
707 | - | ||
708 | - def forward( | ||
709 | - self, | ||
710 | - query, | ||
711 | - key: Optional[Tensor], | ||
712 | - key_padding_mask: Optional[Tensor] = None, | ||
713 | - layer_state: Optional[Dict[str, Optional[Tensor]]] = None, | ||
714 | - attn_mask: Optional[Tensor] = None, | ||
715 | - output_attentions=False, | ||
716 | - ) -> Tuple[Tensor, Optional[Tensor]]: | ||
717 | - """Input shape: Time(SeqLen) x Batch x Channel""" | ||
718 | - static_kv: bool = self.encoder_decoder_attention | ||
719 | - tgt_len, bsz, embed_dim = query.size() | ||
720 | - assert embed_dim == self.embed_dim | ||
721 | - assert list(query.size()) == [tgt_len, bsz, embed_dim] | ||
722 | - # get here for encoder decoder cause of static_kv | ||
723 | - if layer_state is not None: # reuse k,v and encoder_padding_mask | ||
724 | - saved_state = layer_state.get(self.cache_key, {}) | ||
725 | - if "prev_key" in saved_state and static_kv: | ||
726 | - # previous time steps are cached - no need to recompute key and value if they are static | ||
727 | - key = None | ||
728 | - else: | ||
729 | - saved_state = None | ||
730 | - layer_state = {} | ||
731 | - | ||
732 | - q = self.q_proj(query) * self.scaling | ||
733 | - if static_kv: | ||
734 | - if key is None: | ||
735 | - k = v = None | ||
736 | - else: | ||
737 | - k = self.k_proj(key) | ||
738 | - v = self.v_proj(key) | ||
739 | - else: | ||
740 | - k = self.k_proj(query) | ||
741 | - v = self.v_proj(query) | ||
742 | - | ||
743 | - q = self._shape(q, tgt_len, bsz) | ||
744 | - if k is not None: | ||
745 | - k = self._shape(k, -1, bsz) | ||
746 | - if v is not None: | ||
747 | - v = self._shape(v, -1, bsz) | ||
748 | - | ||
749 | - if saved_state is not None: | ||
750 | - k, v, key_padding_mask = self._use_saved_state( | ||
751 | - k, v, saved_state, key_padding_mask, static_kv, bsz | ||
752 | - ) | ||
753 | - | ||
754 | - # Update cache | ||
755 | - layer_state[self.cache_key] = { | ||
756 | - "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), | ||
757 | - "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), | ||
758 | - "prev_key_padding_mask": key_padding_mask if not static_kv else None, | ||
759 | - } | ||
760 | - | ||
761 | - assert k is not None | ||
762 | - src_len = k.size(1) | ||
763 | - attn_weights = torch.bmm(q, k.transpose(1, 2)) | ||
764 | - assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) | ||
765 | - | ||
766 | - if attn_mask is not None: | ||
767 | - attn_weights = ( | ||
768 | - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask | ||
769 | - ) | ||
770 | - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | ||
771 | - | ||
772 | - # This is part of a workaround to get around fork/join parallelism not supporting Optional types. | ||
773 | - if key_padding_mask is not None and key_padding_mask.dim() == 0: | ||
774 | - key_padding_mask = None | ||
775 | - assert key_padding_mask is None or key_padding_mask.size()[:2] == ( | ||
776 | - bsz, | ||
777 | - src_len, | ||
778 | - ) | ||
779 | - | ||
780 | - if key_padding_mask is not None: # don't attend to padding symbols | ||
781 | - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | ||
782 | - reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2) | ||
783 | - attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) | ||
784 | - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | ||
785 | - attn_weights = F.softmax(attn_weights, dim=-1) | ||
786 | - attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,) | ||
787 | - | ||
788 | - assert v is not None | ||
789 | - attn_output = torch.bmm(attn_probs, v) | ||
790 | - assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) | ||
791 | - attn_output = ( | ||
792 | - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | ||
793 | - ) | ||
794 | - attn_output = self.out_proj(attn_output) | ||
795 | - if output_attentions: | ||
796 | - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | ||
797 | - else: | ||
798 | - attn_weights = None | ||
799 | - return attn_output, attn_weights | ||
800 | - | ||
801 | - def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): | ||
802 | - # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) | ||
803 | - if "prev_key" in saved_state: | ||
804 | - _prev_key = saved_state["prev_key"] | ||
805 | - assert _prev_key is not None | ||
806 | - prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) | ||
807 | - if static_kv: | ||
808 | - k = prev_key | ||
809 | - else: | ||
810 | - assert k is not None | ||
811 | - k = torch.cat([prev_key, k], dim=1) | ||
812 | - if "prev_value" in saved_state: | ||
813 | - _prev_value = saved_state["prev_value"] | ||
814 | - assert _prev_value is not None | ||
815 | - prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) | ||
816 | - if static_kv: | ||
817 | - v = prev_value | ||
818 | - else: | ||
819 | - assert v is not None | ||
820 | - v = torch.cat([prev_value, v], dim=1) | ||
821 | - assert k is not None and v is not None | ||
822 | - prev_key_padding_mask: Optional[Tensor] = saved_state.get( | ||
823 | - "prev_key_padding_mask", None | ||
824 | - ) | ||
825 | - if prev_key_padding_mask is not None: | ||
826 | - if static_kv: | ||
827 | - new_key_padding_mask = prev_key_padding_mask | ||
828 | - else: | ||
829 | - new_key_padding_mask = torch.cat( | ||
830 | - [prev_key_padding_mask, key_padding_mask], dim=1 | ||
831 | - ) | ||
832 | - else: | ||
833 | - new_key_padding_mask = key_padding_mask | ||
834 | - return k, v, new_key_padding_mask | ||
835 | - | ||
836 | - | ||
837 | -class BartClassificationHead(nn.Module): | ||
838 | - """Head for sentence-level classification tasks.""" | ||
839 | - | ||
840 | - # This can trivially be shared with RobertaClassificationHead | ||
841 | - | ||
842 | - def __init__( | ||
843 | - self, input_dim, inner_dim, num_classes, pooler_dropout, | ||
844 | - ): | ||
845 | - super().__init__() | ||
846 | - self.dense = nn.Linear(input_dim, inner_dim) | ||
847 | - self.dropout = nn.Dropout(p=pooler_dropout) | ||
848 | - self.out_proj = nn.Linear(inner_dim, num_classes) | ||
849 | - | ||
850 | - def forward(self, x): | ||
851 | - x = self.dropout(x) | ||
852 | - x = self.dense(x) | ||
853 | - x = torch.tanh(x) | ||
854 | - x = self.dropout(x) | ||
855 | - x = self.out_proj(x) | ||
856 | - return x | ||
857 | - | ||
858 | - | ||
859 | -class LearnedPositionalEmbedding(nn.Embedding): | ||
860 | - """ | ||
861 | - This module learns positional embeddings up to a fixed maximum size. | ||
862 | - Padding ids are ignored by either offsetting based on padding_idx | ||
863 | - or by setting padding_idx to None and ensuring that the appropriate | ||
864 | - position ids are passed to the forward function. | ||
865 | - """ | ||
866 | - | ||
867 | - def __init__( | ||
868 | - self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset | ||
869 | - ): | ||
870 | - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 | ||
871 | - # and adjust num_embeddings appropriately. Other models dont have this hack | ||
872 | - self.offset = offset | ||
873 | - assert padding_idx is not None | ||
874 | - num_embeddings += offset | ||
875 | - super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) | ||
876 | - | ||
877 | - def forward(self, input_ids, use_cache=False): | ||
878 | - """Input is expected to be of size [bsz x seqlen].""" | ||
879 | - bsz, seq_len = input_ids.shape[:2] | ||
880 | - if use_cache: | ||
881 | - positions = input_ids.data.new(1, 1).fill_( | ||
882 | - seq_len - 1 | ||
883 | - ) # called before slicing | ||
884 | - else: | ||
885 | - # starts at 0, ends at 1-seq_len | ||
886 | - positions = torch.arange( | ||
887 | - seq_len, dtype=torch.long, device=self.weight.device | ||
888 | - ) | ||
889 | - return super().forward(positions + self.offset) | ||
890 | - | ||
891 | - | ||
892 | -def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): | ||
893 | - if torch.cuda.is_available(): | ||
894 | - try: | ||
895 | - from apex.normalization import FusedLayerNorm | ||
896 | - | ||
897 | - return FusedLayerNorm(normalized_shape, eps, elementwise_affine) | ||
898 | - except ImportError: | ||
899 | - pass | ||
900 | - return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) | ||
901 | - | ||
902 | - | ||
903 | -def fill_with_neg_inf(t): | ||
904 | - """FP16-compatible function that fills a input_ids with -inf.""" | ||
905 | - return t.float().fill_(float("-inf")).type_as(t) | ||
906 | - | ||
907 | - | ||
908 | -# Public API | ||
909 | -def _get_shape(t): | ||
910 | - return getattr(t, "shape", None) | ||
911 | - | ||
912 | - | ||
913 | -@add_start_docstrings( | ||
914 | - "The bare BART Model outputting raw hidden-states without any specific head on top.", | ||
915 | - BART_START_DOCSTRING, | ||
916 | -) | ||
917 | -class BartModel(PretrainedBartModel): | ||
918 | - def __init__(self, config: BartConfig): | ||
919 | - super().__init__(config) | ||
920 | - | ||
921 | - padding_idx, vocab_size = config.pad_token_id, config.vocab_size | ||
922 | - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) | ||
923 | - | ||
924 | - self.encoder = BartEncoder(config, self.shared) | ||
925 | - self.decoder = BartDecoder(config, self.shared) | ||
926 | - | ||
927 | - self.init_weights() | ||
928 | - | ||
929 | - @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
930 | - @add_code_sample_docstrings( | ||
931 | - tokenizer_class=_TOKENIZER_FOR_DOC, | ||
932 | - checkpoint="facebook/bart-large", | ||
933 | - output_type=BaseModelOutputWithPast, | ||
934 | - config_class=_CONFIG_FOR_DOC, | ||
935 | - ) | ||
936 | - def forward( | ||
937 | - self, | ||
938 | - input_ids, | ||
939 | - patch_ids=None, | ||
940 | - attention_mask=None, | ||
941 | - decoder_input_ids=None, | ||
942 | - encoder_outputs: Optional[Tuple] = None, | ||
943 | - decoder_attention_mask=None, | ||
944 | - past_key_values=None, | ||
945 | - use_cache=None, | ||
946 | - output_attentions=None, | ||
947 | - output_hidden_states=None, | ||
948 | - return_dict=None, | ||
949 | - **kwargs, | ||
950 | - ): | ||
951 | - if "decoder_past_key_values" in kwargs: | ||
952 | - warnings.warn( | ||
953 | - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
954 | - FutureWarning, | ||
955 | - ) | ||
956 | - past_key_values = kwargs.pop("decoder_past_key_values") | ||
957 | - | ||
958 | - if decoder_input_ids is None: | ||
959 | - use_cache = False | ||
960 | - | ||
961 | - output_attentions = ( | ||
962 | - output_attentions | ||
963 | - if output_attentions is not None | ||
964 | - else self.config.output_attentions | ||
965 | - ) | ||
966 | - output_hidden_states = ( | ||
967 | - output_hidden_states | ||
968 | - if output_hidden_states is not None | ||
969 | - else self.config.output_hidden_states | ||
970 | - ) | ||
971 | - use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
972 | - return_dict = ( | ||
973 | - return_dict if return_dict is not None else self.config.use_return_dict | ||
974 | - ) | ||
975 | - | ||
976 | - # make masks if user doesn't supply | ||
977 | - if not use_cache: | ||
978 | - ( | ||
979 | - decoder_input_ids, | ||
980 | - decoder_padding_mask, | ||
981 | - causal_mask, | ||
982 | - ) = _prepare_bart_decoder_inputs( | ||
983 | - self.config, | ||
984 | - input_ids, | ||
985 | - decoder_input_ids=decoder_input_ids, | ||
986 | - decoder_padding_mask=decoder_attention_mask, | ||
987 | - causal_mask_dtype=self.shared.weight.dtype, | ||
988 | - ) | ||
989 | - else: | ||
990 | - decoder_padding_mask, causal_mask = None, None | ||
991 | - | ||
992 | - assert decoder_input_ids is not None | ||
993 | - | ||
994 | - if encoder_outputs is None: | ||
995 | - encoder_outputs = self.encoder( | ||
996 | - input_ids=input_ids, | ||
997 | - patch_ids=patch_ids, | ||
998 | - attention_mask=attention_mask, | ||
999 | - output_attentions=output_attentions, | ||
1000 | - output_hidden_states=output_hidden_states, | ||
1001 | - return_dict=return_dict, | ||
1002 | - ) | ||
1003 | - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False | ||
1004 | - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | ||
1005 | - encoder_outputs = BaseModelOutput( | ||
1006 | - last_hidden_state=encoder_outputs[0], | ||
1007 | - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, | ||
1008 | - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | ||
1009 | - ) | ||
1010 | - | ||
1011 | - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
1012 | - decoder_outputs = self.decoder( | ||
1013 | - decoder_input_ids, | ||
1014 | - encoder_outputs[0], | ||
1015 | - attention_mask, | ||
1016 | - decoder_padding_mask, | ||
1017 | - decoder_causal_mask=causal_mask, | ||
1018 | - past_key_values=past_key_values, | ||
1019 | - use_cache=use_cache, | ||
1020 | - output_attentions=output_attentions, | ||
1021 | - output_hidden_states=output_hidden_states, | ||
1022 | - return_dict=return_dict, | ||
1023 | - ) | ||
1024 | - | ||
1025 | - if not return_dict: | ||
1026 | - return decoder_outputs + encoder_outputs | ||
1027 | - | ||
1028 | - return Seq2SeqModelOutput( | ||
1029 | - last_hidden_state=decoder_outputs.last_hidden_state, | ||
1030 | - past_key_values=decoder_outputs.past_key_values, | ||
1031 | - decoder_hidden_states=decoder_outputs.hidden_states, | ||
1032 | - decoder_attentions=decoder_outputs.attentions, | ||
1033 | - encoder_last_hidden_state=encoder_outputs.last_hidden_state, | ||
1034 | - encoder_hidden_states=encoder_outputs.hidden_states, | ||
1035 | - encoder_attentions=encoder_outputs.attentions, | ||
1036 | - ) | ||
1037 | - | ||
1038 | - def get_input_embeddings(self): | ||
1039 | - return self.shared | ||
1040 | - | ||
1041 | - def set_input_embeddings(self, value): | ||
1042 | - self.shared = value | ||
1043 | - self.encoder.embed_tokens = self.shared | ||
1044 | - self.decoder.embed_tokens = self.shared | ||
1045 | - | ||
1046 | - def get_output_embeddings(self): | ||
1047 | - return _make_linear_from_emb(self.shared) # make it on the fly | ||
1048 | - | ||
1049 | - | ||
1050 | -@add_start_docstrings( | ||
1051 | - "The BART Model with a language modeling head. Can be used for summarization.", | ||
1052 | - BART_START_DOCSTRING, | ||
1053 | -) | ||
1054 | -class BartForConditionalGeneration(PretrainedBartModel): | ||
1055 | - base_model_prefix = "model" | ||
1056 | - authorized_missing_keys = [ | ||
1057 | - r"final_logits_bias", | ||
1058 | - r"encoder\.version", | ||
1059 | - r"decoder\.version", | ||
1060 | - ] | ||
1061 | - | ||
1062 | - def __init__(self, config: BartConfig): | ||
1063 | - super().__init__(config) | ||
1064 | - base_model = BartModel(config) | ||
1065 | - self.model = base_model | ||
1066 | - self.register_buffer( | ||
1067 | - "final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)) | ||
1068 | - ) | ||
1069 | - | ||
1070 | - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: | ||
1071 | - old_num_tokens = self.model.shared.num_embeddings | ||
1072 | - new_embeddings = super().resize_token_embeddings(new_num_tokens) | ||
1073 | - self.model.shared = new_embeddings | ||
1074 | - self._resize_final_logits_bias(new_num_tokens, old_num_tokens) | ||
1075 | - return new_embeddings | ||
1076 | - | ||
1077 | - def _resize_final_logits_bias( | ||
1078 | - self, new_num_tokens: int, old_num_tokens: int | ||
1079 | - ) -> None: | ||
1080 | - if new_num_tokens <= old_num_tokens: | ||
1081 | - new_bias = self.final_logits_bias[:, :new_num_tokens] | ||
1082 | - else: | ||
1083 | - extra_bias = torch.zeros( | ||
1084 | - (1, new_num_tokens - old_num_tokens), | ||
1085 | - device=self.final_logits_bias.device, | ||
1086 | - ) | ||
1087 | - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) | ||
1088 | - self.register_buffer("final_logits_bias", new_bias) | ||
1089 | - | ||
1090 | - @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
1091 | - @replace_return_docstrings( | ||
1092 | - output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC | ||
1093 | - ) | ||
1094 | - @add_end_docstrings(BART_GENERATION_EXAMPLE) | ||
1095 | - def forward( | ||
1096 | - self, | ||
1097 | - input_ids, | ||
1098 | - patch_ids, | ||
1099 | - attention_mask=None, | ||
1100 | - encoder_outputs=None, | ||
1101 | - decoder_input_ids=None, | ||
1102 | - decoder_attention_mask=None, | ||
1103 | - past_key_values=None, | ||
1104 | - labels=None, | ||
1105 | - use_cache=None, | ||
1106 | - output_attentions=None, | ||
1107 | - output_hidden_states=None, | ||
1108 | - return_dict=None, | ||
1109 | - **unused, | ||
1110 | - ): | ||
1111 | - r""" | ||
1112 | - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
1113 | - Labels for computing the masked language modeling loss. | ||
1114 | - Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). | ||
1115 | - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens | ||
1116 | - with labels in ``[0, ..., config.vocab_size]``. | ||
1117 | - | ||
1118 | - Returns: | ||
1119 | - | ||
1120 | - Conditional generation example:: | ||
1121 | - | ||
1122 | - # Mask filling only works for bart-large | ||
1123 | - from transformers import BartTokenizer, BartForConditionalGeneration | ||
1124 | - tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | ||
1125 | - TXT = "My friends are <mask> but they eat too many carbs." | ||
1126 | - | ||
1127 | - model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') | ||
1128 | - input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] | ||
1129 | - logits = model(input_ids).logits | ||
1130 | - | ||
1131 | - masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | ||
1132 | - probs = logits[0, masked_index].softmax(dim=0) | ||
1133 | - values, predictions = probs.topk(5) | ||
1134 | - | ||
1135 | - tokenizer.decode(predictions).split() | ||
1136 | - # ['good', 'great', 'all', 'really', 'very'] | ||
1137 | - """ | ||
1138 | - if "lm_labels" in unused: | ||
1139 | - warnings.warn( | ||
1140 | - "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", | ||
1141 | - FutureWarning, | ||
1142 | - ) | ||
1143 | - labels = unused.pop("lm_labels") | ||
1144 | - if "decoder_cached_states" in unused: | ||
1145 | - warnings.warn( | ||
1146 | - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
1147 | - FutureWarning, | ||
1148 | - ) | ||
1149 | - past_key_values = unused.pop("decoder_cached_states") | ||
1150 | - if "decoder_past_key_values" in unused: | ||
1151 | - warnings.warn( | ||
1152 | - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", | ||
1153 | - FutureWarning, | ||
1154 | - ) | ||
1155 | - past_key_values = unused.pop("decoder_past_key_values") | ||
1156 | - return_dict = ( | ||
1157 | - return_dict if return_dict is not None else self.config.use_return_dict | ||
1158 | - ) | ||
1159 | - | ||
1160 | - if labels is not None: | ||
1161 | - use_cache = False | ||
1162 | - if decoder_input_ids is None: | ||
1163 | - decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) | ||
1164 | - | ||
1165 | - outputs = self.model( | ||
1166 | - input_ids, | ||
1167 | - patch_ids=patch_ids, | ||
1168 | - attention_mask=attention_mask, | ||
1169 | - decoder_input_ids=decoder_input_ids, | ||
1170 | - encoder_outputs=encoder_outputs, | ||
1171 | - decoder_attention_mask=decoder_attention_mask, | ||
1172 | - past_key_values=past_key_values, | ||
1173 | - use_cache=use_cache, | ||
1174 | - output_attentions=output_attentions, | ||
1175 | - output_hidden_states=output_hidden_states, | ||
1176 | - return_dict=return_dict, | ||
1177 | - ) | ||
1178 | - lm_logits = F.linear( | ||
1179 | - outputs[0], self.model.shared.weight, bias=self.final_logits_bias | ||
1180 | - ) | ||
1181 | - | ||
1182 | - masked_lm_loss = None | ||
1183 | - if labels is not None: | ||
1184 | - loss_fct = CrossEntropyLoss() | ||
1185 | - # TODO(SS): do we need to ignore pad tokens in labels? | ||
1186 | - masked_lm_loss = loss_fct( | ||
1187 | - lm_logits.view(-1, self.config.vocab_size), labels.view(-1) | ||
1188 | - ) | ||
1189 | - | ||
1190 | - if not return_dict: | ||
1191 | - output = (lm_logits,) + outputs[1:] | ||
1192 | - return ( | ||
1193 | - ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | ||
1194 | - ) | ||
1195 | - | ||
1196 | - return Seq2SeqLMOutput( | ||
1197 | - loss=masked_lm_loss, | ||
1198 | - logits=lm_logits, | ||
1199 | - past_key_values=outputs.past_key_values, | ||
1200 | - decoder_hidden_states=outputs.decoder_hidden_states, | ||
1201 | - decoder_attentions=outputs.decoder_attentions, | ||
1202 | - encoder_last_hidden_state=outputs.encoder_last_hidden_state, | ||
1203 | - encoder_hidden_states=outputs.encoder_hidden_states, | ||
1204 | - encoder_attentions=outputs.encoder_attentions, | ||
1205 | - ) | ||
1206 | - | ||
1207 | - def prepare_inputs_for_generation( | ||
1208 | - self, | ||
1209 | - decoder_input_ids, | ||
1210 | - past, | ||
1211 | - attention_mask, | ||
1212 | - use_cache, | ||
1213 | - encoder_outputs, | ||
1214 | - **kwargs, | ||
1215 | - ): | ||
1216 | - return { | ||
1217 | - "input_ids": None, # encoder_outputs is defined. input_ids not needed | ||
1218 | - "patch_ids": None, # encoder_outputs is defined. input_ids not needed | ||
1219 | - "encoder_outputs": encoder_outputs, | ||
1220 | - "past_key_values": past, | ||
1221 | - "decoder_input_ids": decoder_input_ids, | ||
1222 | - "attention_mask": attention_mask, | ||
1223 | - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) | ||
1224 | - } | ||
1225 | - | ||
1226 | - def adjust_logits_during_generation(self, logits, cur_len, max_length): | ||
1227 | - if cur_len == 1 and self.config.force_bos_token_to_be_generated: | ||
1228 | - self._force_token_ids_generation(logits, self.config.bos_token_id) | ||
1229 | - elif cur_len == max_length - 1 and self.config.eos_token_id is not None: | ||
1230 | - self._force_token_ids_generation(logits, self.config.eos_token_id) | ||
1231 | - return logits | ||
1232 | - | ||
1233 | - def _force_token_ids_generation(self, scores, token_id) -> None: | ||
1234 | - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" | ||
1235 | - scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float( | ||
1236 | - "inf" | ||
1237 | - ) | ||
1238 | - | ||
1239 | - @staticmethod | ||
1240 | - def _reorder_cache(past, beam_idx): | ||
1241 | - reordered_past = [] | ||
1242 | - for layer_past in past: | ||
1243 | - # get the correct batch idx from decoder layer's batch dim for cross and self-attn | ||
1244 | - layer_past_new = { | ||
1245 | - attn_key: _reorder_buffer(attn_cache, beam_idx) | ||
1246 | - for attn_key, attn_cache in layer_past.items() | ||
1247 | - } | ||
1248 | - reordered_past.append(layer_past_new) | ||
1249 | - return reordered_past | ||
1250 | - | ||
1251 | - def get_encoder(self): | ||
1252 | - return self.model.encoder | ||
1253 | - | ||
1254 | - def get_output_embeddings(self): | ||
1255 | - return _make_linear_from_emb(self.model.shared) # make it on the fly | ||
1256 | - | ||
1257 | - | ||
1258 | -@add_start_docstrings( | ||
1259 | - """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, | ||
1260 | - BART_START_DOCSTRING, | ||
1261 | -) | ||
1262 | -class BartForSequenceClassification(PretrainedBartModel): | ||
1263 | - def __init__(self, config: BartConfig, **kwargs): | ||
1264 | - super().__init__(config, **kwargs) | ||
1265 | - self.model = BartModel(config) | ||
1266 | - self.classification_head = BartClassificationHead( | ||
1267 | - config.d_model, config.d_model, config.num_labels, config.classif_dropout, | ||
1268 | - ) | ||
1269 | - self.model._init_weights(self.classification_head.dense) | ||
1270 | - self.model._init_weights(self.classification_head.out_proj) | ||
1271 | - | ||
1272 | - @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
1273 | - @add_code_sample_docstrings( | ||
1274 | - tokenizer_class=_TOKENIZER_FOR_DOC, | ||
1275 | - checkpoint="facebook/bart-large", | ||
1276 | - output_type=Seq2SeqSequenceClassifierOutput, | ||
1277 | - config_class=_CONFIG_FOR_DOC, | ||
1278 | - ) | ||
1279 | - def forward( | ||
1280 | - self, | ||
1281 | - input_ids, | ||
1282 | - attention_mask=None, | ||
1283 | - encoder_outputs=None, | ||
1284 | - decoder_input_ids=None, | ||
1285 | - decoder_attention_mask=None, | ||
1286 | - labels=None, | ||
1287 | - use_cache=None, | ||
1288 | - output_attentions=None, | ||
1289 | - output_hidden_states=None, | ||
1290 | - return_dict=None, | ||
1291 | - ): | ||
1292 | - r""" | ||
1293 | - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | ||
1294 | - Labels for computing the sequence classification/regression loss. | ||
1295 | - Indices should be in :obj:`[0, ..., config.num_labels - 1]`. | ||
1296 | - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
1297 | - """ | ||
1298 | - return_dict = ( | ||
1299 | - return_dict if return_dict is not None else self.config.use_return_dict | ||
1300 | - ) | ||
1301 | - if labels is not None: | ||
1302 | - use_cache = False | ||
1303 | - | ||
1304 | - outputs = self.model( | ||
1305 | - input_ids, | ||
1306 | - attention_mask=attention_mask, | ||
1307 | - decoder_input_ids=decoder_input_ids, | ||
1308 | - decoder_attention_mask=decoder_attention_mask, | ||
1309 | - encoder_outputs=encoder_outputs, | ||
1310 | - use_cache=use_cache, | ||
1311 | - output_attentions=output_attentions, | ||
1312 | - output_hidden_states=output_hidden_states, | ||
1313 | - return_dict=return_dict, | ||
1314 | - ) | ||
1315 | - x = outputs[0] # last hidden state | ||
1316 | - eos_mask = input_ids.eq(self.config.eos_token_id) | ||
1317 | - if len(torch.unique(eos_mask.sum(1))) > 1: | ||
1318 | - raise ValueError("All examples must have the same number of <eos> tokens.") | ||
1319 | - sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[ | ||
1320 | - :, -1, : | ||
1321 | - ] | ||
1322 | - logits = self.classification_head(sentence_representation) | ||
1323 | - | ||
1324 | - loss = None | ||
1325 | - if labels is not None: | ||
1326 | - loss_fct = CrossEntropyLoss() | ||
1327 | - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) | ||
1328 | - | ||
1329 | - if not return_dict: | ||
1330 | - output = (logits,) + outputs[1:] | ||
1331 | - return ((loss,) + output) if loss is not None else output | ||
1332 | - | ||
1333 | - return Seq2SeqSequenceClassifierOutput( | ||
1334 | - loss=loss, | ||
1335 | - logits=logits, | ||
1336 | - past_key_values=outputs.past_key_values, | ||
1337 | - decoder_hidden_states=outputs.decoder_hidden_states, | ||
1338 | - decoder_attentions=outputs.decoder_attentions, | ||
1339 | - encoder_last_hidden_state=outputs.encoder_last_hidden_state, | ||
1340 | - encoder_hidden_states=outputs.encoder_hidden_states, | ||
1341 | - encoder_attentions=outputs.encoder_attentions, | ||
1342 | - ) | ||
1343 | - | ||
1344 | - | ||
1345 | -@add_start_docstrings( | ||
1346 | - """BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of | ||
1347 | - the hidden-states output to compute `span start logits` and `span end logits`). """, | ||
1348 | - BART_START_DOCSTRING, | ||
1349 | -) | ||
1350 | -class BartForQuestionAnswering(PretrainedBartModel): | ||
1351 | - def __init__(self, config): | ||
1352 | - super().__init__(config) | ||
1353 | - | ||
1354 | - config.num_labels = 2 | ||
1355 | - self.num_labels = config.num_labels | ||
1356 | - | ||
1357 | - self.model = BartModel(config) | ||
1358 | - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) | ||
1359 | - | ||
1360 | - self.model._init_weights(self.qa_outputs) | ||
1361 | - | ||
1362 | - @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) | ||
1363 | - @add_code_sample_docstrings( | ||
1364 | - tokenizer_class=_TOKENIZER_FOR_DOC, | ||
1365 | - checkpoint="facebook/bart-large", | ||
1366 | - output_type=Seq2SeqQuestionAnsweringModelOutput, | ||
1367 | - config_class=_CONFIG_FOR_DOC, | ||
1368 | - ) | ||
1369 | - def forward( | ||
1370 | - self, | ||
1371 | - input_ids, | ||
1372 | - attention_mask=None, | ||
1373 | - encoder_outputs=None, | ||
1374 | - decoder_input_ids=None, | ||
1375 | - decoder_attention_mask=None, | ||
1376 | - start_positions=None, | ||
1377 | - end_positions=None, | ||
1378 | - use_cache=None, | ||
1379 | - output_attentions=None, | ||
1380 | - output_hidden_states=None, | ||
1381 | - return_dict=None, | ||
1382 | - ): | ||
1383 | - r""" | ||
1384 | - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | ||
1385 | - Labels for position (index) of the start of the labelled span for computing the token classification loss. | ||
1386 | - Positions are clamped to the length of the sequence (`sequence_length`). | ||
1387 | - Position outside of the sequence are not taken into account for computing the loss. | ||
1388 | - end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | ||
1389 | - Labels for position (index) of the end of the labelled span for computing the token classification loss. | ||
1390 | - Positions are clamped to the length of the sequence (`sequence_length`). | ||
1391 | - Position outside of the sequence are not taken into account for computing the loss. | ||
1392 | - """ | ||
1393 | - return_dict = ( | ||
1394 | - return_dict if return_dict is not None else self.config.use_return_dict | ||
1395 | - ) | ||
1396 | - if start_positions is not None and end_positions is not None: | ||
1397 | - use_cache = False | ||
1398 | - | ||
1399 | - outputs = self.model( | ||
1400 | - input_ids, | ||
1401 | - attention_mask=attention_mask, | ||
1402 | - decoder_input_ids=decoder_input_ids, | ||
1403 | - decoder_attention_mask=decoder_attention_mask, | ||
1404 | - encoder_outputs=encoder_outputs, | ||
1405 | - use_cache=use_cache, | ||
1406 | - output_attentions=output_attentions, | ||
1407 | - output_hidden_states=output_hidden_states, | ||
1408 | - return_dict=return_dict, | ||
1409 | - ) | ||
1410 | - | ||
1411 | - sequence_output = outputs[0] | ||
1412 | - | ||
1413 | - logits = self.qa_outputs(sequence_output) | ||
1414 | - start_logits, end_logits = logits.split(1, dim=-1) | ||
1415 | - start_logits = start_logits.squeeze(-1) | ||
1416 | - end_logits = end_logits.squeeze(-1) | ||
1417 | - | ||
1418 | - total_loss = None | ||
1419 | - if start_positions is not None and end_positions is not None: | ||
1420 | - # If we are on multi-GPU, split add a dimension | ||
1421 | - if len(start_positions.size()) > 1: | ||
1422 | - start_positions = start_positions.squeeze(-1) | ||
1423 | - if len(end_positions.size()) > 1: | ||
1424 | - end_positions = end_positions.squeeze(-1) | ||
1425 | - # sometimes the start/end positions are outside our model inputs, we ignore these terms | ||
1426 | - ignored_index = start_logits.size(1) | ||
1427 | - start_positions.clamp_(0, ignored_index) | ||
1428 | - end_positions.clamp_(0, ignored_index) | ||
1429 | - | ||
1430 | - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | ||
1431 | - start_loss = loss_fct(start_logits, start_positions) | ||
1432 | - end_loss = loss_fct(end_logits, end_positions) | ||
1433 | - total_loss = (start_loss + end_loss) / 2 | ||
1434 | - | ||
1435 | - if not return_dict: | ||
1436 | - output = (start_logits, end_logits,) + outputs[1:] | ||
1437 | - return ((total_loss,) + output) if total_loss is not None else output | ||
1438 | - | ||
1439 | - return Seq2SeqQuestionAnsweringModelOutput( | ||
1440 | - loss=total_loss, | ||
1441 | - start_logits=start_logits, | ||
1442 | - end_logits=end_logits, | ||
1443 | - past_key_values=outputs.past_key_values, | ||
1444 | - decoder_hidden_states=outputs.decoder_hidden_states, | ||
1445 | - decoder_attentions=outputs.decoder_attentions, | ||
1446 | - encoder_last_hidden_state=outputs.encoder_last_hidden_state, | ||
1447 | - encoder_hidden_states=outputs.encoder_hidden_states, | ||
1448 | - encoder_attentions=outputs.encoder_attentions, | ||
1449 | - ) | ||
1450 | - | ||
1451 | - | ||
1452 | -class SinusoidalPositionalEmbedding(nn.Embedding): | ||
1453 | - """This module produces sinusoidal positional embeddings of any length.""" | ||
1454 | - | ||
1455 | - def __init__(self, num_positions, embedding_dim, padding_idx=None): | ||
1456 | - super().__init__(num_positions, embedding_dim) | ||
1457 | - if embedding_dim % 2 != 0: | ||
1458 | - raise NotImplementedError( | ||
1459 | - f"odd embedding_dim {embedding_dim} not supported" | ||
1460 | - ) | ||
1461 | - self.weight = self._init_weight(self.weight) | ||
1462 | - | ||
1463 | - @staticmethod | ||
1464 | - def _init_weight(out: nn.Parameter): | ||
1465 | - """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. | ||
1466 | - The cos features are in the 2nd half of the vector. [dim // 2:] | ||
1467 | - """ | ||
1468 | - n_pos, dim = out.shape | ||
1469 | - position_enc = np.array( | ||
1470 | - [ | ||
1471 | - [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] | ||
1472 | - for pos in range(n_pos) | ||
1473 | - ] | ||
1474 | - ) | ||
1475 | - out[:, 0 : dim // 2] = torch.FloatTensor( | ||
1476 | - np.sin(position_enc[:, 0::2]) | ||
1477 | - ) # This line breaks for odd n_pos | ||
1478 | - out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) | ||
1479 | - out.detach_() | ||
1480 | - out.requires_grad = False | ||
1481 | - return out | ||
1482 | - | ||
1483 | - @torch.no_grad() | ||
1484 | - def forward(self, input_ids, use_cache=False): | ||
1485 | - """Input is expected to be of size [bsz x seqlen].""" | ||
1486 | - bsz, seq_len = input_ids.shape[:2] | ||
1487 | - if use_cache: | ||
1488 | - positions = input_ids.data.new(1, 1).fill_( | ||
1489 | - seq_len - 1 | ||
1490 | - ) # called before slicing | ||
1491 | - else: | ||
1492 | - # starts at 0, ends at 1-seq_len | ||
1493 | - positions = torch.arange( | ||
1494 | - seq_len, dtype=torch.long, device=self.weight.device | ||
1495 | - ) | ||
1496 | - return super().forward(positions) |
train/modeling_utils.py
deleted
100644 → 0
1 | -# coding=utf-8 | ||
2 | -# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. | ||
3 | -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
4 | -# | ||
5 | -# Licensed under the Apache License, Version 2.0 (the "License"); | ||
6 | -# you may not use this file except in compliance with the License. | ||
7 | -# You may obtain a copy of the License at | ||
8 | -# | ||
9 | -# http://www.apache.org/licenses/LICENSE-2.0 | ||
10 | -# | ||
11 | -# Unless required by applicable law or agreed to in writing, software | ||
12 | -# distributed under the License is distributed on an "AS IS" BASIS, | ||
13 | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
14 | -# See the License for the specific language governing permissions and | ||
15 | -# limitations under the License. | ||
16 | - | ||
17 | -import inspect | ||
18 | -import os | ||
19 | -import re | ||
20 | -from dataclasses import dataclass | ||
21 | -from typing import Callable, Dict, List, Optional, Set, Tuple, Union | ||
22 | - | ||
23 | -import torch | ||
24 | -from torch import Tensor, device, dtype, nn | ||
25 | -from torch.nn import CrossEntropyLoss | ||
26 | -from torch.nn import functional as F | ||
27 | - | ||
28 | -from transformers.activations import get_activation | ||
29 | -from transformers.configuration_utils import PretrainedConfig | ||
30 | -from transformers.file_utils import ( | ||
31 | - DUMMY_INPUTS, | ||
32 | - TF2_WEIGHTS_NAME, | ||
33 | - TF_WEIGHTS_NAME, | ||
34 | - WEIGHTS_NAME, | ||
35 | - ModelOutput, | ||
36 | - cached_path, | ||
37 | - hf_bucket_url, | ||
38 | - is_remote_url, | ||
39 | - is_torch_tpu_available, | ||
40 | - replace_return_docstrings, | ||
41 | -) | ||
42 | -from train.generation_utils import GenerationMixin | ||
43 | -import logging | ||
44 | - | ||
45 | -logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
46 | -logging.basicConfig( | ||
47 | - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s", | ||
48 | - datefmt="%m/%d/%Y %H:%M:%S", | ||
49 | - level=logging.INFO, | ||
50 | -) | ||
51 | - | ||
52 | - | ||
53 | -try: | ||
54 | - from torch.nn import Identity | ||
55 | -except ImportError: | ||
56 | - # Older PyTorch compatibility | ||
57 | - class Identity(nn.Module): | ||
58 | - r"""A placeholder identity operator that is argument-insensitive.""" | ||
59 | - | ||
60 | - def __init__(self, *args, **kwargs): | ||
61 | - super().__init__() | ||
62 | - | ||
63 | - def forward(self, input): | ||
64 | - return input | ||
65 | - | ||
66 | - | ||
67 | -def find_pruneable_heads_and_indices( | ||
68 | - heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int] | ||
69 | -) -> Tuple[Set[int], torch.LongTensor]: | ||
70 | - """ | ||
71 | - Finds the heads and their indices taking :obj:`already_pruned_heads` into account. | ||
72 | - | ||
73 | - Args: | ||
74 | - heads (:obj:`List[int]`): List of the indices of heads to prune. | ||
75 | - n_heads (:obj:`int`): The number of heads in the model. | ||
76 | - head_size (:obj:`int`): The size of each head. | ||
77 | - already_pruned_heads (:obj:`Set[int]`): A set of already pruned heads. | ||
78 | - | ||
79 | - Returns: | ||
80 | - :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. | ||
81 | - """ | ||
82 | - mask = torch.ones(n_heads, head_size) | ||
83 | - heads = ( | ||
84 | - set(heads) - already_pruned_heads | ||
85 | - ) # Convert to set and remove already pruned heads | ||
86 | - for head in heads: | ||
87 | - # Compute how many pruned heads are before the head and move the index accordingly | ||
88 | - head = head - sum(1 if h < head else 0 for h in already_pruned_heads) | ||
89 | - mask[head] = 0 | ||
90 | - mask = mask.view(-1).contiguous().eq(1) | ||
91 | - index: torch.LongTensor = torch.arange(len(mask))[mask].long() | ||
92 | - return heads, index | ||
93 | - | ||
94 | - | ||
95 | -class ModuleUtilsMixin: | ||
96 | - """ | ||
97 | - A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin. | ||
98 | - """ | ||
99 | - | ||
100 | - def num_parameters(self, only_trainable: bool = False) -> int: | ||
101 | - """ | ||
102 | - Get the number of (optionally, trainable) parameters in the model. | ||
103 | - | ||
104 | - Args: | ||
105 | - only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
106 | - Whether or not to return only the number of trainable parameters | ||
107 | - | ||
108 | - Returns: | ||
109 | - :obj:`int`: The number of parameters. | ||
110 | - """ | ||
111 | - params = ( | ||
112 | - filter(lambda x: x.requires_grad, self.parameters()) | ||
113 | - if only_trainable | ||
114 | - else self.parameters() | ||
115 | - ) | ||
116 | - return sum(p.numel() for p in params) | ||
117 | - | ||
118 | - @staticmethod | ||
119 | - def _hook_rss_memory_pre_forward(module, *args, **kwargs): | ||
120 | - try: | ||
121 | - import psutil | ||
122 | - except (ImportError): | ||
123 | - raise ImportError( | ||
124 | - "You need to install psutil (pip install psutil) to use memory tracing." | ||
125 | - ) | ||
126 | - | ||
127 | - process = psutil.Process(os.getpid()) | ||
128 | - mem = process.memory_info() | ||
129 | - module.mem_rss_pre_forward = mem.rss | ||
130 | - return None | ||
131 | - | ||
132 | - @staticmethod | ||
133 | - def _hook_rss_memory_post_forward(module, *args, **kwargs): | ||
134 | - try: | ||
135 | - import psutil | ||
136 | - except (ImportError): | ||
137 | - raise ImportError( | ||
138 | - "You need to install psutil (pip install psutil) to use memory tracing." | ||
139 | - ) | ||
140 | - | ||
141 | - process = psutil.Process(os.getpid()) | ||
142 | - mem = process.memory_info() | ||
143 | - module.mem_rss_post_forward = mem.rss | ||
144 | - mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward | ||
145 | - module.mem_rss_diff = mem_rss_diff + ( | ||
146 | - module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0 | ||
147 | - ) | ||
148 | - return None | ||
149 | - | ||
150 | - def add_memory_hooks(self): | ||
151 | - """ | ||
152 | - Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. | ||
153 | - | ||
154 | - Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to | ||
155 | - zero with :obj:`model.reset_memory_hooks_state()`. | ||
156 | - """ | ||
157 | - for module in self.modules(): | ||
158 | - module.register_forward_pre_hook(self._hook_rss_memory_pre_forward) | ||
159 | - module.register_forward_hook(self._hook_rss_memory_post_forward) | ||
160 | - self.reset_memory_hooks_state() | ||
161 | - | ||
162 | - def reset_memory_hooks_state(self): | ||
163 | - """ | ||
164 | - Reset the :obj:`mem_rss_diff` attribute of each module (see | ||
165 | - :func:`~transformers.modeling_utils.ModuleUtilsMixin.add_memory_hooks`). | ||
166 | - """ | ||
167 | - for module in self.modules(): | ||
168 | - module.mem_rss_diff = 0 | ||
169 | - module.mem_rss_post_forward = 0 | ||
170 | - module.mem_rss_pre_forward = 0 | ||
171 | - | ||
172 | - @property | ||
173 | - def device(self) -> device: | ||
174 | - """ | ||
175 | - :obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same | ||
176 | - device). | ||
177 | - """ | ||
178 | - try: | ||
179 | - return next(self.parameters()).device | ||
180 | - except StopIteration: | ||
181 | - # For nn.DataParallel compatibility in PyTorch 1.5 | ||
182 | - | ||
183 | - def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | ||
184 | - tuples = [ | ||
185 | - (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) | ||
186 | - ] | ||
187 | - return tuples | ||
188 | - | ||
189 | - gen = self._named_members(get_members_fn=find_tensor_attributes) | ||
190 | - first_tuple = next(gen) | ||
191 | - return first_tuple[1].device | ||
192 | - | ||
193 | - @property | ||
194 | - def dtype(self) -> dtype: | ||
195 | - """ | ||
196 | - :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | ||
197 | - """ | ||
198 | - try: | ||
199 | - return next(self.parameters()).dtype | ||
200 | - except StopIteration: | ||
201 | - # For nn.DataParallel compatibility in PyTorch 1.5 | ||
202 | - | ||
203 | - def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | ||
204 | - tuples = [ | ||
205 | - (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) | ||
206 | - ] | ||
207 | - return tuples | ||
208 | - | ||
209 | - gen = self._named_members(get_members_fn=find_tensor_attributes) | ||
210 | - first_tuple = next(gen) | ||
211 | - return first_tuple[1].dtype | ||
212 | - | ||
213 | - def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: | ||
214 | - """ | ||
215 | - Invert an attention mask (e.g., switches 0. and 1.). | ||
216 | - | ||
217 | - Args: | ||
218 | - encoder_attention_mask (:obj:`torch.Tensor`): An attention mask. | ||
219 | - | ||
220 | - Returns: | ||
221 | - :obj:`torch.Tensor`: The inverted attention mask. | ||
222 | - """ | ||
223 | - if encoder_attention_mask.dim() == 3: | ||
224 | - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] | ||
225 | - if encoder_attention_mask.dim() == 2: | ||
226 | - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] | ||
227 | - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition | ||
228 | - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow | ||
229 | - # /transformer/transformer_layers.py#L270 | ||
230 | - # encoder_extended_attention_mask = (encoder_extended_attention_mask == | ||
231 | - # encoder_extended_attention_mask.transpose(-1, -2)) | ||
232 | - encoder_extended_attention_mask = encoder_extended_attention_mask.to( | ||
233 | - dtype=self.dtype | ||
234 | - ) # fp16 compatibility | ||
235 | - | ||
236 | - if self.dtype == torch.float16: | ||
237 | - encoder_extended_attention_mask = ( | ||
238 | - 1.0 - encoder_extended_attention_mask | ||
239 | - ) * -1e4 | ||
240 | - elif self.dtype == torch.float32: | ||
241 | - encoder_extended_attention_mask = ( | ||
242 | - 1.0 - encoder_extended_attention_mask | ||
243 | - ) * -1e9 | ||
244 | - else: | ||
245 | - raise ValueError( | ||
246 | - "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format( | ||
247 | - self.dtype | ||
248 | - ) | ||
249 | - ) | ||
250 | - | ||
251 | - return encoder_extended_attention_mask | ||
252 | - | ||
253 | - def get_extended_attention_mask( | ||
254 | - self, attention_mask: Tensor, input_shape: Tuple[int], device: device | ||
255 | - ) -> Tensor: | ||
256 | - """ | ||
257 | - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. | ||
258 | - | ||
259 | - Arguments: | ||
260 | - attention_mask (:obj:`torch.Tensor`): | ||
261 | - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. | ||
262 | - input_shape (:obj:`Tuple[int]`): | ||
263 | - The shape of the input to the model. | ||
264 | - device: (:obj:`torch.device`): | ||
265 | - The device of the input to the model. | ||
266 | - | ||
267 | - Returns: | ||
268 | - :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. | ||
269 | - """ | ||
270 | - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
271 | - # ourselves in which case we just need to make it broadcastable to all heads. | ||
272 | - if attention_mask.dim() == 3: | ||
273 | - extended_attention_mask = attention_mask[:, None, :, :] | ||
274 | - elif attention_mask.dim() == 2: | ||
275 | - # Provided a padding mask of dimensions [batch_size, seq_length] | ||
276 | - # - if the model is a decoder, apply a causal mask in addition to the padding mask | ||
277 | - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||
278 | - if self.config.is_decoder: | ||
279 | - batch_size, seq_length = input_shape | ||
280 | - seq_ids = torch.arange(seq_length, device=device) | ||
281 | - causal_mask = ( | ||
282 | - seq_ids[None, None, :].repeat(batch_size, seq_length, 1) | ||
283 | - <= seq_ids[None, :, None] | ||
284 | - ) | ||
285 | - # causal and attention masks must have same type with pytorch version < 1.3 | ||
286 | - causal_mask = causal_mask.to(attention_mask.dtype) | ||
287 | - extended_attention_mask = ( | ||
288 | - causal_mask[:, None, :, :] * attention_mask[:, None, None, :] | ||
289 | - ) | ||
290 | - else: | ||
291 | - extended_attention_mask = attention_mask[:, None, None, :] | ||
292 | - else: | ||
293 | - raise ValueError( | ||
294 | - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( | ||
295 | - input_shape, attention_mask.shape | ||
296 | - ) | ||
297 | - ) | ||
298 | - | ||
299 | - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for | ||
300 | - # masked positions, this operation will create a tensor which is 0.0 for | ||
301 | - # positions we want to attend and -10000.0 for masked positions. | ||
302 | - # Since we are adding it to the raw scores before the softmax, this is | ||
303 | - # effectively the same as removing these entirely. | ||
304 | - extended_attention_mask = extended_attention_mask.to( | ||
305 | - dtype=self.dtype | ||
306 | - ) # fp16 compatibility | ||
307 | - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||
308 | - return extended_attention_mask | ||
309 | - | ||
310 | - def get_head_mask( | ||
311 | - self, | ||
312 | - head_mask: Optional[Tensor], | ||
313 | - num_hidden_layers: int, | ||
314 | - is_attention_chunked: bool = False, | ||
315 | - ) -> Tensor: | ||
316 | - """ | ||
317 | - Prepare the head mask if needed. | ||
318 | - | ||
319 | - Args: | ||
320 | - head_mask (:obj:`torch.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`): | ||
321 | - The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). | ||
322 | - num_hidden_layers (:obj:`int`): | ||
323 | - The number of hidden layers in the model. | ||
324 | - is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`): | ||
325 | - Whether or not the attentions scores are computed by chunks or not. | ||
326 | - | ||
327 | - Returns: | ||
328 | - :obj:`torch.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length]` | ||
329 | - or list with :obj:`[None]` for each layer. | ||
330 | - """ | ||
331 | - if head_mask is not None: | ||
332 | - head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) | ||
333 | - if is_attention_chunked is True: | ||
334 | - head_mask = head_mask.unsqueeze(-1) | ||
335 | - else: | ||
336 | - head_mask = [None] * num_hidden_layers | ||
337 | - | ||
338 | - return head_mask | ||
339 | - | ||
340 | - def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): | ||
341 | - """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" | ||
342 | - if head_mask.dim() == 1: | ||
343 | - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) | ||
344 | - head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) | ||
345 | - elif head_mask.dim() == 2: | ||
346 | - head_mask = ( | ||
347 | - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) | ||
348 | - ) # We can specify head_mask for each layer | ||
349 | - assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" | ||
350 | - head_mask = head_mask.to( | ||
351 | - dtype=self.dtype | ||
352 | - ) # switch to fload if need + fp16 compatibility | ||
353 | - return head_mask | ||
354 | - | ||
355 | - | ||
356 | -class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): | ||
357 | - r""" | ||
358 | - Base class for all models. | ||
359 | - | ||
360 | - :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods | ||
361 | - for loading, downloading and saving models as well as a few methods common to all models to: | ||
362 | - | ||
363 | - * resize the input embeddings, | ||
364 | - * prune heads in the self-attention heads. | ||
365 | - | ||
366 | - Class attributes (overridden by derived classes): | ||
367 | - - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of | ||
368 | - :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. | ||
369 | - - **load_tf_weights** (:obj:`Callable`) -- A python `method` for loading a TensorFlow checkpoint in a | ||
370 | - PyTorch model, taking as arguments: | ||
371 | - | ||
372 | - - **model** (:class:`~transformers.PreTrainedModel`) -- An instance of the model on which to load the | ||
373 | - TensorFlow checkpoint. | ||
374 | - - **config** (:class:`~transformers.PreTrainedConfig`) -- An instance of the configuration associated | ||
375 | - to the model. | ||
376 | - - **path** (:obj:`str`) -- A path to the TensorFlow checkpoint. | ||
377 | - | ||
378 | - - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in | ||
379 | - derived classes of the same architecture adding modules on top of the base model. | ||
380 | - - **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore | ||
381 | - when loading the model (and avoid unnecessary warnings). | ||
382 | - """ | ||
383 | - config_class = None | ||
384 | - base_model_prefix = "" | ||
385 | - authorized_missing_keys = None | ||
386 | - | ||
387 | - @property | ||
388 | - def dummy_inputs(self) -> Dict[str, torch.Tensor]: | ||
389 | - """ | ||
390 | - :obj:`Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network. | ||
391 | - """ | ||
392 | - return {"input_ids": torch.tensor(DUMMY_INPUTS)} | ||
393 | - | ||
394 | - def __init__(self, config: PretrainedConfig, *inputs, **kwargs): | ||
395 | - super().__init__() | ||
396 | - if not isinstance(config, PretrainedConfig): | ||
397 | - raise ValueError( | ||
398 | - "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " | ||
399 | - "To create a model from a pretrained model use " | ||
400 | - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | ||
401 | - self.__class__.__name__, self.__class__.__name__ | ||
402 | - ) | ||
403 | - ) | ||
404 | - # Save config in model | ||
405 | - self.config = config | ||
406 | - | ||
407 | - @property | ||
408 | - def base_model(self) -> nn.Module: | ||
409 | - """ | ||
410 | - :obj:`torch.nn.Module`: The main body of the model. | ||
411 | - """ | ||
412 | - return getattr(self, self.base_model_prefix, self) | ||
413 | - | ||
414 | - def get_input_embeddings(self) -> nn.Module: | ||
415 | - """ | ||
416 | - Returns the model's input embeddings. | ||
417 | - | ||
418 | - Returns: | ||
419 | - :obj:`nn.Module`: A torch module mapping vocabulary to hidden states. | ||
420 | - """ | ||
421 | - base_model = getattr(self, self.base_model_prefix, self) | ||
422 | - if base_model is not self: | ||
423 | - return base_model.get_input_embeddings() | ||
424 | - else: | ||
425 | - raise NotImplementedError | ||
426 | - | ||
427 | - def set_input_embeddings(self, value: nn.Module): | ||
428 | - """ | ||
429 | - Set model's input embeddings | ||
430 | - | ||
431 | - Args: | ||
432 | - value (:obj:`nn.Module`): A module mapping vocabulary to hidden states. | ||
433 | - """ | ||
434 | - base_model = getattr(self, self.base_model_prefix, self) | ||
435 | - if base_model is not self: | ||
436 | - base_model.set_input_embeddings(value) | ||
437 | - else: | ||
438 | - raise NotImplementedError | ||
439 | - | ||
440 | - def get_output_embeddings(self) -> nn.Module: | ||
441 | - """ | ||
442 | - Returns the model's output embeddings. | ||
443 | - | ||
444 | - Returns: | ||
445 | - :obj:`nn.Module`: A torch module mapping hidden states to vocabulary. | ||
446 | - """ | ||
447 | - return None # Overwrite for models with output embeddings | ||
448 | - | ||
449 | - def tie_weights(self): | ||
450 | - """ | ||
451 | - Tie the weights between the input embeddings and the output embeddings. | ||
452 | - | ||
453 | - If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning | ||
454 | - the weights instead. | ||
455 | - """ | ||
456 | - output_embeddings = self.get_output_embeddings() | ||
457 | - if output_embeddings is not None and self.config.tie_word_embeddings: | ||
458 | - self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) | ||
459 | - | ||
460 | - if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: | ||
461 | - self._tie_encoder_decoder_weights( | ||
462 | - self.encoder, self.decoder, self.base_model_prefix | ||
463 | - ) | ||
464 | - | ||
465 | - @staticmethod | ||
466 | - def _tie_encoder_decoder_weights( | ||
467 | - encoder: nn.Module, decoder: nn.Module, base_model_prefix: str | ||
468 | - ): | ||
469 | - uninitialized_encoder_weights: List[str] = [] | ||
470 | - assert ( | ||
471 | - decoder.__class__ == encoder.__class__ | ||
472 | - ), f"{decoder.__class__} and {encoder.__class__} have to be equal." | ||
473 | - | ||
474 | - def tie_encoder_to_decoder_recursively( | ||
475 | - decoder_pointer: nn.Module, | ||
476 | - encoder_pointer: nn.Module, | ||
477 | - module_name: str, | ||
478 | - uninitialized_encoder_weights: List[str], | ||
479 | - depth=0, | ||
480 | - ): | ||
481 | - assert isinstance(decoder_pointer, nn.Module) and isinstance( | ||
482 | - encoder_pointer, nn.Module | ||
483 | - ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" | ||
484 | - if hasattr(decoder_pointer, "weight"): | ||
485 | - assert hasattr(encoder_pointer, "weight") | ||
486 | - encoder_pointer.weight = decoder_pointer.weight | ||
487 | - if hasattr(decoder_pointer, "bias"): | ||
488 | - assert hasattr(encoder_pointer, "bias") | ||
489 | - encoder_pointer.bias = decoder_pointer.bias | ||
490 | - return | ||
491 | - | ||
492 | - encoder_modules = encoder_pointer._modules | ||
493 | - decoder_modules = decoder_pointer._modules | ||
494 | - if len(decoder_modules) > 0: | ||
495 | - assert ( | ||
496 | - len(encoder_modules) > 0 | ||
497 | - ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" | ||
498 | - | ||
499 | - all_encoder_weights = set( | ||
500 | - [ | ||
501 | - module_name + "/" + sub_name | ||
502 | - for sub_name in encoder_modules.keys() | ||
503 | - ] | ||
504 | - ) | ||
505 | - encoder_layer_pos = 0 | ||
506 | - for name, module in decoder_modules.items(): | ||
507 | - if name.isdigit(): | ||
508 | - encoder_name = str(int(name) + encoder_layer_pos) | ||
509 | - decoder_name = name | ||
510 | - if not isinstance( | ||
511 | - decoder_modules[decoder_name], | ||
512 | - type(encoder_modules[encoder_name]), | ||
513 | - ): | ||
514 | - # this can happen if the name corresponds to the position in a list module list of layers | ||
515 | - # in this case the decoder has added a cross-attention that the encoder does not have | ||
516 | - # thus skip this step and substract one layer pos from encoder | ||
517 | - encoder_layer_pos -= 1 | ||
518 | - continue | ||
519 | - elif name not in encoder_modules: | ||
520 | - continue | ||
521 | - elif depth > 500: | ||
522 | - raise ValueError( | ||
523 | - "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." | ||
524 | - ) | ||
525 | - else: | ||
526 | - decoder_name = encoder_name = name | ||
527 | - tie_encoder_to_decoder_recursively( | ||
528 | - decoder_modules[decoder_name], | ||
529 | - encoder_modules[encoder_name], | ||
530 | - module_name + "/" + name, | ||
531 | - uninitialized_encoder_weights, | ||
532 | - depth=depth + 1, | ||
533 | - ) | ||
534 | - all_encoder_weights.remove(module_name + "/" + encoder_name) | ||
535 | - | ||
536 | - uninitialized_encoder_weights += list(all_encoder_weights) | ||
537 | - | ||
538 | - # tie weights recursively | ||
539 | - tie_encoder_to_decoder_recursively( | ||
540 | - decoder, encoder, base_model_prefix, uninitialized_encoder_weights | ||
541 | - ) | ||
542 | - if len(uninitialized_encoder_weights) > 0: | ||
543 | - logger.warning( | ||
544 | - f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | ||
545 | - ) | ||
546 | - | ||
547 | - def _tie_or_clone_weights(self, output_embeddings, input_embeddings): | ||
548 | - """Tie or clone module weights depending of whether we are using TorchScript or not""" | ||
549 | - if self.config.torchscript: | ||
550 | - output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) | ||
551 | - else: | ||
552 | - output_embeddings.weight = input_embeddings.weight | ||
553 | - | ||
554 | - if getattr(output_embeddings, "bias", None) is not None: | ||
555 | - output_embeddings.bias.data = torch.nn.functional.pad( | ||
556 | - output_embeddings.bias.data, | ||
557 | - ( | ||
558 | - 0, | ||
559 | - output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], | ||
560 | - ), | ||
561 | - "constant", | ||
562 | - 0, | ||
563 | - ) | ||
564 | - if hasattr(output_embeddings, "out_features") and hasattr( | ||
565 | - input_embeddings, "num_embeddings" | ||
566 | - ): | ||
567 | - output_embeddings.out_features = input_embeddings.num_embeddings | ||
568 | - | ||
569 | - def resize_token_embeddings( | ||
570 | - self, new_num_tokens: Optional[int] = None | ||
571 | - ) -> torch.nn.Embedding: | ||
572 | - """ | ||
573 | - Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`. | ||
574 | - | ||
575 | - Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method. | ||
576 | - | ||
577 | - Arguments: | ||
578 | - new_num_tokens (:obj:`int`, `optional`): | ||
579 | - The number of new tokens in the embedding matrix. Increasing the size will add newly initialized | ||
580 | - vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`, | ||
581 | - just returns a pointer to the input tokens :obj:`torch.nn.Embedding` module of the model wihtout doing | ||
582 | - anything. | ||
583 | - | ||
584 | - Return: | ||
585 | - :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. | ||
586 | - """ | ||
587 | - base_model = getattr( | ||
588 | - self, self.base_model_prefix, self | ||
589 | - ) # get the base model if needed | ||
590 | - model_embeds = base_model._resize_token_embeddings(new_num_tokens) | ||
591 | - if new_num_tokens is None: | ||
592 | - return model_embeds | ||
593 | - | ||
594 | - # Update base model and current model config | ||
595 | - self.config.vocab_size = new_num_tokens | ||
596 | - base_model.vocab_size = new_num_tokens | ||
597 | - | ||
598 | - # Tie weights again if needed | ||
599 | - self.tie_weights() | ||
600 | - | ||
601 | - return model_embeds | ||
602 | - | ||
603 | - def _resize_token_embeddings(self, new_num_tokens): | ||
604 | - old_embeddings = self.get_input_embeddings() | ||
605 | - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) | ||
606 | - self.set_input_embeddings(new_embeddings) | ||
607 | - return self.get_input_embeddings() | ||
608 | - | ||
609 | - def _get_resized_embeddings( | ||
610 | - self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None | ||
611 | - ) -> torch.nn.Embedding: | ||
612 | - """ | ||
613 | - Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly | ||
614 | - initialized vectors at the end. Reducing the size will remove vectors from the end | ||
615 | - | ||
616 | - Args: | ||
617 | - old_embeddings (:obj:`torch.nn.Embedding`): | ||
618 | - Old embeddings to be resized. | ||
619 | - new_num_tokens (:obj:`int`, `optional`): | ||
620 | - New number of tokens in the embedding matrix. | ||
621 | - | ||
622 | - Increasing the size will add newly initialized vectors at the end. Reducing the size will remove | ||
623 | - vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens | ||
624 | - :obj:`torch.nn.Embedding`` module of the model wihtout doing anything. | ||
625 | - | ||
626 | - Return: | ||
627 | - :obj:`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if | ||
628 | - :obj:`new_num_tokens` is :obj:`None` | ||
629 | - """ | ||
630 | - if new_num_tokens is None: | ||
631 | - return old_embeddings | ||
632 | - | ||
633 | - old_num_tokens, old_embedding_dim = old_embeddings.weight.size() | ||
634 | - if old_num_tokens == new_num_tokens: | ||
635 | - return old_embeddings | ||
636 | - | ||
637 | - # Build new embeddings | ||
638 | - new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) | ||
639 | - new_embeddings.to(old_embeddings.weight.device) | ||
640 | - | ||
641 | - # initialize all new embeddings (in particular added tokens) | ||
642 | - self._init_weights(new_embeddings) | ||
643 | - | ||
644 | - # Copy token embeddings from the previous weights | ||
645 | - num_tokens_to_copy = min(old_num_tokens, new_num_tokens) | ||
646 | - new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[ | ||
647 | - :num_tokens_to_copy, : | ||
648 | - ] | ||
649 | - | ||
650 | - return new_embeddings | ||
651 | - | ||
652 | - def init_weights(self): | ||
653 | - """ | ||
654 | - Initializes and prunes weights if needed. | ||
655 | - """ | ||
656 | - # Initialize weights | ||
657 | - self.apply(self._init_weights) | ||
658 | - | ||
659 | - # Prune heads if needed | ||
660 | - if self.config.pruned_heads: | ||
661 | - self.prune_heads(self.config.pruned_heads) | ||
662 | - | ||
663 | - # Tie weights if needed | ||
664 | - self.tie_weights() | ||
665 | - | ||
666 | - def prune_heads(self, heads_to_prune: Dict[int, List[int]]): | ||
667 | - """ | ||
668 | - Prunes heads of the base model. | ||
669 | - | ||
670 | - Arguments: | ||
671 | - heads_to_prune (:obj:`Dict[int, List[int]]`): | ||
672 | - Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list | ||
673 | - of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will | ||
674 | - prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. | ||
675 | - """ | ||
676 | - # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads | ||
677 | - for layer, heads in heads_to_prune.items(): | ||
678 | - union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) | ||
679 | - self.config.pruned_heads[layer] = list( | ||
680 | - union_heads | ||
681 | - ) # Unfortunately we have to store it as list for JSON | ||
682 | - | ||
683 | - self.base_model._prune_heads(heads_to_prune) | ||
684 | - | ||
685 | - def save_pretrained(self, save_directory): | ||
686 | - """ | ||
687 | - Save a model and its configuration file to a directory, so that it can be re-loaded using the | ||
688 | - `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. | ||
689 | - | ||
690 | - Arguments: | ||
691 | - save_directory (:obj:`str`): | ||
692 | - Directory to which to save. Will be created if it doesn't exist. | ||
693 | - """ | ||
694 | - if os.path.isfile(save_directory): | ||
695 | - logger.error( | ||
696 | - "Provided path ({}) should be a directory, not a file".format( | ||
697 | - save_directory | ||
698 | - ) | ||
699 | - ) | ||
700 | - return | ||
701 | - os.makedirs(save_directory, exist_ok=True) | ||
702 | - | ||
703 | - # Only save the model itself if we are using distributed training | ||
704 | - model_to_save = self.module if hasattr(self, "module") else self | ||
705 | - | ||
706 | - # Attach architecture to the config | ||
707 | - model_to_save.config.architectures = [model_to_save.__class__.__name__] | ||
708 | - | ||
709 | - # If we save using the predefined names, we can load using `from_pretrained` | ||
710 | - output_model_file = os.path.join(save_directory, WEIGHTS_NAME) | ||
711 | - | ||
712 | - if getattr(self.config, "xla_device", False): | ||
713 | - import torch_xla.core.xla_model as xm | ||
714 | - | ||
715 | - if xm.is_master_ordinal(): | ||
716 | - # Save configuration file | ||
717 | - model_to_save.config.save_pretrained(save_directory) | ||
718 | - # xm.save takes care of saving only from master | ||
719 | - xm.save(model_to_save.state_dict(), output_model_file) | ||
720 | - else: | ||
721 | - model_to_save.config.save_pretrained(save_directory) | ||
722 | - torch.save(model_to_save.state_dict(), output_model_file) | ||
723 | - | ||
724 | - logger.info("Model weights saved in {}".format(output_model_file)) | ||
725 | - | ||
726 | - @classmethod | ||
727 | - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | ||
728 | - r""" | ||
729 | - Instantiate a pretrained pytorch model from a pre-trained model configuration. | ||
730 | - | ||
731 | - The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). | ||
732 | - To train the model, you should first set it back in training mode with ``model.train()``. | ||
733 | - | ||
734 | - The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come | ||
735 | - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning | ||
736 | - task. | ||
737 | - | ||
738 | - The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those | ||
739 | - weights are discarded. | ||
740 | - | ||
741 | - Parameters: | ||
742 | - pretrained_model_name_or_path (:obj:`str`, `optional`): | ||
743 | - Can be either: | ||
744 | - | ||
745 | - - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g., | ||
746 | - ``bert-base-uncased``. | ||
747 | - - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g., | ||
748 | - ``dbmdz/bert-base-german-cased``. | ||
749 | - - A path to a `directory` containing model weights saved using | ||
750 | - :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. | ||
751 | - - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In | ||
752 | - this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided | ||
753 | - as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in | ||
754 | - a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. | ||
755 | - - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword | ||
756 | - arguments ``config`` and ``state_dict``). | ||
757 | - model_args (sequence of positional arguments, `optional`): | ||
758 | - All remaning positional arguments will be passed to the underlying model's ``__init__`` method. | ||
759 | - config (:obj:`Union[PretrainedConfig, str]`, `optional`): | ||
760 | - Can be either: | ||
761 | - | ||
762 | - - an instance of a class derived from :class:`~transformers.PretrainedConfig`, | ||
763 | - - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. | ||
764 | - | ||
765 | - Configuration for the model to use instead of an automatically loaded configuation. Configuration can | ||
766 | - be automatically loaded when: | ||
767 | - | ||
768 | - - The model is a model provided by the library (loaded with the `shortcut name` string of a | ||
769 | - pretrained model). | ||
770 | - - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded | ||
771 | - by suppling the save directory. | ||
772 | - - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a | ||
773 | - configuration JSON file named `config.json` is found in the directory. | ||
774 | - state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`): | ||
775 | - A state dictionary to use instead of a state dictionary loaded from saved weights file. | ||
776 | - | ||
777 | - This option can be used if you want to create a model from a pretrained configuration but load your own | ||
778 | - weights. In this case though, you should check if using | ||
779 | - :func:`~transformers.PreTrainedModel.save_pretrained` and | ||
780 | - :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. | ||
781 | - cache_dir (:obj:`str`, `optional`): | ||
782 | - Path to a directory in which a downloaded pretrained model configuration should be cached if the | ||
783 | - standard cache should not be used. | ||
784 | - from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
785 | - Load the model weights from a TensorFlow checkpoint save file (see docstring of | ||
786 | - ``pretrained_model_name_or_path`` argument). | ||
787 | - force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
788 | - Whether or not to force the (re-)download of the model weights and configuration files, overriding the | ||
789 | - cached versions if they exist. | ||
790 | - resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
791 | - Whether or not to delete incompletely received files. Will attempt to resume the download if such a | ||
792 | - file exists. | ||
793 | - proxies (:obj:`Dict[str, str], `optional`): | ||
794 | - A dictionary of proxy servers to use by protocol or endpoint, e.g., | ||
795 | - :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each | ||
796 | - request. | ||
797 | - output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
798 | - Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error | ||
799 | - messages. | ||
800 | - local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
801 | - Whether or not to only look at local files (e.g., not try doanloading the model). | ||
802 | - use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
803 | - Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on | ||
804 | - our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. | ||
805 | - kwargs (remaining dictionary of keyword arguments, `optional`): | ||
806 | - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | ||
807 | - :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or | ||
808 | - automatically loaded: | ||
809 | - | ||
810 | - - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the | ||
811 | - underlying model's ``__init__`` method (we assume all relevant updates to the configuration have | ||
812 | - already been done) | ||
813 | - - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class | ||
814 | - initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of | ||
815 | - ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute | ||
816 | - with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration | ||
817 | - attribute will be passed to the underlying model's ``__init__`` function. | ||
818 | - | ||
819 | - Examples:: | ||
820 | - | ||
821 | - from transformers import BertConfig, BertModel | ||
822 | - # Download model and configuration from S3 and cache. | ||
823 | - model = BertModel.from_pretrained('bert-base-uncased') | ||
824 | - # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). | ||
825 | - model = BertModel.from_pretrained('./test/saved_model/') | ||
826 | - # Update configuration during loading. | ||
827 | - model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) | ||
828 | - assert model.config.output_attention == True | ||
829 | - # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). | ||
830 | - config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') | ||
831 | - model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) | ||
832 | - """ | ||
833 | - config = kwargs.pop("config", None) | ||
834 | - state_dict = kwargs.pop("state_dict", None) | ||
835 | - cache_dir = kwargs.pop("cache_dir", None) | ||
836 | - from_tf = kwargs.pop("from_tf", False) | ||
837 | - force_download = kwargs.pop("force_download", False) | ||
838 | - resume_download = kwargs.pop("resume_download", False) | ||
839 | - proxies = kwargs.pop("proxies", None) | ||
840 | - output_loading_info = kwargs.pop("output_loading_info", False) | ||
841 | - local_files_only = kwargs.pop("local_files_only", False) | ||
842 | - use_cdn = kwargs.pop("use_cdn", True) | ||
843 | - | ||
844 | - # Load config if we don't provide a configuration | ||
845 | - if not isinstance(config, PretrainedConfig): | ||
846 | - config_path = ( | ||
847 | - config if config is not None else pretrained_model_name_or_path | ||
848 | - ) | ||
849 | - config, model_kwargs = cls.config_class.from_pretrained( | ||
850 | - config_path, | ||
851 | - *model_args, | ||
852 | - cache_dir=cache_dir, | ||
853 | - return_unused_kwargs=True, | ||
854 | - force_download=force_download, | ||
855 | - resume_download=resume_download, | ||
856 | - proxies=proxies, | ||
857 | - local_files_only=local_files_only, | ||
858 | - **kwargs, | ||
859 | - ) | ||
860 | - else: | ||
861 | - model_kwargs = kwargs | ||
862 | - | ||
863 | - # Load model | ||
864 | - if pretrained_model_name_or_path is not None: | ||
865 | - if os.path.isdir(pretrained_model_name_or_path): | ||
866 | - if from_tf and os.path.isfile( | ||
867 | - os.path.join( | ||
868 | - pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index" | ||
869 | - ) | ||
870 | - ): | ||
871 | - # Load from a TF 1.0 checkpoint | ||
872 | - archive_file = os.path.join( | ||
873 | - pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index" | ||
874 | - ) | ||
875 | - elif from_tf and os.path.isfile( | ||
876 | - os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) | ||
877 | - ): | ||
878 | - # Load from a TF 2.0 checkpoint | ||
879 | - archive_file = os.path.join( | ||
880 | - pretrained_model_name_or_path, TF2_WEIGHTS_NAME | ||
881 | - ) | ||
882 | - elif os.path.isfile( | ||
883 | - os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | ||
884 | - ): | ||
885 | - # Load from a PyTorch checkpoint | ||
886 | - archive_file = os.path.join( | ||
887 | - pretrained_model_name_or_path, WEIGHTS_NAME | ||
888 | - ) | ||
889 | - else: | ||
890 | - raise EnvironmentError( | ||
891 | - "Error no file named {} found in directory {} or `from_tf` set to False".format( | ||
892 | - [ | ||
893 | - WEIGHTS_NAME, | ||
894 | - TF2_WEIGHTS_NAME, | ||
895 | - TF_WEIGHTS_NAME + ".index", | ||
896 | - ], | ||
897 | - pretrained_model_name_or_path, | ||
898 | - ) | ||
899 | - ) | ||
900 | - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url( | ||
901 | - pretrained_model_name_or_path | ||
902 | - ): | ||
903 | - archive_file = pretrained_model_name_or_path | ||
904 | - elif os.path.isfile(pretrained_model_name_or_path + ".index"): | ||
905 | - assert ( | ||
906 | - from_tf | ||
907 | - ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( | ||
908 | - pretrained_model_name_or_path + ".index" | ||
909 | - ) | ||
910 | - archive_file = pretrained_model_name_or_path + ".index" | ||
911 | - else: | ||
912 | - archive_file = hf_bucket_url( | ||
913 | - pretrained_model_name_or_path, | ||
914 | - filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), | ||
915 | - use_cdn=use_cdn, | ||
916 | - ) | ||
917 | - | ||
918 | - try: | ||
919 | - # Load from URL or cache if already cached | ||
920 | - resolved_archive_file = cached_path( | ||
921 | - archive_file, | ||
922 | - cache_dir=cache_dir, | ||
923 | - force_download=force_download, | ||
924 | - proxies=proxies, | ||
925 | - resume_download=resume_download, | ||
926 | - local_files_only=local_files_only, | ||
927 | - ) | ||
928 | - if resolved_archive_file is None: | ||
929 | - raise EnvironmentError | ||
930 | - except EnvironmentError: | ||
931 | - msg = ( | ||
932 | - f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" | ||
933 | - f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" | ||
934 | - f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n" | ||
935 | - ) | ||
936 | - raise EnvironmentError(msg) | ||
937 | - | ||
938 | - if resolved_archive_file == archive_file: | ||
939 | - logger.info("loading weights file {}".format(archive_file)) | ||
940 | - else: | ||
941 | - logger.info( | ||
942 | - "loading weights file {} from cache at {}".format( | ||
943 | - archive_file, resolved_archive_file | ||
944 | - ) | ||
945 | - ) | ||
946 | - else: | ||
947 | - resolved_archive_file = None | ||
948 | - | ||
949 | - # Instantiate model. | ||
950 | - model = cls(config, *model_args, **model_kwargs) | ||
951 | - | ||
952 | - if state_dict is None and not from_tf: | ||
953 | - try: | ||
954 | - state_dict = torch.load(resolved_archive_file, map_location="cpu") | ||
955 | - except Exception: | ||
956 | - raise OSError( | ||
957 | - "Unable to load weights from pytorch checkpoint file. " | ||
958 | - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " | ||
959 | - ) | ||
960 | - | ||
961 | - missing_keys = [] | ||
962 | - unexpected_keys = [] | ||
963 | - error_msgs = [] | ||
964 | - | ||
965 | - if from_tf: | ||
966 | - if resolved_archive_file.endswith(".index"): | ||
967 | - # Load from a TensorFlow 1.X checkpoint - provided by original authors | ||
968 | - model = cls.load_tf_weights( | ||
969 | - model, config, resolved_archive_file[:-6] | ||
970 | - ) # Remove the '.index' | ||
971 | - else: | ||
972 | - # Load from our TensorFlow 2.0 checkpoints | ||
973 | - try: | ||
974 | - from transformers import load_tf2_checkpoint_in_pytorch_model | ||
975 | - | ||
976 | - model = load_tf2_checkpoint_in_pytorch_model( | ||
977 | - model, resolved_archive_file, allow_missing_keys=True | ||
978 | - ) | ||
979 | - except ImportError: | ||
980 | - logger.error( | ||
981 | - "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " | ||
982 | - "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." | ||
983 | - ) | ||
984 | - raise | ||
985 | - else: | ||
986 | - # Convert old format to new format if needed from a PyTorch state_dict | ||
987 | - old_keys = [] | ||
988 | - new_keys = [] | ||
989 | - for key in state_dict.keys(): | ||
990 | - new_key = None | ||
991 | - if "gamma" in key: | ||
992 | - new_key = key.replace("gamma", "weight") | ||
993 | - if "beta" in key: | ||
994 | - new_key = key.replace("beta", "bias") | ||
995 | - if new_key: | ||
996 | - old_keys.append(key) | ||
997 | - new_keys.append(new_key) | ||
998 | - for old_key, new_key in zip(old_keys, new_keys): | ||
999 | - state_dict[new_key] = state_dict.pop(old_key) | ||
1000 | - | ||
1001 | - # copy state_dict so _load_from_state_dict can modify it | ||
1002 | - metadata = getattr(state_dict, "_metadata", None) | ||
1003 | - state_dict = state_dict.copy() | ||
1004 | - if metadata is not None: | ||
1005 | - state_dict._metadata = metadata | ||
1006 | - | ||
1007 | - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | ||
1008 | - # so we need to apply the function recursively. | ||
1009 | - def load(module: nn.Module, prefix=""): | ||
1010 | - local_metadata = ( | ||
1011 | - {} if metadata is None else metadata.get(prefix[:-1], {}) | ||
1012 | - ) | ||
1013 | - module._load_from_state_dict( | ||
1014 | - state_dict, | ||
1015 | - prefix, | ||
1016 | - local_metadata, | ||
1017 | - True, | ||
1018 | - missing_keys, | ||
1019 | - unexpected_keys, | ||
1020 | - error_msgs, | ||
1021 | - ) | ||
1022 | - for name, child in module._modules.items(): | ||
1023 | - if child is not None: | ||
1024 | - load(child, prefix + name + ".") | ||
1025 | - | ||
1026 | - # Make sure we are able to load base models as well as derived models (with heads) | ||
1027 | - start_prefix = "" | ||
1028 | - model_to_load = model | ||
1029 | - has_prefix_module = any( | ||
1030 | - s.startswith(cls.base_model_prefix) for s in state_dict.keys() | ||
1031 | - ) | ||
1032 | - if not hasattr(model, cls.base_model_prefix) and has_prefix_module: | ||
1033 | - start_prefix = cls.base_model_prefix + "." | ||
1034 | - if hasattr(model, cls.base_model_prefix) and not has_prefix_module: | ||
1035 | - model_to_load = getattr(model, cls.base_model_prefix) | ||
1036 | - | ||
1037 | - load(model_to_load, prefix=start_prefix) | ||
1038 | - | ||
1039 | - if model.__class__.__name__ != model_to_load.__class__.__name__: | ||
1040 | - base_model_state_dict = model_to_load.state_dict().keys() | ||
1041 | - head_model_state_dict_without_base_prefix = [ | ||
1042 | - key.split(cls.base_model_prefix + ".")[-1] | ||
1043 | - for key in model.state_dict().keys() | ||
1044 | - ] | ||
1045 | - missing_keys.extend( | ||
1046 | - head_model_state_dict_without_base_prefix - base_model_state_dict | ||
1047 | - ) | ||
1048 | - | ||
1049 | - # Some models may have keys that are not in the state by design, removing them before needlessly warning | ||
1050 | - # the user. | ||
1051 | - if cls.authorized_missing_keys is not None: | ||
1052 | - for pat in cls.authorized_missing_keys: | ||
1053 | - missing_keys = [ | ||
1054 | - k for k in missing_keys if re.search(pat, k) is None | ||
1055 | - ] | ||
1056 | - | ||
1057 | - if len(unexpected_keys) > 0: | ||
1058 | - logger.warning( | ||
1059 | - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | ||
1060 | - f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | ||
1061 | - f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | ||
1062 | - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" | ||
1063 | - f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " | ||
1064 | - f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | ||
1065 | - ) | ||
1066 | - else: | ||
1067 | - logger.info( | ||
1068 | - f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" | ||
1069 | - ) | ||
1070 | - if len(missing_keys) > 0: | ||
1071 | - logger.warning( | ||
1072 | - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | ||
1073 | - f"and are newly initialized: {missing_keys}\n" | ||
1074 | - f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | ||
1075 | - ) | ||
1076 | - else: | ||
1077 | - logger.info( | ||
1078 | - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" | ||
1079 | - f"If your task is similar to the task the model of the checkpoint was trained on, " | ||
1080 | - f"you can already use {model.__class__.__name__} for predictions without further training." | ||
1081 | - ) | ||
1082 | - if len(error_msgs) > 0: | ||
1083 | - raise RuntimeError( | ||
1084 | - "Error(s) in loading state_dict for {}:\n\t{}".format( | ||
1085 | - model.__class__.__name__, "\n\t".join(error_msgs) | ||
1086 | - ) | ||
1087 | - ) | ||
1088 | - # make sure token embedding weights are still tied if needed | ||
1089 | - model.tie_weights() | ||
1090 | - | ||
1091 | - # Set model in evaluation mode to deactivate DropOut modules by default | ||
1092 | - model.eval() | ||
1093 | - | ||
1094 | - if output_loading_info: | ||
1095 | - loading_info = { | ||
1096 | - "missing_keys": missing_keys, | ||
1097 | - "unexpected_keys": unexpected_keys, | ||
1098 | - "error_msgs": error_msgs, | ||
1099 | - } | ||
1100 | - return model, loading_info | ||
1101 | - | ||
1102 | - if ( | ||
1103 | - hasattr(config, "xla_device") | ||
1104 | - and config.xla_device | ||
1105 | - and is_torch_tpu_available() | ||
1106 | - ): | ||
1107 | - import torch_xla.core.xla_model as xm | ||
1108 | - | ||
1109 | - model = xm.send_cpu_data_to_device(model, xm.xla_device()) | ||
1110 | - model.to(xm.xla_device()) | ||
1111 | - | ||
1112 | - return model | ||
1113 | - | ||
1114 | - | ||
1115 | -class Conv1D(nn.Module): | ||
1116 | - """ | ||
1117 | - 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). | ||
1118 | - | ||
1119 | - Basically works like a linear layer but the weights are transposed. | ||
1120 | - | ||
1121 | - Args: | ||
1122 | - nf (:obj:`int`): The number of output features. | ||
1123 | - nx (:obj:`int`): The number of input features. | ||
1124 | - """ | ||
1125 | - | ||
1126 | - def __init__(self, nf, nx): | ||
1127 | - super().__init__() | ||
1128 | - self.nf = nf | ||
1129 | - w = torch.empty(nx, nf) | ||
1130 | - nn.init.normal_(w, std=0.02) | ||
1131 | - self.weight = nn.Parameter(w) | ||
1132 | - self.bias = nn.Parameter(torch.zeros(nf)) | ||
1133 | - | ||
1134 | - def forward(self, x): | ||
1135 | - size_out = x.size()[:-1] + (self.nf,) | ||
1136 | - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) | ||
1137 | - x = x.view(*size_out) | ||
1138 | - return x | ||
1139 | - | ||
1140 | - | ||
1141 | -class PoolerStartLogits(nn.Module): | ||
1142 | - """ | ||
1143 | - Compute SQuAD start logits from sequence hidden states. | ||
1144 | - | ||
1145 | - Args: | ||
1146 | - config (:class:`~transformers.PretrainedConfig`): | ||
1147 | - The config used by the model, will be used to grab the :obj:`hidden_size` of the model. | ||
1148 | - """ | ||
1149 | - | ||
1150 | - def __init__(self, config: PretrainedConfig): | ||
1151 | - super().__init__() | ||
1152 | - self.dense = nn.Linear(config.hidden_size, 1) | ||
1153 | - | ||
1154 | - def forward( | ||
1155 | - self, | ||
1156 | - hidden_states: torch.FloatTensor, | ||
1157 | - p_mask: Optional[torch.FloatTensor] = None, | ||
1158 | - ) -> torch.FloatTensor: | ||
1159 | - """ | ||
1160 | - Args: | ||
1161 | - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): | ||
1162 | - The final hidden states of the model. | ||
1163 | - p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): | ||
1164 | - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). | ||
1165 | - 1.0 means token should be masked. | ||
1166 | - | ||
1167 | - Returns: | ||
1168 | - :obj:`torch.FloatTensor`: The start logits for SQuAD. | ||
1169 | - """ | ||
1170 | - x = self.dense(hidden_states).squeeze(-1) | ||
1171 | - | ||
1172 | - if p_mask is not None: | ||
1173 | - if next(self.parameters()).dtype == torch.float16: | ||
1174 | - x = x * (1 - p_mask) - 65500 * p_mask | ||
1175 | - else: | ||
1176 | - x = x * (1 - p_mask) - 1e30 * p_mask | ||
1177 | - | ||
1178 | - return x | ||
1179 | - | ||
1180 | - | ||
1181 | -class PoolerEndLogits(nn.Module): | ||
1182 | - """ | ||
1183 | - Compute SQuAD end logits from sequence hidden states. | ||
1184 | - | ||
1185 | - Args: | ||
1186 | - config (:class:`~transformers.PretrainedConfig`): | ||
1187 | - The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the | ||
1188 | - :obj:`layer_norm_eps` to use. | ||
1189 | - """ | ||
1190 | - | ||
1191 | - def __init__(self, config: PretrainedConfig): | ||
1192 | - super().__init__() | ||
1193 | - self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) | ||
1194 | - self.activation = nn.Tanh() | ||
1195 | - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||
1196 | - self.dense_1 = nn.Linear(config.hidden_size, 1) | ||
1197 | - | ||
1198 | - def forward( | ||
1199 | - self, | ||
1200 | - hidden_states: torch.FloatTensor, | ||
1201 | - start_states: Optional[torch.FloatTensor] = None, | ||
1202 | - start_positions: Optional[torch.LongTensor] = None, | ||
1203 | - p_mask: Optional[torch.FloatTensor] = None, | ||
1204 | - ) -> torch.FloatTensor: | ||
1205 | - """ | ||
1206 | - Args: | ||
1207 | - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): | ||
1208 | - The final hidden states of the model. | ||
1209 | - start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`): | ||
1210 | - The hidden states of the first tokens for the labeled span. | ||
1211 | - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1212 | - The position of the first token for the labeled span. | ||
1213 | - p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): | ||
1214 | - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). | ||
1215 | - 1.0 means token should be masked. | ||
1216 | - | ||
1217 | - .. note:: | ||
1218 | - | ||
1219 | - One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set, | ||
1220 | - ``start_positions`` overrides ``start_states``. | ||
1221 | - | ||
1222 | - Returns: | ||
1223 | - :obj:`torch.FloatTensor`: The end logits for SQuAD. | ||
1224 | - """ | ||
1225 | - assert ( | ||
1226 | - start_states is not None or start_positions is not None | ||
1227 | - ), "One of start_states, start_positions should be not None" | ||
1228 | - if start_positions is not None: | ||
1229 | - slen, hsz = hidden_states.shape[-2:] | ||
1230 | - start_positions = start_positions[:, None, None].expand( | ||
1231 | - -1, -1, hsz | ||
1232 | - ) # shape (bsz, 1, hsz) | ||
1233 | - start_states = hidden_states.gather( | ||
1234 | - -2, start_positions | ||
1235 | - ) # shape (bsz, 1, hsz) | ||
1236 | - start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) | ||
1237 | - | ||
1238 | - x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) | ||
1239 | - x = self.activation(x) | ||
1240 | - x = self.LayerNorm(x) | ||
1241 | - x = self.dense_1(x).squeeze(-1) | ||
1242 | - | ||
1243 | - if p_mask is not None: | ||
1244 | - if next(self.parameters()).dtype == torch.float16: | ||
1245 | - x = x * (1 - p_mask) - 65500 * p_mask | ||
1246 | - else: | ||
1247 | - x = x * (1 - p_mask) - 1e30 * p_mask | ||
1248 | - | ||
1249 | - return x | ||
1250 | - | ||
1251 | - | ||
1252 | -class PoolerAnswerClass(nn.Module): | ||
1253 | - """ | ||
1254 | - Compute SQuAD 2.0 answer class from classification and start tokens hidden states. | ||
1255 | - | ||
1256 | - Args: | ||
1257 | - config (:class:`~transformers.PretrainedConfig`): | ||
1258 | - The config used by the model, will be used to grab the :obj:`hidden_size` of the model. | ||
1259 | - """ | ||
1260 | - | ||
1261 | - def __init__(self, config): | ||
1262 | - super().__init__() | ||
1263 | - self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) | ||
1264 | - self.activation = nn.Tanh() | ||
1265 | - self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) | ||
1266 | - | ||
1267 | - def forward( | ||
1268 | - self, | ||
1269 | - hidden_states: torch.FloatTensor, | ||
1270 | - start_states: Optional[torch.FloatTensor] = None, | ||
1271 | - start_positions: Optional[torch.LongTensor] = None, | ||
1272 | - cls_index: Optional[torch.LongTensor] = None, | ||
1273 | - ) -> torch.FloatTensor: | ||
1274 | - """ | ||
1275 | - Args: | ||
1276 | - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): | ||
1277 | - The final hidden states of the model. | ||
1278 | - start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`): | ||
1279 | - The hidden states of the first tokens for the labeled span. | ||
1280 | - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1281 | - The position of the first token for the labeled span. | ||
1282 | - cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1283 | - Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token. | ||
1284 | - | ||
1285 | - .. note:: | ||
1286 | - | ||
1287 | - One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set, | ||
1288 | - ``start_positions`` overrides ``start_states``. | ||
1289 | - | ||
1290 | - Returns: | ||
1291 | - :obj:`torch.FloatTensor`: The SQuAD 2.0 answer class. | ||
1292 | - """ | ||
1293 | - # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. | ||
1294 | - hsz = hidden_states.shape[-1] | ||
1295 | - assert ( | ||
1296 | - start_states is not None or start_positions is not None | ||
1297 | - ), "One of start_states, start_positions should be not None" | ||
1298 | - if start_positions is not None: | ||
1299 | - start_positions = start_positions[:, None, None].expand( | ||
1300 | - -1, -1, hsz | ||
1301 | - ) # shape (bsz, 1, hsz) | ||
1302 | - start_states = hidden_states.gather(-2, start_positions).squeeze( | ||
1303 | - -2 | ||
1304 | - ) # shape (bsz, hsz) | ||
1305 | - | ||
1306 | - if cls_index is not None: | ||
1307 | - cls_index = cls_index[:, None, None].expand( | ||
1308 | - -1, -1, hsz | ||
1309 | - ) # shape (bsz, 1, hsz) | ||
1310 | - cls_token_state = hidden_states.gather(-2, cls_index).squeeze( | ||
1311 | - -2 | ||
1312 | - ) # shape (bsz, hsz) | ||
1313 | - else: | ||
1314 | - cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) | ||
1315 | - | ||
1316 | - x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) | ||
1317 | - x = self.activation(x) | ||
1318 | - x = self.dense_1(x).squeeze(-1) | ||
1319 | - | ||
1320 | - return x | ||
1321 | - | ||
1322 | - | ||
1323 | -@dataclass | ||
1324 | -class SquadHeadOutput(ModelOutput): | ||
1325 | - """ | ||
1326 | - Base class for outputs of question answering models using a :class:`~transformers.modeling_utils.SQuADHead`. | ||
1327 | - | ||
1328 | - Args: | ||
1329 | - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided): | ||
1330 | - Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. | ||
1331 | - start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1332 | - Log probabilities for the top config.start_n_top start token possibilities (beam-search). | ||
1333 | - start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1334 | - Indices for the top config.start_n_top start token possibilities (beam-search). | ||
1335 | - end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1336 | - Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). | ||
1337 | - end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1338 | - Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). | ||
1339 | - cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): | ||
1340 | - Log probabilities for the ``is_impossible`` label of the answers. | ||
1341 | - | ||
1342 | - """ | ||
1343 | - | ||
1344 | - loss: Optional[torch.FloatTensor] = None | ||
1345 | - start_top_log_probs: Optional[torch.FloatTensor] = None | ||
1346 | - start_top_index: Optional[torch.LongTensor] = None | ||
1347 | - end_top_log_probs: Optional[torch.FloatTensor] = None | ||
1348 | - end_top_index: Optional[torch.LongTensor] = None | ||
1349 | - cls_logits: Optional[torch.FloatTensor] = None | ||
1350 | - | ||
1351 | - | ||
1352 | -class SQuADHead(nn.Module): | ||
1353 | - r""" | ||
1354 | - A SQuAD head inspired by XLNet. | ||
1355 | - | ||
1356 | - Args: | ||
1357 | - config (:class:`~transformers.PretrainedConfig`): | ||
1358 | - The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the | ||
1359 | - :obj:`layer_norm_eps` to use. | ||
1360 | - """ | ||
1361 | - | ||
1362 | - def __init__(self, config): | ||
1363 | - super().__init__() | ||
1364 | - self.start_n_top = config.start_n_top | ||
1365 | - self.end_n_top = config.end_n_top | ||
1366 | - | ||
1367 | - self.start_logits = PoolerStartLogits(config) | ||
1368 | - self.end_logits = PoolerEndLogits(config) | ||
1369 | - self.answer_class = PoolerAnswerClass(config) | ||
1370 | - | ||
1371 | - @replace_return_docstrings( | ||
1372 | - output_type=SquadHeadOutput, config_class=PretrainedConfig | ||
1373 | - ) | ||
1374 | - def forward( | ||
1375 | - self, | ||
1376 | - hidden_states: torch.FloatTensor, | ||
1377 | - start_positions: Optional[torch.LongTensor] = None, | ||
1378 | - end_positions: Optional[torch.LongTensor] = None, | ||
1379 | - cls_index: Optional[torch.LongTensor] = None, | ||
1380 | - is_impossible: Optional[torch.LongTensor] = None, | ||
1381 | - p_mask: Optional[torch.FloatTensor] = None, | ||
1382 | - return_dict: bool = False, | ||
1383 | - ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]: | ||
1384 | - """ | ||
1385 | - Args: | ||
1386 | - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`): | ||
1387 | - Final hidden states of the model on the sequence tokens. | ||
1388 | - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1389 | - Positions of the first token for the labeled span. | ||
1390 | - end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1391 | - Positions of the last token for the labeled span. | ||
1392 | - cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1393 | - Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token. | ||
1394 | - is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | ||
1395 | - Whether the question has a possible answer in the paragraph or not. | ||
1396 | - p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`): | ||
1397 | - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). | ||
1398 | - 1.0 means token should be masked. | ||
1399 | - return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
1400 | - Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple. | ||
1401 | - | ||
1402 | - Returns: | ||
1403 | - """ | ||
1404 | - start_logits = self.start_logits(hidden_states, p_mask=p_mask) | ||
1405 | - | ||
1406 | - if start_positions is not None and end_positions is not None: | ||
1407 | - # If we are on multi-GPU, let's remove the dimension added by batch splitting | ||
1408 | - for x in (start_positions, end_positions, cls_index, is_impossible): | ||
1409 | - if x is not None and x.dim() > 1: | ||
1410 | - x.squeeze_(-1) | ||
1411 | - | ||
1412 | - # during training, compute the end logits based on the ground truth of the start position | ||
1413 | - end_logits = self.end_logits( | ||
1414 | - hidden_states, start_positions=start_positions, p_mask=p_mask | ||
1415 | - ) | ||
1416 | - | ||
1417 | - loss_fct = CrossEntropyLoss() | ||
1418 | - start_loss = loss_fct(start_logits, start_positions) | ||
1419 | - end_loss = loss_fct(end_logits, end_positions) | ||
1420 | - total_loss = (start_loss + end_loss) / 2 | ||
1421 | - | ||
1422 | - if cls_index is not None and is_impossible is not None: | ||
1423 | - # Predict answerability from the representation of CLS and START | ||
1424 | - cls_logits = self.answer_class( | ||
1425 | - hidden_states, start_positions=start_positions, cls_index=cls_index | ||
1426 | - ) | ||
1427 | - loss_fct_cls = nn.BCEWithLogitsLoss() | ||
1428 | - cls_loss = loss_fct_cls(cls_logits, is_impossible) | ||
1429 | - | ||
1430 | - # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss | ||
1431 | - total_loss += cls_loss * 0.5 | ||
1432 | - | ||
1433 | - return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) | ||
1434 | - | ||
1435 | - else: | ||
1436 | - # during inference, compute the end logits based on beam search | ||
1437 | - bsz, slen, hsz = hidden_states.size() | ||
1438 | - start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) | ||
1439 | - | ||
1440 | - start_top_log_probs, start_top_index = torch.topk( | ||
1441 | - start_log_probs, self.start_n_top, dim=-1 | ||
1442 | - ) # shape (bsz, start_n_top) | ||
1443 | - start_top_index_exp = start_top_index.unsqueeze(-1).expand( | ||
1444 | - -1, -1, hsz | ||
1445 | - ) # shape (bsz, start_n_top, hsz) | ||
1446 | - start_states = torch.gather( | ||
1447 | - hidden_states, -2, start_top_index_exp | ||
1448 | - ) # shape (bsz, start_n_top, hsz) | ||
1449 | - start_states = start_states.unsqueeze(1).expand( | ||
1450 | - -1, slen, -1, -1 | ||
1451 | - ) # shape (bsz, slen, start_n_top, hsz) | ||
1452 | - | ||
1453 | - hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( | ||
1454 | - start_states | ||
1455 | - ) # shape (bsz, slen, start_n_top, hsz) | ||
1456 | - p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None | ||
1457 | - end_logits = self.end_logits( | ||
1458 | - hidden_states_expanded, start_states=start_states, p_mask=p_mask | ||
1459 | - ) | ||
1460 | - end_log_probs = F.softmax( | ||
1461 | - end_logits, dim=1 | ||
1462 | - ) # shape (bsz, slen, start_n_top) | ||
1463 | - | ||
1464 | - end_top_log_probs, end_top_index = torch.topk( | ||
1465 | - end_log_probs, self.end_n_top, dim=1 | ||
1466 | - ) # shape (bsz, end_n_top, start_n_top) | ||
1467 | - end_top_log_probs = end_top_log_probs.view( | ||
1468 | - -1, self.start_n_top * self.end_n_top | ||
1469 | - ) | ||
1470 | - end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) | ||
1471 | - | ||
1472 | - start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) | ||
1473 | - cls_logits = self.answer_class( | ||
1474 | - hidden_states, start_states=start_states, cls_index=cls_index | ||
1475 | - ) | ||
1476 | - | ||
1477 | - if not return_dict: | ||
1478 | - return ( | ||
1479 | - start_top_log_probs, | ||
1480 | - start_top_index, | ||
1481 | - end_top_log_probs, | ||
1482 | - end_top_index, | ||
1483 | - cls_logits, | ||
1484 | - ) | ||
1485 | - else: | ||
1486 | - return SquadHeadOutput( | ||
1487 | - start_top_log_probs=start_top_log_probs, | ||
1488 | - start_top_index=start_top_index, | ||
1489 | - end_top_log_probs=end_top_log_probs, | ||
1490 | - end_top_index=end_top_index, | ||
1491 | - cls_logits=cls_logits, | ||
1492 | - ) | ||
1493 | - | ||
1494 | - | ||
1495 | -class SequenceSummary(nn.Module): | ||
1496 | - r""" | ||
1497 | - Compute a single vector summary of a sequence hidden states. | ||
1498 | - | ||
1499 | - Args: | ||
1500 | - config (:class:`~transformers.PretrainedConfig`): | ||
1501 | - The config used by the model. Relevant arguments in the config class of the model are (refer to the | ||
1502 | - actual config class of your model for the default values it uses): | ||
1503 | - | ||
1504 | - - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are: | ||
1505 | - | ||
1506 | - - :obj:`"last"` -- Take the last token hidden state (like XLNet) | ||
1507 | - - :obj:`"first"` -- Take the first token hidden state (like Bert) | ||
1508 | - - :obj:`"mean"` -- Take the mean of all tokens hidden states | ||
1509 | - - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) | ||
1510 | - - :obj:`"attn"` -- Not implemented now, use multi-head attention | ||
1511 | - | ||
1512 | - - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction. | ||
1513 | - - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to | ||
1514 | - :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`). | ||
1515 | - - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the | ||
1516 | - output, another string or :obj:`None` will add no activation. | ||
1517 | - - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and | ||
1518 | - activation. | ||
1519 | - - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and | ||
1520 | - activation. | ||
1521 | - """ | ||
1522 | - | ||
1523 | - def __init__(self, config: PretrainedConfig): | ||
1524 | - super().__init__() | ||
1525 | - | ||
1526 | - self.summary_type = getattr(config, "summary_type", "last") | ||
1527 | - if self.summary_type == "attn": | ||
1528 | - # We should use a standard multi-head attention module with absolute positional embedding for that. | ||
1529 | - # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 | ||
1530 | - # We can probably just use the multi-head attention module of PyTorch >=1.1.0 | ||
1531 | - raise NotImplementedError | ||
1532 | - | ||
1533 | - self.summary = Identity() | ||
1534 | - if hasattr(config, "summary_use_proj") and config.summary_use_proj: | ||
1535 | - if ( | ||
1536 | - hasattr(config, "summary_proj_to_labels") | ||
1537 | - and config.summary_proj_to_labels | ||
1538 | - and config.num_labels > 0 | ||
1539 | - ): | ||
1540 | - num_classes = config.num_labels | ||
1541 | - else: | ||
1542 | - num_classes = config.hidden_size | ||
1543 | - self.summary = nn.Linear(config.hidden_size, num_classes) | ||
1544 | - | ||
1545 | - activation_string = getattr(config, "summary_activation", None) | ||
1546 | - self.activation: Callable = get_activation( | ||
1547 | - activation_string | ||
1548 | - ) if activation_string else Identity() | ||
1549 | - | ||
1550 | - self.first_dropout = Identity() | ||
1551 | - if ( | ||
1552 | - hasattr(config, "summary_first_dropout") | ||
1553 | - and config.summary_first_dropout > 0 | ||
1554 | - ): | ||
1555 | - self.first_dropout = nn.Dropout(config.summary_first_dropout) | ||
1556 | - | ||
1557 | - self.last_dropout = Identity() | ||
1558 | - if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: | ||
1559 | - self.last_dropout = nn.Dropout(config.summary_last_dropout) | ||
1560 | - | ||
1561 | - def forward( | ||
1562 | - self, | ||
1563 | - hidden_states: torch.FloatTensor, | ||
1564 | - cls_index: Optional[torch.LongTensor] = None, | ||
1565 | - ) -> torch.FloatTensor: | ||
1566 | - """ | ||
1567 | - Compute a single vector summary of a sequence hidden states. | ||
1568 | - | ||
1569 | - Args: | ||
1570 | - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`): | ||
1571 | - The hidden states of the last layer. | ||
1572 | - cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`): | ||
1573 | - Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification | ||
1574 | - token. | ||
1575 | - | ||
1576 | - Returns: | ||
1577 | - :obj:`torch.FloatTensor`: The summary of the sequence hidden states. | ||
1578 | - """ | ||
1579 | - if self.summary_type == "last": | ||
1580 | - output = hidden_states[:, -1] | ||
1581 | - elif self.summary_type == "first": | ||
1582 | - output = hidden_states[:, 0] | ||
1583 | - elif self.summary_type == "mean": | ||
1584 | - output = hidden_states.mean(dim=1) | ||
1585 | - elif self.summary_type == "cls_index": | ||
1586 | - if cls_index is None: | ||
1587 | - cls_index = torch.full_like( | ||
1588 | - hidden_states[..., :1, :], | ||
1589 | - hidden_states.shape[-2] - 1, | ||
1590 | - dtype=torch.long, | ||
1591 | - ) | ||
1592 | - else: | ||
1593 | - cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) | ||
1594 | - cls_index = cls_index.expand( | ||
1595 | - (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),) | ||
1596 | - ) | ||
1597 | - # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states | ||
1598 | - output = hidden_states.gather(-2, cls_index).squeeze( | ||
1599 | - -2 | ||
1600 | - ) # shape (bsz, XX, hidden_size) | ||
1601 | - elif self.summary_type == "attn": | ||
1602 | - raise NotImplementedError | ||
1603 | - | ||
1604 | - output = self.first_dropout(output) | ||
1605 | - output = self.summary(output) | ||
1606 | - output = self.activation(output) | ||
1607 | - output = self.last_dropout(output) | ||
1608 | - | ||
1609 | - return output | ||
1610 | - | ||
1611 | - | ||
1612 | -def prune_linear_layer( | ||
1613 | - layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0 | ||
1614 | -) -> torch.nn.Linear: | ||
1615 | - """ | ||
1616 | - Prune a linear layer to keep only entries in index. | ||
1617 | - | ||
1618 | - Used to remove heads. | ||
1619 | - | ||
1620 | - Args: | ||
1621 | - layer (:obj:`torch.nn.Linear`): The layer to prune. | ||
1622 | - index (:obj:`torch.LongTensor`): The indices to keep in the layer. | ||
1623 | - dim (:obj:`int`, `optional`, defaults to 0): The dimension on which to keep the indices. | ||
1624 | - | ||
1625 | - Returns: | ||
1626 | - :obj:`torch.nn.Linear`: The pruned layer as a new layer with :obj:`requires_grad=True`. | ||
1627 | - """ | ||
1628 | - index = index.to(layer.weight.device) | ||
1629 | - W = layer.weight.index_select(dim, index).clone().detach() | ||
1630 | - if layer.bias is not None: | ||
1631 | - if dim == 1: | ||
1632 | - b = layer.bias.clone().detach() | ||
1633 | - else: | ||
1634 | - b = layer.bias[index].clone().detach() | ||
1635 | - new_size = list(layer.weight.size()) | ||
1636 | - new_size[dim] = len(index) | ||
1637 | - new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to( | ||
1638 | - layer.weight.device | ||
1639 | - ) | ||
1640 | - new_layer.weight.requires_grad = False | ||
1641 | - new_layer.weight.copy_(W.contiguous()) | ||
1642 | - new_layer.weight.requires_grad = True | ||
1643 | - if layer.bias is not None: | ||
1644 | - new_layer.bias.requires_grad = False | ||
1645 | - new_layer.bias.copy_(b.contiguous()) | ||
1646 | - new_layer.bias.requires_grad = True | ||
1647 | - return new_layer | ||
1648 | - | ||
1649 | - | ||
1650 | -def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D: | ||
1651 | - """ | ||
1652 | - Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights | ||
1653 | - are transposed. | ||
1654 | - | ||
1655 | - Used to remove heads. | ||
1656 | - | ||
1657 | - Args: | ||
1658 | - layer (:class:`~transformers.modeling_utils.Conv1D`): The layer to prune. | ||
1659 | - index (:obj:`torch.LongTensor`): The indices to keep in the layer. | ||
1660 | - dim (:obj:`int`, `optional`, defaults to 1): The dimension on which to keep the indices. | ||
1661 | - | ||
1662 | - Returns: | ||
1663 | - :class:`~transformers.modeling_utils.Conv1D`: The pruned layer as a new layer with :obj:`requires_grad=True`. | ||
1664 | - """ | ||
1665 | - index = index.to(layer.weight.device) | ||
1666 | - W = layer.weight.index_select(dim, index).clone().detach() | ||
1667 | - if dim == 0: | ||
1668 | - b = layer.bias.clone().detach() | ||
1669 | - else: | ||
1670 | - b = layer.bias[index].clone().detach() | ||
1671 | - new_size = list(layer.weight.size()) | ||
1672 | - new_size[dim] = len(index) | ||
1673 | - new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device) | ||
1674 | - new_layer.weight.requires_grad = False | ||
1675 | - new_layer.weight.copy_(W.contiguous()) | ||
1676 | - new_layer.weight.requires_grad = True | ||
1677 | - new_layer.bias.requires_grad = False | ||
1678 | - new_layer.bias.copy_(b.contiguous()) | ||
1679 | - new_layer.bias.requires_grad = True | ||
1680 | - return new_layer | ||
1681 | - | ||
1682 | - | ||
1683 | -def prune_layer( | ||
1684 | - layer: Union[torch.nn.Linear, Conv1D], | ||
1685 | - index: torch.LongTensor, | ||
1686 | - dim: Optional[int] = None, | ||
1687 | -) -> Union[torch.nn.Linear, Conv1D]: | ||
1688 | - """ | ||
1689 | - Prune a Conv1D or linear layer to keep only entries in index. | ||
1690 | - | ||
1691 | - Used to remove heads. | ||
1692 | - | ||
1693 | - Args: | ||
1694 | - layer (:obj:`Union[torch.nn.Linear, Conv1D]`): The layer to prune. | ||
1695 | - index (:obj:`torch.LongTensor`): The indices to keep in the layer. | ||
1696 | - dim (:obj:`int`, `optional`): The dimension on which to keep the indices. | ||
1697 | - | ||
1698 | - Returns: | ||
1699 | - :obj:`torch.nn.Linear` or :class:`~transformers.modeling_utils.Conv1D`: | ||
1700 | - The pruned layer as a new layer with :obj:`requires_grad=True`. | ||
1701 | - """ | ||
1702 | - if isinstance(layer, nn.Linear): | ||
1703 | - return prune_linear_layer(layer, index, dim=0 if dim is None else dim) | ||
1704 | - elif isinstance(layer, Conv1D): | ||
1705 | - return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) | ||
1706 | - else: | ||
1707 | - raise ValueError("Can't prune layer of class {}".format(layer.__class__)) | ||
1708 | - | ||
1709 | - | ||
1710 | -def apply_chunking_to_forward( | ||
1711 | - forward_fn: Callable[..., torch.Tensor], | ||
1712 | - chunk_size: int, | ||
1713 | - chunk_dim: int, | ||
1714 | - *input_tensors, | ||
1715 | -) -> torch.Tensor: | ||
1716 | - """ | ||
1717 | - This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the | ||
1718 | - dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory. | ||
1719 | - | ||
1720 | - If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as | ||
1721 | - directly applying :obj:`forward_fn` to :obj:`input_tensors`. | ||
1722 | - | ||
1723 | - Args: | ||
1724 | - forward_fn (:obj:`Callable[..., torch.Tensor]`): | ||
1725 | - The forward function of the model. | ||
1726 | - chunk_size (:obj:`int`): | ||
1727 | - The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`. | ||
1728 | - chunk_dim (:obj:`int`): | ||
1729 | - The dimension over which the :obj:`input_tensors` should be chunked. | ||
1730 | - input_tensors (:obj:`Tuple[torch.Tensor]`): | ||
1731 | - The input tensors of ``forward_fn`` which will be chunked. | ||
1732 | - Returns: | ||
1733 | - :obj:`torch.Tensor`: A tensor with the same shape as the :obj:`foward_fn` would have given if applied`. | ||
1734 | - | ||
1735 | - | ||
1736 | - Examples:: | ||
1737 | - | ||
1738 | - # rename the usual forward() fn to forward_chunk() | ||
1739 | - def forward_chunk(self, hidden_states): | ||
1740 | - hidden_states = self.decoder(hidden_states) | ||
1741 | - return hidden_states | ||
1742 | - | ||
1743 | - # implement a chunked forward function | ||
1744 | - def forward(self, hidden_states): | ||
1745 | - return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) | ||
1746 | - """ | ||
1747 | - | ||
1748 | - assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format( | ||
1749 | - input_tensors | ||
1750 | - ) | ||
1751 | - tensor_shape = input_tensors[0].shape | ||
1752 | - assert all( | ||
1753 | - input_tensor.shape == tensor_shape for input_tensor in input_tensors | ||
1754 | - ), "All input tenors have to be of the same shape" | ||
1755 | - | ||
1756 | - # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability | ||
1757 | - num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) | ||
1758 | - assert num_args_in_forward_chunk_fn == len( | ||
1759 | - input_tensors | ||
1760 | - ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format( | ||
1761 | - num_args_in_forward_chunk_fn, len(input_tensors) | ||
1762 | - ) | ||
1763 | - | ||
1764 | - if chunk_size > 0: | ||
1765 | - assert ( | ||
1766 | - input_tensors[0].shape[chunk_dim] % chunk_size == 0 | ||
1767 | - ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format( | ||
1768 | - input_tensors[0].shape[chunk_dim], chunk_size | ||
1769 | - ) | ||
1770 | - | ||
1771 | - num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size | ||
1772 | - | ||
1773 | - # chunk input tensor into tuples | ||
1774 | - input_tensors_chunks = tuple( | ||
1775 | - input_tensor.chunk(num_chunks, dim=chunk_dim) | ||
1776 | - for input_tensor in input_tensors | ||
1777 | - ) | ||
1778 | - # apply forward fn to every tuple | ||
1779 | - output_chunks = tuple( | ||
1780 | - forward_fn(*input_tensors_chunk) | ||
1781 | - for input_tensors_chunk in zip(*input_tensors_chunks) | ||
1782 | - ) | ||
1783 | - # concatenate output at same dimension | ||
1784 | - return torch.cat(output_chunks, dim=chunk_dim) | ||
1785 | - | ||
1786 | - return forward_fn(*input_tensors) |
train/utils.py
deleted
100644 → 0
1 | -import itertools | ||
2 | -import json | ||
3 | -import linecache | ||
4 | -import os | ||
5 | -import pickle | ||
6 | -from logging import getLogger | ||
7 | -from pathlib import Path | ||
8 | -from typing import Callable, Dict, Iterable, List | ||
9 | - | ||
10 | -import git | ||
11 | -import numpy as np | ||
12 | -import torch | ||
13 | -from rouge_score import rouge_scorer, scoring | ||
14 | -from sacrebleu import corpus_bleu | ||
15 | -from torch import nn | ||
16 | -from torch.utils.data import Dataset, Sampler | ||
17 | - | ||
18 | -from transformers import BartTokenizer | ||
19 | - | ||
20 | - | ||
21 | -def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): | ||
22 | - """From fairseq""" | ||
23 | - if target.dim() == lprobs.dim() - 1: | ||
24 | - target = target.unsqueeze(-1) | ||
25 | - nll_loss = -lprobs.gather(dim=-1, index=target) | ||
26 | - smooth_loss = -lprobs.sum(dim=-1, keepdim=True) | ||
27 | - if ignore_index is not None: | ||
28 | - pad_mask = target.eq(ignore_index) | ||
29 | - nll_loss.masked_fill_(pad_mask, 0.0) | ||
30 | - smooth_loss.masked_fill_(pad_mask, 0.0) | ||
31 | - else: | ||
32 | - nll_loss = nll_loss.squeeze(-1) | ||
33 | - smooth_loss = smooth_loss.squeeze(-1) | ||
34 | - | ||
35 | - nll_loss = nll_loss.sum() # mean()? Scared to break other math. | ||
36 | - smooth_loss = smooth_loss.sum() | ||
37 | - eps_i = epsilon / lprobs.size(-1) | ||
38 | - loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss | ||
39 | - return loss, nll_loss | ||
40 | - | ||
41 | - | ||
42 | -def encode_line( | ||
43 | - tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt" | ||
44 | -): | ||
45 | - """Only used by LegacyDataset""" | ||
46 | - extra_kw = ( | ||
47 | - {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} | ||
48 | - ) | ||
49 | - return tokenizer( | ||
50 | - [line], | ||
51 | - max_length=max_length, | ||
52 | - padding="max_length" if pad_to_max_length else None, | ||
53 | - truncation=True, | ||
54 | - return_tensors=return_tensors, | ||
55 | - **extra_kw, | ||
56 | - ) | ||
57 | - | ||
58 | - | ||
59 | -def lmap(f: Callable, x: Iterable) -> List: | ||
60 | - """list(map(f, x))""" | ||
61 | - return list(map(f, x)) | ||
62 | - | ||
63 | - | ||
64 | -def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: | ||
65 | - """Uses sacrebleu's corpus_bleu implementation.""" | ||
66 | - return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} | ||
67 | - | ||
68 | - | ||
69 | -def trim_batch( | ||
70 | - input_ids, pad_token_id, attention_mask=None, | ||
71 | -): | ||
72 | - """Remove columns that are populated exclusively by pad_token_id""" | ||
73 | - keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) | ||
74 | - if attention_mask is None: | ||
75 | - return input_ids[:, keep_column_mask] | ||
76 | - else: | ||
77 | - return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) | ||
78 | - | ||
79 | - | ||
80 | -class AbstractSeq2SeqDataset(Dataset): | ||
81 | - def __init__( | ||
82 | - self, | ||
83 | - tokenizer, | ||
84 | - data_dir, | ||
85 | - max_source_length, | ||
86 | - max_target_length, | ||
87 | - type_path="train", | ||
88 | - n_obs=None, | ||
89 | - src_lang=None, | ||
90 | - tgt_lang=None, | ||
91 | - prefix="", | ||
92 | - ): | ||
93 | - super().__init__() | ||
94 | - self.src_file = Path(data_dir).joinpath(type_path + ".source") | ||
95 | - self.tgt_file = Path(data_dir).joinpath(type_path + ".target") | ||
96 | - self.src_lens = self.get_char_lens(self.src_file) | ||
97 | - self.max_source_length = max_source_length | ||
98 | - self.max_target_length = max_target_length | ||
99 | - assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" | ||
100 | - self.tokenizer = tokenizer | ||
101 | - self.prefix = prefix | ||
102 | - if n_obs is not None: | ||
103 | - self.src_lens = self.src_lens[:n_obs] | ||
104 | - self.pad_token_id = self.tokenizer.pad_token_id | ||
105 | - self.src_lang = src_lang | ||
106 | - self.tgt_lang = tgt_lang | ||
107 | - self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer) | ||
108 | - | ||
109 | - def __len__(self): | ||
110 | - return len(self.src_lens) | ||
111 | - | ||
112 | - @staticmethod | ||
113 | - def get_char_lens(data_file): | ||
114 | - return [len(x) for x in Path(data_file).open().readlines()] | ||
115 | - | ||
116 | - def make_sortish_sampler(self, batch_size): | ||
117 | - return SortishSampler(self.src_lens, batch_size) | ||
118 | - | ||
119 | - def __getitem__(self, item): | ||
120 | - raise NotImplementedError("You must implement this") | ||
121 | - | ||
122 | - def collate_fn(self, batch): | ||
123 | - raise NotImplementedError("You must implement this") | ||
124 | - | ||
125 | - | ||
126 | -class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): | ||
127 | - def __getitem__(self, index) -> Dict[str, torch.Tensor]: | ||
128 | - """Call tokenizer on src and tgt_lines""" | ||
129 | - index = index + 1 # linecache starts at 1 | ||
130 | - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( | ||
131 | - "\n" | ||
132 | - ) | ||
133 | - tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | ||
134 | - assert source_line, f"empty source line for index {index}" | ||
135 | - assert tgt_line, f"empty tgt line for index {index}" | ||
136 | - source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) | ||
137 | - target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) | ||
138 | - | ||
139 | - source_ids = source_inputs["input_ids"].squeeze() | ||
140 | - target_ids = target_inputs["input_ids"].squeeze() | ||
141 | - src_mask = source_inputs["attention_mask"].squeeze() | ||
142 | - return { | ||
143 | - "input_ids": source_ids, | ||
144 | - "attention_mask": src_mask, | ||
145 | - "labels": target_ids, | ||
146 | - } | ||
147 | - | ||
148 | - def collate_fn(self, batch) -> Dict[str, torch.Tensor]: | ||
149 | - input_ids = torch.stack([x["input_ids"] for x in batch]) | ||
150 | - masks = torch.stack([x["attention_mask"] for x in batch]) | ||
151 | - target_ids = torch.stack([x["labels"] for x in batch]) | ||
152 | - pad_token_id = self.pad_token_id | ||
153 | - y = trim_batch(target_ids, pad_token_id) | ||
154 | - source_ids, source_mask = trim_batch( | ||
155 | - input_ids, pad_token_id, attention_mask=masks | ||
156 | - ) | ||
157 | - batch = { | ||
158 | - "input_ids": source_ids, | ||
159 | - "attention_mask": source_mask, | ||
160 | - "labels": y, | ||
161 | - } | ||
162 | - return batch | ||
163 | - | ||
164 | - | ||
165 | -class Seq2SeqDataset(AbstractSeq2SeqDataset): | ||
166 | - """A dataset that calls prepare_seq2seq_batch.""" | ||
167 | - | ||
168 | - def __getitem__(self, index) -> Dict[str, str]: | ||
169 | - index = index + 1 # linecache starts at 1 | ||
170 | - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( | ||
171 | - "\n" | ||
172 | - ) | ||
173 | - tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | ||
174 | - assert source_line, f"empty source line for index {index}" | ||
175 | - assert tgt_line, f"empty tgt line for index {index}" | ||
176 | - return { | ||
177 | - "tgt_texts": tgt_line, | ||
178 | - "src_texts": source_line, | ||
179 | - } | ||
180 | - | ||
181 | - def collate_fn(self, batch) -> Dict[str, torch.Tensor]: | ||
182 | - """Call prepare_seq2seq_batch.""" | ||
183 | - batch_encoding = self.tokenizer.prepare_seq2seq_batch( | ||
184 | - [x["src_texts"] for x in batch], | ||
185 | - src_lang=self.src_lang, | ||
186 | - tgt_texts=[x["tgt_texts"] for x in batch], | ||
187 | - tgt_lang=self.tgt_lang, | ||
188 | - max_length=self.max_source_length, | ||
189 | - max_target_length=self.max_target_length, | ||
190 | - return_tensors="pt", | ||
191 | - add_prefix_space=self.add_prefix_space, | ||
192 | - ) | ||
193 | - return batch_encoding.data | ||
194 | - | ||
195 | - | ||
196 | -class SortishSampler(Sampler): | ||
197 | - "Go through the text data by order of src length with a bit of randomness. From fastai repo." | ||
198 | - | ||
199 | - def __init__(self, data, batch_size): | ||
200 | - self.data, self.bs = data, batch_size | ||
201 | - | ||
202 | - def key(self, i): | ||
203 | - return self.data[i] | ||
204 | - | ||
205 | - def __len__(self) -> int: | ||
206 | - return len(self.data) | ||
207 | - | ||
208 | - def __iter__(self): | ||
209 | - idxs = np.random.permutation(len(self.data)) | ||
210 | - sz = self.bs * 50 | ||
211 | - ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] | ||
212 | - sort_idx = np.concatenate( | ||
213 | - [sorted(s, key=self.key, reverse=True) for s in ck_idx] | ||
214 | - ) | ||
215 | - sz = self.bs | ||
216 | - ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] | ||
217 | - max_ck = np.argmax( | ||
218 | - [self.key(ck[0]) for ck in ck_idx] | ||
219 | - ) # find the chunk with the largest key, | ||
220 | - ck_idx[0], ck_idx[max_ck] = ( | ||
221 | - ck_idx[max_ck], | ||
222 | - ck_idx[0], | ||
223 | - ) # then make sure it goes first. | ||
224 | - sort_idx = ( | ||
225 | - np.concatenate(np.random.permutation(ck_idx[1:])) | ||
226 | - if len(ck_idx) > 1 | ||
227 | - else np.array([], dtype=np.int) | ||
228 | - ) | ||
229 | - sort_idx = np.concatenate((ck_idx[0], sort_idx)) | ||
230 | - return iter(sort_idx) | ||
231 | - | ||
232 | - | ||
233 | -logger = getLogger(__name__) | ||
234 | - | ||
235 | - | ||
236 | -def use_task_specific_params(model, task): | ||
237 | - """Update config with summarization specific params.""" | ||
238 | - task_specific_params = model.config.task_specific_params | ||
239 | - | ||
240 | - if task_specific_params is not None: | ||
241 | - pars = task_specific_params.get(task, {}) | ||
242 | - logger.info(f"using task specific params for {task}: {pars}") | ||
243 | - model.config.update(pars) | ||
244 | - | ||
245 | - | ||
246 | -def pickle_load(path): | ||
247 | - """pickle.load(path)""" | ||
248 | - with open(path, "rb") as f: | ||
249 | - return pickle.load(f) | ||
250 | - | ||
251 | - | ||
252 | -def pickle_save(obj, path): | ||
253 | - """pickle.dump(obj, path)""" | ||
254 | - with open(path, "wb") as f: | ||
255 | - return pickle.dump(obj, f) | ||
256 | - | ||
257 | - | ||
258 | -def flatten_list(summary_ids: List[List]): | ||
259 | - return [x for x in itertools.chain.from_iterable(summary_ids)] | ||
260 | - | ||
261 | - | ||
262 | -def save_git_info(folder_path: str) -> None: | ||
263 | - """Save git information to output_dir/git_log.json""" | ||
264 | - repo_infos = get_git_info() | ||
265 | - save_json(repo_infos, os.path.join(folder_path, "git_log.json")) | ||
266 | - | ||
267 | - | ||
268 | -def save_json(content, path): | ||
269 | - with open(path, "w") as f: | ||
270 | - json.dump(content, f, indent=4) | ||
271 | - | ||
272 | - | ||
273 | -def load_json(path): | ||
274 | - with open(path) as f: | ||
275 | - return json.load(f) | ||
276 | - | ||
277 | - | ||
278 | -def get_git_info(): | ||
279 | - repo = git.Repo(search_parent_directories=True) | ||
280 | - repo_infos = { | ||
281 | - "repo_id": str(repo), | ||
282 | - "repo_sha": str(repo.head.object.hexsha), | ||
283 | - "repo_branch": str(repo.active_branch), | ||
284 | - } | ||
285 | - return repo_infos | ||
286 | - | ||
287 | - | ||
288 | -ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] | ||
289 | - | ||
290 | - | ||
291 | -def calculate_rouge( | ||
292 | - output_lns: List[str], reference_lns: List[str], use_stemmer=True | ||
293 | -) -> Dict: | ||
294 | - scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) | ||
295 | - aggregator = scoring.BootstrapAggregator() | ||
296 | - | ||
297 | - for reference_ln, output_ln in zip(reference_lns, output_lns): | ||
298 | - scores = scorer.score(reference_ln, output_ln) | ||
299 | - aggregator.add_scores(scores) | ||
300 | - | ||
301 | - result = aggregator.aggregate() | ||
302 | - return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} | ||
303 | - | ||
304 | - | ||
305 | -# Utilities for freezing parameters and checking whether they are frozen | ||
306 | - | ||
307 | - | ||
308 | -def freeze_params(model: nn.Module): | ||
309 | - """Set requires_grad=False for each of model.parameters()""" | ||
310 | - for par in model.parameters(): | ||
311 | - par.requires_grad = False | ||
312 | - | ||
313 | - | ||
314 | -def grad_status(model: nn.Module) -> Iterable: | ||
315 | - return (par.requires_grad for par in model.parameters()) | ||
316 | - | ||
317 | - | ||
318 | -def any_requires_grad(model: nn.Module) -> bool: | ||
319 | - return any(grad_status(model)) | ||
320 | - | ||
321 | - | ||
322 | -def assert_all_frozen(model): | ||
323 | - model_grads: List[bool] = list(grad_status(model)) | ||
324 | - n_require_grad = sum(lmap(int, model_grads)) | ||
325 | - npars = len(model_grads) | ||
326 | - assert not any( | ||
327 | - model_grads | ||
328 | - ), f"{n_require_grad/npars:.1%} of {npars} weights require grad" | ||
329 | - | ||
330 | - | ||
331 | -def assert_not_all_frozen(model): | ||
332 | - model_grads: List[bool] = list(grad_status(model)) | ||
333 | - npars = len(model_grads) | ||
334 | - assert any(model_grads), f"none of {npars} weights require grad" |
-
Please register or login to post a comment