graykode

(refactor) argument renaming, add region, do_train and do_predict arguments

......@@ -139,9 +139,10 @@ def start(chunked_sha_msgs, train=True):
max_target_length = args.max_target_length if train else args.val_max_target_length
data_config = DataConfig(
endpoint=args.matorage_dir,
endpoint=args.endpoint,
access_key=os.environ['access_key'],
secret_key=os.environ['secret_key'],
region=args.region,
dataset_name='commit-autosuggestions',
additional={
"mode" : ("training" if train else "evaluation"),
......@@ -175,8 +176,10 @@ def main(args):
]
barrier = int(len(chunked_sha_msgs) * (1 - args.p_val))
start(chunked_sha_msgs[:barrier], train=True)
start(chunked_sha_msgs[barrier:], train=False)
if args.do_train:
start(chunked_sha_msgs[:barrier], train=True)
if args.do_predict:
start(chunked_sha_msgs[barrier:], train=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to collect commits on github")
......@@ -187,10 +190,16 @@ if __name__ == "__main__":
help="github url"
)
parser.add_argument(
"--matorage_dir",
"--endpoint",
type=str,
required=True,
help='matorage saved directory.'
help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
)
parser.add_argument(
"--region",
type=str,
default=None,
help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
)
parser.add_argument(
"--matorage_batch",
......@@ -226,6 +235,8 @@ if __name__ == "__main__":
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset")
parser.add_argument("--do_train", action="store_true", default=False)
parser.add_argument("--do_predict", action="store_true", default=False)
args = parser.parse_args()
args.local_path = args.url.split('/')[-1]
......