Showing
1 changed file
with
19 additions
and
15 deletions
... | @@ -20,18 +20,25 @@ from transformers import AutoTokenizer | ... | @@ -20,18 +20,25 @@ from transformers import AutoTokenizer |
20 | from preprocess import diff_parse, truncate | 20 | from preprocess import diff_parse, truncate |
21 | from train import BartForConditionalGeneration | 21 | from train import BartForConditionalGeneration |
22 | 22 | ||
23 | +def get_length(chunks): | ||
24 | + cnt = 0 | ||
25 | + for chunk in chunks: | ||
26 | + cnt += len(chunk) | ||
27 | + return cnt | ||
28 | + | ||
29 | +def suggester(chunks, model, tokenizer, device): | ||
30 | + max_source_length = get_length(chunks) | ||
23 | 31 | ||
24 | -def suggester(chunks, max_source_length, model, tokenizer, device): | ||
25 | input_ids, attention_masks, patch_ids = zip(*chunks) | 32 | input_ids, attention_masks, patch_ids = zip(*chunks) |
26 | - input_ids = torch.LongTensor([truncate(input_ids, max_source_length, value=0)]).to( | 33 | + input_ids = torch.LongTensor( |
27 | - device | 34 | + [truncate(input_ids, max_source_length, value=0)] |
28 | - ) | 35 | + ).to(device) |
29 | attention_masks = torch.LongTensor( | 36 | attention_masks = torch.LongTensor( |
30 | [truncate(attention_masks, max_source_length, value=1)] | 37 | [truncate(attention_masks, max_source_length, value=1)] |
31 | ).to(device) | 38 | ).to(device) |
32 | - patch_ids = torch.LongTensor([truncate(patch_ids, max_source_length, value=0)]).to( | 39 | + patch_ids = torch.LongTensor( |
33 | - device | 40 | + [truncate(patch_ids, max_source_length, value=0)] |
34 | - ) | 41 | + ).to(device) |
35 | 42 | ||
36 | summaries = model.generate( | 43 | summaries = model.generate( |
37 | input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks | 44 | input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks |
... | @@ -59,9 +66,13 @@ def main(args): | ... | @@ -59,9 +66,13 @@ def main(args): |
59 | staged_files = [f.strip() for f in staged_files] | 66 | staged_files = [f.strip() for f in staged_files] |
60 | chunks = "\n".join(staged_files) | 67 | chunks = "\n".join(staged_files) |
61 | 68 | ||
69 | + chunks = diff_parse(chunks, tokenizer) | ||
70 | + if not chunks: | ||
71 | + print('There is no file in staged state.') | ||
72 | + return | ||
73 | + | ||
62 | commit_message = suggester( | 74 | commit_message = suggester( |
63 | chunks, | 75 | chunks, |
64 | - max_source_length=args.max_source_length, | ||
65 | model=model, | 76 | model=model, |
66 | tokenizer=tokenizer, | 77 | tokenizer=tokenizer, |
67 | device=device, | 78 | device=device, |
... | @@ -89,13 +100,6 @@ if __name__ == "__main__": | ... | @@ -89,13 +100,6 @@ if __name__ == "__main__": |
89 | type=str, | 100 | type=str, |
90 | help="Pretrained tokenizer name or path if not the same as model_name", | 101 | help="Pretrained tokenizer name or path if not the same as model_name", |
91 | ) | 102 | ) |
92 | - parser.add_argument( | ||
93 | - "--max_source_length", | ||
94 | - default=1024, | ||
95 | - type=int, | ||
96 | - help="The maximum total input sequence length after tokenization. Sequences longer " | ||
97 | - "than this will be truncated, sequences shorter will be padded.", | ||
98 | - ) | ||
99 | args = parser.parse_args() | 103 | args = parser.parse_args() |
100 | 104 | ||
101 | main(args) | 105 | main(args) | ... | ... |
-
Please register or login to post a comment