graykode

(add) split with train and test

...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
15 import os 15 import os
16 import re 16 import re
17 import enum 17 import enum
18 +import random
18 import logging 19 import logging
19 import argparse 20 import argparse
20 import numpy as np 21 import numpy as np
...@@ -87,8 +88,7 @@ def sha_parse(sha, tokenizer, max_length=1024): ...@@ -87,8 +88,7 @@ def sha_parse(sha, tokenizer, max_length=1024):
87 return (input_ids, attention_masks, patch_ids) 88 return (input_ids, attention_masks, patch_ids)
88 89
89 def message_parse(msg, tokenizer, max_length=56): 90 def message_parse(msg, tokenizer, max_length=56):
90 - msg = re.sub(r'#([0-9])+', '', msg) 91 + msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg)
91 - msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg)
92 92
93 msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip() 93 msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip()
94 msg = tokenizer.tokenize(msg) 94 msg = tokenizer.tokenize(msg)
...@@ -97,7 +97,7 @@ def message_parse(msg, tokenizer, max_length=56): ...@@ -97,7 +97,7 @@ def message_parse(msg, tokenizer, max_length=56):
97 97
98 return msg 98 return msg
99 99
100 -def jobs(sha_msgs, args, data_config): 100 +def jobs(sha_msgs, args, data_config, train=True):
101 101
102 input_ids, attention_masks, patch_ids, targets = [], [], [], [] 102 input_ids, attention_masks, patch_ids, targets = [], [], [], []
103 data_saver = DataSaver(config=data_config) 103 data_saver = DataSaver(config=data_config)
...@@ -105,11 +105,19 @@ def jobs(sha_msgs, args, data_config): ...@@ -105,11 +105,19 @@ def jobs(sha_msgs, args, data_config):
105 for sha_msg in sha_msgs: 105 for sha_msg in sha_msgs:
106 sha, msg = sha_msg 106 sha, msg = sha_msg
107 107
108 - source = sha_parse(sha, tokenizer=args.tokenizer) 108 + source = sha_parse(
109 + sha,
110 + tokenizer=args.tokenizer,
111 + max_length=args.max_source_length
112 + )
109 if not source: 113 if not source:
110 continue 114 continue
111 input_id, attention_mask, patch_id = source 115 input_id, attention_mask, patch_id = source
112 - target = message_parse(msg, tokenizer=args.tokenizer) 116 + target = message_parse(
117 + msg,
118 + tokenizer=args.tokenizer,
119 + max_length=(args.max_target_length if train else args.val_max_target_length),
120 + )
113 121
114 input_ids.append(input_id) 122 input_ids.append(input_id)
115 attention_masks.append(attention_mask) 123 attention_masks.append(attention_mask)
...@@ -124,9 +132,11 @@ def jobs(sha_msgs, args, data_config): ...@@ -124,9 +132,11 @@ def jobs(sha_msgs, args, data_config):
124 }) 132 })
125 data_saver.disconnect() 133 data_saver.disconnect()
126 134
127 -def main(args): 135 +def start(chunked_sha_msgs, train=True):
128 - if 'access_key' not in os.environ or 'secret_key' not in os.environ: 136 +
129 - raise OSError("access_key or secret_key are not found.") 137 + logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation"))
138 +
139 + max_target_length = args.max_target_length if train else args.val_max_target_length
130 140
131 data_config = DataConfig( 141 data_config = DataConfig(
132 endpoint=args.matorage_dir, 142 endpoint=args.matorage_dir,
...@@ -134,27 +144,39 @@ def main(args): ...@@ -134,27 +144,39 @@ def main(args):
134 secret_key=os.environ['secret_key'], 144 secret_key=os.environ['secret_key'],
135 dataset_name='commit-autosuggestions', 145 dataset_name='commit-autosuggestions',
136 additional={ 146 additional={
147 + "mode" : ("training" if train else "evaluation"),
137 "max_source_length": args.max_source_length, 148 "max_source_length": args.max_source_length,
138 - "max_target_length": args.max_target_length, 149 + "max_target_length": max_target_length,
150 + "url" : args.url,
139 }, 151 },
140 - attributes = [ 152 + attributes=[
141 ('input_ids', 'int32', (args.max_source_length,)), 153 ('input_ids', 'int32', (args.max_source_length,)),
142 ('attention_masks', 'int32', (args.max_source_length,)), 154 ('attention_masks', 'int32', (args.max_source_length,)),
143 ('patch_ids', 'int32', (args.max_source_length,)), 155 ('patch_ids', 'int32', (args.max_source_length,)),
144 - ('targets', 'int32', (args.max_target_length,)) 156 + ('targets', 'int32', (max_target_length,))
145 ] 157 ]
146 ) 158 )
147 159
160 + func = partial(jobs, args=args, data_config=data_config, train=train)
161 + with Pool(processes=args.num_workers) as pool:
162 + with tqdm(total=len(chunked_sha_msgs)) as pbar:
163 + for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))):
164 + pbar.update()
165 +
166 +def main(args):
167 + if 'access_key' not in os.environ or 'secret_key' not in os.environ:
168 + raise OSError("access_key or secret_key are not found.")
169 +
148 sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] 170 sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()]
171 + random.shuffle(sha_msgs)
149 chunked_sha_msgs = [ 172 chunked_sha_msgs = [
150 sha_msgs[x:x + args.matorage_batch] 173 sha_msgs[x:x + args.matorage_batch]
151 for x in range(0, len(sha_msgs), args.matorage_batch) 174 for x in range(0, len(sha_msgs), args.matorage_batch)
152 ] 175 ]
153 - func = partial(jobs, args=args, data_config=data_config) 176 +
154 - with Pool(processes=args.num_workers) as pool: 177 + barrier = int(len(chunked_sha_msgs) * (1 - args.p_val))
155 - with tqdm(total=len(chunked_sha_msgs)) as pbar: 178 + start(chunked_sha_msgs[:barrier], train=True)
156 - for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): 179 + start(chunked_sha_msgs[barrier:], train=False)
157 - pbar.update()
158 180
159 if __name__ == "__main__": 181 if __name__ == "__main__":
160 parser = argparse.ArgumentParser(description="Code to collect commits on github") 182 parser = argparse.ArgumentParser(description="Code to collect commits on github")
...@@ -196,6 +218,14 @@ if __name__ == "__main__": ...@@ -196,6 +218,14 @@ if __name__ == "__main__":
196 help="The maximum total input sequence length after tokenization. Sequences longer " 218 help="The maximum total input sequence length after tokenization. Sequences longer "
197 "than this will be truncated, sequences shorter will be padded.", 219 "than this will be truncated, sequences shorter will be padded.",
198 ) 220 )
221 + parser.add_argument(
222 + "--val_max_target_length",
223 + default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
224 + type=int,
225 + help="The maximum total input sequence length after tokenization. Sequences longer "
226 + "than this will be truncated, sequences shorter will be padded.",
227 + )
228 + parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset")
199 args = parser.parse_args() 229 args = parser.parse_args()
200 230
201 args.local_path = args.url.split('/')[-1] 231 args.local_path = args.url.split('/')[-1]
......