graykode

(fixed) remove max_source_length, add diff_parse function

...@@ -20,18 +20,25 @@ from transformers import AutoTokenizer ...@@ -20,18 +20,25 @@ from transformers import AutoTokenizer
20 from preprocess import diff_parse, truncate 20 from preprocess import diff_parse, truncate
21 from train import BartForConditionalGeneration 21 from train import BartForConditionalGeneration
22 22
23 +def get_length(chunks):
24 + cnt = 0
25 + for chunk in chunks:
26 + cnt += len(chunk)
27 + return cnt
28 +
29 +def suggester(chunks, model, tokenizer, device):
30 + max_source_length = get_length(chunks)
23 31
24 -def suggester(chunks, max_source_length, model, tokenizer, device):
25 input_ids, attention_masks, patch_ids = zip(*chunks) 32 input_ids, attention_masks, patch_ids = zip(*chunks)
26 - input_ids = torch.LongTensor([truncate(input_ids, max_source_length, value=0)]).to( 33 + input_ids = torch.LongTensor(
27 - device 34 + [truncate(input_ids, max_source_length, value=0)]
28 - ) 35 + ).to(device)
29 attention_masks = torch.LongTensor( 36 attention_masks = torch.LongTensor(
30 [truncate(attention_masks, max_source_length, value=1)] 37 [truncate(attention_masks, max_source_length, value=1)]
31 ).to(device) 38 ).to(device)
32 - patch_ids = torch.LongTensor([truncate(patch_ids, max_source_length, value=0)]).to( 39 + patch_ids = torch.LongTensor(
33 - device 40 + [truncate(patch_ids, max_source_length, value=0)]
34 - ) 41 + ).to(device)
35 42
36 summaries = model.generate( 43 summaries = model.generate(
37 input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks 44 input_ids=input_ids, patch_ids=patch_ids, attention_mask=attention_masks
...@@ -59,9 +66,13 @@ def main(args): ...@@ -59,9 +66,13 @@ def main(args):
59 staged_files = [f.strip() for f in staged_files] 66 staged_files = [f.strip() for f in staged_files]
60 chunks = "\n".join(staged_files) 67 chunks = "\n".join(staged_files)
61 68
69 + chunks = diff_parse(chunks, tokenizer)
70 + if not chunks:
71 + print('There is no file in staged state.')
72 + return
73 +
62 commit_message = suggester( 74 commit_message = suggester(
63 chunks, 75 chunks,
64 - max_source_length=args.max_source_length,
65 model=model, 76 model=model,
66 tokenizer=tokenizer, 77 tokenizer=tokenizer,
67 device=device, 78 device=device,
...@@ -89,13 +100,6 @@ if __name__ == "__main__": ...@@ -89,13 +100,6 @@ if __name__ == "__main__":
89 type=str, 100 type=str,
90 help="Pretrained tokenizer name or path if not the same as model_name", 101 help="Pretrained tokenizer name or path if not the same as model_name",
91 ) 102 )
92 - parser.add_argument(
93 - "--max_source_length",
94 - default=1024,
95 - type=int,
96 - help="The maximum total input sequence length after tokenization. Sequences longer "
97 - "than this will be truncated, sequences shorter will be padded.",
98 - )
99 args = parser.parse_args() 103 args = parser.parse_args()
100 104
101 main(args) 105 main(args)
......