graykode

(refactor) create diff_parse function, (add) tokenizer arguments

# Copyright 2020-present Tae Hwan Jung
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .gitcommit import diff_parse
__all__ = [
'diff_parse'
]
\ No newline at end of file
......@@ -17,6 +17,7 @@ import re
import enum
import random
import logging
import tempfile
import argparse
import numpy as np
from tqdm import *
......@@ -62,10 +63,9 @@ def encode_line(tokenizer, line, patch):
len(tokens) * [patch.value]
)
def sha_parse(sha, tokenizer, max_length=1024):
def diff_parse(diff, tokenizer):
chunks = []
for diff in whatthepatch.parse_patch(repo.git.show(sha)):
for diff in whatthepatch.parse_patch(diff):
if diff.header.old_path != diff.header.new_path:
chunks.append(encode_line(tokenizer, diff.header.old_path, PATCH.MINUS))
chunks.append(encode_line(tokenizer, diff.header.new_path, PATCH.PLUS))
......@@ -76,7 +76,11 @@ def sha_parse(sha, tokenizer, max_length=1024):
chunks.append(encode_line(tokenizer, change.line, PATCH.PLUS))
elif change.old != None and change.new == None:
chunks.append(encode_line(tokenizer, change.line, PATCH.MINUS))
return chunks
def sha_parse(sha, tokenizer, max_length=1024):
chunks = diff_parse(diff=repo.git.show(sha), tokenizer=tokenizer)
if not chunks:
return None
......@@ -202,10 +206,16 @@ if __name__ == "__main__":
help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
)
parser.add_argument(
"--tokenizer_name",
default='sshleifer/distilbart-xsum-6-6',
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--matorage_batch",
default=1024,
type=int,
help='batch size to store data.'
help='The smallest batch size stored atomically in matorage.'
)
parser.add_argument(
"--num_workers",
......@@ -246,6 +256,6 @@ if __name__ == "__main__":
if os.path.exists(args.local_path)
else Repo.clone_from(args.url, to_path=args.local_path, branch="master")
)
args.tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
args.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
main(args)
......