graykode

(add) split with train and test

......@@ -15,6 +15,7 @@
import os
import re
import enum
import random
import logging
import argparse
import numpy as np
......@@ -87,8 +88,7 @@ def sha_parse(sha, tokenizer, max_length=1024):
return (input_ids, attention_masks, patch_ids)
def message_parse(msg, tokenizer, max_length=56):
msg = re.sub(r'#([0-9])+', '', msg)
msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg)
msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg)
msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip()
msg = tokenizer.tokenize(msg)
......@@ -97,7 +97,7 @@ def message_parse(msg, tokenizer, max_length=56):
return msg
def jobs(sha_msgs, args, data_config):
def jobs(sha_msgs, args, data_config, train=True):
input_ids, attention_masks, patch_ids, targets = [], [], [], []
data_saver = DataSaver(config=data_config)
......@@ -105,11 +105,19 @@ def jobs(sha_msgs, args, data_config):
for sha_msg in sha_msgs:
sha, msg = sha_msg
source = sha_parse(sha, tokenizer=args.tokenizer)
source = sha_parse(
sha,
tokenizer=args.tokenizer,
max_length=args.max_source_length
)
if not source:
continue
input_id, attention_mask, patch_id = source
target = message_parse(msg, tokenizer=args.tokenizer)
target = message_parse(
msg,
tokenizer=args.tokenizer,
max_length=(args.max_target_length if train else args.val_max_target_length),
)
input_ids.append(input_id)
attention_masks.append(attention_mask)
......@@ -124,9 +132,11 @@ def jobs(sha_msgs, args, data_config):
})
data_saver.disconnect()
def main(args):
if 'access_key' not in os.environ or 'secret_key' not in os.environ:
raise OSError("access_key or secret_key are not found.")
def start(chunked_sha_msgs, train=True):
logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation"))
max_target_length = args.max_target_length if train else args.val_max_target_length
data_config = DataConfig(
endpoint=args.matorage_dir,
......@@ -134,27 +144,39 @@ def main(args):
secret_key=os.environ['secret_key'],
dataset_name='commit-autosuggestions',
additional={
"mode" : ("training" if train else "evaluation"),
"max_source_length": args.max_source_length,
"max_target_length": args.max_target_length,
"max_target_length": max_target_length,
"url" : args.url,
},
attributes = [
attributes=[
('input_ids', 'int32', (args.max_source_length,)),
('attention_masks', 'int32', (args.max_source_length,)),
('patch_ids', 'int32', (args.max_source_length,)),
('targets', 'int32', (args.max_target_length,))
('targets', 'int32', (max_target_length,))
]
)
func = partial(jobs, args=args, data_config=data_config, train=train)
with Pool(processes=args.num_workers) as pool:
with tqdm(total=len(chunked_sha_msgs)) as pbar:
for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
pbar.update()
def main(args):
if 'access_key' not in os.environ or 'secret_key' not in os.environ:
raise OSError("access_key or secret_key are not found.")
sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
random.shuffle(sha_msgs)
chunked_sha_msgs = [
sha_msgs[x:x + args.matorage_batch]
for x in range(0, len(sha_msgs), args.matorage_batch)
]
func = partial(jobs, args=args, data_config=data_config)
with Pool(processes=args.num_workers) as pool:
with tqdm(total=len(chunked_sha_msgs)) as pbar:
for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
pbar.update()
barrier = int(len(chunked_sha_msgs) * (1 - args.p_val))
start(chunked_sha_msgs[:barrier], train=True)
start(chunked_sha_msgs[barrier:], train=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to collect commits on github")
......@@ -196,6 +218,14 @@ if __name__ == "__main__":
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset")
args = parser.parse_args()
args.local_path = args.url.split('/')[-1]
......