graykode

(add) splitting code

......@@ -15,14 +15,14 @@
import os
import re
import json
import random
import jsonlines
import argparse
from tqdm import tqdm
from functools import partial
from collections import defaultdict
from multiprocessing.pool import Pool
from transformers import RobertaTokenizer
from pydriller import GitRepository, RepositoryMining
from pydriller import RepositoryMining
def message_cleaner(message):
msg = message.split("\n")[0]
......@@ -69,6 +69,12 @@ def jobs(repo, args):
}
)
def write_jsonl(lines, path, mode):
saved_path = os.path.join(path, mode)
for line in lines:
with jsonlines.open(f"{saved_path}.jsonl", mode="a") as writer:
writer.write(line)
def main(args):
repos = set()
with open(args.repositories, encoding="utf-8") as f:
......@@ -85,6 +91,27 @@ def main(args):
for i, _ in tqdm(enumerate(pool.imap_unordered(func, repos))):
pbar.update()
data = []
with open(args.output_file, encoding="utf-8") as f:
for idx, line in enumerate(f):
line = line.strip()
data.append(json.loads(line))
random.shuffle(data)
n_data = len(data)
write_jsonl(
data[:int(n_data * 0.9)],
path=args.output_dir, mode='train'
)
write_jsonl(
data[int(n_data * 0.9):int(n_data * 0.95)],
path=args.output_dir, mode='validation'
)
write_jsonl(
data[int(n_data * 0.95):],
path=args.output_dir, mode='test'
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
......