Showing
7 changed files
with
6 additions
and
6 deletions
File moved
File moved
File moved
... | @@ -3,9 +3,7 @@ | ... | @@ -3,9 +3,7 @@ |
3 | 3 | ||
4 | import torch | 4 | import torch |
5 | import torch.nn as nn | 5 | import torch.nn as nn |
6 | -import torch | 6 | + |
7 | -from torch.autograd import Variable | ||
8 | -import copy | ||
9 | class Seq2Seq(nn.Module): | 7 | class Seq2Seq(nn.Module): |
10 | """ | 8 | """ |
11 | Build Seqence-to-Sequence. | 9 | Build Seqence-to-Sequence. |
... | @@ -162,7 +160,7 @@ class Beam(object): | ... | @@ -162,7 +160,7 @@ class Beam(object): |
162 | 160 | ||
163 | # bestScoresId is flattened beam x word array, so calculate which | 161 | # bestScoresId is flattened beam x word array, so calculate which |
164 | # word and beam each score came from | 162 | # word and beam each score came from |
165 | - prevK = bestScoresId / numWords | 163 | + prevK = bestScoresId // numWords |
166 | self.prevKs.append(prevK) | 164 | self.prevKs.append(prevK) |
167 | self.nextYs.append((bestScoresId - prevK * numWords)) | 165 | self.nextYs.append((bestScoresId - prevK * numWords)) |
168 | 166 | ... | ... |
... | @@ -22,7 +22,6 @@ using a masked language modeling (MLM) loss. | ... | @@ -22,7 +22,6 @@ using a masked language modeling (MLM) loss. |
22 | from __future__ import absolute_import | 22 | from __future__ import absolute_import |
23 | import os | 23 | import os |
24 | import sys | 24 | import sys |
25 | -import bleu | ||
26 | import pickle | 25 | import pickle |
27 | import torch | 26 | import torch |
28 | import json | 27 | import json |
... | @@ -35,11 +34,14 @@ from itertools import cycle | ... | @@ -35,11 +34,14 @@ from itertools import cycle |
35 | import torch.nn as nn | 34 | import torch.nn as nn |
36 | from model import Seq2Seq | 35 | from model import Seq2Seq |
37 | from tqdm import tqdm, trange | 36 | from tqdm import tqdm, trange |
38 | -from customized_roberta import RobertaModel | ||
39 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset | 37 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset |
40 | from torch.utils.data.distributed import DistributedSampler | 38 | from torch.utils.data.distributed import DistributedSampler |
41 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, | 39 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, |
42 | RobertaConfig, RobertaTokenizer) | 40 | RobertaConfig, RobertaTokenizer) |
41 | + | ||
42 | +import train.bleu as bleu | ||
43 | +from train.customized_roberta import RobertaModel | ||
44 | + | ||
43 | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} | 45 | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)} |
44 | 46 | ||
45 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | 47 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ... | ... |
-
Please register or login to post a comment