graykode

(add) matorage runnable, edit cpyton to pandas in gitignore

......@@ -137,5 +137,5 @@ dmypy.json
# Cython debug symbols
cython_debug/
cpython
pandas
.idea/
\ No newline at end of file
......
......@@ -16,6 +16,9 @@ from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from matorage import DataConfig
from matorage.torch import Dataset
try:
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
......@@ -75,18 +78,6 @@ class SummarizationModule(BaseTransformer):
self.step_count = 0
self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=self.model.config.prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.target_lens = {
"train": self.hparams.max_target_length,
"val": self.hparams.val_max_target_length,
......@@ -107,9 +98,7 @@ class SummarizationModule(BaseTransformer):
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
self.dataset_class = (
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
)
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
......@@ -137,8 +126,8 @@ class SummarizationModule(BaseTransformer):
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
tgt_ids = batch["labels"]
src_ids, src_mask, src_patch = batch[0].long(), batch[1].long(), batch[2].long()
tgt_ids = batch[3].long()
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(tgt_ids)
else:
......@@ -168,7 +157,7 @@ class SummarizationModule(BaseTransformer):
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
logs["tpb"] = batch[0].long().ne(self.pad).sum() + batch[3].long().ne(self.pad).sum()
return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict:
......@@ -198,14 +187,15 @@ class SummarizationModule(BaseTransformer):
def _generative_step(self, batch: dict) -> dict:
t0 = time.time()
generated_ids = self.model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
batch[0].long(),
attention_mask=batch[1].long(),
# patch_ids=batch[2].long(),
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
)
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
gen_time = (time.time() - t0) / batch[0].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids)
target: List[str] = self.ids_to_clean_text(batch["labels"])
target: List[str] = self.ids_to_clean_text(batch[3])
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = self.calc_generative_metrics(preds, target)
......@@ -220,29 +210,34 @@ class SummarizationModule(BaseTransformer):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = self.dataset_class(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
max_target_length=max_target_length,
**self.dataset_kwargs,
data_config = DataConfig(
endpoint=args.matorage_dir,
access_key=os.environ['access_key'],
secret_key=os.environ['secret_key'],
dataset_name='commit-autosuggestions',
additional={
"mode": ("training" if type_path == "train" else "evaluation"),
"max_source_length": self.hparams.max_source_length,
"max_target_length": max_target_length,
"url": args.url,
},
attributes=[
('input_ids', 'int32', (self.hparams.max_source_length,)),
('attention_masks', 'int32', (self.hparams.max_source_length,)),
('patch_ids', 'int32', (self.hparams.max_source_length,)),
('targets', 'int32', (max_target_length,))
]
)
return dataset
return Dataset(config=data_config, clear=True)
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
sampler=sampler,
......@@ -264,6 +259,18 @@ class SummarizationModule(BaseTransformer):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--url",
type=str,
required=True,
help="github url"
)
parser.add_argument(
"--matorage_dir",
type=str,
required=True,
help='matorage saved directory.'
)
parser.add_argument(
"--max_source_length",
default=1024,
type=int,
......@@ -341,29 +348,8 @@ def main(args, model=None) -> SummarizationModule:
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
project = os.environ.get("WANDB_PROJECT", dataset)
logger = WandbLogger(name=model.output_dir.name, project=project)
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
if args.early_stopping_patience >= 0:
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
else:
es_callback = False
logger = True
es_callback = False
trainer: pl.Trainer = generic_train(
model,
args,
......
......@@ -323,13 +323,6 @@ def add_generic_args(parser, root_dir) -> None:
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
def generic_train(
......