Showing
1 changed file
with
46 additions
and
16 deletions
... | @@ -15,6 +15,7 @@ | ... | @@ -15,6 +15,7 @@ |
15 | import os | 15 | import os |
16 | import re | 16 | import re |
17 | import enum | 17 | import enum |
18 | +import random | ||
18 | import logging | 19 | import logging |
19 | import argparse | 20 | import argparse |
20 | import numpy as np | 21 | import numpy as np |
... | @@ -87,8 +88,7 @@ def sha_parse(sha, tokenizer, max_length=1024): | ... | @@ -87,8 +88,7 @@ def sha_parse(sha, tokenizer, max_length=1024): |
87 | return (input_ids, attention_masks, patch_ids) | 88 | return (input_ids, attention_masks, patch_ids) |
88 | 89 | ||
89 | def message_parse(msg, tokenizer, max_length=56): | 90 | def message_parse(msg, tokenizer, max_length=56): |
90 | - msg = re.sub(r'#([0-9])+', '', msg) | 91 | + msg = re.sub(r'(\(|)#([0-9])+(\)|)', '', msg) |
91 | - msg = re.sub(r'(\(|)([A-z])+-([0-9])+(\)|)(:|)', '', msg) | ||
92 | 92 | ||
93 | msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip() | 93 | msg = re.sub(r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+', '', msg).strip() |
94 | msg = tokenizer.tokenize(msg) | 94 | msg = tokenizer.tokenize(msg) |
... | @@ -97,7 +97,7 @@ def message_parse(msg, tokenizer, max_length=56): | ... | @@ -97,7 +97,7 @@ def message_parse(msg, tokenizer, max_length=56): |
97 | 97 | ||
98 | return msg | 98 | return msg |
99 | 99 | ||
100 | -def jobs(sha_msgs, args, data_config): | 100 | +def jobs(sha_msgs, args, data_config, train=True): |
101 | 101 | ||
102 | input_ids, attention_masks, patch_ids, targets = [], [], [], [] | 102 | input_ids, attention_masks, patch_ids, targets = [], [], [], [] |
103 | data_saver = DataSaver(config=data_config) | 103 | data_saver = DataSaver(config=data_config) |
... | @@ -105,11 +105,19 @@ def jobs(sha_msgs, args, data_config): | ... | @@ -105,11 +105,19 @@ def jobs(sha_msgs, args, data_config): |
105 | for sha_msg in sha_msgs: | 105 | for sha_msg in sha_msgs: |
106 | sha, msg = sha_msg | 106 | sha, msg = sha_msg |
107 | 107 | ||
108 | - source = sha_parse(sha, tokenizer=args.tokenizer) | 108 | + source = sha_parse( |
109 | + sha, | ||
110 | + tokenizer=args.tokenizer, | ||
111 | + max_length=args.max_source_length | ||
112 | + ) | ||
109 | if not source: | 113 | if not source: |
110 | continue | 114 | continue |
111 | input_id, attention_mask, patch_id = source | 115 | input_id, attention_mask, patch_id = source |
112 | - target = message_parse(msg, tokenizer=args.tokenizer) | 116 | + target = message_parse( |
117 | + msg, | ||
118 | + tokenizer=args.tokenizer, | ||
119 | + max_length=(args.max_target_length if train else args.val_max_target_length), | ||
120 | + ) | ||
113 | 121 | ||
114 | input_ids.append(input_id) | 122 | input_ids.append(input_id) |
115 | attention_masks.append(attention_mask) | 123 | attention_masks.append(attention_mask) |
... | @@ -124,9 +132,11 @@ def jobs(sha_msgs, args, data_config): | ... | @@ -124,9 +132,11 @@ def jobs(sha_msgs, args, data_config): |
124 | }) | 132 | }) |
125 | data_saver.disconnect() | 133 | data_saver.disconnect() |
126 | 134 | ||
127 | -def main(args): | 135 | +def start(chunked_sha_msgs, train=True): |
128 | - if 'access_key' not in os.environ or 'secret_key' not in os.environ: | 136 | + |
129 | - raise OSError("access_key or secret_key are not found.") | 137 | + logger.info(f"Start %s pre-processing" % ("training" if train else "evaluation")) |
138 | + | ||
139 | + max_target_length = args.max_target_length if train else args.val_max_target_length | ||
130 | 140 | ||
131 | data_config = DataConfig( | 141 | data_config = DataConfig( |
132 | endpoint=args.matorage_dir, | 142 | endpoint=args.matorage_dir, |
... | @@ -134,27 +144,39 @@ def main(args): | ... | @@ -134,27 +144,39 @@ def main(args): |
134 | secret_key=os.environ['secret_key'], | 144 | secret_key=os.environ['secret_key'], |
135 | dataset_name='commit-autosuggestions', | 145 | dataset_name='commit-autosuggestions', |
136 | additional={ | 146 | additional={ |
147 | + "mode" : ("training" if train else "evaluation"), | ||
137 | "max_source_length": args.max_source_length, | 148 | "max_source_length": args.max_source_length, |
138 | - "max_target_length": args.max_target_length, | 149 | + "max_target_length": max_target_length, |
150 | + "url" : args.url, | ||
139 | }, | 151 | }, |
140 | - attributes = [ | 152 | + attributes=[ |
141 | ('input_ids', 'int32', (args.max_source_length,)), | 153 | ('input_ids', 'int32', (args.max_source_length,)), |
142 | ('attention_masks', 'int32', (args.max_source_length,)), | 154 | ('attention_masks', 'int32', (args.max_source_length,)), |
143 | ('patch_ids', 'int32', (args.max_source_length,)), | 155 | ('patch_ids', 'int32', (args.max_source_length,)), |
144 | - ('targets', 'int32', (args.max_target_length,)) | 156 | + ('targets', 'int32', (max_target_length,)) |
145 | ] | 157 | ] |
146 | ) | 158 | ) |
147 | 159 | ||
160 | + func = partial(jobs, args=args, data_config=data_config, train=train) | ||
161 | + with Pool(processes=args.num_workers) as pool: | ||
162 | + with tqdm(total=len(chunked_sha_msgs)) as pbar: | ||
163 | + for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): | ||
164 | + pbar.update() | ||
165 | + | ||
166 | +def main(args): | ||
167 | + if 'access_key' not in os.environ or 'secret_key' not in os.environ: | ||
168 | + raise OSError("access_key or secret_key are not found.") | ||
169 | + | ||
148 | sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] | 170 | sha_msgs = [(c.hexsha, c.summary) for c in repo.iter_commits()] |
171 | + random.shuffle(sha_msgs) | ||
149 | chunked_sha_msgs = [ | 172 | chunked_sha_msgs = [ |
150 | sha_msgs[x:x + args.matorage_batch] | 173 | sha_msgs[x:x + args.matorage_batch] |
151 | for x in range(0, len(sha_msgs), args.matorage_batch) | 174 | for x in range(0, len(sha_msgs), args.matorage_batch) |
152 | ] | 175 | ] |
153 | - func = partial(jobs, args=args, data_config=data_config) | 176 | + |
154 | - with Pool(processes=args.num_workers) as pool: | 177 | + barrier = int(len(chunked_sha_msgs) * (1 - args.p_val)) |
155 | - with tqdm(total=len(chunked_sha_msgs)) as pbar: | 178 | + start(chunked_sha_msgs[:barrier], train=True) |
156 | - for i, _ in tqdm(enumerate(pool.imap_unordered(func, chunked_sha_msgs))): | 179 | + start(chunked_sha_msgs[barrier:], train=False) |
157 | - pbar.update() | ||
158 | 180 | ||
159 | if __name__ == "__main__": | 181 | if __name__ == "__main__": |
160 | parser = argparse.ArgumentParser(description="Code to collect commits on github") | 182 | parser = argparse.ArgumentParser(description="Code to collect commits on github") |
... | @@ -196,6 +218,14 @@ if __name__ == "__main__": | ... | @@ -196,6 +218,14 @@ if __name__ == "__main__": |
196 | help="The maximum total input sequence length after tokenization. Sequences longer " | 218 | help="The maximum total input sequence length after tokenization. Sequences longer " |
197 | "than this will be truncated, sequences shorter will be padded.", | 219 | "than this will be truncated, sequences shorter will be padded.", |
198 | ) | 220 | ) |
221 | + parser.add_argument( | ||
222 | + "--val_max_target_length", | ||
223 | + default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. | ||
224 | + type=int, | ||
225 | + help="The maximum total input sequence length after tokenization. Sequences longer " | ||
226 | + "than this will be truncated, sequences shorter will be padded.", | ||
227 | + ) | ||
228 | + parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset") | ||
199 | args = parser.parse_args() | 229 | args = parser.parse_args() |
200 | 230 | ||
201 | args.local_path = args.url.split('/')[-1] | 231 | args.local_path = args.url.split('/')[-1] | ... | ... |
-
Please register or login to post a comment