Showing
1 changed file
with
29 additions
and
2 deletions
... | @@ -15,14 +15,14 @@ | ... | @@ -15,14 +15,14 @@ |
15 | import os | 15 | import os |
16 | import re | 16 | import re |
17 | import json | 17 | import json |
18 | +import random | ||
18 | import jsonlines | 19 | import jsonlines |
19 | import argparse | 20 | import argparse |
20 | from tqdm import tqdm | 21 | from tqdm import tqdm |
21 | from functools import partial | 22 | from functools import partial |
22 | -from collections import defaultdict | ||
23 | from multiprocessing.pool import Pool | 23 | from multiprocessing.pool import Pool |
24 | from transformers import RobertaTokenizer | 24 | from transformers import RobertaTokenizer |
25 | -from pydriller import GitRepository, RepositoryMining | 25 | +from pydriller import RepositoryMining |
26 | 26 | ||
27 | def message_cleaner(message): | 27 | def message_cleaner(message): |
28 | msg = message.split("\n")[0] | 28 | msg = message.split("\n")[0] |
... | @@ -69,6 +69,12 @@ def jobs(repo, args): | ... | @@ -69,6 +69,12 @@ def jobs(repo, args): |
69 | } | 69 | } |
70 | ) | 70 | ) |
71 | 71 | ||
72 | +def write_jsonl(lines, path, mode): | ||
73 | + saved_path = os.path.join(path, mode) | ||
74 | + for line in lines: | ||
75 | + with jsonlines.open(f"{saved_path}.jsonl", mode="a") as writer: | ||
76 | + writer.write(line) | ||
77 | + | ||
72 | def main(args): | 78 | def main(args): |
73 | repos = set() | 79 | repos = set() |
74 | with open(args.repositories, encoding="utf-8") as f: | 80 | with open(args.repositories, encoding="utf-8") as f: |
... | @@ -85,6 +91,27 @@ def main(args): | ... | @@ -85,6 +91,27 @@ def main(args): |
85 | for i, _ in tqdm(enumerate(pool.imap_unordered(func, repos))): | 91 | for i, _ in tqdm(enumerate(pool.imap_unordered(func, repos))): |
86 | pbar.update() | 92 | pbar.update() |
87 | 93 | ||
94 | + data = [] | ||
95 | + with open(args.output_file, encoding="utf-8") as f: | ||
96 | + for idx, line in enumerate(f): | ||
97 | + line = line.strip() | ||
98 | + data.append(json.loads(line)) | ||
99 | + | ||
100 | + random.shuffle(data) | ||
101 | + n_data = len(data) | ||
102 | + write_jsonl( | ||
103 | + data[:int(n_data * 0.9)], | ||
104 | + path=args.output_dir, mode='train' | ||
105 | + ) | ||
106 | + write_jsonl( | ||
107 | + data[int(n_data * 0.9):int(n_data * 0.95)], | ||
108 | + path=args.output_dir, mode='validation' | ||
109 | + ) | ||
110 | + write_jsonl( | ||
111 | + data[int(n_data * 0.95):], | ||
112 | + path=args.output_dir, mode='test' | ||
113 | + ) | ||
114 | + | ||
88 | 115 | ||
89 | if __name__ == "__main__": | 116 | if __name__ == "__main__": |
90 | parser = argparse.ArgumentParser(description="") | 117 | parser = argparse.ArgumentParser(description="") | ... | ... |
-
Please register or login to post a comment