graykode

(remove) legacy commit suggester

1 -# Copyright 2020-present Tae Hwan Jung
2 -#
3 -# Licensed under the Apache License, Version 2.0 (the "License");
4 -# you may not use this file except in compliance with the License.
5 -# You may obtain a copy of the License at
6 -#
7 -# http://www.apache.org/licenses/LICENSE-2.0
8 -#
9 -# Unless required by applicable law or agreed to in writing, software
10 -# distributed under the License is distributed on an "AS IS" BASIS,
11 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -# See the License for the specific language governing permissions and
13 -# limitations under the License.
14 -
15 -import torch
16 -import argparse
17 -import subprocess
18 -from transformers import AutoTokenizer
19 -
20 -from preprocess import diff_parse, truncate
21 -from train import BartForConditionalGeneration
22 -
23 -def get_length(chunks):
24 - cnt = 0
25 - for chunk in chunks:
26 - cnt += len(chunk)
27 - return cnt
28 -
29 -def suggester(chunks, model, tokenizer, device):
30 - max_source_length = get_length(chunks)
31 -
32 - input_ids, attention_masks, patch_ids = zip(*chunks)
33 - input_ids = torch.LongTensor(
34 - [truncate(input_ids, max_source_length, value=0)]
35 - ).to(device)
36 - attention_masks = torch.LongTensor(
37 - [truncate(attention_masks, max_source_length, value=1)]
38 - ).to(device)
39 - patch_ids = torch.LongTensor(
40 - [truncate(patch_ids, max_source_length, value=0)]
41 - ).to(device)
42 -
43 - summaries = model.generate(
44 - input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks
45 - )
46 - return tokenizer.batch_decode(
47 - summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False
48 - )
49 -
50 -
51 -def main(args):
52 - device = torch.device(
53 - "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
54 - )
55 - model = BartForConditionalGeneration.from_pretrained(args.output_dir).to(device)
56 -
57 - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
58 -
59 - if args.unittest:
60 - with open("test.source", "r") as f:
61 - chunks = diff_parse(f.read(), tokenizer)
62 - else:
63 - proc = subprocess.Popen(["git", "diff", "--cached"], stdout=subprocess.PIPE)
64 - staged_files = proc.stdout.readlines()
65 - staged_files = [f.decode("utf-8") for f in staged_files]
66 - staged_files = [f.strip() for f in staged_files]
67 - chunks = "\n".join(staged_files)
68 -
69 - chunks = diff_parse(chunks, tokenizer)
70 - if not chunks:
71 - print('There is no file in staged state.')
72 - return
73 -
74 - commit_message = suggester(
75 - chunks,
76 - model=model,
77 - tokenizer=tokenizer,
78 - device=device,
79 - )
80 - print(commit_message)
81 -
82 -
83 -if __name__ == "__main__":
84 - parser = argparse.ArgumentParser(description="Code to collect commits on github")
85 - parser.add_argument(
86 - "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
87 - )
88 - parser.add_argument(
89 - "--unittest", action="store_true", help="Unittest with an one batch git diff"
90 - )
91 - parser.add_argument(
92 - "--output_dir",
93 - type=str,
94 - required=True,
95 - help="The output directory where the model predictions and checkpoints will be written.",
96 - )
97 - parser.add_argument(
98 - "--tokenizer_name",
99 - default="sshleifer/distilbart-xsum-6-6",
100 - type=str,
101 - help="Pretrained tokenizer name or path if not the same as model_name",
102 - )
103 - args = parser.parse_args()
104 -
105 - main(args)
1 -# Copyright 2020-present Tae Hwan Jung
2 -#
3 -# Licensed under the Apache License, Version 2.0 (the "License");
4 -# you may not use this file except in compliance with the License.
5 -# You may obtain a copy of the License at
6 -#
7 -# http://www.apache.org/licenses/LICENSE-2.0
8 -#
9 -# Unless required by applicable law or agreed to in writing, software
10 -# distributed under the License is distributed on an "AS IS" BASIS,
11 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -# See the License for the specific language governing permissions and
13 -# limitations under the License.
14 -
15 -from .gitcommit import diff_parse, truncate
16 -
17 -__all__ = [
18 - "diff_parse",
19 - "truncate",
20 -]
1 -# Copyright 2020-present Tae Hwan Jung
2 -#
3 -# Licensed under the Apache License, Version 2.0 (the "License");
4 -# you may not use this file except in compliance with the License.
5 -# You may obtain a copy of the License at
6 -#
7 -# http://www.apache.org/licenses/LICENSE-2.0
8 -#
9 -# Unless required by applicable law or agreed to in writing, software
10 -# distributed under the License is distributed on an "AS IS" BASIS,
11 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -# See the License for the specific language governing permissions and
13 -# limitations under the License.
14 -
15 -import os
16 -import re
17 -import enum
18 -import random
19 -import logging
20 -import tempfile
21 -import argparse
22 -import numpy as np
23 -from tqdm import *
24 -import whatthepatch
25 -from git import Repo
26 -from functools import partial
27 -from multiprocessing.pool import Pool
28 -from transformers import AutoTokenizer
29 -
30 -from matorage import *
31 -
32 -logger = logging.getLogger(__name__) # pylint: disable=invalid-name
33 -logging.basicConfig(
34 - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
35 - datefmt="%m/%d/%Y %H:%M:%S",
36 - level=logging.INFO,
37 -)
38 -
39 -
40 -class PATCH(enum.Enum):
41 - PLUS = 1
42 - MINUS = 2
43 -
44 -
45 -def truncate(tuple, max_length, value=0):
46 - ls = []
47 - for t in tuple:
48 - if isinstance(t, int):
49 - t = [t]
50 - ls.extend(t)
51 - ls = ls[: max_length - 1]
52 - ls.insert(0, value)
53 - if len(ls) < max_length:
54 - ls.extend([0] * (max_length - len(ls)))
55 - assert len(ls) == max_length
56 - return ls
57 -
58 -
59 -def encode_line(tokenizer, line, patch):
60 - line = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", line).strip()
61 - tokens = tokenizer.tokenize(line)
62 - tokens = tokenizer.convert_tokens_to_ids(tokens)
63 - return (tokens, [1] * len(tokens), len(tokens) * [patch.value])
64 -
65 -
66 -def diff_parse(diff, tokenizer):
67 - chunks = []
68 - for diff in whatthepatch.parse_patch(diff):
69 - if diff.header.old_path != diff.header.new_path:
70 - chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS))
71 - chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS))
72 - if not diff.changes:
73 - continue
74 - for change in diff.changes:
75 - if change.old == None and change.new != None:
76 - chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS))
77 - elif change.old != None and change.new == None:
78 - chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS))
79 - return chunks
80 -
81 -
82 -def sha_parse(sha, tokenizer, max_length=1024):
83 -
84 - chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer)
85 - if not chunks:
86 - return None
87 -
88 - input_ids, attention_masks, patch_ids = zip(*chunks)
89 - input_ids = truncate(input_ids, max_length, value=0)
90 - attention_masks = truncate(attention_masks, max_length, value=1)
91 - patch_ids = truncate(patch_ids, max_length, value=0)
92 -
93 - return (input_ids, attention_masks, patch_ids)
94 -
95 -
96 -def message_parse(msg, tokenizer, max_length=56):
97 - msg = re.sub(r"(\(|)#([0-9])+(\)|)", "", msg)
98 -
99 - msg = re.sub(r"[\u0100-\uFFFF\U00010000-\U0010FFFF]+", "", msg).strip()
100 - msg = tokenizer.tokenize(msg)
101 - msg = tokenizer.convert_tokens_to_ids(msg)
102 - msg = truncate(msg, max_length, value=0)
103 -
104 - return msg
105 -
106 -
107 -def jobs(sha_msgs, args, data_config, train=True):
108 -
109 - input_ids, attention_masks, patch_ids, targets = [], [], [], []
110 - data_saver = DataSaver(config=data_config)
111 -
112 - for sha_msg in sha_msgs:
113 - sha, msg = sha_msg
114 -
115 - source = sha_parse(
116 - sha, tokenizer=args.tokenizer, max_length=args.max_source_length
117 - )
118 - if not source:
119 - continue
120 - input_id, attention_mask, patch_id = source
121 - target = message_parse(
122 - msg,
123 - tokenizer=args.tokenizer,
124 - max_length=(
125 - args.max_target_length if train else args.val_max_target_length
126 - ),
127 - )
128 -
129 - input_ids.append(input_id)
130 - attention_masks.append(attention_mask)
131 - patch_ids.append(patch_id)
132 - targets.append(target)
133 -
134 - data_saver(
135 - {
136 - "input_ids": np.asarray(input_ids),
137 - "attention_masks": np.asarray(attention_masks),
138 - "patch_ids": np.asarray(patch_ids),
139 - "targets": np.asarray(targets),
140 - }
141 - )
142 - data_saver.disconnect()
143 -
144 -
145 -def start(chunked_sha_msgs, train=True):
146 -
147 - logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation"))
148 -
149 - max_target_length = args.max_target_length if train else args.val_max_target_length
150 -
151 - data_config = DataConfig(
152 - endpoint=args.endpoint,
153 - access_key=os.environ["access_key"],
154 - secret_key=os.environ["secret_key"],
155 - region=args.region,
156 - dataset_name="commit-autosuggestions",
157 - additional={
158 - "mode": ("training" if train else "evaluation"),
159 - "max_source_length": args.max_source_length,
160 - "max_target_length": max_target_length,
161 - "url": args.url,
162 - },
163 - attributes=[
164 - ("input_ids", "int32", (args.max_source_length,)),
165 - ("attention_masks", "int32", (args.max_source_length,)),
166 - ("patch_ids", "int32", (args.max_source_length,)),
167 - ("targets", "int32", (max_target_length,)),
168 - ],
169 - )
170 -
171 - func = partial(jobs, args=args, data_config=data_config, train=train)
172 - with Pool(processes=args.num_workers) as pool:
173 - with tqdm(total=len(chunked_sha_msgs)) as pbar:
174 - for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
175 - pbar.update()
176 -
177 -
178 -def main(args):
179 - if "access_key" not in os.environ or "secret_key" not in os.environ:
180 - raise OSError("access_key or secret_key are not found.")
181 -
182 - sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
183 - random.shuffle(sha_msgs)
184 - chunked_sha_msgs = [
185 - sha_msgs[x : x + args.matorage_batch]
186 - for x in range(0, len(sha_msgs), args.matorage_batch)
187 - ]
188 -
189 - barrier = int(len(chunked_sha_msgs) * (1 - args.p_val))
190 - if args.do_train:
191 - start(chunked_sha_msgs[:barrier], train=True)
192 - if args.do_predict:
193 - start(chunked_sha_msgs[barrier:], train=False)
194 -
195 -
196 -if __name__ == "__main__":
197 - parser = argparse.ArgumentParser(description="Code to collect commits on github")
198 - parser.add_argument("--url", type=str, required=True, help="github url")
199 - parser.add_argument(
200 - "--endpoint",
201 - type=str,
202 - required=True,
203 - help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
204 - )
205 - parser.add_argument(
206 - "--region",
207 - type=str,
208 - default=None,
209 - help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
210 - )
211 - parser.add_argument(
212 - "--tokenizer_name",
213 - default="sshleifer/distilbart-xsum-6-6",
214 - type=str,
215 - help="Pretrained tokenizer name or path if not the same as model_name",
216 - )
217 - parser.add_argument(
218 - "--matorage_batch",
219 - default=1024,
220 - type=int,
221 - help="The smallest batch size stored atomically in matorage.",
222 - )
223 - parser.add_argument(
224 - "--num_workers", default=4, type=int, help="number of process",
225 - )
226 - parser.add_argument(
227 - "--max_source_length",
228 - default=1024,
229 - type=int,
230 - help="The maximum total input sequence length after tokenization. Sequences longer "
231 - "than this will be truncated, sequences shorter will be padded.",
232 - )
233 - parser.add_argument(
234 - "--max_target_length",
235 - default=56,
236 - type=int,
237 - help="The maximum total input sequence length after tokenization. Sequences longer "
238 - "than this will be truncated, sequences shorter will be padded.",
239 - )
240 - parser.add_argument(
241 - "--val_max_target_length",
242 - default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
243 - type=int,
244 - help="The maximum total input sequence length after tokenization. Sequences longer "
245 - "than this will be truncated, sequences shorter will be padded.",
246 - )
247 - parser.add_argument(
248 - "--p_val", type=float, default=0.25, help="percent of validation dataset"
249 - )
250 - parser.add_argument("--do_train", action="store_true", default=False)
251 - parser.add_argument("--do_predict", action="store_true", default=False)
252 - args = parser.parse_args()
253 -
254 - args.local_path = args.url.split("/")[-1]
255 - logger.info(f"master branch of {args.url} will be downloaded to {args.local_path}")
256 - repo = (
257 - Repo(args.local_path)
258 - if os.path.exists(args.local_path)
259 - else Repo.clone_from(args.url, to_path=args.local_path, branch="master")
260 - )
261 - args.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
262 -
263 - main(args)
1 -commit b5a5268dabb2a4dea1c3c543a1ddff501b87a447
2 -Author: jbrockmendel <jbrockmendel@gmail.com>
3 -Date: Tue Sep 8 18:33:41 2020 -0700
4 -
5 - STY: De-privatize imported names (#36235)
6 -
7 -diff --git a/pandas/_libs/interval.pyx b/pandas/_libs/interval.pyx
8 -index 931ad8326..f8bcbcfb1 100644
9 ---- a/pandas/_libs/interval.pyx
10 -+++ b/pandas/_libs/interval.pyx
11 -@@ -46,7 +46,7 @@ from pandas._libs.tslibs.util cimport (
12 - is_timedelta64_object,
13 - )
14 -
15 --_VALID_CLOSED = frozenset(['left', 'right', 'both', 'neither'])
16 -+VALID_CLOSED = frozenset(['left', 'right', 'both', 'neither'])
17 -
18 -
19 - cdef class IntervalMixin:
20 -@@ -318,7 +318,7 @@ cdef class Interval(IntervalMixin):
21 - self._validate_endpoint(left)
22 - self._validate_endpoint(right)
23 -
24 -- if closed not in _VALID_CLOSED:
25 -+ if closed not in VALID_CLOSED:
26 - raise ValueError(f"invalid option for 'closed': {closed}")
27 - if not left <= right:
28 - raise ValueError("left side of interval must be <= right side")
29 -diff --git a/pandas/core/arrays/_arrow_utils.py b/pandas/core/arrays/_arrow_utils.py
30 -index 4a33e0e84..c89f5554d 100644
31 ---- a/pandas/core/arrays/_arrow_utils.py
32 -+++ b/pandas/core/arrays/_arrow_utils.py
33 -@@ -4,7 +4,7 @@ import json
34 - import numpy as np
35 - import pyarrow
36 -
37 --from pandas.core.arrays.interval import _VALID_CLOSED
38 -+from pandas.core.arrays.interval import VALID_CLOSED
39 -
40 - _pyarrow_version_ge_015 = LooseVersion(pyarrow.__version__) >= LooseVersion("0.15")
41 -
42 -@@ -83,7 +83,7 @@ if _pyarrow_version_ge_015:
43 - def __init__(self, subtype, closed):
44 - # attributes need to be set first before calling
45 - # super init (as that calls serialize)
46 -- assert closed in _VALID_CLOSED
47 -+ assert closed in VALID_CLOSED
48 - self._closed = closed
49 - if not isinstance(subtype, pyarrow.DataType):
50 - subtype = pyarrow.type_for_alias(str(subtype))
51 -diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py
52 -index d76e0fd62..1dbd3cfc6 100644
53 ---- a/pandas/core/arrays/interval.py
54 -+++ b/pandas/core/arrays/interval.py
55 -@@ -5,7 +5,12 @@ import numpy as np
56 -
57 - from pandas._config import get_option
58 -
59 --from pandas._libs.interval import Interval, IntervalMixin, intervals_to_interval_bounds
60 -+from pandas._libs.interval import (
61 -+ VALID_CLOSED,
62 -+ Interval,
63 -+ IntervalMixin,
64 -+ intervals_to_interval_bounds,
65 -+)
66 - from pandas.compat.numpy import function as nv
67 - from pandas.util._decorators import Appender
68 -
69 -@@ -42,7 +47,6 @@ from pandas.core.construction import array
70 - from pandas.core.indexers import check_array_indexer
71 - from pandas.core.indexes.base import ensure_index
72 -
73 --_VALID_CLOSED = {"left", "right", "both", "neither"}
74 - _interval_shared_docs = {}
75 -
76 - _shared_docs_kwargs = dict(
77 -@@ -475,7 +479,7 @@ class IntervalArray(IntervalMixin, ExtensionArray):
78 - * left and right have the same missing values
79 - * left is always below right
80 - """
81 -- if self.closed not in _VALID_CLOSED:
82 -+ if self.closed not in VALID_CLOSED:
83 - msg = f"invalid option for 'closed': {self.closed}"
84 - raise ValueError(msg)
85 - if len(self.left) != len(self.right):
86 -@@ -1012,7 +1016,7 @@ class IntervalArray(IntervalMixin, ExtensionArray):
87 - )
88 - )
89 - def set_closed(self, closed):
90 -- if closed not in _VALID_CLOSED:
91 -+ if closed not in VALID_CLOSED:
92 - msg = f"invalid option for 'closed': {closed}"
93 - raise ValueError(msg)
94 -
95 -diff --git a/pandas/core/arrays/sparse/__init__.py b/pandas/core/arrays/sparse/__init__.py
96 -index e928db499..e9ff4b7d4 100644
97 ---- a/pandas/core/arrays/sparse/__init__.py
98 -+++ b/pandas/core/arrays/sparse/__init__.py
99 -@@ -5,6 +5,6 @@ from pandas.core.arrays.sparse.array import (
100 - BlockIndex,
101 - IntIndex,
102 - SparseArray,
103 -- _make_index,
104 -+ make_sparse_index,
105 - )
106 - from pandas.core.arrays.sparse.dtype import SparseDtype
107 -diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py
108 -index 47c960dc9..853f7bb0b 100644
109 ---- a/pandas/core/arrays/sparse/array.py
110 -+++ b/pandas/core/arrays/sparse/array.py
111 -@@ -1556,7 +1556,7 @@ def make_sparse(arr: np.ndarray, kind="block", fill_value=None, dtype=None, copy
112 - else:
113 - indices = mask.nonzero()[0].astype(np.int32)
114 -
115 -- index = _make_index(length, indices, kind)
116 -+ index = make_sparse_index(length, indices, kind)
117 - sparsified_values = arr[mask]
118 - if dtype is not None:
119 - sparsified_values = astype_nansafe(sparsified_values, dtype=dtype)
120 -@@ -1564,7 +1564,7 @@ def make_sparse(arr: np.ndarray, kind="block", fill_value=None, dtype=None, copy
121 - return sparsified_values, index, fill_value
122 -
123 -
124 --def _make_index(length, indices, kind):
125 -+def make_sparse_index(length, indices, kind):
126 -
127 - if kind == "block" or isinstance(kind, BlockIndex):
128 - locs, lens = splib.get_blocks(indices)
129 -diff --git a/pandas/core/computation/engines.py b/pandas/core/computation/engines.py
130 -index 0cdc0f530..77a378369 100644
131 ---- a/pandas/core/computation/engines.py
132 -+++ b/pandas/core/computation/engines.py
133 -@@ -130,7 +130,7 @@ class PythonEngine(AbstractEngine):
134 - pass
135 -
136 -
137 --_engines: Dict[str, Type[AbstractEngine]] = {
138 -+ENGINES: Dict[str, Type[AbstractEngine]] = {
139 - "numexpr": NumExprEngine,
140 - "python": PythonEngine,
141 - }
142 -diff --git a/pandas/core/computation/eval.py b/pandas/core/computation/eval.py
143 -index f6a793514..630606b4d 100644
144 ---- a/pandas/core/computation/eval.py
145 -+++ b/pandas/core/computation/eval.py
146 -@@ -9,8 +9,8 @@ import warnings
147 - from pandas._libs.lib import no_default
148 - from pandas.util._validators import validate_bool_kwarg
149 -
150 --from pandas.core.computation.engines import _engines
151 --from pandas.core.computation.expr import Expr, _parsers
152 -+from pandas.core.computation.engines import ENGINES
153 -+from pandas.core.computation.expr import PARSERS, Expr
154 - from pandas.core.computation.parsing import tokenize_string
155 - from pandas.core.computation.scope import ensure_scope
156 -
157 -@@ -43,8 +43,8 @@ def _check_engine(engine: Optional[str]) -> str:
158 - if engine is None:
159 - engine = "numexpr" if NUMEXPR_INSTALLED else "python"
160 -
161 -- if engine not in _engines:
162 -- valid_engines = list(_engines.keys())
163 -+ if engine not in ENGINES:
164 -+ valid_engines = list(ENGINES.keys())
165 - raise KeyError(
166 - f"Invalid engine '{engine}' passed, valid engines are {valid_engines}"
167 - )
168 -@@ -75,9 +75,9 @@ def _check_parser(parser: str):
169 - KeyError
170 - * If an invalid parser is passed
171 - """
172 -- if parser not in _parsers:
173 -+ if parser not in PARSERS:
174 - raise KeyError(
175 -- f"Invalid parser '{parser}' passed, valid parsers are {_parsers.keys()}"
176 -+ f"Invalid parser '{parser}' passed, valid parsers are {PARSERS.keys()}"
177 - )
178 -
179 -
180 -@@ -341,7 +341,7 @@ def eval(
181 - parsed_expr = Expr(expr, engine=engine, parser=parser, env=env)
182 -
183 - # construct the engine and evaluate the parsed expression
184 -- eng = _engines[engine]
185 -+ eng = ENGINES[engine]
186 - eng_inst = eng(parsed_expr)
187 - ret = eng_inst.evaluate()
188 -
189 -diff --git a/pandas/core/computation/expr.py b/pandas/core/computation/expr.py
190 -index 8cff6abc0..f5897277d 100644
191 ---- a/pandas/core/computation/expr.py
192 -+++ b/pandas/core/computation/expr.py
193 -@@ -782,7 +782,7 @@ class Expr:
194 - self.env = env or Scope(level=level + 1)
195 - self.engine = engine
196 - self.parser = parser
197 -- self._visitor = _parsers[parser](self.env, self.engine, self.parser)
198 -+ self._visitor = PARSERS[parser](self.env, self.engine, self.parser)
199 - self.terms = self.parse()
200 -
201 - @property
202 -@@ -814,4 +814,4 @@ class Expr:
203 - return frozenset(term.name for term in com.flatten(self.terms))
204 -
205 -
206 --_parsers = {"python": PythonExprVisitor, "pandas": PandasExprVisitor}
207 -+PARSERS = {"python": PythonExprVisitor, "pandas": PandasExprVisitor}
208 -diff --git a/pandas/core/config_init.py b/pandas/core/config_init.py
209 -index 0c23f1b4b..bfe20551c 100644
210 ---- a/pandas/core/config_init.py
211 -+++ b/pandas/core/config_init.py
212 -@@ -314,9 +314,9 @@ pc_latex_multirow = """
213 -
214 -
215 - def table_schema_cb(key):
216 -- from pandas.io.formats.printing import _enable_data_resource_formatter
217 -+ from pandas.io.formats.printing import enable_data_resource_formatter
218 -
219 -- _enable_data_resource_formatter(cf.get_option(key))
220 -+ enable_data_resource_formatter(cf.get_option(key))
221 -
222 -
223 - def is_terminal() -> bool:
224 -diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py
225 -index 72003eab2..e870187fc 100644
226 ---- a/pandas/core/groupby/generic.py
227 -+++ b/pandas/core/groupby/generic.py
228 -@@ -70,9 +70,9 @@ from pandas.core.groupby.groupby import (
229 - GroupBy,
230 - _agg_template,
231 - _apply_docs,
232 -- _group_selection_context,
233 - _transform_template,
234 - get_groupby,
235 -+ group_selection_context,
236 - )
237 - from pandas.core.groupby.numba_ import generate_numba_func, split_for_numba
238 - from pandas.core.indexes.api import Index, MultiIndex, all_indexes_same
239 -@@ -230,7 +230,7 @@ class SeriesGroupBy(GroupBy[Series]):
240 - raise NotImplementedError(
241 - "Numba engine can only be used with a single function."
242 - )
243 -- with _group_selection_context(self):
244 -+ with group_selection_context(self):
245 - data = self._selected_obj
246 - result, index = self._aggregate_with_numba(
247 - data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
248 -@@ -685,7 +685,7 @@ class SeriesGroupBy(GroupBy[Series]):
249 - self, normalize=False, sort=True, ascending=False, bins=None, dropna=True
250 - ):
251 -
252 -- from pandas.core.reshape.merge import _get_join_indexers
253 -+ from pandas.core.reshape.merge import get_join_indexers
254 - from pandas.core.reshape.tile import cut
255 -
256 - if bins is not None and not np.iterable(bins):
257 -@@ -787,7 +787,7 @@ class SeriesGroupBy(GroupBy[Series]):
258 -
259 - right = [diff.cumsum() - 1, codes[-1]]
260 -
261 -- _, idx = _get_join_indexers(left, right, sort=False, how="left")
262 -+ _, idx = get_join_indexers(left, right, sort=False, how="left")
263 - out = np.where(idx != -1, out[idx], 0)
264 -
265 - if sort:
266 -@@ -942,7 +942,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
267 - raise NotImplementedError(
268 - "Numba engine can only be used with a single function."
269 - )
270 -- with _group_selection_context(self):
271 -+ with group_selection_context(self):
272 - data = self._selected_obj
273 - result, index = self._aggregate_with_numba(
274 - data, func, *args, engine_kwargs=engine_kwargs, **kwargs
275 -diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py
276 -index 6ef2e6703..1e3e56f4f 100644
277 ---- a/pandas/core/groupby/groupby.py
278 -+++ b/pandas/core/groupby/groupby.py
279 -@@ -459,9 +459,9 @@ class GroupByPlot(PandasObject):
280 -
281 -
282 - @contextmanager
283 --def _group_selection_context(groupby: "_GroupBy"):
284 -+def group_selection_context(groupby: "_GroupBy"):
285 - """
286 -- Set / reset the _group_selection_context.
287 -+ Set / reset the group_selection_context.
288 - """
289 - groupby._set_group_selection()
290 - try:
291 -@@ -737,7 +737,7 @@ b 2""",
292 - def _make_wrapper(self, name: str) -> Callable:
293 - assert name in self._apply_allowlist
294 -
295 -- with _group_selection_context(self):
296 -+ with group_selection_context(self):
297 - # need to setup the selection
298 - # as are not passed directly but in the grouper
299 - f = getattr(self._obj_with_exclusions, name)
300 -@@ -868,7 +868,7 @@ b 2""",
301 - # fails on *some* columns, e.g. a numeric operation
302 - # on a string grouper column
303 -
304 -- with _group_selection_context(self):
305 -+ with group_selection_context(self):
306 - return self._python_apply_general(f, self._selected_obj)
307 -
308 - return result
309 -@@ -994,7 +994,7 @@ b 2""",
310 - alias: str,
311 - npfunc: Callable,
312 - ):
313 -- with _group_selection_context(self):
314 -+ with group_selection_context(self):
315 - # try a cython aggregation if we can
316 - try:
317 - return self._cython_agg_general(
318 -@@ -1499,7 +1499,7 @@ class GroupBy(_GroupBy[FrameOrSeries]):
319 - )
320 - else:
321 - func = lambda x: x.var(ddof=ddof)
322 -- with _group_selection_context(self):
323 -+ with group_selection_context(self):
324 - return self._python_agg_general(func)
325 -
326 - @Substitution(name="groupby")
327 -@@ -1658,7 +1658,7 @@ class GroupBy(_GroupBy[FrameOrSeries]):
328 -
329 - @doc(DataFrame.describe)
330 - def describe(self, **kwargs):
331 -- with _group_selection_context(self):
332 -+ with group_selection_context(self):
333 - result = self.apply(lambda x: x.describe(**kwargs))
334 - if self.axis == 1:
335 - return result.T
336 -@@ -1963,7 +1963,7 @@ class GroupBy(_GroupBy[FrameOrSeries]):
337 - nth_values = list(set(n))
338 -
339 - nth_array = np.array(nth_values, dtype=np.intp)
340 -- with _group_selection_context(self):
341 -+ with group_selection_context(self):
342 -
343 - mask_left = np.in1d(self._cumcount_array(), nth_array)
344 - mask_right = np.in1d(
345 -@@ -2226,7 +2226,7 @@ class GroupBy(_GroupBy[FrameOrSeries]):
346 - 5 0
347 - dtype: int64
348 - """
349 -- with _group_selection_context(self):
350 -+ with group_selection_context(self):
351 - index = self._selected_obj.index
352 - result = self._obj_1d_constructor(self.grouper.group_info[0], index)
353 - if not ascending:
354 -@@ -2287,7 +2287,7 @@ class GroupBy(_GroupBy[FrameOrSeries]):
355 - 5 0
356 - dtype: int64
357 - """
358 -- with _group_selection_context(self):
359 -+ with group_selection_context(self):
360 - index = self._selected_obj.index
361 - cumcounts = self._cumcount_array(ascending=ascending)
362 - return self._obj_1d_constructor(cumcounts, index)
363 -diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py
364 -index 526dae7e2..8014b16d0 100644
365 ---- a/pandas/core/indexes/base.py
366 -+++ b/pandas/core/indexes/base.py
367 -@@ -3660,7 +3660,7 @@ class Index(IndexOpsMixin, PandasObject):
368 - return result
369 -
370 - def _join_non_unique(self, other, how="left", return_indexers=False):
371 -- from pandas.core.reshape.merge import _get_join_indexers
372 -+ from pandas.core.reshape.merge import get_join_indexers
373 -
374 - # We only get here if dtypes match
375 - assert self.dtype == other.dtype
376 -@@ -3668,7 +3668,7 @@ class Index(IndexOpsMixin, PandasObject):
377 - lvalues = self._get_engine_target()
378 - rvalues = other._get_engine_target()
379 -
380 -- left_idx, right_idx = _get_join_indexers(
381 -+ left_idx, right_idx = get_join_indexers(
382 - [lvalues], [rvalues], how=how, sort=True
383 - )
384 -
385 -diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py
386 -index 3f72577c9..154f41bf0 100644
387 ---- a/pandas/core/indexes/interval.py
388 -+++ b/pandas/core/indexes/interval.py
389 -@@ -59,7 +59,6 @@ from pandas.core.ops import get_op_result_name
390 - if TYPE_CHECKING:
391 - from pandas import CategoricalIndex # noqa:F401
392 -
393 --_VALID_CLOSED = {"left", "right", "both", "neither"}
394 - _index_doc_kwargs = dict(ibase._index_doc_kwargs)
395 -
396 - _index_doc_kwargs.update(
397 -diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py
398 -index 030dec369..9f19ea9ae 100644
399 ---- a/pandas/core/reshape/merge.py
400 -+++ b/pandas/core/reshape/merge.py
401 -@@ -859,7 +859,7 @@ class _MergeOperation:
402 -
403 - def _get_join_indexers(self):
404 - """ return the join indexers """
405 -- return _get_join_indexers(
406 -+ return get_join_indexers(
407 - self.left_join_keys, self.right_join_keys, sort=self.sort, how=self.how
408 - )
409 -
410 -@@ -1298,7 +1298,7 @@ class _MergeOperation:
411 - raise ValueError("Not a valid argument for validate")
412 -
413 -
414 --def _get_join_indexers(
415 -+def get_join_indexers(
416 - left_keys, right_keys, sort: bool = False, how: str = "inner", **kwargs
417 - ):
418 - """
419 -diff --git a/pandas/io/formats/printing.py b/pandas/io/formats/printing.py
420 -index edc6fbfff..0d2ca83f1 100644
421 ---- a/pandas/io/formats/printing.py
422 -+++ b/pandas/io/formats/printing.py
423 -@@ -243,7 +243,7 @@ def pprint_thing_encoded(
424 - return value.encode(encoding, errors)
425 -
426 -
427 --def _enable_data_resource_formatter(enable: bool) -> None:
428 -+def enable_data_resource_formatter(enable: bool) -> None:
429 - if "IPython" not in sys.modules:
430 - # definitely not in IPython
431 - return
432 -diff --git a/pandas/tests/arrays/sparse/test_libsparse.py b/pandas/tests/arrays/sparse/test_libsparse.py
433 -index a2f861d37..2d6e657de 100644
434 ---- a/pandas/tests/arrays/sparse/test_libsparse.py
435 -+++ b/pandas/tests/arrays/sparse/test_libsparse.py
436 -@@ -8,7 +8,7 @@ import pandas.util._test_decorators as td
437 -
438 - from pandas import Series
439 - import pandas._testing as tm
440 --from pandas.core.arrays.sparse import BlockIndex, IntIndex, _make_index
441 -+from pandas.core.arrays.sparse import BlockIndex, IntIndex, make_sparse_index
442 -
443 - TEST_LENGTH = 20
444 -
445 -@@ -273,41 +273,43 @@ class TestSparseIndexIntersect:
446 -
447 - class TestSparseIndexCommon:
448 - def test_int_internal(self):
449 -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind="integer")
450 -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind="integer")
451 - assert isinstance(idx, IntIndex)
452 - assert idx.npoints == 2
453 - tm.assert_numpy_array_equal(idx.indices, np.array([2, 3], dtype=np.int32))
454 -
455 -- idx = _make_index(4, np.array([], dtype=np.int32), kind="integer")
456 -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind="integer")
457 - assert isinstance(idx, IntIndex)
458 - assert idx.npoints == 0
459 - tm.assert_numpy_array_equal(idx.indices, np.array([], dtype=np.int32))
460 -
461 -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="integer")
462 -+ idx = make_sparse_index(
463 -+ 4, np.array([0, 1, 2, 3], dtype=np.int32), kind="integer"
464 -+ )
465 - assert isinstance(idx, IntIndex)
466 - assert idx.npoints == 4
467 - tm.assert_numpy_array_equal(idx.indices, np.array([0, 1, 2, 3], dtype=np.int32))
468 -
469 - def test_block_internal(self):
470 -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind="block")
471 -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind="block")
472 - assert isinstance(idx, BlockIndex)
473 - assert idx.npoints == 2
474 - tm.assert_numpy_array_equal(idx.blocs, np.array([2], dtype=np.int32))
475 - tm.assert_numpy_array_equal(idx.blengths, np.array([2], dtype=np.int32))
476 -
477 -- idx = _make_index(4, np.array([], dtype=np.int32), kind="block")
478 -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind="block")
479 - assert isinstance(idx, BlockIndex)
480 - assert idx.npoints == 0
481 - tm.assert_numpy_array_equal(idx.blocs, np.array([], dtype=np.int32))
482 - tm.assert_numpy_array_equal(idx.blengths, np.array([], dtype=np.int32))
483 -
484 -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="block")
485 -+ idx = make_sparse_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="block")
486 - assert isinstance(idx, BlockIndex)
487 - assert idx.npoints == 4
488 - tm.assert_numpy_array_equal(idx.blocs, np.array([0], dtype=np.int32))
489 - tm.assert_numpy_array_equal(idx.blengths, np.array([4], dtype=np.int32))
490 -
491 -- idx = _make_index(4, np.array([0, 2, 3], dtype=np.int32), kind="block")
492 -+ idx = make_sparse_index(4, np.array([0, 2, 3], dtype=np.int32), kind="block")
493 - assert isinstance(idx, BlockIndex)
494 - assert idx.npoints == 3
495 - tm.assert_numpy_array_equal(idx.blocs, np.array([0, 2], dtype=np.int32))
496 -@@ -315,7 +317,7 @@ class TestSparseIndexCommon:
497 -
498 - def test_lookup(self):
499 - for kind in ["integer", "block"]:
500 -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind=kind)
501 -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind=kind)
502 - assert idx.lookup(-1) == -1
503 - assert idx.lookup(0) == -1
504 - assert idx.lookup(1) == -1
505 -@@ -323,12 +325,14 @@ class TestSparseIndexCommon:
506 - assert idx.lookup(3) == 1
507 - assert idx.lookup(4) == -1
508 -
509 -- idx = _make_index(4, np.array([], dtype=np.int32), kind=kind)
510 -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind=kind)
511 -
512 - for i in range(-1, 5):
513 - assert idx.lookup(i) == -1
514 -
515 -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind=kind)
516 -+ idx = make_sparse_index(
517 -+ 4, np.array([0, 1, 2, 3], dtype=np.int32), kind=kind
518 -+ )
519 - assert idx.lookup(-1) == -1
520 - assert idx.lookup(0) == 0
521 - assert idx.lookup(1) == 1
522 -@@ -336,7 +340,7 @@ class TestSparseIndexCommon:
523 - assert idx.lookup(3) == 3
524 - assert idx.lookup(4) == -1
525 -
526 -- idx = _make_index(4, np.array([0, 2, 3], dtype=np.int32), kind=kind)
527 -+ idx = make_sparse_index(4, np.array([0, 2, 3], dtype=np.int32), kind=kind)
528 - assert idx.lookup(-1) == -1
529 - assert idx.lookup(0) == 0
530 - assert idx.lookup(1) == -1
531 -@@ -346,7 +350,7 @@ class TestSparseIndexCommon:
532 -
533 - def test_lookup_array(self):
534 - for kind in ["integer", "block"]:
535 -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind=kind)
536 -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind=kind)
537 -
538 - res = idx.lookup_array(np.array([-1, 0, 2], dtype=np.int32))
539 - exp = np.array([-1, -1, 0], dtype=np.int32)
540 -@@ -356,11 +360,13 @@ class TestSparseIndexCommon:
541 - exp = np.array([-1, 0, -1, 1], dtype=np.int32)
542 - tm.assert_numpy_array_equal(res, exp)
543 -
544 -- idx = _make_index(4, np.array([], dtype=np.int32), kind=kind)
545 -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind=kind)
546 - res = idx.lookup_array(np.array([-1, 0, 2, 4], dtype=np.int32))
547 - exp = np.array([-1, -1, -1, -1], dtype=np.int32)
548 -
549 -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind=kind)
550 -+ idx = make_sparse_index(
551 -+ 4, np.array([0, 1, 2, 3], dtype=np.int32), kind=kind
552 -+ )
553 - res = idx.lookup_array(np.array([-1, 0, 2], dtype=np.int32))
554 - exp = np.array([-1, 0, 2], dtype=np.int32)
555 - tm.assert_numpy_array_equal(res, exp)
556 -@@ -369,7 +375,7 @@ class TestSparseIndexCommon:
557 - exp = np.array([-1, 2, 1, 3], dtype=np.int32)
558 - tm.assert_numpy_array_equal(res, exp)
559 -
560 -- idx = _make_index(4, np.array([0, 2, 3], dtype=np.int32), kind=kind)
561 -+ idx = make_sparse_index(4, np.array([0, 2, 3], dtype=np.int32), kind=kind)
562 - res = idx.lookup_array(np.array([2, 1, 3, 0], dtype=np.int32))
563 - exp = np.array([1, -1, 2, 0], dtype=np.int32)
564 - tm.assert_numpy_array_equal(res, exp)
565 -@@ -402,25 +408,25 @@ class TestSparseIndexCommon:
566 -
567 - class TestBlockIndex:
568 - def test_block_internal(self):
569 -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind="block")
570 -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind="block")
571 - assert isinstance(idx, BlockIndex)
572 - assert idx.npoints == 2
573 - tm.assert_numpy_array_equal(idx.blocs, np.array([2], dtype=np.int32))
574 - tm.assert_numpy_array_equal(idx.blengths, np.array([2], dtype=np.int32))
575 -
576 -- idx = _make_index(4, np.array([], dtype=np.int32), kind="block")
577 -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind="block")
578 - assert isinstance(idx, BlockIndex)
579 - assert idx.npoints == 0
580 - tm.assert_numpy_array_equal(idx.blocs, np.array([], dtype=np.int32))
581 - tm.assert_numpy_array_equal(idx.blengths, np.array([], dtype=np.int32))
582 -
583 -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="block")
584 -+ idx = make_sparse_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="block")
585 - assert isinstance(idx, BlockIndex)
586 - assert idx.npoints == 4
587 - tm.assert_numpy_array_equal(idx.blocs, np.array([0], dtype=np.int32))
588 - tm.assert_numpy_array_equal(idx.blengths, np.array([4], dtype=np.int32))
589 -
590 -- idx = _make_index(4, np.array([0, 2, 3], dtype=np.int32), kind="block")
591 -+ idx = make_sparse_index(4, np.array([0, 2, 3], dtype=np.int32), kind="block")
592 - assert isinstance(idx, BlockIndex)
593 - assert idx.npoints == 3
594 - tm.assert_numpy_array_equal(idx.blocs, np.array([0, 2], dtype=np.int32))
595 -@@ -428,7 +434,7 @@ class TestBlockIndex:
596 -
597 - def test_make_block_boundary(self):
598 - for i in [5, 10, 100, 101]:
599 -- idx = _make_index(i, np.arange(0, i, 2, dtype=np.int32), kind="block")
600 -+ idx = make_sparse_index(i, np.arange(0, i, 2, dtype=np.int32), kind="block")
601 -
602 - exp = np.arange(0, i, 2, dtype=np.int32)
603 - tm.assert_numpy_array_equal(idx.blocs, exp)
604 -@@ -514,17 +520,19 @@ class TestIntIndex:
605 - IntIndex(length=5, indices=[1, 3, 3])
606 -
607 - def test_int_internal(self):
608 -- idx = _make_index(4, np.array([2, 3], dtype=np.int32), kind="integer")
609 -+ idx = make_sparse_index(4, np.array([2, 3], dtype=np.int32), kind="integer")
610 - assert isinstance(idx, IntIndex)
611 - assert idx.npoints == 2
612 - tm.assert_numpy_array_equal(idx.indices, np.array([2, 3], dtype=np.int32))
613 -
614 -- idx = _make_index(4, np.array([], dtype=np.int32), kind="integer")
615 -+ idx = make_sparse_index(4, np.array([], dtype=np.int32), kind="integer")
616 - assert isinstance(idx, IntIndex)
617 - assert idx.npoints == 0
618 - tm.assert_numpy_array_equal(idx.indices, np.array([], dtype=np.int32))
619 -
620 -- idx = _make_index(4, np.array([0, 1, 2, 3], dtype=np.int32), kind="integer")
621 -+ idx = make_sparse_index(
622 -+ 4, np.array([0, 1, 2, 3], dtype=np.int32), kind="integer"
623 -+ )
624 - assert isinstance(idx, IntIndex)
625 - assert idx.npoints == 4
626 - tm.assert_numpy_array_equal(idx.indices, np.array([0, 1, 2, 3], dtype=np.int32))
627 -diff --git a/pandas/tests/computation/test_compat.py b/pandas/tests/computation/test_compat.py
628 -index ead102f53..9fc3ed480 100644
629 ---- a/pandas/tests/computation/test_compat.py
630 -+++ b/pandas/tests/computation/test_compat.py
631 -@@ -5,7 +5,7 @@ import pytest
632 - from pandas.compat._optional import VERSIONS
633 -
634 - import pandas as pd
635 --from pandas.core.computation.engines import _engines
636 -+from pandas.core.computation.engines import ENGINES
637 - import pandas.core.computation.expr as expr
638 -
639 -
640 -@@ -26,8 +26,8 @@ def test_compat():
641 - pytest.skip("not testing numexpr version compat")
642 -
643 -
644 --@pytest.mark.parametrize("engine", _engines)
645 --@pytest.mark.parametrize("parser", expr._parsers)
646 -+@pytest.mark.parametrize("engine", ENGINES)
647 -+@pytest.mark.parametrize("parser", expr.PARSERS)
648 - def test_invalid_numexpr_version(engine, parser):
649 - def testit():
650 - a, b = 1, 2 # noqa
651 -diff --git a/pandas/tests/computation/test_eval.py b/pandas/tests/computation/test_eval.py
652 -index 72dc04e68..cca64a6bf 100644
653 ---- a/pandas/tests/computation/test_eval.py
654 -+++ b/pandas/tests/computation/test_eval.py
655 -@@ -19,7 +19,7 @@ from pandas import DataFrame, Series, compat, date_range
656 - import pandas._testing as tm
657 - from pandas.core.computation import pytables
658 - from pandas.core.computation.check import NUMEXPR_VERSION
659 --from pandas.core.computation.engines import NumExprClobberingError, _engines
660 -+from pandas.core.computation.engines import ENGINES, NumExprClobberingError
661 - import pandas.core.computation.expr as expr
662 - from pandas.core.computation.expr import (
663 - BaseExprVisitor,
664 -@@ -46,14 +46,14 @@ from pandas.core.computation.ops import (
665 - f"installed->{NUMEXPR_INSTALLED}",
666 - ),
667 - )
668 -- for engine in _engines
669 -+ for engine in ENGINES
670 - )
671 - ) # noqa
672 - def engine(request):
673 - return request.param
674 -
675 -
676 --@pytest.fixture(params=expr._parsers)
677 -+@pytest.fixture(params=expr.PARSERS)
678 - def parser(request):
679 - return request.param
680 -
681 -@@ -77,7 +77,7 @@ def unary_fns_for_ne():
682 -
683 -
684 - def engine_has_neg_frac(engine):
685 -- return _engines[engine].has_neg_frac
686 -+ return ENGINES[engine].has_neg_frac
687 -
688 -
689 - def _eval_single_bin(lhs, cmp1, rhs, engine):
690 -@@ -168,7 +168,7 @@ class TestEvalNumexprPandas:
691 - def setup_method(self, method):
692 - self.setup_ops()
693 - self.setup_data()
694 -- self.current_engines = (engine for engine in _engines if engine != self.engine)
695 -+ self.current_engines = (engine for engine in ENGINES if engine != self.engine)
696 -
697 - def teardown_method(self, method):
698 - del self.lhses, self.rhses, self.scalar_rhses, self.scalar_lhses
699 -@@ -1921,7 +1921,7 @@ _parsers: Dict[str, Type[BaseExprVisitor]] = {
700 - }
701 -
702 -
703 --@pytest.mark.parametrize("engine", _engines)
704 -+@pytest.mark.parametrize("engine", ENGINES)
705 - @pytest.mark.parametrize("parser", _parsers)
706 - def test_disallowed_nodes(engine, parser):
707 - VisitorClass = _parsers[parser]
...\ No newline at end of file ...\ No newline at end of file
1 -# Copyright 2020-present Tae Hwan Jung
2 -#
3 -# Licensed under the Apache License, Version 2.0 (the "License");
4 -# you may not use this file except in compliance with the License.
5 -# You may obtain a copy of the License at
6 -#
7 -# http://www.apache.org/licenses/LICENSE-2.0
8 -#
9 -# Unless required by applicable law or agreed to in writing, software
10 -# distributed under the License is distributed on an "AS IS" BASIS,
11 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -# See the License for the specific language governing permissions and
13 -# limitations under the License.
14 -
15 -import os
16 -import argparse
17 -import pytorch_lightning as pl
18 -from train.finetune import main, SummarizationModule
19 -
20 -if __name__ == "__main__":
21 - parser = argparse.ArgumentParser()
22 - parser = pl.Trainer.add_argparse_args(parser)
23 - parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
24 -
25 - args = parser.parse_args()
26 -
27 - main(args)
...\ No newline at end of file ...\ No newline at end of file
1 -# Copyright 2020-present Tae Hwan Jung
2 -#
3 -# Licensed under the Apache License, Version 2.0 (the "License");
4 -# you may not use this file except in compliance with the License.
5 -# You may obtain a copy of the License at
6 -#
7 -# http://www.apache.org/licenses/LICENSE-2.0
8 -#
9 -# Unless required by applicable law or agreed to in writing, software
10 -# distributed under the License is distributed on an "AS IS" BASIS,
11 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 -# See the License for the specific language governing permissions and
13 -# limitations under the License.
14 -
15 -from train.modeling_bart import BartForConditionalGeneration
16 -
17 -__all__ = ["BartForConditionalGeneration"]
1 -import logging
2 -import os
3 -from pathlib import Path
4 -
5 -import numpy as np
6 -import pytorch_lightning as pl
7 -import torch
8 -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
9 -from pytorch_lightning.utilities import rank_zero_only
10 -
11 -
12 -def count_trainable_parameters(model):
13 - model_parameters = filter(lambda p: p.requires_grad, model.parameters())
14 - params = sum([np.prod(p.size()) for p in model_parameters])
15 - return params
16 -
17 -
18 -logger = logging.getLogger(__name__)
19 -
20 -
21 -class Seq2SeqLoggingCallback(pl.Callback):
22 - def on_batch_end(self, trainer, pl_module):
23 - lrs = {
24 - f"lr_group_{i}": param["lr"]
25 - for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)
26 - }
27 - pl_module.logger.log_metrics(lrs)
28 -
29 - @rank_zero_only
30 - def _write_logs(
31 - self,
32 - trainer: pl.Trainer,
33 - pl_module: pl.LightningModule,
34 - type_path: str,
35 - save_generations=True,
36 - ) -> None:
37 - logger.info(
38 - f"***** {type_path} results at step {trainer.global_step:05d} *****"
39 - )
40 - metrics = trainer.callback_metrics
41 - trainer.logger.log_metrics(
42 - {
43 - k: v
44 - for k, v in metrics.items()
45 - if k not in ["log", "progress_bar", "preds"]
46 - }
47 - )
48 - # Log results
49 - od = Path(pl_module.hparams.output_dir)
50 - if type_path == "test":
51 - results_file = od / "test_results.txt"
52 - generations_file = od / "test_generations.txt"
53 - else:
54 - # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
55 - # If people want this it will be easy enough to add back.
56 - results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
57 - generations_file = (
58 - od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
59 - )
60 - results_file.parent.mkdir(exist_ok=True)
61 - generations_file.parent.mkdir(exist_ok=True)
62 - with open(results_file, "a+") as writer:
63 - for key in sorted(metrics):
64 - if key in ["log", "progress_bar", "preds"]:
65 - continue
66 - val = metrics[key]
67 - if isinstance(val, torch.Tensor):
68 - val = val.item()
69 - msg = f"{key}: {val:.6f}\n"
70 - writer.write(msg)
71 -
72 - if not save_generations:
73 - return
74 -
75 - if "preds" in metrics:
76 - content = "\n".join(metrics["preds"])
77 - generations_file.open("w+").write(content)
78 -
79 - @rank_zero_only
80 - def on_train_start(self, trainer, pl_module):
81 - try:
82 - npars = pl_module.model.model.num_parameters()
83 - except AttributeError:
84 - npars = pl_module.model.num_parameters()
85 -
86 - n_trainable_pars = count_trainable_parameters(pl_module)
87 - # mp stands for million parameters
88 - trainer.logger.log_metrics(
89 - {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}
90 - )
91 -
92 - @rank_zero_only
93 - def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
94 - return self._write_logs(trainer, pl_module, "test")
95 -
96 -
97 -def get_checkpoint_callback(output_dir, metric):
98 - """Saves the best model by validation ROUGE2 score."""
99 - if metric == "rouge2":
100 - exp = "{val_avg_rouge2:.4f}-{step_count}"
101 - elif metric == "bleu":
102 - exp = "{val_avg_bleu:.4f}-{step_count}"
103 - else:
104 - raise NotImplementedError(
105 - f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
106 - )
107 -
108 - checkpoint_callback = ModelCheckpoint(
109 - filepath=os.path.join(output_dir, exp),
110 - monitor=f"val_{metric}",
111 - mode="max",
112 - save_top_k=1,
113 - period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
114 - )
115 - return checkpoint_callback
116 -
117 -
118 -def get_early_stopping_callback(metric, patience):
119 - return EarlyStopping(
120 - monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,
121 - )
1 -import argparse
2 -import glob
3 -import logging
4 -import os
5 -import time
6 -from collections import defaultdict
7 -from pathlib import Path
8 -from typing import Dict, List, Tuple
9 -
10 -import numpy as np
11 -import pytorch_lightning as pl
12 -import torch
13 -from torch.utils.data import DataLoader
14 -
15 -from train.lightning_base import BaseTransformer, add_generic_args, generic_train
16 -from transformers import MBartTokenizer, T5ForConditionalGeneration
17 -from transformers.modeling_bart import shift_tokens_right
18 -
19 -from matorage import DataConfig
20 -from matorage.torch import Dataset
21 -
22 -
23 -try:
24 - from .callbacks import (
25 - Seq2SeqLoggingCallback,
26 - get_checkpoint_callback,
27 - get_early_stopping_callback,
28 - )
29 - from .utils import (
30 - ROUGE_KEYS,
31 - LegacySeq2SeqDataset,
32 - Seq2SeqDataset,
33 - assert_all_frozen,
34 - calculate_bleu,
35 - calculate_rouge,
36 - flatten_list,
37 - freeze_params,
38 - get_git_info,
39 - label_smoothed_nll_loss,
40 - lmap,
41 - pickle_save,
42 - save_git_info,
43 - save_json,
44 - use_task_specific_params,
45 - )
46 -except ImportError:
47 - from callbacks import (
48 - Seq2SeqLoggingCallback,
49 - get_checkpoint_callback,
50 - get_early_stopping_callback,
51 - )
52 - from utils import (
53 - ROUGE_KEYS,
54 - LegacySeq2SeqDataset,
55 - Seq2SeqDataset,
56 - assert_all_frozen,
57 - calculate_bleu,
58 - calculate_rouge,
59 - flatten_list,
60 - freeze_params,
61 - get_git_info,
62 - label_smoothed_nll_loss,
63 - lmap,
64 - pickle_save,
65 - save_git_info,
66 - save_json,
67 - use_task_specific_params,
68 - )
69 -
70 -logger = logging.getLogger(__name__)
71 -
72 -
73 -class SummarizationModule(BaseTransformer):
74 - mode = "summarization"
75 - loss_names = ["loss"]
76 - metric_names = ROUGE_KEYS
77 - default_val_metric = "rouge2"
78 -
79 - def __init__(self, hparams, **kwargs):
80 - super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
81 - use_task_specific_params(self.model, "summarization")
82 - save_git_info(self.hparams.output_dir)
83 - self.metrics_save_path = Path(self.output_dir) / "metrics.json"
84 - self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
85 - pickle_save(self.hparams, self.hparams_save_path)
86 - self.step_count = 0
87 - self.metrics = defaultdict(list)
88 -
89 - self.target_lens = {
90 - "train": self.hparams.max_target_length,
91 - "val": self.hparams.val_max_target_length,
92 - "test": self.hparams.test_max_target_length,
93 - }
94 - assert (
95 - self.target_lens["train"] <= self.target_lens["val"]
96 - ), f"target_lens: {self.target_lens}"
97 - assert (
98 - self.target_lens["train"] <= self.target_lens["test"]
99 - ), f"target_lens: {self.target_lens}"
100 -
101 - if self.hparams.freeze_embeds:
102 - self.freeze_embeds()
103 - if self.hparams.freeze_encoder:
104 - freeze_params(self.model.get_encoder())
105 - assert_all_frozen(self.model.get_encoder())
106 -
107 - self.hparams.git_sha = get_git_info()["repo_sha"]
108 - self.num_workers = hparams.num_workers
109 - self.decoder_start_token_id = None # default to config
110 - if self.model.config.decoder_start_token_id is None and isinstance(
111 - self.tokenizer, MBartTokenizer
112 - ):
113 - self.decoder_start_token_id = self.tokenizer.lang_code_to_id[
114 - hparams.tgt_lang
115 - ]
116 - self.model.config.decoder_start_token_id = self.decoder_start_token_id
117 -
118 - self.eval_beams = (
119 - self.model.config.num_beams
120 - if self.hparams.eval_beams is None
121 - else self.hparams.eval_beams
122 - )
123 - assert (
124 - self.eval_beams >= 1
125 - ), f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
126 - self.val_metric = (
127 - self.default_val_metric
128 - if self.hparams.val_metric is None
129 - else self.hparams.val_metric
130 - )
131 -
132 - def freeze_embeds(self):
133 - """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
134 - try:
135 - freeze_params(self.model.model.shared)
136 - for d in [self.model.model.encoder, self.model.model.decoder]:
137 - freeze_params(d.embed_positions)
138 - freeze_params(d.embed_tokens)
139 - except AttributeError:
140 - freeze_params(self.model.shared)
141 - for d in [self.model.encoder, self.model.decoder]:
142 - freeze_params(d.embed_tokens)
143 -
144 - def forward(self, input_ids, patch_ids, **kwargs):
145 - return self.model(input_ids, patch_ids, **kwargs)
146 -
147 - def ids_to_clean_text(self, generated_ids: List[int]):
148 - gen_text = self.tokenizer.batch_decode(
149 - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
150 - )
151 - return lmap(str.strip, gen_text)
152 -
153 - def _step(self, batch: dict) -> Tuple:
154 - pad_token_id = self.tokenizer.pad_token_id
155 - src_ids, src_mask, src_patch = batch[0].long(), batch[1].long(), batch[2].long()
156 - tgt_ids = batch[3].long()
157 - if isinstance(self.model, T5ForConditionalGeneration):
158 - decoder_input_ids = self.model._shift_right(tgt_ids)
159 - else:
160 - decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
161 -
162 - outputs = self(
163 - src_ids,
164 - src_patch,
165 - attention_mask=src_mask,
166 - decoder_input_ids=decoder_input_ids,
167 - use_cache=False,
168 - )
169 - lm_logits = outputs[0]
170 - if self.hparams.label_smoothing == 0:
171 - # Same behavior as modeling_bart.py, besides ignoring pad_token_id
172 - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
173 -
174 - assert lm_logits.shape[-1] == self.model.config.vocab_size
175 - loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
176 - else:
177 - lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
178 - loss, nll_loss = label_smoothed_nll_loss(
179 - lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
180 - )
181 - return (loss,)
182 -
183 - @property
184 - def pad(self) -> int:
185 - return self.tokenizer.pad_token_id
186 -
187 - def training_step(self, batch, batch_idx) -> Dict:
188 - loss_tensors = self._step(batch)
189 -
190 - logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
191 - # tokens per batch
192 - logs["tpb"] = (
193 - batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum()
194 - )
195 - return {"loss": loss_tensors[0], "log": logs}
196 -
197 - def validation_step(self, batch, batch_idx) -> Dict:
198 - return self._generative_step(batch)
199 -
200 - def validation_epoch_end(self, outputs, prefix="val") -> Dict:
201 - self.step_count += 1
202 - losses = {
203 - k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names
204 - }
205 - loss = losses["loss"]
206 - rouges = {
207 - k: np.array([x[k] for x in outputs]).mean()
208 - for k in self.metric_names + ["gen_time", "gen_len"]
209 - }
210 - rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(
211 - loss
212 - )
213 - rouges.update({k: v.item() for k, v in losses.items()})
214 - losses.update(rouges)
215 - metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
216 - metrics["step_count"] = self.step_count
217 - self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
218 - preds = flatten_list([x["preds"] for x in outputs])
219 - return {
220 - "log": metrics,
221 - "preds": preds,
222 - f"{prefix}_loss": loss,
223 - f"{prefix}_{self.val_metric}": rouge_tensor,
224 - }
225 -
226 - def save_metrics(self, latest_metrics, type_path) -> None:
227 - self.metrics[type_path].append(latest_metrics)
228 - save_json(self.metrics, self.metrics_save_path)
229 -
230 - def calc_generative_metrics(self, preds, target) -> Dict:
231 - return calculate_rouge(preds, target)
232 -
233 - def _generative_step(self, batch: dict) -> dict:
234 - t0 = time.time()
235 - generated_ids = self.model.generate(
236 - batch[0].long(),
237 - patch_ids=batch[2].long(),
238 - attention_mask=batch[1].long(),
239 - use_cache=True,
240 - decoder_start_token_id=self.decoder_start_token_id,
241 - )
242 - gen_time = (time.time() - t0) / batch[0].shape[0]
243 - preds: List[str] = self.ids_to_clean_text(generated_ids)
244 - target: List[str] = self.ids_to_clean_text(batch[3])
245 - loss_tensors = self._step(batch)
246 - base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
247 - rouge: Dict = self.calc_generative_metrics(preds, target)
248 - summ_len = np.mean(lmap(len, generated_ids))
249 - base_metrics.update(
250 - gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge
251 - )
252 - return base_metrics
253 -
254 - def test_step(self, batch, batch_idx):
255 - return self._generative_step(batch)
256 -
257 - def test_epoch_end(self, outputs):
258 - return self.validation_epoch_end(outputs, prefix="test")
259 -
260 - def get_dataset(self, type_path) -> Seq2SeqDataset:
261 - max_target_length = self.target_lens[type_path]
262 - data_config = DataConfig(
263 - endpoint=self.hparams.endpoint,
264 - access_key=os.environ["access_key"],
265 - secret_key=os.environ["secret_key"],
266 - region=self.hparams.region,
267 - dataset_name="commit-autosuggestions",
268 - additional={
269 - "mode": ("training" if type_path == "train" else "evaluation"),
270 - "max_source_length": self.hparams.max_source_length,
271 - "max_target_length": max_target_length,
272 - "url": self.hparams.url,
273 - },
274 - attributes=[
275 - ("input_ids", "int32", (self.hparams.max_source_length,)),
276 - ("attention_masks", "int32", (self.hparams.max_source_length,)),
277 - ("patch_ids", "int32", (self.hparams.max_source_length,)),
278 - ("targets", "int32", (max_target_length,)),
279 - ],
280 - )
281 - return Dataset(config=data_config, clear=True)
282 -
283 - def get_dataloader(
284 - self, type_path: str, batch_size: int, shuffle: bool = False
285 - ) -> DataLoader:
286 - dataset = self.get_dataset(type_path)
287 - sampler = None
288 -
289 - dataloader = DataLoader(
290 - dataset,
291 - batch_size=batch_size,
292 - shuffle=shuffle,
293 - num_workers=self.num_workers,
294 - sampler=sampler,
295 - )
296 - return dataloader
297 -
298 - def train_dataloader(self) -> DataLoader:
299 - dataloader = self.get_dataloader(
300 - "train", batch_size=self.hparams.train_batch_size, shuffle=True
301 - )
302 - return dataloader
303 -
304 - def val_dataloader(self) -> DataLoader:
305 - return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
306 -
307 - def test_dataloader(self) -> DataLoader:
308 - return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
309 -
310 - @staticmethod
311 - def add_model_specific_args(parser, root_dir):
312 - BaseTransformer.add_model_specific_args(parser, root_dir)
313 - add_generic_args(parser, root_dir)
314 - parser.add_argument("--url", type=str, required=True, help="github url")
315 - parser.add_argument(
316 - "--endpoint",
317 - type=str,
318 - required=True,
319 - help="matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
320 - )
321 - parser.add_argument(
322 - "--region",
323 - type=str,
324 - default=None,
325 - help="matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html",
326 - )
327 - parser.add_argument(
328 - "--max_source_length",
329 - default=1024,
330 - type=int,
331 - help="The maximum total input sequence length after tokenization. Sequences longer "
332 - "than this will be truncated, sequences shorter will be padded.",
333 - )
334 - parser.add_argument(
335 - "--max_target_length",
336 - default=56,
337 - type=int,
338 - help="The maximum total input sequence length after tokenization. Sequences longer "
339 - "than this will be truncated, sequences shorter will be padded.",
340 - )
341 - parser.add_argument(
342 - "--val_max_target_length",
343 - default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
344 - type=int,
345 - help="The maximum total input sequence length after tokenization. Sequences longer "
346 - "than this will be truncated, sequences shorter will be padded.",
347 - )
348 - parser.add_argument(
349 - "--test_max_target_length",
350 - default=142,
351 - type=int,
352 - help="The maximum total input sequence length after tokenization. Sequences longer "
353 - "than this will be truncated, sequences shorter will be padded.",
354 - )
355 - parser.add_argument("--freeze_encoder", action="store_true")
356 - parser.add_argument("--freeze_embeds", action="store_true")
357 - parser.add_argument("--sortish_sampler", action="store_true", default=False)
358 - parser.add_argument(
359 - "--logger_name",
360 - type=str,
361 - choices=["default", "wandb", "wandb_shared"],
362 - default="default",
363 - )
364 - parser.add_argument(
365 - "--n_train",
366 - type=int,
367 - default=-1,
368 - required=False,
369 - help="# examples. -1 means use all.",
370 - )
371 - parser.add_argument(
372 - "--n_val",
373 - type=int,
374 - default=500,
375 - required=False,
376 - help="# examples. -1 means use all.",
377 - )
378 - parser.add_argument(
379 - "--n_test",
380 - type=int,
381 - default=-1,
382 - required=False,
383 - help="# examples. -1 means use all.",
384 - )
385 - parser.add_argument(
386 - "--task",
387 - type=str,
388 - default="summarization",
389 - required=False,
390 - help="# examples. -1 means use all.",
391 - )
392 - parser.add_argument(
393 - "--label_smoothing", type=float, default=0.0, required=False
394 - )
395 - parser.add_argument("--src_lang", type=str, default="", required=False)
396 - parser.add_argument("--tgt_lang", type=str, default="", required=False)
397 - parser.add_argument("--eval_beams", type=int, default=None, required=False)
398 - parser.add_argument("--val_metric", type=str, default=None, required=False)
399 - parser.add_argument(
400 - "--early_stopping_patience",
401 - type=int,
402 - default=-1,
403 - required=False,
404 - help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
405 - )
406 - return parser
407 -
408 -
409 -class TranslationModule(SummarizationModule):
410 - mode = "translation"
411 - loss_names = ["loss"]
412 - metric_names = ["bleu"]
413 - default_val_metric = "bleu"
414 -
415 - def __init__(self, hparams, **kwargs):
416 - super().__init__(hparams, **kwargs)
417 - self.dataset_kwargs["src_lang"] = hparams.src_lang
418 - self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
419 -
420 - def calc_generative_metrics(self, preds, target) -> dict:
421 - return calculate_bleu(preds, target)
422 -
423 -
424 -def main(args, model=None) -> SummarizationModule:
425 - Path(args.output_dir).mkdir(exist_ok=True)
426 - if len(os.listdir(args.output_dir)) > 3 and args.do_train:
427 - raise ValueError(
428 - "Output directory ({}) already exists and is not empty.".format(
429 - args.output_dir
430 - )
431 - )
432 - if model is None:
433 - if args.task == "summarization":
434 - model: SummarizationModule = SummarizationModule(args)
435 - else:
436 - model: SummarizationModule = TranslationModule(args)
437 -
438 - logger = True
439 - es_callback = False
440 - trainer: pl.Trainer = generic_train(
441 - model,
442 - args,
443 - logging_callback=Seq2SeqLoggingCallback(),
444 - checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
445 - early_stopping_callback=es_callback,
446 - logger=logger,
447 - # TODO: early stopping callback seems messed up
448 - )
449 - pickle_save(model.hparams, model.output_dir / "hparams.pkl")
450 - if not args.do_predict:
451 - return model
452 -
453 - model.hparams.test_checkpoint = ""
454 - checkpoints = list(
455 - sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))
456 - )
457 - if checkpoints:
458 - model.hparams.test_checkpoint = checkpoints[-1]
459 - trainer.resume_from_checkpoint = checkpoints[-1]
460 - trainer.logger.log_hyperparams(model.hparams)
461 -
462 - # test() without a model tests using the best checkpoint automatically
463 - trainer.test()
464 - return model
...\ No newline at end of file ...\ No newline at end of file
1 -# coding=utf-8
2 -# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3 -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 -#
5 -# Licensed under the Apache License, Version 2.0 (the "License");
6 -# you may not use this file except in compliance with the License.
7 -# You may obtain a copy of the License at
8 -#
9 -# http://www.apache.org/licenses/LICENSE-2.0
10 -#
11 -# Unless required by applicable law or agreed to in writing, software
12 -# distributed under the License is distributed on an "AS IS" BASIS,
13 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 -# See the License for the specific language governing permissions and
15 -# limitations under the License.
16 -
17 -from typing import Iterable, List, Optional, Tuple
18 -
19 -import torch
20 -from torch import Tensor
21 -from torch.nn import functional as F
22 -
23 -from transformers.file_utils import ModelOutput
24 -import logging
25 -
26 -logger = logging.getLogger(__name__) # pylint: disable=invalid-name
27 -logging.basicConfig(
28 - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
29 - datefmt="%m/%d/%Y %H:%M:%S",
30 - level=logging.INFO,
31 -)
32 -
33 -
34 -class GenerationMixin:
35 - """
36 - A class contraining all of the functions supporting generation, to be used as a mixin in
37 - :class:`~transfomers.PreTrainedModel`.
38 - """
39 -
40 - def prepare_inputs_for_generation(self, input_ids, **kwargs):
41 - """
42 - Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to prepare inputs in the
43 - generate method.
44 - """
45 - return {"input_ids": input_ids}
46 -
47 - def adjust_logits_during_generation(self, logits, **kwargs):
48 - """
49 - Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
50 - the generate method.
51 - """
52 - return logits
53 -
54 - def enforce_repetition_penalty_(
55 - self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty
56 - ):
57 - """
58 - Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
59 - """
60 - for i in range(batch_size * num_beams):
61 - for previous_token in set(prev_output_tokens[i].tolist()):
62 - # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
63 - if lprobs[i, previous_token] < 0:
64 - lprobs[i, previous_token] *= repetition_penalty
65 - else:
66 - lprobs[i, previous_token] /= repetition_penalty
67 -
68 - def postprocess_next_token_scores(
69 - self,
70 - scores,
71 - input_ids,
72 - no_repeat_ngram_size,
73 - bad_words_ids,
74 - cur_len,
75 - min_length,
76 - max_length,
77 - eos_token_id,
78 - repetition_penalty,
79 - batch_size,
80 - num_beams,
81 - ):
82 - # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
83 - if repetition_penalty != 1.0:
84 - self.enforce_repetition_penalty_(
85 - scores, batch_size, num_beams, input_ids, repetition_penalty,
86 - )
87 -
88 - # set eos token prob to zero if min_length is not reached
89 - if eos_token_id is not None and cur_len < min_length:
90 - scores[:, eos_token_id] = -float("inf")
91 -
92 - if no_repeat_ngram_size > 0:
93 - # calculate a list of banned tokens to prevent repetitively generating the same ngrams
94 - num_batch_hypotheses = batch_size * num_beams
95 - # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
96 - banned_batch_tokens = calc_banned_ngram_tokens(
97 - input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
98 - )
99 - for i, banned_tokens in enumerate(banned_batch_tokens):
100 - scores[i, banned_tokens] = -float("inf")
101 -
102 - if bad_words_ids is not None:
103 - # Exclude EOS token (already processed)
104 - bad_words_ids = list(
105 - filter(
106 - lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids
107 - )
108 - )
109 - # calculate a list of banned tokens according to bad words
110 - banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
111 - # Modify the scores in place by setting the banned tokens logits to `-inf`
112 - set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
113 -
114 - return scores
115 -
116 - @torch.no_grad()
117 - def generate(
118 - self,
119 - input_ids: Optional[torch.LongTensor] = None,
120 - patch_ids: Optional[torch.LongTensor] = None,
121 - max_length: Optional[int] = None,
122 - min_length: Optional[int] = None,
123 - do_sample: Optional[bool] = None,
124 - early_stopping: Optional[bool] = None,
125 - num_beams: Optional[int] = None,
126 - temperature: Optional[float] = None,
127 - top_k: Optional[int] = None,
128 - top_p: Optional[float] = None,
129 - repetition_penalty: Optional[float] = None,
130 - bad_words_ids: Optional[Iterable[int]] = None,
131 - bos_token_id: Optional[int] = None,
132 - pad_token_id: Optional[int] = None,
133 - eos_token_id: Optional[int] = None,
134 - length_penalty: Optional[float] = None,
135 - no_repeat_ngram_size: Optional[int] = None,
136 - num_return_sequences: Optional[int] = None,
137 - attention_mask: Optional[torch.LongTensor] = None,
138 - decoder_start_token_id: Optional[int] = None,
139 - use_cache: Optional[bool] = None,
140 - **model_kwargs,
141 - ) -> torch.LongTensor:
142 - r"""
143 - Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
144 - beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
145 -
146 - Adapted in part from `Facebook's XLM beam search code
147 - <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
148 -
149 - Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
150 - attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
151 - indicated are the default values of those config.
152 -
153 - Most of these parameters are explained in more detail in `this blog post
154 - <https://huggingface.co/blog/how-to-generate>`__.
155 -
156 - Parameters:
157 -
158 - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
159 - The sequence used as a prompt for the generation. If :obj:`None` the method initializes
160 - it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
161 - max_length (:obj:`int`, `optional`, defaults to 20):
162 - The maximum length of the sequence to be generated.
163 - min_length (:obj:`int`, `optional`, defaults to 10):
164 - The minimum length of the sequence to be generated.
165 - do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
166 - Whether or not to use sampling ; use greedy decoding otherwise.
167 - early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
168 - Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
169 - num_beams (:obj:`int`, `optional`, defaults to 1):
170 - Number of beams for beam search. 1 means no beam search.
171 - temperature (:obj:`float`, `optional`, defaults tp 1.0):
172 - The value used to module the next token probabilities.
173 - top_k (:obj:`int`, `optional`, defaults to 50):
174 - The number of highest probability vocabulary tokens to keep for top-k-filtering.
175 - top_p (:obj:`float`, `optional`, defaults to 1.0):
176 - If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or
177 - higher are kept for generation.
178 - repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
179 - The parameter for repetition penalty. 1.0 means no penalty. See `this paper
180 - <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
181 - pad_token_id (:obj:`int`, `optional`):
182 - The id of the `padding` token.
183 - bos_token_id (:obj:`int`, `optional`):
184 - The id of the `beginning-of-sequence` token.
185 - eos_token_id (:obj:`int`, `optional`):
186 - The id of the `end-of-sequence` token.
187 - length_penalty (:obj:`float`, `optional`, defaults to 1.0):
188 - Exponential penalty to the length. 1.0 means no penalty.
189 -
190 - Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
191 - order to encourage the model to produce longer sequences.
192 - no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
193 - If set to int > 0, all ngrams of that size can only occur once.
194 - bad_words_ids(:obj:`List[int]`, `optional`):
195 - List of token ids that are not allowed to be generated. In order to get the tokens of the words that
196 - should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
197 - num_return_sequences(:obj:`int`, `optional`, defaults to 1):
198 - The number of independently computed returned sequences for each element in the batch.
199 - attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
200 - Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
201 - tokens that are not masked, and 0 for masked tokens.
202 -
203 - If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token.
204 -
205 - `What are attention masks? <../glossary.html#attention-mask>`__
206 - decoder_start_token_id (:obj:`int`, `optional`):
207 - If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
208 - use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
209 - Whether or not the model should use the past last key/values attentions (if applicable to the model) to
210 - speed up decoding.
211 - model_kwargs:
212 - Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
213 -
214 - Return:
215 -
216 - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
217 - The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
218 - shorter if all batches finished early due to the :obj:`eos_token_id`.
219 -
220 - Examples::
221 -
222 - tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
223 - model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
224 - outputs = model.generate(max_length=40) # do greedy decoding
225 - print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
226 -
227 - tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
228 - model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
229 - input_context = 'The dog'
230 - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
231 - outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
232 - for i in range(3): # 3 output sequences were generated
233 - print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
234 -
235 - tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
236 - model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
237 - input_context = 'The dog'
238 - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
239 - outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
240 - for i in range(3): # 3 output sequences were generated
241 - print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
242 -
243 - tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
244 - model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
245 - input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
246 - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
247 - outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
248 - print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
249 -
250 - tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
251 - model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
252 - input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
253 - bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
254 - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
255 - outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
256 - """
257 -
258 - # We cannot generate if the model does not have a LM head
259 - if self.get_output_embeddings() is None:
260 - raise AttributeError(
261 - "You tried to generate sequences with a model that does not have a LM Head."
262 - "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
263 - )
264 -
265 - max_length = max_length if max_length is not None else self.config.max_length
266 - min_length = min_length if min_length is not None else self.config.min_length
267 - do_sample = do_sample if do_sample is not None else self.config.do_sample
268 - early_stopping = (
269 - early_stopping if early_stopping is not None else self.config.early_stopping
270 - )
271 - use_cache = use_cache if use_cache is not None else self.config.use_cache
272 - num_beams = num_beams if num_beams is not None else self.config.num_beams
273 - temperature = (
274 - temperature if temperature is not None else self.config.temperature
275 - )
276 - top_k = top_k if top_k is not None else self.config.top_k
277 - top_p = top_p if top_p is not None else self.config.top_p
278 - repetition_penalty = (
279 - repetition_penalty
280 - if repetition_penalty is not None
281 - else self.config.repetition_penalty
282 - )
283 - bos_token_id = (
284 - bos_token_id if bos_token_id is not None else self.config.bos_token_id
285 - )
286 - pad_token_id = (
287 - pad_token_id if pad_token_id is not None else self.config.pad_token_id
288 - )
289 - eos_token_id = (
290 - eos_token_id if eos_token_id is not None else self.config.eos_token_id
291 - )
292 - length_penalty = (
293 - length_penalty if length_penalty is not None else self.config.length_penalty
294 - )
295 - no_repeat_ngram_size = (
296 - no_repeat_ngram_size
297 - if no_repeat_ngram_size is not None
298 - else self.config.no_repeat_ngram_size
299 - )
300 - bad_words_ids = (
301 - bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
302 - )
303 - num_return_sequences = (
304 - num_return_sequences
305 - if num_return_sequences is not None
306 - else self.config.num_return_sequences
307 - )
308 - decoder_start_token_id = (
309 - decoder_start_token_id
310 - if decoder_start_token_id is not None
311 - else self.config.decoder_start_token_id
312 - )
313 -
314 - if input_ids is not None:
315 - batch_size = input_ids.shape[0] # overriden by the input batch_size
316 - else:
317 - batch_size = 1
318 -
319 - assert (
320 - isinstance(max_length, int) and max_length > 0
321 - ), "`max_length` should be a strictly positive integer."
322 - assert (
323 - isinstance(min_length, int) and min_length >= 0
324 - ), "`min_length` should be a positive integer."
325 - assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
326 - assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
327 - assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
328 - assert (
329 - isinstance(num_beams, int) and num_beams > 0
330 - ), "`num_beams` should be a strictly positive integer."
331 - assert temperature > 0, "`temperature` should be strictly positive."
332 - assert (
333 - isinstance(top_k, int) and top_k >= 0
334 - ), "`top_k` should be a positive integer."
335 - assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
336 - assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
337 - assert input_ids is not None or (
338 - isinstance(bos_token_id, int) and bos_token_id >= 0
339 - ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
340 - assert pad_token_id is None or (
341 - isinstance(pad_token_id, int) and (pad_token_id >= 0)
342 - ), "`pad_token_id` should be a positive integer."
343 - assert (eos_token_id is None) or (
344 - isinstance(eos_token_id, int) and (eos_token_id >= 0)
345 - ), "`eos_token_id` should be a positive integer."
346 - assert length_penalty > 0, "`length_penalty` should be strictly positive."
347 - assert (
348 - isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
349 - ), "`no_repeat_ngram_size` should be a positive integer."
350 - assert (
351 - isinstance(num_return_sequences, int) and num_return_sequences > 0
352 - ), "`num_return_sequences` should be a strictly positive integer."
353 - assert (
354 - bad_words_ids is None
355 - or isinstance(bad_words_ids, list)
356 - and isinstance(bad_words_ids[0], list)
357 - ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
358 -
359 - if input_ids is None:
360 - assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
361 - "you should either supply a context to complete as `input_ids` input "
362 - "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
363 - )
364 - input_ids = torch.full(
365 - (batch_size, 1),
366 - bos_token_id,
367 - dtype=torch.long,
368 - device=next(self.parameters()).device,
369 - )
370 - else:
371 - assert (
372 - input_ids.dim() == 2
373 - ), "Input prompt should be of shape (batch_size, sequence length)."
374 -
375 - # not allow to duplicate outputs when greedy decoding
376 - if do_sample is False:
377 - if num_beams == 1:
378 - # no_beam_search greedy generation conditions
379 - assert (
380 - num_return_sequences == 1
381 - ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
382 -
383 - else:
384 - # beam_search greedy generation conditions
385 - assert (
386 - num_beams >= num_return_sequences
387 - ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
388 -
389 - # create attention mask if necessary
390 - # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
391 - if (
392 - (attention_mask is None)
393 - and (pad_token_id is not None)
394 - and (pad_token_id in input_ids)
395 - ):
396 - attention_mask = input_ids.ne(pad_token_id).long()
397 - elif attention_mask is None:
398 - attention_mask = input_ids.new_ones(input_ids.shape)
399 -
400 - # set pad_token_id to eos_token_id if not set. Important that this is done after
401 - # attention_mask is created
402 - if pad_token_id is None and eos_token_id is not None:
403 - logger.warning(
404 - "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(
405 - eos_token_id
406 - )
407 - )
408 - pad_token_id = eos_token_id
409 -
410 - # current position and vocab size
411 - if hasattr(self.config, "vocab_size"):
412 - vocab_size = self.config.vocab_size
413 - elif (
414 - self.config.is_encoder_decoder
415 - and hasattr(self.config, "decoder")
416 - and hasattr(self.config.decoder, "vocab_size")
417 - ):
418 - vocab_size = self.config.decoder.vocab_size
419 -
420 - # set effective batch size and effective batch multiplier according to do_sample
421 - if do_sample:
422 - effective_batch_size = batch_size * num_return_sequences
423 - effective_batch_mult = num_return_sequences
424 - else:
425 - effective_batch_size = batch_size
426 - effective_batch_mult = 1
427 -
428 - if self.config.is_encoder_decoder:
429 - if decoder_start_token_id is None:
430 - # see if BOS token can be used for decoder_start_token_id
431 - if bos_token_id is not None:
432 - decoder_start_token_id = bos_token_id
433 - elif hasattr(self.config, "decoder") and hasattr(
434 - self.config.decoder, "bos_token_id"
435 - ):
436 - decoder_start_token_id = self.config.decoder.bos_token_id
437 - else:
438 - raise ValueError(
439 - "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
440 - )
441 -
442 - assert hasattr(
443 - self, "get_encoder"
444 - ), "{} should have a 'get_encoder' function defined".format(self)
445 - assert callable(self.get_encoder), "{} should be a method".format(
446 - self.get_encoder
447 - )
448 -
449 - # get encoder and store encoder outputs
450 - encoder = self.get_encoder()
451 - encoder_outputs: ModelOutput = encoder(
452 - input_ids, patch_ids, attention_mask=attention_mask, return_dict=True
453 - )
454 -
455 - # Expand input ids if num_beams > 1 or num_return_sequences > 1
456 - if num_return_sequences > 1 or num_beams > 1:
457 - input_ids_len = input_ids.shape[-1]
458 - input_ids = input_ids.unsqueeze(1).expand(
459 - batch_size, effective_batch_mult * num_beams, input_ids_len
460 - )
461 - patch_ids = patch_ids.unsqueeze(1).expand(
462 - batch_size, effective_batch_mult * num_beams, input_ids_len
463 - )
464 - attention_mask = attention_mask.unsqueeze(1).expand(
465 - batch_size, effective_batch_mult * num_beams, input_ids_len
466 - )
467 -
468 - input_ids = input_ids.contiguous().view(
469 - effective_batch_size * num_beams, input_ids_len
470 - ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
471 - patch_ids = patch_ids.contiguous().view(
472 - effective_batch_size * num_beams, input_ids_len
473 - ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
474 - attention_mask = attention_mask.contiguous().view(
475 - effective_batch_size * num_beams, input_ids_len
476 - ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
477 -
478 - if self.config.is_encoder_decoder:
479 - # create empty decoder_input_ids
480 - input_ids = torch.full(
481 - (effective_batch_size * num_beams, 1),
482 - decoder_start_token_id,
483 - dtype=torch.long,
484 - device=next(self.parameters()).device,
485 - )
486 - cur_len = 1
487 -
488 - assert (
489 - batch_size == encoder_outputs.last_hidden_state.shape[0]
490 - ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
491 -
492 - # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
493 - expanded_batch_idxs = (
494 - torch.arange(batch_size)
495 - .view(-1, 1)
496 - .repeat(1, num_beams * effective_batch_mult)
497 - .view(-1)
498 - .to(input_ids.device)
499 - )
500 -
501 - # expand encoder_outputs
502 - encoder_outputs[
503 - "last_hidden_state"
504 - ] = encoder_outputs.last_hidden_state.index_select(0, expanded_batch_idxs)
505 -
506 - # save encoder_outputs in `model_kwargs`
507 - model_kwargs["encoder_outputs"] = encoder_outputs
508 -
509 - else:
510 - cur_len = input_ids.shape[-1]
511 -
512 - assert (
513 - cur_len < max_length
514 - ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
515 -
516 - if num_beams > 1:
517 - output = self._generate_beam_search(
518 - input_ids,
519 - cur_len=cur_len,
520 - max_length=max_length,
521 - min_length=min_length,
522 - do_sample=do_sample,
523 - early_stopping=early_stopping,
524 - temperature=temperature,
525 - top_k=top_k,
526 - top_p=top_p,
527 - repetition_penalty=repetition_penalty,
528 - no_repeat_ngram_size=no_repeat_ngram_size,
529 - bad_words_ids=bad_words_ids,
530 - pad_token_id=pad_token_id,
531 - eos_token_id=eos_token_id,
532 - batch_size=effective_batch_size,
533 - num_return_sequences=num_return_sequences,
534 - length_penalty=length_penalty,
535 - num_beams=num_beams,
536 - vocab_size=vocab_size,
537 - attention_mask=attention_mask,
538 - use_cache=use_cache,
539 - model_kwargs=model_kwargs,
540 - )
541 - else:
542 - output = self._generate_no_beam_search(
543 - input_ids,
544 - cur_len=cur_len,
545 - max_length=max_length,
546 - min_length=min_length,
547 - do_sample=do_sample,
548 - temperature=temperature,
549 - top_k=top_k,
550 - top_p=top_p,
551 - repetition_penalty=repetition_penalty,
552 - no_repeat_ngram_size=no_repeat_ngram_size,
553 - bad_words_ids=bad_words_ids,
554 - pad_token_id=pad_token_id,
555 - eos_token_id=eos_token_id,
556 - batch_size=effective_batch_size,
557 - attention_mask=attention_mask,
558 - use_cache=use_cache,
559 - model_kwargs=model_kwargs,
560 - )
561 -
562 - return output
563 -
564 - def _generate_no_beam_search(
565 - self,
566 - input_ids,
567 - cur_len,
568 - max_length,
569 - min_length,
570 - do_sample,
571 - temperature,
572 - top_k,
573 - top_p,
574 - repetition_penalty,
575 - no_repeat_ngram_size,
576 - bad_words_ids,
577 - pad_token_id,
578 - eos_token_id,
579 - batch_size,
580 - attention_mask,
581 - use_cache,
582 - model_kwargs,
583 - ):
584 - """Generate sequences for each example without beam search (num_beams == 1).
585 - All returned sequence are generated independantly.
586 - """
587 - # length of generated sentences / unfinished sentences
588 - unfinished_sents = input_ids.new(batch_size).fill_(1)
589 - sent_lengths = input_ids.new(batch_size).fill_(max_length)
590 -
591 - past = None
592 - while cur_len < max_length:
593 - model_inputs = self.prepare_inputs_for_generation(
594 - input_ids,
595 - past=past,
596 - attention_mask=attention_mask,
597 - use_cache=use_cache,
598 - **model_kwargs,
599 - )
600 -
601 - outputs = self(**model_inputs, return_dict=True)
602 - next_token_logits = outputs.logits[:, -1, :]
603 -
604 - scores = self.postprocess_next_token_scores(
605 - scores=next_token_logits,
606 - input_ids=input_ids,
607 - no_repeat_ngram_size=no_repeat_ngram_size,
608 - bad_words_ids=bad_words_ids,
609 - cur_len=cur_len,
610 - min_length=min_length,
611 - max_length=max_length,
612 - eos_token_id=eos_token_id,
613 - repetition_penalty=repetition_penalty,
614 - batch_size=batch_size,
615 - num_beams=1,
616 - )
617 -
618 - # if model has past, then set the past variable to speed up decoding
619 - if "past_key_values" in outputs:
620 - past = outputs.past_key_values
621 - elif "mems" in outputs:
622 - past = outputs.mems
623 -
624 - if do_sample:
625 - # Temperature (higher temperature => more likely to sample low probability tokens)
626 - if temperature != 1.0:
627 - scores = scores / temperature
628 - # Top-p/top-k filtering
629 - next_token_logscores = top_k_top_p_filtering(
630 - scores, top_k=top_k, top_p=top_p
631 - )
632 - # Sample
633 - probs = F.softmax(next_token_logscores, dim=-1)
634 - next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
635 - else:
636 - # Greedy decoding
637 - next_token = torch.argmax(next_token_logits, dim=-1)
638 -
639 - # update generations and finished sentences
640 - if eos_token_id is not None:
641 - # pad finished sentences if eos_token_id exist
642 - tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (
643 - 1 - unfinished_sents
644 - )
645 - else:
646 - tokens_to_add = next_token
647 -
648 - # add token and increase length by one
649 - input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
650 - cur_len = cur_len + 1
651 -
652 - if eos_token_id is not None:
653 - eos_in_sents = tokens_to_add == eos_token_id
654 - # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
655 - is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(
656 - eos_in_sents.long()
657 - ).bool()
658 - sent_lengths.masked_fill_(
659 - is_sents_unfinished_and_token_to_add_is_eos, cur_len
660 - )
661 - # unfinished_sents is set to zero if eos in sentence
662 - unfinished_sents.mul_((~eos_in_sents).long())
663 -
664 - # stop when there is a </s> in each sentence, or if we exceed the maximul length
665 - if unfinished_sents.max() == 0:
666 - break
667 -
668 - # extend attention_mask for new generated input if only decoder
669 - if self.config.is_encoder_decoder is False:
670 - attention_mask = torch.cat(
671 - [
672 - attention_mask,
673 - attention_mask.new_ones((attention_mask.shape[0], 1)),
674 - ],
675 - dim=-1,
676 - )
677 -
678 - return input_ids
679 -
680 - def _generate_beam_search(
681 - self,
682 - input_ids,
683 - cur_len,
684 - max_length,
685 - min_length,
686 - do_sample,
687 - early_stopping,
688 - temperature,
689 - top_k,
690 - top_p,
691 - repetition_penalty,
692 - no_repeat_ngram_size,
693 - bad_words_ids,
694 - pad_token_id,
695 - eos_token_id,
696 - batch_size,
697 - num_return_sequences,
698 - length_penalty,
699 - num_beams,
700 - vocab_size,
701 - attention_mask,
702 - use_cache,
703 - model_kwargs,
704 - ):
705 - """Generate sequences for each example with beam search."""
706 -
707 - # generated hypotheses
708 - generated_hyps = [
709 - BeamHypotheses(
710 - num_beams, max_length, length_penalty, early_stopping=early_stopping
711 - )
712 - for _ in range(batch_size)
713 - ]
714 -
715 - # scores for each sentence in the beam
716 - beam_scores = torch.zeros(
717 - (batch_size, num_beams), dtype=torch.float, device=input_ids.device
718 - )
719 -
720 - # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
721 - if do_sample is False:
722 - beam_scores[:, 1:] = -1e9
723 - beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
724 -
725 - # cache compute states
726 - past = None
727 -
728 - # done sentences
729 - done = [False for _ in range(batch_size)]
730 -
731 - while cur_len < max_length:
732 - model_inputs = self.prepare_inputs_for_generation(
733 - input_ids,
734 - past=past,
735 - attention_mask=attention_mask,
736 - use_cache=use_cache,
737 - **model_kwargs,
738 - )
739 - outputs = self(
740 - **model_inputs, return_dict=True
741 - ) # (batch_size * num_beams, cur_len, vocab_size)
742 - next_token_logits = outputs.logits[
743 - :, -1, :
744 - ] # (batch_size * num_beams, vocab_size)
745 -
746 - # if model has past, then set the past variable to speed up decoding
747 - if "past_key_values" in outputs:
748 - past = outputs.past_key_values
749 - elif "mems" in outputs:
750 - past = outputs.mems
751 -
752 - if self.config.is_encoder_decoder and do_sample is False:
753 - # TODO (PVP) still a bit hacky here - there might be a better solution
754 - next_token_logits = self.adjust_logits_during_generation(
755 - next_token_logits, cur_len=cur_len, max_length=max_length
756 - )
757 -
758 - scores = F.log_softmax(
759 - next_token_logits, dim=-1
760 - ) # (batch_size * num_beams, vocab_size)
761 -
762 - scores = self.postprocess_next_token_scores(
763 - scores=scores,
764 - input_ids=input_ids,
765 - no_repeat_ngram_size=no_repeat_ngram_size,
766 - bad_words_ids=bad_words_ids,
767 - cur_len=cur_len,
768 - min_length=min_length,
769 - max_length=max_length,
770 - eos_token_id=eos_token_id,
771 - repetition_penalty=repetition_penalty,
772 - batch_size=batch_size,
773 - num_beams=num_beams,
774 - )
775 -
776 - assert scores.shape == (
777 - batch_size * num_beams,
778 - vocab_size,
779 - ), "Shapes of scores: {} != {}".format(
780 - scores.shape, (batch_size * num_beams, vocab_size)
781 - )
782 -
783 - if do_sample:
784 - _scores = scores + beam_scores[:, None].expand_as(
785 - scores
786 - ) # (batch_size * num_beams, vocab_size)
787 - # Temperature
788 - if temperature != 1.0:
789 - _scores = _scores / temperature
790 - # Top-p/top-k filtering
791 - _scores = top_k_top_p_filtering(
792 - _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
793 - ) # (batch_size * num_beams, vocab_size)
794 - # re-organize to group the beam together to sample from all beam_idxs
795 - _scores = _scores.contiguous().view(
796 - batch_size, num_beams * vocab_size
797 - ) # (batch_size, num_beams * vocab_size)
798 -
799 - # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
800 - probs = F.softmax(_scores, dim=-1)
801 - next_tokens = torch.multinomial(
802 - probs, num_samples=2 * num_beams
803 - ) # (batch_size, num_beams * 2)
804 - # Compute next scores
805 - next_scores = torch.gather(
806 - _scores, -1, next_tokens
807 - ) # (batch_size, num_beams * 2)
808 - # sort the sampled vector to make sure that the first num_beams samples are the best
809 - next_scores, next_scores_indices = torch.sort(
810 - next_scores, descending=True, dim=1
811 - )
812 - next_tokens = torch.gather(
813 - next_tokens, -1, next_scores_indices
814 - ) # (batch_size, num_beams * 2)
815 -
816 - else:
817 - next_scores = scores + beam_scores[:, None].expand_as(
818 - scores
819 - ) # (batch_size * num_beams, vocab_size)
820 -
821 - # re-organize to group the beam together (we are keeping top hypothesis accross beams)
822 - next_scores = next_scores.view(
823 - batch_size, num_beams * vocab_size
824 - ) # (batch_size, num_beams * vocab_size)
825 -
826 - next_scores, next_tokens = torch.topk(
827 - next_scores, 2 * num_beams, dim=1, largest=True, sorted=True
828 - )
829 -
830 - assert (
831 - next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
832 - )
833 -
834 - # next batch beam content
835 - next_batch_beam = []
836 -
837 - # for each sentence
838 - for batch_idx in range(batch_size):
839 -
840 - # if we are done with this sentence, add a pad token
841 - if done[batch_idx]:
842 - assert (
843 - len(generated_hyps[batch_idx]) >= num_beams
844 - ), "Batch can only be done if at least {} beams have been generated".format(
845 - num_beams
846 - )
847 - assert (
848 - eos_token_id is not None and pad_token_id is not None
849 - ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
850 - next_batch_beam.extend(
851 - [(0, pad_token_id, 0)] * num_beams
852 - ) # pad the batch
853 - continue
854 -
855 - # next sentence beam content, this will get added to next_batch_beam
856 - next_sent_beam = []
857 -
858 - # next tokens for this sentence
859 - for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
860 - zip(next_tokens[batch_idx], next_scores[batch_idx])
861 - ):
862 - # get beam and token IDs
863 - beam_id = beam_token_id // vocab_size
864 - token_id = beam_token_id % vocab_size
865 -
866 - effective_beam_id = batch_idx * num_beams + beam_id
867 - # add to generated hypotheses if end of sentence
868 - if (eos_token_id is not None) and (token_id.item() == eos_token_id):
869 - # if beam_token does not belong to top num_beams tokens, it should not be added
870 - is_beam_token_worse_than_top_num_beams = (
871 - beam_token_rank >= num_beams
872 - )
873 - if is_beam_token_worse_than_top_num_beams:
874 - continue
875 - generated_hyps[batch_idx].add(
876 - input_ids[effective_beam_id].clone(),
877 - beam_token_score.item(),
878 - )
879 - else:
880 - # add next predicted token since it is not eos_token
881 - next_sent_beam.append(
882 - (beam_token_score, token_id, effective_beam_id)
883 - )
884 -
885 - # once the beam for next step is full, don't add more tokens to it.
886 - if len(next_sent_beam) == num_beams:
887 - break
888 -
889 - # Check if we are done so that we can save a pad step if all(done)
890 - done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
891 - next_scores[batch_idx].max().item(), cur_len
892 - )
893 -
894 - # update next beam content
895 - assert len(next_sent_beam) == num_beams, "Beam should always be full"
896 - next_batch_beam.extend(next_sent_beam)
897 - assert len(next_batch_beam) == num_beams * (
898 - batch_idx + 1
899 - ), "We should have added num_beams each step"
900 -
901 - # stop when we are done with each sentence
902 - if all(done):
903 - break
904 -
905 - # sanity check / prepare next batch
906 - assert len(next_batch_beam) == batch_size * num_beams
907 - beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
908 - beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
909 - beam_idx = input_ids.new([x[2] for x in next_batch_beam])
910 -
911 - # re-order batch and update current length
912 - input_ids = input_ids[beam_idx, :]
913 - input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
914 - cur_len = cur_len + 1
915 -
916 - # re-order internal states
917 - if past is not None:
918 - past = self._reorder_cache(past, beam_idx)
919 -
920 - # extend attention_mask for new generated input if only decoder
921 - if self.config.is_encoder_decoder is False:
922 - attention_mask = torch.cat(
923 - [
924 - attention_mask,
925 - attention_mask.new_ones((attention_mask.shape[0], 1)),
926 - ],
927 - dim=-1,
928 - )
929 -
930 - # finalize all open beam hypotheses and add to generated hypotheses
931 - for batch_idx in range(batch_size):
932 - if done[batch_idx]:
933 - continue
934 -
935 - # test that beam scores match previously calculated scores if not eos and batch_idx not done
936 - if eos_token_id is not None and all(
937 - (token_id % vocab_size).item() != eos_token_id
938 - for token_id in next_tokens[batch_idx]
939 - ):
940 - assert torch.all(
941 - next_scores[batch_idx, :num_beams]
942 - == beam_scores.view(batch_size, num_beams)[batch_idx]
943 - ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
944 - next_scores[:, :num_beams][batch_idx],
945 - beam_scores.view(batch_size, num_beams)[batch_idx],
946 - )
947 -
948 - # need to add best num_beams hypotheses to generated hyps
949 - for beam_id in range(num_beams):
950 - effective_beam_id = batch_idx * num_beams + beam_id
951 - final_score = beam_scores[effective_beam_id].item()
952 - final_tokens = input_ids[effective_beam_id]
953 - generated_hyps[batch_idx].add(final_tokens, final_score)
954 -
955 - # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
956 - output_batch_size = (
957 - batch_size if do_sample else batch_size * num_return_sequences
958 - )
959 - output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
960 -
961 - # select the best hypotheses
962 - sent_lengths = input_ids.new(output_batch_size)
963 - best = []
964 -
965 - # retrieve best hypotheses
966 - for i, hypotheses in enumerate(generated_hyps):
967 - sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
968 - for j in range(output_num_return_sequences_per_batch):
969 - effective_batch_idx = output_num_return_sequences_per_batch * i + j
970 - best_hyp = sorted_hyps.pop()[1]
971 - sent_lengths[effective_batch_idx] = len(best_hyp)
972 - best.append(best_hyp)
973 -
974 - # shorter batches are padded
975 - if sent_lengths.min().item() != sent_lengths.max().item():
976 - assert pad_token_id is not None, "`Pad_token_id` has to be defined"
977 - sent_max_len = min(sent_lengths.max().item() + 1, max_length)
978 - decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
979 -
980 - # fill with hypothesis and eos_token_id if necessary
981 - for i, hypo in enumerate(best):
982 - decoded[i, : sent_lengths[i]] = hypo
983 - if sent_lengths[i] < max_length:
984 - decoded[i, sent_lengths[i]] = eos_token_id
985 - else:
986 - # none of the hypotheses have an eos_token
987 - assert (len(hypo) == max_length for hypo in best)
988 - decoded = (
989 - torch.stack(best).type(torch.long).to(next(self.parameters()).device)
990 - )
991 -
992 - return decoded
993 -
994 - @staticmethod
995 - def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
996 - return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
997 -
998 -
999 -def calc_banned_ngram_tokens(
1000 - prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int
1001 -) -> None:
1002 - """Copied from fairseq for no_repeat_ngram in beam_search"""
1003 - if cur_len + 1 < no_repeat_ngram_size:
1004 - # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
1005 - return [[] for _ in range(num_hypos)]
1006 - generated_ngrams = [{} for _ in range(num_hypos)]
1007 - for idx in range(num_hypos):
1008 - gen_tokens = prev_input_ids[idx].tolist()
1009 - generated_ngram = generated_ngrams[idx]
1010 - for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
1011 - prev_ngram_tuple = tuple(ngram[:-1])
1012 - generated_ngram[prev_ngram_tuple] = generated_ngram.get(
1013 - prev_ngram_tuple, []
1014 - ) + [ngram[-1]]
1015 -
1016 - def _get_generated_ngrams(hypo_idx):
1017 - # Before decoding the next token, prevent decoding of ngrams that have already appeared
1018 - start_idx = cur_len + 1 - no_repeat_ngram_size
1019 - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
1020 - return generated_ngrams[hypo_idx].get(ngram_idx, [])
1021 -
1022 - banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
1023 - return banned_tokens
1024 -
1025 -
1026 -def calc_banned_bad_words_ids(
1027 - prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]
1028 -) -> Iterable[int]:
1029 - banned_tokens = []
1030 -
1031 - def _tokens_match(prev_tokens, tokens):
1032 - if len(tokens) == 0:
1033 - # if bad word tokens is just one token always ban it
1034 - return True
1035 - if len(tokens) > len(prev_tokens):
1036 - # if bad word tokens are longer than prev tokens they can't be equal
1037 - return False
1038 -
1039 - if prev_tokens[-len(tokens) :] == tokens:
1040 - # if tokens match
1041 - return True
1042 - else:
1043 - return False
1044 -
1045 - for prev_input_ids_slice in prev_input_ids:
1046 - banned_tokens_slice = []
1047 -
1048 - for banned_token_seq in bad_words_ids:
1049 - assert (
1050 - len(banned_token_seq) > 0
1051 - ), "Banned words token sequences {} cannot have an empty list".format(
1052 - bad_words_ids
1053 - )
1054 -
1055 - if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
1056 - # if tokens do not match continue
1057 - continue
1058 -
1059 - banned_tokens_slice.append(banned_token_seq[-1])
1060 -
1061 - banned_tokens.append(banned_tokens_slice)
1062 -
1063 - return banned_tokens
1064 -
1065 -
1066 -def set_scores_to_inf_for_banned_tokens(
1067 - scores: torch.Tensor, banned_tokens: List[List[int]]
1068 -) -> None:
1069 - """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
1070 - a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
1071 - Args:
1072 - scores: logits distribution of shape (batch size, vocabulary size)
1073 - banned_tokens: list of list of tokens to ban of length (batch_size)
1074 - """
1075 - banned_mask_list = []
1076 - for idx, batch_banned_tokens in enumerate(banned_tokens):
1077 - for token in batch_banned_tokens:
1078 - banned_mask_list.append([idx, token])
1079 - if not banned_mask_list:
1080 - return
1081 - banned_mask = torch.LongTensor(banned_mask_list)
1082 - indices = torch.ones(len(banned_mask))
1083 - # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
1084 - # [ 0 1 1 ]
1085 - # [ 0 0 0 ]
1086 - # [ 1 0 0 ]
1087 -
1088 - banned_mask = (
1089 - torch.sparse.LongTensor(banned_mask.t(), indices, scores.size())
1090 - .to(scores.device)
1091 - .to_dense()
1092 - .bool()
1093 - )
1094 - scores.masked_fill_(banned_mask, -float("inf"))
1095 -
1096 -
1097 -def top_k_top_p_filtering(
1098 - logits: Tensor,
1099 - top_k: int = 0,
1100 - top_p: float = 1.0,
1101 - filter_value: float = -float("Inf"),
1102 - min_tokens_to_keep: int = 1,
1103 -) -> Tensor:
1104 - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
1105 - Args:
1106 - logits: logits distribution shape (batch size, vocabulary size)
1107 - if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
1108 - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
1109 - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
1110 - Make sure we keep at least min_tokens_to_keep per batch example in the output
1111 - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
1112 - """
1113 - if top_k > 0:
1114 - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
1115 - # Remove all tokens with a probability less than the last token of the top-k
1116 - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
1117 - logits[indices_to_remove] = filter_value
1118 -
1119 - if top_p < 1.0:
1120 - sorted_logits, sorted_indices = torch.sort(logits, descending=True)
1121 - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1122 -
1123 - # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
1124 - sorted_indices_to_remove = cumulative_probs > top_p
1125 - if min_tokens_to_keep > 1:
1126 - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
1127 - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
1128 - # Shift the indices to the right to keep also the first token above the threshold
1129 - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1130 - sorted_indices_to_remove[..., 0] = 0
1131 -
1132 - # scatter sorted tensors to original indexing
1133 - indices_to_remove = sorted_indices_to_remove.scatter(
1134 - 1, sorted_indices, sorted_indices_to_remove
1135 - )
1136 - logits[indices_to_remove] = filter_value
1137 - return logits
1138 -
1139 -
1140 -class BeamHypotheses(object):
1141 - def __init__(self, num_beams, max_length, length_penalty, early_stopping):
1142 - """
1143 - Initialize n-best list of hypotheses.
1144 - """
1145 - self.max_length = max_length - 1 # ignoring bos_token
1146 - self.length_penalty = length_penalty
1147 - self.early_stopping = early_stopping
1148 - self.num_beams = num_beams
1149 - self.beams = []
1150 - self.worst_score = 1e9
1151 -
1152 - def __len__(self):
1153 - """
1154 - Number of hypotheses in the list.
1155 - """
1156 - return len(self.beams)
1157 -
1158 - def add(self, hyp, sum_logprobs):
1159 - """
1160 - Add a new hypothesis to the list.
1161 - """
1162 - score = sum_logprobs / len(hyp) ** self.length_penalty
1163 - if len(self) < self.num_beams or score > self.worst_score:
1164 - self.beams.append((score, hyp))
1165 - if len(self) > self.num_beams:
1166 - sorted_scores = sorted(
1167 - [(s, idx) for idx, (s, _) in enumerate(self.beams)]
1168 - )
1169 - del self.beams[sorted_scores[0][1]]
1170 - self.worst_score = sorted_scores[1][0]
1171 - else:
1172 - self.worst_score = min(score, self.worst_score)
1173 -
1174 - def is_done(self, best_sum_logprobs, cur_len):
1175 - """
1176 - If there are enough hypotheses and that none of the hypotheses being generated
1177 - can become better than the worst one in the heap, then we are done with this sentence.
1178 - """
1179 -
1180 - if len(self) < self.num_beams:
1181 - return False
1182 - elif self.early_stopping:
1183 - return True
1184 - else:
1185 - cur_score = best_sum_logprobs / cur_len ** self.length_penalty
1186 - ret = self.worst_score >= cur_score
1187 - return ret
1 -import argparse
2 -import logging
3 -import os
4 -from pathlib import Path
5 -from typing import Any, Dict
6 -
7 -import pytorch_lightning as pl
8 -from pytorch_lightning.utilities import rank_zero_info
9 -
10 -from transformers import (
11 - AdamW,
12 - AutoConfig,
13 - AutoModel,
14 - AutoModelForPreTraining,
15 - AutoModelForQuestionAnswering,
16 - AutoModelForSeq2SeqLM,
17 - AutoModelForSequenceClassification,
18 - AutoModelForTokenClassification,
19 - AutoModelWithLMHead,
20 - AutoTokenizer,
21 - PretrainedConfig,
22 - PreTrainedTokenizer,
23 -)
24 -from train.modeling_bart import BartForConditionalGeneration
25 -
26 -from transformers.optimization import (
27 - Adafactor,
28 - get_cosine_schedule_with_warmup,
29 - get_cosine_with_hard_restarts_schedule_with_warmup,
30 - get_linear_schedule_with_warmup,
31 - get_polynomial_decay_schedule_with_warmup,
32 -)
33 -
34 -
35 -logger = logging.getLogger(__name__)
36 -
37 -
38 -MODEL_MODES = {
39 - "base": AutoModel,
40 - "sequence-classification": AutoModelForSequenceClassification,
41 - "question-answering": AutoModelForQuestionAnswering,
42 - "pretraining": AutoModelForPreTraining,
43 - "token-classification": AutoModelForTokenClassification,
44 - "language-modeling": AutoModelWithLMHead,
45 - "summarization": BartForConditionalGeneration,
46 - "translation": AutoModelForSeq2SeqLM,
47 -}
48 -
49 -
50 -# update this and the import above to support new schedulers from transformers.optimization
51 -arg_to_scheduler = {
52 - "linear": get_linear_schedule_with_warmup,
53 - "cosine": get_cosine_schedule_with_warmup,
54 - "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
55 - "polynomial": get_polynomial_decay_schedule_with_warmup,
56 - # '': get_constant_schedule, # not supported for now
57 - # '': get_constant_schedule_with_warmup, # not supported for now
58 -}
59 -arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
60 -arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
61 -
62 -
63 -class BaseTransformer(pl.LightningModule):
64 - def __init__(
65 - self,
66 - hparams: argparse.Namespace,
67 - num_labels=None,
68 - mode="base",
69 - config=None,
70 - tokenizer=None,
71 - model=None,
72 - **config_kwargs,
73 - ):
74 - """Initialize a model, tokenizer and config."""
75 - super().__init__()
76 - # TODO: move to self.save_hyperparameters()
77 - # self.save_hyperparameters()
78 - # can also expand arguments into trainer signature for easier reading
79 -
80 - self.save_hyperparameters(hparams)
81 - self.step_count = 0
82 - self.output_dir = Path(self.hparams.output_dir)
83 - cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
84 - if config is None:
85 - self.config = AutoConfig.from_pretrained(
86 - self.hparams.config_name
87 - if self.hparams.config_name
88 - else self.hparams.model_name_or_path,
89 - **({"num_labels": num_labels} if num_labels is not None else {}),
90 - cache_dir=cache_dir,
91 - **config_kwargs,
92 - )
93 - else:
94 - self.config: PretrainedConfig = config
95 -
96 - extra_model_params = (
97 - "encoder_layerdrop",
98 - "decoder_layerdrop",
99 - "dropout",
100 - "attention_dropout",
101 - )
102 - for p in extra_model_params:
103 - if getattr(self.hparams, p, None):
104 - assert hasattr(
105 - self.config, p
106 - ), f"model config doesn't have a `{p}` attribute"
107 - setattr(self.config, p, getattr(self.hparams, p))
108 -
109 - if tokenizer is None:
110 - self.tokenizer = AutoTokenizer.from_pretrained(
111 - self.hparams.tokenizer_name
112 - if self.hparams.tokenizer_name
113 - else self.hparams.model_name_or_path,
114 - cache_dir=cache_dir,
115 - )
116 - else:
117 - self.tokenizer: PreTrainedTokenizer = tokenizer
118 - self.model_type = MODEL_MODES[mode]
119 - if model is None:
120 - self.model = self.model_type.from_pretrained(
121 - self.hparams.model_name_or_path,
122 - from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
123 - config=self.config,
124 - cache_dir=cache_dir,
125 - )
126 - else:
127 - self.model = model
128 - self.model.resize_token_embeddings(len(tokenizer))
129 -
130 - def load_hf_checkpoint(self, *args, **kwargs):
131 - self.model = self.model_type.from_pretrained(*args, **kwargs)
132 -
133 - def get_lr_scheduler(self):
134 - get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
135 - scheduler = get_schedule_func(
136 - self.opt,
137 - num_warmup_steps=self.hparams.warmup_steps,
138 - num_training_steps=self.total_steps,
139 - )
140 - scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
141 - return scheduler
142 -
143 - def configure_optimizers(self):
144 - """Prepare optimizer and schedule (linear warmup and decay)"""
145 - model = self.model
146 - no_decay = ["bias", "LayerNorm.weight"]
147 - optimizer_grouped_parameters = [
148 - {
149 - "params": [
150 - p
151 - for n, p in model.named_parameters()
152 - if not any(nd in n for nd in no_decay)
153 - ],
154 - "weight_decay": self.hparams.weight_decay,
155 - },
156 - {
157 - "params": [
158 - p
159 - for n, p in model.named_parameters()
160 - if any(nd in n for nd in no_decay)
161 - ],
162 - "weight_decay": 0.0,
163 - },
164 - ]
165 - if self.hparams.adafactor:
166 - optimizer = Adafactor(
167 - optimizer_grouped_parameters,
168 - lr=self.hparams.learning_rate,
169 - scale_parameter=False,
170 - relative_step=False,
171 - )
172 -
173 - else:
174 - optimizer = AdamW(
175 - optimizer_grouped_parameters,
176 - lr=self.hparams.learning_rate,
177 - eps=self.hparams.adam_epsilon,
178 - )
179 - self.opt = optimizer
180 -
181 - scheduler = self.get_lr_scheduler()
182 -
183 - return [optimizer], [scheduler]
184 -
185 - def test_step(self, batch, batch_nb):
186 - return self.validation_step(batch, batch_nb)
187 -
188 - def test_epoch_end(self, outputs):
189 - return self.validation_end(outputs)
190 -
191 - @property
192 - def total_steps(self) -> int:
193 - """The number of total training steps that will be run. Used for lr scheduler purposes."""
194 - num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
195 - effective_batch_size = (
196 - self.hparams.train_batch_size
197 - * self.hparams.accumulate_grad_batches
198 - * num_devices
199 - )
200 - dataset_size = len(self.train_loader.dataset)
201 - return (dataset_size / effective_batch_size) * self.hparams.max_epochs
202 -
203 - def setup(self, mode):
204 - if mode == "fit":
205 - self.train_loader = self.get_dataloader(
206 - "train", self.hparams.train_batch_size, shuffle=True
207 - )
208 -
209 - def get_dataloader(self, type_path, batch_size, shuffle=False):
210 - raise NotImplementedError("You must implement this for your task")
211 -
212 - def train_dataloader(self):
213 - return self.train_loader
214 -
215 - def val_dataloader(self):
216 - return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
217 -
218 - def test_dataloader(self):
219 - return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
220 -
221 - def _feature_file(self, mode):
222 - return os.path.join(
223 - self.hparams.data_dir,
224 - "cached_{}_{}_{}".format(
225 - mode,
226 - list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
227 - str(self.hparams.max_seq_length),
228 - ),
229 - )
230 -
231 - @pl.utilities.rank_zero_only
232 - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
233 - save_path = self.output_dir.joinpath("best_tfmr")
234 - self.model.config.save_step = self.step_count
235 - self.model.save_pretrained(save_path)
236 - self.tokenizer.save_pretrained(save_path)
237 -
238 - @staticmethod
239 - def add_model_specific_args(parser, root_dir):
240 - parser.add_argument(
241 - "--model_name_or_path",
242 - default=None,
243 - type=str,
244 - required=True,
245 - help="Path to pretrained model or model identifier from huggingface.co/models",
246 - )
247 - parser.add_argument(
248 - "--config_name",
249 - default="",
250 - type=str,
251 - help="Pretrained config name or path if not the same as model_name",
252 - )
253 - parser.add_argument(
254 - "--tokenizer_name",
255 - default=None,
256 - type=str,
257 - help="Pretrained tokenizer name or path if not the same as model_name",
258 - )
259 - parser.add_argument(
260 - "--cache_dir",
261 - default="",
262 - type=str,
263 - help="Where do you want to store the pre-trained models downloaded from s3",
264 - )
265 - parser.add_argument(
266 - "--encoder_layerdrop",
267 - type=float,
268 - help="Encoder layer dropout probability (Optional). Goes into model.config",
269 - )
270 - parser.add_argument(
271 - "--decoder_layerdrop",
272 - type=float,
273 - help="Decoder layer dropout probability (Optional). Goes into model.config",
274 - )
275 - parser.add_argument(
276 - "--dropout",
277 - type=float,
278 - help="Dropout probability (Optional). Goes into model.config",
279 - )
280 - parser.add_argument(
281 - "--attention_dropout",
282 - type=float,
283 - help="Attention dropout probability (Optional). Goes into model.config",
284 - )
285 - parser.add_argument(
286 - "--learning_rate",
287 - default=5e-5,
288 - type=float,
289 - help="The initial learning rate for Adam.",
290 - )
291 - parser.add_argument(
292 - "--lr_scheduler",
293 - default="linear",
294 - choices=arg_to_scheduler_choices,
295 - metavar=arg_to_scheduler_metavar,
296 - type=str,
297 - help="Learning rate scheduler",
298 - )
299 - parser.add_argument(
300 - "--weight_decay",
301 - default=0.0,
302 - type=float,
303 - help="Weight decay if we apply some.",
304 - )
305 - parser.add_argument(
306 - "--adam_epsilon",
307 - default=1e-8,
308 - type=float,
309 - help="Epsilon for Adam optimizer.",
310 - )
311 - parser.add_argument(
312 - "--warmup_steps",
313 - default=0,
314 - type=int,
315 - help="Linear warmup over warmup_steps.",
316 - )
317 - parser.add_argument(
318 - "--num_workers", default=4, type=int, help="kwarg passed to DataLoader"
319 - )
320 - parser.add_argument(
321 - "--num_train_epochs", dest="max_epochs", default=3, type=int
322 - )
323 - parser.add_argument("--train_batch_size", default=32, type=int)
324 - parser.add_argument("--eval_batch_size", default=32, type=int)
325 - parser.add_argument("--adafactor", action="store_true")
326 -
327 -
328 -class LoggingCallback(pl.Callback):
329 - def on_batch_end(self, trainer, pl_module):
330 - lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
331 - lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
332 - pl_module.logger.log_metrics(lrs)
333 -
334 - def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
335 - rank_zero_info("***** Validation results *****")
336 - metrics = trainer.callback_metrics
337 - # Log results
338 - for key in sorted(metrics):
339 - if key not in ["log", "progress_bar"]:
340 - rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
341 -
342 - def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
343 - rank_zero_info("***** Test results *****")
344 - metrics = trainer.callback_metrics
345 - # Log and save results to file
346 - output_test_results_file = os.path.join(
347 - pl_module.hparams.output_dir, "test_results.txt"
348 - )
349 - with open(output_test_results_file, "w") as writer:
350 - for key in sorted(metrics):
351 - if key not in ["log", "progress_bar"]:
352 - rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
353 - writer.write("{} = {}\n".format(key, str(metrics[key])))
354 -
355 -
356 -def add_generic_args(parser, root_dir) -> None:
357 - # TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
358 - parser.add_argument(
359 - "--output_dir",
360 - default=None,
361 - type=str,
362 - required=True,
363 - help="The output directory where the model predictions and checkpoints will be written.",
364 - )
365 - parser.add_argument(
366 - "--fp16",
367 - action="store_true",
368 - help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
369 - )
370 -
371 - parser.add_argument(
372 - "--fp16_opt_level",
373 - type=str,
374 - default="O2",
375 - help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
376 - "See details at https://nvidia.github.io/apex/amp.html",
377 - )
378 - parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
379 - parser.add_argument(
380 - "--max_grad_norm",
381 - dest="gradient_clip_val",
382 - default=1.0,
383 - type=float,
384 - help="Max gradient norm",
385 - )
386 - parser.add_argument(
387 - "--do_train", action="store_true", help="Whether to run training."
388 - )
389 - parser.add_argument(
390 - "--do_predict",
391 - action="store_true",
392 - help="Whether to run predictions on the test set.",
393 - )
394 - parser.add_argument(
395 - "--gradient_accumulation_steps",
396 - dest="accumulate_grad_batches",
397 - type=int,
398 - default=1,
399 - help="Number of updates steps to accumulate before performing a backward/update pass.",
400 - )
401 - parser.add_argument(
402 - "--seed", type=int, default=42, help="random seed for initialization"
403 - )
404 -
405 -
406 -def generic_train(
407 - model: BaseTransformer,
408 - args: argparse.Namespace,
409 - early_stopping_callback=False,
410 - logger=True, # can pass WandbLogger() here
411 - extra_callbacks=[],
412 - checkpoint_callback=None,
413 - logging_callback=None,
414 - **extra_train_kwargs,
415 -):
416 - pl.seed_everything(args.seed)
417 -
418 - # init model
419 - odir = Path(model.hparams.output_dir)
420 - odir.mkdir(exist_ok=True)
421 -
422 - # add custom checkpoints
423 - if checkpoint_callback is None:
424 - checkpoint_callback = pl.callbacks.ModelCheckpoint(
425 - filepath=args.output_dir,
426 - prefix="checkpoint",
427 - monitor="val_loss",
428 - mode="min",
429 - save_top_k=1,
430 - )
431 - if logging_callback is None:
432 - logging_callback = LoggingCallback()
433 -
434 - train_params = {}
435 -
436 - # TODO: remove with PyTorch 1.6 since pl uses native amp
437 - if args.fp16:
438 - train_params["precision"] = 16
439 - train_params["amp_level"] = args.fp16_opt_level
440 -
441 - if args.gpus > 1:
442 - train_params["distributed_backend"] = "ddp"
443 -
444 - trainer = pl.Trainer.from_argparse_args(
445 - args,
446 - weights_summary=None,
447 - callbacks=[logging_callback] + extra_callbacks,
448 - logger=logger,
449 - checkpoint_callback=checkpoint_callback,
450 - early_stop_callback=early_stopping_callback,
451 - **train_params,
452 - )
453 -
454 - if args.do_train:
455 - trainer.fit(model)
456 -
457 - return trainer
1 -# coding=utf-8
2 -# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
3 -#
4 -# Licensed under the Apache License, Version 2.0 (the "License");
5 -# you may not use this file except in compliance with the License.
6 -# You may obtain a copy of the License at
7 -#
8 -# http://www.apache.org/licenses/LICENSE-2.0
9 -#
10 -# Unless required by applicable law or agreed to in writing, software
11 -# distributed under the License is distributed on an "AS IS" BASIS,
12 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 -# See the License for the specific language governing permissions and
14 -# limitations under the License.
15 -"""PyTorch BART model, ported from the fairseq repo."""
16 -import math
17 -import random
18 -import warnings
19 -from typing import Dict, List, Optional, Tuple
20 -
21 -import numpy as np
22 -import torch
23 -import torch.nn.functional as F
24 -from torch import Tensor, nn
25 -from torch.nn import CrossEntropyLoss
26 -
27 -from transformers.activations import ACT2FN
28 -from transformers.configuration_bart import BartConfig
29 -from transformers.file_utils import (
30 - add_code_sample_docstrings,
31 - add_end_docstrings,
32 - add_start_docstrings,
33 - add_start_docstrings_to_callable,
34 - replace_return_docstrings,
35 -)
36 -from transformers.modeling_outputs import (
37 - BaseModelOutput,
38 - BaseModelOutputWithPast,
39 - Seq2SeqLMOutput,
40 - Seq2SeqModelOutput,
41 - Seq2SeqQuestionAnsweringModelOutput,
42 - Seq2SeqSequenceClassifierOutput,
43 -)
44 -from train.modeling_utils import PreTrainedModel
45 -import logging
46 -
47 -logger = logging.getLogger(__name__) # pylint: disable=invalid-name
48 -logging.basicConfig(
49 - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
50 - datefmt="%m/%d/%Y %H:%M:%S",
51 - level=logging.INFO,
52 -)
53 -
54 -_CONFIG_FOR_DOC = "BartConfig"
55 -_TOKENIZER_FOR_DOC = "BartTokenizer"
56 -
57 -
58 -BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
59 - "facebook/bart-base",
60 - "facebook/bart-large",
61 - "facebook/bart-large-mnli",
62 - "facebook/bart-large-cnn",
63 - "facebook/bart-large-xsum",
64 - "facebook/mbart-large-en-ro",
65 -]
66 -# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
67 -
68 -
69 -BART_START_DOCSTRING = r"""
70 -
71 - This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
72 - refer to the PyTorch documentation for all matters related to general usage and behavior.
73 -
74 - Parameters:
75 - config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
76 - Initializing with a config file does not load the weights associated with the model, only the configuration.
77 - Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
78 -
79 -"""
80 -BART_GENERATION_EXAMPLE = r"""
81 - Summarization example::
82 -
83 - from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
84 -
85 - # see ``examples/summarization/bart/run_eval.py`` for a longer example
86 - model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
87 - tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
88 -
89 - ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
90 - inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
91 -
92 - # Generate Summary
93 - summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
94 - print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
95 -
96 -"""
97 -
98 -BART_INPUTS_DOCSTRING = r"""
99 - Args:
100 - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
101 - Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
102 - Padding will be ignored by default should you provide it.
103 - Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
104 - attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
105 - Mask to avoid performing attention on padding token indices in input_ids.
106 - Mask values selected in ``[0, 1]``:
107 - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
108 - encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
109 - Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
110 - `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
111 - Used in the cross-attention of the decoder.
112 - decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
113 - Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
114 - decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
115 - Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
116 - If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
117 - See diagram 1 in the paper for more info on the default strategy
118 - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
119 - Contains pre-computed key and value hidden-states of the attention blocks.
120 - Can be used to speed up decoding.
121 - If ``past_key_values`` are used, the user can optionally input only the last
122 - ``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
123 - :obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
124 - use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
125 - If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
126 - ``past_key_values``).
127 - output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
128 - If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
129 - output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
130 - If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
131 - return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
132 - If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
133 - plain tuple.
134 -"""
135 -
136 -
137 -def invert_mask(attention_mask):
138 - """Turns 1->0, 0->1, False->True, True-> False"""
139 - assert attention_mask.dim() == 2
140 - return attention_mask.eq(0)
141 -
142 -
143 -def _prepare_bart_decoder_inputs(
144 - config,
145 - input_ids,
146 - decoder_input_ids=None,
147 - decoder_padding_mask=None,
148 - causal_mask_dtype=torch.float32,
149 -):
150 - """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
151 - none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
152 - Note: this is not called during generation
153 - """
154 - pad_token_id = config.pad_token_id
155 - if decoder_input_ids is None:
156 - decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
157 - bsz, tgt_len = decoder_input_ids.size()
158 - if decoder_padding_mask is None:
159 - decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
160 - else:
161 - decoder_padding_mask = invert_mask(decoder_padding_mask)
162 - if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
163 - # never mask leading token, even if it is pad
164 - decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
165 - causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
166 - dtype=causal_mask_dtype, device=decoder_input_ids.device
167 - )
168 - return decoder_input_ids, decoder_padding_mask, causal_mask
169 -
170 -
171 -class PretrainedBartModel(PreTrainedModel):
172 - config_class = BartConfig
173 - base_model_prefix = "model"
174 -
175 - def _init_weights(self, module):
176 - std = self.config.init_std
177 - if isinstance(module, nn.Linear):
178 - module.weight.data.normal_(mean=0.0, std=std)
179 - if module.bias is not None:
180 - module.bias.data.zero_()
181 - elif isinstance(module, SinusoidalPositionalEmbedding):
182 - pass
183 - elif isinstance(module, nn.Embedding):
184 - module.weight.data.normal_(mean=0.0, std=std)
185 - if module.padding_idx is not None:
186 - module.weight.data[module.padding_idx].zero_()
187 -
188 - @property
189 - def dummy_inputs(self):
190 - pad_token = self.config.pad_token_id
191 - input_ids = torch.tensor(
192 - [[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device
193 - )
194 - dummy_inputs = {
195 - "attention_mask": input_ids.ne(pad_token),
196 - "input_ids": input_ids,
197 - }
198 - return dummy_inputs
199 -
200 -
201 -def _make_linear_from_emb(emb):
202 - vocab_size, emb_size = emb.weight.shape
203 - lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
204 - lin_layer.weight.data = emb.weight.data
205 - return lin_layer
206 -
207 -
208 -# Helper Functions, mostly for making masks
209 -def _check_shapes(shape_1, shape2):
210 - if shape_1 != shape2:
211 - raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
212 -
213 -
214 -def shift_tokens_right(input_ids, pad_token_id):
215 - """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
216 - prev_output_tokens = input_ids.clone()
217 - index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
218 - prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
219 - prev_output_tokens[:, 1:] = input_ids[:, :-1]
220 - return prev_output_tokens
221 -
222 -
223 -def make_padding_mask(input_ids, padding_idx=1):
224 - """True for pad tokens"""
225 - padding_mask = input_ids.eq(padding_idx)
226 - if not padding_mask.any():
227 - padding_mask = None
228 - return padding_mask
229 -
230 -
231 -# Helper Modules
232 -
233 -
234 -class EncoderLayer(nn.Module):
235 - def __init__(self, config: BartConfig):
236 - super().__init__()
237 - self.embed_dim = config.d_model
238 - self.self_attn = Attention(
239 - self.embed_dim,
240 - config.encoder_attention_heads,
241 - dropout=config.attention_dropout,
242 - )
243 - self.normalize_before = config.normalize_before
244 - self.self_attn_layer_norm = LayerNorm(self.embed_dim)
245 - self.dropout = config.dropout
246 - self.activation_fn = ACT2FN[config.activation_function]
247 - self.activation_dropout = config.activation_dropout
248 - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
249 - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
250 - self.final_layer_norm = LayerNorm(self.embed_dim)
251 -
252 - def forward(self, x, encoder_padding_mask, output_attentions=False):
253 - """
254 - Args:
255 - x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
256 - encoder_padding_mask (ByteTensor): binary ByteTensor of shape
257 - `(batch, src_len)` where padding elements are indicated by ``1``.
258 - for t_tgt, t_src is excluded (or masked out), =0 means it is
259 - included in attention
260 -
261 - Returns:
262 - encoded output of shape `(seq_len, batch, embed_dim)`
263 - """
264 - residual = x
265 - if self.normalize_before:
266 - x = self.self_attn_layer_norm(x)
267 - x, attn_weights = self.self_attn(
268 - query=x,
269 - key=x,
270 - key_padding_mask=encoder_padding_mask,
271 - output_attentions=output_attentions,
272 - )
273 - x = F.dropout(x, p=self.dropout, training=self.training)
274 - x = residual + x
275 - if not self.normalize_before:
276 - x = self.self_attn_layer_norm(x)
277 -
278 - residual = x
279 - if self.normalize_before:
280 - x = self.final_layer_norm(x)
281 - x = self.activation_fn(self.fc1(x))
282 - x = F.dropout(x, p=self.activation_dropout, training=self.training)
283 - x = self.fc2(x)
284 - x = F.dropout(x, p=self.dropout, training=self.training)
285 - x = residual + x
286 - if not self.normalize_before:
287 - x = self.final_layer_norm(x)
288 - return x, attn_weights
289 -
290 -
291 -class BartEncoder(nn.Module):
292 - """
293 - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
294 - is a :class:`EncoderLayer`.
295 -
296 - Args:
297 - config: BartConfig
298 - """
299 -
300 - def __init__(self, config: BartConfig, embed_tokens):
301 - super().__init__()
302 -
303 - self.dropout = config.dropout
304 - self.layerdrop = config.encoder_layerdrop
305 -
306 - embed_dim = embed_tokens.embedding_dim
307 - self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
308 - self.padding_idx = embed_tokens.padding_idx
309 - self.max_source_positions = config.max_position_embeddings
310 -
311 - self.embed_tokens = embed_tokens
312 - if config.static_position_embeddings:
313 - self.embed_positions = SinusoidalPositionalEmbedding(
314 - config.max_position_embeddings, embed_dim, self.padding_idx
315 - )
316 - else:
317 - self.embed_positions = LearnedPositionalEmbedding(
318 - config.max_position_embeddings,
319 - embed_dim,
320 - self.padding_idx,
321 - config.extra_pos_embeddings,
322 - )
323 - self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0)
324 - self.layers = nn.ModuleList(
325 - [EncoderLayer(config) for _ in range(config.encoder_layers)]
326 - )
327 - self.layernorm_embedding = (
328 - LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
329 - )
330 - # mbart has one extra layer_norm
331 - self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
332 -
333 - def forward(
334 - self,
335 - input_ids,
336 - patch_ids,
337 - attention_mask=None,
338 - output_attentions=False,
339 - output_hidden_states=False,
340 - return_dict=False,
341 - ):
342 - """
343 - Args:
344 - input_ids (LongTensor): tokens in the source language of shape
345 - `(batch, src_len)`
346 - attention_mask (torch.LongTensor): indicating which indices are padding tokens.
347 - Returns:
348 - BaseModelOutput or Tuple comprised of:
349 - - **x** (Tensor): the last encoder layer's output of
350 - shape `(src_len, batch, embed_dim)`
351 - - **encoder_states** (tuple(torch.FloatTensor)): all intermediate
352 - hidden states of shape `(src_len, batch, embed_dim)`.
353 - Only populated if *output_hidden_states:* is True.
354 - - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
355 - During training might not be of length n_layers because of layer dropout.
356 - """
357 - # check attention mask and invert
358 - if attention_mask is not None:
359 - attention_mask = invert_mask(attention_mask)
360 -
361 - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
362 - embed_pos = self.embed_positions(input_ids)
363 - embed_patch = self.embed_patches(patch_ids)
364 - x = inputs_embeds + embed_pos + embed_patch
365 - x = self.layernorm_embedding(x)
366 - x = F.dropout(x, p=self.dropout, training=self.training)
367 -
368 - # B x T x C -> T x B x C
369 - x = x.transpose(0, 1)
370 -
371 - encoder_states = [] if output_hidden_states else None
372 - all_attentions = () if output_attentions else None
373 - for encoder_layer in self.layers:
374 - if output_hidden_states:
375 - encoder_states.append(x)
376 - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
377 - dropout_probability = random.uniform(0, 1)
378 - if self.training and (
379 - dropout_probability < self.layerdrop
380 - ): # skip the layer
381 - attn = None
382 - else:
383 - x, attn = encoder_layer(
384 - x, attention_mask, output_attentions=output_attentions
385 - )
386 -
387 - if output_attentions:
388 - all_attentions = all_attentions + (attn,)
389 -
390 - if self.layer_norm:
391 - x = self.layer_norm(x)
392 - if output_hidden_states:
393 - encoder_states.append(x)
394 - # T x B x C -> B x T x C
395 - encoder_states = tuple(
396 - hidden_state.transpose(0, 1) for hidden_state in encoder_states
397 - )
398 -
399 - # T x B x C -> B x T x C
400 - x = x.transpose(0, 1)
401 -
402 - if not return_dict:
403 - return tuple(
404 - v for v in [x, encoder_states, all_attentions] if v is not None
405 - )
406 - return BaseModelOutput(
407 - last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions
408 - )
409 -
410 -
411 -class DecoderLayer(nn.Module):
412 - def __init__(self, config: BartConfig):
413 - super().__init__()
414 - self.embed_dim = config.d_model
415 -
416 - self.self_attn = Attention(
417 - embed_dim=self.embed_dim,
418 - num_heads=config.decoder_attention_heads,
419 - dropout=config.attention_dropout,
420 - )
421 - self.dropout = config.dropout
422 - self.activation_fn = ACT2FN[config.activation_function]
423 - self.activation_dropout = config.activation_dropout
424 - self.normalize_before = config.normalize_before
425 -
426 - self.self_attn_layer_norm = LayerNorm(self.embed_dim)
427 - self.encoder_attn = Attention(
428 - self.embed_dim,
429 - config.decoder_attention_heads,
430 - dropout=config.attention_dropout,
431 - encoder_decoder_attention=True,
432 - )
433 - self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
434 - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
435 - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
436 - self.final_layer_norm = LayerNorm(self.embed_dim)
437 -
438 - def forward(
439 - self,
440 - x,
441 - encoder_hidden_states,
442 - encoder_attn_mask=None,
443 - layer_state=None,
444 - causal_mask=None,
445 - decoder_padding_mask=None,
446 - output_attentions=False,
447 - ):
448 - residual = x
449 -
450 - if layer_state is None:
451 - layer_state = {}
452 - if self.normalize_before:
453 - x = self.self_attn_layer_norm(x)
454 - # Self Attention
455 -
456 - x, self_attn_weights = self.self_attn(
457 - query=x,
458 - key=x,
459 - layer_state=layer_state, # adds keys to layer state
460 - key_padding_mask=decoder_padding_mask,
461 - attn_mask=causal_mask,
462 - output_attentions=output_attentions,
463 - )
464 - x = F.dropout(x, p=self.dropout, training=self.training)
465 - x = residual + x
466 - if not self.normalize_before:
467 - x = self.self_attn_layer_norm(x)
468 -
469 - # Cross attention
470 - residual = x
471 - assert self.encoder_attn.cache_key != self.self_attn.cache_key
472 - if self.normalize_before:
473 - x = self.encoder_attn_layer_norm(x)
474 - x, _ = self.encoder_attn(
475 - query=x,
476 - key=encoder_hidden_states,
477 - key_padding_mask=encoder_attn_mask,
478 - layer_state=layer_state, # mutates layer state
479 - )
480 - x = F.dropout(x, p=self.dropout, training=self.training)
481 - x = residual + x
482 - if not self.normalize_before:
483 - x = self.encoder_attn_layer_norm(x)
484 -
485 - # Fully Connected
486 - residual = x
487 - if self.normalize_before:
488 - x = self.final_layer_norm(x)
489 - x = self.activation_fn(self.fc1(x))
490 - x = F.dropout(x, p=self.activation_dropout, training=self.training)
491 - x = self.fc2(x)
492 - x = F.dropout(x, p=self.dropout, training=self.training)
493 - x = residual + x
494 - if not self.normalize_before:
495 - x = self.final_layer_norm(x)
496 - return (
497 - x,
498 - self_attn_weights,
499 - layer_state,
500 - ) # just self_attn weights for now, following t5, layer_state = cache for decoding
501 -
502 -
503 -class BartDecoder(nn.Module):
504 - """
505 - Transformer decoder consisting of *config.decoder_layers* layers. Each layer
506 - is a :class:`DecoderLayer`.
507 - Args:
508 - config: BartConfig
509 - embed_tokens (torch.nn.Embedding): output embedding
510 - """
511 -
512 - def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
513 - super().__init__()
514 - self.dropout = config.dropout
515 - self.layerdrop = config.decoder_layerdrop
516 - self.padding_idx = embed_tokens.padding_idx
517 - self.max_target_positions = config.max_position_embeddings
518 - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
519 - self.embed_tokens = embed_tokens
520 - if config.static_position_embeddings:
521 - self.embed_positions = SinusoidalPositionalEmbedding(
522 - config.max_position_embeddings, config.d_model, config.pad_token_id
523 - )
524 - else:
525 - self.embed_positions = LearnedPositionalEmbedding(
526 - config.max_position_embeddings,
527 - config.d_model,
528 - self.padding_idx,
529 - config.extra_pos_embeddings,
530 - )
531 - self.layers = nn.ModuleList(
532 - [DecoderLayer(config) for _ in range(config.decoder_layers)]
533 - ) # type: List[DecoderLayer]
534 - self.layernorm_embedding = (
535 - LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
536 - )
537 - self.layer_norm = (
538 - LayerNorm(config.d_model) if config.add_final_layer_norm else None
539 - )
540 -
541 - def forward(
542 - self,
543 - input_ids,
544 - encoder_hidden_states,
545 - encoder_padding_mask,
546 - decoder_padding_mask,
547 - decoder_causal_mask,
548 - past_key_values=None,
549 - use_cache=False,
550 - output_attentions=False,
551 - output_hidden_states=False,
552 - return_dict=False,
553 - **unused,
554 - ):
555 - """
556 - Includes several features from "Jointly Learning to Align and
557 - Translate with Transformer Models" (Garg et al., EMNLP 2019).
558 -
559 - Args:
560 - input_ids (LongTensor): previous decoder outputs of shape
561 - `(batch, tgt_len)`, for teacher forcing
562 - encoder_hidden_states: output from the encoder, used for
563 - encoder-side attention
564 - encoder_padding_mask: for ignoring pad tokens
565 - past_key_values (dict or None): dictionary used for storing state during generation
566 -
567 - Returns:
568 - BaseModelOutputWithPast or tuple:
569 - - the decoder's features of shape `(batch, tgt_len, embed_dim)`
570 - - the cache
571 - - hidden states
572 - - attentions
573 - """
574 - if "decoder_cached_states" in unused:
575 - warnings.warn(
576 - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
577 - FutureWarning,
578 - )
579 - past_key_values = unused.pop("decoder_cached_states")
580 - if "decoder_past_key_values" in unused:
581 - warnings.warn(
582 - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
583 - FutureWarning,
584 - )
585 - past_key_values = unused.pop("decoder_past_key_values")
586 -
587 - # check attention mask and invert
588 - if encoder_padding_mask is not None:
589 - encoder_padding_mask = invert_mask(encoder_padding_mask)
590 -
591 - # embed positions
592 - positions = self.embed_positions(input_ids, use_cache=use_cache)
593 -
594 - if use_cache:
595 - input_ids = input_ids[:, -1:]
596 - positions = positions[:, -1:] # happens after we embed them
597 - # assert input_ids.ne(self.padding_idx).any()
598 -
599 - x = self.embed_tokens(input_ids) * self.embed_scale
600 - x += positions
601 - x = self.layernorm_embedding(x)
602 - x = F.dropout(x, p=self.dropout, training=self.training)
603 -
604 - # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
605 - x = x.transpose(0, 1)
606 - encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
607 -
608 - # decoder layers
609 - all_hidden_states = () if output_hidden_states else None
610 - all_self_attns = () if output_attentions else None
611 - next_decoder_cache = []
612 - for idx, decoder_layer in enumerate(self.layers):
613 - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
614 - if output_hidden_states:
615 - all_hidden_states += (x,)
616 - dropout_probability = random.uniform(0, 1)
617 - if self.training and (dropout_probability < self.layerdrop):
618 - continue
619 -
620 - layer_state = past_key_values[idx] if past_key_values is not None else None
621 -
622 - x, layer_self_attn, layer_past = decoder_layer(
623 - x,
624 - encoder_hidden_states,
625 - encoder_attn_mask=encoder_padding_mask,
626 - decoder_padding_mask=decoder_padding_mask,
627 - layer_state=layer_state,
628 - causal_mask=decoder_causal_mask,
629 - output_attentions=output_attentions,
630 - )
631 -
632 - if use_cache:
633 - next_decoder_cache.append(layer_past.copy())
634 -
635 - if self.layer_norm and (
636 - idx == len(self.layers) - 1
637 - ): # if config.add_final_layer_norm (mBART)
638 - x = self.layer_norm(x)
639 - if output_attentions:
640 - all_self_attns += (layer_self_attn,)
641 -
642 - # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
643 - if output_hidden_states:
644 - all_hidden_states = tuple(
645 - hidden_state.transpose(0, 1) for hidden_state in all_hidden_states
646 - )
647 - x = x.transpose(0, 1)
648 - encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
649 -
650 - next_cache = next_decoder_cache if use_cache else None
651 -
652 - if not return_dict:
653 - return tuple(
654 - v
655 - for v in [x, next_cache, all_hidden_states, all_self_attns]
656 - if v is not None
657 - )
658 - return BaseModelOutputWithPast(
659 - last_hidden_state=x,
660 - past_key_values=next_cache,
661 - hidden_states=all_hidden_states,
662 - attentions=all_self_attns,
663 - )
664 -
665 -
666 -def _reorder_buffer(attn_cache, new_order):
667 - for k, input_buffer_k in attn_cache.items():
668 - if input_buffer_k is not None:
669 - attn_cache[k] = input_buffer_k.index_select(0, new_order)
670 - return attn_cache
671 -
672 -
673 -class Attention(nn.Module):
674 - """Multi-headed attention from 'Attention Is All You Need' paper"""
675 -
676 - def __init__(
677 - self,
678 - embed_dim,
679 - num_heads,
680 - dropout=0.0,
681 - bias=True,
682 - encoder_decoder_attention=False, # otherwise self_attention
683 - ):
684 - super().__init__()
685 - self.embed_dim = embed_dim
686 - self.num_heads = num_heads
687 - self.dropout = dropout
688 - self.head_dim = embed_dim // num_heads
689 - assert (
690 - self.head_dim * num_heads == self.embed_dim
691 - ), "embed_dim must be divisible by num_heads"
692 - self.scaling = self.head_dim ** -0.5
693 -
694 - self.encoder_decoder_attention = encoder_decoder_attention
695 - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
696 - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
697 - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
698 - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
699 - self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
700 -
701 - def _shape(self, tensor, seq_len, bsz):
702 - return (
703 - tensor.contiguous()
704 - .view(seq_len, bsz * self.num_heads, self.head_dim)
705 - .transpose(0, 1)
706 - )
707 -
708 - def forward(
709 - self,
710 - query,
711 - key: Optional[Tensor],
712 - key_padding_mask: Optional[Tensor] = None,
713 - layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
714 - attn_mask: Optional[Tensor] = None,
715 - output_attentions=False,
716 - ) -> Tuple[Tensor, Optional[Tensor]]:
717 - """Input shape: Time(SeqLen) x Batch x Channel"""
718 - static_kv: bool = self.encoder_decoder_attention
719 - tgt_len, bsz, embed_dim = query.size()
720 - assert embed_dim == self.embed_dim
721 - assert list(query.size()) == [tgt_len, bsz, embed_dim]
722 - # get here for encoder decoder cause of static_kv
723 - if layer_state is not None: # reuse k,v and encoder_padding_mask
724 - saved_state = layer_state.get(self.cache_key, {})
725 - if "prev_key" in saved_state and static_kv:
726 - # previous time steps are cached - no need to recompute key and value if they are static
727 - key = None
728 - else:
729 - saved_state = None
730 - layer_state = {}
731 -
732 - q = self.q_proj(query) * self.scaling
733 - if static_kv:
734 - if key is None:
735 - k = v = None
736 - else:
737 - k = self.k_proj(key)
738 - v = self.v_proj(key)
739 - else:
740 - k = self.k_proj(query)
741 - v = self.v_proj(query)
742 -
743 - q = self._shape(q, tgt_len, bsz)
744 - if k is not None:
745 - k = self._shape(k, -1, bsz)
746 - if v is not None:
747 - v = self._shape(v, -1, bsz)
748 -
749 - if saved_state is not None:
750 - k, v, key_padding_mask = self._use_saved_state(
751 - k, v, saved_state, key_padding_mask, static_kv, bsz
752 - )
753 -
754 - # Update cache
755 - layer_state[self.cache_key] = {
756 - "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
757 - "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
758 - "prev_key_padding_mask": key_padding_mask if not static_kv else None,
759 - }
760 -
761 - assert k is not None
762 - src_len = k.size(1)
763 - attn_weights = torch.bmm(q, k.transpose(1, 2))
764 - assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
765 -
766 - if attn_mask is not None:
767 - attn_weights = (
768 - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
769 - )
770 - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
771 -
772 - # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
773 - if key_padding_mask is not None and key_padding_mask.dim() == 0:
774 - key_padding_mask = None
775 - assert key_padding_mask is None or key_padding_mask.size()[:2] == (
776 - bsz,
777 - src_len,
778 - )
779 -
780 - if key_padding_mask is not None: # don't attend to padding symbols
781 - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
782 - reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
783 - attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
784 - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
785 - attn_weights = F.softmax(attn_weights, dim=-1)
786 - attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
787 -
788 - assert v is not None
789 - attn_output = torch.bmm(attn_probs, v)
790 - assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
791 - attn_output = (
792 - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
793 - )
794 - attn_output = self.out_proj(attn_output)
795 - if output_attentions:
796 - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
797 - else:
798 - attn_weights = None
799 - return attn_output, attn_weights
800 -
801 - def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
802 - # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
803 - if "prev_key" in saved_state:
804 - _prev_key = saved_state["prev_key"]
805 - assert _prev_key is not None
806 - prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
807 - if static_kv:
808 - k = prev_key
809 - else:
810 - assert k is not None
811 - k = torch.cat([prev_key, k], dim=1)
812 - if "prev_value" in saved_state:
813 - _prev_value = saved_state["prev_value"]
814 - assert _prev_value is not None
815 - prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
816 - if static_kv:
817 - v = prev_value
818 - else:
819 - assert v is not None
820 - v = torch.cat([prev_value, v], dim=1)
821 - assert k is not None and v is not None
822 - prev_key_padding_mask: Optional[Tensor] = saved_state.get(
823 - "prev_key_padding_mask", None
824 - )
825 - if prev_key_padding_mask is not None:
826 - if static_kv:
827 - new_key_padding_mask = prev_key_padding_mask
828 - else:
829 - new_key_padding_mask = torch.cat(
830 - [prev_key_padding_mask, key_padding_mask], dim=1
831 - )
832 - else:
833 - new_key_padding_mask = key_padding_mask
834 - return k, v, new_key_padding_mask
835 -
836 -
837 -class BartClassificationHead(nn.Module):
838 - """Head for sentence-level classification tasks."""
839 -
840 - # This can trivially be shared with RobertaClassificationHead
841 -
842 - def __init__(
843 - self, input_dim, inner_dim, num_classes, pooler_dropout,
844 - ):
845 - super().__init__()
846 - self.dense = nn.Linear(input_dim, inner_dim)
847 - self.dropout = nn.Dropout(p=pooler_dropout)
848 - self.out_proj = nn.Linear(inner_dim, num_classes)
849 -
850 - def forward(self, x):
851 - x = self.dropout(x)
852 - x = self.dense(x)
853 - x = torch.tanh(x)
854 - x = self.dropout(x)
855 - x = self.out_proj(x)
856 - return x
857 -
858 -
859 -class LearnedPositionalEmbedding(nn.Embedding):
860 - """
861 - This module learns positional embeddings up to a fixed maximum size.
862 - Padding ids are ignored by either offsetting based on padding_idx
863 - or by setting padding_idx to None and ensuring that the appropriate
864 - position ids are passed to the forward function.
865 - """
866 -
867 - def __init__(
868 - self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset
869 - ):
870 - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
871 - # and adjust num_embeddings appropriately. Other models dont have this hack
872 - self.offset = offset
873 - assert padding_idx is not None
874 - num_embeddings += offset
875 - super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
876 -
877 - def forward(self, input_ids, use_cache=False):
878 - """Input is expected to be of size [bsz x seqlen]."""
879 - bsz, seq_len = input_ids.shape[:2]
880 - if use_cache:
881 - positions = input_ids.data.new(1, 1).fill_(
882 - seq_len - 1
883 - ) # called before slicing
884 - else:
885 - # starts at 0, ends at 1-seq_len
886 - positions = torch.arange(
887 - seq_len, dtype=torch.long, device=self.weight.device
888 - )
889 - return super().forward(positions + self.offset)
890 -
891 -
892 -def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
893 - if torch.cuda.is_available():
894 - try:
895 - from apex.normalization import FusedLayerNorm
896 -
897 - return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
898 - except ImportError:
899 - pass
900 - return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
901 -
902 -
903 -def fill_with_neg_inf(t):
904 - """FP16-compatible function that fills a input_ids with -inf."""
905 - return t.float().fill_(float("-inf")).type_as(t)
906 -
907 -
908 -# Public API
909 -def _get_shape(t):
910 - return getattr(t, "shape", None)
911 -
912 -
913 -@add_start_docstrings(
914 - "The bare BART Model outputting raw hidden-states without any specific head on top.",
915 - BART_START_DOCSTRING,
916 -)
917 -class BartModel(PretrainedBartModel):
918 - def __init__(self, config: BartConfig):
919 - super().__init__(config)
920 -
921 - padding_idx, vocab_size = config.pad_token_id, config.vocab_size
922 - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
923 -
924 - self.encoder = BartEncoder(config, self.shared)
925 - self.decoder = BartDecoder(config, self.shared)
926 -
927 - self.init_weights()
928 -
929 - @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
930 - @add_code_sample_docstrings(
931 - tokenizer_class=_TOKENIZER_FOR_DOC,
932 - checkpoint="facebook/bart-large",
933 - output_type=BaseModelOutputWithPast,
934 - config_class=_CONFIG_FOR_DOC,
935 - )
936 - def forward(
937 - self,
938 - input_ids,
939 - patch_ids=None,
940 - attention_mask=None,
941 - decoder_input_ids=None,
942 - encoder_outputs: Optional[Tuple] = None,
943 - decoder_attention_mask=None,
944 - past_key_values=None,
945 - use_cache=None,
946 - output_attentions=None,
947 - output_hidden_states=None,
948 - return_dict=None,
949 - **kwargs,
950 - ):
951 - if "decoder_past_key_values" in kwargs:
952 - warnings.warn(
953 - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
954 - FutureWarning,
955 - )
956 - past_key_values = kwargs.pop("decoder_past_key_values")
957 -
958 - if decoder_input_ids is None:
959 - use_cache = False
960 -
961 - output_attentions = (
962 - output_attentions
963 - if output_attentions is not None
964 - else self.config.output_attentions
965 - )
966 - output_hidden_states = (
967 - output_hidden_states
968 - if output_hidden_states is not None
969 - else self.config.output_hidden_states
970 - )
971 - use_cache = use_cache if use_cache is not None else self.config.use_cache
972 - return_dict = (
973 - return_dict if return_dict is not None else self.config.use_return_dict
974 - )
975 -
976 - # make masks if user doesn't supply
977 - if not use_cache:
978 - (
979 - decoder_input_ids,
980 - decoder_padding_mask,
981 - causal_mask,
982 - ) = _prepare_bart_decoder_inputs(
983 - self.config,
984 - input_ids,
985 - decoder_input_ids=decoder_input_ids,
986 - decoder_padding_mask=decoder_attention_mask,
987 - causal_mask_dtype=self.shared.weight.dtype,
988 - )
989 - else:
990 - decoder_padding_mask, causal_mask = None, None
991 -
992 - assert decoder_input_ids is not None
993 -
994 - if encoder_outputs is None:
995 - encoder_outputs = self.encoder(
996 - input_ids=input_ids,
997 - patch_ids=patch_ids,
998 - attention_mask=attention_mask,
999 - output_attentions=output_attentions,
1000 - output_hidden_states=output_hidden_states,
1001 - return_dict=return_dict,
1002 - )
1003 - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False
1004 - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1005 - encoder_outputs = BaseModelOutput(
1006 - last_hidden_state=encoder_outputs[0],
1007 - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1008 - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1009 - )
1010 -
1011 - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1012 - decoder_outputs = self.decoder(
1013 - decoder_input_ids,
1014 - encoder_outputs[0],
1015 - attention_mask,
1016 - decoder_padding_mask,
1017 - decoder_causal_mask=causal_mask,
1018 - past_key_values=past_key_values,
1019 - use_cache=use_cache,
1020 - output_attentions=output_attentions,
1021 - output_hidden_states=output_hidden_states,
1022 - return_dict=return_dict,
1023 - )
1024 -
1025 - if not return_dict:
1026 - return decoder_outputs + encoder_outputs
1027 -
1028 - return Seq2SeqModelOutput(
1029 - last_hidden_state=decoder_outputs.last_hidden_state,
1030 - past_key_values=decoder_outputs.past_key_values,
1031 - decoder_hidden_states=decoder_outputs.hidden_states,
1032 - decoder_attentions=decoder_outputs.attentions,
1033 - encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1034 - encoder_hidden_states=encoder_outputs.hidden_states,
1035 - encoder_attentions=encoder_outputs.attentions,
1036 - )
1037 -
1038 - def get_input_embeddings(self):
1039 - return self.shared
1040 -
1041 - def set_input_embeddings(self, value):
1042 - self.shared = value
1043 - self.encoder.embed_tokens = self.shared
1044 - self.decoder.embed_tokens = self.shared
1045 -
1046 - def get_output_embeddings(self):
1047 - return _make_linear_from_emb(self.shared) # make it on the fly
1048 -
1049 -
1050 -@add_start_docstrings(
1051 - "The BART Model with a language modeling head. Can be used for summarization.",
1052 - BART_START_DOCSTRING,
1053 -)
1054 -class BartForConditionalGeneration(PretrainedBartModel):
1055 - base_model_prefix = "model"
1056 - authorized_missing_keys = [
1057 - r"final_logits_bias",
1058 - r"encoder\.version",
1059 - r"decoder\.version",
1060 - ]
1061 -
1062 - def __init__(self, config: BartConfig):
1063 - super().__init__(config)
1064 - base_model = BartModel(config)
1065 - self.model = base_model
1066 - self.register_buffer(
1067 - "final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
1068 - )
1069 -
1070 - def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
1071 - old_num_tokens = self.model.shared.num_embeddings
1072 - new_embeddings = super().resize_token_embeddings(new_num_tokens)
1073 - self.model.shared = new_embeddings
1074 - self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
1075 - return new_embeddings
1076 -
1077 - def _resize_final_logits_bias(
1078 - self, new_num_tokens: int, old_num_tokens: int
1079 - ) -> None:
1080 - if new_num_tokens <= old_num_tokens:
1081 - new_bias = self.final_logits_bias[:, :new_num_tokens]
1082 - else:
1083 - extra_bias = torch.zeros(
1084 - (1, new_num_tokens - old_num_tokens),
1085 - device=self.final_logits_bias.device,
1086 - )
1087 - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1088 - self.register_buffer("final_logits_bias", new_bias)
1089 -
1090 - @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1091 - @replace_return_docstrings(
1092 - output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
1093 - )
1094 - @add_end_docstrings(BART_GENERATION_EXAMPLE)
1095 - def forward(
1096 - self,
1097 - input_ids,
1098 - patch_ids,
1099 - attention_mask=None,
1100 - encoder_outputs=None,
1101 - decoder_input_ids=None,
1102 - decoder_attention_mask=None,
1103 - past_key_values=None,
1104 - labels=None,
1105 - use_cache=None,
1106 - output_attentions=None,
1107 - output_hidden_states=None,
1108 - return_dict=None,
1109 - **unused,
1110 - ):
1111 - r"""
1112 - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1113 - Labels for computing the masked language modeling loss.
1114 - Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
1115 - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
1116 - with labels in ``[0, ..., config.vocab_size]``.
1117 -
1118 - Returns:
1119 -
1120 - Conditional generation example::
1121 -
1122 - # Mask filling only works for bart-large
1123 - from transformers import BartTokenizer, BartForConditionalGeneration
1124 - tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
1125 - TXT = "My friends are <mask> but they eat too many carbs."
1126 -
1127 - model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
1128 - input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
1129 - logits = model(input_ids).logits
1130 -
1131 - masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
1132 - probs = logits[0, masked_index].softmax(dim=0)
1133 - values, predictions = probs.topk(5)
1134 -
1135 - tokenizer.decode(predictions).split()
1136 - # ['good', 'great', 'all', 'really', 'very']
1137 - """
1138 - if "lm_labels" in unused:
1139 - warnings.warn(
1140 - "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1141 - FutureWarning,
1142 - )
1143 - labels = unused.pop("lm_labels")
1144 - if "decoder_cached_states" in unused:
1145 - warnings.warn(
1146 - "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1147 - FutureWarning,
1148 - )
1149 - past_key_values = unused.pop("decoder_cached_states")
1150 - if "decoder_past_key_values" in unused:
1151 - warnings.warn(
1152 - "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1153 - FutureWarning,
1154 - )
1155 - past_key_values = unused.pop("decoder_past_key_values")
1156 - return_dict = (
1157 - return_dict if return_dict is not None else self.config.use_return_dict
1158 - )
1159 -
1160 - if labels is not None:
1161 - use_cache = False
1162 - if decoder_input_ids is None:
1163 - decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
1164 -
1165 - outputs = self.model(
1166 - input_ids,
1167 - patch_ids=patch_ids,
1168 - attention_mask=attention_mask,
1169 - decoder_input_ids=decoder_input_ids,
1170 - encoder_outputs=encoder_outputs,
1171 - decoder_attention_mask=decoder_attention_mask,
1172 - past_key_values=past_key_values,
1173 - use_cache=use_cache,
1174 - output_attentions=output_attentions,
1175 - output_hidden_states=output_hidden_states,
1176 - return_dict=return_dict,
1177 - )
1178 - lm_logits = F.linear(
1179 - outputs[0], self.model.shared.weight, bias=self.final_logits_bias
1180 - )
1181 -
1182 - masked_lm_loss = None
1183 - if labels is not None:
1184 - loss_fct = CrossEntropyLoss()
1185 - # TODO(SS): do we need to ignore pad tokens in labels?
1186 - masked_lm_loss = loss_fct(
1187 - lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
1188 - )
1189 -
1190 - if not return_dict:
1191 - output = (lm_logits,) + outputs[1:]
1192 - return (
1193 - ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1194 - )
1195 -
1196 - return Seq2SeqLMOutput(
1197 - loss=masked_lm_loss,
1198 - logits=lm_logits,
1199 - past_key_values=outputs.past_key_values,
1200 - decoder_hidden_states=outputs.decoder_hidden_states,
1201 - decoder_attentions=outputs.decoder_attentions,
1202 - encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1203 - encoder_hidden_states=outputs.encoder_hidden_states,
1204 - encoder_attentions=outputs.encoder_attentions,
1205 - )
1206 -
1207 - def prepare_inputs_for_generation(
1208 - self,
1209 - decoder_input_ids,
1210 - past,
1211 - attention_mask,
1212 - use_cache,
1213 - encoder_outputs,
1214 - **kwargs,
1215 - ):
1216 - return {
1217 - "input_ids": None, # encoder_outputs is defined. input_ids not needed
1218 - "patch_ids": None, # encoder_outputs is defined. input_ids not needed
1219 - "encoder_outputs": encoder_outputs,
1220 - "past_key_values": past,
1221 - "decoder_input_ids": decoder_input_ids,
1222 - "attention_mask": attention_mask,
1223 - "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1224 - }
1225 -
1226 - def adjust_logits_during_generation(self, logits, cur_len, max_length):
1227 - if cur_len == 1 and self.config.force_bos_token_to_be_generated:
1228 - self._force_token_ids_generation(logits, self.config.bos_token_id)
1229 - elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
1230 - self._force_token_ids_generation(logits, self.config.eos_token_id)
1231 - return logits
1232 -
1233 - def _force_token_ids_generation(self, scores, token_id) -> None:
1234 - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
1235 - scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float(
1236 - "inf"
1237 - )
1238 -
1239 - @staticmethod
1240 - def _reorder_cache(past, beam_idx):
1241 - reordered_past = []
1242 - for layer_past in past:
1243 - # get the correct batch idx from decoder layer's batch dim for cross and self-attn
1244 - layer_past_new = {
1245 - attn_key: _reorder_buffer(attn_cache, beam_idx)
1246 - for attn_key, attn_cache in layer_past.items()
1247 - }
1248 - reordered_past.append(layer_past_new)
1249 - return reordered_past
1250 -
1251 - def get_encoder(self):
1252 - return self.model.encoder
1253 -
1254 - def get_output_embeddings(self):
1255 - return _make_linear_from_emb(self.model.shared) # make it on the fly
1256 -
1257 -
1258 -@add_start_docstrings(
1259 - """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
1260 - BART_START_DOCSTRING,
1261 -)
1262 -class BartForSequenceClassification(PretrainedBartModel):
1263 - def __init__(self, config: BartConfig, **kwargs):
1264 - super().__init__(config, **kwargs)
1265 - self.model = BartModel(config)
1266 - self.classification_head = BartClassificationHead(
1267 - config.d_model, config.d_model, config.num_labels, config.classif_dropout,
1268 - )
1269 - self.model._init_weights(self.classification_head.dense)
1270 - self.model._init_weights(self.classification_head.out_proj)
1271 -
1272 - @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1273 - @add_code_sample_docstrings(
1274 - tokenizer_class=_TOKENIZER_FOR_DOC,
1275 - checkpoint="facebook/bart-large",
1276 - output_type=Seq2SeqSequenceClassifierOutput,
1277 - config_class=_CONFIG_FOR_DOC,
1278 - )
1279 - def forward(
1280 - self,
1281 - input_ids,
1282 - attention_mask=None,
1283 - encoder_outputs=None,
1284 - decoder_input_ids=None,
1285 - decoder_attention_mask=None,
1286 - labels=None,
1287 - use_cache=None,
1288 - output_attentions=None,
1289 - output_hidden_states=None,
1290 - return_dict=None,
1291 - ):
1292 - r"""
1293 - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1294 - Labels for computing the sequence classification/regression loss.
1295 - Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1296 - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1297 - """
1298 - return_dict = (
1299 - return_dict if return_dict is not None else self.config.use_return_dict
1300 - )
1301 - if labels is not None:
1302 - use_cache = False
1303 -
1304 - outputs = self.model(
1305 - input_ids,
1306 - attention_mask=attention_mask,
1307 - decoder_input_ids=decoder_input_ids,
1308 - decoder_attention_mask=decoder_attention_mask,
1309 - encoder_outputs=encoder_outputs,
1310 - use_cache=use_cache,
1311 - output_attentions=output_attentions,
1312 - output_hidden_states=output_hidden_states,
1313 - return_dict=return_dict,
1314 - )
1315 - x = outputs[0] # last hidden state
1316 - eos_mask = input_ids.eq(self.config.eos_token_id)
1317 - if len(torch.unique(eos_mask.sum(1))) > 1:
1318 - raise ValueError("All examples must have the same number of <eos> tokens.")
1319 - sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[
1320 - :, -1, :
1321 - ]
1322 - logits = self.classification_head(sentence_representation)
1323 -
1324 - loss = None
1325 - if labels is not None:
1326 - loss_fct = CrossEntropyLoss()
1327 - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1328 -
1329 - if not return_dict:
1330 - output = (logits,) + outputs[1:]
1331 - return ((loss,) + output) if loss is not None else output
1332 -
1333 - return Seq2SeqSequenceClassifierOutput(
1334 - loss=loss,
1335 - logits=logits,
1336 - past_key_values=outputs.past_key_values,
1337 - decoder_hidden_states=outputs.decoder_hidden_states,
1338 - decoder_attentions=outputs.decoder_attentions,
1339 - encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1340 - encoder_hidden_states=outputs.encoder_hidden_states,
1341 - encoder_attentions=outputs.encoder_attentions,
1342 - )
1343 -
1344 -
1345 -@add_start_docstrings(
1346 - """BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
1347 - the hidden-states output to compute `span start logits` and `span end logits`). """,
1348 - BART_START_DOCSTRING,
1349 -)
1350 -class BartForQuestionAnswering(PretrainedBartModel):
1351 - def __init__(self, config):
1352 - super().__init__(config)
1353 -
1354 - config.num_labels = 2
1355 - self.num_labels = config.num_labels
1356 -
1357 - self.model = BartModel(config)
1358 - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1359 -
1360 - self.model._init_weights(self.qa_outputs)
1361 -
1362 - @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1363 - @add_code_sample_docstrings(
1364 - tokenizer_class=_TOKENIZER_FOR_DOC,
1365 - checkpoint="facebook/bart-large",
1366 - output_type=Seq2SeqQuestionAnsweringModelOutput,
1367 - config_class=_CONFIG_FOR_DOC,
1368 - )
1369 - def forward(
1370 - self,
1371 - input_ids,
1372 - attention_mask=None,
1373 - encoder_outputs=None,
1374 - decoder_input_ids=None,
1375 - decoder_attention_mask=None,
1376 - start_positions=None,
1377 - end_positions=None,
1378 - use_cache=None,
1379 - output_attentions=None,
1380 - output_hidden_states=None,
1381 - return_dict=None,
1382 - ):
1383 - r"""
1384 - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1385 - Labels for position (index) of the start of the labelled span for computing the token classification loss.
1386 - Positions are clamped to the length of the sequence (`sequence_length`).
1387 - Position outside of the sequence are not taken into account for computing the loss.
1388 - end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1389 - Labels for position (index) of the end of the labelled span for computing the token classification loss.
1390 - Positions are clamped to the length of the sequence (`sequence_length`).
1391 - Position outside of the sequence are not taken into account for computing the loss.
1392 - """
1393 - return_dict = (
1394 - return_dict if return_dict is not None else self.config.use_return_dict
1395 - )
1396 - if start_positions is not None and end_positions is not None:
1397 - use_cache = False
1398 -
1399 - outputs = self.model(
1400 - input_ids,
1401 - attention_mask=attention_mask,
1402 - decoder_input_ids=decoder_input_ids,
1403 - decoder_attention_mask=decoder_attention_mask,
1404 - encoder_outputs=encoder_outputs,
1405 - use_cache=use_cache,
1406 - output_attentions=output_attentions,
1407 - output_hidden_states=output_hidden_states,
1408 - return_dict=return_dict,
1409 - )
1410 -
1411 - sequence_output = outputs[0]
1412 -
1413 - logits = self.qa_outputs(sequence_output)
1414 - start_logits, end_logits = logits.split(1, dim=-1)
1415 - start_logits = start_logits.squeeze(-1)
1416 - end_logits = end_logits.squeeze(-1)
1417 -
1418 - total_loss = None
1419 - if start_positions is not None and end_positions is not None:
1420 - # If we are on multi-GPU, split add a dimension
1421 - if len(start_positions.size()) > 1:
1422 - start_positions = start_positions.squeeze(-1)
1423 - if len(end_positions.size()) > 1:
1424 - end_positions = end_positions.squeeze(-1)
1425 - # sometimes the start/end positions are outside our model inputs, we ignore these terms
1426 - ignored_index = start_logits.size(1)
1427 - start_positions.clamp_(0, ignored_index)
1428 - end_positions.clamp_(0, ignored_index)
1429 -
1430 - loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1431 - start_loss = loss_fct(start_logits, start_positions)
1432 - end_loss = loss_fct(end_logits, end_positions)
1433 - total_loss = (start_loss + end_loss) / 2
1434 -
1435 - if not return_dict:
1436 - output = (start_logits, end_logits,) + outputs[1:]
1437 - return ((total_loss,) + output) if total_loss is not None else output
1438 -
1439 - return Seq2SeqQuestionAnsweringModelOutput(
1440 - loss=total_loss,
1441 - start_logits=start_logits,
1442 - end_logits=end_logits,
1443 - past_key_values=outputs.past_key_values,
1444 - decoder_hidden_states=outputs.decoder_hidden_states,
1445 - decoder_attentions=outputs.decoder_attentions,
1446 - encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1447 - encoder_hidden_states=outputs.encoder_hidden_states,
1448 - encoder_attentions=outputs.encoder_attentions,
1449 - )
1450 -
1451 -
1452 -class SinusoidalPositionalEmbedding(nn.Embedding):
1453 - """This module produces sinusoidal positional embeddings of any length."""
1454 -
1455 - def __init__(self, num_positions, embedding_dim, padding_idx=None):
1456 - super().__init__(num_positions, embedding_dim)
1457 - if embedding_dim % 2 != 0:
1458 - raise NotImplementedError(
1459 - f"odd embedding_dim {embedding_dim} not supported"
1460 - )
1461 - self.weight = self._init_weight(self.weight)
1462 -
1463 - @staticmethod
1464 - def _init_weight(out: nn.Parameter):
1465 - """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
1466 - The cos features are in the 2nd half of the vector. [dim // 2:]
1467 - """
1468 - n_pos, dim = out.shape
1469 - position_enc = np.array(
1470 - [
1471 - [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
1472 - for pos in range(n_pos)
1473 - ]
1474 - )
1475 - out[:, 0 : dim // 2] = torch.FloatTensor(
1476 - np.sin(position_enc[:, 0::2])
1477 - ) # This line breaks for odd n_pos
1478 - out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
1479 - out.detach_()
1480 - out.requires_grad = False
1481 - return out
1482 -
1483 - @torch.no_grad()
1484 - def forward(self, input_ids, use_cache=False):
1485 - """Input is expected to be of size [bsz x seqlen]."""
1486 - bsz, seq_len = input_ids.shape[:2]
1487 - if use_cache:
1488 - positions = input_ids.data.new(1, 1).fill_(
1489 - seq_len - 1
1490 - ) # called before slicing
1491 - else:
1492 - # starts at 0, ends at 1-seq_len
1493 - positions = torch.arange(
1494 - seq_len, dtype=torch.long, device=self.weight.device
1495 - )
1496 - return super().forward(positions)
1 -# coding=utf-8
2 -# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3 -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 -#
5 -# Licensed under the Apache License, Version 2.0 (the "License");
6 -# you may not use this file except in compliance with the License.
7 -# You may obtain a copy of the License at
8 -#
9 -# http://www.apache.org/licenses/LICENSE-2.0
10 -#
11 -# Unless required by applicable law or agreed to in writing, software
12 -# distributed under the License is distributed on an "AS IS" BASIS,
13 -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 -# See the License for the specific language governing permissions and
15 -# limitations under the License.
16 -
17 -import inspect
18 -import os
19 -import re
20 -from dataclasses import dataclass
21 -from typing import Callable, Dict, List, Optional, Set, Tuple, Union
22 -
23 -import torch
24 -from torch import Tensor, device, dtype, nn
25 -from torch.nn import CrossEntropyLoss
26 -from torch.nn import functional as F
27 -
28 -from transformers.activations import get_activation
29 -from transformers.configuration_utils import PretrainedConfig
30 -from transformers.file_utils import (
31 - DUMMY_INPUTS,
32 - TF2_WEIGHTS_NAME,
33 - TF_WEIGHTS_NAME,
34 - WEIGHTS_NAME,
35 - ModelOutput,
36 - cached_path,
37 - hf_bucket_url,
38 - is_remote_url,
39 - is_torch_tpu_available,
40 - replace_return_docstrings,
41 -)
42 -from train.generation_utils import GenerationMixin
43 -import logging
44 -
45 -logger = logging.getLogger(__name__) # pylint: disable=invalid-name
46 -logging.basicConfig(
47 - format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
48 - datefmt="%m/%d/%Y %H:%M:%S",
49 - level=logging.INFO,
50 -)
51 -
52 -
53 -try:
54 - from torch.nn import Identity
55 -except ImportError:
56 - # Older PyTorch compatibility
57 - class Identity(nn.Module):
58 - r"""A placeholder identity operator that is argument-insensitive."""
59 -
60 - def __init__(self, *args, **kwargs):
61 - super().__init__()
62 -
63 - def forward(self, input):
64 - return input
65 -
66 -
67 -def find_pruneable_heads_and_indices(
68 - heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
69 -) -> Tuple[Set[int], torch.LongTensor]:
70 - """
71 - Finds the heads and their indices taking :obj:`already_pruned_heads` into account.
72 -
73 - Args:
74 - heads (:obj:`List[int]`): List of the indices of heads to prune.
75 - n_heads (:obj:`int`): The number of heads in the model.
76 - head_size (:obj:`int`): The size of each head.
77 - already_pruned_heads (:obj:`Set[int]`): A set of already pruned heads.
78 -
79 - Returns:
80 - :obj:`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
81 - """
82 - mask = torch.ones(n_heads, head_size)
83 - heads = (
84 - set(heads) - already_pruned_heads
85 - ) # Convert to set and remove already pruned heads
86 - for head in heads:
87 - # Compute how many pruned heads are before the head and move the index accordingly
88 - head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
89 - mask[head] = 0
90 - mask = mask.view(-1).contiguous().eq(1)
91 - index: torch.LongTensor = torch.arange(len(mask))[mask].long()
92 - return heads, index
93 -
94 -
95 -class ModuleUtilsMixin:
96 - """
97 - A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
98 - """
99 -
100 - def num_parameters(self, only_trainable: bool = False) -> int:
101 - """
102 - Get the number of (optionally, trainable) parameters in the model.
103 -
104 - Args:
105 - only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
106 - Whether or not to return only the number of trainable parameters
107 -
108 - Returns:
109 - :obj:`int`: The number of parameters.
110 - """
111 - params = (
112 - filter(lambda x: x.requires_grad, self.parameters())
113 - if only_trainable
114 - else self.parameters()
115 - )
116 - return sum(p.numel() for p in params)
117 -
118 - @staticmethod
119 - def _hook_rss_memory_pre_forward(module, *args, **kwargs):
120 - try:
121 - import psutil
122 - except (ImportError):
123 - raise ImportError(
124 - "You need to install psutil (pip install psutil) to use memory tracing."
125 - )
126 -
127 - process = psutil.Process(os.getpid())
128 - mem = process.memory_info()
129 - module.mem_rss_pre_forward = mem.rss
130 - return None
131 -
132 - @staticmethod
133 - def _hook_rss_memory_post_forward(module, *args, **kwargs):
134 - try:
135 - import psutil
136 - except (ImportError):
137 - raise ImportError(
138 - "You need to install psutil (pip install psutil) to use memory tracing."
139 - )
140 -
141 - process = psutil.Process(os.getpid())
142 - mem = process.memory_info()
143 - module.mem_rss_post_forward = mem.rss
144 - mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
145 - module.mem_rss_diff = mem_rss_diff + (
146 - module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0
147 - )
148 - return None
149 -
150 - def add_memory_hooks(self):
151 - """
152 - Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
153 -
154 - Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to
155 - zero with :obj:`model.reset_memory_hooks_state()`.
156 - """
157 - for module in self.modules():
158 - module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
159 - module.register_forward_hook(self._hook_rss_memory_post_forward)
160 - self.reset_memory_hooks_state()
161 -
162 - def reset_memory_hooks_state(self):
163 - """
164 - Reset the :obj:`mem_rss_diff` attribute of each module (see
165 - :func:`~transformers.modeling_utils.ModuleUtilsMixin.add_memory_hooks`).
166 - """
167 - for module in self.modules():
168 - module.mem_rss_diff = 0
169 - module.mem_rss_post_forward = 0
170 - module.mem_rss_pre_forward = 0
171 -
172 - @property
173 - def device(self) -> device:
174 - """
175 - :obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
176 - device).
177 - """
178 - try:
179 - return next(self.parameters()).device
180 - except StopIteration:
181 - # For nn.DataParallel compatibility in PyTorch 1.5
182 -
183 - def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
184 - tuples = [
185 - (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)
186 - ]
187 - return tuples
188 -
189 - gen = self._named_members(get_members_fn=find_tensor_attributes)
190 - first_tuple = next(gen)
191 - return first_tuple[1].device
192 -
193 - @property
194 - def dtype(self) -> dtype:
195 - """
196 - :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
197 - """
198 - try:
199 - return next(self.parameters()).dtype
200 - except StopIteration:
201 - # For nn.DataParallel compatibility in PyTorch 1.5
202 -
203 - def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
204 - tuples = [
205 - (k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)
206 - ]
207 - return tuples
208 -
209 - gen = self._named_members(get_members_fn=find_tensor_attributes)
210 - first_tuple = next(gen)
211 - return first_tuple[1].dtype
212 -
213 - def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
214 - """
215 - Invert an attention mask (e.g., switches 0. and 1.).
216 -
217 - Args:
218 - encoder_attention_mask (:obj:`torch.Tensor`): An attention mask.
219 -
220 - Returns:
221 - :obj:`torch.Tensor`: The inverted attention mask.
222 - """
223 - if encoder_attention_mask.dim() == 3:
224 - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
225 - if encoder_attention_mask.dim() == 2:
226 - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
227 - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
228 - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
229 - # /transformer/transformer_layers.py#L270
230 - # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
231 - # encoder_extended_attention_mask.transpose(-1, -2))
232 - encoder_extended_attention_mask = encoder_extended_attention_mask.to(
233 - dtype=self.dtype
234 - ) # fp16 compatibility
235 -
236 - if self.dtype == torch.float16:
237 - encoder_extended_attention_mask = (
238 - 1.0 - encoder_extended_attention_mask
239 - ) * -1e4
240 - elif self.dtype == torch.float32:
241 - encoder_extended_attention_mask = (
242 - 1.0 - encoder_extended_attention_mask
243 - ) * -1e9
244 - else:
245 - raise ValueError(
246 - "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
247 - self.dtype
248 - )
249 - )
250 -
251 - return encoder_extended_attention_mask
252 -
253 - def get_extended_attention_mask(
254 - self, attention_mask: Tensor, input_shape: Tuple[int], device: device
255 - ) -> Tensor:
256 - """
257 - Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
258 -
259 - Arguments:
260 - attention_mask (:obj:`torch.Tensor`):
261 - Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
262 - input_shape (:obj:`Tuple[int]`):
263 - The shape of the input to the model.
264 - device: (:obj:`torch.device`):
265 - The device of the input to the model.
266 -
267 - Returns:
268 - :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
269 - """
270 - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
271 - # ourselves in which case we just need to make it broadcastable to all heads.
272 - if attention_mask.dim() == 3:
273 - extended_attention_mask = attention_mask[:, None, :, :]
274 - elif attention_mask.dim() == 2:
275 - # Provided a padding mask of dimensions [batch_size, seq_length]
276 - # - if the model is a decoder, apply a causal mask in addition to the padding mask
277 - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
278 - if self.config.is_decoder:
279 - batch_size, seq_length = input_shape
280 - seq_ids = torch.arange(seq_length, device=device)
281 - causal_mask = (
282 - seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
283 - <= seq_ids[None, :, None]
284 - )
285 - # causal and attention masks must have same type with pytorch version < 1.3
286 - causal_mask = causal_mask.to(attention_mask.dtype)
287 - extended_attention_mask = (
288 - causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
289 - )
290 - else:
291 - extended_attention_mask = attention_mask[:, None, None, :]
292 - else:
293 - raise ValueError(
294 - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
295 - input_shape, attention_mask.shape
296 - )
297 - )
298 -
299 - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
300 - # masked positions, this operation will create a tensor which is 0.0 for
301 - # positions we want to attend and -10000.0 for masked positions.
302 - # Since we are adding it to the raw scores before the softmax, this is
303 - # effectively the same as removing these entirely.
304 - extended_attention_mask = extended_attention_mask.to(
305 - dtype=self.dtype
306 - ) # fp16 compatibility
307 - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
308 - return extended_attention_mask
309 -
310 - def get_head_mask(
311 - self,
312 - head_mask: Optional[Tensor],
313 - num_hidden_layers: int,
314 - is_attention_chunked: bool = False,
315 - ) -> Tensor:
316 - """
317 - Prepare the head mask if needed.
318 -
319 - Args:
320 - head_mask (:obj:`torch.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`):
321 - The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
322 - num_hidden_layers (:obj:`int`):
323 - The number of hidden layers in the model.
324 - is_attention_chunked: (:obj:`bool`, `optional, defaults to :obj:`False`):
325 - Whether or not the attentions scores are computed by chunks or not.
326 -
327 - Returns:
328 - :obj:`torch.Tensor` with shape :obj:`[num_hidden_layers x batch x num_heads x seq_length x seq_length]`
329 - or list with :obj:`[None]` for each layer.
330 - """
331 - if head_mask is not None:
332 - head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
333 - if is_attention_chunked is True:
334 - head_mask = head_mask.unsqueeze(-1)
335 - else:
336 - head_mask = [None] * num_hidden_layers
337 -
338 - return head_mask
339 -
340 - def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
341 - """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
342 - if head_mask.dim() == 1:
343 - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
344 - head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
345 - elif head_mask.dim() == 2:
346 - head_mask = (
347 - head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
348 - ) # We can specify head_mask for each layer
349 - assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
350 - head_mask = head_mask.to(
351 - dtype=self.dtype
352 - ) # switch to fload if need + fp16 compatibility
353 - return head_mask
354 -
355 -
356 -class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
357 - r"""
358 - Base class for all models.
359 -
360 - :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods
361 - for loading, downloading and saving models as well as a few methods common to all models to:
362 -
363 - * resize the input embeddings,
364 - * prune heads in the self-attention heads.
365 -
366 - Class attributes (overridden by derived classes):
367 - - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
368 - :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
369 - - **load_tf_weights** (:obj:`Callable`) -- A python `method` for loading a TensorFlow checkpoint in a
370 - PyTorch model, taking as arguments:
371 -
372 - - **model** (:class:`~transformers.PreTrainedModel`) -- An instance of the model on which to load the
373 - TensorFlow checkpoint.
374 - - **config** (:class:`~transformers.PreTrainedConfig`) -- An instance of the configuration associated
375 - to the model.
376 - - **path** (:obj:`str`) -- A path to the TensorFlow checkpoint.
377 -
378 - - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
379 - derived classes of the same architecture adding modules on top of the base model.
380 - - **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
381 - when loading the model (and avoid unnecessary warnings).
382 - """
383 - config_class = None
384 - base_model_prefix = ""
385 - authorized_missing_keys = None
386 -
387 - @property
388 - def dummy_inputs(self) -> Dict[str, torch.Tensor]:
389 - """
390 - :obj:`Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
391 - """
392 - return {"input_ids": torch.tensor(DUMMY_INPUTS)}
393 -
394 - def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
395 - super().__init__()
396 - if not isinstance(config, PretrainedConfig):
397 - raise ValueError(
398 - "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
399 - "To create a model from a pretrained model use "
400 - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
401 - self.__class__.__name__, self.__class__.__name__
402 - )
403 - )
404 - # Save config in model
405 - self.config = config
406 -
407 - @property
408 - def base_model(self) -> nn.Module:
409 - """
410 - :obj:`torch.nn.Module`: The main body of the model.
411 - """
412 - return getattr(self, self.base_model_prefix, self)
413 -
414 - def get_input_embeddings(self) -> nn.Module:
415 - """
416 - Returns the model's input embeddings.
417 -
418 - Returns:
419 - :obj:`nn.Module`: A torch module mapping vocabulary to hidden states.
420 - """
421 - base_model = getattr(self, self.base_model_prefix, self)
422 - if base_model is not self:
423 - return base_model.get_input_embeddings()
424 - else:
425 - raise NotImplementedError
426 -
427 - def set_input_embeddings(self, value: nn.Module):
428 - """
429 - Set model's input embeddings
430 -
431 - Args:
432 - value (:obj:`nn.Module`): A module mapping vocabulary to hidden states.
433 - """
434 - base_model = getattr(self, self.base_model_prefix, self)
435 - if base_model is not self:
436 - base_model.set_input_embeddings(value)
437 - else:
438 - raise NotImplementedError
439 -
440 - def get_output_embeddings(self) -> nn.Module:
441 - """
442 - Returns the model's output embeddings.
443 -
444 - Returns:
445 - :obj:`nn.Module`: A torch module mapping hidden states to vocabulary.
446 - """
447 - return None # Overwrite for models with output embeddings
448 -
449 - def tie_weights(self):
450 - """
451 - Tie the weights between the input embeddings and the output embeddings.
452 -
453 - If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
454 - the weights instead.
455 - """
456 - output_embeddings = self.get_output_embeddings()
457 - if output_embeddings is not None and self.config.tie_word_embeddings:
458 - self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
459 -
460 - if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
461 - self._tie_encoder_decoder_weights(
462 - self.encoder, self.decoder, self.base_model_prefix
463 - )
464 -
465 - @staticmethod
466 - def _tie_encoder_decoder_weights(
467 - encoder: nn.Module, decoder: nn.Module, base_model_prefix: str
468 - ):
469 - uninitialized_encoder_weights: List[str] = []
470 - assert (
471 - decoder.__class__ == encoder.__class__
472 - ), f"{decoder.__class__} and {encoder.__class__} have to be equal."
473 -
474 - def tie_encoder_to_decoder_recursively(
475 - decoder_pointer: nn.Module,
476 - encoder_pointer: nn.Module,
477 - module_name: str,
478 - uninitialized_encoder_weights: List[str],
479 - depth=0,
480 - ):
481 - assert isinstance(decoder_pointer, nn.Module) and isinstance(
482 - encoder_pointer, nn.Module
483 - ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
484 - if hasattr(decoder_pointer, "weight"):
485 - assert hasattr(encoder_pointer, "weight")
486 - encoder_pointer.weight = decoder_pointer.weight
487 - if hasattr(decoder_pointer, "bias"):
488 - assert hasattr(encoder_pointer, "bias")
489 - encoder_pointer.bias = decoder_pointer.bias
490 - return
491 -
492 - encoder_modules = encoder_pointer._modules
493 - decoder_modules = decoder_pointer._modules
494 - if len(decoder_modules) > 0:
495 - assert (
496 - len(encoder_modules) > 0
497 - ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
498 -
499 - all_encoder_weights = set(
500 - [
501 - module_name + "/" + sub_name
502 - for sub_name in encoder_modules.keys()
503 - ]
504 - )
505 - encoder_layer_pos = 0
506 - for name, module in decoder_modules.items():
507 - if name.isdigit():
508 - encoder_name = str(int(name) + encoder_layer_pos)
509 - decoder_name = name
510 - if not isinstance(
511 - decoder_modules[decoder_name],
512 - type(encoder_modules[encoder_name]),
513 - ):
514 - # this can happen if the name corresponds to the position in a list module list of layers
515 - # in this case the decoder has added a cross-attention that the encoder does not have
516 - # thus skip this step and substract one layer pos from encoder
517 - encoder_layer_pos -= 1
518 - continue
519 - elif name not in encoder_modules:
520 - continue
521 - elif depth > 500:
522 - raise ValueError(
523 - "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
524 - )
525 - else:
526 - decoder_name = encoder_name = name
527 - tie_encoder_to_decoder_recursively(
528 - decoder_modules[decoder_name],
529 - encoder_modules[encoder_name],
530 - module_name + "/" + name,
531 - uninitialized_encoder_weights,
532 - depth=depth + 1,
533 - )
534 - all_encoder_weights.remove(module_name + "/" + encoder_name)
535 -
536 - uninitialized_encoder_weights += list(all_encoder_weights)
537 -
538 - # tie weights recursively
539 - tie_encoder_to_decoder_recursively(
540 - decoder, encoder, base_model_prefix, uninitialized_encoder_weights
541 - )
542 - if len(uninitialized_encoder_weights) > 0:
543 - logger.warning(
544 - f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
545 - )
546 -
547 - def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
548 - """Tie or clone module weights depending of whether we are using TorchScript or not"""
549 - if self.config.torchscript:
550 - output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
551 - else:
552 - output_embeddings.weight = input_embeddings.weight
553 -
554 - if getattr(output_embeddings, "bias", None) is not None:
555 - output_embeddings.bias.data = torch.nn.functional.pad(
556 - output_embeddings.bias.data,
557 - (
558 - 0,
559 - output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
560 - ),
561 - "constant",
562 - 0,
563 - )
564 - if hasattr(output_embeddings, "out_features") and hasattr(
565 - input_embeddings, "num_embeddings"
566 - ):
567 - output_embeddings.out_features = input_embeddings.num_embeddings
568 -
569 - def resize_token_embeddings(
570 - self, new_num_tokens: Optional[int] = None
571 - ) -> torch.nn.Embedding:
572 - """
573 - Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
574 -
575 - Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
576 -
577 - Arguments:
578 - new_num_tokens (:obj:`int`, `optional`):
579 - The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
580 - vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
581 - just returns a pointer to the input tokens :obj:`torch.nn.Embedding` module of the model wihtout doing
582 - anything.
583 -
584 - Return:
585 - :obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
586 - """
587 - base_model = getattr(
588 - self, self.base_model_prefix, self
589 - ) # get the base model if needed
590 - model_embeds = base_model._resize_token_embeddings(new_num_tokens)
591 - if new_num_tokens is None:
592 - return model_embeds
593 -
594 - # Update base model and current model config
595 - self.config.vocab_size = new_num_tokens
596 - base_model.vocab_size = new_num_tokens
597 -
598 - # Tie weights again if needed
599 - self.tie_weights()
600 -
601 - return model_embeds
602 -
603 - def _resize_token_embeddings(self, new_num_tokens):
604 - old_embeddings = self.get_input_embeddings()
605 - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
606 - self.set_input_embeddings(new_embeddings)
607 - return self.get_input_embeddings()
608 -
609 - def _get_resized_embeddings(
610 - self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
611 - ) -> torch.nn.Embedding:
612 - """
613 - Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
614 - initialized vectors at the end. Reducing the size will remove vectors from the end
615 -
616 - Args:
617 - old_embeddings (:obj:`torch.nn.Embedding`):
618 - Old embeddings to be resized.
619 - new_num_tokens (:obj:`int`, `optional`):
620 - New number of tokens in the embedding matrix.
621 -
622 - Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
623 - vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
624 - :obj:`torch.nn.Embedding`` module of the model wihtout doing anything.
625 -
626 - Return:
627 - :obj:`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
628 - :obj:`new_num_tokens` is :obj:`None`
629 - """
630 - if new_num_tokens is None:
631 - return old_embeddings
632 -
633 - old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
634 - if old_num_tokens == new_num_tokens:
635 - return old_embeddings
636 -
637 - # Build new embeddings
638 - new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
639 - new_embeddings.to(old_embeddings.weight.device)
640 -
641 - # initialize all new embeddings (in particular added tokens)
642 - self._init_weights(new_embeddings)
643 -
644 - # Copy token embeddings from the previous weights
645 - num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
646 - new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[
647 - :num_tokens_to_copy, :
648 - ]
649 -
650 - return new_embeddings
651 -
652 - def init_weights(self):
653 - """
654 - Initializes and prunes weights if needed.
655 - """
656 - # Initialize weights
657 - self.apply(self._init_weights)
658 -
659 - # Prune heads if needed
660 - if self.config.pruned_heads:
661 - self.prune_heads(self.config.pruned_heads)
662 -
663 - # Tie weights if needed
664 - self.tie_weights()
665 -
666 - def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
667 - """
668 - Prunes heads of the base model.
669 -
670 - Arguments:
671 - heads_to_prune (:obj:`Dict[int, List[int]]`):
672 - Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
673 - of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
674 - prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
675 - """
676 - # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
677 - for layer, heads in heads_to_prune.items():
678 - union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
679 - self.config.pruned_heads[layer] = list(
680 - union_heads
681 - ) # Unfortunately we have to store it as list for JSON
682 -
683 - self.base_model._prune_heads(heads_to_prune)
684 -
685 - def save_pretrained(self, save_directory):
686 - """
687 - Save a model and its configuration file to a directory, so that it can be re-loaded using the
688 - `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
689 -
690 - Arguments:
691 - save_directory (:obj:`str`):
692 - Directory to which to save. Will be created if it doesn't exist.
693 - """
694 - if os.path.isfile(save_directory):
695 - logger.error(
696 - "Provided path ({}) should be a directory, not a file".format(
697 - save_directory
698 - )
699 - )
700 - return
701 - os.makedirs(save_directory, exist_ok=True)
702 -
703 - # Only save the model itself if we are using distributed training
704 - model_to_save = self.module if hasattr(self, "module") else self
705 -
706 - # Attach architecture to the config
707 - model_to_save.config.architectures = [model_to_save.__class__.__name__]
708 -
709 - # If we save using the predefined names, we can load using `from_pretrained`
710 - output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
711 -
712 - if getattr(self.config, "xla_device", False):
713 - import torch_xla.core.xla_model as xm
714 -
715 - if xm.is_master_ordinal():
716 - # Save configuration file
717 - model_to_save.config.save_pretrained(save_directory)
718 - # xm.save takes care of saving only from master
719 - xm.save(model_to_save.state_dict(), output_model_file)
720 - else:
721 - model_to_save.config.save_pretrained(save_directory)
722 - torch.save(model_to_save.state_dict(), output_model_file)
723 -
724 - logger.info("Model weights saved in {}".format(output_model_file))
725 -
726 - @classmethod
727 - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
728 - r"""
729 - Instantiate a pretrained pytorch model from a pre-trained model configuration.
730 -
731 - The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
732 - To train the model, you should first set it back in training mode with ``model.train()``.
733 -
734 - The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
735 - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
736 - task.
737 -
738 - The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
739 - weights are discarded.
740 -
741 - Parameters:
742 - pretrained_model_name_or_path (:obj:`str`, `optional`):
743 - Can be either:
744 -
745 - - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
746 - ``bert-base-uncased``.
747 - - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
748 - ``dbmdz/bert-base-german-cased``.
749 - - A path to a `directory` containing model weights saved using
750 - :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
751 - - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
752 - this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
753 - as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
754 - a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
755 - - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
756 - arguments ``config`` and ``state_dict``).
757 - model_args (sequence of positional arguments, `optional`):
758 - All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
759 - config (:obj:`Union[PretrainedConfig, str]`, `optional`):
760 - Can be either:
761 -
762 - - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
763 - - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
764 -
765 - Configuration for the model to use instead of an automatically loaded configuation. Configuration can
766 - be automatically loaded when:
767 -
768 - - The model is a model provided by the library (loaded with the `shortcut name` string of a
769 - pretrained model).
770 - - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
771 - by suppling the save directory.
772 - - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
773 - configuration JSON file named `config.json` is found in the directory.
774 - state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`):
775 - A state dictionary to use instead of a state dictionary loaded from saved weights file.
776 -
777 - This option can be used if you want to create a model from a pretrained configuration but load your own
778 - weights. In this case though, you should check if using
779 - :func:`~transformers.PreTrainedModel.save_pretrained` and
780 - :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
781 - cache_dir (:obj:`str`, `optional`):
782 - Path to a directory in which a downloaded pretrained model configuration should be cached if the
783 - standard cache should not be used.
784 - from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
785 - Load the model weights from a TensorFlow checkpoint save file (see docstring of
786 - ``pretrained_model_name_or_path`` argument).
787 - force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
788 - Whether or not to force the (re-)download of the model weights and configuration files, overriding the
789 - cached versions if they exist.
790 - resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
791 - Whether or not to delete incompletely received files. Will attempt to resume the download if such a
792 - file exists.
793 - proxies (:obj:`Dict[str, str], `optional`):
794 - A dictionary of proxy servers to use by protocol or endpoint, e.g.,
795 - :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
796 - request.
797 - output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
798 - Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
799 - messages.
800 - local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
801 - Whether or not to only look at local files (e.g., not try doanloading the model).
802 - use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
803 - Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
804 - our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
805 - kwargs (remaining dictionary of keyword arguments, `optional`):
806 - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
807 - :obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
808 - automatically loaded:
809 -
810 - - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
811 - underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
812 - already been done)
813 - - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
814 - initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
815 - ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
816 - with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
817 - attribute will be passed to the underlying model's ``__init__`` function.
818 -
819 - Examples::
820 -
821 - from transformers import BertConfig, BertModel
822 - # Download model and configuration from S3 and cache.
823 - model = BertModel.from_pretrained('bert-base-uncased')
824 - # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
825 - model = BertModel.from_pretrained('./test/saved_model/')
826 - # Update configuration during loading.
827 - model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)
828 - assert model.config.output_attention == True
829 - # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
830 - config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
831 - model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
832 - """
833 - config = kwargs.pop("config", None)
834 - state_dict = kwargs.pop("state_dict", None)
835 - cache_dir = kwargs.pop("cache_dir", None)
836 - from_tf = kwargs.pop("from_tf", False)
837 - force_download = kwargs.pop("force_download", False)
838 - resume_download = kwargs.pop("resume_download", False)
839 - proxies = kwargs.pop("proxies", None)
840 - output_loading_info = kwargs.pop("output_loading_info", False)
841 - local_files_only = kwargs.pop("local_files_only", False)
842 - use_cdn = kwargs.pop("use_cdn", True)
843 -
844 - # Load config if we don't provide a configuration
845 - if not isinstance(config, PretrainedConfig):
846 - config_path = (
847 - config if config is not None else pretrained_model_name_or_path
848 - )
849 - config, model_kwargs = cls.config_class.from_pretrained(
850 - config_path,
851 - *model_args,
852 - cache_dir=cache_dir,
853 - return_unused_kwargs=True,
854 - force_download=force_download,
855 - resume_download=resume_download,
856 - proxies=proxies,
857 - local_files_only=local_files_only,
858 - **kwargs,
859 - )
860 - else:
861 - model_kwargs = kwargs
862 -
863 - # Load model
864 - if pretrained_model_name_or_path is not None:
865 - if os.path.isdir(pretrained_model_name_or_path):
866 - if from_tf and os.path.isfile(
867 - os.path.join(
868 - pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index"
869 - )
870 - ):
871 - # Load from a TF 1.0 checkpoint
872 - archive_file = os.path.join(
873 - pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index"
874 - )
875 - elif from_tf and os.path.isfile(
876 - os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
877 - ):
878 - # Load from a TF 2.0 checkpoint
879 - archive_file = os.path.join(
880 - pretrained_model_name_or_path, TF2_WEIGHTS_NAME
881 - )
882 - elif os.path.isfile(
883 - os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
884 - ):
885 - # Load from a PyTorch checkpoint
886 - archive_file = os.path.join(
887 - pretrained_model_name_or_path, WEIGHTS_NAME
888 - )
889 - else:
890 - raise EnvironmentError(
891 - "Error no file named {} found in directory {} or `from_tf` set to False".format(
892 - [
893 - WEIGHTS_NAME,
894 - TF2_WEIGHTS_NAME,
895 - TF_WEIGHTS_NAME + ".index",
896 - ],
897 - pretrained_model_name_or_path,
898 - )
899 - )
900 - elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
901 - pretrained_model_name_or_path
902 - ):
903 - archive_file = pretrained_model_name_or_path
904 - elif os.path.isfile(pretrained_model_name_or_path + ".index"):
905 - assert (
906 - from_tf
907 - ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
908 - pretrained_model_name_or_path + ".index"
909 - )
910 - archive_file = pretrained_model_name_or_path + ".index"
911 - else:
912 - archive_file = hf_bucket_url(
913 - pretrained_model_name_or_path,
914 - filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
915 - use_cdn=use_cdn,
916 - )
917 -
918 - try:
919 - # Load from URL or cache if already cached
920 - resolved_archive_file = cached_path(
921 - archive_file,
922 - cache_dir=cache_dir,
923 - force_download=force_download,
924 - proxies=proxies,
925 - resume_download=resume_download,
926 - local_files_only=local_files_only,
927 - )
928 - if resolved_archive_file is None:
929 - raise EnvironmentError
930 - except EnvironmentError:
931 - msg = (
932 - f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
933 - f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
934 - f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
935 - )
936 - raise EnvironmentError(msg)
937 -
938 - if resolved_archive_file == archive_file:
939 - logger.info("loading weights file {}".format(archive_file))
940 - else:
941 - logger.info(
942 - "loading weights file {} from cache at {}".format(
943 - archive_file, resolved_archive_file
944 - )
945 - )
946 - else:
947 - resolved_archive_file = None
948 -
949 - # Instantiate model.
950 - model = cls(config, *model_args, **model_kwargs)
951 -
952 - if state_dict is None and not from_tf:
953 - try:
954 - state_dict = torch.load(resolved_archive_file, map_location="cpu")
955 - except Exception:
956 - raise OSError(
957 - "Unable to load weights from pytorch checkpoint file. "
958 - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
959 - )
960 -
961 - missing_keys = []
962 - unexpected_keys = []
963 - error_msgs = []
964 -
965 - if from_tf:
966 - if resolved_archive_file.endswith(".index"):
967 - # Load from a TensorFlow 1.X checkpoint - provided by original authors
968 - model = cls.load_tf_weights(
969 - model, config, resolved_archive_file[:-6]
970 - ) # Remove the '.index'
971 - else:
972 - # Load from our TensorFlow 2.0 checkpoints
973 - try:
974 - from transformers import load_tf2_checkpoint_in_pytorch_model
975 -
976 - model = load_tf2_checkpoint_in_pytorch_model(
977 - model, resolved_archive_file, allow_missing_keys=True
978 - )
979 - except ImportError:
980 - logger.error(
981 - "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
982 - "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
983 - )
984 - raise
985 - else:
986 - # Convert old format to new format if needed from a PyTorch state_dict
987 - old_keys = []
988 - new_keys = []
989 - for key in state_dict.keys():
990 - new_key = None
991 - if "gamma" in key:
992 - new_key = key.replace("gamma", "weight")
993 - if "beta" in key:
994 - new_key = key.replace("beta", "bias")
995 - if new_key:
996 - old_keys.append(key)
997 - new_keys.append(new_key)
998 - for old_key, new_key in zip(old_keys, new_keys):
999 - state_dict[new_key] = state_dict.pop(old_key)
1000 -
1001 - # copy state_dict so _load_from_state_dict can modify it
1002 - metadata = getattr(state_dict, "_metadata", None)
1003 - state_dict = state_dict.copy()
1004 - if metadata is not None:
1005 - state_dict._metadata = metadata
1006 -
1007 - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
1008 - # so we need to apply the function recursively.
1009 - def load(module: nn.Module, prefix=""):
1010 - local_metadata = (
1011 - {} if metadata is None else metadata.get(prefix[:-1], {})
1012 - )
1013 - module._load_from_state_dict(
1014 - state_dict,
1015 - prefix,
1016 - local_metadata,
1017 - True,
1018 - missing_keys,
1019 - unexpected_keys,
1020 - error_msgs,
1021 - )
1022 - for name, child in module._modules.items():
1023 - if child is not None:
1024 - load(child, prefix + name + ".")
1025 -
1026 - # Make sure we are able to load base models as well as derived models (with heads)
1027 - start_prefix = ""
1028 - model_to_load = model
1029 - has_prefix_module = any(
1030 - s.startswith(cls.base_model_prefix) for s in state_dict.keys()
1031 - )
1032 - if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
1033 - start_prefix = cls.base_model_prefix + "."
1034 - if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
1035 - model_to_load = getattr(model, cls.base_model_prefix)
1036 -
1037 - load(model_to_load, prefix=start_prefix)
1038 -
1039 - if model.__class__.__name__ != model_to_load.__class__.__name__:
1040 - base_model_state_dict = model_to_load.state_dict().keys()
1041 - head_model_state_dict_without_base_prefix = [
1042 - key.split(cls.base_model_prefix + ".")[-1]
1043 - for key in model.state_dict().keys()
1044 - ]
1045 - missing_keys.extend(
1046 - head_model_state_dict_without_base_prefix - base_model_state_dict
1047 - )
1048 -
1049 - # Some models may have keys that are not in the state by design, removing them before needlessly warning
1050 - # the user.
1051 - if cls.authorized_missing_keys is not None:
1052 - for pat in cls.authorized_missing_keys:
1053 - missing_keys = [
1054 - k for k in missing_keys if re.search(pat, k) is None
1055 - ]
1056 -
1057 - if len(unexpected_keys) > 0:
1058 - logger.warning(
1059 - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
1060 - f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
1061 - f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
1062 - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
1063 - f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
1064 - f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
1065 - )
1066 - else:
1067 - logger.info(
1068 - f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
1069 - )
1070 - if len(missing_keys) > 0:
1071 - logger.warning(
1072 - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
1073 - f"and are newly initialized: {missing_keys}\n"
1074 - f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1075 - )
1076 - else:
1077 - logger.info(
1078 - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
1079 - f"If your task is similar to the task the model of the checkpoint was trained on, "
1080 - f"you can already use {model.__class__.__name__} for predictions without further training."
1081 - )
1082 - if len(error_msgs) > 0:
1083 - raise RuntimeError(
1084 - "Error(s) in loading state_dict for {}:\n\t{}".format(
1085 - model.__class__.__name__, "\n\t".join(error_msgs)
1086 - )
1087 - )
1088 - # make sure token embedding weights are still tied if needed
1089 - model.tie_weights()
1090 -
1091 - # Set model in evaluation mode to deactivate DropOut modules by default
1092 - model.eval()
1093 -
1094 - if output_loading_info:
1095 - loading_info = {
1096 - "missing_keys": missing_keys,
1097 - "unexpected_keys": unexpected_keys,
1098 - "error_msgs": error_msgs,
1099 - }
1100 - return model, loading_info
1101 -
1102 - if (
1103 - hasattr(config, "xla_device")
1104 - and config.xla_device
1105 - and is_torch_tpu_available()
1106 - ):
1107 - import torch_xla.core.xla_model as xm
1108 -
1109 - model = xm.send_cpu_data_to_device(model, xm.xla_device())
1110 - model.to(xm.xla_device())
1111 -
1112 - return model
1113 -
1114 -
1115 -class Conv1D(nn.Module):
1116 - """
1117 - 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
1118 -
1119 - Basically works like a linear layer but the weights are transposed.
1120 -
1121 - Args:
1122 - nf (:obj:`int`): The number of output features.
1123 - nx (:obj:`int`): The number of input features.
1124 - """
1125 -
1126 - def __init__(self, nf, nx):
1127 - super().__init__()
1128 - self.nf = nf
1129 - w = torch.empty(nx, nf)
1130 - nn.init.normal_(w, std=0.02)
1131 - self.weight = nn.Parameter(w)
1132 - self.bias = nn.Parameter(torch.zeros(nf))
1133 -
1134 - def forward(self, x):
1135 - size_out = x.size()[:-1] + (self.nf,)
1136 - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
1137 - x = x.view(*size_out)
1138 - return x
1139 -
1140 -
1141 -class PoolerStartLogits(nn.Module):
1142 - """
1143 - Compute SQuAD start logits from sequence hidden states.
1144 -
1145 - Args:
1146 - config (:class:`~transformers.PretrainedConfig`):
1147 - The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
1148 - """
1149 -
1150 - def __init__(self, config: PretrainedConfig):
1151 - super().__init__()
1152 - self.dense = nn.Linear(config.hidden_size, 1)
1153 -
1154 - def forward(
1155 - self,
1156 - hidden_states: torch.FloatTensor,
1157 - p_mask: Optional[torch.FloatTensor] = None,
1158 - ) -> torch.FloatTensor:
1159 - """
1160 - Args:
1161 - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
1162 - The final hidden states of the model.
1163 - p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
1164 - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1165 - 1.0 means token should be masked.
1166 -
1167 - Returns:
1168 - :obj:`torch.FloatTensor`: The start logits for SQuAD.
1169 - """
1170 - x = self.dense(hidden_states).squeeze(-1)
1171 -
1172 - if p_mask is not None:
1173 - if next(self.parameters()).dtype == torch.float16:
1174 - x = x * (1 - p_mask) - 65500 * p_mask
1175 - else:
1176 - x = x * (1 - p_mask) - 1e30 * p_mask
1177 -
1178 - return x
1179 -
1180 -
1181 -class PoolerEndLogits(nn.Module):
1182 - """
1183 - Compute SQuAD end logits from sequence hidden states.
1184 -
1185 - Args:
1186 - config (:class:`~transformers.PretrainedConfig`):
1187 - The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
1188 - :obj:`layer_norm_eps` to use.
1189 - """
1190 -
1191 - def __init__(self, config: PretrainedConfig):
1192 - super().__init__()
1193 - self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
1194 - self.activation = nn.Tanh()
1195 - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1196 - self.dense_1 = nn.Linear(config.hidden_size, 1)
1197 -
1198 - def forward(
1199 - self,
1200 - hidden_states: torch.FloatTensor,
1201 - start_states: Optional[torch.FloatTensor] = None,
1202 - start_positions: Optional[torch.LongTensor] = None,
1203 - p_mask: Optional[torch.FloatTensor] = None,
1204 - ) -> torch.FloatTensor:
1205 - """
1206 - Args:
1207 - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
1208 - The final hidden states of the model.
1209 - start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
1210 - The hidden states of the first tokens for the labeled span.
1211 - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1212 - The position of the first token for the labeled span.
1213 - p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
1214 - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1215 - 1.0 means token should be masked.
1216 -
1217 - .. note::
1218 -
1219 - One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
1220 - ``start_positions`` overrides ``start_states``.
1221 -
1222 - Returns:
1223 - :obj:`torch.FloatTensor`: The end logits for SQuAD.
1224 - """
1225 - assert (
1226 - start_states is not None or start_positions is not None
1227 - ), "One of start_states, start_positions should be not None"
1228 - if start_positions is not None:
1229 - slen, hsz = hidden_states.shape[-2:]
1230 - start_positions = start_positions[:, None, None].expand(
1231 - -1, -1, hsz
1232 - ) # shape (bsz, 1, hsz)
1233 - start_states = hidden_states.gather(
1234 - -2, start_positions
1235 - ) # shape (bsz, 1, hsz)
1236 - start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
1237 -
1238 - x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
1239 - x = self.activation(x)
1240 - x = self.LayerNorm(x)
1241 - x = self.dense_1(x).squeeze(-1)
1242 -
1243 - if p_mask is not None:
1244 - if next(self.parameters()).dtype == torch.float16:
1245 - x = x * (1 - p_mask) - 65500 * p_mask
1246 - else:
1247 - x = x * (1 - p_mask) - 1e30 * p_mask
1248 -
1249 - return x
1250 -
1251 -
1252 -class PoolerAnswerClass(nn.Module):
1253 - """
1254 - Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
1255 -
1256 - Args:
1257 - config (:class:`~transformers.PretrainedConfig`):
1258 - The config used by the model, will be used to grab the :obj:`hidden_size` of the model.
1259 - """
1260 -
1261 - def __init__(self, config):
1262 - super().__init__()
1263 - self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
1264 - self.activation = nn.Tanh()
1265 - self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
1266 -
1267 - def forward(
1268 - self,
1269 - hidden_states: torch.FloatTensor,
1270 - start_states: Optional[torch.FloatTensor] = None,
1271 - start_positions: Optional[torch.LongTensor] = None,
1272 - cls_index: Optional[torch.LongTensor] = None,
1273 - ) -> torch.FloatTensor:
1274 - """
1275 - Args:
1276 - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
1277 - The final hidden states of the model.
1278 - start_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`, `optional`):
1279 - The hidden states of the first tokens for the labeled span.
1280 - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1281 - The position of the first token for the labeled span.
1282 - cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1283 - Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
1284 -
1285 - .. note::
1286 -
1287 - One of ``start_states`` or ``start_positions`` should be not obj:`None`. If both are set,
1288 - ``start_positions`` overrides ``start_states``.
1289 -
1290 - Returns:
1291 - :obj:`torch.FloatTensor`: The SQuAD 2.0 answer class.
1292 - """
1293 - # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
1294 - hsz = hidden_states.shape[-1]
1295 - assert (
1296 - start_states is not None or start_positions is not None
1297 - ), "One of start_states, start_positions should be not None"
1298 - if start_positions is not None:
1299 - start_positions = start_positions[:, None, None].expand(
1300 - -1, -1, hsz
1301 - ) # shape (bsz, 1, hsz)
1302 - start_states = hidden_states.gather(-2, start_positions).squeeze(
1303 - -2
1304 - ) # shape (bsz, hsz)
1305 -
1306 - if cls_index is not None:
1307 - cls_index = cls_index[:, None, None].expand(
1308 - -1, -1, hsz
1309 - ) # shape (bsz, 1, hsz)
1310 - cls_token_state = hidden_states.gather(-2, cls_index).squeeze(
1311 - -2
1312 - ) # shape (bsz, hsz)
1313 - else:
1314 - cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
1315 -
1316 - x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
1317 - x = self.activation(x)
1318 - x = self.dense_1(x).squeeze(-1)
1319 -
1320 - return x
1321 -
1322 -
1323 -@dataclass
1324 -class SquadHeadOutput(ModelOutput):
1325 - """
1326 - Base class for outputs of question answering models using a :class:`~transformers.modeling_utils.SQuADHead`.
1327 -
1328 - Args:
1329 - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
1330 - Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
1331 - start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1332 - Log probabilities for the top config.start_n_top start token possibilities (beam-search).
1333 - start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1334 - Indices for the top config.start_n_top start token possibilities (beam-search).
1335 - end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1336 - Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
1337 - end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1338 - Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
1339 - cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
1340 - Log probabilities for the ``is_impossible`` label of the answers.
1341 -
1342 - """
1343 -
1344 - loss: Optional[torch.FloatTensor] = None
1345 - start_top_log_probs: Optional[torch.FloatTensor] = None
1346 - start_top_index: Optional[torch.LongTensor] = None
1347 - end_top_log_probs: Optional[torch.FloatTensor] = None
1348 - end_top_index: Optional[torch.LongTensor] = None
1349 - cls_logits: Optional[torch.FloatTensor] = None
1350 -
1351 -
1352 -class SQuADHead(nn.Module):
1353 - r"""
1354 - A SQuAD head inspired by XLNet.
1355 -
1356 - Args:
1357 - config (:class:`~transformers.PretrainedConfig`):
1358 - The config used by the model, will be used to grab the :obj:`hidden_size` of the model and the
1359 - :obj:`layer_norm_eps` to use.
1360 - """
1361 -
1362 - def __init__(self, config):
1363 - super().__init__()
1364 - self.start_n_top = config.start_n_top
1365 - self.end_n_top = config.end_n_top
1366 -
1367 - self.start_logits = PoolerStartLogits(config)
1368 - self.end_logits = PoolerEndLogits(config)
1369 - self.answer_class = PoolerAnswerClass(config)
1370 -
1371 - @replace_return_docstrings(
1372 - output_type=SquadHeadOutput, config_class=PretrainedConfig
1373 - )
1374 - def forward(
1375 - self,
1376 - hidden_states: torch.FloatTensor,
1377 - start_positions: Optional[torch.LongTensor] = None,
1378 - end_positions: Optional[torch.LongTensor] = None,
1379 - cls_index: Optional[torch.LongTensor] = None,
1380 - is_impossible: Optional[torch.LongTensor] = None,
1381 - p_mask: Optional[torch.FloatTensor] = None,
1382 - return_dict: bool = False,
1383 - ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
1384 - """
1385 - Args:
1386 - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
1387 - Final hidden states of the model on the sequence tokens.
1388 - start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1389 - Positions of the first token for the labeled span.
1390 - end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1391 - Positions of the last token for the labeled span.
1392 - cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1393 - Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
1394 - is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1395 - Whether the question has a possible answer in the paragraph or not.
1396 - p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
1397 - Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1398 - 1.0 means token should be masked.
1399 - return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
1400 - Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
1401 -
1402 - Returns:
1403 - """
1404 - start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1405 -
1406 - if start_positions is not None and end_positions is not None:
1407 - # If we are on multi-GPU, let's remove the dimension added by batch splitting
1408 - for x in (start_positions, end_positions, cls_index, is_impossible):
1409 - if x is not None and x.dim() > 1:
1410 - x.squeeze_(-1)
1411 -
1412 - # during training, compute the end logits based on the ground truth of the start position
1413 - end_logits = self.end_logits(
1414 - hidden_states, start_positions=start_positions, p_mask=p_mask
1415 - )
1416 -
1417 - loss_fct = CrossEntropyLoss()
1418 - start_loss = loss_fct(start_logits, start_positions)
1419 - end_loss = loss_fct(end_logits, end_positions)
1420 - total_loss = (start_loss + end_loss) / 2
1421 -
1422 - if cls_index is not None and is_impossible is not None:
1423 - # Predict answerability from the representation of CLS and START
1424 - cls_logits = self.answer_class(
1425 - hidden_states, start_positions=start_positions, cls_index=cls_index
1426 - )
1427 - loss_fct_cls = nn.BCEWithLogitsLoss()
1428 - cls_loss = loss_fct_cls(cls_logits, is_impossible)
1429 -
1430 - # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1431 - total_loss += cls_loss * 0.5
1432 -
1433 - return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
1434 -
1435 - else:
1436 - # during inference, compute the end logits based on beam search
1437 - bsz, slen, hsz = hidden_states.size()
1438 - start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1439 -
1440 - start_top_log_probs, start_top_index = torch.topk(
1441 - start_log_probs, self.start_n_top, dim=-1
1442 - ) # shape (bsz, start_n_top)
1443 - start_top_index_exp = start_top_index.unsqueeze(-1).expand(
1444 - -1, -1, hsz
1445 - ) # shape (bsz, start_n_top, hsz)
1446 - start_states = torch.gather(
1447 - hidden_states, -2, start_top_index_exp
1448 - ) # shape (bsz, start_n_top, hsz)
1449 - start_states = start_states.unsqueeze(1).expand(
1450 - -1, slen, -1, -1
1451 - ) # shape (bsz, slen, start_n_top, hsz)
1452 -
1453 - hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1454 - start_states
1455 - ) # shape (bsz, slen, start_n_top, hsz)
1456 - p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1457 - end_logits = self.end_logits(
1458 - hidden_states_expanded, start_states=start_states, p_mask=p_mask
1459 - )
1460 - end_log_probs = F.softmax(
1461 - end_logits, dim=1
1462 - ) # shape (bsz, slen, start_n_top)
1463 -
1464 - end_top_log_probs, end_top_index = torch.topk(
1465 - end_log_probs, self.end_n_top, dim=1
1466 - ) # shape (bsz, end_n_top, start_n_top)
1467 - end_top_log_probs = end_top_log_probs.view(
1468 - -1, self.start_n_top * self.end_n_top
1469 - )
1470 - end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1471 -
1472 - start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
1473 - cls_logits = self.answer_class(
1474 - hidden_states, start_states=start_states, cls_index=cls_index
1475 - )
1476 -
1477 - if not return_dict:
1478 - return (
1479 - start_top_log_probs,
1480 - start_top_index,
1481 - end_top_log_probs,
1482 - end_top_index,
1483 - cls_logits,
1484 - )
1485 - else:
1486 - return SquadHeadOutput(
1487 - start_top_log_probs=start_top_log_probs,
1488 - start_top_index=start_top_index,
1489 - end_top_log_probs=end_top_log_probs,
1490 - end_top_index=end_top_index,
1491 - cls_logits=cls_logits,
1492 - )
1493 -
1494 -
1495 -class SequenceSummary(nn.Module):
1496 - r"""
1497 - Compute a single vector summary of a sequence hidden states.
1498 -
1499 - Args:
1500 - config (:class:`~transformers.PretrainedConfig`):
1501 - The config used by the model. Relevant arguments in the config class of the model are (refer to the
1502 - actual config class of your model for the default values it uses):
1503 -
1504 - - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:
1505 -
1506 - - :obj:`"last"` -- Take the last token hidden state (like XLNet)
1507 - - :obj:`"first"` -- Take the first token hidden state (like Bert)
1508 - - :obj:`"mean"` -- Take the mean of all tokens hidden states
1509 - - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
1510 - - :obj:`"attn"` -- Not implemented now, use multi-head attention
1511 -
1512 - - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
1513 - - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
1514 - :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
1515 - - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
1516 - output, another string or :obj:`None` will add no activation.
1517 - - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
1518 - activation.
1519 - - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
1520 - activation.
1521 - """
1522 -
1523 - def __init__(self, config: PretrainedConfig):
1524 - super().__init__()
1525 -
1526 - self.summary_type = getattr(config, "summary_type", "last")
1527 - if self.summary_type == "attn":
1528 - # We should use a standard multi-head attention module with absolute positional embedding for that.
1529 - # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
1530 - # We can probably just use the multi-head attention module of PyTorch >=1.1.0
1531 - raise NotImplementedError
1532 -
1533 - self.summary = Identity()
1534 - if hasattr(config, "summary_use_proj") and config.summary_use_proj:
1535 - if (
1536 - hasattr(config, "summary_proj_to_labels")
1537 - and config.summary_proj_to_labels
1538 - and config.num_labels > 0
1539 - ):
1540 - num_classes = config.num_labels
1541 - else:
1542 - num_classes = config.hidden_size
1543 - self.summary = nn.Linear(config.hidden_size, num_classes)
1544 -
1545 - activation_string = getattr(config, "summary_activation", None)
1546 - self.activation: Callable = get_activation(
1547 - activation_string
1548 - ) if activation_string else Identity()
1549 -
1550 - self.first_dropout = Identity()
1551 - if (
1552 - hasattr(config, "summary_first_dropout")
1553 - and config.summary_first_dropout > 0
1554 - ):
1555 - self.first_dropout = nn.Dropout(config.summary_first_dropout)
1556 -
1557 - self.last_dropout = Identity()
1558 - if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1559 - self.last_dropout = nn.Dropout(config.summary_last_dropout)
1560 -
1561 - def forward(
1562 - self,
1563 - hidden_states: torch.FloatTensor,
1564 - cls_index: Optional[torch.LongTensor] = None,
1565 - ) -> torch.FloatTensor:
1566 - """
1567 - Compute a single vector summary of a sequence hidden states.
1568 -
1569 - Args:
1570 - hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
1571 - The hidden states of the last layer.
1572 - cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
1573 - Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
1574 - token.
1575 -
1576 - Returns:
1577 - :obj:`torch.FloatTensor`: The summary of the sequence hidden states.
1578 - """
1579 - if self.summary_type == "last":
1580 - output = hidden_states[:, -1]
1581 - elif self.summary_type == "first":
1582 - output = hidden_states[:, 0]
1583 - elif self.summary_type == "mean":
1584 - output = hidden_states.mean(dim=1)
1585 - elif self.summary_type == "cls_index":
1586 - if cls_index is None:
1587 - cls_index = torch.full_like(
1588 - hidden_states[..., :1, :],
1589 - hidden_states.shape[-2] - 1,
1590 - dtype=torch.long,
1591 - )
1592 - else:
1593 - cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1594 - cls_index = cls_index.expand(
1595 - (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)
1596 - )
1597 - # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1598 - output = hidden_states.gather(-2, cls_index).squeeze(
1599 - -2
1600 - ) # shape (bsz, XX, hidden_size)
1601 - elif self.summary_type == "attn":
1602 - raise NotImplementedError
1603 -
1604 - output = self.first_dropout(output)
1605 - output = self.summary(output)
1606 - output = self.activation(output)
1607 - output = self.last_dropout(output)
1608 -
1609 - return output
1610 -
1611 -
1612 -def prune_linear_layer(
1613 - layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0
1614 -) -> torch.nn.Linear:
1615 - """
1616 - Prune a linear layer to keep only entries in index.
1617 -
1618 - Used to remove heads.
1619 -
1620 - Args:
1621 - layer (:obj:`torch.nn.Linear`): The layer to prune.
1622 - index (:obj:`torch.LongTensor`): The indices to keep in the layer.
1623 - dim (:obj:`int`, `optional`, defaults to 0): The dimension on which to keep the indices.
1624 -
1625 - Returns:
1626 - :obj:`torch.nn.Linear`: The pruned layer as a new layer with :obj:`requires_grad=True`.
1627 - """
1628 - index = index.to(layer.weight.device)
1629 - W = layer.weight.index_select(dim, index).clone().detach()
1630 - if layer.bias is not None:
1631 - if dim == 1:
1632 - b = layer.bias.clone().detach()
1633 - else:
1634 - b = layer.bias[index].clone().detach()
1635 - new_size = list(layer.weight.size())
1636 - new_size[dim] = len(index)
1637 - new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(
1638 - layer.weight.device
1639 - )
1640 - new_layer.weight.requires_grad = False
1641 - new_layer.weight.copy_(W.contiguous())
1642 - new_layer.weight.requires_grad = True
1643 - if layer.bias is not None:
1644 - new_layer.bias.requires_grad = False
1645 - new_layer.bias.copy_(b.contiguous())
1646 - new_layer.bias.requires_grad = True
1647 - return new_layer
1648 -
1649 -
1650 -def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
1651 - """
1652 - Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
1653 - are transposed.
1654 -
1655 - Used to remove heads.
1656 -
1657 - Args:
1658 - layer (:class:`~transformers.modeling_utils.Conv1D`): The layer to prune.
1659 - index (:obj:`torch.LongTensor`): The indices to keep in the layer.
1660 - dim (:obj:`int`, `optional`, defaults to 1): The dimension on which to keep the indices.
1661 -
1662 - Returns:
1663 - :class:`~transformers.modeling_utils.Conv1D`: The pruned layer as a new layer with :obj:`requires_grad=True`.
1664 - """
1665 - index = index.to(layer.weight.device)
1666 - W = layer.weight.index_select(dim, index).clone().detach()
1667 - if dim == 0:
1668 - b = layer.bias.clone().detach()
1669 - else:
1670 - b = layer.bias[index].clone().detach()
1671 - new_size = list(layer.weight.size())
1672 - new_size[dim] = len(index)
1673 - new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
1674 - new_layer.weight.requires_grad = False
1675 - new_layer.weight.copy_(W.contiguous())
1676 - new_layer.weight.requires_grad = True
1677 - new_layer.bias.requires_grad = False
1678 - new_layer.bias.copy_(b.contiguous())
1679 - new_layer.bias.requires_grad = True
1680 - return new_layer
1681 -
1682 -
1683 -def prune_layer(
1684 - layer: Union[torch.nn.Linear, Conv1D],
1685 - index: torch.LongTensor,
1686 - dim: Optional[int] = None,
1687 -) -> Union[torch.nn.Linear, Conv1D]:
1688 - """
1689 - Prune a Conv1D or linear layer to keep only entries in index.
1690 -
1691 - Used to remove heads.
1692 -
1693 - Args:
1694 - layer (:obj:`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
1695 - index (:obj:`torch.LongTensor`): The indices to keep in the layer.
1696 - dim (:obj:`int`, `optional`): The dimension on which to keep the indices.
1697 -
1698 - Returns:
1699 - :obj:`torch.nn.Linear` or :class:`~transformers.modeling_utils.Conv1D`:
1700 - The pruned layer as a new layer with :obj:`requires_grad=True`.
1701 - """
1702 - if isinstance(layer, nn.Linear):
1703 - return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
1704 - elif isinstance(layer, Conv1D):
1705 - return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
1706 - else:
1707 - raise ValueError("Can't prune layer of class {}".format(layer.__class__))
1708 -
1709 -
1710 -def apply_chunking_to_forward(
1711 - forward_fn: Callable[..., torch.Tensor],
1712 - chunk_size: int,
1713 - chunk_dim: int,
1714 - *input_tensors,
1715 -) -> torch.Tensor:
1716 - """
1717 - This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
1718 - dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory.
1719 -
1720 - If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as
1721 - directly applying :obj:`forward_fn` to :obj:`input_tensors`.
1722 -
1723 - Args:
1724 - forward_fn (:obj:`Callable[..., torch.Tensor]`):
1725 - The forward function of the model.
1726 - chunk_size (:obj:`int`):
1727 - The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
1728 - chunk_dim (:obj:`int`):
1729 - The dimension over which the :obj:`input_tensors` should be chunked.
1730 - input_tensors (:obj:`Tuple[torch.Tensor]`):
1731 - The input tensors of ``forward_fn`` which will be chunked.
1732 - Returns:
1733 - :obj:`torch.Tensor`: A tensor with the same shape as the :obj:`foward_fn` would have given if applied`.
1734 -
1735 -
1736 - Examples::
1737 -
1738 - # rename the usual forward() fn to forward_chunk()
1739 - def forward_chunk(self, hidden_states):
1740 - hidden_states = self.decoder(hidden_states)
1741 - return hidden_states
1742 -
1743 - # implement a chunked forward function
1744 - def forward(self, hidden_states):
1745 - return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
1746 - """
1747 -
1748 - assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(
1749 - input_tensors
1750 - )
1751 - tensor_shape = input_tensors[0].shape
1752 - assert all(
1753 - input_tensor.shape == tensor_shape for input_tensor in input_tensors
1754 - ), "All input tenors have to be of the same shape"
1755 -
1756 - # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
1757 - num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
1758 - assert num_args_in_forward_chunk_fn == len(
1759 - input_tensors
1760 - ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
1761 - num_args_in_forward_chunk_fn, len(input_tensors)
1762 - )
1763 -
1764 - if chunk_size > 0:
1765 - assert (
1766 - input_tensors[0].shape[chunk_dim] % chunk_size == 0
1767 - ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
1768 - input_tensors[0].shape[chunk_dim], chunk_size
1769 - )
1770 -
1771 - num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1772 -
1773 - # chunk input tensor into tuples
1774 - input_tensors_chunks = tuple(
1775 - input_tensor.chunk(num_chunks, dim=chunk_dim)
1776 - for input_tensor in input_tensors
1777 - )
1778 - # apply forward fn to every tuple
1779 - output_chunks = tuple(
1780 - forward_fn(*input_tensors_chunk)
1781 - for input_tensors_chunk in zip(*input_tensors_chunks)
1782 - )
1783 - # concatenate output at same dimension
1784 - return torch.cat(output_chunks, dim=chunk_dim)
1785 -
1786 - return forward_fn(*input_tensors)
1 -import itertools
2 -import json
3 -import linecache
4 -import os
5 -import pickle
6 -from logging import getLogger
7 -from pathlib import Path
8 -from typing import Callable, Dict, Iterable, List
9 -
10 -import git
11 -import numpy as np
12 -import torch
13 -from rouge_score import rouge_scorer, scoring
14 -from sacrebleu import corpus_bleu
15 -from torch import nn
16 -from torch.utils.data import Dataset, Sampler
17 -
18 -from transformers import BartTokenizer
19 -
20 -
21 -def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
22 - """From fairseq"""
23 - if target.dim() == lprobs.dim() - 1:
24 - target = target.unsqueeze(-1)
25 - nll_loss = -lprobs.gather(dim=-1, index=target)
26 - smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
27 - if ignore_index is not None:
28 - pad_mask = target.eq(ignore_index)
29 - nll_loss.masked_fill_(pad_mask, 0.0)
30 - smooth_loss.masked_fill_(pad_mask, 0.0)
31 - else:
32 - nll_loss = nll_loss.squeeze(-1)
33 - smooth_loss = smooth_loss.squeeze(-1)
34 -
35 - nll_loss = nll_loss.sum() # mean()? Scared to break other math.
36 - smooth_loss = smooth_loss.sum()
37 - eps_i = epsilon / lprobs.size(-1)
38 - loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
39 - return loss, nll_loss
40 -
41 -
42 -def encode_line(
43 - tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"
44 -):
45 - """Only used by LegacyDataset"""
46 - extra_kw = (
47 - {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
48 - )
49 - return tokenizer(
50 - [line],
51 - max_length=max_length,
52 - padding="max_length" if pad_to_max_length else None,
53 - truncation=True,
54 - return_tensors=return_tensors,
55 - **extra_kw,
56 - )
57 -
58 -
59 -def lmap(f: Callable, x: Iterable) -> List:
60 - """list(map(f, x))"""
61 - return list(map(f, x))
62 -
63 -
64 -def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
65 - """Uses sacrebleu's corpus_bleu implementation."""
66 - return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
67 -
68 -
69 -def trim_batch(
70 - input_ids, pad_token_id, attention_mask=None,
71 -):
72 - """Remove columns that are populated exclusively by pad_token_id"""
73 - keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
74 - if attention_mask is None:
75 - return input_ids[:, keep_column_mask]
76 - else:
77 - return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
78 -
79 -
80 -class AbstractSeq2SeqDataset(Dataset):
81 - def __init__(
82 - self,
83 - tokenizer,
84 - data_dir,
85 - max_source_length,
86 - max_target_length,
87 - type_path="train",
88 - n_obs=None,
89 - src_lang=None,
90 - tgt_lang=None,
91 - prefix="",
92 - ):
93 - super().__init__()
94 - self.src_file = Path(data_dir).joinpath(type_path + ".source")
95 - self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
96 - self.src_lens = self.get_char_lens(self.src_file)
97 - self.max_source_length = max_source_length
98 - self.max_target_length = max_target_length
99 - assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
100 - self.tokenizer = tokenizer
101 - self.prefix = prefix
102 - if n_obs is not None:
103 - self.src_lens = self.src_lens[:n_obs]
104 - self.pad_token_id = self.tokenizer.pad_token_id
105 - self.src_lang = src_lang
106 - self.tgt_lang = tgt_lang
107 - self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
108 -
109 - def __len__(self):
110 - return len(self.src_lens)
111 -
112 - @staticmethod
113 - def get_char_lens(data_file):
114 - return [len(x) for x in Path(data_file).open().readlines()]
115 -
116 - def make_sortish_sampler(self, batch_size):
117 - return SortishSampler(self.src_lens, batch_size)
118 -
119 - def __getitem__(self, item):
120 - raise NotImplementedError("You must implement this")
121 -
122 - def collate_fn(self, batch):
123 - raise NotImplementedError("You must implement this")
124 -
125 -
126 -class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
127 - def __getitem__(self, index) -> Dict[str, torch.Tensor]:
128 - """Call tokenizer on src and tgt_lines"""
129 - index = index + 1 # linecache starts at 1
130 - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
131 - "\n"
132 - )
133 - tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
134 - assert source_line, f"empty source line for index {index}"
135 - assert tgt_line, f"empty tgt line for index {index}"
136 - source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
137 - target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
138 -
139 - source_ids = source_inputs["input_ids"].squeeze()
140 - target_ids = target_inputs["input_ids"].squeeze()
141 - src_mask = source_inputs["attention_mask"].squeeze()
142 - return {
143 - "input_ids": source_ids,
144 - "attention_mask": src_mask,
145 - "labels": target_ids,
146 - }
147 -
148 - def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
149 - input_ids = torch.stack([x["input_ids"] for x in batch])
150 - masks = torch.stack([x["attention_mask"] for x in batch])
151 - target_ids = torch.stack([x["labels"] for x in batch])
152 - pad_token_id = self.pad_token_id
153 - y = trim_batch(target_ids, pad_token_id)
154 - source_ids, source_mask = trim_batch(
155 - input_ids, pad_token_id, attention_mask=masks
156 - )
157 - batch = {
158 - "input_ids": source_ids,
159 - "attention_mask": source_mask,
160 - "labels": y,
161 - }
162 - return batch
163 -
164 -
165 -class Seq2SeqDataset(AbstractSeq2SeqDataset):
166 - """A dataset that calls prepare_seq2seq_batch."""
167 -
168 - def __getitem__(self, index) -> Dict[str, str]:
169 - index = index + 1 # linecache starts at 1
170 - source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
171 - "\n"
172 - )
173 - tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
174 - assert source_line, f"empty source line for index {index}"
175 - assert tgt_line, f"empty tgt line for index {index}"
176 - return {
177 - "tgt_texts": tgt_line,
178 - "src_texts": source_line,
179 - }
180 -
181 - def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
182 - """Call prepare_seq2seq_batch."""
183 - batch_encoding = self.tokenizer.prepare_seq2seq_batch(
184 - [x["src_texts"] for x in batch],
185 - src_lang=self.src_lang,
186 - tgt_texts=[x["tgt_texts"] for x in batch],
187 - tgt_lang=self.tgt_lang,
188 - max_length=self.max_source_length,
189 - max_target_length=self.max_target_length,
190 - return_tensors="pt",
191 - add_prefix_space=self.add_prefix_space,
192 - )
193 - return batch_encoding.data
194 -
195 -
196 -class SortishSampler(Sampler):
197 - "Go through the text data by order of src length with a bit of randomness. From fastai repo."
198 -
199 - def __init__(self, data, batch_size):
200 - self.data, self.bs = data, batch_size
201 -
202 - def key(self, i):
203 - return self.data[i]
204 -
205 - def __len__(self) -> int:
206 - return len(self.data)
207 -
208 - def __iter__(self):
209 - idxs = np.random.permutation(len(self.data))
210 - sz = self.bs * 50
211 - ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
212 - sort_idx = np.concatenate(
213 - [sorted(s, key=self.key, reverse=True) for s in ck_idx]
214 - )
215 - sz = self.bs
216 - ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
217 - max_ck = np.argmax(
218 - [self.key(ck[0]) for ck in ck_idx]
219 - ) # find the chunk with the largest key,
220 - ck_idx[0], ck_idx[max_ck] = (
221 - ck_idx[max_ck],
222 - ck_idx[0],
223 - ) # then make sure it goes first.
224 - sort_idx = (
225 - np.concatenate(np.random.permutation(ck_idx[1:]))
226 - if len(ck_idx) > 1
227 - else np.array([], dtype=np.int)
228 - )
229 - sort_idx = np.concatenate((ck_idx[0], sort_idx))
230 - return iter(sort_idx)
231 -
232 -
233 -logger = getLogger(__name__)
234 -
235 -
236 -def use_task_specific_params(model, task):
237 - """Update config with summarization specific params."""
238 - task_specific_params = model.config.task_specific_params
239 -
240 - if task_specific_params is not None:
241 - pars = task_specific_params.get(task, {})
242 - logger.info(f"using task specific params for {task}: {pars}")
243 - model.config.update(pars)
244 -
245 -
246 -def pickle_load(path):
247 - """pickle.load(path)"""
248 - with open(path, "rb") as f:
249 - return pickle.load(f)
250 -
251 -
252 -def pickle_save(obj, path):
253 - """pickle.dump(obj, path)"""
254 - with open(path, "wb") as f:
255 - return pickle.dump(obj, f)
256 -
257 -
258 -def flatten_list(summary_ids: List[List]):
259 - return [x for x in itertools.chain.from_iterable(summary_ids)]
260 -
261 -
262 -def save_git_info(folder_path: str) -> None:
263 - """Save git information to output_dir/git_log.json"""
264 - repo_infos = get_git_info()
265 - save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
266 -
267 -
268 -def save_json(content, path):
269 - with open(path, "w") as f:
270 - json.dump(content, f, indent=4)
271 -
272 -
273 -def load_json(path):
274 - with open(path) as f:
275 - return json.load(f)
276 -
277 -
278 -def get_git_info():
279 - repo = git.Repo(search_parent_directories=True)
280 - repo_infos = {
281 - "repo_id": str(repo),
282 - "repo_sha": str(repo.head.object.hexsha),
283 - "repo_branch": str(repo.active_branch),
284 - }
285 - return repo_infos
286 -
287 -
288 -ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
289 -
290 -
291 -def calculate_rouge(
292 - output_lns: List[str], reference_lns: List[str], use_stemmer=True
293 -) -> Dict:
294 - scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
295 - aggregator = scoring.BootstrapAggregator()
296 -
297 - for reference_ln, output_ln in zip(reference_lns, output_lns):
298 - scores = scorer.score(reference_ln, output_ln)
299 - aggregator.add_scores(scores)
300 -
301 - result = aggregator.aggregate()
302 - return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
303 -
304 -
305 -# Utilities for freezing parameters and checking whether they are frozen
306 -
307 -
308 -def freeze_params(model: nn.Module):
309 - """Set requires_grad=False for each of model.parameters()"""
310 - for par in model.parameters():
311 - par.requires_grad = False
312 -
313 -
314 -def grad_status(model: nn.Module) -> Iterable:
315 - return (par.requires_grad for par in model.parameters())
316 -
317 -
318 -def any_requires_grad(model: nn.Module) -> bool:
319 - return any(grad_status(model))
320 -
321 -
322 -def assert_all_frozen(model):
323 - model_grads: List[bool] = list(grad_status(model))
324 - n_require_grad = sum(lmap(int, model_grads))
325 - npars = len(model_grads)
326 - assert not any(
327 - model_grads
328 - ), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
329 -
330 -
331 -def assert_not_all_frozen(model):
332 - model_grads: List[bool] = list(grad_status(model))
333 - npars = len(model_grads)
334 - assert any(model_grads), f"none of {npars} weights require grad"