graykode

(add) splitting code

...@@ -15,14 +15,14 @@ ...@@ -15,14 +15,14 @@
15 import os 15 import os
16 import re 16 import re
17 import json 17 import json
18 +import random
18 import jsonlines 19 import jsonlines
19 import argparse 20 import argparse
20 from tqdm import tqdm 21 from tqdm import tqdm
21 from functools import partial 22 from functools import partial
22 -from collections import defaultdict
23 from multiprocessing.pool import Pool 23 from multiprocessing.pool import Pool
24 from transformers import RobertaTokenizer 24 from transformers import RobertaTokenizer
25 -from pydriller import GitRepository, RepositoryMining 25 +from pydriller import RepositoryMining
26 26
27 def message_cleaner(message): 27 def message_cleaner(message):
28 msg = message.split("\n")[0] 28 msg = message.split("\n")[0]
...@@ -69,6 +69,12 @@ def jobs(repo, args): ...@@ -69,6 +69,12 @@ def jobs(repo, args):
69 } 69 }
70 ) 70 )
71 71
72 +def write_jsonl(lines, path, mode):
73 + saved_path = os.path.join(path, mode)
74 + for line in lines:
75 + with jsonlines.open(f"{saved_path}.jsonl", mode="a") as writer:
76 + writer.write(line)
77 +
72 def main(args): 78 def main(args):
73 repos = set() 79 repos = set()
74 with open(args.repositories, encoding="utf-8") as f: 80 with open(args.repositories, encoding="utf-8") as f:
...@@ -85,6 +91,27 @@ def main(args): ...@@ -85,6 +91,27 @@ def main(args):
85 for i, _ in tqdm(enumerate(pool.imap_unordered(func, repos))): 91 for i, _ in tqdm(enumerate(pool.imap_unordered(func, repos))):
86 pbar.update() 92 pbar.update()
87 93
94 + data = []
95 + with open(args.output_file, encoding="utf-8") as f:
96 + for idx, line in enumerate(f):
97 + line = line.strip()
98 + data.append(json.loads(line))
99 +
100 + random.shuffle(data)
101 + n_data = len(data)
102 + write_jsonl(
103 + data[:int(n_data * 0.9)],
104 + path=args.output_dir, mode='train'
105 + )
106 + write_jsonl(
107 + data[int(n_data * 0.9):int(n_data * 0.95)],
108 + path=args.output_dir, mode='validation'
109 + )
110 + write_jsonl(
111 + data[int(n_data * 0.95):],
112 + path=args.output_dir, mode='test'
113 + )
114 +
88 115
89 if __name__ == "__main__": 116 if __name__ == "__main__":
90 parser = argparse.ArgumentParser(description="") 117 parser = argparse.ArgumentParser(description="")
......